diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 907a48ea2..18b9ee281 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -323,9 +323,6 @@ LANGSMITH_PROJECT=surfsense # ============================================================================= # OPTIONAL: New-chat agent feature flags # ============================================================================= -# Multi-agent orchestrator switch for authenticated chat streaming. -# MULTI_AGENT_CHAT_ENABLED=false - # Master kill-switch — when true, every flag below is forced OFF. # SURFSENSE_DISABLE_NEW_AGENT_STACK=false diff --git a/surfsense_backend/app/agents/autocomplete/autocomplete_agent.py b/surfsense_backend/app/agents/autocomplete/autocomplete_agent.py deleted file mode 100644 index 890b3e06e..000000000 --- a/surfsense_backend/app/agents/autocomplete/autocomplete_agent.py +++ /dev/null @@ -1,557 +0,0 @@ -"""Vision autocomplete agent with scoped filesystem exploration. - -Converts the stateless single-shot vision autocomplete into an agent that -seeds a virtual filesystem from KB search results and lets the vision LLM -explore documents via ``ls``, ``read_file``, ``glob``, ``grep``, etc. -before generating the final completion. - -Performance: KB search and agent graph compilation run in parallel so -the only sequential latency is KB-search (or agent compile, whichever is -slower) + the agent's LLM turns. There is no separate "query extraction" -LLM call — the window title is used directly as the KB search query. -""" - -from __future__ import annotations - -import asyncio -import json -import logging -import re -import uuid -from collections.abc import AsyncGenerator -from typing import Any - -from deepagents.graph import BASE_AGENT_PROMPT -from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware -from langchain.agents import create_agent -from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware -from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage, ToolMessage - -from app.agents.new_chat.document_xml import build_document_xml -from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware -from app.agents.new_chat.middleware.knowledge_search import ( - search_knowledge_base, -) -from app.agents.new_chat.path_resolver import ( - DOCUMENTS_ROOT, - build_path_index, - doc_to_virtual_path, -) -from app.db import shielded_async_session -from app.services.new_streaming_service import VercelStreamingService - -try: - from deepagents.backends.utils import create_file_data -except Exception: # pragma: no cover - defensive - - def create_file_data(content: str) -> dict[str, Any]: - return {"content": content.split("\n")} - - -async def _build_autocomplete_filesystem( - *, - documents: Any, - search_space_id: int, -) -> tuple[dict[str, Any], dict[int, str]]: - """Build a ``state['files']``-shaped dict from KB search results. - - This is the autocomplete-specific replacement for the previous - ``build_scoped_filesystem`` helper. It uses the canonical path resolver - so paths line up with the rest of the system, including collision - suffixes for duplicate titles. - """ - files: dict[str, Any] = {} - doc_id_to_path: dict[int, str] = {} - - if not documents: - return files, doc_id_to_path - - async with shielded_async_session() as session: - index = await build_path_index(session, search_space_id) - - for document in documents: - if not isinstance(document, dict): - continue - meta = document.get("document") or {} - doc_id = meta.get("id") - if not isinstance(doc_id, int): - continue - title = str(meta.get("title") or "untitled") - folder_id = meta.get("folder_id") - path = doc_to_virtual_path( - doc_id=doc_id, title=title, folder_id=folder_id, index=index - ) - chunk_ids = document.get("matched_chunk_ids") or [] - try: - matched_set = {int(c) for c in chunk_ids} - except (TypeError, ValueError): - matched_set = set() - xml = build_document_xml(document, matched_chunk_ids=matched_set) - files[path] = create_file_data(xml) - doc_id_to_path[doc_id] = path - - if not files: - # Ensure the synthetic /documents folder is visible even when empty. - files.setdefault(f"{DOCUMENTS_ROOT}/.placeholder", create_file_data("")) - - return files, doc_id_to_path - - -logger = logging.getLogger(__name__) - -KB_TOP_K = 10 - -# --------------------------------------------------------------------------- -# System prompt -# --------------------------------------------------------------------------- - -AUTOCOMPLETE_SYSTEM_PROMPT = """You are a smart writing assistant that analyzes the user's screen to draft or complete text. - -You will receive a screenshot of the user's screen. Your PRIMARY source of truth is the screenshot itself — the visual context determines what to write. - -Your job: -1. Analyze the ENTIRE screenshot to understand what the user is working on (email thread, chat conversation, document, code editor, form, etc.). -2. Identify the text area where the user will type. -3. Generate the text the user most likely wants to write based on the visual context. - -You also have access to the user's knowledge base documents via filesystem tools. However: -- ONLY consult the knowledge base if the screenshot clearly involves a topic where your KB documents are DIRECTLY relevant (e.g., the user is writing about a specific project/topic that matches a document title). -- Do NOT explore documents just because they exist. Most autocomplete requests can be answered purely from the screenshot. -- If you do read a document, only incorporate information that is 100% relevant to what the user is typing RIGHT NOW. Do not add extra details, background, or tangential information from the KB. -- Keep your output SHORT — autocomplete should feel like a natural continuation, not an essay. - -Key behavior: -- If the text area is EMPTY, draft a concise response or message based on what you see on screen (e.g., reply to an email, respond to a chat message, continue a document). -- If the text area already has text, continue it naturally — typically just a sentence or two. - -Rules: -- Be CONCISE. Prefer a single paragraph or a few sentences. Autocomplete is a quick assist, not a full draft. -- Match the tone and formality of the surrounding context. -- If the screen shows code, write code. If it shows a casual chat, be casual. If it shows a formal email, be formal. -- Do NOT describe the screenshot or explain your reasoning. -- Do NOT cite or reference documents explicitly — just let the knowledge inform your writing naturally. -- If you cannot determine what to write, output an empty JSON array: [] - -## Output Format - -You MUST provide exactly 3 different suggestion options. Each should be a distinct, plausible completion — vary the tone, detail level, or angle. - -Return your suggestions as a JSON array of exactly 3 strings. Output ONLY the JSON array, nothing else — no markdown fences, no explanation, no commentary. - -Example format: -["First suggestion text here.", "Second suggestion — a different take.", "Third option with another approach."] - -## Filesystem Tools `ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep` - -All file paths must start with a `/`. -- ls: list files and directories at a given path. -- read_file: read a file from the filesystem. -- write_file: create a temporary file in the session (not persisted). -- edit_file: edit a file in the session (not persisted for /documents/ files). -- glob: find files matching a pattern (e.g., "**/*.xml"). -- grep: search for text within files. - -## When to Use Filesystem Tools - -BEFORE reaching for any tool, ask yourself: "Can I write a good completion purely from the screenshot?" If yes, just write it — do NOT explore the KB. - -Only use tools when: -- The user is clearly writing about a specific topic that likely has detailed information in their KB. -- You need a specific fact, name, number, or reference that the screenshot doesn't provide. - -When you do use tools, be surgical: -- Check the `ls` output first. If no document title looks relevant, stop — do not read files just to see what's there. -- If a title looks relevant, read only the `` (first ~20 lines) and jump to matched chunks. Do not read entire documents. -- Extract only the specific information you need and move on to generating the completion. - -## Reading Documents Efficiently - -Documents are formatted as XML. Each document contains: -- `` — title, type, URL, etc. -- `` — a table of every chunk with its **line range** and a - `matched="true"` flag for chunks that matched the search query. -- `` — the actual chunks in original document order. - -**Workflow**: read the first ~20 lines to see the ``, identify -chunks marked `matched="true"`, then use `read_file(path, offset=, -limit=)` to jump directly to those sections.""" - -APP_CONTEXT_BLOCK = """ - -The user is currently working in "{app_name}" (window: "{window_title}"). Use this to understand the type of application and adapt your tone and format accordingly.""" - - -def _build_autocomplete_system_prompt(app_name: str, window_title: str) -> str: - prompt = AUTOCOMPLETE_SYSTEM_PROMPT - if app_name: - prompt += APP_CONTEXT_BLOCK.format(app_name=app_name, window_title=window_title) - return prompt - - -# --------------------------------------------------------------------------- -# Pre-compute KB filesystem (runs in parallel with agent compilation) -# --------------------------------------------------------------------------- - - -class _KBResult: - """Container for pre-computed KB filesystem results.""" - - __slots__ = ("files", "ls_ai_msg", "ls_tool_msg") - - def __init__( - self, - files: dict[str, Any] | None = None, - ls_ai_msg: AIMessage | None = None, - ls_tool_msg: ToolMessage | None = None, - ) -> None: - self.files = files - self.ls_ai_msg = ls_ai_msg - self.ls_tool_msg = ls_tool_msg - - @property - def has_documents(self) -> bool: - return bool(self.files) - - -async def precompute_kb_filesystem( - search_space_id: int, - query: str, - top_k: int = KB_TOP_K, -) -> _KBResult: - """Search the KB and build the scoped filesystem outside the agent. - - This is designed to be called via ``asyncio.gather`` alongside agent - graph compilation so the two run concurrently. - """ - if not query: - return _KBResult() - - try: - search_results = await search_knowledge_base( - query=query, - search_space_id=search_space_id, - top_k=top_k, - ) - - if not search_results: - return _KBResult() - - new_files, _ = await _build_autocomplete_filesystem( - documents=search_results, - search_space_id=search_space_id, - ) - - if not new_files: - return _KBResult() - - doc_paths = [ - p - for p, v in new_files.items() - if p.startswith("/documents/") and v is not None - ] - tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}" - ai_msg = AIMessage( - content="", - tool_calls=[ - {"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id} - ], - ) - tool_msg = ToolMessage( - content=str(doc_paths) if doc_paths else "No documents found.", - tool_call_id=tool_call_id, - ) - return _KBResult(files=new_files, ls_ai_msg=ai_msg, ls_tool_msg=tool_msg) - - except Exception: - logger.warning( - "KB pre-computation failed, proceeding without KB", exc_info=True - ) - return _KBResult() - - -# --------------------------------------------------------------------------- -# Filesystem middleware — no save_document, no persistence -# --------------------------------------------------------------------------- - - -class AutocompleteFilesystemMiddleware(SurfSenseFilesystemMiddleware): - """Filesystem middleware for autocomplete — read-only exploration only. - - Passes ``search_space_id=None`` so the new persistence pipeline is - bypassed; the autocomplete flow only reads, never commits to Postgres. - """ - - def __init__(self) -> None: - super().__init__(search_space_id=None, created_by_id=None) - - -# --------------------------------------------------------------------------- -# Agent factory -# --------------------------------------------------------------------------- - - -async def _compile_agent( - llm: BaseChatModel, - app_name: str, - window_title: str, -) -> Any: - """Compile the agent graph (CPU-bound, runs in a thread).""" - system_prompt = _build_autocomplete_system_prompt(app_name, window_title) - final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT - - middleware = [ - AutocompleteFilesystemMiddleware(), - PatchToolCallsMiddleware(), - AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), - ] - - agent = await asyncio.to_thread( - create_agent, - llm, - system_prompt=final_system_prompt, - tools=[], - middleware=middleware, - ) - return agent.with_config({"recursion_limit": 200}) - - -async def create_autocomplete_agent( - llm: BaseChatModel, - *, - search_space_id: int, - kb_query: str, - app_name: str = "", - window_title: str = "", -) -> tuple[Any, _KBResult]: - """Create the autocomplete agent and pre-compute KB in parallel. - - Returns ``(agent, kb_result)`` so the caller can inject the pre-computed - filesystem into the agent's initial state without any middleware delay. - """ - agent, kb = await asyncio.gather( - _compile_agent(llm, app_name, window_title), - precompute_kb_filesystem(search_space_id, kb_query), - ) - return agent, kb - - -# --------------------------------------------------------------------------- -# JSON suggestion parsing (with fallback) -# --------------------------------------------------------------------------- - - -def _parse_suggestions(raw: str) -> list[str]: - """Extract a list of suggestion strings from the agent's output. - - Tries, in order: - 1. Direct ``json.loads`` - 2. Extract content between ```json ... ``` fences - 3. Find the first ``[`` … ``]`` span - Falls back to wrapping the raw text as a single suggestion. - """ - text = raw.strip() - if not text: - return [] - - for candidate in _json_candidates(text): - try: - parsed = json.loads(candidate) - if isinstance(parsed, list) and all(isinstance(s, str) for s in parsed): - return [s for s in parsed if s.strip()] - except (json.JSONDecodeError, ValueError): - continue - - return [text] - - -def _json_candidates(text: str) -> list[str]: - """Yield candidate JSON strings from raw text.""" - candidates = [text] - - fence = re.search(r"```(?:json)?\s*\n?(.*?)```", text, re.DOTALL) - if fence: - candidates.append(fence.group(1).strip()) - - bracket = re.search(r"\[.*]", text, re.DOTALL) - if bracket: - candidates.append(bracket.group(0)) - - return candidates - - -# --------------------------------------------------------------------------- -# Streaming helper -# --------------------------------------------------------------------------- - - -async def stream_autocomplete_agent( - agent: Any, - input_data: dict[str, Any], - streaming_service: VercelStreamingService, - *, - emit_message_start: bool = True, -) -> AsyncGenerator[str, None]: - """Stream agent events as Vercel SSE, with thinking steps for tool calls. - - When ``emit_message_start`` is False the caller has already sent the - ``message_start`` event (e.g. to show preparation steps before the agent - runs). - """ - thread_id = uuid.uuid4().hex - config = {"configurable": {"thread_id": thread_id}} - - text_buffer: list[str] = [] - active_tool_depth = 0 - thinking_step_counter = 0 - tool_step_ids: dict[str, str] = {} - step_titles: dict[str, str] = {} - completed_step_ids: set[str] = set() - last_active_step_id: str | None = None - - def next_thinking_step_id() -> str: - nonlocal thinking_step_counter - thinking_step_counter += 1 - return f"autocomplete-step-{thinking_step_counter}" - - def complete_current_step() -> str | None: - nonlocal last_active_step_id - if last_active_step_id and last_active_step_id not in completed_step_ids: - completed_step_ids.add(last_active_step_id) - title = step_titles.get(last_active_step_id, "Done") - event = streaming_service.format_thinking_step( - step_id=last_active_step_id, - title=title, - status="complete", - ) - last_active_step_id = None - return event - return None - - if emit_message_start: - yield streaming_service.format_message_start() - - gen_step_id = next_thinking_step_id() - last_active_step_id = gen_step_id - step_titles[gen_step_id] = "Generating suggestions" - yield streaming_service.format_thinking_step( - step_id=gen_step_id, - title="Generating suggestions", - status="in_progress", - ) - - try: - async for event in agent.astream_events( - input_data, config=config, version="v2" - ): - event_type = event.get("event", "") - if event_type == "on_chat_model_stream": - if active_tool_depth > 0: - continue - if "surfsense:internal" in event.get("tags", []): - continue - chunk = event.get("data", {}).get("chunk") - if chunk and hasattr(chunk, "content"): - content = chunk.content - if content and isinstance(content, str): - text_buffer.append(content) - - elif event_type == "on_chat_model_end": - if active_tool_depth > 0: - continue - if "surfsense:internal" in event.get("tags", []): - continue - output = event.get("data", {}).get("output") - if output and hasattr(output, "content"): - if getattr(output, "tool_calls", None): - continue - content = output.content - if content and isinstance(content, str) and not text_buffer: - text_buffer.append(content) - - elif event_type == "on_tool_start": - active_tool_depth += 1 - tool_name = event.get("name", "unknown_tool") - run_id = event.get("run_id", "") - tool_input = event.get("data", {}).get("input", {}) - - step_event = complete_current_step() - if step_event: - yield step_event - - tool_step_id = next_thinking_step_id() - tool_step_ids[run_id] = tool_step_id - last_active_step_id = tool_step_id - - title, items = _describe_tool_call(tool_name, tool_input) - step_titles[tool_step_id] = title - yield streaming_service.format_thinking_step( - step_id=tool_step_id, - title=title, - status="in_progress", - items=items, - ) - - elif event_type == "on_tool_end": - active_tool_depth = max(0, active_tool_depth - 1) - run_id = event.get("run_id", "") - step_id = tool_step_ids.pop(run_id, None) - if step_id and step_id not in completed_step_ids: - completed_step_ids.add(step_id) - title = step_titles.get(step_id, "Done") - yield streaming_service.format_thinking_step( - step_id=step_id, - title=title, - status="complete", - ) - if last_active_step_id == step_id: - last_active_step_id = None - - step_event = complete_current_step() - if step_event: - yield step_event - - raw_text = "".join(text_buffer) - suggestions = _parse_suggestions(raw_text) - - yield streaming_service.format_data("suggestions", {"options": suggestions}) - - yield streaming_service.format_finish() - yield streaming_service.format_done() - - except Exception as e: - logger.error(f"Autocomplete agent streaming error: {e}", exc_info=True) - yield streaming_service.format_error("Autocomplete failed. Please try again.") - yield streaming_service.format_done() - - -def _describe_tool_call(tool_name: str, tool_input: Any) -> tuple[str, list[str]]: - """Return a human-readable (title, items) for a tool call thinking step.""" - inp = tool_input if isinstance(tool_input, dict) else {} - if tool_name == "ls": - path = inp.get("path", "/") - return "Listing files", [path] - if tool_name == "read_file": - fp = inp.get("file_path", "") - display = fp if len(fp) <= 80 else "…" + fp[-77:] - return "Reading file", [display] - if tool_name == "write_file": - fp = inp.get("file_path", "") - display = fp if len(fp) <= 80 else "…" + fp[-77:] - return "Writing file", [display] - if tool_name == "edit_file": - fp = inp.get("file_path", "") - display = fp if len(fp) <= 80 else "…" + fp[-77:] - return "Editing file", [display] - if tool_name == "glob": - pat = inp.get("pattern", "") - base = inp.get("path", "/") - return "Searching files", [f"{pat} in {base}"] - if tool_name == "grep": - pat = inp.get("pattern", "") - path = inp.get("path", "") - display_pat = pat[:60] + ("…" if len(pat) > 60 else "") - return "Searching content", [ - f'"{display_pat}"' + (f" in {path}" if path else "") - ] - return f"Using {tool_name}", [] diff --git a/surfsense_backend/app/agents/chat/__init__.py b/surfsense_backend/app/agents/chat/__init__.py new file mode 100644 index 000000000..4f6b7d07f --- /dev/null +++ b/surfsense_backend/app/agents/chat/__init__.py @@ -0,0 +1,5 @@ +"""Chat agents category. + +Groups the conversational agents that share a kernel: ``anonymous_chat`` and +``multi_agent_chat``. Code shared by *both* lives in ``chat/shared/``. +""" diff --git a/surfsense_backend/app/agents/chat/anonymous_chat/__init__.py b/surfsense_backend/app/agents/chat/anonymous_chat/__init__.py new file mode 100644 index 000000000..ba3b2a6f1 --- /dev/null +++ b/surfsense_backend/app/agents/chat/anonymous_chat/__init__.py @@ -0,0 +1,14 @@ +"""Anonymous / free-chat agent. + +The no-login chat experience: a deliberately minimal agent that bypasses the +full SurfSense deep-agent stack (filesystem, knowledge-base persistence, +subagents, skills, memory) and answers with an optional ``web_search`` tool and +an optional read-only uploaded document. See :mod:`.agent` for details. +""" + +from app.agents.chat.anonymous_chat.agent import ( + build_anonymous_system_prompt, + create_anonymous_chat_agent, +) + +__all__ = ["build_anonymous_system_prompt", "create_anonymous_chat_agent"] diff --git a/surfsense_backend/app/agents/new_chat/anonymous_agent.py b/surfsense_backend/app/agents/chat/anonymous_chat/agent.py similarity index 97% rename from surfsense_backend/app/agents/new_chat/anonymous_agent.py rename to surfsense_backend/app/agents/chat/anonymous_chat/agent.py index c783d9a45..250b4c158 100644 --- a/surfsense_backend/app/agents/new_chat/anonymous_agent.py +++ b/surfsense_backend/app/agents/chat/anonymous_chat/agent.py @@ -27,12 +27,12 @@ from langchain.agents.middleware import ( from langchain_core.language_models import BaseChatModel from langgraph.types import Checkpointer -from app.agents.new_chat.context import SurfSenseContextSchema -from app.agents.new_chat.middleware import ( +from app.agents.chat.shared.context import SurfSenseContextSchema +from app.agents.chat.shared.middleware import ( RetryAfterMiddleware, create_surfsense_compaction_middleware, ) -from app.agents.new_chat.tools.web_search import create_web_search_tool +from app.agents.chat.shared.tools.web_search import create_web_search_tool # Cap how much of an uploaded document we inline into the system prompt. The # upload endpoint allows files up to several MB, but the doc is re-sent on diff --git a/surfsense_backend/app/agents/multi_agent_chat/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/constants.py b/surfsense_backend/app/agents/chat/multi_agent_chat/constants.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/constants.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/constants.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/context_prune/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/context_prune/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/context_prune/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/context_prune/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/context_prune/prune_tool_names.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/context_prune/prune_tool_names.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/context_prune/prune_tool_names.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/context_prune/prune_tool_names.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/graph/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/graph/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/graph/compile_graph_sync.py similarity index 89% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/graph/compile_graph_sync.py index b86da932a..2755d5d96 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/graph/compile_graph_sync.py @@ -11,12 +11,12 @@ from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.types import Checkpointer -from app.agents.multi_agent_chat.middleware.stack import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.stack import ( build_main_agent_deepagent_middleware, ) -from app.agents.new_chat.context import SurfSenseContextSchema -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.shared.context import SurfSenseContextSchema from app.db import ChatVisibility diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/__init__.py diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/action_log/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/action_log/__init__.py new file mode 100644 index 000000000..46fa28009 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/action_log/__init__.py @@ -0,0 +1,10 @@ +"""Action-log middleware: audit row per tool call (impl + builder).""" + +from .builder import build_action_log_mw +from .middleware import ActionLogMiddleware, ToolDefinition + +__all__ = [ + "ActionLogMiddleware", + "ToolDefinition", + "build_action_log_mw", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/action_log.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/action_log/builder.py similarity index 62% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/action_log.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/action_log/builder.py index c9f893d97..9213f1339 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/action_log.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/action_log/builder.py @@ -4,11 +4,10 @@ from __future__ import annotations import logging -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.middleware import ActionLogMiddleware -from app.agents.new_chat.tools.registry import BUILTIN_TOOLS +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled -from ..shared.flags import enabled +from .middleware import ActionLogMiddleware def build_action_log_mw( @@ -21,12 +20,13 @@ def build_action_log_mw( if not enabled(flags, "enable_action_log") or thread_id is None: return None try: - tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS} + # No built-in tool declares a ``reverse`` callable yet, so the action + # log runs without a tool_definitions map. Reversibility is opt-in per + # tool via ``ToolDefinition.reverse`` and can be wired here when used. return ActionLogMiddleware( thread_id=thread_id, search_space_id=search_space_id, user_id=user_id, - tool_definitions=tool_defs_by_name, ) except Exception: # pragma: no cover - defensive logging.warning( diff --git a/surfsense_backend/app/agents/new_chat/middleware/action_log.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/action_log/middleware.py similarity index 86% rename from surfsense_backend/app/agents/new_chat/middleware/action_log.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/action_log/middleware.py index 716a1616c..789705d0e 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/action_log.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/action_log/middleware.py @@ -1,25 +1,15 @@ """Append-only action-log middleware for the SurfSense agent. -Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes -a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt -into reversibility by declaring a ``reverse`` callable on their -:class:`~app.agents.new_chat.tools.registry.ToolDefinition`; the rendered -descriptor is persisted in ``reverse_descriptor`` for use by +Wraps every tool call and writes a row to :class:`~app.db.AgentActionLog` +after the tool returns. Tools opt into reversibility via a ``reverse`` +callable on their :class:`ToolDefinition`; the rendered descriptor powers ``/api/threads/{thread_id}/revert/{action_id}``. -Design points: - -* **Defensive.** Logging never blocks the agent. We catch every exception - on the DB write path and emit a warning; the tool's ``ToolMessage`` - result is always returned untouched. -* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) + - ``result_id`` + ``reverse_descriptor`` are stored. Tool output text - remains in the LangGraph checkpoint / spilled tool-output files. -* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)`` - with the parsed JSON result when the tool's content is a JSON object; - otherwise the raw text is passed. Exceptions in the reverse callable - are swallowed and logged — a failed descriptor render simply means the - action is NOT marked reversible. +Logging is fully defensive — DB-write failures are swallowed so the tool's +result is always returned untouched. Only metadata (name, capped args, +result_id, reverse_descriptor) is stored; tool output stays in the +checkpoint. Reversibility is best-effort: a reverse callable that raises +just leaves the action non-reversible. """ from __future__ import annotations @@ -27,14 +17,14 @@ from __future__ import annotations import json import logging from collections.abc import Awaitable, Callable +from dataclasses import dataclass from typing import TYPE_CHECKING, Any from langchain.agents.middleware import AgentMiddleware from langchain_core.callbacks import adispatch_custom_event from langchain_core.messages import ToolMessage -from app.agents.new_chat.feature_flags import get_flags -from app.agents.new_chat.tools.registry import ToolDefinition +from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags if TYPE_CHECKING: # pragma: no cover - type-only from langchain.agents.middleware.types import ToolCallRequest @@ -44,6 +34,31 @@ if TYPE_CHECKING: # pragma: no cover - type-only logger = logging.getLogger(__name__) +@dataclass +class ToolDefinition: + """Reversibility descriptor consumed by :class:`ActionLogMiddleware`. + + Only ``name`` and ``reverse`` are read by the middleware; the remaining + fields let callers and tests describe a tool declaratively. A tool is + marked reversible in the action log when ``reverse`` is set and renders a + descriptor without raising. + + Attributes: + name: Unique identifier for the tool. + description: Human-readable description of what the tool does. + factory: Optional callable that builds the tool (unused by the + middleware; retained for declarative call sites/tests). + reverse: Optional callable that, given the tool's ``(args, result)``, + returns a ``ReverseDescriptor`` describing the inverse invocation. + + """ + + name: str + description: str = "" + factory: Callable[[dict[str, Any]], Any] | None = None + reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None + + # Cap for the persisted ``args`` JSON to avoid bloating the action log with # accidentally-huge inputs. Values are truncated and a flag is set in the # stored payload so consumers can detect truncation. @@ -178,11 +193,9 @@ class ActionLogMiddleware(AgentMiddleware): ) return - # Surface a side-channel SSE event so the chat tool card can - # render a Revert button immediately after the row is durable. - # ``stream_new_chat`` translates this into a - # ``data-action-log`` SSE event. We DO NOT include the - # ``reverse_descriptor`` payload here; only a presence flag. + # Side-channel event (relayed by ``stream_new_chat`` as a + # ``data-action-log`` SSE) so the tool card can show a Revert button + # once the row is durable. Carries a presence flag, not the descriptor. try: await adispatch_custom_event( "action_log", diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/anonymous_document/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/anonymous_document/__init__.py new file mode 100644 index 000000000..5684a592c --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/anonymous_document/__init__.py @@ -0,0 +1,9 @@ +"""Anonymous-document middleware: Redis hydration, cloud only (impl + builder).""" + +from .builder import build_anonymous_doc_mw +from .middleware import AnonymousDocumentMiddleware + +__all__ = [ + "AnonymousDocumentMiddleware", + "build_anonymous_doc_mw", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/anonymous_doc.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/anonymous_document/builder.py similarity index 73% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/anonymous_doc.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/anonymous_document/builder.py index afd54a2d3..f03543124 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/anonymous_doc.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/anonymous_document/builder.py @@ -2,8 +2,9 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.middleware import AnonymousDocumentMiddleware +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode + +from .middleware import AnonymousDocumentMiddleware def build_anonymous_doc_mw( diff --git a/surfsense_backend/app/agents/new_chat/middleware/anonymous_document.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/anonymous_document/middleware.py similarity index 93% rename from surfsense_backend/app/agents/new_chat/middleware/anonymous_document.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/anonymous_document/middleware.py index 2893d2e11..d29c31230 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/anonymous_document.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/anonymous_document/middleware.py @@ -24,8 +24,13 @@ from typing import Any from langchain.agents.middleware import AgentMiddleware, AgentState from langgraph.runtime import Runtime -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, safe_filename +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) +from app.agents.chat.runtime.path_resolver import ( + DOCUMENTS_ROOT, + safe_filename, +) logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/busy_mutex/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/busy_mutex/__init__.py new file mode 100644 index 000000000..17c33b8ab --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/busy_mutex/__init__.py @@ -0,0 +1,25 @@ +"""Per-turn cooperative busy-lock middleware + cancel primitives (main-agent).""" + +from .builder import build_busy_mutex_mw +from .middleware import ( + BusyMutexMiddleware, + end_turn, + get_cancel_event, + get_cancel_state, + is_cancel_requested, + manager, + request_cancel, + reset_cancel, +) + +__all__ = [ + "BusyMutexMiddleware", + "build_busy_mutex_mw", + "end_turn", + "get_cancel_event", + "get_cancel_state", + "is_cancel_requested", + "manager", + "request_cancel", + "reset_cancel", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/busy_mutex.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/busy_mutex/builder.py similarity index 54% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/busy_mutex.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/busy_mutex/builder.py index 0ea53bf16..0daf87e0b 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/busy_mutex.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/busy_mutex/builder.py @@ -2,10 +2,12 @@ from __future__ import annotations -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.middleware import BusyMutexMiddleware +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled -from ..shared.flags import enabled +from .middleware import ( + BusyMutexMiddleware, +) def build_busy_mutex_mw(flags: AgentFeatureFlags) -> BusyMutexMiddleware | None: diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/busy_mutex/middleware.py similarity index 83% rename from surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/busy_mutex/middleware.py index e7d9b8f75..7a82196d9 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/busy_mutex/middleware.py @@ -1,32 +1,12 @@ -""" -BusyMutexMiddleware — per-thread asyncio lock + cancel token. +"""Per-thread asyncio lock + cooperative cancel token, keyed by ``thread_id``. -LangChain has no built-in concept of "this thread is already running a -turn — refuse the second concurrent request". Without it, a user -double-clicking "send" or refreshing the page mid-stream can spawn two -turns racing on the same checkpoint, producing duplicated tool calls -and mangled state. +Refuses a second concurrent turn on the same thread (e.g. double-clicked +"send") that would otherwise race on the same checkpoint and duplicate tool +calls. Also exposes a per-thread cancel event that long-running tools poll +via ``runtime.context.cancel_event.is_set()`` to abort cooperatively. -Ported from OpenCode's ``Stream.scoped(AbortController)`` pattern: a -single-process, in-memory lock + cooperative cancellation token keyed by -``thread_id``. For multi-worker deployments a distributed lock backend -(Redis or PostgreSQL advisory locks) is a phase-2 follow-up. - -What this provides: -- A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``; - acquiring the lock during ``before_agent`` blocks any concurrent - prompt on the same thread until release. -- A per-thread ``asyncio.Event`` (``cancel_event``) that long-running - tools can poll to abort cooperatively. The event is reset between - turns. Tools should check ``runtime.context.cancel_event.is_set()`` - in tight inner loops. -- A typed :class:`~app.agents.new_chat.errors.BusyError` raised when a - second turn arrives while the lock is held. - -Note: SurfSense's ``stream_new_chat`` is the call site that should -acquire/release. Wiring this as middleware means the contract is -explicit and the lock manager is shared with subagents that compile -their own ``create_agent`` runnables. +Process-local and in-memory; multi-worker deployments need a distributed lock +(Redis / PostgreSQL advisory locks) as a follow-up. """ from __future__ import annotations @@ -46,7 +26,7 @@ from langchain.agents.middleware.types import ( from langgraph.config import get_config from langgraph.runtime import Runtime -from app.agents.new_chat.errors import BusyError +from app.agents.chat.runtime.errors import BusyError logger = logging.getLogger(__name__) @@ -152,9 +132,8 @@ class _ThreadLockManager: return True -# Module-level singleton — process-local but reused across all agent -# instances built in this process. Subagents created in nested -# ``create_agent`` calls also get this so locks are coherent. +# Process-local singleton shared across all agents/subagents built in this +# process so per-thread locks stay coherent. manager = _ThreadLockManager() @@ -266,7 +245,6 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo await lock.acquire() epoch = manager.bump_turn_epoch(thread_id) self._held_locks[thread_id] = (lock, epoch) - # Reset the cancel event so this turn starts fresh reset_cancel(thread_id) return None @@ -289,17 +267,14 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo return None if lock.locked(): lock.release() - # Always clear cancel event between turns so a stale signal - # doesn't leak into the next request. + # Clear cancel event so a stale signal doesn't leak into the next turn. reset_cancel(thread_id) return None - # Provide sync no-ops because the middleware base class allows them def before_agent( # type: ignore[override] self, state: AgentState[Any], runtime: Runtime[ContextT] ) -> dict[str, Any] | None: - # Sync path: no asyncio.Lock to acquire. Best we can do is reject - # if anyone else is in flight. + # Sync path can't await an asyncio.Lock; only reject if one is in flight. thread_id = self._thread_id(runtime) if thread_id is None: if self._require_thread_id: diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/config.py similarity index 66% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/config.py index ad5b58607..72e2282ff 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/config.py @@ -1,7 +1,9 @@ -"""RunnableConfig wiring for nested subagent invocations. +"""HITL resume side-channel for nested subagent invocations. -Forwards the parent's ``runtime.config`` (thread_id, …) into the subagent and -exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads. +Exposes the configurable side-channel ``stream_resume_chat`` uses to ferry +resume payloads into a mid-flight subagent. The ``RunnableConfig`` builder and +state-key filter shared with subagents live in +``app.agents.chat.multi_agent_chat.subagents.shared.invocation``. """ from __future__ import annotations @@ -11,8 +13,6 @@ from typing import Any from langchain.tools import ToolRuntime -from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT - logger = logging.getLogger(__name__) # langgraph stores the parent task's scratchpad under this configurable key; @@ -20,39 +20,6 @@ logger = logging.getLogger(__name__) _LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad" -def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]: - """RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``. - - Each parallel subagent invocation lands in its own checkpoint slot keyed - by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``. - The same call across the resume cycle keeps reading from the same snapshot - (``tool_call_id`` is stable per LLM-emitted call). - - We namespace via ``thread_id`` rather than ``checkpoint_ns`` because - langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a - subgraph path and raises ``ValueError("Subgraph X not found")``. - """ - merged: dict[str, Any] = dict(runtime.config) if runtime.config else {} - current_limit = merged.get("recursion_limit") - try: - current_int = int(current_limit) if current_limit is not None else 0 - except (TypeError, ValueError): - current_int = 0 - if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT: - merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT - - configurable: dict[str, Any] = dict(merged.get("configurable") or {}) - parent_thread_id = configurable.get("thread_id") - per_call_suffix = f"task:{runtime.tool_call_id}" - configurable["thread_id"] = ( - f"{parent_thread_id}::{per_call_suffix}" - if parent_thread_id - else per_call_suffix - ) - merged["configurable"] = configurable - return merged - - def consume_surfsense_resume(runtime: ToolRuntime) -> Any: """Pop the resume payload for *this* call's ``tool_call_id``. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/constants.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/constants.py similarity index 85% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/constants.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/constants.py index e11f3c3ec..d6a328b2a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/constants.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/constants.py @@ -1,24 +1,14 @@ -"""Constants shared by the checkpointed subagent middleware.""" +"""Tuning constants for the checkpointed subagent middleware. + +``EXCLUDED_STATE_KEYS`` and ``DEFAULT_SUBAGENT_RECURSION_LIMIT`` are part of the +subagent-invocation contract shared with subagents and now live in +``app.agents.chat.multi_agent_chat.subagents.shared.invocation``. +""" from __future__ import annotations import os -# Mirror of deepagents.middleware.subagents._EXCLUDED_STATE_KEYS. -EXCLUDED_STATE_KEYS = frozenset( - { - "messages", - "todos", - "structured_response", - "skills_metadata", - "memory_contents", - } -) - -# Match the parent graph's budget; the LangGraph default of 25 trips on -# multi-step subagent runs. -DEFAULT_SUBAGENT_RECURSION_LIMIT = 10_000 - def _read_timeout_env(name: str, default: float) -> float: """Parse ``name`` from the environment; fall back to ``default`` on bad values. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/middleware.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/middleware.py similarity index 98% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/middleware.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/middleware.py index 6cc71f252..a1545ba33 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/middleware.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/middleware.py @@ -16,7 +16,7 @@ from langchain.agents import create_agent from langchain.chat_models import init_chat_model from langgraph.types import Checkpointer -from app.agents.multi_agent_chat.subagents.shared.spec import ( +from app.agents.chat.multi_agent_chat.subagents.shared.spec import ( SURF_CONTEXT_HINT_PROVIDER_KEY, ) from app.utils.perf import get_perf_logger diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/propagation.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/propagation.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/propagation.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/propagation.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/resume.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/resume.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/resume.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/resume.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/resume_routing.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/resume_routing.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/resume_routing.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/resume_routing.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/spawn_paused.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/spawn_paused.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/spawn_paused.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/spawn_paused.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/task_description.py similarity index 80% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/task_description.py index 73afa6823..3464b889a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/task_description.py @@ -6,7 +6,7 @@ and the ```` block render from the same source. from __future__ import annotations -from app.agents.multi_agent_chat.main_agent.system_prompt.builder.load_md import ( +from app.agents.chat.multi_agent_chat.main_agent.system_prompt.builder.load_md import ( read_prompt_md, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/task_tool.py similarity index 86% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/task_tool.py index eaed9a55f..ab825501a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/task_tool.py @@ -23,7 +23,11 @@ from langchain_core.tools import StructuredTool from langgraph.errors import GraphInterrupt from langgraph.types import Command, Interrupt -from app.agents.multi_agent_chat.subagents.shared.spec import ( +from app.agents.chat.multi_agent_chat.subagents.shared.invocation import ( + EXCLUDED_STATE_KEYS, + subagent_invoke_config, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import ( SURF_CONTEXT_HINT_PROVIDER_KEY, ContextHintProvider, ) @@ -34,13 +38,11 @@ from .config import ( consume_surfsense_resume, drain_parent_null_resume, has_surfsense_resume, - subagent_invoke_config, ) from .constants import ( DEFAULT_SUBAGENT_BATCH_CONCURRENCY, DEFAULT_SUBAGENT_BILLABLE_THRESHOLD, DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS, - EXCLUDED_STATE_KEYS, MAX_SUBAGENT_BATCH_SIZE, ) from .propagation import wrap_with_tool_call_id @@ -80,13 +82,10 @@ _T = TypeVar("_T") async def _ainvoke_with_timeout[T]( coro: Awaitable[_T], *, subagent_type: str, started_at: float ) -> _T: - """Apply :data:`DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS` to ``coro``. + """Apply the subagent invoke timeout to ``coro`` (non-positive disables it). - A non-positive timeout disables the cap (configurable via the - ``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` env var). On expiry the - underlying task is cancelled and :class:`SubagentInvokeTimeoutError` is - raised — the caller wraps it into a synthetic ToolMessage so the - orchestrator can decide what to do. + On expiry the task is cancelled and :class:`SubagentInvokeTimeoutError` is + raised for the caller to turn into a synthetic ToolMessage. """ timeout = DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS if timeout <= 0: @@ -149,12 +148,9 @@ def build_task_tool_with_parent_config( subagent_graphs: dict[str, Runnable] = { spec["name"]: spec["runnable"] for spec in subagents } - # Per-subagent context-hint providers (see ``SurfSenseSubagentSpec``). - # The mapping is sparse: only routes that opted in via ``pack_subagent`` - # appear here, and the value is invoked once per ``task(...)`` call to - # generate a short string prepended to the subagent's first - # ``HumanMessage``. Failures are logged and swallowed — a broken hint - # provider must never prevent the underlying task from running. + # Sparse map of opt-in context-hint providers; each runs once per task() + # call to prepend a string to the subagent's first HumanMessage. Failures + # are swallowed so a broken hint never blocks the task. subagent_hint_providers: dict[str, ContextHintProvider] = { spec["name"]: provider for spec in subagents @@ -176,24 +172,18 @@ def build_task_tool_with_parent_config( def _billable_call_update( subagent_type: str, runtime: ToolRuntime ) -> dict[str, Any]: - """Build the per-call ``billable_calls`` delta + an optional warning. + """Build the per-call ``billable_calls`` delta plus an optional soft-cap warning. - The orchestrator's ``billable_calls`` map is summed by - :func:`_int_counter_merge_reducer`, so we always emit - ``{subagent_type: 1}`` and let the reducer accumulate. If the - cumulative count *after* this call would cross the configured - threshold, we also slip a soft ``messages`` entry into the update - so the orchestrator can read it on its next step and self-limit. - Returning a plain ``dict`` (vs. an extra :class:`Command`) keeps - the helper composable with the existing single/batch return paths. + Always emits ``{subagent_type: 1}`` (a reducer accumulates it); when this + call would cross the threshold, also adds a soft ``messages`` entry so the + orchestrator self-limits on its next step. """ delta: dict[str, Any] = {"billable_calls": {subagent_type: 1}} threshold = DEFAULT_SUBAGENT_BILLABLE_THRESHOLD if threshold <= 0: return delta prior = runtime.state.get("billable_calls") or {} - # ``prior`` may be a plain dict or a reducer-managed mapping; only - # int values are counted so a malformed checkpoint can't crash us. + # Count int values only so a malformed checkpoint can't crash us. prior_total = sum(v for v in prior.values() if isinstance(v, int)) new_total = prior_total + 1 if prior_total < threshold <= new_total: @@ -212,8 +202,7 @@ def build_task_tool_with_parent_config( """Merge the per-call billable counter (and warning) into ``cmd``.""" delta = _billable_call_update(subagent_type, runtime) warn_text = delta.pop("_billable_warn_text", None) - # ``cmd.update`` may be a dict or LangGraph ``UpdateDict``; defensively - # copy so we don't mutate state shared across other tool returns. + # Copy so we don't mutate state shared with other tool returns. update = dict(getattr(cmd, "update", {}) or {}) for key, value in delta.items(): update[key] = value @@ -226,14 +215,10 @@ def build_task_tool_with_parent_config( return Command(update=update) def _safe_message_text(msg: Any) -> str: - """Pull text out of a BaseMessage without trusting the ``.text`` property. + """Pull text out of a BaseMessage without using the ``.text`` property. - ``BaseMessage.text`` walks ``content_blocks`` and crashes with - ``TypeError: 'NoneType' object is not iterable`` when ``content`` is - ``None`` (common for tool-call AIMessages whose payload is purely - structured). ``getattr(msg, "text", None)`` does not catch this - because Python evaluates the property body before falling back to - the default. Read ``content`` directly and coerce defensively. + ``.text`` crashes when ``content`` is ``None`` (common for tool-call + AIMessages), and ``getattr`` won't catch it, so read ``content`` directly. """ try: content = getattr(msg, "content", None) @@ -256,23 +241,18 @@ def build_task_tool_with_parent_config( return str(content) def _build_tool_trace(messages: list[Any]) -> list[dict[str, Any]]: - """Compress the subagent's message stream into a compact tool trace. + """Compress the subagent's messages into a compact tool trace. - Each entry is ``{"tool": , "status": "ok"|"error", "preview": - <≤120 chars>}`` so the orchestrator can show "this is what your - specialist actually did" without dumping the full message stream - back through the prompt. The list is attached to the returned - ToolMessage's ``additional_kwargs`` (under ``"surf_tool_trace"``); - the LLM never sees it, but UI / observability code can pluck it - out of the checkpoint. + Entries (``{tool, status, preview}``) ride on the ToolMessage's + ``additional_kwargs["surf_tool_trace"]`` for UI/observability; the LLM + never sees them. """ trace: list[dict[str, Any]] = [] for msg in messages: tool_name = getattr(msg, "name", None) tool_call_id_attr = getattr(msg, "tool_call_id", None) if not tool_name and not tool_call_id_attr: - # Only ToolMessages have either field; skip AIMessage / - # HumanMessage / SystemMessage frames. + # Only ToolMessages carry either field. continue status = getattr(msg, "status", None) or "ok" preview = _safe_message_text(msg).strip().replace("\n", " ") @@ -306,8 +286,7 @@ def build_task_tool_with_parent_config( ) raise ValueError(msg) message_text = _safe_message_text(messages[-1]).rstrip() - # Tool-trace is purely observability — wrap defensively so a single - # malformed frame never bubbles up and kills the whole user turn. + # Trace is observability-only; never let a bad frame kill the turn. try: tool_trace = _build_tool_trace(messages) except Exception: @@ -318,10 +297,7 @@ def build_task_tool_with_parent_config( tool_trace = [] tool_msg = ToolMessage(message_text, tool_call_id=tool_call_id) if tool_trace: - # ``additional_kwargs`` is a free-form dict on BaseMessage; using - # a ``surf_`` prefix avoids collision with provider-specific keys - # (e.g. Anthropic's ``cache_control``). The LLM doesn't see it; - # consumers (UI, observability) read it off the checkpoint. + # surf_ prefix avoids collision with provider keys (e.g. cache_control). tool_msg.additional_kwargs["surf_tool_trace"] = tool_trace return Command( update={ @@ -359,9 +335,7 @@ def build_task_tool_with_parent_config( } hint = _resolve_context_hint(subagent_type, description, runtime) if hint: - # Prepend as a tagged block so the subagent prompt can pattern-match - # on the section (and a future change can lift it into its own - # ``SystemMessage`` if needed). + # Tagged block so the subagent prompt can pattern-match the section. payload = f"\n{hint}\n\n\n{description}" else: payload = description @@ -372,16 +346,12 @@ def build_task_tool_with_parent_config( results: list[tuple[int, str, dict | str, dict | None]], runtime: ToolRuntime, ) -> Command: - """Combine per-child results into one Command with a combined ToolMessage. + """Combine per-child results into one Command with an aggregate ToolMessage. - ``results`` is a list of ``(task_index, subagent_type, - payload_or_error_text, child_state_update)`` tuples — preserving the - input order so the orchestrator can map each block back to the task - it dispatched. State updates are merged by reducer for keys outside - :data:`EXCLUDED_STATE_KEYS`; everything else (``messages``, ``todos``, - etc.) is replaced by the synthesized aggregate ToolMessage. Every - child also contributes a ``billable_calls`` increment so cost - accounting matches single-mode dispatch. + ``results`` tuples are ``(task_index, subagent_type, payload_or_error, + child_state_update)``; output blocks are sorted by index so the LLM can + map them back to dispatch order, and each child contributes a + ``billable_calls`` increment to match single-mode accounting. """ results.sort(key=lambda r: r[0]) merged_state: dict[str, Any] = {} @@ -422,8 +392,8 @@ def build_task_tool_with_parent_config( } ) if state_update: - # Naive merge: later tasks win on scalar collisions; reducer-backed - # fields (``receipts``, ``files`` etc.) accumulate at apply time. + # Later tasks win on scalar collisions; reducer-backed fields + # accumulate at apply time. merged_state.update(state_update) aggregate = "\n\n".join(message_blocks) aggregate_msg = ToolMessage( @@ -467,11 +437,9 @@ def build_task_tool_with_parent_config( ) -> tuple[int, str, dict | str, dict | None]: """Run one child of a batched ``task`` call under the concurrency cap. - Errors are returned as plain text in slot 2 so a single child's - failure does not abort the whole batch. ``GraphInterrupt`` from a - batched child is currently treated as a hard failure for that child - only — batched HITL is intentionally out of scope for the v1 - rollout (see plan tier 2 item 4 risks). + Errors are returned as text (slot 2) so one child's failure doesn't abort + the batch. A child's ``GraphInterrupt`` is a hard failure for that child: + batched HITL is intentionally out of scope. """ async with semaphore: if subagent_type not in subagent_graphs: @@ -505,8 +473,7 @@ def build_task_tool_with_parent_config( ) return (task_index, subagent_type, str(exc), None) except GraphInterrupt: - # Batched HITL is unsupported in v1 — surface as a failure - # for this child so the rest of the batch still completes. + # Batched HITL unsupported; fail this child so the batch finishes. logger.warning( "Batch child %d (%s) raised GraphInterrupt; batched HITL " "is not supported. Re-dispatch this task as a single " @@ -543,14 +510,11 @@ def build_task_tool_with_parent_config( return (task_index, subagent_type, result, child_state_update) def _coerce_batch_arg(tasks: Any) -> list[dict] | str: - """Rescue common LLM-side malformations of the ``tasks`` argument. + """Rescue common LLM malformations of the ``tasks`` argument. - Some providers serialise an array argument as a JSON-encoded string, - and small models occasionally hand back a single ``{description, - subagent_type}`` dict instead of a one-element array. Both are - recovered here with a WARN log so the issue is visible in metrics - but the user's turn still completes; truly broken shapes return a - plain string that the caller surfaces as the tool error. + Recovers a JSON-encoded array string and a single dict (instead of a + 1-element array), logging a WARN. Unrecoverable shapes return a string + the caller surfaces as the tool error. """ if isinstance(tasks, list): return tasks @@ -585,13 +549,10 @@ def build_task_tool_with_parent_config( async def _adispatch_batch( tasks: list[dict], runtime: ToolRuntime ) -> Command | str: - """Fan-out helper for the ``tasks`` array shape. + """Fan out the ``tasks`` array (size- and concurrency-capped). - Bounded by :data:`MAX_SUBAGENT_BATCH_SIZE` and concurrency-capped - at :data:`DEFAULT_SUBAGENT_BATCH_CONCURRENCY`. Returns a single - :class:`Command` that the LLM sees as one ToolMessage per child, - prefixed with ``[task ]`` so it can map back to the input - order. + Returns one Command; the LLM sees one ``[task ]``-prefixed block + per child, in input order. """ if not tasks: return "tasks: array is empty; nothing to dispatch." @@ -701,17 +662,16 @@ def build_task_tool_with_parent_config( if pending_value is not None: resume_value = consume_surfsense_resume(runtime) if resume_value is None: - # Bridge invariant: a queued resume must accompany any pending - # subagent interrupt. Fall-through replay would silently re-prompt - # the user; raise so the streaming layer surfaces a clear error. + # A pending interrupt must have a queued resume; otherwise replay + # would silently re-prompt the user. Raise instead. raise RuntimeError( f"Subagent {subagent_type!r} has a pending interrupt but no " "surfsense_resume_value on config; resume bridge is broken." ) expected = hitlrequest_action_count(pending_value) resume_value = fan_out_decisions_to_match(resume_value, expected) - # Prevent the parent's resume payload from leaking into subagent - # interrupts via langgraph's parent_scratchpad fallback. + # Stop the parent's resume leaking into subagent interrupts via + # langgraph's parent_scratchpad fallback. drain_parent_null_resume(runtime) with ot.subagent_invoke_span( subagent_type=subagent_type, path=invoke_path @@ -827,10 +787,8 @@ def build_task_tool_with_parent_config( ] = None, ) -> str | Command: atask_start = time.perf_counter() - # Kill switch: when ops flips the spawn-paused flag for this - # workspace, every ``task(...)`` invocation (single- or batch-mode) - # short-circuits with a clear ToolMessage so the orchestrator can - # tell the user what happened and stop hammering downstream APIs. + # Ops kill switch: short-circuit every task() call for this workspace + # so the orchestrator stops hammering downstream APIs. if await is_spawn_paused(search_space_id): logger.warning( "[hitl_route] atask SPAWN_PAUSED: search_space_id=%s tool_call_id=%s", @@ -921,8 +879,8 @@ def build_task_tool_with_parent_config( ) expected = hitlrequest_action_count(pending_value) resume_value = fan_out_decisions_to_match(resume_value, expected) - # Prevent the parent's resume payload from leaking into subagent - # interrupts via langgraph's parent_scratchpad fallback. + # Stop the parent's resume leaking into subagent interrupts via + # langgraph's parent_scratchpad fallback. drain_parent_null_resume(runtime) with ot.subagent_invoke_span( subagent_type=subagent_type, path=invoke_path diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/context_editing/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/context_editing/__init__.py new file mode 100644 index 000000000..0c86c8cbd --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/context_editing/__init__.py @@ -0,0 +1,15 @@ +"""Context-editing middleware: spill + clear-tool-uses passes (impl + builder).""" + +from .builder import build_context_editing_mw +from .middleware import ( + ClearToolUsesEdit, + SpillingContextEditingMiddleware, + SpillToBackendEdit, +) + +__all__ = [ + "ClearToolUsesEdit", + "SpillToBackendEdit", + "SpillingContextEditingMiddleware", + "build_context_editing_mw", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/context_editing.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/context_editing/builder.py similarity index 82% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/context_editing.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/context_editing/builder.py index e8f99933e..1d7a2f47f 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/context_editing.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/context_editing/builder.py @@ -7,18 +7,18 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.main_agent.context_prune.prune_tool_names import ( +from app.agents.chat.multi_agent_chat.main_agent.context_prune.prune_tool_names import ( safe_exclude_tools, ) -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.middleware import ( +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled + +from .middleware import ( ClearToolUsesEdit, SpillingContextEditingMiddleware, SpillToBackendEdit, ) -from ..shared.flags import enabled - def build_context_editing_mw( *, diff --git a/surfsense_backend/app/agents/new_chat/middleware/context_editing.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/context_editing/middleware.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/middleware/context_editing.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/context_editing/middleware.py diff --git a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/dedup_hitl.py similarity index 63% rename from surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/dedup_hitl.py index a6d2ce310..7710731ab 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/dedup_hitl.py @@ -1,4 +1,4 @@ -"""Middleware that deduplicates HITL tool calls within a single LLM response. +"""Drop duplicate HITL tool calls before execution. When the LLM emits multiple calls to the same HITL tool with the same primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``), @@ -9,72 +9,33 @@ the duplicate call is stripped from the AIMessage that gets checkpointed. That means it is also safe across LangGraph ``interrupt()`` boundaries: the removed call will never appear on graph resume. -Dedup-key resolution order: +Dedup-key resolution order (read from each tool's own ``metadata``): -1. :class:`ToolDefinition.dedup_key` — callable provided by the registry - entry. This is the canonical mechanism. -2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg name; - used by MCP / Composio tools whose schemas the registry doesn't see. +1. ``tool.metadata["dedup_key"]`` — callable mapping the args dict to a + stable signature string. This is the canonical mechanism. +2. ``tool.metadata["hitl_dedup_key"]`` — string naming a primary arg; + used by MCP / Composio tools that only expose a single key field. A tool with no resolver from either path simply opts out of dedup. """ from __future__ import annotations -import json import logging -from collections.abc import Callable +from collections.abc import Sequence from typing import Any from langchain.agents.middleware import AgentMiddleware, AgentState +from langchain_core.tools import BaseTool from langgraph.runtime import Runtime +from app.agents.chat.multi_agent_chat.shared.middleware.dedup_tool_calls import ( + DedupResolver, + wrap_dedup_key_by_arg_name, +) + logger = logging.getLogger(__name__) -# Resolver type — given the tool ``args`` dict returns a stable -# string used to dedupe consecutive calls. ``None`` means no dedup. -DedupResolver = Callable[[dict[str, Any]], str] - - -def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver: - """Adapt a string-arg name into a :data:`DedupResolver`. - - Convenience helper used by registry entries that just want to dedupe - on a single arg's lowercased value (the most common case for native - HITL tools like ``send_gmail_email`` keyed on ``subject``). - - Example:: - - ToolDefinition( - name="send_gmail_email", - ..., - dedup_key=wrap_dedup_key_by_arg_name("subject"), - ) - """ - - def _resolver(args: dict[str, Any]) -> str: - return str(args.get(arg_name, "")).lower() - - return _resolver - - -def dedup_key_full_args(args: dict[str, Any]) -> str: - """Resolver that collapses calls only when **every** argument is identical. - - Safe default for tools where no single field uniquely identifies a call - (e.g. MCP tools whose first required field is a shared workspace id). - """ - - try: - return json.dumps(args, sort_keys=True, default=str) - except (TypeError, ValueError): - return repr(sorted(args.items())) if isinstance(args, dict) else repr(args) - - -# Backwards-compatible alias for code that imported the original -# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`. -_wrap_string_key = wrap_dedup_key_by_arg_name - class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] """Remove duplicate HITL tool calls from a single LLM response. @@ -84,9 +45,8 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] The dedup-resolver map is built from two sources, in priority order: - 1. ``tool.metadata["dedup_key"]`` — callable provided by the registry's - ``ToolDefinition.dedup_key``. Receives the args dict and returns - a string signature. This is the canonical mechanism. + 1. ``tool.metadata["dedup_key"]`` — callable that receives the args dict + and returns a string signature. This is the canonical mechanism. 2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg name; primarily used by MCP / Composio tools. """ @@ -162,3 +122,7 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] updated_msg = last_msg.model_copy(update={"tool_calls": deduped}) return {"messages": [updated_msg]} + + +def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware: + return DedupHITLToolCallsMiddleware(agent_tools=list(tools)) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/doom_loop/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/doom_loop/__init__.py new file mode 100644 index 000000000..d0a1126a5 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/doom_loop/__init__.py @@ -0,0 +1,9 @@ +"""Doom-loop middleware: detect repeated identical tool calls (impl + builder).""" + +from .builder import build_doom_loop_mw +from .middleware import DoomLoopMiddleware + +__all__ = [ + "DoomLoopMiddleware", + "build_doom_loop_mw", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/doom_loop.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/doom_loop/builder.py similarity index 58% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/doom_loop.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/doom_loop/builder.py index d67b8d518..96024adfd 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/doom_loop.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/doom_loop/builder.py @@ -2,10 +2,10 @@ from __future__ import annotations -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.middleware import DoomLoopMiddleware +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled -from ..shared.flags import enabled +from .middleware import DoomLoopMiddleware def build_doom_loop_mw(flags: AgentFeatureFlags) -> DoomLoopMiddleware | None: diff --git a/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/doom_loop/middleware.py similarity index 99% rename from surfsense_backend/app/agents/new_chat/middleware/doom_loop.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/doom_loop/middleware.py index a7901c010..4f9b4af1c 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/doom_loop/middleware.py @@ -16,7 +16,7 @@ This ships **OFF by default** until the frontend explicitly handles ``context.permission == "doom_loop"`` interrupts. Wire format: uses SurfSense's existing ``interrupt()`` payload shape -(see ``app/agents/new_chat/tools/hitl.py``): +(see ``app/agents/shared/tools/hitl.py``): { "type": "permission_ask", diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/__init__.py new file mode 100644 index 000000000..b5b0267ff --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/__init__.py @@ -0,0 +1,13 @@ +"""End-of-turn KB persistence middleware (main-agent only).""" + +from .builder import build_kb_persistence_mw +from .middleware import ( + KnowledgeBasePersistenceMiddleware, + commit_staged_filesystem_state, +) + +__all__ = [ + "KnowledgeBasePersistenceMiddleware", + "build_kb_persistence_mw", + "commit_staged_filesystem_state", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/kb_persistence.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/builder.py similarity index 78% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/kb_persistence.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/builder.py index 4b27581e7..7e8e06570 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/kb_persistence.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/builder.py @@ -2,8 +2,11 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.middleware import KnowledgeBasePersistenceMiddleware +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode + +from .middleware import ( + KnowledgeBasePersistenceMiddleware, +) def build_kb_persistence_mw( diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py similarity index 81% rename from surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py index c88dced85..747ddacd3 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py @@ -1,33 +1,19 @@ """End-of-turn persistence for the cloud-mode SurfSense filesystem. -This middleware runs ``aafter_agent`` once per turn (cloud only). It commits -all staged folder creations, file moves, content writes/edits, file deletes -(``rm``), and directory deletes (``rmdir``) to Postgres in a single ordered -pass: +Runs ``aafter_agent`` once per turn (cloud only), committing staged folder +creates, moves, writes/edits, and ``rm``/``rmdir`` to Postgres in one ordered +pass. Order matters: moves resolve before writes (so write-then-move lands at +the final path), and file deletes run before directory deletes (so a same-turn +``rm /a/x.md`` + ``rmdir /a`` works). -1. Materialize ``staged_dirs`` into ``Folder`` rows. -2. Apply ``pending_moves`` in order (chained moves resolved via - ``doc_id_by_path``). -3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move - sequences commit at the final path. Paths queued for ``rm`` this turn - are dropped here so a write+rm sequence doesn't recreate the doc. -4. Commit content writes / edits for ``/documents/*`` paths, skipping - ``temp_*`` basenames. -5. Apply ``pending_deletes`` (``rm``) — file deletes run BEFORE directory - deletes so a same-turn ``rm /a/x.md`` + ``rmdir /a`` sequence works. -6. Apply ``pending_dir_deletes`` (``rmdir``); re-verifies emptiness against - the post-step-5 DB state. +When ``flags.enable_action_log`` is on, each destructive op also snapshots a +``DocumentRevision`` / ``FolderRevision`` for revert. For ``rm``/``rmdir`` the +snapshot and DELETE share a SAVEPOINT, so a failed snapshot aborts the delete +rather than making the data silently irreversible. -When ``flags.enable_action_log`` is on every destructive op also writes a -``DocumentRevision`` / ``FolderRevision`` snapshot bound to the -originating ``AgentActionLog`` row via ``tool_call_id``. ``rm``/``rmdir`` -share a single ``SAVEPOINT`` with their snapshot — if the snapshot fails -the DELETE rolls back and we surface the error rather than silently -making the data irreversible. - -The commit body is exposed as a free function ``commit_staged_filesystem_state`` -so the optional stream-task fallback (``stream_new_chat.py``) can call the -exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect). +The commit body is a free function (``commit_staged_filesystem_state``) so the +stream-task fallback can run the identical routine when ``aafter_agent`` was +skipped (e.g. client disconnect). """ from __future__ import annotations @@ -45,17 +31,22 @@ from sqlalchemy import delete, select, update from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.feature_flags import get_flags -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.path_resolver import ( +from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.receipts.receipt import ( + Receipt, + make_receipt, +) +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) +from app.agents.chat.multi_agent_chat.shared.state.reducers import _CLEAR +from app.agents.chat.runtime.path_resolver import ( DOCUMENTS_ROOT, parse_documents_path, safe_folder_segment, virtual_path_to_doc, ) -from app.agents.new_chat.state_reducers import _CLEAR -from app.agents.shared.receipt import Receipt, make_receipt from app.db import ( AgentActionLog, Chunk, @@ -211,11 +202,9 @@ async def _create_document( virtual_path, search_space_id, ) - # Filesystem-parity invariant: the only thing that *must* be unique is - # the path. Two notes can legitimately share content (e.g. ``cp a b``). - # Guard against the path-derived ``unique_identifier_hash`` constraint - # so we surface a clean ValueError instead of letting the INSERT poison - # the session with an IntegrityError. + # Pre-check the path-derived unique_identifier_hash so a duplicate path + # surfaces as a clean ValueError instead of an INSERT IntegrityError that + # poisons the session. Content is intentionally not unique (cp a b). path_collision = await session.execute( select(Document.id).where( Document.search_space_id == search_space_id, @@ -227,13 +216,6 @@ async def _create_document( f"a document already exists at path '{virtual_path}' " "(unique_identifier_hash collision)" ) - # ``content_hash`` is intentionally NOT checked for uniqueness here. - # In a real filesystem two files at different paths can hold identical - # bytes, and the agent's ``write_file`` path needs that semantic to - # support copy/duplicate operations. The hash remains useful as a - # change-detection hint for connector indexers, which still consult it - # via :func:`check_duplicate_document` but do so with a non-unique - # lookup (``.first()``). content_hash = generate_content_hash(content, search_space_id) doc = Document( title=title, @@ -430,15 +412,9 @@ async def _mark_action_reversible( ) -> None: """Flip ``agent_action_log.reversible = TRUE`` for ``action_id``. - Best-effort: caller may invoke from inside a SAVEPOINT and treat - failure as a soft demotion (snapshot persists, just no Revert button). - - Callers should also call ``_dispatch_reversibility_update`` (defined - below) AFTER the enclosing SAVEPOINT block exits successfully so the - chat tool card can light up its Revert button without - re-fetching ``GET /threads/.../actions``. Dispatching from inside the - SAVEPOINT would risk emitting "reversible=true" for rows whose - update gets rolled back if the surrounding destructive op fails. + Pair with ``_dispatch_reversibility_update`` *after* the enclosing + SAVEPOINT commits, so the UI never sees ``reversible=true`` for a row whose + update later rolls back. """ if action_id is None: return @@ -450,22 +426,11 @@ async def _mark_action_reversible( async def _dispatch_reversibility_update(action_id: int | None) -> None: - """Best-effort dispatch of an ``action_log_updated`` custom event. + """Emit an ``action_log_updated`` SSE event so the Revert button lights up. - Surfaces the post-SAVEPOINT reversibility flip to the SSE layer so - the chat tool card can flip its Revert button live. Defensive: - failures are logged at debug level and swallowed; the - REST endpoint ``GET /threads/.../actions`` is still authoritative. - - .. warning:: - Inside :func:`commit_staged_filesystem_state` we DEFER all - dispatches until the outer ``session.commit()`` succeeds — see - the ``deferred_dispatches`` queue in that function. Dispatching - from inside a SAVEPOINT block while the outer transaction is - still pending would emit ``reversible=true`` for rows whose - snapshots get rolled back if the outer commit fails. Direct - callers (e.g. the optional stream-task fallback) that own the - full session lifetime can still call this helper inline. + Best-effort (failures swallowed; the REST actions endpoint is + authoritative). Inside :func:`commit_staged_filesystem_state` this is + deferred until after the outer commit via ``deferred_dispatches``. """ if action_id is None: return @@ -484,12 +449,9 @@ async def _dispatch_reversibility_update(action_id: int | None) -> None: # --------------------------------------------------------------------------- # Snapshot helpers # --------------------------------------------------------------------------- -# -# Best-effort helpers swallow + log so a snapshot failure can never break -# the destructive op for non-destructive tools (write/edit/move/mkdir). -# Strict helpers run inside the SAME ``begin_nested()`` SAVEPOINT as the -# destructive DELETE — failure aborts the savepoint and leaves the doc / -# folder intact, so revertable ops never become irreversible silently. +# Best-effort variants (write/edit/move/mkdir) swallow failures. Strict +# variants (rm/rmdir) share the destructive op's SAVEPOINT so a snapshot +# failure aborts the delete instead of making it silently irreversible. def _doc_revision_payload( @@ -699,15 +661,9 @@ async def commit_staged_filesystem_state( ) -> dict[str, Any] | None: """Commit all staged filesystem changes; return the state delta for reducers. - Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` - and the optional stream-task fallback. - - When ``flags.enable_action_log`` is on every destructive op also writes - a ``DocumentRevision`` / ``FolderRevision`` snapshot bound to the - originating ``AgentActionLog`` row via ``tool_call_id``. Snapshot - durability is best-effort for non-destructive ops and STRICT for - ``rm``/``rmdir`` (snapshot + DELETE share a SAVEPOINT — snapshot - failure aborts the delete). + Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` and + the stream-task fallback. See the module docstring for ordering and the + action-log snapshot/revert semantics. """ if filesystem_mode != FilesystemMode.CLOUD: return None @@ -766,8 +722,7 @@ async def commit_staged_filesystem_state( flags = get_flags() snapshot_enabled = flags.enable_action_log - # De-duplicate pending deletes per-path while preserving the latest - # tool_call_id (the one the user is most likely to revert via the UI). + # De-dup deletes per-path, keeping the latest tool_call_id (likeliest revert). file_delete_paths: dict[str, str] = {} for entry in pending_deletes: if not isinstance(entry, dict): @@ -791,22 +746,14 @@ async def commit_staged_filesystem_state( applied_moves: list[dict[str, Any]] = [] doc_id_path_tombstones: dict[str, int | None] = {} tree_changed = False - # Reversibility-flip dispatches are deferred until AFTER the outer - # ``session.commit()`` succeeds. Dispatching from inside the - # SAVEPOINT chain while the outer transaction is still pending - # would emit ``reversible=true`` for rows whose snapshots get rolled - # back if the final commit raises. Snapshot helpers append on - # success; we drain this list after commit and silently abandon it - # on rollback so the UI stays consistent with durable state. + # Reversibility-flip dispatches are drained only after the outer commit + # succeeds (and abandoned on rollback), so the UI never sees reversible=true + # for a snapshot that didn't durably land. deferred_dispatches: list[int] = [] try: async with shielded_async_session() as session: - # ------------------------------------------------------------------ - # Resolve action-id bindings up front. One SELECT per turn for all - # tool_call_ids, NOT one per op — important because a turn that - # touches 50 paths would otherwise issue 50 lookups. - # ------------------------------------------------------------------ + # Resolve all action-id bindings in one SELECT per turn, not per op. action_id_by_call: dict[str, int] = {} if snapshot_enabled and thread_id is not None: tool_call_ids: set[str] = set() @@ -839,10 +786,7 @@ async def commit_staged_filesystem_state( next(iter(action_id_by_call), None) if action_id_by_call else None ) - # ------------------------------------------------------------------ - # 1. staged_dirs -> Folder rows. Snapshot post-flush so the new - # folder_id is available for the FK. - # ------------------------------------------------------------------ + # 1. staged_dirs -> Folder rows (snapshot post-flush for the FK). for folder_path in staged_dirs: if not isinstance(folder_path, str): continue @@ -863,7 +807,6 @@ async def commit_staged_filesystem_state( tcid = staged_dir_tool_calls.get(folder_path) action_id = _action_id_for(tcid) if action_id is not None: - # Re-read the folder for the snapshot. result = await session.execute( select(Folder).where(Folder.id == folder_id) ) @@ -878,16 +821,13 @@ async def commit_staged_filesystem_state( deferred_dispatches=deferred_dispatches, ) - # ------------------------------------------------------------------ - # 2. pending_moves. Snapshot pre-move (in-place restore on revert). - # ------------------------------------------------------------------ + # 2. pending_moves (snapshot pre-move for in-place restore on revert). for move in pending_moves: source = str(move.get("source") or "") if snapshot_enabled and source: tcid = str(move.get("tool_call_id") or "") action_id = _action_id_for(tcid) if action_id is not None: - # Resolve the doc to snapshot BEFORE we mutate it. doc_id_pre = doc_id_by_path.get(source) document_pre: Document | None = None if doc_id_pre is not None: @@ -937,10 +877,8 @@ async def commit_staged_filesystem_state( path = move_alias[path] return path - # ------------------------------------------------------------------ - # 3. dirty_paths -> writes/edits. Skip any path queued for ``rm`` - # this turn so a write+rm sequence doesn't recreate the doc. - # ------------------------------------------------------------------ + # 3. dirty_paths -> writes/edits. Paths queued for rm this turn are + # skipped so a write+rm sequence doesn't recreate the doc. kb_dirty_seen: set[str] = set() kb_dirty: list[str] = [] kb_dirty_origin: dict[str, str] = {} @@ -969,9 +907,7 @@ async def commit_staged_filesystem_state( continue content = "\n".join(file_data.get("content") or []) doc_id = doc_id_by_path.get(path) - # Path ↔ tool_call_id binding: the dirty_paths list dedupes via - # _add_unique_reducer, so we look up the latest tool_call_id by - # path (or by the un-renamed origin). + # Look up tool_call_id by final path or its pre-rename origin. origin = kb_dirty_origin.get(path, path) tcid = dirty_path_tool_calls.get(path) or dirty_path_tool_calls.get( origin @@ -979,12 +915,9 @@ async def commit_staged_filesystem_state( action_id = _action_id_for(tcid) if doc_id is None: - # The in-memory ``doc_id_by_path`` is per-thread and starts - # empty in every new chat. If the agent writes to a path - # that already exists in the DB (e.g. a previous chat's - # ``notes.md``), we must NOT try to INSERT — it would hit - # ``unique_identifier_hash`` (path-derived). Look up the - # existing doc and update it in place instead. + # doc_id_by_path is per-thread and empty in a new chat, so a + # write to a path already in the DB must update in place, not + # INSERT (which would hit the path-derived unique hash). existing = await virtual_path_to_doc( session, search_space_id=search_space_id, @@ -1033,12 +966,9 @@ async def commit_staged_filesystem_state( } ) else: - # Fresh create. Wrap each create in a SAVEPOINT so a - # residual ``IntegrityError`` (e.g. a deployment that - # hasn't run migration 133 yet, where - # ``documents.content_hash`` still carries its legacy - # global UNIQUE constraint) rolls back only this one - # create instead of poisoning the whole turn. + # Fresh create, wrapped in a SAVEPOINT so a residual + # IntegrityError (e.g. pre-migration-133 content_hash UNIQUE) + # rolls back only this create, not the whole turn. placeholder_revision_id: int | None = None if snapshot_enabled and action_id is not None: placeholder_revision_id = await _snapshot_document_pre_create( @@ -1061,8 +991,7 @@ async def commit_staged_filesystem_state( logger.warning( "kb_persistence: skipping %s create: %s", path, exc ) - # Roll back the placeholder revision since the create - # never happened. + # Create never happened; drop its placeholder revision. if placeholder_revision_id is not None: await session.execute( delete(DocumentRevision).where( @@ -1109,19 +1038,14 @@ async def commit_staged_filesystem_state( ) tree_changed = True - # ------------------------------------------------------------------ - # 4. pending_deletes -> ``rm``. STRICT durability: snapshot + DELETE - # share a SAVEPOINT. If the snapshot insert fails, the DELETE - # rolls back too and we surface the error rather than silently - # making the data irreversible. - # ------------------------------------------------------------------ + # 4. pending_deletes -> rm. Strict: snapshot + DELETE share a + # SAVEPOINT, so a failed snapshot rolls the delete back too. for raw_path, tcid in file_delete_paths.items(): final = _final_path(raw_path) if not final.startswith(DOCUMENTS_ROOT + "/"): continue action_id = _action_id_for(tcid) - # Resolve the doc. doc_id_for_delete = doc_id_by_path.get(final) document_to_delete: Document | None = None if doc_id_for_delete is not None: @@ -1150,7 +1074,6 @@ async def commit_staged_filesystem_state( try: async with session.begin_nested(): - # Strict: snapshot first; failure aborts the delete. if snapshot_enabled and action_id is not None: chunks = await _load_chunks_for_snapshot( session, doc_id=doc_pk @@ -1179,10 +1102,7 @@ async def commit_staged_filesystem_state( ) continue - # B1 — SAVEPOINT released. Defer the reversibility-flip - # dispatch until AFTER the outer commit succeeds so we - # never tell the UI a row is reversible if its snapshot - # gets rolled back. + # Defer the reversibility flip until after the outer commit. if snapshot_enabled and action_id is not None: deferred_dispatches.append(int(action_id)) @@ -1201,11 +1121,8 @@ async def commit_staged_filesystem_state( ) tree_changed = True - # ------------------------------------------------------------------ - # 5. pending_dir_deletes -> ``rmdir``. STRICT durability + final - # emptiness check (after step 4's deletes have run, an "empty - # mid-turn" directory really IS empty in DB now). - # ------------------------------------------------------------------ + # 5. pending_dir_deletes -> rmdir. Strict, and re-checks emptiness + # against post-step-4 DB state. for raw_path, tcid in dir_delete_paths.items(): final = _final_path(raw_path) if not final.startswith(DOCUMENTS_ROOT + "/"): @@ -1226,7 +1143,6 @@ async def commit_staged_filesystem_state( ) continue - # Re-check emptiness against in-DB state. docs_in_folder = await session.execute( select(Document.id) .where(Document.folder_id == folder_id) @@ -1291,10 +1207,7 @@ async def commit_staged_filesystem_state( ) continue - # B1 — SAVEPOINT released. Defer the reversibility-flip - # dispatch until AFTER the outer commit succeeds so we - # never tell the UI a row is reversible if its snapshot - # gets rolled back. + # Defer the reversibility flip until after the outer commit. if snapshot_enabled and action_id is not None: deferred_dispatches.append(int(action_id)) @@ -1314,18 +1227,13 @@ async def commit_staged_filesystem_state( logger.exception( "kb_persistence: commit failed (search_space=%s)", search_space_id ) - # Outer commit raised — every SAVEPOINT-released change above - # (snapshots + reversibility flips) is now rolled back. Drop - # the deferred SSE dispatches so the UI stays consistent with - # durable state. + # Outer commit raised: everything above rolled back, so drop the + # deferred dispatches. deferred_dispatches.clear() return None - # Outer commit succeeded; flush deferred reversibility-flip - # dispatches now so the chat tool card can light up its Revert - # button without re-fetching ``GET /threads/.../actions``. De-dup - # to avoid emitting the same id twice (e.g. write-then-rm in the - # same turn dispatches once for each snapshot site). + # Commit succeeded; flush deferred reversibility flips (de-duped, since + # write-then-rm in one turn appends an id per snapshot site). if deferred_dispatches and dispatch_events: for action_id in dict.fromkeys(deferred_dispatches): try: @@ -1371,9 +1279,8 @@ async def commit_staged_filesystem_state( p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX) ] - # Tombstone every committed-delete path so a stale ``state["files"]`` entry - # (which als_info would otherwise interpret as content) cannot survive into - # the next turn and make a now-empty folder look non-empty. + # Tombstone committed-delete paths so a stale state["files"] entry can't + # survive into the next turn and make a now-empty folder look non-empty. deleted_file_paths = [ str(payload.get("virtualPath") or "") for payload in committed_deletes @@ -1394,11 +1301,8 @@ async def commit_staged_filesystem_state( "dirty_path_tool_calls": {_CLEAR: True}, } - # Emit one Receipt per committed mutation, folded into ``state['receipts']`` - # via ``_list_append_reducer``. The receipts surface what actually committed - # (post-savepoint) rather than what the LLM intended; the orchestrator uses - # them as ground truth in the ```` teaching. KB writes do not - # have public verifiable URLs, so ``verifiable_url`` stays unset. + # One Receipt per committed mutation: ground truth (post-savepoint) for the + # orchestrator's teaching. KB writes have no public URL. receipts: list[Receipt] = [] def _kb_receipt( @@ -1439,8 +1343,6 @@ async def commit_staged_filesystem_state( external_id=payload.get("id"), ) for payload in applied_moves: - # ``applied_moves`` rows carry the destination ``virtualPath`` because - # the move has already landed in the DB by the time we reach this code. path = str(payload.get("virtualPath") or "") _kb_receipt( type="file", @@ -1480,9 +1382,7 @@ async def commit_staged_filesystem_state( if tree_changed: delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1 - # Avoid 'unused' lint when turn_id_for_revision was only useful for - # diagnostic purposes inside the SAVEPOINT chain above. - _ = turn_id_for_revision + _ = turn_id_for_revision # diagnostic-only; silence unused lint logger.info( "kb_persistence: commit (search_space=%s) creates=%d updates=%d " diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/knowledge_priority.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/knowledge_priority.py similarity index 81% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/knowledge_priority.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/knowledge_priority.py index 27cee8b37..310dd676c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/knowledge_priority.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/knowledge_priority.py @@ -4,8 +4,10 @@ from __future__ import annotations from langchain_core.language_models import BaseChatModel -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.middleware import KnowledgePriorityMiddleware +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import ( + KnowledgePriorityMiddleware, +) from app.services.llm_service import get_planner_llm diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/knowledge_tree/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/knowledge_tree/__init__.py new file mode 100644 index 000000000..f2d456b34 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/knowledge_tree/__init__.py @@ -0,0 +1,9 @@ +"""Knowledge-tree middleware: injection, cloud only (impl + builder).""" + +from .builder import build_knowledge_tree_mw +from .middleware import KnowledgeTreeMiddleware + +__all__ = [ + "KnowledgeTreeMiddleware", + "build_knowledge_tree_mw", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/knowledge_tree.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/knowledge_tree/builder.py similarity index 80% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/knowledge_tree.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/knowledge_tree/builder.py index fb4511067..644d1e55a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/knowledge_tree.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/knowledge_tree/builder.py @@ -4,8 +4,9 @@ from __future__ import annotations from langchain_core.language_models import BaseChatModel -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.middleware import KnowledgeTreeMiddleware +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode + +from .middleware import KnowledgeTreeMiddleware def build_knowledge_tree_mw( diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/knowledge_tree/middleware.py similarity index 97% rename from surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/knowledge_tree/middleware.py index 6bd6430d1..a0c62834a 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/knowledge_tree/middleware.py @@ -33,9 +33,11 @@ from langchain_core.messages import SystemMessage from langgraph.runtime import Runtime from sqlalchemy import select -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.path_resolver import ( +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) +from app.agents.chat.runtime.path_resolver import ( DOCUMENTS_ROOT, PathIndex, build_path_index, diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/memory/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/memory/__init__.py new file mode 100644 index 000000000..0106234c0 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/memory/__init__.py @@ -0,0 +1,5 @@ +"""User/team memory injection middleware (main-agent only).""" + +from .builder import build_memory_mw + +__all__ = ["build_memory_mw"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/memory.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/memory/builder.py similarity index 86% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/memory.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/memory/builder.py index 9316b3e21..4ea171e13 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/memory.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/memory/builder.py @@ -2,9 +2,10 @@ from __future__ import annotations -from app.agents.new_chat.middleware import MemoryInjectionMiddleware from app.db import ChatVisibility +from .middleware import MemoryInjectionMiddleware + def build_memory_mw( *, diff --git a/surfsense_backend/app/agents/new_chat/middleware/memory_injection.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/memory/middleware.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/middleware/memory_injection.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/memory/middleware.py diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/noop_injection/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/noop_injection/__init__.py new file mode 100644 index 000000000..c4c004618 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/noop_injection/__init__.py @@ -0,0 +1,9 @@ +"""Noop-injection middleware: provider-compat _noop tool (impl + builder).""" + +from .builder import build_noop_injection_mw +from .middleware import NoopInjectionMiddleware + +__all__ = [ + "NoopInjectionMiddleware", + "build_noop_injection_mw", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/noop_injection.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/noop_injection/builder.py similarity index 59% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/noop_injection.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/noop_injection/builder.py index 6e6467ad0..774cb0f46 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/noop_injection.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/noop_injection/builder.py @@ -2,10 +2,10 @@ from __future__ import annotations -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.middleware import NoopInjectionMiddleware +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled -from ..shared.flags import enabled +from .middleware import NoopInjectionMiddleware def build_noop_injection_mw(flags: AgentFeatureFlags) -> NoopInjectionMiddleware | None: diff --git a/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/noop_injection/middleware.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/middleware/noop_injection.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/noop_injection/middleware.py diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/otel_span/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/otel_span/__init__.py new file mode 100644 index 000000000..801d08962 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/otel_span/__init__.py @@ -0,0 +1,9 @@ +"""OTel-span middleware: spans on model and tool calls (impl + builder).""" + +from .builder import build_otel_mw +from .middleware import OtelSpanMiddleware + +__all__ = [ + "OtelSpanMiddleware", + "build_otel_mw", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/otel.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/otel_span/builder.py similarity index 53% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/otel.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/otel_span/builder.py index bd7516e65..fe3bce4c5 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/otel.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/otel_span/builder.py @@ -2,10 +2,10 @@ from __future__ import annotations -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.middleware import OtelSpanMiddleware +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled -from ..shared.flags import enabled +from .middleware import OtelSpanMiddleware def build_otel_mw(flags: AgentFeatureFlags) -> OtelSpanMiddleware | None: diff --git a/surfsense_backend/app/agents/new_chat/middleware/otel_span.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/otel_span/middleware.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/middleware/otel_span.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/otel_span/middleware.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/plugins.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/plugins.py similarity index 86% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/plugins.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/plugins.py index 4418e3806..43f4136ec 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/plugins.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/plugins.py @@ -7,15 +7,15 @@ from typing import Any from langchain_core.language_models import BaseChatModel -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.plugin_loader import ( +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled +from app.db import ChatVisibility + +from ..plugins.loader import ( PluginContext, load_allowed_plugin_names_from_env, load_plugin_middlewares, ) -from app.db import ChatVisibility - -from ..shared.flags import enabled def build_plugin_middlewares( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/skills.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/skills.py similarity index 71% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/skills.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/skills.py index 63a57c5a0..13c62e817 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/skills.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/skills.py @@ -6,14 +6,11 @@ import logging from deepagents.middleware.skills import SkillsMiddleware -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.middleware import ( - build_skills_backend_factory, - default_skills_sources, -) +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled -from ..shared.flags import enabled +from ..skills.backends import build_skills_backend_factory, default_skills_sources def build_skills_mw( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/stack.py similarity index 73% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/stack.py index 3b20d8915..6b75688dd 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/stack.py @@ -20,50 +20,66 @@ from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.types import Checkpointer -from app.agents.multi_agent_chat.subagents import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.memory import ( + build_memory_mw, +) +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.middleware.anthropic_cache import ( + build_anthropic_cache_mw, +) +from app.agents.chat.multi_agent_chat.shared.middleware.compaction import ( + build_compaction_mw, +) +from app.agents.chat.multi_agent_chat.shared.middleware.kb_context_projection import ( + build_kb_context_projection_mw, +) +from app.agents.chat.multi_agent_chat.shared.middleware.patch_tool_calls import ( + build_patch_tool_calls_mw, +) +from app.agents.chat.multi_agent_chat.shared.middleware.resilience import ( + build_resilience_middlewares, +) +from app.agents.chat.multi_agent_chat.shared.middleware.todos import build_todos_mw +from app.agents.chat.multi_agent_chat.shared.permissions import ( + build_permission_mw, +) +from app.agents.chat.multi_agent_chat.subagents import ( build_subagents, get_subagents_to_exclude, ) -from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.agent import ( +from app.agents.chat.multi_agent_chat.subagents.builtins.knowledge_base.agent import ( READONLY_NAME as KB_READONLY_NAME, build_readonly_subagent as build_kb_readonly_subagent, ) -from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.ask_knowledge_base_tool import ( +from app.agents.chat.multi_agent_chat.subagents.builtins.knowledge_base.ask_knowledge_base_tool import ( build_ask_knowledge_base_tool, ) -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.subagents.middleware_stack import ( + build_subagent_middleware_stack, +) from app.db import ChatVisibility -from .main_agent.action_log import build_action_log_mw -from .main_agent.anonymous_doc import build_anonymous_doc_mw -from .main_agent.busy_mutex import build_busy_mutex_mw -from .main_agent.checkpointed_subagent_middleware import ( +from .action_log import build_action_log_mw +from .anonymous_document import build_anonymous_doc_mw +from .busy_mutex import build_busy_mutex_mw +from .checkpointed_subagent_middleware import ( SurfSenseCheckpointedSubAgentMiddleware, ) -from .main_agent.checkpointed_subagent_middleware.task_description import ( +from .checkpointed_subagent_middleware.task_description import ( TASK_TOOL_DESCRIPTION, ) -from .main_agent.context_editing import build_context_editing_mw -from .main_agent.dedup_hitl import build_dedup_hitl_mw -from .main_agent.doom_loop import build_doom_loop_mw -from .main_agent.kb_persistence import build_kb_persistence_mw -from .main_agent.knowledge_priority import build_knowledge_priority_mw -from .main_agent.knowledge_tree import build_knowledge_tree_mw -from .main_agent.noop_injection import build_noop_injection_mw -from .main_agent.otel import build_otel_mw -from .main_agent.plugins import build_plugin_middlewares -from .main_agent.repair import build_repair_mw -from .main_agent.skills import build_skills_mw -from .shared.anthropic_cache import build_anthropic_cache_mw -from .shared.compaction import build_compaction_mw -from .shared.kb_context_projection import build_kb_context_projection_mw -from .shared.memory import build_memory_mw -from .shared.patch_tool_calls import build_patch_tool_calls_mw -from .shared.permissions import build_permission_mw -from .shared.resilience import build_resilience_middlewares -from .shared.todos import build_todos_mw -from .subagent.middleware_stack import build_subagent_middleware_stack +from .context_editing import build_context_editing_mw +from .dedup_hitl import build_dedup_hitl_mw +from .doom_loop import build_doom_loop_mw +from .kb_persistence import build_kb_persistence_mw +from .knowledge_priority import build_knowledge_priority_mw +from .knowledge_tree import build_knowledge_tree_mw +from .noop_injection import build_noop_injection_mw +from .otel_span import build_otel_mw +from .plugins import build_plugin_middlewares +from .skills import build_skills_mw +from .tool_call_repair import build_repair_mw def build_main_agent_deepagent_middleware( diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/tool_call_repair/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/tool_call_repair/__init__.py new file mode 100644 index 000000000..1e6d93750 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/tool_call_repair/__init__.py @@ -0,0 +1,9 @@ +"""Tool-call-repair middleware: fix miscased/unknown tool names (impl + builder).""" + +from .builder import build_repair_mw +from .middleware import ToolCallNameRepairMiddleware + +__all__ = [ + "ToolCallNameRepairMiddleware", + "build_repair_mw", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/repair.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/tool_call_repair/builder.py similarity index 83% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/repair.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/tool_call_repair/builder.py index 378b61be1..a1cc558b2 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/repair.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/tool_call_repair/builder.py @@ -6,10 +6,10 @@ from collections.abc import Sequence from langchain_core.tools import BaseTool -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.middleware import ToolCallNameRepairMiddleware +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled -from ..shared.flags import enabled +from .middleware import ToolCallNameRepairMiddleware # deepagents-built-in tool names the repair pass treats as known. _DEEPAGENT_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset( diff --git a/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/tool_call_repair/middleware.py similarity index 96% rename from surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/tool_call_repair/middleware.py index 9f81a168b..260e5cbd4 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/tool_call_repair/middleware.py @@ -34,8 +34,6 @@ from langchain.agents.middleware.types import ( from langchain_core.messages import AIMessage from langgraph.runtime import Runtime -from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME - logger = logging.getLogger(__name__) @@ -120,6 +118,12 @@ class ToolCallNameRepairMiddleware( return call # Stage 2 — invalid fallback + # Local import keeps the middleware module import-light and avoids any + # tools <-> middleware import-order coupling at module scope. + from app.agents.chat.multi_agent_chat.main_agent.tools.invalid_tool import ( + INVALID_TOOL_NAME, + ) + if INVALID_TOOL_NAME in registered: original_args = call.get("args") or {} error_msg = ( diff --git a/surfsense_backend/app/agents/new_chat/plugins/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/plugins/__init__.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/plugins/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/plugins/__init__.py diff --git a/surfsense_backend/app/agents/new_chat/plugin_loader.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/plugins/loader.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/plugin_loader.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/plugins/loader.py diff --git a/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/plugins/year_substituter.py similarity index 95% rename from surfsense_backend/app/agents/new_chat/plugins/year_substituter.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/plugins/year_substituter.py index 2b7781b90..f6564fe6e 100644 --- a/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/plugins/year_substituter.py @@ -17,7 +17,7 @@ Wire-up in ``pyproject.toml`` (illustrative; the in-repo plugin doesn't need this -- it's already on the import path):: [project.entry-points."surfsense.plugins"] - year_substituter = "app.agents.new_chat.plugins.year_substituter:make_middleware" + year_substituter = "app.agents.chat.multi_agent_chat.main_agent.plugins.year_substituter:make_middleware" """ from __future__ import annotations @@ -34,7 +34,7 @@ if TYPE_CHECKING: # pragma: no cover - type-only from langchain_core.messages import ToolMessage from langgraph.types import Command - from app.agents.new_chat.plugin_loader import PluginContext + from .loader import PluginContext logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py similarity index 95% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py index df1ee1b4c..65fa02749 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py @@ -10,18 +10,18 @@ from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.types import Checkpointer -from app.agents.new_chat.agent_cache import ( +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.db import ChatVisibility + +from ..graph.compile_graph_sync import build_compiled_agent_graph_sync +from .agent_cache_store import ( flags_signature, get_cache, stable_hash, system_prompt_hash, tools_signature, ) -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.db import ChatVisibility - -from ..graph.compile_graph_sync import build_compiled_agent_graph_sync def mcp_signature(mcp_tools_by_agent: dict[str, list[BaseTool]]) -> str: diff --git a/surfsense_backend/app/agents/new_chat/agent_cache.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache_store.py similarity index 96% rename from surfsense_backend/app/agents/new_chat/agent_cache.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache_store.py index fa8e6fb72..ee51b4176 100644 --- a/surfsense_backend/app/agents/new_chat/agent_cache.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache_store.py @@ -113,12 +113,11 @@ def tools_signature( MCP tools loaded for the user changes, gating rules flip, etc.). * The available connectors / document types for the search space change (new connector added, last connector removed, new document - type indexed). Because :func:`get_connector_gated_tools` derives - ``modified_disabled_tools`` from ``available_connectors``, the - tool surface is technically already covered — but we hash the - connector list separately so an empty-list "no tools changed" - situation still rotates the key when, say, the user re-adds a - connector that gates a tool we were already not exposing. + type indexed). Connector gating derives disabled tools from + ``available_connectors``, so the tool surface is technically already + covered — but we hash the connector list separately so an empty-list + "no tools changed" situation still rotates the key when, say, the user + re-adds a connector that gates a tool we were already not exposing. Stays stable across: diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/connector_searchable_types.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/connector_searchable_types.py new file mode 100644 index 000000000..be193be04 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/connector_searchable_types.py @@ -0,0 +1,100 @@ +"""Map configured connectors to the searchable document/connector types. + +This is agent-agnostic infrastructure shared by every agent factory (single- +and multi-agent). It translates the connectors a search space has enabled into +the set of searchable type strings that pre-search middleware and ``web_search`` +understand, and always layers in the document types that exist independently of +any connector (uploads, notes, extension captures, YouTube). + +It lives in its own module — rather than inside a specific agent factory — so +that retiring or moving any single agent never disturbs the others' access to +this mapping. +""" + +from __future__ import annotations + +from typing import Any + +# Maps SearchSourceConnectorType enum values to the searchable document/connector types +# used by pre-search middleware and web_search. +# Live search connectors (TAVILY_API, LINKUP_API, BAIDU_SEARCH_API) are routed to +# the web_search tool; all others are considered local/indexed data. +_CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = { + # Live search connectors (handled by web_search tool) + "TAVILY_API": "TAVILY_API", + "LINKUP_API": "LINKUP_API", + "BAIDU_SEARCH_API": "BAIDU_SEARCH_API", + # Local/indexed connectors (handled by KB pre-search middleware) + "SLACK_CONNECTOR": "SLACK_CONNECTOR", + "TEAMS_CONNECTOR": "TEAMS_CONNECTOR", + "NOTION_CONNECTOR": "NOTION_CONNECTOR", + "GITHUB_CONNECTOR": "GITHUB_CONNECTOR", + "LINEAR_CONNECTOR": "LINEAR_CONNECTOR", + "DISCORD_CONNECTOR": "DISCORD_CONNECTOR", + "JIRA_CONNECTOR": "JIRA_CONNECTOR", + "CONFLUENCE_CONNECTOR": "CONFLUENCE_CONNECTOR", + "CLICKUP_CONNECTOR": "CLICKUP_CONNECTOR", + "GOOGLE_CALENDAR_CONNECTOR": "GOOGLE_CALENDAR_CONNECTOR", + "GOOGLE_GMAIL_CONNECTOR": "GOOGLE_GMAIL_CONNECTOR", + "GOOGLE_DRIVE_CONNECTOR": "GOOGLE_DRIVE_FILE", # Connector type differs from document type + "AIRTABLE_CONNECTOR": "AIRTABLE_CONNECTOR", + "LUMA_CONNECTOR": "LUMA_CONNECTOR", + "ELASTICSEARCH_CONNECTOR": "ELASTICSEARCH_CONNECTOR", + "WEBCRAWLER_CONNECTOR": "CRAWLED_URL", # Maps to document type + "BOOKSTACK_CONNECTOR": "BOOKSTACK_CONNECTOR", + "CIRCLEBACK_CONNECTOR": "CIRCLEBACK", # Connector type differs from document type + "OBSIDIAN_CONNECTOR": "OBSIDIAN_CONNECTOR", + "DROPBOX_CONNECTOR": "DROPBOX_FILE", # Connector type differs from document type + "ONEDRIVE_CONNECTOR": "ONEDRIVE_FILE", # Connector type differs from document type + # Composio connectors (unified to native document types). + # Reverse of NATIVE_TO_LEGACY_DOCTYPE in app.db. + "COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "GOOGLE_DRIVE_FILE", + "COMPOSIO_GMAIL_CONNECTOR": "GOOGLE_GMAIL_CONNECTOR", + "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "GOOGLE_CALENDAR_CONNECTOR", +} + +# Document types that don't come from SearchSourceConnector but should always be searchable +_ALWAYS_AVAILABLE_DOC_TYPES: list[str] = [ + "EXTENSION", # Browser extension data + "FILE", # Uploaded files + "NOTE", # User notes + "YOUTUBE_VIDEO", # YouTube videos +] + + +def map_connectors_to_searchable_types( + connector_types: list[Any], +) -> list[str]: + """ + Map SearchSourceConnectorType enums to searchable document/connector types. + + This function: + 1. Converts connector type enums to their searchable counterparts + 2. Includes always-available document types (EXTENSION, FILE, NOTE, YOUTUBE_VIDEO) + 3. Deduplicates while preserving order + + Args: + connector_types: List of SearchSourceConnectorType enum values + + Returns: + List of searchable connector/document type strings + """ + result_set: set[str] = set() + result_list: list[str] = [] + + # Add always-available document types first + for doc_type in _ALWAYS_AVAILABLE_DOC_TYPES: + if doc_type not in result_set: + result_set.add(doc_type) + result_list.append(doc_type) + + # Map each connector type to its searchable equivalent + for ct in connector_types: + # Handle both enum and string types + ct_str = ct.value if hasattr(ct, "value") else str(ct) + searchable = _CONNECTOR_TYPE_TO_SEARCHABLE.get(ct_str) + if searchable and searchable not in result_set: + result_set.add(searchable) + result_list.append(searchable) + + return result_list diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py similarity index 90% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py index 44529d243..d70263841 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py @@ -12,21 +12,28 @@ from langchain_core.tools import BaseTool from langgraph.types import Checkpointer from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents import ( +from app.agents.chat.multi_agent_chat.shared.feature_flags import ( + AgentFeatureFlags, + get_flags, +) +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( + FilesystemMode, + FilesystemSelection, +) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.resolver import ( + build_backend_resolver, +) +from app.agents.chat.multi_agent_chat.subagents import ( get_subagents_to_exclude, main_prompt_registry_subagent_lines, ) -from app.agents.multi_agent_chat.subagents.mcp_tools.index import ( +from app.agents.chat.multi_agent_chat.subagents.mcp_tools.index import ( load_mcp_tools_by_connector, ) -from app.agents.new_chat.chat_deepagent import _map_connectors_to_searchable_types -from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags -from app.agents.new_chat.filesystem_backends import build_backend_resolver -from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection -from app.agents.new_chat.llm_config import AgentConfig -from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching -from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool -from app.agents.new_chat.tools.registry import build_tools_async +from app.agents.chat.runtime.llm_config import AgentConfig +from app.agents.chat.runtime.prompt_caching import ( + apply_litellm_prompt_caching, +) from app.db import ChatVisibility from app.services.connector_service import ConnectorService from app.services.user_tool_allowlist import ( @@ -40,7 +47,10 @@ from ..tools import ( MAIN_AGENT_SURFSENSE_TOOL_NAMES, MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED, ) +from ..tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool +from ..tools.registry import build_main_agent_tools from .agent_cache import build_agent_with_cache +from .connector_searchable_types import map_connectors_to_searchable_types _perf_log = get_perf_logger() @@ -90,7 +100,7 @@ async def create_multi_agent_chat_deep_agent( connector_types = await connector_service.get_available_connectors( search_space_id ) - available_connectors = _map_connectors_to_searchable_types(connector_types) + available_connectors = map_connectors_to_searchable_types(connector_types) available_document_types = await connector_service.get_available_document_types( search_space_id @@ -210,12 +220,14 @@ async def create_multi_agent_chat_deep_agent( main_agent_enabled_tools = list(MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED) _t0 = time.perf_counter() - tools = await build_tools_async( + # Main agent builds only its own small SurfSense toolset via the SRP + # main-agent registry; connectors/MCP/deliverables are delegated to + # subagents, so no MCP loading or connector construction happens here. + tools = build_main_agent_tools( dependencies=dependencies, enabled_tools=main_agent_enabled_tools, disabled_tools=modified_disabled_tools, additional_tools=list(additional_tools) if additional_tools else None, - include_mcp_tools=False, ) _flags: AgentFeatureFlags = get_flags() diff --git a/surfsense_backend/app/agents/new_chat/skills/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/__init__.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/skills/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/__init__.py diff --git a/surfsense_backend/app/agents/new_chat/middleware/skills_backends.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/backends.py similarity index 95% rename from surfsense_backend/app/agents/new_chat/middleware/skills_backends.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/backends.py index 072d73401..31620fe9b 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/skills_backends.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/backends.py @@ -16,7 +16,7 @@ prompt at agent build time, not edited at runtime. Two backends are provided: * :class:`BuiltinSkillsBackend` — disk-backed read of bundled skills from - ``app/agents/new_chat/skills/builtin/``. + ``app/agents/shared/skills/builtin/``. * :class:`SearchSpaceSkillsBackend` — a thin read-only wrapper over :class:`KBPostgresBackend` that filters notes under the privileged folder ``/documents/_skills/``. @@ -47,7 +47,9 @@ from deepagents.backends.state import StateBackend if TYPE_CHECKING: from langchain.tools import ToolRuntime - from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend + from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import ( + KBPostgresBackend, + ) logger = logging.getLogger(__name__) @@ -59,9 +61,10 @@ _MAX_SKILL_FILE_SIZE = 10 * 1024 * 1024 def _default_builtin_root() -> Path: """Return the absolute path to the bundled builtin skills directory. - Located at ``app/agents/new_chat/skills/builtin/`` relative to this module. + Located at ``builtin/`` next to this module (this module lives at + ``app/agents/multi_agent_chat/main_agent/skills/backends.py``). """ - return (Path(__file__).resolve().parent.parent / "skills" / "builtin").resolve() + return (Path(__file__).resolve().parent / "builtin").resolve() class BuiltinSkillsBackend(BackendProtocol): @@ -121,6 +124,8 @@ class BuiltinSkillsBackend(BackendProtocol): else ("/" + str(target.relative_to(self.root)).replace("\\", "/")) ) for child in sorted(target.iterdir()): + if child.name == "__pycache__" or child.name.startswith("."): + continue child_virtual = ( target_virtual.rstrip("/") + "/" + child.name if target_virtual != "/" @@ -305,7 +310,7 @@ def build_skills_backend_factory( # Imported lazily to avoid a hard dependency at module import time: # ``KBPostgresBackend`` pulls in DB models, which are unnecessary for # the unit-tested builtin path. - from app.agents.new_chat.middleware.kb_postgres_backend import ( + from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import ( KBPostgresBackend, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/builtin/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/builtin/__init__.py diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/email-drafting/SKILL.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/builtin/email-drafting/SKILL.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/skills/builtin/email-drafting/SKILL.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/builtin/email-drafting/SKILL.md diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/kb-research/SKILL.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/builtin/kb-research/SKILL.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/skills/builtin/kb-research/SKILL.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/builtin/kb-research/SKILL.md diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/meeting-prep/SKILL.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/builtin/meeting-prep/SKILL.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/skills/builtin/meeting-prep/SKILL.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/builtin/meeting-prep/SKILL.md diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/report-writing/SKILL.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/builtin/report-writing/SKILL.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/skills/builtin/report-writing/SKILL.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/builtin/report-writing/SKILL.md diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/slack-summary/SKILL.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/builtin/slack-summary/SKILL.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/skills/builtin/slack-summary/SKILL.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/skills/builtin/slack-summary/SKILL.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/compose.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/compose.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/compose.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/compose.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/load_md.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/load_md.py similarity index 85% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/load_md.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/load_md.py index 61e30b1c7..fae45f520 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/load_md.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/load_md.py @@ -4,7 +4,7 @@ from __future__ import annotations from importlib import resources -_PROMPTS_PACKAGE = "app.agents.multi_agent_chat.main_agent.system_prompt.prompts" +_PROMPTS_PACKAGE = "app.agents.chat.multi_agent_chat.main_agent.system_prompt.prompts" def read_prompt_md(filename: str) -> str: diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/citations.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/citations.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/citations.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/citations.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/dynamic_context.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/dynamic_context.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/dynamic_context.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/dynamic_context.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/identity.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/identity.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/identity.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/identity.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/memory_protocol.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/memory_protocol.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/memory_protocol.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/memory_protocol.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/specialists.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/specialists.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/specialists.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/specialists.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/tools.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/tools.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/tools.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/sections/tools.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/tool_instruction_block.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/tool_instruction_block.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/tool_instruction_block.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/builder/tool_instruction_block.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/citations/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/citations/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/citations/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/citations/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/citations/off.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/citations/off.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/citations/off.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/citations/off.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/citations/on.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/citations/on.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/citations/on.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/citations/on.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/core_behavior.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/core_behavior.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/core_behavior.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/core_behavior.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/private.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/private.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/private.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/private.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/team.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/team.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/team.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/team.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/identity/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/identity/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/identity/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/identity/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/identity/private.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/identity/private.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/identity/private.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/identity/private.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/identity/team.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/identity/team.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/identity/team.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/identity/team.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/kb_first.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/kb_first.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/kb_first.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/kb_first.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/memory_protocol/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/memory_protocol/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/memory_protocol/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/memory_protocol/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/memory_protocol/private.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/memory_protocol/private.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/memory_protocol/private.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/memory_protocol/private.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/memory_protocol/team.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/memory_protocol/team.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/memory_protocol/team.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/memory_protocol/team.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/output_format.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/output_format.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/output_format.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/output_format.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/__init__.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/base/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/anthropic.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/anthropic.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/anthropic.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/anthropic.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/deepseek.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/deepseek.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/deepseek.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/deepseek.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/default.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/default.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/default.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/default.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/google.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/google.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/google.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/google.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/grok.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/grok.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/grok.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/grok.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/kimi.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/kimi.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/kimi.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/kimi.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_classic.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_classic.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_classic.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_classic.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_codex.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_codex.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_codex.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_codex.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_reasoning.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_reasoning.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_reasoning.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_reasoning.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/refusal_and_limits.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/refusal_and_limits.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/refusal_and_limits.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/refusal_and_limits.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/reminder.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/reminder.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/reminder.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/reminder.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/routing.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/routing.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/routing.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/routing.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/example.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/example.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/example.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/example.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/scrape_webpage/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/scrape_webpage/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/scrape_webpage/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/scrape_webpage/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/scrape_webpage/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/scrape_webpage/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/scrape_webpage/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/scrape_webpage/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/scrape_webpage/example.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/scrape_webpage/example.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/scrape_webpage/example.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/scrape_webpage/example.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/example.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/example.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/example.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/example.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/private/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/private/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/private/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/private/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/private/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/private/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/private/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/private/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/private/example.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/private/example.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/private/example.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/private/example.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/team/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/team/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/team/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/team/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/team/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/team/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/team/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/team/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/team/example.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/team/example.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/team/example.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/update_memory/team/example.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/web_search/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/web_search/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/web_search/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/web_search/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/web_search/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/web_search/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/web_search/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/web_search/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/web_search/example.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/web_search/example.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/web_search/example.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/tools/web_search/example.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/create.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/create.py similarity index 98% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/create.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/create.py index 62d39fcf2..4472a11ac 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/create.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/create.py @@ -27,7 +27,7 @@ from langchain_core.messages import HumanMessage from langchain_core.tools import tool from pydantic import ValidationError -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.automations.schemas.api import AutomationCreate diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/prompt.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/prompt.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/prompt.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/prompt.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/index.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/index.py diff --git a/surfsense_backend/app/agents/new_chat/tools/invalid_tool.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/invalid_tool.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/tools/invalid_tool.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/invalid_tool.py diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/registry.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/registry.py new file mode 100644 index 000000000..9e2e20d35 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/registry.py @@ -0,0 +1,133 @@ +"""SRP main-agent tool registry. + +The main agent exposes only a small, fixed set of SurfSense tools to its LLM; +connector integrations, MCP, and deliverables are delegated to ``task`` +subagents (see :mod:`app.agents.chat.multi_agent_chat.main_agent.tools.index`). + +This module is the *building* counterpart to that name list: it owns the +factories for those few tools and nothing else, so the main agent's tool +surface stays self-contained and connector-free. + +Tool *display* metadata for the whole app (the ``/agent/tools`` listing +endpoint) lives separately in :mod:`app.agents.chat.multi_agent_chat.shared.tools.catalog`, a +pure-data module that imports no connectors. This registry only governs what +the main agent actually builds and binds. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from langchain_core.tools import BaseTool + +from app.agents.chat.shared.tools.web_search import create_web_search_tool +from app.db import ChatVisibility + +from .scrape_webpage import create_scrape_webpage_tool +from .update_memory import ( + create_update_memory_tool, + create_update_team_memory_tool, +) + + +def _build_scrape_webpage_tool(deps: dict[str, Any]) -> BaseTool: + return create_scrape_webpage_tool(firecrawl_api_key=deps.get("firecrawl_api_key")) + + +def _build_web_search_tool(deps: dict[str, Any]) -> BaseTool: + return create_web_search_tool( + search_space_id=deps.get("search_space_id"), + available_connectors=deps.get("available_connectors"), + ) + + +def _build_create_automation_tool(deps: dict[str, Any]) -> BaseTool: + # Deferred import: the automation package is a sibling under ``main_agent`` + # and is only needed at build time, mirroring the shared registry's + # call-time import to keep module import order robust. + from .automation import create_create_automation_tool + + return create_create_automation_tool( + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + llm=deps["llm"], + ) + + +def _build_update_memory_tool(deps: dict[str, Any]) -> BaseTool: + if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE: + return create_update_team_memory_tool( + search_space_id=deps["search_space_id"], + db_session=deps["db_session"], + llm=deps.get("llm"), + ) + return create_update_memory_tool( + user_id=deps["user_id"], + db_session=deps["db_session"], + llm=deps.get("llm"), + ) + + +# Ordered to match the historical main-agent binding order: +# scrape_webpage, web_search, create_automation, update_memory. +# Each entry is ``(factory, required_dependency_names)``. +_MAIN_AGENT_TOOL_FACTORIES: dict[ + str, tuple[Callable[[dict[str, Any]], BaseTool], tuple[str, ...]] +] = { + "scrape_webpage": (_build_scrape_webpage_tool, ()), + "web_search": (_build_web_search_tool, ()), + "create_automation": ( + _build_create_automation_tool, + ("search_space_id", "user_id", "llm"), + ), + "update_memory": ( + _build_update_memory_tool, + ("user_id", "search_space_id", "db_session", "thread_visibility", "llm"), + ), +} + + +def build_main_agent_tools( + dependencies: dict[str, Any], + enabled_tools: list[str] | None = None, + disabled_tools: list[str] | None = None, + additional_tools: list[BaseTool] | None = None, +) -> list[BaseTool]: + """Build the main agent's tool instances. + + Args: + dependencies: Dependency bag passed to each tool factory. + enabled_tools: Explicit allow-list of tool names. When ``None``, all + main-agent tools are enabled. Names not owned by this registry are + ignored. + disabled_tools: Names to drop after the enabled set is resolved. + additional_tools: Extra tools appended verbatim (e.g. custom tools). + + Returns: + Tool instances in the registry's declaration order, with any + ``additional_tools`` appended. + """ + if enabled_tools is None: + names = list(_MAIN_AGENT_TOOL_FACTORIES) + else: + wanted = set(enabled_tools) + names = [n for n in _MAIN_AGENT_TOOL_FACTORIES if n in wanted] + + if disabled_tools: + disabled = set(disabled_tools) + names = [n for n in names if n not in disabled] + + tools: list[BaseTool] = [] + for name in names: + factory, requires = _MAIN_AGENT_TOOL_FACTORIES[name] + missing = [dep for dep in requires if dep not in dependencies] + if missing: + msg = f"Tool '{name}' requires dependencies: {missing}" + raise ValueError(msg) + tools.append(factory(dependencies)) + + if additional_tools: + tools.extend(additional_tools) + + return tools diff --git a/surfsense_backend/app/agents/new_chat/tools/scrape_webpage.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/scrape_webpage.py similarity index 91% rename from surfsense_backend/app/agents/new_chat/tools/scrape_webpage.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/scrape_webpage.py index 014126927..24a686da1 100644 --- a/surfsense_backend/app/agents/new_chat/tools/scrape_webpage.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/scrape_webpage.py @@ -29,7 +29,6 @@ def extract_domain(url: str) -> str: try: parsed = urlparse(url) domain = parsed.netloc - # Remove 'www.' prefix if present if domain.startswith("www."): domain = domain[4:] return domain @@ -53,14 +52,13 @@ def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]: if len(content) <= max_length: return content, False - # Try to truncate at a sentence boundary + # Prefer truncating at a sentence/paragraph boundary. truncated = content[:max_length] last_period = truncated.rfind(".") last_newline = truncated.rfind("\n\n") - # Use the later of the two boundaries, or just truncate boundary = max(last_period, last_newline) - if boundary > max_length * 0.8: # Only use boundary if it's not too far back + if boundary > max_length * 0.8: # only if the boundary isn't too far back truncated = content[: boundary + 1] return truncated + "\n\n[Content truncated...]", True @@ -111,8 +109,8 @@ async def _scrape_youtube_video( http_client.proxies.update(residential_proxies) ytt_api = YouTubeTranscriptApi(http_client=http_client) - # List all available transcripts and pick the first one - # (the video's primary language) instead of defaulting to English + # Pick the first transcript (video's primary language) rather than + # defaulting to English. transcript_list = ytt_api.list(video_id) transcript = next(iter(transcript_list)) captions = transcript.fetch() @@ -134,10 +132,8 @@ async def _scrape_youtube_video( logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}") transcript_text = f"No captions available for this video. Error: {e!s}" - # Build combined content content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}" - # Truncate if needed content, was_truncated = truncate_content(content, max_length) word_count = len(content.split()) @@ -212,20 +208,16 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None): scrape_id = generate_scrape_id(url) domain = extract_domain(url) - # Validate and normalize URL if not url.startswith(("http://", "https://")): url = f"https://{url}" try: - # Check if this is a YouTube URL and use transcript API instead + # YouTube URLs use the transcript API instead of crawling. video_id = get_youtube_video_id(url) if video_id: return await _scrape_youtube_video(url, video_id, max_length) - # Create webcrawler connector connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key) - - # Crawl the URL result, error = await connector.crawl_url(url, formats=["markdown"]) if error: @@ -250,28 +242,21 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None): "error": "No content returned from crawler", } - # Extract content and metadata content = result.get("content", "") metadata = result.get("metadata", {}) - # Get title from metadata title = metadata.get("title", "") if not title: title = domain or url.split("/")[-1] or "Webpage" - # Get description from metadata description = metadata.get("description", "") if not description and content: - # Use first paragraph as description first_para = content.split("\n\n")[0] if content else "" description = ( first_para[:300] + "..." if len(first_para) > 300 else first_para ) - # Truncate content if needed content, was_truncated = truncate_content(content, max_length) - - # Calculate word count word_count = len(content.split()) return { diff --git a/surfsense_backend/app/agents/new_chat/tools/update_memory.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/update_memory.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/tools/update_memory.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/update_memory.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/__init__.py diff --git a/surfsense_backend/app/agents/new_chat/utils.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/date_filters.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/utils.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/date_filters.py diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/feature_flags.py similarity index 68% rename from surfsense_backend/app/agents/new_chat/feature_flags.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/feature_flags.py index 27188fac3..9564bd195 100644 --- a/surfsense_backend/app/agents/new_chat/feature_flags.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/feature_flags.py @@ -1,37 +1,9 @@ -""" -Feature flags for the SurfSense new_chat agent stack. +"""Feature flags for the SurfSense new_chat agent stack. -These flags gate the newer agent middleware (some ported from OpenCode, -some sourced from ``langchain.agents.middleware`` / ``deepagents``, some -SurfSense-native). Most shipped agent-stack upgrades default ON so Docker -image updates work even when older installs do not have newly introduced -environment variables. Risky/experimental integrations stay default OFF, -and the master kill-switch can still disable everything new. - -All new middleware checks its flag at agent build time. If the master -kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new -middleware is disabled regardless of its individual flag. This gives -operators a single switch to revert to pre-port behavior. - -Examples --------- - -Defaults: - - SURFSENSE_ENABLE_CONTEXT_EDITING=true - SURFSENSE_ENABLE_COMPACTION_V2=true - SURFSENSE_ENABLE_RETRY_AFTER=true - SURFSENSE_ENABLE_MODEL_FALLBACK=false - SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true - SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true - SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true - SURFSENSE_ENABLE_PERMISSION=true - SURFSENSE_ENABLE_DOOM_LOOP=true - SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call - -Master kill-switch (overrides everything else): - - SURFSENSE_DISABLE_NEW_AGENT_STACK=true +Flags are resolved at agent build time. Most upgrades default ON so Docker +updates work without operators adding new env vars; risky integrations stay +OFF. The master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` forces every +flag below to False for a one-switch rollback to pre-port behavior. """ from __future__ import annotations @@ -93,39 +65,14 @@ class AgentFeatureFlags: # Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT) enable_otel: bool = False - # Performance — compiled-agent cache (Phase 1 + Phase 2). - # When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled - # graph if the cache key matches (LLM config + thread + tool surface + - # flags + system prompt + filesystem mode). Cuts per-turn agent-build - # wall clock from ~4-5s to <50µs on cache hits. - # - # SAFETY (Phase 2 unblocked this default-on): - # All connector mutation tools (``tools/notion``, ``tools/gmail``, - # ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``, - # ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``, - # ``tools/teams``, ``tools/luma``, ``connected_accounts``, - # ``update_memory``) now acquire fresh - # short-lived ``AsyncSession`` instances per call via - # :data:`async_session_maker`. The factory still accepts ``db_session`` - # for registry compatibility but ``del``'s it immediately — see any - # of those files' factory docstrings for the rationale. The ``llm`` - # closure is per-(provider, model, config_id) which is already in - # the cache key, so the LLM is safe to share across cached hits of - # the same key. The KB priority middleware reads - # ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5), - # not its constructor closure, so the same compiled agent serves - # turns with different mention lists correctly. - # - # Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the - # environment if a regression surfaces. The path is exercised by - # the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite. + # Performance — reuse a compiled agent graph when the cache key matches + # (~4-5s -> <50µs per turn). Safe to default-on because mutation tools take + # fresh short-lived sessions per call and per-turn context (mentions, etc.) + # is read from runtime.context, not the constructor closure. Rollback via + # SURFSENSE_ENABLE_AGENT_CACHE=false. enable_agent_cache: bool = True - # Phase 1 (deferred — measure first): pre-build & share the - # general-purpose subagent ``CompiledSubAgent`` across cold-cache - # misses. Only helps when the outer cache MISSES (cache hits already - # reuse the entire SubAgentMiddleware-compiled graph). Off by default - # until we have data showing cold misses are frequent enough to - # justify the extra global state. + # Deferred: only helps on outer-cache MISSES, so off until data shows cold + # misses are frequent enough to justify the extra global state. enable_agent_cache_share_gp_subagent: bool = False @classmethod diff --git a/surfsense_backend/app/agents/new_chat/filesystem_selection.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/filesystem_selection.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/filesystem_selection.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/filesystem_selection.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/anthropic_cache.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/anthropic_cache.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/anthropic_cache.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/anthropic_cache.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/compaction.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/compaction.py similarity index 80% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/compaction.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/compaction.py index b59e7d2c4..c1d26429e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/compaction.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/compaction.py @@ -7,7 +7,7 @@ from typing import Any from deepagents.backends import StateBackend from langchain_core.language_models import BaseChatModel -from app.agents.new_chat.middleware import create_surfsense_compaction_middleware +from app.agents.chat.shared.middleware import create_surfsense_compaction_middleware def build_compaction_mw(llm: BaseChatModel) -> Any: diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/dedup_tool_calls.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/dedup_tool_calls.py new file mode 100644 index 000000000..087a69ae6 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/dedup_tool_calls.py @@ -0,0 +1,59 @@ +"""Dedup-key resolvers for tool-call deduplication. + +A *resolver* maps a tool's ``args`` dict to a stable signature string used to +collapse duplicate calls. These helpers are shared: the MCP tool layer uses +:func:`dedup_key_full_args` as a safe default, and the main-agent +``DedupHITLToolCallsMiddleware`` builds its resolver map from them. + +Resolver resolution order (read from each tool's own ``metadata``): + +1. ``tool.metadata["dedup_key"]`` — callable mapping the args dict to a + stable signature string. This is the canonical mechanism. +2. ``tool.metadata["hitl_dedup_key"]`` — string naming a primary arg; + used by MCP / Composio tools that only expose a single key field. + +A tool with no resolver from either path simply opts out of dedup. +""" + +from __future__ import annotations + +import json +from collections.abc import Callable +from typing import Any + +# Resolver type — given the tool ``args`` dict returns a stable +# string used to dedupe consecutive calls. ``None`` means no dedup. +DedupResolver = Callable[[dict[str, Any]], str] + + +def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver: + """Adapt a string-arg name into a :data:`DedupResolver`. + + Convenience helper for tools that just want to dedupe on a single arg's + lowercased value (the most common case for HITL tools like + ``send_gmail_email`` keyed on ``subject``). Set the result on the tool's + ``metadata["dedup_key"]``. + """ + + def _resolver(args: dict[str, Any]) -> str: + return str(args.get(arg_name, "")).lower() + + return _resolver + + +def dedup_key_full_args(args: dict[str, Any]) -> str: + """Resolver that collapses calls only when **every** argument is identical. + + Safe default for tools where no single field uniquely identifies a call + (e.g. MCP tools whose first required field is a shared workspace id). + """ + + try: + return json.dumps(args, sort_keys=True, default=str) + except (TypeError, ValueError): + return repr(sorted(args.items())) if isinstance(args, dict) else repr(args) + + +# Backwards-compatible alias for code that imported the original +# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`. +_wrap_string_key = wrap_dedup_key_by_arg_name diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/__init__.py diff --git a/surfsense_backend/app/agents/new_chat/document_xml.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/document_xml.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/document_xml.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/document_xml.py diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py similarity index 99% rename from surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py index 7cf3bf8cd..7b8aaf2b0 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py @@ -42,8 +42,10 @@ from langchain.tools import ToolRuntime from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.document_xml import build_document_xml -from app.agents.new_chat.path_resolver import ( +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.document_xml import ( + build_document_xml, +) +from app.agents.chat.runtime.path_resolver import ( DOCUMENTS_ROOT, build_path_index, doc_to_virtual_path, diff --git a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/local_folder.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/local_folder.py diff --git a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/multi_root_local_folder.py similarity index 99% rename from surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/multi_root_local_folder.py index a5add6248..db84a17eb 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/multi_root_local_folder.py @@ -15,7 +15,9 @@ from deepagents.backends.protocol import ( WriteResult, ) -from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.local_folder import ( + LocalFolderBackend, +) _INVALID_PATH = "invalid_path" _FILE_NOT_FOUND = "file_not_found" diff --git a/surfsense_backend/app/agents/new_chat/filesystem_backends.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/resolver.py similarity index 84% rename from surfsense_backend/app/agents/new_chat/filesystem_backends.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/resolver.py index c8288be71..6c35f369f 100644 --- a/surfsense_backend/app/agents/new_chat/filesystem_backends.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/resolver.py @@ -9,9 +9,14 @@ from deepagents.backends.protocol import BackendProtocol from deepagents.backends.state import StateBackend from langgraph.prebuilt.tool_node import ToolRuntime -from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection -from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend -from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( + FilesystemMode, + FilesystemSelection, +) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import ( + KBPostgresBackend, +) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.multi_root_local_folder import ( MultiRootLocalFolderBackend, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/index.py similarity index 88% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/index.py index fb8dbe209..91bc4db7c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/index.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import Any -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode from .middleware import SurfSenseFilesystemMiddleware diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/async_dispatch.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/async_dispatch.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/async_dispatch.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/async_dispatch.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/index.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/index.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/middleware.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/middleware.py similarity index 91% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/middleware.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/middleware.py index c32e14438..f04390f4a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/middleware.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/middleware.py @@ -7,9 +7,13 @@ from typing import Any from deepagents import FilesystemMiddleware from langchain_core.tools import BaseTool -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.sandbox import is_sandbox_enabled +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox import ( + is_sandbox_enabled, +) +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) from ..system_prompt import build_system_prompt from ..tools import ( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/mode.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/mode.py similarity index 70% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/mode.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/mode.py index a23d77535..44d69a50a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/mode.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/mode.py @@ -2,8 +2,8 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.runtime.path_resolver import DOCUMENTS_ROOT def is_cloud(mode: FilesystemMode) -> bool: diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/namespace_policy.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/namespace_policy.py similarity index 89% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/namespace_policy.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/namespace_policy.py index 539050414..1eced41d7 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/namespace_policy.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/namespace_policy.py @@ -11,8 +11,10 @@ from typing import TYPE_CHECKING from langchain.tools import ToolRuntime -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) +from app.agents.chat.runtime.path_resolver import DOCUMENTS_ROOT from ..shared.paths import TEMP_PREFIX, basename from .mode import is_cloud diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/path_resolution.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/path_resolution.py similarity index 94% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/path_resolution.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/path_resolution.py index 2c8ec6b4d..2650d9c34 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/path_resolution.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/path_resolution.py @@ -7,11 +7,13 @@ from typing import TYPE_CHECKING from langchain.tools import ToolRuntime -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.multi_root_local_folder import ( MultiRootLocalFolderBackend, ) +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) from ..shared.paths import ( extract_mount_from_path, diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/read_only_policy.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/read_only_policy.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/read_only_policy.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/middleware/read_only_policy.py diff --git a/surfsense_backend/app/agents/new_chat/sandbox.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/sandbox.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/sandbox.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/sandbox.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/shared/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/shared/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/shared/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/shared/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/shared/paths.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/shared/paths.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/shared/paths.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/shared/paths.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/system_prompt/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/system_prompt/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/cloud.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/system_prompt/cloud.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/cloud.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/system_prompt/cloud.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/common.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/system_prompt/common.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/common.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/system_prompt/common.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/desktop.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/system_prompt/desktop.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/desktop.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/system_prompt/desktop.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/system_prompt/index.py similarity index 86% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/system_prompt/index.py index 9d3cdbae3..491b5a762 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/system_prompt/index.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode from .cloud import BODY as CLOUD_BODY from .common import HEADER, SANDBOX_ADDENDUM diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/cd/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/cd/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/cd/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/cd/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/cd/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/cd/description.py similarity index 83% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/cd/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/cd/description.py index 6d7b987c8..bc106efcf 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/cd/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/cd/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _DESCRIPTION = """Changes the current working directory (cwd). diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/cd/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/cd/index.py similarity index 93% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/cd/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/cd/index.py index 8df6b9edb..0e78e8640 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/cd/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/cd/index.py @@ -10,8 +10,10 @@ from langchain_core.messages import ToolMessage from langchain_core.tools import BaseTool, StructuredTool from langgraph.types import Command -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) +from app.agents.chat.runtime.path_resolver import DOCUMENTS_ROOT from ...middleware.async_dispatch import run_async_blocking from ...middleware.path_resolution import resolve_relative diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/edit_file/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/edit_file/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/edit_file/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/edit_file/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/edit_file/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/edit_file/description.py similarity index 89% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/edit_file/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/edit_file/description.py index de2a47648..5c474e2f8 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/edit_file/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/edit_file/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _CLOUD_DESCRIPTION = """Performs exact string replacements in files. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/edit_file/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/edit_file/index.py similarity index 95% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/edit_file/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/edit_file/index.py index 324ef09b0..775469531 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/edit_file/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/edit_file/index.py @@ -11,8 +11,12 @@ from langchain_core.messages import ToolMessage from langchain_core.tools import BaseTool, StructuredTool from langgraph.types import Command -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import ( + KBPostgresBackend, +) +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) from ...middleware.async_dispatch import run_async_blocking from ...middleware.mode import is_cloud diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/execute_code/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/execute_code/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/execute_code/description.py similarity index 86% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/execute_code/description.py index 89415c2f3..ae19b977e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/execute_code/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _DESCRIPTION = """Executes Python code in an isolated sandbox environment. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/helpers.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/execute_code/helpers.py similarity index 93% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/helpers.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/execute_code/helpers.py index cda9f535d..2c3293e14 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/helpers.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/execute_code/helpers.py @@ -14,12 +14,14 @@ from typing import TYPE_CHECKING from daytona.common.errors import DaytonaError from langchain.tools import ToolRuntime -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.sandbox import ( +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox import ( _evict_sandbox_cache, delete_sandbox, get_or_create_sandbox, ) +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/execute_code/index.py similarity index 95% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/execute_code/index.py index 2711636e4..b530c91f2 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/execute_code/index.py @@ -7,7 +7,9 @@ from typing import TYPE_CHECKING, Annotated from langchain.tools import ToolRuntime from langchain_core.tools import BaseTool, StructuredTool -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) from ...middleware.async_dispatch import run_async_blocking from .description import select_description diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/glob/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/glob/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/glob/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/glob/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/glob/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/glob/description.py similarity index 77% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/glob/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/glob/description.py index d022f9a7a..7c9fafa36 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/glob/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/glob/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _DESCRIPTION = """Find files matching a glob pattern. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/grep/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/grep/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/grep/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/grep/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/grep/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/grep/description.py similarity index 89% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/grep/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/grep/description.py index 5d7c393a9..4b34ac60b 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/grep/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/grep/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _CLOUD_DESCRIPTION = """Search for a literal text pattern across files. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/list_tree/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/list_tree/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/list_tree/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/list_tree/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/list_tree/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/list_tree/description.py similarity index 93% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/list_tree/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/list_tree/description.py index a24230fb0..619a639d1 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/list_tree/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/list_tree/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _CLOUD_DESCRIPTION = """Lists files/folders recursively in a single bounded call. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/list_tree/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/list_tree/index.py similarity index 94% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/list_tree/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/list_tree/index.py index 8bad88a74..21bba1fc3 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/list_tree/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/list_tree/index.py @@ -9,8 +9,12 @@ from deepagents.backends.utils import validate_path from langchain.tools import ToolRuntime from langchain_core.tools import BaseTool, StructuredTool -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import ( + KBPostgresBackend, +) +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) from ...middleware.async_dispatch import run_async_blocking from ...middleware.path_resolution import resolve_list_target_path diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/ls/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/ls/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/ls/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/ls/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/ls/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/ls/description.py similarity index 91% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/ls/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/ls/description.py index 8c7e301dc..f49a64772 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/ls/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/ls/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _CLOUD_DESCRIPTION = """Lists files and directories at the given path. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/ls/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/ls/index.py similarity index 93% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/ls/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/ls/index.py index 70f31dd04..e45a279d7 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/ls/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/ls/index.py @@ -8,8 +8,12 @@ from deepagents.backends.utils import validate_path from langchain.tools import ToolRuntime from langchain_core.tools import BaseTool, StructuredTool -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.middleware.kb_postgres_backend import paginate_listing +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import ( + paginate_listing, +) +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) from ...middleware.async_dispatch import run_async_blocking from ...middleware.path_resolution import resolve_list_target_path diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/mkdir/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/mkdir/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/mkdir/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/mkdir/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/mkdir/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/mkdir/description.py similarity index 89% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/mkdir/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/mkdir/description.py index 1c86e72f7..94eb49d2d 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/mkdir/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/mkdir/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _CLOUD_DESCRIPTION = """Creates a directory under `/documents/`. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/mkdir/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/mkdir/index.py similarity index 94% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/mkdir/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/mkdir/index.py index 788381faa..3ea38f525 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/mkdir/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/mkdir/index.py @@ -11,8 +11,10 @@ from langchain_core.messages import ToolMessage from langchain_core.tools import BaseTool, StructuredTool from langgraph.types import Command -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) +from app.agents.chat.runtime.path_resolver import DOCUMENTS_ROOT from ...middleware.async_dispatch import run_async_blocking from ...middleware.mode import is_cloud diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/move_file/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/move_file/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/move_file/description.py similarity index 92% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/move_file/description.py index fdba40b29..520692697 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/move_file/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _CLOUD_DESCRIPTION = """Moves or renames a file or folder. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/helpers.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/move_file/helpers.py similarity index 90% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/helpers.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/move_file/helpers.py index 7613f62f1..ded4701f9 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/helpers.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/move_file/helpers.py @@ -8,10 +8,14 @@ from langchain.tools import ToolRuntime from langchain_core.messages import ToolMessage from langgraph.types import Command -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend -from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT -from app.agents.new_chat.state_reducers import _CLEAR +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import ( + KBPostgresBackend, +) +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) +from app.agents.chat.multi_agent_chat.shared.state.reducers import _CLEAR +from app.agents.chat.runtime.path_resolver import DOCUMENTS_ROOT if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/move_file/index.py similarity index 96% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/move_file/index.py index d90535990..b7345b1a0 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/move_file/index.py @@ -11,7 +11,9 @@ from langchain_core.messages import ToolMessage from langchain_core.tools import BaseTool, StructuredTool from langgraph.types import Command -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) from ...middleware.async_dispatch import run_async_blocking from ...middleware.mode import is_cloud diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/pwd/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/pwd/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/pwd/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/pwd/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/pwd/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/pwd/description.py similarity index 72% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/pwd/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/pwd/description.py index 594a38843..11f0b9f91 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/pwd/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/pwd/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _DESCRIPTION = """Prints the current working directory.""" diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/pwd/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/pwd/index.py similarity index 89% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/pwd/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/pwd/index.py index c15b67114..2c220efca 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/pwd/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/pwd/index.py @@ -7,7 +7,9 @@ from typing import TYPE_CHECKING from langchain.tools import ToolRuntime from langchain_core.tools import BaseTool, StructuredTool -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) from ...middleware.path_resolution import current_cwd from .description import select_description diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/read_file/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/read_file/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/read_file/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/read_file/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/read_file/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/read_file/description.py similarity index 89% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/read_file/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/read_file/description.py index 9b5d7623f..b10ca4acc 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/read_file/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/read_file/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _DESCRIPTION = """Reads a file from the filesystem. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/read_file/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/read_file/index.py similarity index 93% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/read_file/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/read_file/index.py index 8b0a1a1c8..5c20619d6 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/read_file/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/read_file/index.py @@ -10,8 +10,12 @@ from langchain_core.messages import ToolMessage from langchain_core.tools import BaseTool, StructuredTool from langgraph.types import Command -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import ( + KBPostgresBackend, +) +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) from ...middleware.async_dispatch import run_async_blocking from ...middleware.path_resolution import resolve_relative diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rm/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rm/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rm/description.py similarity index 92% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rm/description.py index a9e120e7c..7a8e96c09 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rm/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _CLOUD_DESCRIPTION = """Deletes a single file under `/documents/`. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/helpers.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rm/helpers.py similarity index 92% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/helpers.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rm/helpers.py index 8a02544d8..e2e445d08 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/helpers.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rm/helpers.py @@ -12,10 +12,14 @@ from langchain.tools import ToolRuntime from langchain_core.messages import ToolMessage from langgraph.types import Command -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend -from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT -from app.agents.new_chat.state_reducers import _CLEAR +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import ( + KBPostgresBackend, +) +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) +from app.agents.chat.multi_agent_chat.shared.state.reducers import _CLEAR +from app.agents.chat.runtime.path_resolver import DOCUMENTS_ROOT if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rm/index.py similarity index 94% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rm/index.py index 0c4e2fc71..099079476 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rm/index.py @@ -9,7 +9,9 @@ from langchain.tools import ToolRuntime from langchain_core.tools import BaseTool, StructuredTool from langgraph.types import Command -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) from ...middleware.async_dispatch import run_async_blocking from ...middleware.mode import is_cloud diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rmdir/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rmdir/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rmdir/description.py similarity index 94% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rmdir/description.py index 2b72f815b..0880b4d22 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rmdir/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _CLOUD_DESCRIPTION = """Deletes an empty directory under `/documents/`. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/helpers.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rmdir/helpers.py similarity index 92% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/helpers.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rmdir/helpers.py index de5afe722..b511a8d79 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/helpers.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rmdir/helpers.py @@ -13,10 +13,14 @@ from langchain.tools import ToolRuntime from langchain_core.messages import ToolMessage from langgraph.types import Command -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend -from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT -from app.agents.new_chat.state_reducers import _CLEAR +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import ( + KBPostgresBackend, +) +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) +from app.agents.chat.multi_agent_chat.shared.state.reducers import _CLEAR +from app.agents.chat.runtime.path_resolver import DOCUMENTS_ROOT from ...middleware.path_resolution import current_cwd from ...shared.paths import is_ancestor_of diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rmdir/index.py similarity index 94% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rmdir/index.py index cdf057353..4c52f68ae 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/rmdir/index.py @@ -9,7 +9,9 @@ from langchain.tools import ToolRuntime from langchain_core.tools import BaseTool, StructuredTool from langgraph.types import Command -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) from ...middleware.async_dispatch import run_async_blocking from ...middleware.mode import is_cloud diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/write_file/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/write_file/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/write_file/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/write_file/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/write_file/description.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/write_file/description.py similarity index 93% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/write_file/description.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/write_file/description.py index 223cc3f26..933ba2caf 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/write_file/description.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/write_file/description.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode _CLOUD_DESCRIPTION = """Writes a new text file to the workspace. diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/write_file/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/write_file/index.py similarity index 96% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/write_file/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/write_file/index.py index a42f7ed62..5aa250143 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/write_file/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/tools/write_file/index.py @@ -11,7 +11,9 @@ from langchain_core.messages import ToolMessage from langchain_core.tools import BaseTool, StructuredTool from langgraph.types import Command -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) from ...middleware.async_dispatch import run_async_blocking from ...middleware.mode import is_cloud diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/flags.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/flags.py similarity index 78% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/flags.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/flags.py index 69994ae00..dfbf3e6ee 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/flags.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/flags.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags def enabled(flags: AgentFeatureFlags, attr: str) -> bool: diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/kb_context_projection.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/kb_context_projection.py similarity index 93% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/kb_context_projection.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/kb_context_projection.py index 2685d8a9b..4667441ab 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/kb_context_projection.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/kb_context_projection.py @@ -9,10 +9,13 @@ from langchain.agents.middleware import AgentMiddleware, AgentState from langchain_core.messages import SystemMessage from langgraph.runtime import Runtime -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.middleware.knowledge_search import _render_priority_message +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) from app.utils.perf import get_perf_logger +from .knowledge_search import _render_priority_message + _perf_log = get_perf_logger() diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py similarity index 88% rename from surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py index 77b413940..2714c6065 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py @@ -41,15 +41,20 @@ from litellm import token_counter from pydantic import BaseModel, Field, ValidationError from sqlalchemy import select -from app.agents.new_chat.feature_flags import get_flags -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.path_resolver import ( +from app.agents.chat.multi_agent_chat.shared.date_filters import ( + parse_date_or_datetime, + resolve_date_range, +) +from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import ( + SurfSenseFilesystemState, +) +from app.agents.chat.runtime.path_resolver import ( PathIndex, build_path_index, doc_to_virtual_path, ) -from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range from app.db import ( NATIVE_TO_LEGACY_DOCTYPE, Chunk, @@ -589,14 +594,9 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] inject_system_message: bool = True, # For backwards compatibility ) -> None: self.llm = llm - # The planner LLM handles short, structured internal tasks (query - # rewriting, date extraction, recency classification). When an - # operator marks a global config ``is_planner: true`` we route - # those calls to a cheap/fast model (e.g. gpt-4o-mini, Haiku, Azure - # gpt-5.x-nano) instead of the user's chat LLM — those classification - # tasks don't need frontier-tier capability. Falls back to the chat - # LLM when no planner config is wired up so deployments without one - # keep working unchanged. + # Cheap model for structured internal tasks (query rewrite, date + # extraction, recency classification) when one is configured; falls back + # to the chat LLM otherwise. self.planner_llm = planner_llm or llm self.search_space_id = search_space_id self.filesystem_mode = filesystem_mode @@ -605,26 +605,17 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] self.top_k = top_k self.mentioned_document_ids = mentioned_document_ids or [] self.inject_system_message = inject_system_message - # Build the kb-planner private Runnable ONCE here so we don't pay - # the ``create_agent`` compile cost (50-200ms) on every turn. - # Disabled by default behind ``enable_kb_planner_runnable``; when - # off the planner falls back to the legacy ``planner_llm.ainvoke`` - # path. + # Compiled lazily and memoized to avoid the per-turn create_agent cost. self._planner: Runnable | None = None self._planner_compile_failed = False def _build_kb_planner_runnable(self) -> Runnable | None: - """Compile the kb-planner private :class:`Runnable` once. + """Lazily compile and memoize the kb-planner Runnable. - Returns ``None`` when the feature flag is disabled, when the LLM is - unavailable, or when ``create_agent`` raises (we fall back to the - legacy ``planner_llm.ainvoke`` path in that case). Compilation happens - lazily on first call, then memoized via ``self._planner``. - - The compiled agent is constructed without tools — the planner's - contract is "answer with structured JSON" — but it inherits the - :class:`RetryAfterMiddleware` so transient rate-limit errors - from the planner LLM call don't fail the whole turn. + Returns ``None`` (and the caller falls back to ``planner_llm.ainvoke``) + when the flag is off, the LLM is missing, or ``create_agent`` raises. + Built without tools but with RetryAfterMiddleware so a transient + rate-limit on the planner call doesn't fail the whole turn. """ if self._planner is not None or self._planner_compile_failed: return self._planner @@ -634,7 +625,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack: return None - from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware + from app.agents.chat.shared.middleware.retry_after import RetryAfterMiddleware try: self._planner = create_agent( @@ -672,10 +663,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] loop = asyncio.get_running_loop() t0 = loop.time() - # Prefer the compiled-once planner Runnable when enabled; otherwise - # fall back to ``planner_llm.ainvoke``. The ``surfsense:internal`` - # tag is preserved on both paths so ``_stream_agent_events`` still - # suppresses the planner's intermediate events from the UI. + # Both paths tag surfsense:internal so the planner's intermediate + # events stay suppressed from the UI. planner = self._build_kb_planner_runnable() try: if planner is not None: @@ -814,32 +803,16 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] user_text=user_text, ) - # Per-turn ``mentioned_document_ids`` flow: - # 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the - # streaming task supplies a fresh :class:`SurfSenseContextSchema` - # on every ``astream_events`` call, so this list is naturally - # scoped to the current turn. Allows cross-turn graph reuse via - # ``agent_cache``. - # 2. Legacy fallback (cache disabled / context not propagated): the - # constructor-injected ``self.mentioned_document_ids`` list. We - # drain it after the first read so a cached graph (no Phase 1.5 - # wiring) doesn't keep replaying the same mentions on every - # turn. + # Prefer per-turn mentions from runtime.context (lets a cached graph + # serve different turns); fall back to the constructor closure, draining + # it after one read so stale mentions can't replay. # - # CRITICAL: distinguish "context absent" (legacy caller, no field at - # all) from "context provided but empty" (turn with no mentions). - # ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in - # Python, so a naive ``if ctx_mentions:`` would fall through to the - # legacy closure on every no-mention follow-up turn — replaying the - # mentions baked in by turn 1's cache-miss build. Always drain the - # closure once the runtime path has fired so a cached middleware - # instance can never resurrect stale state. + # CRITICAL: test ``ctx_mentions is not None``, not truthiness — an empty + # list means "this turn has no mentions", not "use the closure". mention_ids: list[int] = [] ctx = getattr(runtime, "context", None) if runtime is not None else None ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None if ctx_mentions is not None: - # Runtime path is authoritative — even an empty list means - # "this turn has no mentions", NOT "look at the closure". mention_ids = list(ctx_mentions) if self.mentioned_document_ids: self.mentioned_document_ids = [] @@ -847,12 +820,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] mention_ids = list(self.mentioned_document_ids) self.mentioned_document_ids = [] - # Folder mentions live alongside doc mentions on the runtime - # context. They never feed hybrid search (folders aren't - # embedded) — they're surfaced purely as ``[USER-MENTIONED]`` - # priority entries so the agent walks the folder with ``ls`` / - # ``find_documents`` instead of ignoring it. Cloud filesystem - # mode only. + # Folder mentions aren't embedded, so they skip hybrid search and are + # surfaced only as [USER-MENTIONED] entries. Cloud mode only. folder_mention_ids: list[int] = [] if ( ctx is not None @@ -934,14 +903,10 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] async def _materialize_folder_priority( self, folder_ids: list[int] ) -> list[dict[str, Any]]: - """Resolve user-mentioned folder ids to ```` entries. + """Resolve mentioned folder ids to canonical-path priority entries. - Each entry uses the canonical ``/documents/Folder/Sub/`` virtual - path (matching ``KnowledgeTreeMiddleware`` and the agent's - ``ls`` adapter) and is flagged ``mentioned=True`` so the - rendered line carries ``[USER-MENTIONED]``. ``score`` is left - ``None`` so the renderer prints ``n/a`` — folders aren't - ranked, the agent decides which children to read. + Flagged ``mentioned=True`` with ``score=None`` (folders aren't ranked; + the agent decides which children to read). """ if not folder_ids: return [] @@ -1044,12 +1009,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] return priority, matched_chunk_ids -# Backwards-compatible alias for any external imports. -KnowledgeBaseSearchMiddleware = KnowledgePriorityMiddleware - - __all__ = [ - "KnowledgeBaseSearchMiddleware", "KnowledgePriorityMiddleware", "browse_recent_documents", "fetch_mentioned_documents", diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/patch_tool_calls.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/patch_tool_calls.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/patch_tool_calls.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/patch_tool_calls.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/bundle.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/bundle.py similarity index 88% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/bundle.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/bundle.py index 111244784..8b83c9b27 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/bundle.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/bundle.py @@ -10,15 +10,15 @@ from langchain.agents.middleware import ( ToolCallLimitMiddleware, ) -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.middleware import RetryAfterMiddleware -from app.agents.new_chat.middleware.scoped_model_fallback import ( - ScopedModelFallbackMiddleware, -) +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.shared.middleware import RetryAfterMiddleware from .fallback import build_fallback_mw from .model_call_limit import build_model_call_limit_mw from .retry import build_retry_mw +from .scoped_model_fallback import ( + ScopedModelFallbackMiddleware, +) from .tool_call_limit import build_tool_call_limit_mw diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/fallback.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/fallback.py similarity index 82% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/fallback.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/fallback.py index ea68a764e..5b7dcc6ce 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/fallback.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/fallback.py @@ -4,12 +4,12 @@ from __future__ import annotations import logging -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.middleware.scoped_model_fallback import ( - ScopedModelFallbackMiddleware, -) +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags from ..flags import enabled +from .scoped_model_fallback import ( + ScopedModelFallbackMiddleware, +) def build_fallback_mw( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/model_call_limit.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/model_call_limit.py similarity index 85% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/model_call_limit.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/model_call_limit.py index 85707a385..2565a4b13 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/model_call_limit.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/model_call_limit.py @@ -4,7 +4,7 @@ from __future__ import annotations from langchain.agents.middleware import ModelCallLimitMiddleware -from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags from ..flags import enabled diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/retry.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/retry.py similarity index 69% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/retry.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/retry.py index c98fc4083..b0ce3e324 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/retry.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/retry.py @@ -2,8 +2,8 @@ from __future__ import annotations -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.middleware import RetryAfterMiddleware +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.shared.middleware import RetryAfterMiddleware from ..flags import enabled diff --git a/surfsense_backend/app/agents/new_chat/middleware/scoped_model_fallback.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/scoped_model_fallback.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/middleware/scoped_model_fallback.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/scoped_model_fallback.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/tool_call_limit.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/tool_call_limit.py similarity index 85% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/tool_call_limit.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/tool_call_limit.py index dcde81f37..0e4708849 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/tool_call_limit.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/resilience/tool_call_limit.py @@ -4,7 +4,7 @@ from __future__ import annotations from langchain.agents.middleware import ToolCallLimitMiddleware -from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags from ..flags import enabled diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/todos.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/todos.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/todos.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/todos.py diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/__init__.py new file mode 100644 index 000000000..cad69379b --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/__init__.py @@ -0,0 +1,41 @@ +"""Permissions vertical slice: rule model + allow/deny/ask enforcement. + +Self-contained subsystem combining the permission rule engine (:mod:`.model`) +with the pattern-based allow/deny/ask middleware and its HITL fallback +(:mod:`.middleware`, :mod:`.ask`, :mod:`.deny`). + +Public surface: +- rule model: ``Rule``, ``Ruleset``, ``RuleAction`` and the ``evaluate`` / + ``evaluate_many`` / ``aggregate_action`` / ``wildcard_match`` helpers. +- middleware: ``build_permission_mw`` — the construction recipe shared by + every agent stack. +""" + +# isort: off +# Import order matters: the rule model must be bound on this package before the +# middleware loads, because the middleware transitively imports consumers (e.g. +# app.services.user_tool_allowlist) that re-import ``Rule``/``Ruleset`` from this +# package root. Loading ``.model`` first avoids a partially-initialized cycle. +from .model import ( + Rule, + RuleAction, + Ruleset, + aggregate_action, + evaluate, + evaluate_many, + wildcard_match, +) +from .middleware.factory import build_permission_mw + +# isort: on + +__all__ = [ + "Rule", + "RuleAction", + "Ruleset", + "aggregate_action", + "build_permission_mw", + "evaluate", + "evaluate_many", + "wildcard_match", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/decision.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/decision.py similarity index 97% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/decision.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/decision.py index f507e85ff..e77f16c35 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/decision.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/decision.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging from typing import Any -from app.agents.multi_agent_chat.subagents.shared.hitl.wire import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.wire import ( LC_DECISION_APPROVE, LC_DECISION_EDIT, LC_DECISION_REJECT, diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/edit/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/edit/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/edit/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/edit/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/edit/merge.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/edit/merge.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/edit/merge.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/edit/merge.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/payload.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/payload.py similarity index 94% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/payload.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/payload.py index 6c5d011df..c16b9072a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/payload.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/payload.py @@ -6,14 +6,14 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.hitl.wire import ( +from app.agents.chat.multi_agent_chat.shared.permissions.model import Rule +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.wire import ( LC_DECISION_APPROVE, LC_DECISION_EDIT, LC_DECISION_REJECT, SURFSENSE_DECISION_APPROVE_ALWAYS, build_lc_hitl_payload, ) -from app.agents.new_chat.permissions import Rule PERMISSION_ASK_INTERRUPT_TYPE = "permission_ask" diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/request.py similarity index 96% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/request.py index 3db51883d..7dc1e0a3c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/ask/request.py @@ -16,7 +16,7 @@ from typing import Any from langchain_core.tools import BaseTool from langgraph.types import interrupt -from app.agents.new_chat.permissions import Rule +from app.agents.chat.multi_agent_chat.shared.permissions.model import Rule from app.observability import metrics as ot_metrics, otel as ot from .decision import normalize_permission_decision diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/deny.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/deny.py similarity index 89% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/deny.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/deny.py index 196c4040e..83677b4ca 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/deny.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/deny.py @@ -11,8 +11,8 @@ from typing import Any from langchain_core.messages import ToolMessage -from app.agents.new_chat.errors import StreamingError -from app.agents.new_chat.permissions import Rule +from app.agents.chat.multi_agent_chat.shared.permissions.model import Rule +from app.agents.chat.runtime.errors import StreamingError def build_deny_message(tool_call: dict[str, Any], rule: Rule) -> ToolMessage: diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/core.py similarity index 98% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/core.py index d2950c5b4..a97e32379 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/core.py @@ -26,8 +26,8 @@ from langchain_core.messages import AIMessage, ToolMessage from langchain_core.tools import BaseTool from langgraph.runtime import Runtime -from app.agents.new_chat.errors import CorrectedError, RejectedError -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions.model import Ruleset +from app.agents.chat.runtime.errors import CorrectedError, RejectedError from app.services.user_tool_allowlist import TrustedToolSaver from ..ask.edit import merge_edited_args diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/evaluation.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/evaluation.py similarity index 96% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/evaluation.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/evaluation.py index 51531c4eb..745c1d727 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/evaluation.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/evaluation.py @@ -16,7 +16,7 @@ from __future__ import annotations import logging from typing import Any -from app.agents.new_chat.permissions import ( +from app.agents.chat.multi_agent_chat.shared.permissions.model import ( Rule, RuleAction, Ruleset, diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/factory.py similarity index 95% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/factory.py index 3c061ded6..7f143d640 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/factory.py @@ -27,8 +27,8 @@ from collections.abc import Sequence from langchain_core.tools import BaseTool -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.permissions import Rule, Ruleset +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.permissions.model import Rule, Ruleset from app.services.user_tool_allowlist import TrustedToolSaver from .core import PermissionMiddleware diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/pattern_resolver.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/pattern_resolver.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/pattern_resolver.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/pattern_resolver.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/ruleset_view.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/ruleset_view.py similarity index 87% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/ruleset_view.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/ruleset_view.py index fbb66d455..da089114e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/ruleset_view.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/ruleset_view.py @@ -9,7 +9,11 @@ newly-promoted rules apply to subsequent calls. from __future__ import annotations -from app.agents.new_chat.permissions import Ruleset, aggregate_action, evaluate_many +from app.agents.chat.multi_agent_chat.shared.permissions.model import ( + Ruleset, + aggregate_action, + evaluate_many, +) def all_rulesets( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/runtime_promote.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/runtime_promote.py similarity index 88% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/runtime_promote.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/runtime_promote.py index afc65fdc0..2ae38db50 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/runtime_promote.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/middleware/runtime_promote.py @@ -7,7 +7,7 @@ is the streaming layer's job — this module keeps the in-memory copy only. from __future__ import annotations -from app.agents.new_chat.permissions import Rule, Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions.model import Rule, Ruleset def persist_always( diff --git a/surfsense_backend/app/agents/new_chat/permissions.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/model.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/permissions.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/permissions/model.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/receipts/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/receipts/__init__.py diff --git a/surfsense_backend/app/agents/shared/receipt_command.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/receipts/command.py similarity index 88% rename from surfsense_backend/app/agents/shared/receipt_command.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/receipts/command.py index f1c269e90..d31df998c 100644 --- a/surfsense_backend/app/agents/shared/receipt_command.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/receipts/command.py @@ -6,7 +6,7 @@ participate in the verification teaching from ``multi_agent_chat/subagents/shared/snippets/verifiable_handle.md`` those tools now also need to write a :class:`Receipt` into the parent's ``state['receipts']`` list (declared on -:class:`~app.agents.new_chat.filesystem_state.SurfSenseFilesystemState` +:class:`~app.agents.chat.multi_agent_chat.shared.state.filesystem_state.SurfSenseFilesystemState` and backed by the append reducer). :func:`with_receipt` wraps both behaviours: it returns the tool payload as @@ -24,7 +24,7 @@ from typing import Any from langchain_core.messages import ToolMessage from langgraph.types import Command -from app.agents.shared.receipt import Receipt +from app.agents.chat.multi_agent_chat.shared.receipts.receipt import Receipt def _content_to_text(payload: dict[str, Any] | str) -> str: @@ -51,7 +51,7 @@ def with_receipt( """Return a Command that ships ``payload`` as a ToolMessage AND appends ``receipt``. The append happens via the ``_list_append_reducer`` on the ``receipts`` - field of :class:`~app.agents.new_chat.filesystem_state.SurfSenseFilesystemState`, + field of :class:`~app.agents.chat.multi_agent_chat.shared.state.filesystem_state.SurfSenseFilesystemState`, so concurrent subagent batches (item 4 in the plan) won't clobber each other's receipts. """ diff --git a/surfsense_backend/app/agents/shared/receipt.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/receipts/receipt.py similarity index 96% rename from surfsense_backend/app/agents/shared/receipt.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/receipts/receipt.py index 6f30067ee..b1986a224 100644 --- a/surfsense_backend/app/agents/shared/receipt.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/receipts/receipt.py @@ -5,7 +5,7 @@ delegate_tool.py:1663-1697``) for our 5 deliverable types + 15 connectors + KB writes. The supervisor reads the Receipt to verify what actually happened without round-tripping through LLM paraphrase. -**Why this lives under ``app.agents.shared`` and not under either of the +**Why this lives under ``app.agents.chat.shared`` and not under either of the two agent packages:** the Receipt is a *contract* shared between ``multi_agent_chat`` (where mutating tools emit it) and ``new_chat`` (where ``filesystem_state.SurfSenseFilesystemState`` declares the @@ -23,7 +23,7 @@ the receipt into the parent's ``receipts`` state via the append reducer. The KB write path is the one exception: file-tool calls cannot emit a durable receipt because the actual DB writes happen end-of-turn inside -:class:`app.agents.new_chat.middleware.kb_persistence.KnowledgeBasePersistenceMiddleware`. +:class:`app.agents.chat.multi_agent_chat.shared.middleware.kb_persistence.KnowledgeBasePersistenceMiddleware`. KB tools therefore emit a *provisional* receipt with ``status="pending"``; the persistence middleware flips it to ``"success"`` or ``"failed"`` before returning control to the parent. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/state/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/state/__init__.py diff --git a/surfsense_backend/app/agents/new_chat/filesystem_state.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/state/filesystem_state.py similarity index 96% rename from surfsense_backend/app/agents/new_chat/filesystem_state.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/state/filesystem_state.py index de2c94b41..41bed9d62 100644 --- a/surfsense_backend/app/agents/new_chat/filesystem_state.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/state/filesystem_state.py @@ -20,7 +20,7 @@ extra fields needed to implement Postgres-backed virtual filesystem semantics: * ``workspace_tree_text`` — pre-rendered ```` body for the turn. Tools mutate these fields ONLY via ``Command(update=...)`` returns; the -reducers in :mod:`app.agents.new_chat.state_reducers` handle merging. +reducers in :mod:`app.agents.chat.multi_agent_chat.shared.state.reducers` handle merging. """ from __future__ import annotations @@ -30,14 +30,14 @@ from typing import Annotated, Any, NotRequired from deepagents.middleware.filesystem import FilesystemState from typing_extensions import TypedDict -from app.agents.new_chat.state_reducers import ( +from app.agents.chat.multi_agent_chat.shared.receipts.receipt import Receipt +from app.agents.chat.multi_agent_chat.shared.state.reducers import ( _add_unique_reducer, _dict_merge_with_tombstones_reducer, _int_counter_merge_reducer, _list_append_reducer, _replace_reducer, ) -from app.agents.shared.receipt import Receipt class PendingMove(TypedDict, total=False): @@ -190,7 +190,7 @@ class SurfSenseFilesystemState(FilesystemState): Each mutating tool (deliverables, every connector, KB writes via the persistence middleware) wraps its native return into a - :class:`~app.agents.shared.receipt.Receipt` + :class:`~app.agents.chat.multi_agent_chat.shared.receipts.receipt.Receipt` and returns it under the ``"receipt"`` key alongside its existing payload. The subagent's tool-call middleware folds the receipt into this list, and ``_return_command_with_state_update`` in diff --git a/surfsense_backend/app/agents/new_chat/state_reducers.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/state/reducers.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/state_reducers.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/state/reducers.py diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/__init__.py new file mode 100644 index 000000000..a36be01eb --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/__init__.py @@ -0,0 +1 @@ +"""Tools shared across multi_agent_chat (main agent + subagents + boundary).""" diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/catalog.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/catalog.py new file mode 100644 index 000000000..1aff733b2 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/catalog.py @@ -0,0 +1,83 @@ +"""Pure-data catalog of built-in agent tools. + +This module advertises *what* tools exist and their display metadata. It is +intentionally free of any tool implementation imports (no connectors, no +factories) so it can be consumed without pulling the whole tool dependency +graph — and so connector packages stay independently deletable. + +The single live consumer is the ``GET /agent/tools`` endpoint, which renders +the tool picker in the web UI. Tool *construction* lives elsewhere: + +* main-agent tools -> ``app.agents.chat.multi_agent_chat.main_agent.tools.registry`` +* subagent / connector tools -> ``app.agents.chat.multi_agent_chat.subagents.*`` +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ToolMetadata: + """Display metadata for a single built-in tool. + + Attributes: + name: Unique identifier for the tool. + description: Human-readable description of what the tool does. + enabled_by_default: Whether the tool is on when no explicit config + is provided. + hidden: WIP tools that should be excluded from public listings. + + """ + + name: str + description: str + enabled_by_default: bool = True + hidden: bool = False + + +# Catalog of all built-in tools. Contributors: add new tools here so they show +# up in the UI tool picker. This list carries metadata only — wire the actual +# implementation in the relevant builder/registry module. +TOOL_CATALOG: list[ToolMetadata] = [ + ToolMetadata(name="generate_podcast", description="Generate an audio podcast from provided content"), + ToolMetadata(name="generate_video_presentation", description="Generate a video presentation with slides and narration from provided content"), + ToolMetadata(name="generate_report", description="Generate a structured report from provided content and export it"), + ToolMetadata(name="generate_resume", description="Generate a professional resume as a Typst document"), + ToolMetadata(name="generate_image", description="Generate images from text descriptions using AI image models"), + ToolMetadata(name="scrape_webpage", description="Scrape and extract the main content from a webpage"), + ToolMetadata(name="web_search", description="Search the web for real-time information using configured search engines"), + ToolMetadata(name="create_automation", description="Draft an automation from an NL intent; user approves the card; tool saves"), + ToolMetadata(name="update_memory", description="Save important long-term facts, preferences, and instructions to the (personal or team) memory"), + ToolMetadata(name="create_notion_page", description="Create a new page in the user's Notion workspace"), + ToolMetadata(name="update_notion_page", description="Append new content to an existing Notion page"), + ToolMetadata(name="delete_notion_page", description="Delete an existing Notion page"), + ToolMetadata(name="create_google_drive_file", description="Create a new Google Doc or Google Sheet in Google Drive"), + ToolMetadata(name="delete_google_drive_file", description="Move an indexed Google Drive file to trash"), + ToolMetadata(name="create_dropbox_file", description="Create a new file in Dropbox"), + ToolMetadata(name="delete_dropbox_file", description="Delete a file from Dropbox"), + ToolMetadata(name="create_onedrive_file", description="Create a new file in Microsoft OneDrive"), + ToolMetadata(name="delete_onedrive_file", description="Move a OneDrive file to the recycle bin"), + ToolMetadata(name="search_calendar_events", description="Search Google Calendar events within a date range"), + ToolMetadata(name="create_calendar_event", description="Create a new event on Google Calendar"), + ToolMetadata(name="update_calendar_event", description="Update an existing indexed Google Calendar event"), + ToolMetadata(name="delete_calendar_event", description="Delete an existing indexed Google Calendar event"), + ToolMetadata(name="search_gmail", description="Search emails in Gmail using Gmail search syntax"), + ToolMetadata(name="read_gmail_email", description="Read the full content of a specific Gmail email"), + ToolMetadata(name="create_gmail_draft", description="Create a draft email in Gmail"), + ToolMetadata(name="send_gmail_email", description="Send an email via Gmail"), + ToolMetadata(name="trash_gmail_email", description="Move an indexed email to trash in Gmail"), + ToolMetadata(name="update_gmail_draft", description="Update an existing Gmail draft"), + ToolMetadata(name="create_confluence_page", description="Create a new page in the user's Confluence space"), + ToolMetadata(name="update_confluence_page", description="Update an existing indexed Confluence page"), + ToolMetadata(name="delete_confluence_page", description="Delete an existing indexed Confluence page"), + ToolMetadata(name="list_discord_channels", description="List text channels in the connected Discord server"), + ToolMetadata(name="read_discord_messages", description="Read recent messages from a Discord text channel"), + ToolMetadata(name="send_discord_message", description="Send a message to a Discord text channel"), + ToolMetadata(name="list_teams_channels", description="List Microsoft Teams and their channels"), + ToolMetadata(name="read_teams_messages", description="Read recent messages from a Microsoft Teams channel"), + ToolMetadata(name="send_teams_message", description="Send a message to a Microsoft Teams channel"), + ToolMetadata(name="list_luma_events", description="List upcoming and recent Luma events"), + ToolMetadata(name="read_luma_event", description="Read detailed information about a specific Luma event"), + ToolMetadata(name="create_luma_event", description="Create a new event on Luma"), +] diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/hitl.py similarity index 81% rename from surfsense_backend/app/agents/new_chat/tools/hitl.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/hitl.py index 5b64929de..9b16e1a4c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/hitl.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/hitl.py @@ -6,7 +6,7 @@ shared by every sensitive tool (native connectors and MCP tools alike). Usage inside a tool:: - from app.agents.new_chat.tools.hitl import request_approval + from app.agents.chat.multi_agent_chat.shared.tools.hitl import request_approval result = request_approval( action_type="gmail_email_send", @@ -30,22 +30,11 @@ from langgraph.types import interrupt logger = logging.getLogger(__name__) -# Tools that mirror the safety profile of ``write_file`` against the -# SurfSense KB: each call creates ONE artifact in the user's own workspace -# with no external visibility (drafts aren't sent; new files aren't shared -# unless the user shares them later). These are auto-approved by default -# so the agent can compose drafts and seed scratch files without a popup -# on every call. -# -# Members of this set still call ``request_approval`` exactly as before; -# the function returns immediately with ``decision_type="auto_approved"`` -# and the original params untouched. This preserves the call-site shape -# (logging, metadata fetching, account fallbacks) so the only behavior -# change is "no interrupt fires". -# -# To re-enable prompting, the future per-search-space rules table -# (``agent_permission_rules``) takes precedence — see the ``# (future)`` -# layer-3 comment in :mod:`app.agents.new_chat.chat_deepagent`. +# Low-stakes creation tools auto-approved by default: each creates one +# artifact in the user's own workspace with no external visibility (drafts +# aren't sent; new files aren't shared). They still call ``request_approval``, +# which returns ``decision_type="auto_approved"`` without firing an interrupt. +# Per-search-space ``agent_permission_rules`` can re-enable prompting. DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset( { "create_gmail_draft", @@ -150,10 +139,6 @@ def request_approval( return HITLResult(rejected=False, decision_type="trusted", params=dict(params)) if tool_name in DEFAULT_AUTO_APPROVED_TOOLS: - # Default policy: low-stakes creation tools (drafts + new-file - # creates) skip HITL because they're as recoverable as a local - # ``write_file`` against the SurfSense KB. The user can still - # delete the artifact in <30s if it's wrong. logger.info( "Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL", tool_name, diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/mcp/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/mcp/__init__.py new file mode 100644 index 000000000..07a5b02de --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/mcp/__init__.py @@ -0,0 +1,7 @@ +"""MCP (Model Context Protocol) integration: client, tool loading, and cache. + +Split by responsibility: +- ``client``: the low-level :class:`MCPClient` connection wrapper. +- ``tool``: discovery + LangChain tool construction and cache invalidation. +- ``cache``: the connector tool-cache refresh helpers. +""" diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tools_cache.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/mcp/cache.py similarity index 94% rename from surfsense_backend/app/agents/new_chat/tools/mcp_tools_cache.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/mcp/cache.py index 81027e1c4..d088fac0b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tools_cache.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/mcp/cache.py @@ -112,7 +112,9 @@ def refresh_mcp_tools_cache_for_connector( when an event loop is available. Neither path raises. """ try: - from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + from app.agents.chat.multi_agent_chat.shared.tools.mcp.tool import ( + invalidate_mcp_tools_cache, + ) invalidate_mcp_tools_cache(search_space_id) except Exception: @@ -133,7 +135,9 @@ def refresh_mcp_tools_cache_for_connector( async def _run_connector_prefetch(connector_id: int) -> None: - from app.agents.new_chat.tools.mcp_tool import discover_single_mcp_connector + from app.agents.chat.multi_agent_chat.shared.tools.mcp.tool import ( + discover_single_mcp_connector, + ) try: await discover_single_mcp_connector(connector_id) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/mcp/client.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/tools/mcp_client.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/mcp/client.py diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/mcp/tool.py similarity index 99% rename from surfsense_backend/app/agents/new_chat/tools/mcp_tool.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/mcp/tool.py index 6c4cfb6be..a1240391b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/mcp/tool.py @@ -33,14 +33,16 @@ from sqlalchemy import cast, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args -from app.agents.new_chat.tools.hitl import request_approval -from app.agents.new_chat.tools.mcp_client import MCPClient -from app.agents.new_chat.tools.mcp_tools_cache import ( +from app.agents.chat.multi_agent_chat.shared.middleware.dedup_tool_calls import ( + dedup_key_full_args, +) +from app.agents.chat.multi_agent_chat.shared.tools.hitl import request_approval +from app.agents.chat.multi_agent_chat.shared.tools.mcp.cache import ( CachedMCPTools, read_cached_tools, write_cached_tools, ) +from app.agents.chat.multi_agent_chat.shared.tools.mcp.client import MCPClient from app.db import SearchSourceConnector from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type from app.utils.perf import get_perf_logger diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/agent.py similarity index 81% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/agent.py index 396e0ec79..b483b8578 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET, load_tools diff --git a/surfsense_backend/app/agents/shared/deliverable_wait.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/deliverable_wait.py similarity index 92% rename from surfsense_backend/app/agents/shared/deliverable_wait.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/deliverable_wait.py index abaa017ea..2fcc98385 100644 --- a/surfsense_backend/app/agents/shared/deliverable_wait.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/deliverable_wait.py @@ -1,10 +1,10 @@ """Shared poll-until-terminal helper for Celery-backed deliverables. -Lives in ``app.agents.shared`` (neutral package, no dependencies on either -``new_chat`` or ``multi_agent_chat``) so both the flat single-agent tools -under ``app/agents/new_chat/tools/`` and the multi-agent subagent tools -under ``app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/`` -can import it without creating a circular dependency. +Lives in ``app.agents.chat.shared`` (neutral kernel package, no dependency on +``multi_agent_chat``) so both the shared tools under ``app/agents/shared/tools/`` +and the multi-agent subagent tools under +``app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/`` can import +it without creating a circular dependency. Background ---------- diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/system_prompt.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py similarity index 81% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index 094371760..7bb4a7c24 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -11,8 +11,8 @@ from litellm import aimage_generation from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.shared.receipt import make_receipt -from app.agents.shared.receipt_command import with_receipt +from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt +from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt from app.config import config from app.db import ( ImageGeneration, @@ -25,6 +25,7 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) @@ -43,13 +44,16 @@ _PROVIDER_MAP = { } +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + def _build_model_string( provider: str, model_name: str, custom_provider: str | None ) -> str: - if custom_provider: - return f"{custom_provider}/{model_name}" - prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) - return f"{prefix}/{model_name}" + return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" def _get_global_image_gen_config(config_id: int) -> dict | None: @@ -71,7 +75,7 @@ def create_generate_image_tool( captured model), use this config id instead of reading the search space's live ``image_generation_config_id``. """ - del db_session # use a fresh per-call session, see below + del db_session # tool uses a fresh per-call session instead @tool async def generate_image( @@ -136,17 +140,12 @@ def create_generate_image_tool( or IMAGE_GEN_AUTO_MODE_ID ) - # Build generation kwargs - # NOTE: size, quality, and style are intentionally NOT passed. - # Different models support different values for these params - # (e.g. DALL-E 3 wants "hd"/"standard" for quality while - # gpt-image-1 wants "high"/"medium"/"low"; size options also - # differ). Letting the model use its own defaults avoids errors. + # size/quality/style are intentionally omitted: valid values + # differ per model, so we let each model use its own defaults. gen_kwargs: dict[str, Any] = {} if n is not None and n > 1: gen_kwargs["n"] = n - # Call litellm based on config type if is_image_gen_auto_mode(config_id): if not ImageGenRouterService.is_initialized(): err = ( @@ -163,14 +162,20 @@ def create_generate_image_tool( err = f"Image generation config {config_id} not found" return _failed({"error": err}, error=err) - model_string = _build_model_string( - cfg.get("provider", ""), - cfg["model_name"], - cfg.get("custom_provider"), + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) + model_string = f"{provider_prefix}/{cfg['model_name']}" gen_kwargs["api_key"] = cfg.get("api_key") - if cfg.get("api_base"): - gen_kwargs["api_base"] = cfg["api_base"] + # Defense-in-depth: an empty ``api_base`` must not fall + # through to LiteLLM's global ``api_base`` (e.g. Azure). + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base if cfg.get("api_version"): gen_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): @@ -191,14 +196,20 @@ def create_generate_image_tool( err = f"Image generation config {config_id} not found" return _failed({"error": err}, error=err) - model_string = _build_model_string( - db_cfg.provider.value, - db_cfg.model_name, - db_cfg.custom_provider, + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) + model_string = f"{provider_prefix}/{db_cfg.model_name}" gen_kwargs["api_key"] = db_cfg.api_key - if db_cfg.api_base: - gen_kwargs["api_base"] = db_cfg.api_base + # Defense-in-depth: an empty ``api_base`` must not fall + # through to LiteLLM's global ``api_base`` (e.g. Azure). + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base if db_cfg.api_version: gen_kwargs["api_version"] = db_cfg.api_version if db_cfg.litellm_params: @@ -208,17 +219,13 @@ def create_generate_image_tool( prompt=prompt, model=model_string, **gen_kwargs ) - # Parse the response and store in DB response_dict = ( response.model_dump() if hasattr(response, "model_dump") else dict(response) ) - # Generate a random access token for this image access_token = generate_image_token() - - # Save to image_generations table for history db_image_gen = ImageGeneration( prompt=prompt, model=getattr(response, "_hidden_params", {}).get("model"), @@ -233,7 +240,6 @@ def create_generate_image_tool( await session.refresh(db_image_gen) db_image_gen_id = db_image_gen.id - # Extract image URLs from response images = response_dict.get("data", []) if not images: return _failed( @@ -244,11 +250,8 @@ def create_generate_image_tool( first_image = images[0] revised_prompt = first_image.get("revised_prompt", prompt) - # Resolve image URL: - # - If the API returned a URL, use it directly. - # - If the API returned b64_json (e.g. gpt-image-1), serve the - # image through our backend endpoint to avoid bloating the - # LLM context with megabytes of base64 data. + # b64_json (e.g. gpt-image-1) is served via our backend endpoint so + # megabytes of base64 don't bloat the LLM context. if first_image.get("url"): image_url = first_image["url"] elif first_image.get("b64_json"): diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py similarity index 96% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py index ddfcbd7fb..b968c1701 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py @@ -9,7 +9,7 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset from .generate_image import create_generate_image_tool from .podcast import create_generate_podcast_tool diff --git a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py similarity index 87% rename from surfsense_backend/app/agents/new_chat/tools/knowledge_base.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py index c24497bfd..e99e0291a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py @@ -241,23 +241,12 @@ def _normalize_connectors( connectors_to_search: list[str] | None, available_connectors: list[str] | None = None, ) -> list[str]: + """Normalize model-supplied connectors to canonical ConnectorService types. + + Maps user-facing aliases (e.g. WEBCRAWLER_CONNECTOR), drops unknowns, and + constrains to ``available_connectors`` when given. Empty input defaults to + all available connectors (minus live-search ones). """ - Normalize connectors provided by the model. - - - Accepts user-facing enums like WEBCRAWLER_CONNECTOR and maps them to canonical - ConnectorService types. - - Drops unknown values. - - If available_connectors is provided, only includes connectors from that list. - - If connectors_to_search is None/empty, defaults to available_connectors or all. - - Args: - connectors_to_search: List of connectors requested by the model - available_connectors: List of connectors actually available in the search space - - Returns: - List of normalized connector strings to search - """ - # Determine the set of valid connectors to consider valid_set = ( set(available_connectors) if available_connectors else set(_ALL_CONNECTORS) ) @@ -276,18 +265,16 @@ def _normalize_connectors( c = (raw or "").strip().upper() if not c: continue - # Map user-facing aliases to canonical names if c == "WEBCRAWLER_CONNECTOR": c = "CRAWLED_URL" normalized.append(c) - # de-dupe while preserving order + filter to valid connectors + # De-dupe (order-preserving), keeping only known + available connectors. seen: set[str] = set() out: list[str] = [] for c in normalized: if c in seen: continue - # Only include if it's a known connector AND available if c not in _ALL_CONNECTORS: continue if c not in valid_set: @@ -295,7 +282,7 @@ def _normalize_connectors( seen.add(c) out.append(c) - # Fallback to all available if nothing matched + # Nothing matched: fall back to all available. if not out: base = ( list(available_connectors) @@ -377,39 +364,17 @@ def format_documents_for_context( max_chunk_chars: int = _MAX_CHUNK_CHARS, max_chunks_per_doc: int = 0, ) -> str: - """ - Format retrieved documents into a readable context string for the LLM. + """Format retrieved documents into an XML context string for the LLM. - Documents are added in order (highest relevance first) until the character - budget is reached. Individual chunks are capped at ``max_chunk_chars`` and - each document is limited to a dynamically computed chunk cap so a single - large document cannot monopolize the output while still maximising the use - of available context space. - - Args: - documents: List of document dictionaries from connector search - max_chars: Approximate character budget for the entire output. - max_chunk_chars: Per-chunk character cap (content is tail-truncated). - max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means - auto-compute per document using a rank-adaptive formula so - higher-ranked documents receive more chunks. - - Returns: - Formatted string with document contents and metadata + Documents are emitted highest-relevance first until ``max_chars`` is hit. + ``max_chunks_per_doc=0`` auto-computes a rank-adaptive cap so top results get + more chunks and no single large document monopolizes the budget. """ if not documents: return "" - # Group chunks by document id (preferred) to produce the XML structure. - # - # IMPORTANT: ConnectorService returns **document-grouped** results of the form: - # { - # "document": {...}, - # "chunks": [{"chunk_id": 123, "content": "..."}, ...], - # "source": "NOTION_CONNECTOR" | "FILE" | ... - # } - # - # We must preserve chunk_id so citations like [citation:123] are possible. + # Group chunks by document id, preserving chunk_id so [citation:123] works. + # ConnectorService returns document-grouped results ({document, chunks, source}). grouped: dict[str, dict[str, Any]] = {} for doc in documents: @@ -430,7 +395,7 @@ def format_documents_for_context( or "UNKNOWN" ) - # Document identity (prefer document_id; otherwise fall back to type+title+url) + # Identity: prefer document_id, else type+title+url. document_id_val = document_info.get("id") title = ( document_info.get("title") or metadata.get("title") or "Untitled Document" @@ -460,7 +425,7 @@ def format_documents_for_context( "chunks": [], } - # Prefer document-grouped chunks if available + # Prefer document-grouped chunks when present. chunks_list = doc.get("chunks") if isinstance(doc, dict) else None if isinstance(chunks_list, list) and chunks_list: for ch in chunks_list: @@ -492,7 +457,6 @@ def format_documents_for_context( "BAIDU_SEARCH_API", } - # Render XML expected by citation instructions, respecting the char budget. parts: list[str] = [] total_chars = 0 total_docs = len(grouped) @@ -594,30 +558,11 @@ async def search_knowledge_base_async( available_document_types: list[str] | None = None, max_input_tokens: int | None = None, ) -> str: - """ - Search the user's knowledge base for relevant documents. + """Search the knowledge base across connectors and return formatted results. - This is the async implementation that searches across multiple connectors. - - Args: - query: The search query - search_space_id: The user's search space ID - db_session: Database session - connector_service: Initialized connector service - connectors_to_search: Optional list of connector types to search. If omitted, searches all. - top_k: Number of results per connector - start_date: Optional start datetime (UTC) for filtering documents - end_date: Optional end datetime (UTC) for filtering documents - available_connectors: Optional list of connectors actually available in the search space. - If provided, only these connectors will be searched. - available_document_types: Optional list of document types that actually have indexed - data. When provided, local connectors whose document type is - absent are skipped entirely (no embedding / DB round-trip). - max_input_tokens: Model context window size (tokens). Used to dynamically - size the output so it fits within the model's limits. - - Returns: - Formatted string with search results + ``available_document_types`` lets local connectors with no indexed data be + skipped (no embedding / DB round-trip), and ``max_input_tokens`` sizes the + output to the model's context window. """ perf = get_perf_logger() t0 = time.perf_counter() @@ -692,7 +637,7 @@ async def search_knowledge_base_raw_async( # Preserve the public signature for compatibility even if values are unused. _ = (db_session, connector_service) - from app.agents.new_chat.utils import resolve_date_range + from app.agents.chat.multi_agent_chat.shared.date_filters import resolve_date_range resolved_start_date, resolved_end_date = resolve_date_range( start_date=start_date, diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py similarity index 94% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py index 298257799..03850010e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py @@ -16,9 +16,11 @@ from langchain_core.tools import tool from langgraph.types import Command from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.shared.deliverable_wait import wait_for_deliverable -from app.agents.shared.receipt import make_receipt -from app.agents.shared.receipt_command import with_receipt +from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt +from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt +from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait import ( + wait_for_deliverable, +) from app.db import Podcast, PodcastStatus, shielded_async_session logger = logging.getLogger(__name__) @@ -96,7 +98,7 @@ def create_generate_podcast_tool( # Wait until the Celery worker flips the row to a terminal # state. The wait is bounded only by the subagent invoke # timeout (multi-agent) or HTTP lifetime (single-agent) — - # see app.agents.shared.deliverable_wait for details. + # see app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait for details. terminal_status, columns, elapsed = await wait_for_deliverable( model=Podcast, row_id=podcast_id, diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/report.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/report.py similarity index 90% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/report.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/report.py index 39b7c4694..d9874638c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/report.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/report.py @@ -12,8 +12,8 @@ from langchain_core.messages import HumanMessage from langchain_core.tools import tool from langgraph.types import Command -from app.agents.shared.receipt import make_receipt -from app.agents.shared.receipt_command import with_receipt +from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt +from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt from app.db import Report, shielded_async_session from app.services.connector_service import ConnectorService from app.services.llm_service import get_agent_llm @@ -196,13 +196,8 @@ def _strip_wrapping_code_fences(text: str) -> str: def _extract_metadata(content: str) -> dict[str, Any]: """Extract metadata from generated Markdown content.""" - # Count section headings headings = re.findall(r"^(#{1,6})\s+(.+)$", content, re.MULTILINE) - - # Word count word_count = len(content.split()) - - # Character count char_count = len(content) return { @@ -227,12 +222,11 @@ def _parse_sections(content: str) -> list[dict[str, str]]: in_code_block = False for line in lines: - # Track code blocks to avoid matching headings inside them + # Track fences so headings inside code blocks aren't treated as splits. stripped = line.strip() if stripped.startswith("```"): in_code_block = not in_code_block - # Only split on # or ## headings (not ### or deeper) and only outside code blocks is_section_heading = ( not in_code_block and re.match(r"^#{1,2}\s+", line) @@ -240,7 +234,6 @@ def _parse_sections(content: str) -> list[dict[str, str]]: ) if is_section_heading: - # Save previous section if current_heading or current_body_lines: sections.append( { @@ -253,7 +246,6 @@ def _parse_sections(content: str) -> list[dict[str, str]]: else: current_body_lines.append(line) - # Save last section if current_heading or current_body_lines: sections.append( { @@ -292,7 +284,6 @@ async def _revise_with_sections( Unchanged sections are kept byte-for-byte identical. Returns the revised content, or None to trigger full-document revision fallback. """ - # Parse report into sections sections = _parse_sections(parent_content) if len(sections) < 2: logger.info( @@ -300,7 +291,6 @@ async def _revise_with_sections( ) return None - # Build a sections listing for the LLM sections_listing = "" for i, sec in enumerate(sections): heading = sec["heading"] or "(preamble — content before first heading)" @@ -352,11 +342,9 @@ async def _revise_with_sections( ) return None - # Compute total operations for progress tracking total_ops = len(modify_indices) + len(add_sections) current_op = 0 - # Emit plan summary parts = [] if modify_indices: parts.append( @@ -394,7 +382,6 @@ async def _revise_with_sections( current_op += 1 sec = sections[idx] - # Extract plain section name (strip markdown heading markers) section_name = ( re.sub(r"^#+\s*", "", sec["heading"]).strip() if sec["heading"] @@ -412,7 +399,6 @@ async def _revise_with_sections( f"{sec['heading']}\n\n{sec['body']}" if sec["heading"] else sec["body"] ) - # Build context from surrounding sections context_parts = [] if idx > 0: prev = sections[idx - 1] @@ -442,7 +428,6 @@ async def _revise_with_sections( revised_text = resp.content if revised_text and isinstance(revised_text, str): revised_text = _strip_wrapping_code_fences(revised_text).strip() - # Parse the LLM output back into heading + body revised_parsed = _parse_sections(revised_text) if revised_parsed: revised_sections[idx] = revised_parsed[0] @@ -465,7 +450,6 @@ async def _revise_with_sections( heading = add_info.get("heading", "## New Section") description = add_info.get("description", "") - # Extract plain section name for progress display plain_heading = re.sub(r"^#+\s*", "", heading).strip() dispatch_custom_event( "report_progress", @@ -475,7 +459,6 @@ async def _revise_with_sections( }, ) - # Build context from the surrounding sections at the insertion point ctx_parts = [] if 0 <= after_idx < len(revised_sections): before_sec = revised_sections[after_idx] @@ -542,36 +525,13 @@ def create_generate_report_tool( available_connectors: list[str] | None = None, available_document_types: list[str] | None = None, ): - """ - Factory function to create the generate_report tool with injected dependencies. + """Create the generate_report tool with injected dependencies. - The tool generates a Markdown report inline using the search space's - agent LLM, saves it to the database, and returns immediately. - - Uses short-lived database sessions for each DB operation so no connection - is held during the long LLM API call. - - Generation strategies: - - New reports: single-shot generation (1 LLM call) - - Revisions (targeted edits): section-level (unchanged sections preserved) - - Revisions (global changes): full-document revision fallback - - Source strategies: - - "provided"/"conversation": use only the supplied source_content - - "kb_search": search the knowledge base internally using targeted queries - - "auto": use source_content if sufficient, otherwise fall back to KB search - - Args: - search_space_id: The user's search space ID - thread_id: The chat thread ID for associating the report - connector_service: Optional connector service for internal KB search. - When provided, the tool can search the knowledge base internally - (used by the "kb_search" and "auto" source strategies). - available_connectors: Optional list of connector types available in the - search space (used to scope internal KB searches). - - Returns: - A configured tool function for generating reports + Uses short-lived DB sessions per operation so no connection is held during + the long LLM call. Generation: new reports are single-shot; revisions try + section-level first (unchanged sections preserved) and fall back to full-doc. + Source strategies: provided/conversation (use source_content), kb_search + (internal KB queries), auto (KB search only when source_content is thin). """ @tool @@ -693,7 +653,7 @@ def create_generate_report_tool( Returns: Dict with status, report_id, title, word_count, and message. """ - # Initialize version tracking variables (used by _save_failed_report closure) + # Shared with the _save_failed_report closure. parent_report_content: str | None = None report_group_id: int | None = None @@ -733,7 +693,7 @@ def create_generate_report_tool( session.add(failed_report) await session.commit() await session.refresh(failed_report) - # If this is a new group (v1 failed), set group to self + # New group (v1 failed): point the group at itself. if not failed_report.report_group_id: failed_report.report_group_id = failed_report.id await session.commit() @@ -749,8 +709,8 @@ def create_generate_report_tool( try: # ── Phase 1: READ (short-lived session) ────────────────────── - # Fetch parent report and LLM config, then close the session - # so no DB connection is held during the long LLM call. + # Fetch parent report + LLM config, then release the connection + # before the long LLM call. async with shielded_async_session() as read_session: if parent_report_id: parent_report = await read_session.get(Report, parent_report_id) @@ -768,7 +728,6 @@ def create_generate_report_tool( ) llm = await get_agent_llm(read_session, search_space_id) - # read_session closed — connection returned to pool if not llm: error_msg = ( @@ -785,7 +744,6 @@ def create_generate_report_tool( error=error_msg, ) - # Build the user instructions string user_instructions_section = "" if user_instructions: user_instructions_section = ( @@ -829,7 +787,7 @@ def create_generate_report_tool( try: from .knowledge_base import search_knowledge_base_async - # Run all queries in parallel, each with its own session + # Each query gets its own short-lived session. async def _run_single_query(q: str) -> str: async with shielded_async_session() as kb_session: kb_connector_svc = ConnectorService( @@ -849,7 +807,6 @@ def create_generate_report_tool( *[_run_single_query(q) for q in search_queries[:5]] ) - # Merge non-empty results into source_content kb_text_parts = [r for r in kb_results if r and r.strip()] if kb_text_parts: kb_combined = "\n\n---\n\n".join(kb_text_parts) @@ -903,9 +860,9 @@ def create_generate_report_tool( "provided. Using source_content as-is." ) - capped_source = effective_source[:100000] # Cap source content + capped_source = effective_source[:100000] - # Length constraint — only when user explicitly asks for brevity + # Length constraint only when the user explicitly asked for brevity. length_instruction = "" if report_style == "brief": length_instruction = ( @@ -920,11 +877,8 @@ def create_generate_report_tool( report_content: str | None = None if parent_report_content: - # ─── REVISION MODE ─────────────────────────────────────── - # Strategy: Try section-level revision first (preserves - # unchanged sections byte-for-byte). Falls back to full- - # document revision if section identification fails or if - # all sections need changes. + # Revision mode: section-level first (preserves untouched + # sections), falling back to full-doc revision. dispatch_custom_event( "report_progress", { @@ -946,7 +900,6 @@ def create_generate_report_tool( ) if report_content is None: - # Fallback: full-document revision dispatch_custom_event( "report_progress", {"phase": "writing", "message": "Rewriting your full report"}, @@ -969,9 +922,7 @@ def create_generate_report_tool( report_content = response.content else: - # ─── NEW REPORT MODE ───────────────────────────────────── - # Single-shot generation: one LLM call produces the full - # report. Fast, globally coherent, and cost-efficient. + # New report: single-shot generation (one LLM call). dispatch_custom_event( "report_progress", {"phase": "writing", "message": "Writing your report"}, @@ -991,8 +942,6 @@ def create_generate_report_tool( response = await llm.ainvoke([HumanMessage(content=prompt)]) report_content = response.content - # ── Validate LLM output ────────────────────────────────────── - if not report_content or not isinstance(report_content, str): error_msg = "LLM returned empty or invalid content" report_id = await _save_failed_report(error_msg) @@ -1029,14 +978,12 @@ def create_generate_report_tool( if report_content.rstrip().endswith("---"): report_content = report_content.rstrip()[:-3].rstrip() - # Append exactly one standard disclaimer + # Append exactly one standard footer. report_content += "\n\n---\n\n" + _REPORT_FOOTER - # Extract metadata (includes "status": "ready") metadata = _extract_metadata(report_content) # ── Phase 3: WRITE (short-lived session) ───────────────────── - # Save the report to the database, then close the session. async with shielded_async_session() as write_session: report = Report( title=topic, @@ -1051,14 +998,13 @@ def create_generate_report_tool( await write_session.commit() await write_session.refresh(report) - # If this is a brand-new report (v1), set report_group_id = own id + # Brand-new report (v1): point the group at itself. if not report.report_group_id: report.report_group_id = report.id await write_session.commit() saved_report_id = report.id saved_group_id = report.report_group_id - # write_session closed — connection returned to pool logger.info( f"[generate_report] Created report {saved_report_id} " diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/resume.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/resume.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/resume.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/resume.py index 86332ccbe..f4697b835 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/resume.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/resume.py @@ -14,8 +14,8 @@ from langchain_core.messages import HumanMessage from langchain_core.tools import tool from langgraph.types import Command -from app.agents.shared.receipt import make_receipt -from app.agents.shared.receipt_command import with_receipt +from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt +from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt from app.db import Report, shielded_async_session from app.services.llm_service import get_agent_llm diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/video_presentation.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/video_presentation.py similarity index 94% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/video_presentation.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/video_presentation.py index 5407c8834..5c71ebf33 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/video_presentation.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/video_presentation.py @@ -17,9 +17,11 @@ from langchain_core.tools import tool from langgraph.types import Command from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.shared.deliverable_wait import wait_for_deliverable -from app.agents.shared.receipt import make_receipt -from app.agents.shared.receipt_command import with_receipt +from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt +from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt +from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait import ( + wait_for_deliverable, +) from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session logger = logging.getLogger(__name__) @@ -83,7 +85,7 @@ def create_generate_video_presentation_tool( # Wait until the Celery worker flips the row to a terminal # state. The wait is bounded only by the subagent invoke # timeout (multi-agent) or HTTP lifetime (single-agent) — - # see app.agents.shared.deliverable_wait for details. + # see app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait for details. terminal_status, _columns, elapsed = await wait_for_deliverable( model=VideoPresentation, row_id=video_pres_id, diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/agent.py similarity index 92% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/agent.py index c6a0220ec..2720589ef 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/agent.py @@ -13,9 +13,9 @@ from deepagents import SubAgent from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.permissions import Rule, Ruleset +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.permissions import Rule, Ruleset +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec from .middleware_stack import build_kb_middleware from .prompts import load_description, load_readonly_system_prompt, load_system_prompt diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/ask_knowledge_base_tool.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/ask_knowledge_base_tool.py similarity index 92% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/ask_knowledge_base_tool.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/ask_knowledge_base_tool.py index 1708fe52f..321477e11 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/ask_knowledge_base_tool.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/ask_knowledge_base_tool.py @@ -10,11 +10,9 @@ from langchain_core.runnables import Runnable from langchain_core.tools import StructuredTool from langgraph.types import Command -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import ( - subagent_invoke_config, -) -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.constants import ( +from app.agents.chat.multi_agent_chat.subagents.shared.invocation import ( EXCLUDED_STATE_KEYS, + subagent_invoke_config, ) from .prompts import load_readonly_description diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/description_readonly.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/description_readonly.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/description_readonly.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/description_readonly.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py similarity index 84% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py index 778bb250c..1407a4d65 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py @@ -10,27 +10,27 @@ from typing import Any from langchain_core.language_models import BaseChatModel -from app.agents.multi_agent_chat.middleware.shared.anthropic_cache import ( +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.middleware.anthropic_cache import ( build_anthropic_cache_mw, ) -from app.agents.multi_agent_chat.middleware.shared.compaction import ( +from app.agents.chat.multi_agent_chat.shared.middleware.compaction import ( build_compaction_mw, ) -from app.agents.multi_agent_chat.middleware.shared.filesystem import ( +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem import ( build_filesystem_mw, ) -from app.agents.multi_agent_chat.middleware.shared.kb_context_projection import ( +from app.agents.chat.multi_agent_chat.shared.middleware.kb_context_projection import ( build_kb_context_projection_mw, ) -from app.agents.multi_agent_chat.middleware.shared.patch_tool_calls import ( +from app.agents.chat.multi_agent_chat.shared.middleware.patch_tool_calls import ( build_patch_tool_calls_mw, ) -from app.agents.multi_agent_chat.middleware.shared.permissions import ( +from app.agents.chat.multi_agent_chat.shared.permissions import ( + Ruleset, build_permission_mw, ) -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.permissions import Ruleset def _kb_user_allowlist( diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/prompts.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/prompts.py similarity index 83% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/prompts.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/prompts.py index 617bb2a85..ea9ae4706 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/prompts.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/prompts.py @@ -2,8 +2,10 @@ from __future__ import annotations -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) def load_system_prompt(filesystem_mode: FilesystemMode) -> str: diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_cloud.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_cloud.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_cloud.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_cloud.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_desktop.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_desktop.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_desktop.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_desktop.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_readonly_cloud.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_readonly_cloud.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_readonly_cloud.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_readonly_cloud.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_readonly_desktop.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_readonly_desktop.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_readonly_desktop.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_readonly_desktop.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/tools/index.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/knowledge_base/tools/index.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/agent.py similarity index 78% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/agent.py index 84ab0c2fb..4038b13de 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/agent.py @@ -7,9 +7,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET, load_tools diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/system_prompt.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/tools/index.py similarity index 92% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/tools/index.py index b6e06dcdd..0afce9dec 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/tools/index.py @@ -6,7 +6,7 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset from app.db import ChatVisibility from .update_memory import create_update_memory_tool, create_update_team_memory_tool diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/update_memory.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/tools/update_memory.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/update_memory.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/memory/tools/update_memory.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/agent.py similarity index 78% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/agent.py index 37026bebd..9a694872b 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/agent.py @@ -7,9 +7,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET, load_tools diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/system_prompt.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/index.py similarity index 91% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/index.py index d8abce46c..1e823fafa 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/index.py @@ -6,7 +6,7 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset from .scrape_webpage import create_scrape_webpage_tool from .web_search import create_web_search_tool diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/scrape_webpage.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/scrape_webpage.py similarity index 91% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/scrape_webpage.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/scrape_webpage.py index bb7c8e5a3..f4f109761 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/scrape_webpage.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/scrape_webpage.py @@ -23,7 +23,6 @@ def extract_domain(url: str) -> str: try: parsed = urlparse(url) domain = parsed.netloc - # Remove 'www.' prefix if present if domain.startswith("www."): domain = domain[4:] return domain @@ -47,14 +46,13 @@ def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]: if len(content) <= max_length: return content, False - # Try to truncate at a sentence boundary + # Prefer truncating at a sentence/paragraph boundary. truncated = content[:max_length] last_period = truncated.rfind(".") last_newline = truncated.rfind("\n\n") - # Use the later of the two boundaries, or just truncate boundary = max(last_period, last_newline) - if boundary > max_length * 0.8: # Only use boundary if it's not too far back + if boundary > max_length * 0.8: # only if the boundary isn't too far back truncated = content[: boundary + 1] return truncated + "\n\n[Content truncated...]", True @@ -105,8 +103,8 @@ async def _scrape_youtube_video( http_client.proxies.update(residential_proxies) ytt_api = YouTubeTranscriptApi(http_client=http_client) - # List all available transcripts and pick the first one - # (the video's primary language) instead of defaulting to English + # Pick the first transcript (video's primary language) rather than + # defaulting to English. transcript_list = ytt_api.list(video_id) transcript = next(iter(transcript_list)) captions = transcript.fetch() @@ -128,10 +126,8 @@ async def _scrape_youtube_video( logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}") transcript_text = f"No captions available for this video. Error: {e!s}" - # Build combined content content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}" - # Truncate if needed content, was_truncated = truncate_content(content, max_length) word_count = len(content.split()) @@ -206,20 +202,16 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None): scrape_id = generate_scrape_id(url) domain = extract_domain(url) - # Validate and normalize URL if not url.startswith(("http://", "https://")): url = f"https://{url}" try: - # Check if this is a YouTube URL and use transcript API instead + # YouTube URLs use the transcript API instead of crawling. video_id = get_youtube_video_id(url) if video_id: return await _scrape_youtube_video(url, video_id, max_length) - # Create webcrawler connector connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key) - - # Crawl the URL result, error = await connector.crawl_url(url, formats=["markdown"]) if error: @@ -244,28 +236,21 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None): "error": "No content returned from crawler", } - # Extract content and metadata content = result.get("content", "") metadata = result.get("metadata", {}) - # Get title from metadata title = metadata.get("title", "") if not title: title = domain or url.split("/")[-1] or "Webpage" - # Get description from metadata description = metadata.get("description", "") if not description and content: - # Use first paragraph as description first_para = content.split("\n\n")[0] if content else "" description = ( first_para[:300] + "..." if len(first_para) > 300 else first_para ) - # Truncate content if needed content, was_truncated = truncate_content(content, max_length) - - # Calculate word count word_count = len(content.split()) return { diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/web_search.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/web_search.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/web_search.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/web_search.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/agent.py similarity index 80% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/agent.py index d7648d407..87391371a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/system_prompt.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/tools/index.py similarity index 90% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/tools/index.py index 9eebd2395..52cc8be2d 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/airtable/tools/index.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.permissions import Rule, Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Rule, Ruleset NAME = "airtable" diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/agent.py similarity index 81% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/agent.py index 7ef706c3d..b9b7b553a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET, load_tools diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/system_prompt.md diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/__init__.py new file mode 100644 index 000000000..717199fef --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/__init__.py @@ -0,0 +1,11 @@ +from .create_event import create_create_calendar_event_tool +from .delete_event import create_delete_calendar_event_tool +from .search_events import create_search_calendar_events_tool +from .update_event import create_update_calendar_event_tool + +__all__ = [ + "create_create_calendar_event_tool", + "create_delete_calendar_event_tool", + "create_search_calendar_events_tool", + "create_update_calendar_event_tool", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py index e5262bd43..91a50b3cc 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py @@ -8,7 +8,7 @@ from googleapiclient.discovery import build from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.services.google_calendar import GoogleCalendarToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py index 2f907e746..7682dae33 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py @@ -8,7 +8,7 @@ from googleapiclient.discovery import build from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.services.google_calendar import GoogleCalendarToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/index.py similarity index 94% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/index.py index 2570a51b2..b087105d4 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/index.py @@ -10,7 +10,7 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset from .create_event import create_create_calendar_event_tool from .delete_event import create_delete_calendar_event_tool diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/search_events.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/search_events.py similarity index 98% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/search_events.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/search_events.py index 6772d5a1e..cf9a015cf 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/search_events.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/search_events.py @@ -5,7 +5,9 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.agents.new_chat.tools.gmail.search_emails import _build_credentials +from app.agents.chat.multi_agent_chat.subagents.connectors.google_auth import ( + build_credentials as _build_credentials, +) from app.db import SearchSourceConnector, SearchSourceConnectorType logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py index e6f9f098e..78d3b147b 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py @@ -8,7 +8,7 @@ from googleapiclient.discovery import build from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.services.google_calendar import GoogleCalendarToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/agent.py similarity index 80% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/agent.py index e1308a100..dd6ea6503 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/system_prompt.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/tools/index.py similarity index 90% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/tools/index.py index b2c523080..c64da647a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/clickup/tools/index.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.permissions import Rule, Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Rule, Ruleset NAME = "clickup" diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/agent.py similarity index 81% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/agent.py index 5e95c876d..8322d901b 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET, load_tools diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/system_prompt.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/create_page.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/create_page.py similarity index 98% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/create_page.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/create_page.py index f33dc8e23..17497eee2 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/create_page.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/create_page.py @@ -5,7 +5,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.connectors.confluence_history import ConfluenceHistoryConnector diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/delete_page.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/delete_page.py similarity index 98% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/delete_page.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/delete_page.py index 7a3a4f2c7..5e2bd9868 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/delete_page.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/delete_page.py @@ -5,7 +5,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.connectors.confluence_history import ConfluenceHistoryConnector diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/index.py similarity index 93% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/index.py index b38503c5c..73350974e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/index.py @@ -9,7 +9,7 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset from .create_page import create_create_confluence_page_tool from .delete_page import create_delete_confluence_page_tool diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/update_page.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/update_page.py similarity index 98% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/update_page.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/update_page.py index 7a8207a00..7db9a24dc 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/update_page.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/confluence/tools/update_page.py @@ -5,7 +5,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.connectors.confluence_history import ConfluenceHistoryConnector diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/agent.py similarity index 81% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/agent.py index 567e72973..fe8f0df1e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET, load_tools diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/system_prompt.md diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/__init__.py new file mode 100644 index 000000000..e6733a098 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/__init__.py @@ -0,0 +1,9 @@ +from .list_channels import create_list_discord_channels_tool +from .read_messages import create_read_discord_messages_tool +from .send_message import create_send_discord_message_tool + +__all__ = [ + "create_list_discord_channels_tool", + "create_read_discord_messages_tool", + "create_send_discord_message_tool", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/_auth.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/_auth.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/_auth.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/_auth.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/index.py similarity index 92% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/index.py index c69ef3e5c..fcef3401a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/index.py @@ -9,7 +9,7 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset from .list_channels import create_list_discord_channels_tool from .read_messages import create_read_discord_messages_tool diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/list_channels.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/list_channels.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/list_channels.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/list_channels.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/read_messages.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/read_messages.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/read_messages.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/read_messages.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/send_message.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/send_message.py similarity index 97% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/send_message.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/send_message.py index 95890ed10..59ea1de30 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/send_message.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/discord/tools/send_message.py @@ -5,7 +5,7 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/agent.py similarity index 81% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/agent.py index d3ae6dc83..841bcba6e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET, load_tools diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/system_prompt.md diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/tools/__init__.py new file mode 100644 index 000000000..f2b8303a5 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/tools/__init__.py @@ -0,0 +1,7 @@ +from .create_file import create_create_dropbox_file_tool +from .trash_file import create_delete_dropbox_file_tool + +__all__ = [ + "create_create_dropbox_file_tool", + "create_delete_dropbox_file_tool", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/create_file.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/tools/create_file.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/create_file.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/tools/create_file.py index 2de7c301f..7732c35e5 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/create_file.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/tools/create_file.py @@ -8,7 +8,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.connectors.dropbox.client import DropboxClient diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/tools/index.py similarity index 91% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/tools/index.py index 68e02866a..440b4583c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/tools/index.py @@ -9,7 +9,7 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset from .create_file import create_create_dropbox_file_tool from .trash_file import create_delete_dropbox_file_tool diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/trash_file.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/tools/trash_file.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/trash_file.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/tools/trash_file.py index 7cb652d5d..c713bdd00 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/trash_file.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/dropbox/tools/trash_file.py @@ -6,7 +6,7 @@ from sqlalchemy import String, and_, cast, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.connectors.dropbox.client import DropboxClient diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/agent.py similarity index 81% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/agent.py index 082400eb9..be8adc17c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET, load_tools diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/system_prompt.md diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/__init__.py new file mode 100644 index 000000000..1f0839c44 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/__init__.py @@ -0,0 +1,15 @@ +from .create_draft import create_create_gmail_draft_tool +from .read_email import create_read_gmail_email_tool +from .search_emails import create_search_gmail_tool +from .send_email import create_send_gmail_email_tool +from .trash_email import create_trash_gmail_email_tool +from .update_draft import create_update_gmail_draft_tool + +__all__ = [ + "create_create_gmail_draft_tool", + "create_read_gmail_email_tool", + "create_search_gmail_tool", + "create_send_gmail_email_tool", + "create_trash_gmail_email_tool", + "create_update_gmail_draft_tool", +] diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/_helpers.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/_helpers.py new file mode 100644 index 000000000..12d984352 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/_helpers.py @@ -0,0 +1,46 @@ +"""Gmail-specific helpers for the Gmail connector tools. + +Google OAuth credential construction lives in +``app.agents.chat.multi_agent_chat.subagents.connectors.google_auth`` (shared +with the Calendar connector). It is re-exported here under the legacy private +names so the existing Gmail tools keep importing it from this module. +""" + +from __future__ import annotations + +from typing import Any + +from app.agents.chat.multi_agent_chat.subagents.connectors.google_auth import ( + build_credentials as _build_credentials, + get_token_encryption as _get_token_encryption, +) + +__all__ = [ + "_build_credentials", + "_format_gmail_summary", + "_get_token_encryption", + "_gmail_headers", +] + + +def _gmail_headers(message: dict[str, Any]) -> dict[str, str]: + headers = message.get("payload", {}).get("headers", []) + return { + header.get("name", "").lower(): header.get("value", "") + for header in headers + if isinstance(header, dict) + } + + +def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]: + headers = _gmail_headers(message) + return { + "message_id": message.get("id") or message.get("messageId"), + "thread_id": message.get("threadId"), + "subject": message.get("subject") or headers.get("subject", "No Subject"), + "from": message.get("sender") or headers.get("from", "Unknown"), + "to": message.get("to") or headers.get("to", ""), + "date": message.get("messageTimestamp") or headers.get("date", ""), + "snippet": message.get("snippet") or message.get("messageText", "")[:300], + "labels": message.get("labelIds", []), + } diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py index fb1461d7c..3f25305c5 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py @@ -8,7 +8,7 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.services.gmail import GmailToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/index.py similarity index 94% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/index.py index 020089ebb..60405dcf7 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/index.py @@ -9,7 +9,7 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset from .create_draft import create_create_gmail_draft_tool from .read_email import create_read_gmail_email_tool diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/read_email.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/read_email.py similarity index 95% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/read_email.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/read_email.py index 39526f25e..10c64c6c5 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/read_email.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/read_email.py @@ -61,11 +61,10 @@ def create_read_gmail_email_tool( "message": "Composio connected account ID not found for this Gmail connector.", } - from app.agents.new_chat.tools.gmail.search_emails import ( - _format_gmail_summary, - ) from app.services.composio_service import ComposioService + from ._helpers import _format_gmail_summary + detail, error = await ComposioService().get_gmail_message_detail( connected_account_id=cca_id, entity_id=f"surfsense_{user_id}", @@ -97,9 +96,7 @@ def create_read_gmail_email_tool( "content": content, } - from app.agents.new_chat.tools.gmail.search_emails import ( - _build_credentials, - ) + from ._helpers import _build_credentials creds = _build_credentials(connector) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/search_emails.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/search_emails.py similarity index 95% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/search_emails.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/search_emails.py index a9d7cdedf..2c633d629 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/search_emails.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/search_emails.py @@ -69,11 +69,10 @@ def create_search_gmail_tool( "message": "Composio connected account ID not found for this Gmail connector.", } - from app.agents.new_chat.tools.gmail.search_emails import ( - _format_gmail_summary, - ) from app.services.composio_service import ComposioService + from ._helpers import _format_gmail_summary + ( messages, _next, @@ -98,9 +97,7 @@ def create_search_gmail_tool( } return {"status": "success", "emails": emails, "total": len(emails)} - from app.agents.new_chat.tools.gmail.search_emails import ( - _build_credentials, - ) + from ._helpers import _build_credentials creds = _build_credentials(connector) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py similarity index 98% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py index 0680e51cb..3431a2bc3 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py @@ -10,11 +10,11 @@ from langchain_core.tools import tool from langgraph.types import Command from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt +from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) -from app.agents.shared.receipt import make_receipt -from app.agents.shared.receipt_command import with_receipt from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py index b24e9ebe4..ef5882074 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py @@ -6,7 +6,7 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.services.gmail import GmailToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py index 1ab9d30cf..ef7839a1a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py @@ -8,7 +8,7 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.services.gmail import GmailToolMetadataService diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_auth.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_auth.py new file mode 100644 index 000000000..6eb60ef2a --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_auth.py @@ -0,0 +1,59 @@ +"""Google OAuth credential construction shared across Google connectors. + +Both the Gmail and Calendar connector tools are Google OAuth backed and build +``google.oauth2.credentials.Credentials`` from a stored ``SearchSourceConnector`` +the same way. This module is the single owner of that logic so neither connector +has to import the other. +""" + +from __future__ import annotations + +from datetime import datetime + +from app.db import SearchSourceConnector + +_token_encryption_cache: object | None = None + + +def get_token_encryption(): + global _token_encryption_cache + if _token_encryption_cache is None: + from app.config import config + from app.utils.oauth_security import TokenEncryption + + if not config.SECRET_KEY: + raise RuntimeError("SECRET_KEY not configured for token decryption.") + _token_encryption_cache = TokenEncryption(config.SECRET_KEY) + return _token_encryption_cache + + +def build_credentials(connector: SearchSourceConnector): + """Build Google OAuth Credentials from a connector's stored config. + + Handles both native OAuth connectors (with encrypted tokens) and + Composio-backed connectors. Shared by Gmail and Calendar tools. + """ + from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES + + if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: + raise ValueError("Composio connectors must use Composio tool execution.") + + from google.oauth2.credentials import Credentials + + cfg = dict(connector.config) + if cfg.get("_token_encrypted"): + enc = get_token_encryption() + for key in ("token", "refresh_token", "client_secret"): + if cfg.get(key): + cfg[key] = enc.decrypt_token(cfg[key]) + + exp = (cfg.get("expiry") or "").replace("Z", "") + return Credentials( + token=cfg.get("token"), + refresh_token=cfg.get("refresh_token"), + token_uri=cfg.get("token_uri"), + client_id=cfg.get("client_id"), + client_secret=cfg.get("client_secret"), + scopes=cfg.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/agent.py similarity index 81% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/agent.py index fb4a24ddd..1597d025e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET, load_tools diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/system_prompt.md diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/tools/__init__.py new file mode 100644 index 000000000..403140a5d --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/tools/__init__.py @@ -0,0 +1,7 @@ +from .create_file import create_create_google_drive_file_tool +from .trash_file import create_delete_google_drive_file_tool + +__all__ = [ + "create_create_google_drive_file_tool", + "create_delete_google_drive_file_tool", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py index 70f5eea74..9de4e0a4b 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py @@ -5,7 +5,7 @@ from googleapiclient.errors import HttpError from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.connectors.google_drive.client import GoogleDriveClient diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/tools/index.py similarity index 92% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/tools/index.py index dd05374a1..caf06d6ba 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/tools/index.py @@ -9,7 +9,7 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset from .create_file import create_create_google_drive_file_tool from .trash_file import create_delete_google_drive_file_tool diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py index 7fbcd74a3..c89b54c8e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py @@ -5,7 +5,7 @@ from googleapiclient.errors import HttpError from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.connectors.google_drive.client import GoogleDriveClient diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/agent.py similarity index 80% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/agent.py index ff71d4cf7..693d5980a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/system_prompt.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/tools/index.py similarity index 93% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/tools/index.py index 24f1bdc01..20c67671b 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/jira/tools/index.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.permissions import Rule, Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Rule, Ruleset NAME = "jira" diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/agent.py similarity index 80% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/agent.py index d9b282f2b..d88ec03f1 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/system_prompt.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/tools/index.py similarity index 94% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/tools/index.py index 4a71a31b8..a06b33359 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/linear/tools/index.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.permissions import Rule, Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Rule, Ruleset NAME = "linear" diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/agent.py similarity index 81% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/agent.py index d84efaed8..49973d08c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET, load_tools diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/system_prompt.md diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/__init__.py new file mode 100644 index 000000000..c089eab4b --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/__init__.py @@ -0,0 +1,9 @@ +from .create_event import create_create_luma_event_tool +from .list_events import create_list_luma_events_tool +from .read_event import create_read_luma_event_tool + +__all__ = [ + "create_create_luma_event_tool", + "create_list_luma_events_tool", + "create_read_luma_event_tool", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/_auth.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/_auth.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/_auth.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/_auth.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/create_event.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/create_event.py similarity index 98% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/create_event.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/create_event.py index e3e1126fd..0dffb2d2c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/create_event.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/create_event.py @@ -5,7 +5,7 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/index.py similarity index 92% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/index.py index dbde01061..a479331bb 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/index.py @@ -9,7 +9,7 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset from .create_event import create_create_luma_event_tool from .list_events import create_list_luma_events_tool diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/list_events.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/list_events.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/list_events.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/list_events.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/read_event.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/read_event.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/read_event.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/luma/tools/read_event.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/agent.py similarity index 81% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/agent.py index 8de86b2d8..a4b2d61cf 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET, load_tools diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/system_prompt.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/create_page.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/create_page.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/create_page.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/create_page.py index 20862eb56..49ee0f3aa 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/create_page.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/create_page.py @@ -4,7 +4,7 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py similarity index 98% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py index c98b25811..a187b2cbc 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py @@ -6,11 +6,11 @@ from langchain_core.tools import tool from langgraph.types import Command from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt +from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) -from app.agents.shared.receipt import make_receipt -from app.agents.shared.receipt_command import with_receipt from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.services.notion.tool_metadata_service import NotionToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/index.py similarity index 92% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/index.py index 0475e9dd0..b8f662b03 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/index.py @@ -9,7 +9,7 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset from .create_page import create_create_notion_page_tool from .delete_page import create_delete_notion_page_tool diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/update_page.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/update_page.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/update_page.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/update_page.py index 2b9ce3a6c..6950f0abd 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/update_page.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/notion/tools/update_page.py @@ -4,7 +4,7 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/agent.py similarity index 81% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/agent.py index f7634d8ef..e2fcdac90 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET, load_tools diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/system_prompt.md diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/tools/__init__.py new file mode 100644 index 000000000..406b9b6d2 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/tools/__init__.py @@ -0,0 +1,7 @@ +from .create_file import create_create_onedrive_file_tool +from .trash_file import create_delete_onedrive_file_tool + +__all__ = [ + "create_create_onedrive_file_tool", + "create_delete_onedrive_file_tool", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/create_file.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/tools/create_file.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/create_file.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/tools/create_file.py index 41fa65787..11160650d 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/create_file.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/tools/create_file.py @@ -8,7 +8,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.connectors.onedrive.client import OneDriveClient diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/tools/index.py similarity index 91% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/tools/index.py index e09b43200..4f0a2a7d6 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/tools/index.py @@ -9,7 +9,7 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset from .create_file import create_create_onedrive_file_tool from .trash_file import create_delete_onedrive_file_tool diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/trash_file.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/tools/trash_file.py similarity index 99% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/trash_file.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/tools/trash_file.py index 1f7c51ac5..7b4e0b98c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/trash_file.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/onedrive/tools/trash_file.py @@ -6,7 +6,7 @@ from sqlalchemy import String, and_, cast, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) from app.connectors.onedrive.client import OneDriveClient diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/agent.py similarity index 80% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/agent.py index e16956b25..9951a63f0 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/system_prompt.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/tools/index.py similarity index 89% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/tools/index.py index 44b96661c..a26b537a6 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/slack/tools/index.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.agents.new_chat.permissions import Rule, Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Rule, Ruleset NAME = "slack" diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/agent.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/agent.py similarity index 81% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/agent.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/agent.py index ab808b745..ab927654b 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/agent.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/agent.py @@ -12,9 +12,13 @@ from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) from .tools.index import NAME, RULESET, load_tools diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/description.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/description.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/description.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/description.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/system_prompt.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/system_prompt.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/system_prompt.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/system_prompt.md diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/__init__.py new file mode 100644 index 000000000..dbf966307 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/__init__.py @@ -0,0 +1,9 @@ +from .list_channels import create_list_teams_channels_tool +from .read_messages import create_read_teams_messages_tool +from .send_message import create_send_teams_message_tool + +__all__ = [ + "create_list_teams_channels_tool", + "create_read_teams_messages_tool", + "create_send_teams_message_tool", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/_auth.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/_auth.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/_auth.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/_auth.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/index.py similarity index 92% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/index.py index 41661651f..d144eee82 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/index.py @@ -9,7 +9,7 @@ from typing import Any from langchain_core.tools import BaseTool -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset from .list_channels import create_list_teams_channels_tool from .read_messages import create_read_teams_messages_tool diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/list_channels.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/list_channels.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/list_channels.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/list_channels.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/read_messages.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/read_messages.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/read_messages.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/read_messages.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/send_message.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/send_message.py similarity index 97% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/send_message.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/send_message.py index f1469e3e1..c4491e82e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/send_message.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/connectors/teams/tools/send_message.py @@ -5,7 +5,7 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/mcp_tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/mcp_tools/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/mcp_tools/index.py similarity index 96% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/mcp_tools/index.py index 16dc09ac5..436b13aea 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/mcp_tools/index.py @@ -18,10 +18,10 @@ from sqlalchemy import cast, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat.constants import ( +from app.agents.chat.multi_agent_chat.constants import ( CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS, ) -from app.agents.new_chat.tools.mcp_tool import load_mcp_tools +from app.agents.chat.multi_agent_chat.shared.tools.mcp.tool import load_mcp_tools from app.db import SearchSourceConnector logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/middleware_stack.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/middleware_stack.py similarity index 82% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/middleware_stack.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/middleware_stack.py index aa6211fcc..124ccf704 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/middleware_stack.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/middleware_stack.py @@ -14,11 +14,14 @@ from __future__ import annotations from typing import Any -from app.agents.new_chat.feature_flags import AgentFeatureFlags - -from ..shared.permissions import build_permission_mw -from ..shared.resilience import ResilienceMiddlewares -from ..shared.todos import build_todos_mw +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.middleware.resilience import ( + ResilienceMiddlewares, +) +from app.agents.chat.multi_agent_chat.shared.middleware.todos import build_todos_mw +from app.agents.chat.multi_agent_chat.shared.permissions import ( + build_permission_mw, +) def build_subagent_middleware_stack( diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/registry.py similarity index 77% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/registry.py index 27c147672..cec9eee3a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/registry.py @@ -8,70 +8,70 @@ from deepagents import SubAgent from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.constants import ( +from app.agents.chat.multi_agent_chat.constants import ( SUBAGENT_TO_REQUIRED_CONNECTOR_MAP, ) -from app.agents.multi_agent_chat.subagents.builtins.deliverables.agent import ( +from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.agent import ( build_subagent as build_deliverables_subagent, ) -from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.agent import ( +from app.agents.chat.multi_agent_chat.subagents.builtins.knowledge_base.agent import ( build_subagent as build_knowledge_base_subagent, ) -from app.agents.multi_agent_chat.subagents.builtins.memory.agent import ( +from app.agents.chat.multi_agent_chat.subagents.builtins.memory.agent import ( build_subagent as build_memory_subagent, ) -from app.agents.multi_agent_chat.subagents.builtins.research.agent import ( +from app.agents.chat.multi_agent_chat.subagents.builtins.research.agent import ( build_subagent as build_research_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.airtable.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.airtable.agent import ( build_subagent as build_airtable_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.calendar.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.calendar.agent import ( build_subagent as build_calendar_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.clickup.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.clickup.agent import ( build_subagent as build_clickup_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.confluence.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.confluence.agent import ( build_subagent as build_confluence_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.discord.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.discord.agent import ( build_subagent as build_discord_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.dropbox.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.dropbox.agent import ( build_subagent as build_dropbox_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.gmail.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.gmail.agent import ( build_subagent as build_gmail_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.google_drive.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.google_drive.agent import ( build_subagent as build_google_drive_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.jira.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.jira.agent import ( build_subagent as build_jira_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.linear.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.linear.agent import ( build_subagent as build_linear_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.luma.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.luma.agent import ( build_subagent as build_luma_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.notion.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.notion.agent import ( build_subagent as build_notion_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.onedrive.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.onedrive.agent import ( build_subagent as build_onedrive_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.slack.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.slack.agent import ( build_subagent as build_slack_subagent, ) -from app.agents.multi_agent_chat.subagents.connectors.teams.agent import ( +from app.agents.chat.multi_agent_chat.subagents.connectors.teams.agent import ( build_subagent as build_teams_subagent, ) -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( read_md_file, ) -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec class SubagentBuilder(Protocol): diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/__init__.py new file mode 100644 index 000000000..4ed3a5d8e --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/__init__.py @@ -0,0 +1,17 @@ +"""Cross-slice helpers for route subagents.""" + +from __future__ import annotations + +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, +) +from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) + +__all__ = [ + "SurfSenseSubagentSpec", + "pack_subagent", + "read_md_file", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/auto_approved.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/auto_approved.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/auto_approved.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/auto_approved.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/request.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/request.py similarity index 98% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/request.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/request.py index 2f7e3cd35..8771b1506 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/request.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/request.py @@ -19,7 +19,7 @@ from typing import Any from langgraph.types import interrupt -from app.agents.multi_agent_chat.subagents.shared.hitl.wire import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.wire import ( LC_DECISION_APPROVE, LC_DECISION_EDIT, LC_DECISION_REJECT, diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/result.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/result.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/result.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/result.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/wire/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/wire/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/decision.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/wire/decision.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/decision.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/wire/decision.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/payload.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/wire/payload.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/payload.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/hitl/wire/payload.py diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/invocation.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/invocation.py new file mode 100644 index 000000000..63a63cbc3 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/invocation.py @@ -0,0 +1,69 @@ +"""Subagent-invocation contract shared by the orchestrator and nested subagents. + +Both the main-agent ``task`` middleware (``checkpointed_subagent_middleware``) +and subagents that themselves invoke another subagent (e.g. +``ask_knowledge_base``) need the same two things when spawning a child run: + +- a ``RunnableConfig`` that raises the recursion limit and isolates the child's + ``thread_id`` so each invocation lands in its own checkpoint slot + (``subagent_invoke_config``), and +- the set of parent state keys that must *not* be forwarded into / merged back + from the child (``EXCLUDED_STATE_KEYS``). + +Keeping this here (rather than inside the main-agent middleware) lets subagents +reuse the contract without importing main-agent internals. +""" + +from __future__ import annotations + +from typing import Any + +from langchain.tools import ToolRuntime + +# Mirror of deepagents.middleware.subagents._EXCLUDED_STATE_KEYS. +EXCLUDED_STATE_KEYS = frozenset( + { + "messages", + "todos", + "structured_response", + "skills_metadata", + "memory_contents", + } +) + +# Match the parent graph's budget; the LangGraph default of 25 trips on +# multi-step subagent runs. +DEFAULT_SUBAGENT_RECURSION_LIMIT = 10_000 + + +def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]: + """RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``. + + Each parallel subagent invocation lands in its own checkpoint slot keyed + by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``. + The same call across the resume cycle keeps reading from the same snapshot + (``tool_call_id`` is stable per LLM-emitted call). + + We namespace via ``thread_id`` rather than ``checkpoint_ns`` because + langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a + subgraph path and raises ``ValueError("Subgraph X not found")``. + """ + merged: dict[str, Any] = dict(runtime.config) if runtime.config else {} + current_limit = merged.get("recursion_limit") + try: + current_int = int(current_limit) if current_limit is not None else 0 + except (TypeError, ValueError): + current_int = 0 + if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT: + merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT + + configurable: dict[str, Any] = dict(merged.get("configurable") or {}) + parent_thread_id = configurable.get("thread_id") + per_call_suffix = f"task:{runtime.tool_call_id}" + configurable["thread_id"] = ( + f"{parent_thread_id}::{per_call_suffix}" + if parent_thread_id + else per_call_suffix + ) + merged["configurable"] = configurable + return merged diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/md_file_reader.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/md_file_reader.py similarity index 90% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/shared/md_file_reader.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/md_file_reader.py index 5694e4326..786086f60 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/md_file_reader.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/md_file_reader.py @@ -5,7 +5,7 @@ from __future__ import annotations from functools import lru_cache from importlib import resources -_SHARED_SNIPPETS_PACKAGE = "app.agents.multi_agent_chat.subagents.shared.snippets" +_SHARED_SNIPPETS_PACKAGE = "app.agents.chat.multi_agent_chat.subagents.shared.snippets" def read_md_file(package: str, stem: str) -> str: diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/snippets/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/__init__.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/snippets/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/output_contract_base.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/snippets/output_contract_base.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/output_contract_base.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/snippets/output_contract_base.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/verifiable_handle.md b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/snippets/verifiable_handle.md similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/verifiable_handle.md rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/snippets/verifiable_handle.md diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/spec.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/spec.py similarity index 97% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/shared/spec.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/spec.py index f891f94d2..6bace8ca4 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/spec.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/spec.py @@ -8,7 +8,7 @@ from typing import Any from deepagents import SubAgent -from app.agents.new_chat.permissions import Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Ruleset # A context-hint provider receives the parent-agent ``runtime.state`` mapping # and the ``description`` the orchestrator wrote, and returns a short string diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/subagent_builder.py similarity index 95% rename from surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py rename to surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/subagent_builder.py index 5025b32e7..d03e86685 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/shared/subagent_builder.py @@ -11,18 +11,18 @@ from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.middleware.shared.permissions import ( +from app.agents.chat.multi_agent_chat.shared.permissions import ( + Ruleset, build_permission_mw, ) -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( read_shared_snippet, ) -from app.agents.multi_agent_chat.subagents.shared.spec import ( +from app.agents.chat.multi_agent_chat.subagents.shared.spec import ( SURF_CONTEXT_HINT_PROVIDER_KEY, ContextHintProvider, SurfSenseSubagentSpec, ) -from app.agents.new_chat.permissions import Ruleset logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/chat/runtime/__init__.py b/surfsense_backend/app/agents/chat/runtime/__init__.py new file mode 100644 index 000000000..9cc63f289 --- /dev/null +++ b/surfsense_backend/app/agents/chat/runtime/__init__.py @@ -0,0 +1,16 @@ +"""Lower-level runtime infrastructure for the chat agents. + +Modules here are the foundation layer used to *run* chat agents: wired by the +boundary (routes/tasks) and/or imported by the agent factory + shared +middleware, but never part of any single agent's domain logic. Because they sit +below the agent packages, both the boundary and the agents may depend on them +(forward dependency), while they never import agent code. + +Contents: +- ``checkpointer`` LangGraph Postgres checkpoint saver (boundary lifespan) +- ``llm_config`` LLM provider/model configuration resolution +- ``prompt_caching`` LiteLLM prompt-caching configuration +- ``errors`` agent-runtime error contracts (raised by MW, caught at boundary) +- ``path_resolver`` filesystem path resolution helpers +- ``mention_resolver`` @-mention resolution helpers +""" diff --git a/surfsense_backend/app/agents/new_chat/checkpointer.py b/surfsense_backend/app/agents/chat/runtime/checkpointer.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/checkpointer.py rename to surfsense_backend/app/agents/chat/runtime/checkpointer.py diff --git a/surfsense_backend/app/agents/new_chat/errors.py b/surfsense_backend/app/agents/chat/runtime/errors.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/errors.py rename to surfsense_backend/app/agents/chat/runtime/errors.py diff --git a/surfsense_backend/app/agents/new_chat/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py similarity index 64% rename from surfsense_backend/app/agents/new_chat/llm_config.py rename to surfsense_backend/app/agents/chat/runtime/llm_config.py index bc37bf1c4..aad432edb 100644 --- a/surfsense_backend/app/agents/new_chat/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -27,7 +27,9 @@ from litellm import get_model_info from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching +from app.agents.chat.runtime.prompt_caching import ( + apply_litellm_prompt_caching, +) from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, @@ -90,15 +92,9 @@ class SanitizedChatLiteLLM(ChatLiteLLM): yield chunk -# Provider mapping for LiteLLM model string construction. -# -# Single source of truth lives in -# :mod:`app.services.provider_capabilities` so the YAML loader (which -# runs during ``app.config`` class-body init) can resolve provider -# prefixes without dragging the agent / tools tree into module load -# order. Re-exported here under the historical ``PROVIDER_MAP`` name -# so existing callers (``llm_router_service``, ``image_gen_router_service``, -# tests) keep working unchanged. +# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives +# in provider_capabilities so the YAML loader can resolve prefixes during +# app.config init without importing the agent/tools tree. from app.services.provider_capabilities import ( # noqa: E402 _PROVIDER_PREFIX_MAP as PROVIDER_MAP, ) @@ -155,25 +151,14 @@ class AgentConfig: anonymous_enabled: bool = False quota_reserve_tokens: int | None = None - # Capability flag: best-effort True for the chat selector / catalog. - # Resolved via :func:`provider_capabilities.derive_supports_image_input` - # which prefers OpenRouter's ``architecture.input_modalities`` and - # otherwise consults LiteLLM's authoritative model map. Default True - # is the conservative-allow stance — the streaming-task safety net - # (``is_known_text_only_chat_model``) is the *only* place a False - # actually blocks a request. Setting this to False here without an - # authoritative source would silently hide vision-capable models - # (the regression we're fixing). + # Default-allow: only the streaming safety net (is_known_text_only_chat_model) + # actually blocks on False, so defaulting False would silently hide + # vision-capable models. Resolved via derive_supports_image_input. supports_image_input: bool = True @classmethod def from_auto_mode(cls) -> "AgentConfig": - """ - Create an AgentConfig for Auto mode (LiteLLM Router load balancing). - - Returns: - AgentConfig instance configured for Auto mode - """ + """Build an AgentConfig for Auto mode (LiteLLM Router load balancing).""" return cls( provider="AUTO", model_name="auto", @@ -191,27 +176,15 @@ class AgentConfig: is_premium=False, anonymous_enabled=False, quota_reserve_tokens=None, - # Auto routes across the configured pool, which usually - # contains at least one vision-capable deployment; the router - # will surface a 404 from a non-vision deployment as a normal - # ``allowed_fails`` event and fail over rather than blocking - # the request outright. + # Auto fails over across the pool, so a non-vision deployment's 404 + # is just an allowed_fails event rather than a hard block. supports_image_input=True, ) @classmethod def from_new_llm_config(cls, config) -> "AgentConfig": - """ - Create an AgentConfig from a NewLLMConfig database model. - - Args: - config: NewLLMConfig database model instance - - Returns: - AgentConfig instance - """ - # Lazy import to avoid pulling provider_capabilities (and its - # transitive litellm import) into module-init order. + """Build an AgentConfig from a NewLLMConfig database model.""" + # Lazy import: keeps provider_capabilities (and litellm) out of init order. from app.services.provider_capabilities import derive_supports_image_input provider_value = ( @@ -243,10 +216,8 @@ class AgentConfig: is_premium=False, anonymous_enabled=False, quota_reserve_tokens=None, - # BYOK rows have no operator-curated capability flag, so we - # ask LiteLLM (default-allow on unknown). The streaming - # safety net still blocks if the model is *explicitly* - # marked text-only. + # BYOK rows have no curated flag; ask LiteLLM (default-allow on + # unknown). The streaming safety net still blocks explicit text-only. supports_image_input=derive_supports_image_input( provider=provider_value, model_name=config.model_name, @@ -257,25 +228,14 @@ class AgentConfig: @classmethod def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig": + """Build an AgentConfig from a YAML configuration dictionary. + + Supports the same prompt fields as NewLLMConfig (system_instructions, + use_default_system_instructions, citations_enabled). """ - Create an AgentConfig from a YAML configuration dictionary. - - YAML configs now support the same prompt configuration fields as NewLLMConfig: - - system_instructions: Custom system instructions (empty string uses defaults) - - use_default_system_instructions: Whether to use default instructions - - citations_enabled: Whether citations are enabled - - Args: - yaml_config: Configuration dictionary from YAML file - - Returns: - AgentConfig instance - """ - # Lazy import to avoid pulling provider_capabilities (and its - # transitive litellm import) into module-init order. + # Lazy import: keeps provider_capabilities (and litellm) out of init order. from app.services.provider_capabilities import derive_supports_image_input - # Get system instructions from YAML, default to empty string system_instructions = yaml_config.get("system_instructions", "") provider = yaml_config.get("provider", "").upper() @@ -288,13 +248,8 @@ class AgentConfig: else None ) - # Explicit YAML override wins; otherwise derive from LiteLLM / - # OpenRouter modalities. The YAML loader already populates this - # field, but this method is also called from - # ``load_global_llm_config_by_id``'s file fallback (hot reload), - # so we re-derive here for safety. The bool() coercion preserves - # the loader's behaviour for explicit ``true`` / ``false`` - # strings that PyYAML may surface. + # Explicit YAML override wins; otherwise re-derive (the hot-reload file + # fallback reaches this method without the loader having populated it). if "supports_image_input" in yaml_config: supports_image_input = bool(yaml_config.get("supports_image_input")) else: @@ -312,7 +267,6 @@ class AgentConfig: api_base=yaml_config.get("api_base"), custom_provider=custom_provider, litellm_params=yaml_config.get("litellm_params"), - # Prompt configuration from YAML (with defaults for backwards compatibility) system_instructions=system_instructions if system_instructions else None, use_default_system_instructions=yaml_config.get( "use_default_system_instructions", True @@ -330,20 +284,10 @@ class AgentConfig: def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None: - """ - Load a specific LLM config from global_llm_config.yaml. - - Args: - llm_config_id: The id of the config to load (default: -1) - - Returns: - LLM config dict or None if not found - """ - # Get the config file path + """Load a specific LLM config from global_llm_config.yaml.""" base_dir = Path(__file__).resolve().parent.parent.parent.parent config_file = base_dir / "app" / "config" / "global_llm_config.yaml" - # Fallback to example file if main config doesn't exist if not config_file.exists(): config_file = base_dir / "app" / "config" / "global_llm_config.example.yaml" if not config_file.exists(): @@ -366,24 +310,17 @@ def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None: def load_global_llm_config_by_id(llm_config_id: int) -> dict | None: - """ - Load a global LLM config by ID, checking in-memory configs first. + """Load a global LLM config by ID, checking in-memory configs first. - This handles both static YAML configs and dynamically injected configs - (e.g. OpenRouter integration models that only exist in memory). - - Args: - llm_config_id: The negative ID of the global config to load - - Returns: - LLM config dict or None if not found + In-memory covers both static YAML and dynamically injected configs (e.g. + OpenRouter integration models that only exist in memory). """ from app.config import config as app_config for cfg in app_config.GLOBAL_LLM_CONFIGS: if cfg.get("id") == llm_config_id: return cfg - # Fallback to YAML file read (covers edge cases like hot-reload) + # Fallback to YAML file read (covers hot-reload edge cases). return load_llm_config_from_yaml(llm_config_id) @@ -391,17 +328,7 @@ async def load_new_llm_config_from_db( session: AsyncSession, config_id: int, ) -> "AgentConfig | None": - """ - Load a NewLLMConfig from the database by ID. - - Args: - session: AsyncSession for database access - config_id: The ID of the NewLLMConfig to load - - Returns: - AgentConfig instance or None if not found - """ - # Import here to avoid circular imports + """Load a NewLLMConfig from the database by ID.""" from app.db import NewLLMConfig try: @@ -424,26 +351,13 @@ async def load_agent_llm_config_for_search_space( session: AsyncSession, search_space_id: int, ) -> "AgentConfig | None": + """Load the agent LLM config for a search space via its agent_llm_id. + + Positive id -> DB; negative -> YAML; None -> first global config (-1). """ - Load the agent LLM configuration for a search space. - - This loads the LLM config based on the search space's agent_llm_id setting: - - Positive ID: Load from NewLLMConfig database table - - Negative ID: Load from YAML global configs - - None: Falls back to first global config (id=-1) - - Args: - session: AsyncSession for database access - search_space_id: The search space ID - - Returns: - AgentConfig instance or None if not found - """ - # Import here to avoid circular imports from app.db import SearchSpace try: - # Get the search space to check its agent_llm_id preference result = await session.execute( select(SearchSpace).filter(SearchSpace.id == search_space_id) ) @@ -453,12 +367,9 @@ async def load_agent_llm_config_for_search_space( print(f"Error: SearchSpace with id {search_space_id} not found") return None - # Use agent_llm_id from search space, fallback to -1 (first global config) config_id = ( search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 ) - - # Load the config using the unified loader return await load_agent_config(session, config_id, search_space_id) except Exception as e: print(f"Error loading agent LLM config for search space {search_space_id}: {e}") @@ -470,23 +381,7 @@ async def load_agent_config( config_id: int, search_space_id: int | None = None, ) -> "AgentConfig | None": - """ - Load an agent configuration, supporting Auto mode, YAML, and database configs. - - This is the main entry point for loading configurations: - - ID 0: Auto mode (uses LiteLLM Router for load balancing) - - Negative IDs: Load from YAML file (global configs) - - Positive IDs: Load from NewLLMConfig database table - - Args: - session: AsyncSession for database access - config_id: The config ID (0 for Auto, negative for YAML, positive for database) - search_space_id: Optional search space ID for context - - Returns: - AgentConfig instance or None if not found - """ - # Auto mode (ID 0) - use LiteLLM Router + """Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB.""" if is_auto_mode(config_id): if not LLMRouterService.is_initialized(): print("Error: Auto mode requested but LLM Router not initialized") @@ -494,33 +389,22 @@ async def load_agent_config( return AgentConfig.from_auto_mode() if config_id < 0: - # Check in-memory configs first (includes static YAML + dynamic OpenRouter) + # In-memory covers static YAML + dynamic OpenRouter configs. from app.config import config as app_config for cfg in app_config.GLOBAL_LLM_CONFIGS: if cfg.get("id") == config_id: return AgentConfig.from_yaml_config(cfg) - # Fallback to YAML file read for safety yaml_config = load_llm_config_from_yaml(config_id) if yaml_config: return AgentConfig.from_yaml_config(yaml_config) return None else: - # Load from database (NewLLMConfig) return await load_new_llm_config_from_db(session, config_id) def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: - """ - Create a ChatLiteLLM instance from a global LLM config dictionary. - - Args: - llm_config: LLM configuration dictionary from YAML - - Returns: - ChatLiteLLM instance or None on error - """ - # Build the model string + """Create a ChatLiteLLM instance from a global LLM config dictionary.""" if llm_config.get("custom_provider"): model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}" else: @@ -528,27 +412,20 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) model_string = f"{provider_prefix}/{llm_config['model_name']}" - # Create ChatLiteLLM instance with streaming enabled litellm_kwargs = { "model": model_string, "api_key": llm_config.get("api_key"), - "streaming": True, # Enable streaming for real-time token streaming + "streaming": True, } - - # Add optional parameters if llm_config.get("api_base"): litellm_kwargs["api_base"] = llm_config["api_base"] - - # Add any additional litellm parameters if llm_config.get("litellm_params"): litellm_kwargs.update(llm_config["litellm_params"]) llm = SanitizedChatLiteLLM(**litellm_kwargs) _attach_model_profile(llm, model_string) - # Configure LiteLLM-native prompt caching (cache_control_injection_points - # for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.). - # ``agent_config=None`` here — the YAML path doesn't have provider intent - # in a structured form, so we set only the universal injection points. + # agent_config=None: the YAML path lacks structured provider intent, so set + # only the universal cache_control_injection_points. apply_litellm_prompt_caching(llm) return llm @@ -556,19 +433,7 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: def create_chat_litellm_from_agent_config( agent_config: AgentConfig, ) -> ChatLiteLLM | ChatLiteLLMRouter | None: - """ - Create a ChatLiteLLM or ChatLiteLLMRouter instance from an AgentConfig. - - For Auto mode configs, returns a ChatLiteLLMRouter that uses LiteLLM Router - for automatic load balancing across available providers. - - Args: - agent_config: AgentConfig instance - - Returns: - ChatLiteLLM or ChatLiteLLMRouter instance, or None on error - """ - # Handle Auto mode - return ChatLiteLLMRouter + """Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config.""" if agent_config.is_auto_mode: if not LLMRouterService.is_initialized(): print("Error: Auto mode requested but LLM Router not initialized") @@ -576,19 +441,14 @@ def create_chat_litellm_from_agent_config( try: router_llm = get_auto_mode_llm() if router_llm is not None: - # Universal cache_control_injection_points only — auto-mode - # fans out across providers, so OpenAI-only kwargs (e.g. - # ``prompt_cache_key``) are left off here. ``drop_params`` - # would strip them at the provider boundary anyway, but - # there's no point setting them when we don't know the - # destination. + # Universal injection points only: auto-mode fans out across + # providers, so provider-specific kwargs have no known target. apply_litellm_prompt_caching(router_llm, agent_config=agent_config) return router_llm except Exception as e: print(f"Error creating ChatLiteLLMRouter: {e}") return None - # Build the model string if agent_config.custom_provider: model_string = f"{agent_config.custom_provider}/{agent_config.model_name}" else: @@ -597,26 +457,19 @@ def create_chat_litellm_from_agent_config( ) model_string = f"{provider_prefix}/{agent_config.model_name}" - # Create ChatLiteLLM instance with streaming enabled litellm_kwargs = { "model": model_string, "api_key": agent_config.api_key, - "streaming": True, # Enable streaming for real-time token streaming + "streaming": True, } - - # Add optional parameters if agent_config.api_base: litellm_kwargs["api_base"] = agent_config.api_base - - # Add any additional litellm parameters if agent_config.litellm_params: litellm_kwargs.update(agent_config.litellm_params) llm = SanitizedChatLiteLLM(**litellm_kwargs) _attach_model_profile(llm, model_string) - # Build-time prompt caching: sets ``cache_control_injection_points`` for - # all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``. - # Per-thread ``prompt_cache_key`` is layered on later in - # ``create_surfsense_deep_agent`` once ``thread_id`` is known. + # Build-time caching only; the per-thread prompt_cache_key is layered on + # later in create_surfsense_deep_agent once thread_id is known. apply_litellm_prompt_caching(llm, agent_config=agent_config) return llm diff --git a/surfsense_backend/app/agents/new_chat/mention_resolver.py b/surfsense_backend/app/agents/chat/runtime/mention_resolver.py similarity index 99% rename from surfsense_backend/app/agents/new_chat/mention_resolver.py rename to surfsense_backend/app/agents/chat/runtime/mention_resolver.py index f13dbc6ae..a47ed8f36 100644 --- a/surfsense_backend/app/agents/new_chat/mention_resolver.py +++ b/surfsense_backend/app/agents/chat/runtime/mention_resolver.py @@ -36,7 +36,7 @@ from dataclasses import dataclass, field from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.path_resolver import ( +from app.agents.chat.runtime.path_resolver import ( DOCUMENTS_ROOT, build_path_index, doc_to_virtual_path, diff --git a/surfsense_backend/app/agents/new_chat/path_resolver.py b/surfsense_backend/app/agents/chat/runtime/path_resolver.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/path_resolver.py rename to surfsense_backend/app/agents/chat/runtime/path_resolver.py diff --git a/surfsense_backend/app/agents/chat/runtime/prompt_caching.py b/surfsense_backend/app/agents/chat/runtime/prompt_caching.py new file mode 100644 index 000000000..5a5fd7418 --- /dev/null +++ b/surfsense_backend/app/agents/chat/runtime/prompt_caching.py @@ -0,0 +1,162 @@ +r"""LiteLLM-native prompt caching for SurfSense agents. + +Replaces the legacy ``AnthropicPromptCachingMiddleware`` (its +``isinstance(model, ChatAnthropic)`` gate never matched our LiteLLM stack) +with LiteLLM's universal ``cache_control_injection_points`` mechanism, which +covers the Anthropic/Bedrock/Vertex/Gemini/OpenRouter/etc. marker-based +providers and the auto-caching OpenAI family. + +Two breakpoints per request: + +- ``index: 0`` pins the head-of-request system prompt. We use ``index: 0``, + NOT ``role: system``: ``before_agent`` injectors accumulate many + SystemMessages, and tagging all of them overflows Anthropic's 4-block cap + (upstream 400 via OpenRouter). +- ``index: -1`` pins the latest message so longest-prefix lookup compounds + multi-turn savings. + +OpenAI-family configs also get ``prompt_cache_key`` (per-thread routing hint) +and ``prompt_cache_retention="24h"``. Azure is excluded from the latter +because LiteLLM's Azure transformer drops it (see +``_PROMPT_CACHE_RETENTION_PROVIDERS``). + +Safety net: ``litellm.drop_params=True`` (set in ``app.services.llm_service``) +strips any kwarg the destination provider rejects, so an auto-mode fallback +can't 400 on these extras. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from langchain_core.language_models import BaseChatModel + +if TYPE_CHECKING: + from app.agents.chat.runtime.llm_config import AgentConfig + +logger = logging.getLogger(__name__) + + +# Head-of-request + latest message (see module docstring for the index:0 vs +# role:system rationale and Anthropic's 4-block cap). +_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = ( + {"location": "message", "index": 0}, + {"location": "message", "index": -1}, +) + +# Providers that accept the OpenAI ``prompt_cache_key`` routing hint. Strict +# whitelist: many providers route through litellm's ``openai`` prefix without +# the prompt-cache surface, so the prefix alone isn't enough to infer family. +_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset( + {"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"} +) + +# Subset that also accepts ``prompt_cache_retention="24h"``. Azure is excluded +# because LiteLLM's Azure transformer omits the param (drop_params strips it). +_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset( + {"OPENAI", "DEEPSEEK", "XAI"} +) + + +def _is_router_llm(llm: BaseChatModel) -> bool: + """Detect ``ChatLiteLLMRouter`` by class name to avoid an import cycle.""" + return type(llm).__name__ == "ChatLiteLLMRouter" + + +def _provider_supports_prompt_cache_key(agent_config: AgentConfig | None) -> bool: + """Whether the config targets a provider that accepts ``prompt_cache_key``. + + Strict — only returns True for explicitly chosen OPENAI, DEEPSEEK, + XAI, AZURE, or AZURE_OPENAI providers. Auto-mode and custom + providers return False because we can't statically know the + destination and the router fans out across mixed providers. + """ + if agent_config is None or not agent_config.provider: + return False + if agent_config.is_auto_mode: + return False + if agent_config.custom_provider: + return False + return agent_config.provider.upper() in _PROMPT_CACHE_KEY_PROVIDERS + + +def _provider_supports_prompt_cache_retention( + agent_config: AgentConfig | None, +) -> bool: + """Whether the config targets a provider that accepts ``prompt_cache_retention``. + + Tighter than :func:`_provider_supports_prompt_cache_key` — Azure + deployments are excluded until LiteLLM ships the param in its Azure + transformer (see module docstring). + """ + if agent_config is None or not agent_config.provider: + return False + if agent_config.is_auto_mode: + return False + if agent_config.custom_provider: + return False + return agent_config.provider.upper() in _PROMPT_CACHE_RETENTION_PROVIDERS + + +def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None: + """Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail. + + Initialises the field to ``{}`` when present-but-None on a Pydantic v2 + model. Returns ``None`` if the LLM type doesn't expose a writable + ``model_kwargs`` attribute (caller should treat as no-op). + """ + model_kwargs = getattr(llm, "model_kwargs", None) + if isinstance(model_kwargs, dict): + return model_kwargs + try: + llm.model_kwargs = {} # type: ignore[attr-defined] + except Exception: + return None + refreshed = getattr(llm, "model_kwargs", None) + return refreshed if isinstance(refreshed, dict) else None + + +def apply_litellm_prompt_caching( + llm: BaseChatModel, + *, + agent_config: AgentConfig | None = None, + thread_id: int | None = None, +) -> None: + """Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter. + + Idempotent (existing ``model_kwargs`` values are preserved) and mutates + ``llm.model_kwargs`` in place. Without ``agent_config`` (or in auto-mode) + only the universal injection points are set; ``thread_id`` adds a per-thread + ``prompt_cache_key`` for OpenAI-family providers to improve routing affinity. + """ + model_kwargs = _get_or_init_model_kwargs(llm) + if model_kwargs is None: + logger.debug( + "apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping", + type(llm).__name__, + ) + return + + if "cache_control_injection_points" not in model_kwargs: + model_kwargs["cache_control_injection_points"] = [ + dict(point) for point in _DEFAULT_INJECTION_POINTS + ] + + # OpenAI-style extras only when the destination is statically known. The + # auto-mode router fans out across mixed providers, so skip them there. + if _is_router_llm(llm): + return + + if ( + thread_id is not None + and "prompt_cache_key" not in model_kwargs + and _provider_supports_prompt_cache_key(agent_config) + ): + model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}" + + if ( + "prompt_cache_retention" not in model_kwargs + and _provider_supports_prompt_cache_retention(agent_config) + ): + model_kwargs["prompt_cache_retention"] = "24h" diff --git a/surfsense_backend/app/agents/shared/__init__.py b/surfsense_backend/app/agents/chat/shared/__init__.py similarity index 79% rename from surfsense_backend/app/agents/shared/__init__.py rename to surfsense_backend/app/agents/chat/shared/__init__.py index 7c46c65ff..e84bc7543 100644 --- a/surfsense_backend/app/agents/shared/__init__.py +++ b/surfsense_backend/app/agents/chat/shared/__init__.py @@ -2,7 +2,7 @@ Symbols here are intentionally framework-light (no LangGraph / deepagents internals) so they can be imported from both ``app.agents.new_chat`` and -``app.agents.multi_agent_chat`` without creating a circular dependency +``app.agents.chat.multi_agent_chat`` without creating a circular dependency between the two packages. See ``receipt.py`` for the rationale. """ diff --git a/surfsense_backend/app/agents/new_chat/context.py b/surfsense_backend/app/agents/chat/shared/context.py similarity index 97% rename from surfsense_backend/app/agents/new_chat/context.py rename to surfsense_backend/app/agents/chat/shared/context.py index 1b3ea3d20..50b761f5b 100644 --- a/surfsense_backend/app/agents/new_chat/context.py +++ b/surfsense_backend/app/agents/chat/shared/context.py @@ -50,8 +50,8 @@ class SurfSenseContextSchema: (cloud filesystem mode). Surfaced as ``[USER-MENTIONED]`` entries in ```` so the agent prioritises walking those folders with ``ls`` / ``find_documents``. - file_operation_contract: One-shot file operation contract emitted - by ``FileIntentMiddleware`` for the upcoming turn. + file_operation_contract: One-shot file operation contract for the + upcoming turn (reserved; not currently populated). turn_id / request_id: Correlation IDs surfaced by the streaming task; populated for telemetry. diff --git a/surfsense_backend/app/agents/chat/shared/middleware/__init__.py b/surfsense_backend/app/agents/chat/shared/middleware/__init__.py new file mode 100644 index 000000000..90339137b --- /dev/null +++ b/surfsense_backend/app/agents/chat/shared/middleware/__init__.py @@ -0,0 +1,13 @@ +"""Shared middleware components for the SurfSense chat agents.""" + +from app.agents.chat.shared.middleware.compaction import ( + SurfSenseCompactionMiddleware, + create_surfsense_compaction_middleware, +) +from app.agents.chat.shared.middleware.retry_after import RetryAfterMiddleware + +__all__ = [ + "RetryAfterMiddleware", + "SurfSenseCompactionMiddleware", + "create_surfsense_compaction_middleware", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/compaction.py b/surfsense_backend/app/agents/chat/shared/middleware/compaction.py similarity index 70% rename from surfsense_backend/app/agents/new_chat/middleware/compaction.py rename to surfsense_backend/app/agents/chat/shared/middleware/compaction.py index f8d340e5d..f91af6a70 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/compaction.py +++ b/surfsense_backend/app/agents/chat/shared/middleware/compaction.py @@ -1,26 +1,13 @@ -""" -SurfSense compaction middleware. +"""SurfSense compaction middleware. -Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware` -to add SurfSense-specific behavior: +Extends ``SummarizationMiddleware`` with three SurfSense behaviors: -1. **Structured summary template** (OpenCode-style ``## Goal / Constraints / - Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``) - — see :data:`SURFSENSE_SUMMARY_PROMPT` below. The base - ``SummarizationMiddleware`` only ships a freeform "summarize this" - prompt; the structured template is ported from OpenCode's - ``compaction.ts``. -2. **Protect SurfSense-specific SystemMessages** so injected hints - (````, ````, ````, - ````, ````, ````, ````) - are *not* summarized away and are kept verbatim in the post-summary - message list. Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy - (some message types are part of the agent's contract and must survive - compaction unchanged). -3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string`` - (Azure OpenAI / LiteLLM defense — when a provider streams an AIMessage - containing only tool_calls and no text, ``content`` can be ``None`` and - ``get_buffer_string`` crashes iterating over ``None``). SurfSense-specific. +1. A structured summary template (:data:`SURFSENSE_SUMMARY_PROMPT`) instead of + the base freeform prompt. +2. Protected SystemMessages (injected hints like ````) are + kept verbatim instead of being summarized away. +3. ``content=None`` is sanitized before ``get_buffer_string`` (some providers + stream tool-only AIMessages with ``None`` content, which would crash it). """ from __future__ import annotations @@ -43,9 +30,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# Structured summary template ported from OpenCode's -# ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a -# module-level constant so unit tests can assert on its sections. +# Module-level constant so unit tests can assert on its sections. SURFSENSE_SUMMARY_PROMPT = """ SurfSense Conversation Compaction Assistant @@ -94,7 +79,7 @@ Respond ONLY with the structured summary. Do not include any text before or afte PROTECTED_SYSTEM_PREFIXES: tuple[str, ...] = ( "", # KnowledgePriorityMiddleware "", # KnowledgeTreeMiddleware - "", # FileIntentMiddleware + "", # reserved file-operation contract prefix "", # MemoryInjectionMiddleware "", # MemoryInjectionMiddleware "", # MemoryInjectionMiddleware @@ -114,13 +99,10 @@ def _is_protected_system_message(msg: AnyMessage) -> bool: def _sanitize_message_content(msg: AnyMessage) -> AnyMessage: - """Return ``msg`` with ``content=None`` coerced to ``""``. + """Return a copy of ``msg`` with ``content=None`` coerced to ``""``. - Folds in the historical defense from ``safe_summarization.py`` — - ``get_buffer_string`` reads ``m.text`` which iterates ``self.content``, - so a ``None`` content (Azure OpenAI / LiteLLM streaming a tool-only - AIMessage) explodes. We return a copy with empty string content so - downstream consumers see an empty body without mutating the original. + ``get_buffer_string`` reads ``m.text`` (iterating ``content``), so a + tool-only AIMessage with ``None`` content would crash it. """ if getattr(msg, "content", "not-missing") is not None: return msg @@ -159,20 +141,11 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware): conversation_messages: list[AnyMessage], cutoff_index: int, ) -> tuple[list[AnyMessage], list[AnyMessage]]: - """Split messages but always preserve SurfSense protected SystemMessages. + """Split messages, always preserving protected SystemMessages. - Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy - (``opencode/packages/opencode/src/session/compaction.ts``): some - message types are always kept verbatim because they are part of the - agent's working contract, not transient output. - - Also opens a ``compaction.run`` OTel span (no-op when OTel is off) - so dashboards can count compaction events and message-volume - without having to instrument upstream callers. + Also opens a ``compaction.run`` OTel span (no-op when OTel is off) here, + since partitioning is the first call once summarization is decided. """ - # Opening a span here is appropriate because partitioning is the - # first call SummarizationMiddleware makes when it has decided to - # summarize; we record the volume and then close as a normal span. with ot.compaction_span( reason="auto", messages_in=len(conversation_messages), @@ -191,20 +164,15 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware): else: kept_for_summary.append(msg) - # Place protected blocks at the *front* of preserved_messages so - # they keep their original ordering relative to the summary - # HumanMessage that precedes the rest of the preserved tail. + # Protected blocks go at the front of preserved_messages to keep + # ordering relative to the summary HumanMessage. return kept_for_summary, [*protected, *preserved_messages] def _filter_summary_messages( # type: ignore[override] self, messages: list[AnyMessage] ) -> list[AnyMessage]: - """Filter previous summaries AND sanitize ``content=None``. - - Folds the ``safe_summarization.py`` defense in: when the buffer - builder iterates ``m.text`` over ``None`` it explodes; sanitizing - here covers both the sync and async offload paths. - """ + """Filter previous summaries and sanitize ``content=None`` (covers the + sync and async offload paths).""" filtered = super()._filter_summary_messages(messages) return [_sanitize_message_content(m) for m in filtered] diff --git a/surfsense_backend/app/agents/new_chat/middleware/retry_after.py b/surfsense_backend/app/agents/chat/shared/middleware/retry_after.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/middleware/retry_after.py rename to surfsense_backend/app/agents/chat/shared/middleware/retry_after.py diff --git a/surfsense_backend/app/agents/chat/shared/tools/__init__.py b/surfsense_backend/app/agents/chat/shared/tools/__init__.py new file mode 100644 index 000000000..342fe9169 --- /dev/null +++ b/surfsense_backend/app/agents/chat/shared/tools/__init__.py @@ -0,0 +1,5 @@ +"""Cross-agent shared tools. + +Only genuinely cross-agent tool code lives here (currently web_search, imported +directly from its module). +""" diff --git a/surfsense_backend/app/agents/new_chat/tools/web_search.py b/surfsense_backend/app/agents/chat/shared/tools/web_search.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/tools/web_search.py rename to surfsense_backend/app/agents/chat/shared/tools/web_search.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/provider_hints.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/provider_hints.py deleted file mode 100644 index 78d7b08ec..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/provider_hints.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Provider-specific style hints from ``prompts/providers/`` (main agent only).""" - -from __future__ import annotations - -import re - -from .load_md import read_prompt_md - -ProviderVariant = str - -_OPENAI_CODEX_RE = re.compile( - r"\b(gpt-codex|codex-mini|gpt-[\d.]+-codex)\b", re.IGNORECASE -) -_OPENAI_REASONING_RE = re.compile(r"\b(gpt-5|o\d|o-)", re.IGNORECASE) -_OPENAI_CLASSIC_RE = re.compile(r"\bgpt-4", re.IGNORECASE) -_ANTHROPIC_RE = re.compile(r"\bclaude\b", re.IGNORECASE) -_GOOGLE_RE = re.compile(r"\bgemini\b", re.IGNORECASE) -_KIMI_RE = re.compile(r"\b(kimi[-\d.]*|moonshot)\b", re.IGNORECASE) -_GROK_RE = re.compile(r"\bgrok\b", re.IGNORECASE) -_DEEPSEEK_RE = re.compile(r"\bdeepseek\b", re.IGNORECASE) - - -def detect_provider_variant(model_name: str | None) -> ProviderVariant: - if not model_name: - return "default" - name = model_name.strip() - if _OPENAI_CODEX_RE.search(name): - return "openai_codex" - if _OPENAI_REASONING_RE.search(name): - return "openai_reasoning" - if _OPENAI_CLASSIC_RE.search(name): - return "openai_classic" - if _ANTHROPIC_RE.search(name): - return "anthropic" - if _GOOGLE_RE.search(name): - return "google" - if _KIMI_RE.search(name): - return "kimi" - if _GROK_RE.search(name): - return "grok" - if _DEEPSEEK_RE.search(name): - return "deepseek" - return "default" - - -def build_provider_hint_block(provider_variant: ProviderVariant) -> str: - if not provider_variant or provider_variant == "default": - return "" - text = read_prompt_md(f"providers/{provider_variant}.md") - return f"\n{text}\n" if text else "" diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/provider.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/provider.py deleted file mode 100644 index 7de722080..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/sections/provider.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Provider-specific style hints.""" - -from __future__ import annotations - -from ..provider_hints import build_provider_hint_block, detect_provider_variant - - -def build_provider_section(*, model_name: str | None) -> str: - return build_provider_hint_block(detect_provider_variant(model_name)) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/dedup_hitl.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/dedup_hitl.py deleted file mode 100644 index 66cae300b..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/dedup_hitl.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Drop duplicate HITL tool calls before execution.""" - -from __future__ import annotations - -from collections.abc import Sequence - -from langchain_core.tools import BaseTool - -from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware - - -def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware: - return DedupHITLToolCallsMiddleware(agent_tools=list(tools)) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/__init__.py deleted file mode 100644 index c25c2b281..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Pattern-based allow/deny/ask middleware with HITL fallback (vertical slice). - -Public surface (one entry point only — every other symbol is an internal of -the rule engine and stays inside ``middleware/``, ``ask/``, or ``deny.py``): - -- :func:`build_permission_mw` — construction recipe shared by every stack. -""" - -from .middleware.factory import build_permission_mw - -__all__ = ["build_permission_mw"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/__init__.py deleted file mode 100644 index 13d4c06cb..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from app.agents.new_chat.tools.google_calendar.create_event import ( - create_create_calendar_event_tool, -) -from app.agents.new_chat.tools.google_calendar.delete_event import ( - create_delete_calendar_event_tool, -) -from app.agents.new_chat.tools.google_calendar.search_events import ( - create_search_calendar_events_tool, -) -from app.agents.new_chat.tools.google_calendar.update_event import ( - create_update_calendar_event_tool, -) - -__all__ = [ - "create_create_calendar_event_tool", - "create_delete_calendar_event_tool", - "create_search_calendar_events_tool", - "create_update_calendar_event_tool", -] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/__init__.py deleted file mode 100644 index b4eaec1f0..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from app.agents.new_chat.tools.discord.list_channels import ( - create_list_discord_channels_tool, -) -from app.agents.new_chat.tools.discord.read_messages import ( - create_read_discord_messages_tool, -) -from app.agents.new_chat.tools.discord.send_message import ( - create_send_discord_message_tool, -) - -__all__ = [ - "create_list_discord_channels_tool", - "create_read_discord_messages_tool", - "create_send_discord_message_tool", -] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/__init__.py deleted file mode 100644 index 836b9ee41..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from app.agents.new_chat.tools.dropbox.create_file import ( - create_create_dropbox_file_tool, -) -from app.agents.new_chat.tools.dropbox.trash_file import ( - create_delete_dropbox_file_tool, -) - -__all__ = [ - "create_create_dropbox_file_tool", - "create_delete_dropbox_file_tool", -] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/__init__.py deleted file mode 100644 index 294840122..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -from app.agents.new_chat.tools.gmail.create_draft import ( - create_create_gmail_draft_tool, -) -from app.agents.new_chat.tools.gmail.read_email import ( - create_read_gmail_email_tool, -) -from app.agents.new_chat.tools.gmail.search_emails import ( - create_search_gmail_tool, -) -from app.agents.new_chat.tools.gmail.send_email import ( - create_send_gmail_email_tool, -) -from app.agents.new_chat.tools.gmail.trash_email import ( - create_trash_gmail_email_tool, -) -from app.agents.new_chat.tools.gmail.update_draft import ( - create_update_gmail_draft_tool, -) - -__all__ = [ - "create_create_gmail_draft_tool", - "create_read_gmail_email_tool", - "create_search_gmail_tool", - "create_send_gmail_email_tool", - "create_trash_gmail_email_tool", - "create_update_gmail_draft_tool", -] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/__init__.py deleted file mode 100644 index 9c63bceb1..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from app.agents.new_chat.tools.google_drive.create_file import ( - create_create_google_drive_file_tool, -) -from app.agents.new_chat.tools.google_drive.trash_file import ( - create_delete_google_drive_file_tool, -) - -__all__ = [ - "create_create_google_drive_file_tool", - "create_delete_google_drive_file_tool", -] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/__init__.py deleted file mode 100644 index 255119bee..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from app.agents.new_chat.tools.luma.create_event import ( - create_create_luma_event_tool, -) -from app.agents.new_chat.tools.luma.list_events import ( - create_list_luma_events_tool, -) -from app.agents.new_chat.tools.luma.read_event import ( - create_read_luma_event_tool, -) - -__all__ = [ - "create_create_luma_event_tool", - "create_list_luma_events_tool", - "create_read_luma_event_tool", -] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/__init__.py deleted file mode 100644 index 8edb4857e..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from app.agents.new_chat.tools.onedrive.create_file import ( - create_create_onedrive_file_tool, -) -from app.agents.new_chat.tools.onedrive.trash_file import ( - create_delete_onedrive_file_tool, -) - -__all__ = [ - "create_create_onedrive_file_tool", - "create_delete_onedrive_file_tool", -] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/__init__.py deleted file mode 100644 index 60e2add49..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from app.agents.new_chat.tools.teams.list_channels import ( - create_list_teams_channels_tool, -) -from app.agents.new_chat.tools.teams.read_messages import ( - create_read_teams_messages_tool, -) -from app.agents.new_chat.tools.teams.send_message import ( - create_send_teams_message_tool, -) - -__all__ = [ - "create_list_teams_channels_tool", - "create_read_teams_messages_tool", - "create_send_teams_message_tool", -] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/__init__.py deleted file mode 100644 index 70d3dfe39..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Cross-slice helpers for route subagents.""" - -from __future__ import annotations - -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) - -__all__ = [ - "SurfSenseSubagentSpec", - "pack_subagent", - "read_md_file", -] diff --git a/surfsense_backend/app/agents/new_chat/__init__.py b/surfsense_backend/app/agents/new_chat/__init__.py deleted file mode 100644 index 4b2eb89eb..000000000 --- a/surfsense_backend/app/agents/new_chat/__init__.py +++ /dev/null @@ -1,91 +0,0 @@ -""" -SurfSense New Chat Agent Module. - -This module provides the SurfSense deep agent with configurable tools, -middleware, and preloaded knowledge-base filesystem behavior. - -Directory Structure: -- tools/: All agent tools (podcast, generate_image, web, memory, etc.) -- middleware/: Custom middleware (knowledge search, filesystem, dedup, etc.) -- chat_deepagent.py: Main agent factory -- system_prompt.py: System prompts and instructions -- context.py: Context schema for the agent -- checkpointer.py: LangGraph checkpointer setup -- llm_config.py: LLM configuration utilities -- utils.py: Shared utilities -""" - -# Agent factory -from .chat_deepagent import create_surfsense_deep_agent - -# Context -from .context import SurfSenseContextSchema - -# LLM config -from .llm_config import ( - create_chat_litellm_from_config, - load_global_llm_config_by_id, - load_llm_config_from_yaml, -) - -# Middleware -from .middleware import ( - DedupHITLToolCallsMiddleware, - KnowledgeBaseSearchMiddleware, - SurfSenseFilesystemMiddleware, -) - -# System prompt -from .system_prompt import ( - SURFSENSE_CITATION_INSTRUCTIONS, - SURFSENSE_SYSTEM_PROMPT, - build_surfsense_system_prompt, -) - -# Tools - registry exports -# Tools - factory exports (for direct use) -# Tools - knowledge base utilities -from .tools import ( - BUILTIN_TOOLS, - ToolDefinition, - build_tools, - create_generate_podcast_tool, - create_scrape_webpage_tool, - format_documents_for_context, - get_all_tool_names, - get_default_enabled_tools, - get_tool_by_name, - search_knowledge_base_async, -) - -__all__ = [ - # Tools registry - "BUILTIN_TOOLS", - # System prompt - "SURFSENSE_CITATION_INSTRUCTIONS", - "SURFSENSE_SYSTEM_PROMPT", - # Middleware - "DedupHITLToolCallsMiddleware", - "KnowledgeBaseSearchMiddleware", - # Context - "SurfSenseContextSchema", - "SurfSenseFilesystemMiddleware", - "ToolDefinition", - "build_surfsense_system_prompt", - "build_tools", - # LLM config - "create_chat_litellm_from_config", - # Tool factories - "create_generate_podcast_tool", - "create_scrape_webpage_tool", - # Agent factory - "create_surfsense_deep_agent", - # Knowledge base utilities - "format_documents_for_context", - "get_all_tool_names", - "get_default_enabled_tools", - "get_tool_by_name", - "load_global_llm_config_by_id", - "load_llm_config_from_yaml", - "search_knowledge_base_async", -] diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py deleted file mode 100644 index f8db333ba..000000000 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ /dev/null @@ -1,1166 +0,0 @@ -""" -SurfSense deep agent implementation. - -This module provides the factory function for creating SurfSense deep agents -with configurable tools via the tools registry and configurable prompts -via NewLLMConfig. - -We use ``create_agent`` (from langchain) rather than ``create_deep_agent`` -(from deepagents) so that the middleware stack is fully under our control. -This lets us swap in ``SurfSenseFilesystemMiddleware`` — a customisable -subclass of the default ``FilesystemMiddleware`` — while preserving every -other behaviour that ``create_deep_agent`` provides (todo-list, subagents, -summarisation, etc.). Prompt caching is configured at LLM-build time via -``apply_litellm_prompt_caching`` (LiteLLM-native, multi-provider) rather -than as a middleware. -""" - -import asyncio -import logging -import time -from collections.abc import Sequence -from typing import Any - -from deepagents import SubAgent, SubAgentMiddleware, __version__ as deepagents_version -from deepagents.backends import StateBackend -from deepagents.graph import BASE_AGENT_PROMPT -from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware -from deepagents.middleware.skills import SkillsMiddleware -from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT -from langchain.agents import create_agent -from langchain.agents.middleware import ( - LLMToolSelectorMiddleware, - ModelCallLimitMiddleware, - TodoListMiddleware, - ToolCallLimitMiddleware, -) -from langchain_core.language_models import BaseChatModel -from langchain_core.tools import BaseTool -from langgraph.types import Checkpointer -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.agent_cache import ( - flags_signature, - get_cache, - stable_hash, - system_prompt_hash, - tools_signature, -) -from app.agents.new_chat.context import SurfSenseContextSchema -from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags -from app.agents.new_chat.filesystem_backends import build_backend_resolver -from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection -from app.agents.new_chat.llm_config import AgentConfig -from app.agents.new_chat.middleware import ( - ActionLogMiddleware, - AnonymousDocumentMiddleware, - BusyMutexMiddleware, - ClearToolUsesEdit, - DedupHITLToolCallsMiddleware, - DoomLoopMiddleware, - FileIntentMiddleware, - FlattenSystemMessageMiddleware, - KnowledgeBasePersistenceMiddleware, - KnowledgePriorityMiddleware, - KnowledgeTreeMiddleware, - MemoryInjectionMiddleware, - NoopInjectionMiddleware, - OtelSpanMiddleware, - PermissionMiddleware, - RetryAfterMiddleware, - SpillingContextEditingMiddleware, - SpillToBackendEdit, - SurfSenseFilesystemMiddleware, - ToolCallNameRepairMiddleware, - build_skills_backend_factory, - create_surfsense_compaction_middleware, - default_skills_sources, -) -from app.agents.new_chat.middleware.scoped_model_fallback import ( - ScopedModelFallbackMiddleware, -) -from app.agents.new_chat.permissions import Rule, Ruleset -from app.agents.new_chat.plugin_loader import ( - PluginContext, - load_allowed_plugin_names_from_env, - load_plugin_middlewares, -) -from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching -from app.agents.new_chat.subagents import build_specialized_subagents -from app.agents.new_chat.system_prompt import ( - build_configurable_system_prompt, - build_surfsense_system_prompt, -) -from app.agents.new_chat.tools.invalid_tool import ( - INVALID_TOOL_NAME, - invalid_tool, -) -from app.agents.new_chat.tools.registry import ( - BUILTIN_TOOLS, - build_tools_async, - get_connector_gated_tools, -) -from app.db import ChatVisibility -from app.services.connector_service import ConnectorService -from app.services.llm_service import get_planner_llm -from app.utils.perf import get_perf_logger - -_perf_log = get_perf_logger() - - -def _resolve_prompt_model_name( - agent_config: AgentConfig | None, - llm: BaseChatModel, -) -> str | None: - """Resolve the model id to feed to provider-variant detection. - - Preference order (matches the established idiom in - ``llm_router_service.py`` — see ``params.get("base_model") or - params.get("model", "")`` usages there): - - 1. ``agent_config.litellm_params["base_model"]`` — required for Azure - deployments where ``model_name`` is the deployment slug, not the - underlying family. Without this, a deployment named e.g. - ``"prod-chat-001"`` would silently miss every provider regex. - 2. ``agent_config.model_name`` — the user's configured model id. - 3. ``getattr(llm, "model", None)`` — fallback for direct callers that - don't supply an ``AgentConfig`` (currently a defensive path; all - production callers pass ``agent_config``). - - Returns ``None`` when nothing is available; ``compose_system_prompt`` - treats that as the ``"default"`` variant (no provider block emitted). - """ - if agent_config is not None: - params = agent_config.litellm_params or {} - base_model = params.get("base_model") - if isinstance(base_model, str) and base_model.strip(): - return base_model - if agent_config.model_name: - return agent_config.model_name - return getattr(llm, "model", None) - - -# ============================================================================= -# Connector Type Mapping -# ============================================================================= - -# Maps SearchSourceConnectorType enum values to the searchable document/connector types -# used by pre-search middleware and web_search. -# Live search connectors (TAVILY_API, LINKUP_API, BAIDU_SEARCH_API) are routed to -# the web_search tool; all others are considered local/indexed data. -_CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = { - # Live search connectors (handled by web_search tool) - "TAVILY_API": "TAVILY_API", - "LINKUP_API": "LINKUP_API", - "BAIDU_SEARCH_API": "BAIDU_SEARCH_API", - # Local/indexed connectors (handled by KB pre-search middleware) - "SLACK_CONNECTOR": "SLACK_CONNECTOR", - "TEAMS_CONNECTOR": "TEAMS_CONNECTOR", - "NOTION_CONNECTOR": "NOTION_CONNECTOR", - "GITHUB_CONNECTOR": "GITHUB_CONNECTOR", - "LINEAR_CONNECTOR": "LINEAR_CONNECTOR", - "DISCORD_CONNECTOR": "DISCORD_CONNECTOR", - "JIRA_CONNECTOR": "JIRA_CONNECTOR", - "CONFLUENCE_CONNECTOR": "CONFLUENCE_CONNECTOR", - "CLICKUP_CONNECTOR": "CLICKUP_CONNECTOR", - "GOOGLE_CALENDAR_CONNECTOR": "GOOGLE_CALENDAR_CONNECTOR", - "GOOGLE_GMAIL_CONNECTOR": "GOOGLE_GMAIL_CONNECTOR", - "GOOGLE_DRIVE_CONNECTOR": "GOOGLE_DRIVE_FILE", # Connector type differs from document type - "AIRTABLE_CONNECTOR": "AIRTABLE_CONNECTOR", - "LUMA_CONNECTOR": "LUMA_CONNECTOR", - "ELASTICSEARCH_CONNECTOR": "ELASTICSEARCH_CONNECTOR", - "WEBCRAWLER_CONNECTOR": "CRAWLED_URL", # Maps to document type - "BOOKSTACK_CONNECTOR": "BOOKSTACK_CONNECTOR", - "CIRCLEBACK_CONNECTOR": "CIRCLEBACK", # Connector type differs from document type - "OBSIDIAN_CONNECTOR": "OBSIDIAN_CONNECTOR", - "DROPBOX_CONNECTOR": "DROPBOX_FILE", # Connector type differs from document type - "ONEDRIVE_CONNECTOR": "ONEDRIVE_FILE", # Connector type differs from document type - # Composio connectors (unified to native document types). - # Reverse of NATIVE_TO_LEGACY_DOCTYPE in app.db. - "COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "GOOGLE_DRIVE_FILE", - "COMPOSIO_GMAIL_CONNECTOR": "GOOGLE_GMAIL_CONNECTOR", - "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "GOOGLE_CALENDAR_CONNECTOR", -} - -# Document types that don't come from SearchSourceConnector but should always be searchable -_ALWAYS_AVAILABLE_DOC_TYPES: list[str] = [ - "EXTENSION", # Browser extension data - "FILE", # Uploaded files - "NOTE", # User notes - "YOUTUBE_VIDEO", # YouTube videos -] - - -def _map_connectors_to_searchable_types( - connector_types: list[Any], -) -> list[str]: - """ - Map SearchSourceConnectorType enums to searchable document/connector types. - - This function: - 1. Converts connector type enums to their searchable counterparts - 2. Includes always-available document types (EXTENSION, FILE, NOTE, YOUTUBE_VIDEO) - 3. Deduplicates while preserving order - - Args: - connector_types: List of SearchSourceConnectorType enum values - - Returns: - List of searchable connector/document type strings - """ - result_set: set[str] = set() - result_list: list[str] = [] - - # Add always-available document types first - for doc_type in _ALWAYS_AVAILABLE_DOC_TYPES: - if doc_type not in result_set: - result_set.add(doc_type) - result_list.append(doc_type) - - # Map each connector type to its searchable equivalent - for ct in connector_types: - # Handle both enum and string types - ct_str = ct.value if hasattr(ct, "value") else str(ct) - searchable = _CONNECTOR_TYPE_TO_SEARCHABLE.get(ct_str) - if searchable and searchable not in result_set: - result_set.add(searchable) - result_list.append(searchable) - - return result_list - - -# ============================================================================= -# Deep Agent Factory -# ============================================================================= - - -async def create_surfsense_deep_agent( - llm: BaseChatModel, - search_space_id: int, - db_session: AsyncSession, - connector_service: ConnectorService, - checkpointer: Checkpointer, - user_id: str | None = None, - thread_id: int | None = None, - agent_config: AgentConfig | None = None, - enabled_tools: list[str] | None = None, - disabled_tools: list[str] | None = None, - additional_tools: Sequence[BaseTool] | None = None, - firecrawl_api_key: str | None = None, - thread_visibility: ChatVisibility | None = None, - mentioned_document_ids: list[int] | None = None, - anon_session_id: str | None = None, - filesystem_selection: FilesystemSelection | None = None, -): - """ - Create a SurfSense deep agent with configurable tools and prompts. - - The agent comes with built-in tools that can be configured: - - generate_podcast: Generate audio podcasts from content - - generate_image: Generate images from text descriptions using AI models - - scrape_webpage: Extract content from webpages - - update_memory: Update the user's personal or team memory document - - The agent also includes TodoListMiddleware by default (via create_deep_agent) which provides: - - write_todos: Create and update planning/todo lists for complex tasks - - The system prompt can be configured via agent_config: - - Custom system instructions (or use defaults) - - Citation toggle (enable/disable citation requirements) - - Args: - llm: ChatLiteLLM instance for the agent's language model - search_space_id: The user's search space ID - db_session: Database session for tools that need DB access - connector_service: Initialized connector service for knowledge base search - checkpointer: LangGraph checkpointer for conversation state persistence. - Use AsyncPostgresSaver for production or MemorySaver for testing. - user_id: The current user's UUID string (required for memory tools) - agent_config: Optional AgentConfig from NewLLMConfig for prompt configuration. - If None, uses default system prompt with citations enabled. - enabled_tools: Explicit list of tool names to enable. If None, all default tools - are enabled. Use this to limit which tools are available. - disabled_tools: List of tool names to disable. Applied after enabled_tools. - Use this to exclude specific tools from the defaults. - additional_tools: Extra custom tools to add beyond the built-in ones. - These are always added regardless of enabled/disabled settings. - firecrawl_api_key: Optional Firecrawl API key for premium web scraping. - Falls back to Chromium/Trafilatura if not provided. - - Returns: - CompiledStateGraph: The configured deep agent - - Examples: - # Create agent with all default tools and default prompt - agent = create_surfsense_deep_agent(llm, search_space_id, db_session, ...) - - # Create agent with custom prompt configuration - agent = create_surfsense_deep_agent( - llm, search_space_id, db_session, ..., - agent_config=AgentConfig( - provider="OPENAI", - model_name="gpt-4", - api_key="...", - system_instructions="Custom instructions...", - citations_enabled=False, - ) - ) - - # Create agent with only specific tools - agent = create_surfsense_deep_agent( - llm, search_space_id, db_session, ..., - enabled_tools=["scrape_webpage"] - ) - - # Create agent without podcast generation - agent = create_surfsense_deep_agent( - llm, search_space_id, db_session, ..., - disabled_tools=["generate_podcast"] - ) - - # Add custom tools - agent = create_surfsense_deep_agent( - llm, search_space_id, db_session, ..., - additional_tools=[my_custom_tool] - ) - """ - _t_agent_total = time.perf_counter() - - # Layer thread-aware prompt caching onto the LLM. Idempotent with the - # build-time call in ``llm_config.py``; this run merely adds - # ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` for OpenAI-family - # configs now that ``thread_id`` is known. No-op when ``thread_id`` is - # None or the provider is non-OpenAI-family. - apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id) - - filesystem_selection = filesystem_selection or FilesystemSelection() - backend_resolver = build_backend_resolver( - filesystem_selection, - search_space_id=search_space_id - if filesystem_selection.mode == FilesystemMode.CLOUD - else None, - ) - - # Discover available connectors and document types for this search space. - # - # NOTE: These two calls cannot be parallelized via ``asyncio.gather``. - # ``ConnectorService`` shares a single ``AsyncSession`` (``self.session``); - # SQLAlchemy explicitly forbids concurrent operations on the same session - # ("This session is provisioning a new connection; concurrent operations - # are not permitted on the same session"). The Phase 1.4 in-process TTL - # cache in ``connector_service`` already collapses the warm path to a - # near-zero pair of dict lookups, so sequential awaits cost nothing in - # the common case while remaining correct on cold cache misses. - available_connectors: list[str] | None = None - available_document_types: list[str] | None = None - - _t0 = time.perf_counter() - try: - try: - connector_types_result = await connector_service.get_available_connectors( - search_space_id - ) - if connector_types_result: - available_connectors = _map_connectors_to_searchable_types( - connector_types_result - ) - except Exception as e: - logging.warning("Failed to discover available connectors: %s", e) - - try: - available_document_types = ( - await connector_service.get_available_document_types(search_space_id) - ) - except Exception as e: - logging.warning("Failed to discover available document types: %s", e) - except Exception as e: # pragma: no cover - defensive outer guard - logging.warning(f"Failed to discover available connectors/document types: {e}") - _perf_log.info( - "[create_agent] Connector/doc-type discovery in %.3fs", - time.perf_counter() - _t0, - ) - - # Build dependencies dict for the tools registry - visibility = thread_visibility or ChatVisibility.PRIVATE - - # Extract the model's context window so tools can size their output. - _model_profile = getattr(llm, "profile", None) - _max_input_tokens: int | None = ( - _model_profile.get("max_input_tokens") - if isinstance(_model_profile, dict) - else None - ) - - dependencies = { - "search_space_id": search_space_id, - "db_session": db_session, - "connector_service": connector_service, - "firecrawl_api_key": firecrawl_api_key, - "user_id": user_id, - "thread_id": thread_id, - "thread_visibility": visibility, - "available_connectors": available_connectors, - "available_document_types": available_document_types, - "max_input_tokens": _max_input_tokens, - "llm": llm, - } - - modified_disabled_tools = list(disabled_tools) if disabled_tools else [] - modified_disabled_tools.extend(get_connector_gated_tools(available_connectors)) - - # Remove direct KB search tool; KnowledgePriorityMiddleware now runs hybrid - # search per turn and surfaces hits as a hint plus - # `` markers inside lazy-loaded XML. - if "search_knowledge_base" not in modified_disabled_tools: - modified_disabled_tools.append("search_knowledge_base") - - # Build tools using the async registry (includes MCP tools) - _t0 = time.perf_counter() - tools = await build_tools_async( - dependencies=dependencies, - enabled_tools=enabled_tools, - disabled_tools=modified_disabled_tools, - additional_tools=list(additional_tools) if additional_tools else None, - ) - - # Register the ``invalid`` tool only when tool-call repair is on. It - # is dispatched only when :class:`ToolCallNameRepairMiddleware` - # rewrites a malformed call. We intentionally append it AFTER - # ``build_tools_async`` so it never appears in the system-prompt - # tool list (which is built from the registry, not the bound tool - # list). - _flags: AgentFeatureFlags = get_flags() - if _flags.enable_tool_call_repair and INVALID_TOOL_NAME not in { - t.name for t in tools - }: - tools = [*list(tools), invalid_tool] - _perf_log.info( - "[create_agent] build_tools_async in %.3fs (%d tools)", - time.perf_counter() - _t0, - len(tools), - ) - - # Build system prompt based on agent_config, scoped to the tools actually enabled - _t0 = time.perf_counter() - _enabled_tool_names = {t.name for t in tools} - _user_disabled_tool_names = set(disabled_tools) if disabled_tools else set() - - # Collect generic MCP connector info so the system prompt can route queries - # to their tools instead of falling back to "not in knowledge base". - _mcp_connector_tools: dict[str, list[str]] = {} - for t in tools: - meta = getattr(t, "metadata", None) or {} - if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"): - _mcp_connector_tools.setdefault( - meta["mcp_connector_name"], - [], - ).append(t.name) - - if _mcp_connector_tools: - _perf_log.info("MCP connector tool routing: %s", _mcp_connector_tools) - - if agent_config is not None: - system_prompt = build_configurable_system_prompt( - custom_system_instructions=agent_config.system_instructions, - use_default_system_instructions=agent_config.use_default_system_instructions, - citations_enabled=agent_config.citations_enabled, - thread_visibility=thread_visibility, - enabled_tool_names=_enabled_tool_names, - disabled_tool_names=_user_disabled_tool_names, - mcp_connector_tools=_mcp_connector_tools, - model_name=_resolve_prompt_model_name(agent_config, llm), - ) - else: - system_prompt = build_surfsense_system_prompt( - thread_visibility=thread_visibility, - enabled_tool_names=_enabled_tool_names, - disabled_tool_names=_user_disabled_tool_names, - mcp_connector_tools=_mcp_connector_tools, - model_name=_resolve_prompt_model_name(agent_config, llm), - ) - _perf_log.info( - "[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0 - ) - - # Combine system_prompt with BASE_AGENT_PROMPT (same as create_deep_agent) - final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT - - # The middleware stack — and especially ``SubAgentMiddleware`` — is *not* - # cheap to build. ``SubAgentMiddleware.__init__`` calls ``create_agent`` - # synchronously to compile the general-purpose subagent's full state graph - # (every tool + every middleware → pydantic schemas + langgraph compile). - # On gpt-5.x agents that's roughly 1.5-2s of pure CPU work. If we run it - # directly here it blocks the asyncio event loop for the whole streaming - # task (and any other coroutine sharing this loop), which is why - # "agent creation" wall-clock time used to stretch to ~3-4s. Move the - # entire middleware build + main-graph compile into a single - # ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the - # event loop stays responsive. - # - # PHASE 1: cache the resulting compiled graph. ``agent_cache`` is keyed - # on every per-request value that any middleware in the stack closes - # over in ``__init__`` — drop one and you risk leaking state across - # threads. Hits collapse this whole block to a microsecond lookup; - # misses pay the original CPU cost AND populate the cache. - config_id = agent_config.config_id if agent_config is not None else None - - async def _build_agent() -> Any: - return await asyncio.to_thread( - _build_compiled_agent_blocking, - llm=llm, - tools=tools, - final_system_prompt=final_system_prompt, - backend_resolver=backend_resolver, - filesystem_mode=filesystem_selection.mode, - search_space_id=search_space_id, - user_id=user_id, - thread_id=thread_id, - visibility=visibility, - anon_session_id=anon_session_id, - available_connectors=available_connectors, - available_document_types=available_document_types, - # ``mentioned_document_ids`` is consumed by - # ``KnowledgePriorityMiddleware`` per turn via - # ``runtime.context`` (Phase 1.5). We still pass the - # caller-provided list here for the legacy fallback path - # (cache disabled / context not propagated) — the middleware - # drains its own copy after the first read so a cached graph - # never replays stale mentions. - mentioned_document_ids=mentioned_document_ids, - max_input_tokens=_max_input_tokens, - flags=_flags, - checkpointer=checkpointer, - ) - - _t0 = time.perf_counter() - if _flags.enable_agent_cache and not _flags.disable_new_agent_stack: - # Cache key components — order matters only for human readability; - # the resulting hash is what's stored. Every component must - # rotate on a real shape change AND stay stable across identical - # invocations. - cache_key = stable_hash( - "v1", # schema version of the key — bump if components change - config_id, - thread_id, - user_id, - search_space_id, - visibility, - filesystem_selection.mode, - anon_session_id, - tools_signature( - tools, - available_connectors=available_connectors, - available_document_types=available_document_types, - ), - flags_signature(_flags), - system_prompt_hash(final_system_prompt), - _max_input_tokens, - # ``mentioned_document_ids`` deliberately omitted — middleware - # reads it from ``runtime.context`` (Phase 1.5). - ) - agent = await get_cache().get_or_build(cache_key, builder=_build_agent) - else: - agent = await _build_agent() - _perf_log.info( - "[create_agent] Middleware stack + graph compiled in %.3fs (cache=%s)", - time.perf_counter() - _t0, - "on" - if _flags.enable_agent_cache and not _flags.disable_new_agent_stack - else "off", - ) - - _perf_log.info( - "[create_agent] Total agent creation in %.3fs", - time.perf_counter() - _t_agent_total, - ) - return agent - - -# Tools whose output is too costly / lossy to discard. Keep this -# conservative — anything listed here is *never* pruned by -# :class:`ContextEditingMiddleware`. The list is filtered against -# actually-bound tool names so disabled connectors don't show up here. -_PRUNE_PROTECTED_TOOL_NAMES: frozenset[str] = frozenset( - { - "generate_report", - "generate_resume", - "generate_podcast", - "generate_video_presentation", - "generate_image", - # Read-heavy connector reads — recomputing them is expensive - "read_email", - "search_emails", - # The fallback for malformed tool calls — keep its replies visible - "invalid", - } -) - - -def _safe_exclude_tools(tools: Sequence[BaseTool]) -> tuple[str, ...]: - """Return ``exclude_tools`` derived from the actually-bound tool list. - - Filters :data:`_PRUNE_PROTECTED_TOOL_NAMES` against the bound tools - so we never list tools that don't exist (would be a silent no-op). - """ - enabled = {t.name for t in tools} - return tuple(name for name in _PRUNE_PROTECTED_TOOL_NAMES if name in enabled) - - -# Connector gating: any tool whose ``ToolDefinition.required_connector`` -# isn't actually wired up gets a synthesized permission deny rule so -# execution attempts short-circuit with ``permission_denied`` instead of -# bubbling up provider-specific 401/404 errors. Mirrors OpenCode's -# ``Permission.disabled`` (declarative, per-tool gating) — replaces the -# legacy binary ``_CONNECTOR_TYPE_TO_SEARCHABLE`` substring-heuristic. -def _synthesize_connector_deny_rules( - *, - available_connectors: list[str] | None, - enabled_tool_names: set[str], -) -> list[Rule]: - """Build deny rules for tools whose required connector is not enabled. - - Source of truth is ``ToolDefinition.required_connector`` in - :data:`BUILTIN_TOOLS`. A tool only gets a deny rule when: - - 1. It is currently bound (``enabled_tool_names``). - 2. It declares a ``required_connector``. - 3. That connector is *not* in ``available_connectors``. - """ - available = set(available_connectors or []) - deny: list[Rule] = [] - for tool_def in BUILTIN_TOOLS: - if tool_def.name not in enabled_tool_names: - continue - rc = tool_def.required_connector - if rc and rc not in available: - deny.append(Rule(permission=tool_def.name, pattern="*", action="deny")) - return deny - - -def _build_compiled_agent_blocking( - *, - llm: BaseChatModel, - tools: Sequence[BaseTool], - final_system_prompt: str, - backend_resolver: Any, - filesystem_mode: FilesystemMode, - search_space_id: int, - user_id: str | None, - thread_id: int | None, - visibility: ChatVisibility, - anon_session_id: str | None, - available_connectors: list[str] | None, - available_document_types: list[str] | None, - mentioned_document_ids: list[int] | None, - max_input_tokens: int | None, - flags: AgentFeatureFlags, - checkpointer: Checkpointer, -): - """Build the middleware stack and compile the agent graph synchronously. - - Runs in a worker thread (see ``asyncio.to_thread`` call site) so the heavy - CPU work — most notably ``SubAgentMiddleware.__init__`` eagerly calling - ``create_agent`` to compile the general-purpose subagent — does not block - the event loop. - """ - _memory_middleware = MemoryInjectionMiddleware( - user_id=user_id, - search_space_id=search_space_id, - thread_visibility=visibility, - ) - - # General-purpose subagent middleware - # Subagent omits AnonymousDocumentMiddleware, KnowledgeTreeMiddleware, - # KnowledgePriorityMiddleware, and KnowledgeBasePersistenceMiddleware - it - # inherits state and tools from the parent, but should not (a) re-load - # anon docs / re-render the tree / re-run hybrid search, or (b) commit at - # its own completion (only the top-level agent's aafter_agent commits). - gp_middleware = [ - TodoListMiddleware(), - _memory_middleware, - FileIntentMiddleware(llm=llm), - SurfSenseFilesystemMiddleware( - backend=backend_resolver, - filesystem_mode=filesystem_mode, - search_space_id=search_space_id, - created_by_id=user_id, - thread_id=thread_id, - ), - create_surfsense_compaction_middleware(llm, StateBackend), - PatchToolCallsMiddleware(), - ] - - general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key] - **GENERAL_PURPOSE_SUBAGENT, - "model": llm, - "tools": tools, - "middleware": gp_middleware, - } - - # Specialized user-facing subagents (explore, report_writer, - # connector_negotiator). Registered through SubAgentMiddleware alongside - # the general-purpose spec so the parent's `task` tool can address them - # by name. Off by default until the flag flips so existing deployments - # don't see new agent types in the task tool description. - specialized_subagents: list[SubAgent] = [] - if flags.enable_specialized_subagents and not flags.disable_new_agent_stack: - try: - # Specialized subagents share the parent's filesystem + - # todo view so their system prompts (which promise - # ``read_file``, ``ls``, ``grep``, ``glob``, ``write_todos``) - # actually match runtime behavior. Build *fresh* instances - # rather than aliasing the parent's GP middleware to avoid - # subtle state coupling across compiled graphs. - subagent_extra_middleware: list = [ - TodoListMiddleware(), - SurfSenseFilesystemMiddleware( - backend=backend_resolver, - filesystem_mode=filesystem_mode, - search_space_id=search_space_id, - created_by_id=user_id, - thread_id=thread_id, - ), - ] - specialized_subagents = build_specialized_subagents( - tools=tools, - model=llm, - extra_middleware=subagent_extra_middleware, - ) - logging.info( - "Specialized subagents registered for task tool: %s", - [s["name"] for s in specialized_subagents], - ) - except Exception as exc: # pragma: no cover - defensive - logging.warning( - "Specialized subagent build failed; running without them: %s", - exc, - ) - specialized_subagents = [] - - subagent_specs: list[SubAgent] = [general_purpose_spec, *specialized_subagents] - - # Main agent middleware - # Order: AnonDoc -> Tree -> Priority -> FileIntent -> Filesystem -> Persistence -> ... - # before_agent hooks run in declared order; later injections sit closer to - # the latest human turn. Tree (large + cacheable) is injected earliest so - # provider-side prefix caching has more material to hit; FileIntent (most - # actionable per-turn contract) is injected closest to the user message. - # - # ``wrap_model_call`` ordering: the FIRST middleware in the list is the - # OUTERMOST wrapper. To ensure prune executes before summarization, - # place ``SpillingContextEditingMiddleware`` before - # ``SurfSenseCompactionMiddleware``. Compaction is the canonical - # token-budget defense; the Bedrock buffer-empty defense is folded - # into ``SurfSenseCompactionMiddleware``. - summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend) - _ = flags.enable_compaction_v2 # historical flag; retained for telemetry parity - - # ContextEditing prune. Trigger at 55% of ``max_input_tokens``, - # earlier than summarization (~85%). When disabled, no edit runs. - context_edit_mw = None - if ( - flags.enable_context_editing - and not flags.disable_new_agent_stack - and max_input_tokens - ): - spill_edit = SpillToBackendEdit( - trigger=int(max_input_tokens * 0.55), - clear_at_least=int(max_input_tokens * 0.15), - keep=5, - exclude_tools=_safe_exclude_tools(tools), - clear_tool_inputs=True, - ) - clear_edit = ClearToolUsesEdit( - trigger=int(max_input_tokens * 0.55), - clear_at_least=int(max_input_tokens * 0.15), - keep=5, - exclude_tools=_safe_exclude_tools(tools), - clear_tool_inputs=True, - placeholder="[cleared - older tool output trimmed for context]", - ) - context_edit_mw = SpillingContextEditingMiddleware( - edits=[spill_edit, clear_edit], - backend_resolver=backend_resolver, - ) - - # Resilience knobs: header-aware retry, model fallback, and - # per-thread / per-run call-count limits. The fallback / limit - # middlewares are vanilla LangChain primitives; ``RetryAfter`` is - # SurfSense's header-aware variant (see its module docstring). - retry_mw = ( - RetryAfterMiddleware(max_retries=3) - if flags.enable_retry_after and not flags.disable_new_agent_stack - else None - ) - # Fallback chain — primary is the agent's own model; we add cheap - # alternatives. Off by default; only the first call site that - # configures the chain via env should enable it. - fallback_mw: ScopedModelFallbackMiddleware | None = None - if flags.enable_model_fallback and not flags.disable_new_agent_stack: - try: - fallback_mw = ScopedModelFallbackMiddleware( - "openai:gpt-4o-mini", - "anthropic:claude-3-5-haiku-20241022", - ) - except Exception: - logging.warning("ScopedModelFallbackMiddleware init failed; skipping.") - fallback_mw = None - model_call_limit_mw = ( - ModelCallLimitMiddleware( - thread_limit=120, - run_limit=80, - exit_behavior="end", - ) - if flags.enable_model_call_limit and not flags.disable_new_agent_stack - else None - ) - tool_call_limit_mw = ( - ToolCallLimitMiddleware( - thread_limit=300, run_limit=80, exit_behavior="continue" - ) - if flags.enable_tool_call_limit and not flags.disable_new_agent_stack - else None - ) - - # Provider-compat ``_noop`` injection (mirrors OpenCode's - # ``llm.ts`` workaround for providers that reject empty assistant - # turns or alternating-role constraints). - noop_mw = ( - NoopInjectionMiddleware() - if flags.enable_compaction_v2 and not flags.disable_new_agent_stack - else None - ) - - # Tool-call name repair (lowercase + ``invalid`` fallback). - # - # ``registered_tool_names`` MUST cover every tool the model can legitimately - # call. That includes the bound ``tools`` list AND every tool provided by - # middleware in the stack — ``FilesystemMiddleware`` (read_file, ls, grep, - # glob, edit_file, write_file, execute), ``TodoListMiddleware`` - # (write_todos), ``SubAgentMiddleware`` (task), ``SkillsMiddleware`` (skill - # loaders), etc. If we only inspect ``tools`` here, every call to - # ``read_file`` / ``ls`` / ``grep`` from the model will be rewritten to - # ``invalid`` because the repair middleware doesn't recognize them. The - # built-in deepagents middleware aren't in scope yet at this point of the - # function but they're added unconditionally below, so we hard-code their - # canonical names alongside the dynamic ``tools`` set. - repair_mw = None - if flags.enable_tool_call_repair and not flags.disable_new_agent_stack: - registered_names: set[str] = {t.name for t in tools} - # Tools owned by the standard deepagents middleware stack and the - # SurfSense filesystem extension. - registered_names |= { - "write_todos", - "ls", - "read_file", - "write_file", - "edit_file", - "glob", - "grep", - "execute", - "task", - "mkdir", - "cd", - "pwd", - "move_file", - "rm", - "rmdir", - "list_tree", - "execute_code", - } - repair_mw = ToolCallNameRepairMiddleware( - registered_tool_names=registered_names, - # Disable fuzzy matching to avoid silent rewrites; the - # lowercase + ``invalid`` fallback alone covers >95% of - # observed model errors. - fuzzy_match_threshold=None, - ) - - # Doom-loop detector. Off by default until the frontend handles - # ``permission == "doom_loop"`` interrupts. - doom_loop_mw = ( - DoomLoopMiddleware(threshold=3) - if flags.enable_doom_loop and not flags.disable_new_agent_stack - else None - ) - - # PermissionMiddleware. Layers, earliest -> latest (last match wins, - # same evaluation order as OpenCode's ``permission/index.ts``): - # - # 1. ``surfsense_defaults`` — single ``allow */*`` rule. SurfSense - # already runs per-tool HITL (see ``tools/hitl.py``) for mutating - # connector tools, so we only want PermissionMiddleware to *deny* - # things the user has gated off; the default fallback in - # ``permissions.evaluate`` is ``ask``, which would double-prompt - # on every safe read-only call (``ls``, ``read_file``, ``grep``, - # ``glob``, ``web_search`` …) and, on resume, replay the previous - # reject decision into innocent calls. - # 2. ``desktop_safety`` — ``ask`` for destructive filesystem ops when - # the agent is operating against the user's real disk. Cloud mode - # has full revision-based revert via ``revert_service``, but - # desktop mode hits disk immediately with no undo, so an - # accidental ``rm`` / ``rmdir`` / ``move_file`` / ``edit_file`` / - # ``write_file`` is unrecoverable. This layer is forced on in - # desktop mode regardless of ``enable_permission`` because the - # safety net is non-negotiable. - # 3. ``connector_synthesized`` — deny rules for tools whose required - # connector is not connected to this space. Overrides #1/#2. - # 4. (future) user-defined rules from ``agent_permission_rules`` table - # via the Agent Permissions UI. Loaded last so they override all. - permission_mw: PermissionMiddleware | None = None - is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER - permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack - # Build the middleware whenever it has work to do: either the user - # opted into the rule engine, OR we're in desktop mode and need the - # safety rules unconditionally. - if permission_enabled or is_desktop_fs: - rulesets: list[Ruleset] = [ - Ruleset( - rules=[Rule(permission="*", pattern="*", action="allow")], - origin="surfsense_defaults", - ), - ] - if is_desktop_fs: - rulesets.append( - Ruleset( - rules=[ - Rule(permission="rm", pattern="*", action="ask"), - Rule(permission="rmdir", pattern="*", action="ask"), - Rule(permission="move_file", pattern="*", action="ask"), - Rule(permission="edit_file", pattern="*", action="ask"), - Rule(permission="write_file", pattern="*", action="ask"), - ], - origin="desktop_safety", - ) - ) - if permission_enabled: - synthesized = _synthesize_connector_deny_rules( - available_connectors=available_connectors, - enabled_tool_names={t.name for t in tools}, - ) - rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized")) - permission_mw = PermissionMiddleware(rulesets=rulesets) - - # ActionLogMiddleware. Off by default until the ``agent_action_log`` - # table is migrated. When enabled, persists one row per tool call - # with optional reverse_descriptor for - # ``POST /api/threads/{thread_id}/revert/{action_id}``. Sits inside - # ``permission`` so denied calls aren't logged as completions. - action_log_mw: ActionLogMiddleware | None = None - if ( - flags.enable_action_log - and not flags.disable_new_agent_stack - and thread_id is not None - ): - try: - tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS} - action_log_mw = ActionLogMiddleware( - thread_id=thread_id, - search_space_id=search_space_id, - user_id=user_id, - tool_definitions=tool_defs_by_name, - ) - except Exception: # pragma: no cover - defensive - logging.warning( - "ActionLogMiddleware init failed; running without it.", - exc_info=True, - ) - action_log_mw = None - - # Per-thread busy mutex (refuse a second concurrent turn on the same - # thread; see :class:`BusyMutexMiddleware` docstring). - busy_mutex_mw: BusyMutexMiddleware | None = ( - BusyMutexMiddleware() - if flags.enable_busy_mutex and not flags.disable_new_agent_stack - else None - ) - - # OpenTelemetry spans (model.call + tool.call). Lives just inside - # BusyMutex so it spans every retry/fallback attempt of the current - # turn but never wraps a queued/blocked turn. - otel_mw: OtelSpanMiddleware | None = ( - OtelSpanMiddleware() - if flags.enable_otel and not flags.disable_new_agent_stack - else None - ) - - # Plugin entry-point loader. Off by default; opt-in via the - # ``SURFSENSE_ENABLE_PLUGIN_LOADER`` flag. The allowlist is read from - # the ``SURFSENSE_ALLOWED_PLUGINS`` env var (comma-separated). A future - # PR can wire it through ``global_llm_config.yaml``. - plugin_middlewares: list[Any] = [] - if flags.enable_plugin_loader and not flags.disable_new_agent_stack: - try: - allowed_names = load_allowed_plugin_names_from_env() - if allowed_names: - plugin_middlewares = load_plugin_middlewares( - PluginContext.build( - search_space_id=search_space_id, - user_id=user_id, - thread_visibility=visibility, - llm=llm, - ), - allowed_plugin_names=allowed_names, - ) - except Exception: # pragma: no cover - defensive - logging.warning( - "Plugin loader failed; continuing without plugins.", - exc_info=True, - ) - plugin_middlewares = [] - - # SkillsMiddleware (deepagents) loads built-in + space-authored - # skills via a CompositeBackend. Sources are layered: built-in first, - # space last, so a search-space-authored skill of the same name - # overrides the bundled one. - skills_mw: SkillsMiddleware | None = None - if flags.enable_skills and not flags.disable_new_agent_stack: - try: - skills_factory = build_skills_backend_factory( - search_space_id=search_space_id - if filesystem_mode == FilesystemMode.CLOUD - else None, - ) - skills_mw = SkillsMiddleware( - backend=skills_factory, - sources=default_skills_sources(), - ) - except Exception as exc: # pragma: no cover - defensive - logging.warning("SkillsMiddleware init failed; skipping: %s", exc) - skills_mw = None - - # LangChain's LLM-driven tool selection — only enabled for stacks - # large enough to need narrowing (>30 tools). - selector_mw: LLMToolSelectorMiddleware | None = None - if ( - flags.enable_llm_tool_selector - and not flags.disable_new_agent_stack - and len(tools) > 30 - ): - try: - selector_mw = LLMToolSelectorMiddleware( - model="openai:gpt-4o-mini", - max_tools=12, - always_include=[ - name - for name in ( - "update_memory", - "get_connected_accounts", - "scrape_webpage", - ) - if name in {t.name for t in tools} - ], - ) - except Exception: - logging.warning("LLMToolSelectorMiddleware init failed; skipping.") - selector_mw = None - - deepagent_middleware = [ - # BusyMutex is OUTERMOST: it must wrap the entire stream so no - # other turn can sneak in while this one is mid-flight. - busy_mutex_mw, - # OTel spans sit just inside BusyMutex so each retry attempt - # gets its own model.call / tool.call span. - otel_mw, - TodoListMiddleware(), - _memory_middleware, - AnonymousDocumentMiddleware( - anon_session_id=anon_session_id, - ) - if filesystem_mode == FilesystemMode.CLOUD - else None, - KnowledgeTreeMiddleware( - search_space_id=search_space_id, - filesystem_mode=filesystem_mode, - llm=llm, - ) - if filesystem_mode == FilesystemMode.CLOUD - else None, - KnowledgePriorityMiddleware( - llm=llm, - planner_llm=get_planner_llm(), - search_space_id=search_space_id, - filesystem_mode=filesystem_mode, - available_connectors=available_connectors, - available_document_types=available_document_types, - mentioned_document_ids=mentioned_document_ids, - ), - FileIntentMiddleware(llm=llm), - SurfSenseFilesystemMiddleware( - backend=backend_resolver, - filesystem_mode=filesystem_mode, - search_space_id=search_space_id, - created_by_id=user_id, - thread_id=thread_id, - ), - KnowledgeBasePersistenceMiddleware( - search_space_id=search_space_id, - created_by_id=user_id, - filesystem_mode=filesystem_mode, - thread_id=thread_id, - ) - if filesystem_mode == FilesystemMode.CLOUD - else None, - # Skill loader. Placed before SubAgentMiddleware so subagents - # inherit the same skill metadata (subagent specs reference the - # same source paths via ``default_skills_sources()``). - skills_mw, - SubAgentMiddleware(backend=StateBackend, subagents=subagent_specs), - # Tool selection (only when >30 tools and flag on). - selector_mw, - # Defensive caps, then prune, then summarize. - model_call_limit_mw, - tool_call_limit_mw, - context_edit_mw, - summarization_mw, - # Provider compatibility + retry chain — placed after prune/compact - # so retries happen on the already-trimmed payload. - noop_mw, - retry_mw, - fallback_mw, - # Coalesce a multi-text-block system message into one block - # immediately before the model call. Sits innermost on the - # system-message-mutation chain so it observes every appender - # (todo / filesystem / skills / subagents …) and prevents - # OpenRouter→Anthropic from redistributing ``cache_control`` - # across N blocks and tripping Anthropic's 4-breakpoint cap. - # See ``middleware/flatten_system.py`` for full rationale. - FlattenSystemMessageMiddleware(), - # Tool-call repair must run after model emits but before - # permission / dedup / doom-loop interpret the calls. - repair_mw, - # Permission deny/ask BEFORE the calls are forwarded to tool nodes. - permission_mw, - doom_loop_mw, - # Action log sits inside permission so denied calls don't appear - # as completions, and outside dedup so each unique tool invocation - # gets its own row. - action_log_mw, - PatchToolCallsMiddleware(), - DedupHITLToolCallsMiddleware(agent_tools=list(tools)), - # Plugin slot — sits at the tail so plugin-side transforms see the - # final tool result. Prompt caching is now applied at LLM build time - # via ``apply_litellm_prompt_caching`` (see prompt_caching.py), so no - # caching middleware is needed here. Multiple plugins run in declared - # order; loader filtered by the admin allowlist already. - *plugin_middlewares, - ] - deepagent_middleware = [m for m in deepagent_middleware if m is not None] - - agent = create_agent( - llm, - system_prompt=final_system_prompt, - tools=list(tools), - middleware=deepagent_middleware, - context_schema=SurfSenseContextSchema, - checkpointer=checkpointer, - ) - return agent.with_config( - { - "recursion_limit": 10_000, - "metadata": { - "ls_integration": "deepagents", - "versions": {"deepagents": deepagents_version}, - }, - } - ) diff --git a/surfsense_backend/app/agents/new_chat/middleware/__init__.py b/surfsense_backend/app/agents/new_chat/middleware/__init__.py deleted file mode 100644 index 6742bd8de..000000000 --- a/surfsense_backend/app/agents/new_chat/middleware/__init__.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Middleware components for the SurfSense new chat agent.""" - -from app.agents.new_chat.middleware.action_log import ActionLogMiddleware -from app.agents.new_chat.middleware.anonymous_document import ( - AnonymousDocumentMiddleware, -) -from app.agents.new_chat.middleware.busy_mutex import BusyMutexMiddleware -from app.agents.new_chat.middleware.compaction import ( - SurfSenseCompactionMiddleware, - create_surfsense_compaction_middleware, -) -from app.agents.new_chat.middleware.context_editing import ( - ClearToolUsesEdit, - SpillingContextEditingMiddleware, - SpillToBackendEdit, -) -from app.agents.new_chat.middleware.dedup_tool_calls import ( - DedupHITLToolCallsMiddleware, -) -from app.agents.new_chat.middleware.doom_loop import DoomLoopMiddleware -from app.agents.new_chat.middleware.file_intent import ( - FileIntentMiddleware, -) -from app.agents.new_chat.middleware.filesystem import ( - SurfSenseFilesystemMiddleware, -) -from app.agents.new_chat.middleware.flatten_system import ( - FlattenSystemMessageMiddleware, -) -from app.agents.new_chat.middleware.kb_persistence import ( - KnowledgeBasePersistenceMiddleware, - commit_staged_filesystem_state, -) -from app.agents.new_chat.middleware.knowledge_search import ( - KnowledgeBaseSearchMiddleware, - KnowledgePriorityMiddleware, -) -from app.agents.new_chat.middleware.knowledge_tree import ( - KnowledgeTreeMiddleware, -) -from app.agents.new_chat.middleware.memory_injection import ( - MemoryInjectionMiddleware, -) -from app.agents.new_chat.middleware.noop_injection import NoopInjectionMiddleware -from app.agents.new_chat.middleware.otel_span import OtelSpanMiddleware -from app.agents.new_chat.middleware.permission import PermissionMiddleware -from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware -from app.agents.new_chat.middleware.skills_backends import ( - BuiltinSkillsBackend, - SearchSpaceSkillsBackend, - build_skills_backend_factory, - default_skills_sources, -) -from app.agents.new_chat.middleware.tool_call_repair import ( - ToolCallNameRepairMiddleware, -) - -__all__ = [ - "ActionLogMiddleware", - "AnonymousDocumentMiddleware", - "BuiltinSkillsBackend", - "BusyMutexMiddleware", - "ClearToolUsesEdit", - "DedupHITLToolCallsMiddleware", - "DoomLoopMiddleware", - "FileIntentMiddleware", - "FlattenSystemMessageMiddleware", - "KnowledgeBasePersistenceMiddleware", - "KnowledgeBaseSearchMiddleware", - "KnowledgePriorityMiddleware", - "KnowledgeTreeMiddleware", - "MemoryInjectionMiddleware", - "NoopInjectionMiddleware", - "OtelSpanMiddleware", - "PermissionMiddleware", - "RetryAfterMiddleware", - "SearchSpaceSkillsBackend", - "SpillToBackendEdit", - "SpillingContextEditingMiddleware", - "SurfSenseCompactionMiddleware", - "SurfSenseFilesystemMiddleware", - "ToolCallNameRepairMiddleware", - "build_skills_backend_factory", - "commit_staged_filesystem_state", - "create_surfsense_compaction_middleware", - "default_skills_sources", -] diff --git a/surfsense_backend/app/agents/new_chat/middleware/file_intent.py b/surfsense_backend/app/agents/new_chat/middleware/file_intent.py deleted file mode 100644 index 7897e13d6..000000000 --- a/surfsense_backend/app/agents/new_chat/middleware/file_intent.py +++ /dev/null @@ -1,334 +0,0 @@ -"""Semantic file-intent routing middleware for new chat turns. - -This middleware classifies the latest human turn into a small intent set: -- chat_only -- file_write -- file_read - -For ``file_write`` turns it injects a strict system contract so the model -uses filesystem tools before claiming success, and provides a deterministic -fallback path when no filename is specified by the user. -""" - -from __future__ import annotations - -import json -import logging -import re -from datetime import UTC, datetime -from enum import StrEnum -from typing import Any - -from langchain.agents.middleware import AgentMiddleware, AgentState -from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage -from langgraph.runtime import Runtime -from pydantic import BaseModel, Field, ValidationError - -logger = logging.getLogger(__name__) - - -class FileOperationIntent(StrEnum): - CHAT_ONLY = "chat_only" - FILE_WRITE = "file_write" - FILE_READ = "file_read" - - -class FileIntentPlan(BaseModel): - intent: FileOperationIntent = Field( - description="Primary user intent for this turn." - ) - confidence: float = Field( - ge=0.0, - le=1.0, - default=0.5, - description="Model confidence in the selected intent.", - ) - suggested_filename: str | None = Field( - default=None, - description="Optional filename (e.g. notes.md) inferred from user request.", - ) - suggested_directory: str | None = Field( - default=None, - description=( - "Optional directory path (e.g. /reports/q2 or reports/q2) inferred from " - "user request." - ), - ) - suggested_path: str | None = Field( - default=None, - description=( - "Optional full file path (e.g. /reports/q2/summary.md). If present, this " - "takes precedence over suggested_directory + suggested_filename." - ), - ) - - -def _extract_text_from_message(message: BaseMessage) -> str: - content = getattr(message, "content", "") - if isinstance(content, str): - return content - if isinstance(content, list): - parts: list[str] = [] - for item in content: - if isinstance(item, str): - parts.append(item) - elif isinstance(item, dict) and item.get("type") == "text": - parts.append(str(item.get("text", ""))) - return "\n".join(part for part in parts if part) - return str(content) - - -def _extract_json_payload(text: str) -> str: - stripped = text.strip() - fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL) - if fenced: - return fenced.group(1) - start = stripped.find("{") - end = stripped.rfind("}") - if start != -1 and end != -1 and end > start: - return stripped[start : end + 1] - return stripped - - -def _sanitize_filename(value: str) -> str: - name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip() - name = re.sub(r"\s+", "-", name) - name = name.strip("._-") - if not name: - name = "note" - if len(name) > 80: - name = name[:80].rstrip("-_.") - return name - - -def _sanitize_path_segment(value: str) -> str: - segment = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip() - segment = re.sub(r"\s+", "_", segment) - segment = segment.strip("._-") - return segment - - -def _normalize_directory(value: str) -> str: - raw = value.strip().replace("\\", "/") - raw = raw.strip("/") - if not raw: - return "" - parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()] - parts = [part for part in parts if part] - return "/".join(parts) - - -def _normalize_file_path(value: str) -> str: - raw = value.strip().replace("\\", "/").strip() - if not raw: - return "" - had_trailing_slash = raw.endswith("/") - raw = raw.strip("/") - if not raw: - return "" - parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()] - parts = [part for part in parts if part] - if not parts: - return "" - if had_trailing_slash: - return f"/{'/'.join(parts)}/" - return f"/{'/'.join(parts)}" - - -def _infer_directory_from_user_text(user_text: str) -> str | None: - patterns = ( - r"\b(?:in|inside|under)\s+(?:the\s+)?([a-zA-Z0-9 _\-/]+?)\s+folder\b", - r"\b(?:in|inside|under)\s+([a-zA-Z0-9 _\-/]+?)\b", - ) - lowered = user_text.lower() - for pattern in patterns: - match = re.search(pattern, lowered, flags=re.IGNORECASE) - if not match: - continue - candidate = match.group(1).strip() - if candidate in {"the", "a", "an"}: - continue - normalized = _normalize_directory(candidate) - if normalized: - return normalized - return None - - -def _fallback_path( - suggested_filename: str | None, - *, - suggested_directory: str | None = None, - suggested_path: str | None = None, - user_text: str, -) -> str: - inferred_dir = _infer_directory_from_user_text(user_text) - - sanitized_filename = "" - if suggested_filename: - sanitized_filename = _sanitize_filename(suggested_filename) - if sanitized_filename.lower().endswith(".txt"): - sanitized_filename = f"{sanitized_filename[:-4]}.md" - if not sanitized_filename: - sanitized_filename = "notes.md" - elif "." not in sanitized_filename: - sanitized_filename = f"{sanitized_filename}.md" - - normalized_suggested_path = ( - _normalize_file_path(suggested_path) if suggested_path else "" - ) - if normalized_suggested_path: - if normalized_suggested_path.endswith("/"): - return f"{normalized_suggested_path.rstrip('/')}/{sanitized_filename}" - return normalized_suggested_path - - directory = _normalize_directory(suggested_directory or "") - if not directory and inferred_dir: - directory = inferred_dir - if directory: - return f"/{directory}/{sanitized_filename}" - - return f"/{sanitized_filename}" - - -def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str: - return ( - "Classify the latest user request into a filesystem intent for an AI agent.\n" - "Return JSON only with this exact schema:\n" - '{"intent":"chat_only|file_write|file_read","confidence":0.0,"suggested_filename":"string or null","suggested_directory":"string or null","suggested_path":"string or null"}\n\n' - "Rules:\n" - "- Use semantic intent, not literal keywords.\n" - "- file_write: user asks to create/save/write/update/edit content as a file.\n" - "- file_read: user asks to open/read/list/search existing files.\n" - "- chat_only: conversational/analysis responses without required file operations.\n" - "- For file_write, choose a concise semantic suggested_filename and match the requested format.\n" - "- If the user mentions a folder/directory, populate suggested_directory.\n" - "- If user specifies an explicit full path, populate suggested_path.\n" - "- Use extensions that match user intent (e.g. .md, .json, .yaml, .csv, .py, .ts, .js, .html, .css, .sql).\n" - "- Do not use .txt; prefer .md for generic text notes.\n" - "- Do not include dates or timestamps in suggested_filename unless explicitly requested.\n" - "- Never include markdown or explanation.\n\n" - f"Recent conversation:\n{recent_conversation or '(none)'}\n\n" - f"Latest user message:\n{user_text}" - ) - - -def _build_recent_conversation( - messages: list[BaseMessage], *, max_messages: int = 6 -) -> str: - rows: list[str] = [] - filtered: list[tuple[str, BaseMessage]] = [] - for msg in messages: - role: str | None = None - if isinstance(msg, HumanMessage): - role = "user" - elif isinstance(msg, AIMessage): - if getattr(msg, "tool_calls", None): - continue - role = "assistant" - else: - continue - filtered.append((role, msg)) - for role, msg in filtered[-max_messages:]: - text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip() - if text: - rows.append(f"{role}: {text[:280]}") - return "\n".join(rows) - - -class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg] - """Classify file intent and inject a strict file-write contract.""" - - tools = () - - def __init__(self, *, llm: BaseChatModel | None = None) -> None: - self.llm = llm - - async def _classify_intent( - self, *, messages: list[BaseMessage], user_text: str - ) -> FileIntentPlan: - if self.llm is None: - return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0) - - prompt = _build_classifier_prompt( - recent_conversation=_build_recent_conversation(messages), - user_text=user_text, - ) - try: - response = await self.llm.ainvoke( - [HumanMessage(content=prompt)], - config={"tags": ["surfsense:internal"]}, - ) - payload = json.loads( - _extract_json_payload(_extract_text_from_message(response)) - ) - plan = FileIntentPlan.model_validate(payload) - return plan - except (json.JSONDecodeError, ValidationError, ValueError) as exc: - logger.warning("File intent classifier returned invalid output: %s", exc) - except Exception as exc: # pragma: no cover - defensive fallback - logger.warning("File intent classifier failed: %s", exc) - - return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0) - - async def abefore_agent( # type: ignore[override] - self, - state: AgentState, - runtime: Runtime[Any], - ) -> dict[str, Any] | None: - del runtime - messages = state.get("messages") or [] - if not messages: - return None - - last_human: HumanMessage | None = None - for msg in reversed(messages): - if isinstance(msg, HumanMessage): - last_human = msg - break - if last_human is None: - return None - - user_text = _extract_text_from_message(last_human).strip() - if not user_text: - return None - - plan = await self._classify_intent(messages=messages, user_text=user_text) - suggested_path = _fallback_path( - plan.suggested_filename, - suggested_directory=plan.suggested_directory, - suggested_path=plan.suggested_path, - user_text=user_text, - ) - contract = { - "intent": plan.intent.value, - "confidence": plan.confidence, - "suggested_path": suggested_path, - "timestamp": datetime.now(UTC).isoformat(), - "turn_id": state.get("turn_id", ""), - } - - if plan.intent != FileOperationIntent.FILE_WRITE: - return {"file_operation_contract": contract} - - contract_msg = SystemMessage( - content=( - "\n" - "This turn intent is file_write.\n" - f"Suggested default path: {suggested_path}\n" - "Rules:\n" - "- You MUST call write_file or edit_file before claiming success.\n" - "- If no path is provided by the user, use the suggested default path.\n" - "- Do not claim a file was created/updated unless tool output confirms it.\n" - "- If the write/edit fails, clearly report failure instead of success.\n" - "- Do not include timestamps or dates in generated file content unless the user explicitly asks for them.\n" - "- For open-ended requests (e.g., random note), generate useful concrete content, not placeholders.\n" - "" - ) - ) - - # Insert just before the latest human turn so it applies to this request. - new_messages = list(messages) - insert_at = max(len(new_messages) - 1, 0) - new_messages.insert(insert_at, contract_msg) - return {"messages": new_messages, "file_operation_contract": contract} diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py deleted file mode 100644 index c46eb98a5..000000000 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ /dev/null @@ -1,1998 +0,0 @@ -"""Custom filesystem middleware for the SurfSense agent. - -This middleware fully overrides every deepagents filesystem tool so that the -``Command(update=...)`` payload can carry SurfSense-specific state fields -(``cwd``, ``staged_dirs``, ``pending_moves``, ``doc_id_by_path``, -``dirty_paths``) atomically alongside the standard ``files`` update. - -In CLOUD mode the backend is :class:`KBPostgresBackend` (lazy DB reads, no DB -writes). End-of-turn persistence is handled by -:class:`KnowledgeBasePersistenceMiddleware`. In DESKTOP_LOCAL_FOLDER mode the -backend is :class:`MultiRootLocalFolderBackend` and writes go straight to disk. - -New tools introduced here: - -* ``mkdir`` — cloud-only stages folder paths to ``state['staged_dirs']``; - desktop creates real directories. -* ``cd`` / ``pwd`` — manage ``state['cwd']`` (per-thread). -* ``move_file`` — staged commit in cloud, real disk move in desktop. -* ``list_tree`` — works in both modes (cloud uses - :func:`KBPostgresBackend.alist_tree_listing`). - -The middleware no longer ships ``save_document``; persistence is inferred -from ``write_file`` / ``edit_file`` against ``/documents/*`` paths. -""" - -from __future__ import annotations - -import asyncio -import json -import logging -import posixpath -import re -import secrets -from typing import Annotated, Any - -from daytona.common.errors import DaytonaError -from deepagents import FilesystemMiddleware -from deepagents.backends.protocol import EditResult, WriteResult -from deepagents.backends.utils import ( - create_file_data, - format_read_response, - validate_path, -) -from langchain.tools import ToolRuntime -from langchain_core.messages import ToolMessage -from langchain_core.tools import BaseTool, StructuredTool -from langgraph.types import Command - -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState -from app.agents.new_chat.middleware.kb_postgres_backend import ( - KBPostgresBackend, - paginate_listing, -) -from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( - MultiRootLocalFolderBackend, -) -from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT -from app.agents.new_chat.sandbox import ( - _evict_sandbox_cache, - delete_sandbox, - get_or_create_sandbox, - is_sandbox_enabled, -) -from app.agents.new_chat.state_reducers import _CLEAR - -logger = logging.getLogger(__name__) - - -# ============================================================================= -# System Prompt (built per-session based on filesystem_mode) -# ============================================================================= -# -# Each chat session runs in exactly one filesystem mode. Including rules for -# the OTHER mode just wastes tokens and confuses the model, so we build the -# prompt + tool descriptions for the active mode only. - -_COMMON_PROMPT_HEADER = """## Following Conventions - -- Read files before editing — understand existing content before making changes. -- Mimic existing style, naming conventions, and patterns. -- Never claim a file was created/updated unless filesystem tool output confirms success. -- If a file write/edit fails, explicitly report the failure. -""" - -_CLOUD_SYSTEM_PROMPT = ( - _COMMON_PROMPT_HEADER - + """ -## Filesystem Tools - -All file paths must start with `/`. Relative paths resolve against the -current working directory (`cwd`, default `/documents`). - -- ls(path, offset=0, limit=200): list files and directories at the given path. -- read_file(path, offset, limit): read a file (paginated) from the filesystem. -- write_file(path, content): create a new text file in the workspace. -- edit_file(path, old, new): exact string-replacement edit (lazy-loads KB - documents on first edit). -- glob(pattern, path): find files matching a glob pattern. -- grep(pattern, path, glob): substring search across files. -- mkdir(path): create a folder under `/documents/` (committed at end of turn). -- cd(path): change the current working directory. -- pwd(): print the current working directory. -- move_file(source, dest): move/rename a file under `/documents/`. -- rm(path): delete a single file under `/documents/` (no `-r`). -- rmdir(path): delete an empty directory under `/documents/`. -- list_tree(path, max_depth, page_size): recursively list files/folders. - -## Persistence Rules - -- Files written under `/documents/<...>` are **persisted** at end of turn as - Documents in the user's knowledge base. -- Files whose **basename** starts with `temp_` (e.g. `temp_plan.md` or - `/documents/temp_scratch.md`) are **discarded** at end of turn — use this - prefix for any scratch/working content you do NOT want saved. -- All other paths (outside `/documents/` and not `temp_*`) are rejected. -- mkdir/move_file/rm/rmdir are staged this turn and committed at end of - turn alongside any new/edited documents. Snapshot/revert is enabled - for every destructive operation when action logging is on. - -## Reading Documents Efficiently - -Documents are formatted as XML. Each document contains: -- `` — title, type, URL, etc. -- `` — a table of every chunk with its **line range** and a - `matched="true"` flag for chunks that matched the search query. -- `` — the actual chunks in original document order. - -**Workflow**: when reading a large document, read the first ~20 lines to see -the ``, identify chunks marked `matched="true"`, then use -`read_file(path, offset=, limit=)` to jump directly to -those sections instead of reading the entire file sequentially. - -Use `` values as citation IDs in your answers. - -## Priority List - -You receive a `` system message each turn listing the -top-K paths most relevant to the user's query (by hybrid search). Read those -first — matched sections are flagged inside each document's ``. - -## Workspace Tree - -You receive a `` system message each turn with the current -folder/document layout. The tree may be truncated past a hard cap; in that -case, drill into specific folders with `ls(...)` or `list_tree(...)`. - -## grep Line Numbers - -`grep` searches across both your in-memory edits and the indexed chunks in -Postgres. State-cached files return real line numbers; database hits return -`line=0` because their position depends on per-document XML layout — call -`read_file(path)` to find the exact line. -""" -) - -_DESKTOP_SYSTEM_PROMPT = ( - _COMMON_PROMPT_HEADER - + """ -## Local Folder Mode - -This chat operates directly on the user's local folders. Writes and edits -hit disk immediately — there is no end-of-turn staging, no `/documents/` -namespace, and no `temp_` semantics. - -## Filesystem Tools - -All file paths must start with `/` and use mount-prefixed absolute paths -like `//file.ext`. Relative paths resolve against the current working -directory (`cwd`). - -- ls(path, offset=0, limit=200): list files and directories at the given path. -- read_file(path, offset, limit): read a file (paginated) from disk. -- write_file(path, content): write a file to disk. -- edit_file(path, old, new): exact string-replacement edit on disk. -- glob(pattern, path): find files matching a glob pattern. -- grep(pattern, path, glob): substring search across files. -- mkdir(path): create a directory on disk. -- cd(path): change the current working directory. -- pwd(): print the current working directory. -- move_file(source, dest): move/rename a file. -- rm(path): delete a single file from disk (no `-r`). NOT reversible. -- rmdir(path): delete an empty directory from disk. NOT reversible. -- list_tree(path, max_depth, page_size): recursively list files/folders. - -## Workflow Tips - -- If you are unsure which mounts are available, call `ls('/')` first. -- For large trees, prefer `list_tree` then `grep` then `read_file` over - brute-force directory traversal. -- Cross-mount moves are not supported. -- Desktop deletes hit disk immediately and cannot be undone via the - agent's revert flow — confirm before calling `rm`/`rmdir`. -""" -) - -_SANDBOX_PROMPT_ADDENDUM = ( - "\n- execute_code: run Python code in an isolated sandbox." - "\n\n## Code Execution" - "\n\nUse execute_code whenever a task benefits from running code." - " Never perform arithmetic manually." - "\n\nDocuments here are XML-wrapped markdown, not raw data files." - " To work with them programmatically, read the document first," - " extract the data, write it as a clean file (CSV, JSON, etc.)," - " and then run your code against it." -) - - -def _build_filesystem_system_prompt( - filesystem_mode: FilesystemMode, - *, - sandbox_available: bool, -) -> str: - """Build the filesystem system prompt for a given session mode. - - The prompt only describes rules and tools that actually apply in the - chosen mode — there is no cross-mode noise. - """ - base = ( - _CLOUD_SYSTEM_PROMPT - if filesystem_mode == FilesystemMode.CLOUD - else _DESKTOP_SYSTEM_PROMPT - ) - if sandbox_available: - base += _SANDBOX_PROMPT_ADDENDUM - return base - - -# Backwards-compatible alias retained for any external imports. -SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = _CLOUD_SYSTEM_PROMPT - -# ============================================================================= -# Per-Tool Descriptions (shown to the LLM as the tool's docstring) -# ============================================================================= - -# ============================================================================= -# Per-Tool Descriptions (mode-specific; injected as the tool's docstring) -# ============================================================================= - -# --- mode-agnostic --------------------------------------------------------- - -SURFSENSE_READ_FILE_TOOL_DESCRIPTION = """Reads a file from the filesystem. - -Usage: -- By default, reads up to 100 lines from the beginning. -- Use `offset` and `limit` for pagination when files are large. -- Results include line numbers. -- Documents contain a `` near the top listing every chunk with - its line range and a `matched="true"` flag for search-relevant chunks. - Read the index first, then jump to matched chunks with - `read_file(path, offset=, limit=)`. -- Use chunk IDs (``) as citations in answers. -""" - -SURFSENSE_GLOB_TOOL_DESCRIPTION = """Find files matching a glob pattern. - -Supports standard glob patterns: `*`, `**`, `?`. -Returns absolute file paths. -""" - -SURFSENSE_CD_TOOL_DESCRIPTION = """Changes the current working directory (cwd). - -Args: -- path: absolute or relative directory path. Relative paths resolve against - the current cwd. - -The new cwd is used by other filesystem tools whenever a relative path is -given. Returns the resolved cwd. -""" - -SURFSENSE_PWD_TOOL_DESCRIPTION = """Prints the current working directory.""" - -SURFSENSE_EXECUTE_CODE_TOOL_DESCRIPTION = """Executes Python code in an isolated sandbox environment. - -Common data-science packages are pre-installed (pandas, numpy, matplotlib, -scipy, scikit-learn). - -Usage notes: -- No outbound network access. -- Returns combined stdout/stderr with exit code. -- Use print() to produce output. -- Use the optional timeout parameter to override the default timeout. -""" - -# --- cloud-only ------------------------------------------------------------ - -_CLOUD_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path. - -Usage: -- Provide an absolute path under `/documents` (relative paths resolve under - the current cwd, which defaults to `/documents`). -- For very large folders, use `offset` and `limit` to paginate the listing. -- Returns one entry per line; directories end with a trailing `/`. -""" - -_CLOUD_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new text file to the workspace. - -Usage: -- Files written under `/documents/<...>` are persisted as Documents at end - of turn. -- Use a `temp_` filename prefix (e.g. `temp_plan.md` or `/documents/temp_x.md`) - for scratch/working files; they are automatically discarded at end of turn. -- Writes outside `/documents/` are rejected unless the basename starts with - `temp_`. -- Supported outputs include common LLM-friendly text formats like markdown, - json, yaml, csv, xml, html, css, sql, and code files. -- Avoid placeholders; produce concrete and useful text. -""" - -_CLOUD_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files. - -IMPORTANT: -- Read the file before editing. -- Preserve exact indentation and formatting. -- Edits to documents under `/documents/` are persisted at end of turn. -- Edits to `temp_*` files are discarded at end of turn. -""" - -_CLOUD_MOVE_FILE_TOOL_DESCRIPTION = """Moves or renames a file or folder. - -Use absolute paths for both source and destination. - -Notes: -- `move_file` is staged this turn and committed at end of turn. -- The agent cannot overwrite an existing destination — pass a fresh dest - path or move the existing destination away first. -- The anonymous uploaded document is read-only and cannot be moved. -- Rename is a special case of move (same folder, different filename). -""" - -_CLOUD_LIST_TREE_TOOL_DESCRIPTION = """Lists files/folders recursively in a single bounded call. - -Args: -- path: absolute path to start from. Defaults to `/documents`. -- max_depth: recursion depth limit (default 8). -- page_size: maximum number of entries returned (max 1000). -- include_files / include_dirs: filter returned entry types. - -Returns JSON with: -- entries: [{path, is_dir, size, modified_at, depth}] -- truncated: true when additional entries were omitted due to page_size. -""" - -_CLOUD_GREP_TOOL_DESCRIPTION = """Search for a literal text pattern across files. - -Searches both your in-memory edits and the indexed chunks in Postgres. -State-cached file matches include real line numbers; database hits return -`line=0` because their position depends on per-document XML layout — call -`read_file(path)` afterwards to find the exact line. -""" - -_CLOUD_MKDIR_TOOL_DESCRIPTION = """Creates a directory under `/documents/`. - -Stages the folder for end-of-turn commit; the Folder row is inserted only -after the agent's turn finishes successfully. - -Args: -- path: absolute path of the new directory (must start with - `/documents/`). - -Notes: -- Parent folders are created as needed. -""" - -_CLOUD_RM_TOOL_DESCRIPTION = """Deletes a single file under `/documents/`. - -Mirrors POSIX `rm path` (no `-r`, no glob expansion). Stages the deletion -for end-of-turn commit; the row is removed only after the agent's turn -finishes successfully. - -Args: -- path: absolute or relative file path. Cannot point at a directory — use - `rmdir` for empty folders. Cannot target the root or `/documents`. - -Notes: -- The action is reversible via the per-action revert flow when action - logging is enabled. -- The anonymous uploaded document is read-only and cannot be deleted. -""" - -_CLOUD_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory under `/documents/`. - -Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive -deletion (`rm -r`) is intentionally NOT supported — clear contents with -`rm` first. - -Args: -- path: absolute or relative directory path. Cannot target the root, - `/documents`, the current cwd, or any ancestor of cwd (use `cd` to - move out first). - -Notes: -- Emptiness is evaluated against the post-staged view, so a same-turn - `rm /a/x.md` followed by `rmdir /a` is fine. -- If the directory was added in this same turn via `mkdir` and never - committed, the staged mkdir is dropped instead of issuing a delete. -- The action is reversible via the per-action revert flow when action - logging is enabled. -""" - -# --- desktop-only ---------------------------------------------------------- - -_DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path. - -Usage: -- Provide an absolute path using a mount prefix (e.g. `//sub/dir`). - Use `ls('/')` to discover available mounts. -- For very large folders, use `offset` and `limit` to paginate the listing. -- Returns one entry per line; directories end with a trailing `/`. -""" - -_DESKTOP_WRITE_FILE_TOOL_DESCRIPTION = """Writes a text file to disk. - -Usage: -- Use mount-prefixed absolute paths like `//sub/file.ext`. -- Writes hit disk immediately. There is no end-of-turn staging. -- Supported outputs include common LLM-friendly text formats like markdown, - json, yaml, csv, xml, html, css, sql, and code files. -- Avoid placeholders; produce concrete and useful text. -""" - -_DESKTOP_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files on disk. - -IMPORTANT: -- Read the file before editing. -- Preserve exact indentation and formatting. -- Edits hit disk immediately. -""" - -_DESKTOP_MOVE_FILE_TOOL_DESCRIPTION = """Moves or renames a file or folder on disk. - -Use mount-prefixed absolute paths for both source and destination -(e.g. `//old.txt` -> `//new.txt`). - -Notes: -- Cross-mount moves are not supported. -- Rename is a special case of move (same folder, different filename). -""" - -_DESKTOP_LIST_TREE_TOOL_DESCRIPTION = """Lists files/folders recursively in a single bounded call. - -Args: -- path: absolute path to start from. Defaults to `/`. -- max_depth: recursion depth limit (default 8). -- page_size: maximum number of entries returned (max 1000). -- include_files / include_dirs: filter returned entry types. - -Returns JSON with: -- entries: [{path, is_dir, size, modified_at, depth}] -- truncated: true when additional entries were omitted due to page_size. -""" - -_DESKTOP_GREP_TOOL_DESCRIPTION = """Search for a literal text pattern across files. - -Searches files on disk and any in-memory edits. Returns real line numbers. -""" - -_DESKTOP_MKDIR_TOOL_DESCRIPTION = """Creates a directory on disk. - -Args: -- path: absolute mount-prefixed path of the new directory. - -Notes: -- Parent folders are created as needed. -""" - -_DESKTOP_RM_TOOL_DESCRIPTION = """Deletes a single file from disk. - -Mirrors POSIX `rm path` (no `-r`, no glob expansion). The deletion hits -disk immediately. Desktop deletes are NOT reversible via the agent's -revert flow. - -Args: -- path: absolute mount-prefixed file path. Cannot point at a directory — - use `rmdir` for empty folders. -""" - -_DESKTOP_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory from disk. - -Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive -deletion is NOT supported. The deletion hits disk immediately and is -NOT reversible via the agent's revert flow. - -Args: -- path: absolute mount-prefixed directory path. Cannot target the mount - root or any directory containing files/subfolders. -""" - - -def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]: - """Pick the active-mode description for every filesystem tool.""" - if filesystem_mode == FilesystemMode.CLOUD: - return { - "ls": _CLOUD_LIST_FILES_TOOL_DESCRIPTION, - "read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION, - "write_file": _CLOUD_WRITE_FILE_TOOL_DESCRIPTION, - "edit_file": _CLOUD_EDIT_FILE_TOOL_DESCRIPTION, - "move_file": _CLOUD_MOVE_FILE_TOOL_DESCRIPTION, - "list_tree": _CLOUD_LIST_TREE_TOOL_DESCRIPTION, - "glob": SURFSENSE_GLOB_TOOL_DESCRIPTION, - "grep": _CLOUD_GREP_TOOL_DESCRIPTION, - "mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION, - "cd": SURFSENSE_CD_TOOL_DESCRIPTION, - "pwd": SURFSENSE_PWD_TOOL_DESCRIPTION, - "rm": _CLOUD_RM_TOOL_DESCRIPTION, - "rmdir": _CLOUD_RMDIR_TOOL_DESCRIPTION, - } - return { - "ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION, - "read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION, - "write_file": _DESKTOP_WRITE_FILE_TOOL_DESCRIPTION, - "edit_file": _DESKTOP_EDIT_FILE_TOOL_DESCRIPTION, - "move_file": _DESKTOP_MOVE_FILE_TOOL_DESCRIPTION, - "list_tree": _DESKTOP_LIST_TREE_TOOL_DESCRIPTION, - "glob": SURFSENSE_GLOB_TOOL_DESCRIPTION, - "grep": _DESKTOP_GREP_TOOL_DESCRIPTION, - "mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION, - "cd": SURFSENSE_CD_TOOL_DESCRIPTION, - "pwd": SURFSENSE_PWD_TOOL_DESCRIPTION, - "rm": _DESKTOP_RM_TOOL_DESCRIPTION, - "rmdir": _DESKTOP_RMDIR_TOOL_DESCRIPTION, - } - - -# Backwards-compatible aliases retained for any external imports/tests that -# referenced the original CLOUD-flavoured constants. -SURFSENSE_LIST_FILES_TOOL_DESCRIPTION = _CLOUD_LIST_FILES_TOOL_DESCRIPTION -SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = _CLOUD_WRITE_FILE_TOOL_DESCRIPTION -SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = _CLOUD_EDIT_FILE_TOOL_DESCRIPTION -SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION = _CLOUD_MOVE_FILE_TOOL_DESCRIPTION -SURFSENSE_LIST_TREE_TOOL_DESCRIPTION = _CLOUD_LIST_TREE_TOOL_DESCRIPTION -SURFSENSE_GREP_TOOL_DESCRIPTION = _CLOUD_GREP_TOOL_DESCRIPTION -SURFSENSE_MKDIR_TOOL_DESCRIPTION = _CLOUD_MKDIR_TOOL_DESCRIPTION - - -# ============================================================================= -# Helpers -# ============================================================================= - - -_TEMP_PREFIX = "temp_" - - -def _basename(path: str) -> str: - return path.rsplit("/", 1)[-1] - - -def _is_ancestor_of(candidate: str, target: str) -> bool: - """True iff ``candidate`` is a strict ancestor directory of ``target``. - - ``target`` itself is NOT considered an ancestor (use equality for that). - Both paths are assumed to be canonicalised, absolute, and free of - trailing slashes (except the root ``/``). - """ - if not candidate.startswith("/") or not target.startswith("/"): - return False - if candidate == target: - return False - prefix = candidate.rstrip("/") + "/" - return target.startswith(prefix) - - -class SurfSenseFilesystemMiddleware(FilesystemMiddleware): - """SurfSense-specific filesystem middleware (cloud + desktop).""" - - state_schema = SurfSenseFilesystemState - - _MAX_EXECUTE_TIMEOUT = 300 - - def __init__( - self, - *, - backend: Any = None, - filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, - search_space_id: int | None = None, - created_by_id: str | None = None, - thread_id: int | str | None = None, - tool_token_limit_before_evict: int | None = 20000, - ) -> None: - self._filesystem_mode = filesystem_mode - self._search_space_id = search_space_id - self._created_by_id = created_by_id - self._thread_id = thread_id - self._sandbox_available = is_sandbox_enabled() and thread_id is not None - - # Build the prompt + tool descriptions for the active mode only — - # mixing both modes wastes tokens and confuses the model with rules - # it can't actually use this session. - system_prompt = _build_filesystem_system_prompt( - filesystem_mode, - sandbox_available=self._sandbox_available, - ) - - super().__init__( - backend=backend, - system_prompt=system_prompt, - custom_tool_descriptions=_build_tool_descriptions(filesystem_mode), - tool_token_limit_before_evict=tool_token_limit_before_evict, - max_execute_timeout=self._MAX_EXECUTE_TIMEOUT, - ) - self.tools = [t for t in self.tools if t.name != "execute"] - self.tools.append(self._create_mkdir_tool()) - self.tools.append(self._create_cd_tool()) - self.tools.append(self._create_pwd_tool()) - self.tools.append(self._create_move_file_tool()) - self.tools.append(self._create_rm_tool()) - self.tools.append(self._create_rmdir_tool()) - self.tools.append(self._create_list_tree_tool()) - if self._sandbox_available: - self.tools.append(self._create_execute_code_tool()) - - # ------------------------------------------------------------------ helpers - - def _is_cloud(self) -> bool: - return self._filesystem_mode == FilesystemMode.CLOUD - - @staticmethod - def _run_async_blocking(coro: Any) -> Any: - try: - loop = asyncio.get_running_loop() - if loop.is_running(): - return "Error: sync filesystem operation not supported inside an active event loop." - except RuntimeError: - pass - return asyncio.run(coro) - - @staticmethod - def _normalize_absolute_path(candidate: str) -> str: - normalized = re.sub(r"/+", "/", candidate.strip().replace("\\", "/")) - if not normalized: - return "/" - if normalized.startswith("/"): - return normalized - return f"/{normalized.lstrip('/')}" - - @staticmethod - def _extract_mount_from_path(path: str, mounts: tuple[str, ...]) -> str | None: - rel = path.lstrip("/") - if not rel: - return None - mount, _, _ = rel.partition("/") - if mount in mounts: - return mount - return None - - @staticmethod - def _local_parent_path(path: str) -> str: - rel = path.lstrip("/") - if "/" not in rel: - return "/" - parent = rel.rsplit("/", 1)[0].strip("/") - if not parent: - return "/" - return f"/{parent}" - - @staticmethod - def _path_exists_under_mount( - backend: MultiRootLocalFolderBackend, - mount: str, - local_path: str, - ) -> bool: - result = backend.list_tree( - f"/{mount}{local_path}", - max_depth=0, - page_size=1, - include_files=True, - include_dirs=True, - ) - return not bool(result.get("error")) - - def _normalize_local_mount_path( - self, - candidate: str, - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> str: - normalized = self._normalize_absolute_path(candidate) - backend = self._get_backend(runtime) - if not isinstance(backend, MultiRootLocalFolderBackend): - return normalized - - mounts = backend.list_mounts() - explicit_mount = self._extract_mount_from_path(normalized, mounts) - if explicit_mount: - return normalized - - if len(mounts) == 1: - return f"/{mounts[0]}{normalized}" - - suggested_mount: str | None = None - contract = runtime.state.get("file_operation_contract") or {} - suggested_path = contract.get("suggested_path") - if isinstance(suggested_path, str) and suggested_path.strip(): - normalized_suggested = self._normalize_absolute_path(suggested_path) - suggested_mount = self._extract_mount_from_path( - normalized_suggested, mounts - ) - - matching_mounts = [ - mount - for mount in mounts - if self._path_exists_under_mount(backend, mount, normalized) - ] - if len(matching_mounts) == 1: - return f"/{matching_mounts[0]}{normalized}" - - parent_path = self._local_parent_path(normalized) - if parent_path != "/": - parent_matching_mounts = [ - mount - for mount in mounts - if self._path_exists_under_mount(backend, mount, parent_path) - ] - if len(parent_matching_mounts) == 1: - return f"/{parent_matching_mounts[0]}{normalized}" - - if suggested_mount: - return f"/{suggested_mount}{normalized}" - - return f"/{backend.default_mount()}{normalized}" - - def _default_cwd(self) -> str: - return DOCUMENTS_ROOT if self._is_cloud() else "/" - - def _current_cwd(self, runtime: ToolRuntime[None, SurfSenseFilesystemState]) -> str: - cwd = runtime.state.get("cwd") if hasattr(runtime, "state") else None - if isinstance(cwd, str) and cwd.startswith("/"): - return cwd - return self._default_cwd() - - def _get_contract_suggested_path( - self, runtime: ToolRuntime[None, SurfSenseFilesystemState] - ) -> str: - contract = runtime.state.get("file_operation_contract") or {} - suggested = contract.get("suggested_path") - if isinstance(suggested, str) and suggested.strip(): - return self._normalize_absolute_path(suggested) - return self._default_cwd().rstrip("/") + "/notes.md" - - def _resolve_relative( - self, - path: str, - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> str: - candidate = path.strip() - if not candidate: - return self._current_cwd(runtime) - if candidate.startswith("/"): - return self._normalize_absolute_path(candidate) - cwd = self._current_cwd(runtime) - joined = posixpath.normpath(posixpath.join(cwd, candidate)) - return self._normalize_absolute_path(joined) - - def _resolve_write_target_path( - self, - file_path: str, - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> str: - candidate = file_path.strip() - if not candidate: - return self._get_contract_suggested_path(runtime) - if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - return self._normalize_local_mount_path(candidate, runtime) - return self._resolve_relative(candidate, runtime) - - def _resolve_move_target_path( - self, - file_path: str, - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> str: - candidate = file_path.strip() - if not candidate: - return "" - if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - return self._normalize_local_mount_path(candidate, runtime) - return self._resolve_relative(candidate, runtime) - - def _resolve_list_target_path( - self, - path: str, - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> str: - candidate = path.strip() or self._current_cwd(runtime) - if candidate == "/": - return "/" - if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - return self._normalize_local_mount_path(candidate, runtime) - return self._resolve_relative(candidate, runtime) - - # ------------------------------------------------------------------ namespace policy - - def _check_cloud_write_namespace( - self, - path: str, - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> str | None: - """Return an error string if cloud writes to ``path`` are not allowed. - - Order matters: - 1. Reject writes to the anonymous read-only doc. - 2. Allow ``/documents/*``. - 3. Allow ``temp_*`` basename anywhere. - 4. Reject everything else. - """ - if not self._is_cloud(): - return None - anon = runtime.state.get("kb_anon_doc") or {} - if isinstance(anon, dict): - anon_path = str(anon.get("path") or "") - if anon_path and anon_path == path: - return "Error: the anonymous uploaded document is read-only." - if path.startswith(DOCUMENTS_ROOT + "/") or path == DOCUMENTS_ROOT: - return None - if _basename(path).startswith(_TEMP_PREFIX): - return None - return ( - "Error: cloud writes must target /documents/<...> or use a 'temp_' " - f"basename for scratch (got '{path}')." - ) - - # ------------------------------------------------------------------ tool: ls - - def _create_ls_tool(self) -> BaseTool: - tool_description = ( - self._custom_tool_descriptions.get("ls") - or SURFSENSE_LIST_FILES_TOOL_DESCRIPTION - ) - - def sync_ls( - runtime: ToolRuntime[None, SurfSenseFilesystemState], - path: Annotated[ - str, - "Absolute path to the directory to list. Relative paths resolve against the current cwd.", - ] = "", - offset: Annotated[ - int, - "Number of entries to skip. Use for paginating large folders. Defaults to 0.", - ] = 0, - limit: Annotated[ - int, - "Maximum number of entries to return. Defaults to 200.", - ] = 200, - ) -> str: - return self._run_async_blocking( - async_ls(runtime, path=path, offset=offset, limit=limit) - ) - - async def async_ls( - runtime: ToolRuntime[None, SurfSenseFilesystemState], - path: Annotated[ - str, - "Absolute path to the directory to list. Relative paths resolve against the current cwd.", - ] = "", - offset: Annotated[ - int, - "Number of entries to skip. Use for paginating large folders. Defaults to 0.", - ] = 0, - limit: Annotated[ - int, - "Maximum number of entries to return. Defaults to 200.", - ] = 200, - ) -> str: - target = self._resolve_list_target_path(path, runtime) - try: - validated = validate_path(target) - except ValueError as exc: - return f"Error: {exc}" - if offset < 0: - offset = 0 - if limit < 1: - limit = 1 - backend = self._get_backend(runtime) - infos = await backend.als_info(validated) - page = paginate_listing(infos, offset=offset, limit=limit) - paths = [ - f"{fi.get('path', '')}/" if fi.get("is_dir") else fi.get("path", "") - for fi in page - ] - total = len(infos) - shown = len(page) - header = ( - f"{validated} ({shown} of {total} entries" - f"{f', offset={offset}' if offset else ''})" - ) - if not paths: - return f"{header}\n(empty)" - body = "\n".join(paths) - if total > offset + shown: - body += ( - f"\n... {total - offset - shown} more — call ls(" - f"'{validated}', offset={offset + shown}, limit={limit})" - ) - return f"{header}\n{body}" - - return StructuredTool.from_function( - name="ls", - description=tool_description, - func=sync_ls, - coroutine=async_ls, - ) - - # ------------------------------------------------------------------ tool: read_file - - def _create_read_file_tool(self) -> BaseTool: - tool_description = ( - self._custom_tool_descriptions.get("read_file") - or SURFSENSE_READ_FILE_TOOL_DESCRIPTION - ) - - async def async_read_file( - file_path: Annotated[ - str, - "Absolute path to the file to read. Relative paths resolve against the current cwd.", - ], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - offset: Annotated[ - int, - "Line number to start reading from (0-indexed).", - ] = 0, - limit: Annotated[ - int, - "Maximum number of lines to read.", - ] = 100, - ) -> Command | str: - target = self._resolve_relative(file_path, runtime) - try: - validated = validate_path(target) - except ValueError as exc: - return f"Error: {exc}" - - files = runtime.state.get("files") or {} - if validated in files: - return format_read_response(files[validated], offset, limit) - - backend = self._get_backend(runtime) - if isinstance(backend, KBPostgresBackend): - loaded = await backend._load_file_data(validated) - if loaded is None: - return f"Error: File '{validated}' not found" - file_data, doc_id = loaded - rendered = format_read_response(file_data, offset, limit) - update: dict[str, Any] = { - "files": {validated: file_data}, - "messages": [ - ToolMessage( - content=rendered, - tool_call_id=runtime.tool_call_id, - ) - ], - } - if doc_id is not None: - update["doc_id_by_path"] = {validated: doc_id} - return Command(update=update) - - try: - rendered = await backend.aread(validated, offset=offset, limit=limit) - except Exception as exc: # pragma: no cover - defensive - return f"Error: {exc}" - return rendered - - def sync_read_file( - file_path: Annotated[ - str, - "Absolute path to the file to read. Relative paths resolve against the current cwd.", - ], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - offset: Annotated[ - int, - "Line number to start reading from (0-indexed).", - ] = 0, - limit: Annotated[ - int, - "Maximum number of lines to read.", - ] = 100, - ) -> Command | str: - return self._run_async_blocking( - async_read_file(file_path, runtime, offset, limit) - ) - - return StructuredTool.from_function( - name="read_file", - description=tool_description, - func=sync_read_file, - coroutine=async_read_file, - ) - - # ------------------------------------------------------------------ tool: write_file - - def _create_write_file_tool(self) -> BaseTool: - tool_description = ( - self._custom_tool_descriptions.get("write_file") - or SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION - ) - - async def async_write_file( - file_path: Annotated[ - str, - "Absolute path where the file should be created. Relative paths resolve against the current cwd.", - ], - content: Annotated[str, "Text content to write to the file."], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> Command | str: - target = self._resolve_write_target_path(file_path, runtime) - try: - validated = validate_path(target) - except ValueError as exc: - return f"Error: {exc}" - - namespace_error = self._check_cloud_write_namespace(validated, runtime) - if namespace_error: - return namespace_error - - backend = self._get_backend(runtime) - res: WriteResult = await backend.awrite(validated, content) - if res.error: - return res.error - - path = res.path or validated - files_update = res.files_update or {path: create_file_data(content)} - update: dict[str, Any] = { - "files": files_update, - "messages": [ - ToolMessage( - content=f"Updated file {path}", - tool_call_id=runtime.tool_call_id, - ) - ], - } - if self._is_cloud(): - update["dirty_paths"] = [path] - update["dirty_path_tool_calls"] = {path: runtime.tool_call_id} - return Command(update=update) - - def sync_write_file( - file_path: Annotated[ - str, - "Absolute path where the file should be created. Relative paths resolve against the current cwd.", - ], - content: Annotated[str, "Text content to write to the file."], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> Command | str: - return self._run_async_blocking( - async_write_file(file_path, content, runtime) - ) - - return StructuredTool.from_function( - name="write_file", - description=tool_description, - func=sync_write_file, - coroutine=async_write_file, - ) - - # ------------------------------------------------------------------ tool: edit_file - - def _create_edit_file_tool(self) -> BaseTool: - tool_description = ( - self._custom_tool_descriptions.get("edit_file") - or SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION - ) - - async def async_edit_file( - file_path: Annotated[ - str, - "Absolute path to the file to edit. Relative paths resolve against the current cwd.", - ], - old_string: Annotated[ - str, - "Exact text to replace. Must be unique unless replace_all is True.", - ], - new_string: Annotated[ - str, - "Replacement text. Must differ from old_string.", - ], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - *, - replace_all: Annotated[ - bool, - "If True, replace all occurrences of old_string. Defaults to False.", - ] = False, - ) -> Command | str: - target = self._resolve_relative(file_path, runtime) - try: - validated = validate_path(target) - except ValueError as exc: - return f"Error: {exc}" - - namespace_error = self._check_cloud_write_namespace(validated, runtime) - if namespace_error: - return namespace_error - - backend = self._get_backend(runtime) - files_state = runtime.state.get("files") or {} - doc_id_to_attach: int | None = None - - if ( - self._is_cloud() - and validated not in files_state - and isinstance(backend, KBPostgresBackend) - ): - loaded = await backend._load_file_data(validated) - if loaded is None: - return f"Error: File '{validated}' not found" - _, doc_id_to_attach = loaded - - res: EditResult = await backend.aedit( - validated, old_string, new_string, replace_all=replace_all - ) - if res.error: - return res.error - - path = res.path or validated - files_update = res.files_update or {} - update: dict[str, Any] = { - "files": files_update, - "messages": [ - ToolMessage( - content=( - f"Successfully replaced {res.occurrences} instance(s) " - f"of the string in '{path}'" - ), - tool_call_id=runtime.tool_call_id, - ) - ], - } - if self._is_cloud(): - update["dirty_paths"] = [path] - update["dirty_path_tool_calls"] = {path: runtime.tool_call_id} - if doc_id_to_attach is not None: - update["doc_id_by_path"] = {path: doc_id_to_attach} - return Command(update=update) - - def sync_edit_file( - file_path: Annotated[ - str, - "Absolute path to the file to edit. Relative paths resolve against the current cwd.", - ], - old_string: Annotated[ - str, - "Exact text to replace. Must be unique unless replace_all is True.", - ], - new_string: Annotated[ - str, - "Replacement text. Must differ from old_string.", - ], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - *, - replace_all: Annotated[ - bool, - "If True, replace all occurrences of old_string. Defaults to False.", - ] = False, - ) -> Command | str: - return self._run_async_blocking( - async_edit_file( - file_path, old_string, new_string, runtime, replace_all=replace_all - ) - ) - - return StructuredTool.from_function( - name="edit_file", - description=tool_description, - func=sync_edit_file, - coroutine=async_edit_file, - ) - - # ------------------------------------------------------------------ tool: mkdir - - def _create_mkdir_tool(self) -> BaseTool: - tool_description = ( - self._custom_tool_descriptions.get("mkdir") - or SURFSENSE_MKDIR_TOOL_DESCRIPTION - ) - - async def async_mkdir( - path: Annotated[str, "Absolute or relative directory path to create."], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> Command | str: - target = self._resolve_relative(path, runtime) - try: - validated = validate_path(target) - except ValueError as exc: - return f"Error: {exc}" - - if self._is_cloud(): - if not ( - validated.startswith(DOCUMENTS_ROOT + "/") - or validated == DOCUMENTS_ROOT - ): - return ( - "Error: cloud mkdir must target a path under /documents/ " - f"(got '{validated}')." - ) - return Command( - update={ - "staged_dirs": [validated], - "staged_dir_tool_calls": { - validated: runtime.tool_call_id, - }, - "messages": [ - ToolMessage( - content=( - f"Staged directory '{validated}' (will be created " - "at end of turn)." - ), - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - - backend = self._get_backend(runtime) - local_method = getattr(backend, "amkdir", None) or getattr( - backend, "mkdir", None - ) - if callable(local_method): - try: - res = local_method(validated, parents=True, exist_ok=True) - if asyncio.iscoroutine(res): - await res - except TypeError: - res = local_method(validated) - if asyncio.iscoroutine(res): - await res - except Exception as exc: # pragma: no cover - return f"Error: {exc}" - return f"Created directory {validated}" - - def sync_mkdir( - path: Annotated[str, "Absolute or relative directory path to create."], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> Command | str: - return self._run_async_blocking(async_mkdir(path, runtime)) - - return StructuredTool.from_function( - name="mkdir", - description=tool_description, - func=sync_mkdir, - coroutine=async_mkdir, - ) - - # ------------------------------------------------------------------ tool: cd - - def _create_cd_tool(self) -> BaseTool: - tool_description = ( - self._custom_tool_descriptions.get("cd") or SURFSENSE_CD_TOOL_DESCRIPTION - ) - - async def async_cd( - path: Annotated[str, "Absolute or relative directory path to switch into."], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> Command | str: - target = self._resolve_relative(path, runtime) - try: - validated = validate_path(target) - except ValueError as exc: - return f"Error: {exc}" - - backend = self._get_backend(runtime) - try: - infos = await backend.als_info(validated) - except Exception as exc: # pragma: no cover - defensive - return f"Error: {exc}" - staged_dirs = list(runtime.state.get("staged_dirs") or []) - files = runtime.state.get("files") or {} - cwd_exists = ( - bool(infos) - or validated in staged_dirs - or any(p == validated for p in files) - or any( - isinstance(p, str) and p.startswith(validated.rstrip("/") + "/") - for p in files - ) - or validated == "/" - or validated == DOCUMENTS_ROOT - ) - if not cwd_exists: - return f"Error: directory '{validated}' not found." - return Command( - update={ - "cwd": validated, - "messages": [ - ToolMessage( - content=f"cwd changed to {validated}", - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - - def sync_cd( - path: Annotated[str, "Absolute or relative directory path to switch into."], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> Command | str: - return self._run_async_blocking(async_cd(path, runtime)) - - return StructuredTool.from_function( - name="cd", - description=tool_description, - func=sync_cd, - coroutine=async_cd, - ) - - # ------------------------------------------------------------------ tool: pwd - - def _create_pwd_tool(self) -> BaseTool: - tool_description = ( - self._custom_tool_descriptions.get("pwd") or SURFSENSE_PWD_TOOL_DESCRIPTION - ) - - def sync_pwd( - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> str: - return self._current_cwd(runtime) - - async def async_pwd( - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> str: - return self._current_cwd(runtime) - - return StructuredTool.from_function( - name="pwd", - description=tool_description, - func=sync_pwd, - coroutine=async_pwd, - ) - - # ------------------------------------------------------------------ tool: move_file - - def _create_move_file_tool(self) -> BaseTool: - tool_description = ( - self._custom_tool_descriptions.get("move_file") - or SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION - ) - - async def async_move_file( - source_path: Annotated[str, "Absolute or relative source path."], - destination_path: Annotated[str, "Absolute or relative destination path."], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - *, - overwrite: Annotated[ - bool, - "If True, replace existing destination. Cloud mode rejects True. Defaults to False.", - ] = False, - ) -> Command | str: - if not source_path.strip() or not destination_path.strip(): - return "Error: source_path and destination_path are required." - - source = self._resolve_move_target_path(source_path, runtime) - dest = self._resolve_move_target_path(destination_path, runtime) - try: - validated_source = validate_path(source) - validated_dest = validate_path(dest) - except ValueError as exc: - return f"Error: {exc}" - - if self._is_cloud(): - return await self._cloud_move_file( - runtime, - validated_source, - validated_dest, - overwrite=overwrite, - ) - - backend = self._get_backend(runtime) - res: WriteResult = await backend.amove( - validated_source, validated_dest, overwrite=overwrite - ) - if res.error: - return res.error - update: dict[str, Any] = { - "messages": [ - ToolMessage( - content=f"Moved '{validated_source}' to '{res.path or validated_dest}'", - tool_call_id=runtime.tool_call_id, - ) - ], - } - if res.files_update is not None: - update["files"] = res.files_update - return Command(update=update) - - def sync_move_file( - source_path: Annotated[str, "Absolute or relative source path."], - destination_path: Annotated[str, "Absolute or relative destination path."], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - *, - overwrite: Annotated[ - bool, - "If True, replace existing destination. Cloud mode rejects True. Defaults to False.", - ] = False, - ) -> Command | str: - return self._run_async_blocking( - async_move_file( - source_path, destination_path, runtime, overwrite=overwrite - ) - ) - - return StructuredTool.from_function( - name="move_file", - description=tool_description, - func=sync_move_file, - coroutine=async_move_file, - ) - - async def _cloud_move_file( - self, - runtime: ToolRuntime[None, SurfSenseFilesystemState], - source: str, - dest: str, - *, - overwrite: bool, - ) -> Command | str: - backend = self._get_backend(runtime) - if not isinstance(backend, KBPostgresBackend): - return "Error: cloud move requires KBPostgresBackend." - - if source == dest: - return f"Moved '{source}' to '{dest}' (no-op)" - if overwrite: - return ( - "Error: overwrite=True is not supported in cloud mode. Move/edit " - "the destination doc explicitly first." - ) - if not source.startswith(DOCUMENTS_ROOT + "/"): - return ( - "Error: cloud move_file source must be under /documents/ (got " - f"'{source}')." - ) - if not dest.startswith(DOCUMENTS_ROOT + "/"): - return ( - "Error: cloud move_file destination must be under /documents/ (got " - f"'{dest}')." - ) - anon = runtime.state.get("kb_anon_doc") or {} - if isinstance(anon, dict): - anon_path = str(anon.get("path") or "") - if anon_path and (anon_path in (source, dest)): - return "Error: the anonymous uploaded document is read-only." - - files = runtime.state.get("files") or {} - doc_id_by_path = runtime.state.get("doc_id_by_path") or {} - pending_moves = list(runtime.state.get("pending_moves") or []) - - # Dest collision: occupied in state, in pending moves, or in DB. - if dest in files: - return f"Error: destination '{dest}' already exists." - if any(move.get("dest") == dest for move in pending_moves): - return f"Error: destination '{dest}' already exists." - if dest != source: - existing_dest = await backend._load_file_data(dest) - if existing_dest is not None: - return f"Error: destination '{dest}' already exists." - - # Source materialization: lazy load if not in state. - source_file_data = files.get(source) - source_doc_id = doc_id_by_path.get(source) - if source_file_data is None: - loaded = await backend._load_file_data(source) - if loaded is None: - return f"Error: source '{source}' not found." - source_file_data, loaded_doc_id = loaded - if source_doc_id is None: - source_doc_id = loaded_doc_id - - files_update: dict[str, Any] = {source: None, dest: source_file_data} - update: dict[str, Any] = { - "files": files_update, - "pending_moves": [ - { - "source": source, - "dest": dest, - "overwrite": False, - "tool_call_id": runtime.tool_call_id, - } - ], - "messages": [ - ToolMessage( - content=( - f"Moved '{source}' to '{dest}' (will commit at end of turn)." - ), - tool_call_id=runtime.tool_call_id, - ) - ], - } - - doc_id_update: dict[str, int | None] = {source: None} - if source_doc_id is not None: - doc_id_update[dest] = source_doc_id - update["doc_id_by_path"] = doc_id_update - - dirty_paths = list(runtime.state.get("dirty_paths") or []) - if source in dirty_paths: - new_dirty: list[Any] = [_CLEAR] - for entry in dirty_paths: - new_dirty.append(dest if entry == source else entry) - update["dirty_paths"] = new_dirty - return Command(update=update) - - # ------------------------------------------------------------------ tool: rm - - def _create_rm_tool(self) -> BaseTool: - tool_description = ( - self._custom_tool_descriptions.get("rm") or _CLOUD_RM_TOOL_DESCRIPTION - ) - - async def async_rm( - path: Annotated[ - str, - "Absolute or relative path to the file to delete.", - ], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> Command | str: - if not path or not path.strip(): - return "Error: path is required." - - target = self._resolve_relative(path, runtime) - try: - validated = validate_path(target) - except ValueError as exc: - return f"Error: {exc}" - - if self._is_cloud(): - if validated in ("/", DOCUMENTS_ROOT): - return f"Error: refusing to rm '{validated}'." - if not validated.startswith(DOCUMENTS_ROOT + "/"): - return ( - "Error: cloud rm must target a path under /documents/ " - f"(got '{validated}')." - ) - - anon = runtime.state.get("kb_anon_doc") or {} - if isinstance(anon, dict) and str(anon.get("path") or "") == validated: - return "Error: the anonymous uploaded document is read-only." - - # Refuse if the path looks like a directory. - staged_dirs = list(runtime.state.get("staged_dirs") or []) - if validated in staged_dirs: - return ( - f"Error: '{validated}' is a directory. Use rmdir for " - "empty directories." - ) - pending_dir_deletes = list( - runtime.state.get("pending_dir_deletes") or [] - ) - if any( - isinstance(d, dict) and d.get("path") == validated - for d in pending_dir_deletes - ): - return f"Error: '{validated}' is already queued for rmdir." - - backend = self._get_backend(runtime) - if isinstance(backend, KBPostgresBackend): - # Detect "is a directory" via `ls`: if the path lists - # children we know it's a folder. Otherwise we still - # need to confirm it's a real file before staging. - children = await backend.als_info(validated) - if children: - return ( - f"Error: '{validated}' is a directory. Use rmdir for " - "empty directories." - ) - - # Already queued for delete this turn? - pending_deletes = list(runtime.state.get("pending_deletes") or []) - if any( - isinstance(d, dict) and d.get("path") == validated - for d in pending_deletes - ): - return f"'{validated}' is already queued for deletion." - - # Resolve doc_id (best-effort): file in state or DB. - files_state = runtime.state.get("files") or {} - doc_id_by_path = runtime.state.get("doc_id_by_path") or {} - resolved_doc_id: int | None = doc_id_by_path.get(validated) - if ( - validated not in files_state - and resolved_doc_id is None - and isinstance(backend, KBPostgresBackend) - ): - loaded = await backend._load_file_data(validated) - if loaded is None: - return f"Error: file '{validated}' not found." - _, resolved_doc_id = loaded - - files_update: dict[str, Any] = {validated: None} - update: dict[str, Any] = { - "pending_deletes": [ - { - "path": validated, - "tool_call_id": runtime.tool_call_id, - } - ], - "files": files_update, - "doc_id_by_path": {validated: None}, - "messages": [ - ToolMessage( - content=( - f"Staged delete of '{validated}' (will commit at " - "end of turn)." - ), - tool_call_id=runtime.tool_call_id, - ) - ], - } - - # Drop the path from dirty_paths so a same-turn write+rm - # doesn't recreate the doc at commit time. - dirty_paths = list(runtime.state.get("dirty_paths") or []) - if validated in dirty_paths: - new_dirty: list[Any] = [_CLEAR] - for entry in dirty_paths: - if entry != validated: - new_dirty.append(entry) - update["dirty_paths"] = new_dirty - update["dirty_path_tool_calls"] = {validated: None} - - return Command(update=update) - - # Desktop mode — hit disk immediately. - backend = self._get_backend(runtime) - adelete = getattr(backend, "adelete_file", None) - if not callable(adelete): - return "Error: rm is not supported by the active backend." - res: WriteResult = await adelete(validated) - if res.error: - return res.error - update_desktop: dict[str, Any] = { - "files": {validated: None}, - "messages": [ - ToolMessage( - content=f"Deleted file '{res.path or validated}'", - tool_call_id=runtime.tool_call_id, - ) - ], - } - return Command(update=update_desktop) - - def sync_rm( - path: Annotated[ - str, - "Absolute or relative path to the file to delete.", - ], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> Command | str: - return self._run_async_blocking(async_rm(path, runtime)) - - return StructuredTool.from_function( - name="rm", - description=tool_description, - func=sync_rm, - coroutine=async_rm, - ) - - # ------------------------------------------------------------------ tool: rmdir - - def _create_rmdir_tool(self) -> BaseTool: - tool_description = ( - self._custom_tool_descriptions.get("rmdir") or _CLOUD_RMDIR_TOOL_DESCRIPTION - ) - - async def async_rmdir( - path: Annotated[ - str, - "Absolute or relative path of the empty directory to delete.", - ], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> Command | str: - if not path or not path.strip(): - return "Error: path is required." - - target = self._resolve_relative(path, runtime) - try: - validated = validate_path(target) - except ValueError as exc: - return f"Error: {exc}" - - if self._is_cloud(): - if validated in ("/", DOCUMENTS_ROOT): - return f"Error: refusing to rmdir '{validated}'." - if not validated.startswith(DOCUMENTS_ROOT + "/"): - return ( - "Error: cloud rmdir must target a path under /documents/ " - f"(got '{validated}')." - ) - - cwd = self._current_cwd(runtime) - if validated == cwd or _is_ancestor_of(validated, cwd): - return ( - f"Error: cannot rmdir '{validated}' because the current " - "cwd is at or under it. cd out first." - ) - - staged_dirs = list(runtime.state.get("staged_dirs") or []) - pending_dir_deletes = list( - runtime.state.get("pending_dir_deletes") or [] - ) - if any( - isinstance(d, dict) and d.get("path") == validated - for d in pending_dir_deletes - ): - return f"'{validated}' is already queued for deletion." - - backend = self._get_backend(runtime) - - # The path must currently exist either in DB folder paths or - # in staged_dirs. We rely on KBPostgresBackend.als_info (which - # already accounts for pending deletes/moves) to evaluate - # both existence and emptiness against the post-staged view. - exists_in_staged = validated in staged_dirs - children: list[Any] = [] - if isinstance(backend, KBPostgresBackend): - children = list(await backend.als_info(validated)) - - # Detect "is a file" — if als_info returns no children but - # the path is actually a file, we should reject. We use - # _load_file_data to disambiguate file vs missing folder. - if ( - isinstance(backend, KBPostgresBackend) - and not children - and not exists_in_staged - ): - loaded = await backend._load_file_data(validated) - if loaded is not None: - return ( - f"Error: '{validated}' is a file. Use rm to delete files." - ) - # Confirm folder exists in DB by checking the parent listing. - parent = posixpath.dirname(validated) or "/" - parent_listing = await backend.als_info(parent) - parent_has_dir = any( - info.get("path") == validated and info.get("is_dir") - for info in parent_listing - ) - if not parent_has_dir: - return f"Error: directory '{validated}' not found." - - if children: - return ( - f"Error: directory '{validated}' is not empty. " - "Remove contents first." - ) - - # Same-turn mkdir un-stage: drop the staged_dirs entry - # entirely and skip queuing a DB delete (nothing was ever - # committed). - if exists_in_staged: - rest = [d for d in staged_dirs if d != validated] - return Command( - update={ - "staged_dirs": [_CLEAR, *rest], - "staged_dir_tool_calls": {validated: None}, - "messages": [ - ToolMessage( - content=(f"Un-staged directory '{validated}'."), - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - - return Command( - update={ - "pending_dir_deletes": [ - { - "path": validated, - "tool_call_id": runtime.tool_call_id, - } - ], - "messages": [ - ToolMessage( - content=( - f"Staged rmdir of '{validated}' (will commit " - "at end of turn)." - ), - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - - # Desktop mode — hit disk immediately. - backend = self._get_backend(runtime) - armdir = getattr(backend, "armdir", None) - if not callable(armdir): - return "Error: rmdir is not supported by the active backend." - res: WriteResult = await armdir(validated) - if res.error: - return res.error - return Command( - update={ - "messages": [ - ToolMessage( - content=f"Deleted directory '{res.path or validated}'", - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - - def sync_rmdir( - path: Annotated[ - str, - "Absolute or relative path of the empty directory to delete.", - ], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - ) -> Command | str: - return self._run_async_blocking(async_rmdir(path, runtime)) - - return StructuredTool.from_function( - name="rmdir", - description=tool_description, - func=sync_rmdir, - coroutine=async_rmdir, - ) - - # ------------------------------------------------------------------ tool: list_tree - - def _create_list_tree_tool(self) -> BaseTool: - tool_description = ( - self._custom_tool_descriptions.get("list_tree") - or SURFSENSE_LIST_TREE_TOOL_DESCRIPTION - ) - - async def async_list_tree( - runtime: ToolRuntime[None, SurfSenseFilesystemState], - path: Annotated[ - str, - "Absolute path to start from. Defaults to /documents in cloud mode.", - ] = "", - max_depth: Annotated[int, "Recursion depth limit. Default 8."] = 8, - page_size: Annotated[int, "Maximum entries returned. Max 1000."] = 500, - include_files: Annotated[bool, "Include file entries."] = True, - include_dirs: Annotated[bool, "Include directory entries."] = True, - ) -> str: - if max_depth < 0: - return "Error: max_depth must be >= 0." - if page_size < 1: - return "Error: page_size must be >= 1." - if not include_files and not include_dirs: - return "Error: include_files and include_dirs cannot both be false." - - target = self._resolve_list_target_path(path, runtime) - try: - validated = validate_path(target) - except ValueError as exc: - return f"Error: {exc}" - - backend = self._get_backend(runtime) - if isinstance(backend, KBPostgresBackend): - result = await backend.alist_tree_listing( - validated, - max_depth=max_depth, - page_size=page_size, - include_files=include_files, - include_dirs=include_dirs, - ) - elif hasattr(backend, "alist_tree"): - result = await backend.alist_tree( - validated, - max_depth=max_depth, - page_size=page_size, - include_files=include_files, - include_dirs=include_dirs, - ) - else: - return "Error: list_tree is not supported by the active backend." - - if isinstance(result, dict) and isinstance(result.get("error"), str): - return result["error"] - return json.dumps(result, ensure_ascii=True) - - def sync_list_tree( - runtime: ToolRuntime[None, SurfSenseFilesystemState], - path: Annotated[ - str, - "Absolute path to start from. Defaults to /documents in cloud mode.", - ] = "", - max_depth: Annotated[int, "Recursion depth limit. Default 8."] = 8, - page_size: Annotated[int, "Maximum entries returned. Max 1000."] = 500, - include_files: Annotated[bool, "Include file entries."] = True, - include_dirs: Annotated[bool, "Include directory entries."] = True, - ) -> str: - return self._run_async_blocking( - async_list_tree( - runtime, - path=path, - max_depth=max_depth, - page_size=page_size, - include_files=include_files, - include_dirs=include_dirs, - ) - ) - - return StructuredTool.from_function( - name="list_tree", - description=tool_description, - func=sync_list_tree, - coroutine=async_list_tree, - ) - - # ------------------------------------------------------------------ tool: execute_code (sandbox) - - def _create_execute_code_tool(self) -> BaseTool: - def sync_execute_code( - command: Annotated[ - str, "Python code to execute. Use print() to see output." - ], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - timeout: Annotated[ - int | None, - "Optional timeout in seconds.", - ] = None, - ) -> str: - if timeout is not None: - if timeout < 0: - return f"Error: timeout must be non-negative, got {timeout}." - if timeout > self._MAX_EXECUTE_TIMEOUT: - return f"Error: timeout {timeout}s exceeds maximum ({self._MAX_EXECUTE_TIMEOUT}s)." - return self._run_async_blocking( - self._execute_in_sandbox(command, runtime, timeout) - ) - - async def async_execute_code( - command: Annotated[ - str, "Python code to execute. Use print() to see output." - ], - runtime: ToolRuntime[None, SurfSenseFilesystemState], - timeout: Annotated[ - int | None, - "Optional timeout in seconds.", - ] = None, - ) -> str: - if timeout is not None: - if timeout < 0: - return f"Error: timeout must be non-negative, got {timeout}." - if timeout > self._MAX_EXECUTE_TIMEOUT: - return f"Error: timeout {timeout}s exceeds maximum ({self._MAX_EXECUTE_TIMEOUT}s)." - return await self._execute_in_sandbox(command, runtime, timeout) - - return StructuredTool.from_function( - name="execute_code", - description=SURFSENSE_EXECUTE_CODE_TOOL_DESCRIPTION, - func=sync_execute_code, - coroutine=async_execute_code, - ) - - @staticmethod - def _wrap_as_python(code: str) -> str: - sentinel = f"_PYEOF_{secrets.token_hex(8)}" - return f"python3 << '{sentinel}'\n{code}\n{sentinel}" - - async def _execute_in_sandbox( - self, - command: str, - runtime: ToolRuntime[None, SurfSenseFilesystemState], - timeout: int | None, - ) -> str: - assert self._thread_id is not None - command = self._wrap_as_python(command) - try: - return await self._try_sandbox_execute(command, runtime, timeout) - except (DaytonaError, Exception) as first_err: - logger.warning( - "Sandbox execute failed for thread %s, retrying: %s", - self._thread_id, - first_err, - ) - try: - await delete_sandbox(self._thread_id) - except Exception: - _evict_sandbox_cache(self._thread_id) - try: - return await self._try_sandbox_execute(command, runtime, timeout) - except Exception: - logger.exception( - "Sandbox retry also failed for thread %s", self._thread_id - ) - return "Error: Code execution is temporarily unavailable. Please try again." - - async def _try_sandbox_execute( - self, - command: str, - runtime: ToolRuntime[None, SurfSenseFilesystemState], - timeout: int | None, - ) -> str: - sandbox, _is_new = await get_or_create_sandbox(self._thread_id) - result = await sandbox.aexecute(command, timeout=timeout) - output = (result.output or "").strip() - if not output and result.exit_code == 0: - return ( - "[Code executed successfully but produced no output. " - "Use print() to display results, then try again.]" - ) - parts = [result.output] - if result.exit_code is not None: - status = "succeeded" if result.exit_code == 0 else "failed" - parts.append(f"\n[Command {status} with exit code {result.exit_code}]") - if result.truncated: - parts.append("\n[Output was truncated due to size limits]") - return "".join(parts) diff --git a/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py b/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py deleted file mode 100644 index 29cd57aa0..000000000 --- a/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py +++ /dev/null @@ -1,233 +0,0 @@ -r"""Coalesce multi-block system messages into a single text block. - -Several middlewares in our deepagent stack each call -``append_to_system_message`` on the way down to the model -(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``, -``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the -request reaches the LLM, the system message has 5+ separate text blocks. - -Anthropic enforces a hard cap of **4 ``cache_control`` blocks per -request**, and we configure 2 injection points -(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting -the prepended ``request.system_message``, this middleware is the -defensive partner: it guarantees that "the system block" is *one* -content block, so LiteLLM's ``AnthropicCacheControlHook`` and any -OpenRouter→Anthropic transformer can never multiply our budget into -several breakpoints by spreading ``cache_control`` across multiple -text blocks of a multi-block system content. - -Without flattening we used to see:: - - OpenrouterException - {"error":{"message":"Provider returned error", - "code":400,"metadata":{"raw":"...A maximum of 4 blocks with - cache_control may be provided. Found 5."}}} - -(Same error class documented in -https://github.com/BerriAI/litellm/issues/15696 and -https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix -in PR #15395 covers the litellm transformer but does not protect us -when the OpenRouter SaaS itself does the redistribution.) - -A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching -the first injection point from ``role: system`` to ``index: 0``) -neutralises the *primary* cause of the same 400 — multiple -``SystemMessage``\ s injected by ``before_agent`` middlewares -(priority/tree/memory/file-intent/anonymous-doc) accumulating across -turns, each tagged with ``cache_control`` by the ``role: system`` -matcher. This middleware remains useful as defence-in-depth against -the multi-block redistribution path. - -Placement: innermost on the system-message-mutation chain, after every -appender (``todo``/``filesystem``/``skills``/``subagents``) and after -summarization, but before ``noop``/``retry``/``fallback`` so each retry -attempt sees a flattened payload. See ``chat_deepagent.py``. - -Idempotent: a string-content system message is left untouched. A list -that contains anything other than plain text blocks (e.g. an image) is -also left untouched — those are rare on system messages and we'd lose -the non-text payload by joining. -""" - -from __future__ import annotations - -import logging -from collections.abc import Awaitable, Callable -from typing import Any - -from langchain.agents.middleware.types import ( - AgentMiddleware, - AgentState, - ContextT, - ModelRequest, - ModelResponse, - ResponseT, -) -from langchain_core.messages import SystemMessage - -logger = logging.getLogger(__name__) - - -def _flatten_text_blocks(content: list[Any]) -> str | None: - """Return joined text if every block is a plain ``{"type": "text"}``. - - Returns ``None`` when the list contains anything that isn't a text - block we can safely concatenate (image, audio, file, non-standard - blocks, dicts with extra non-cache_control fields). The caller - leaves the original content untouched in that case rather than - silently dropping payload. - - ``cache_control`` on individual blocks is intentionally discarded — - the whole point of flattening is to let LiteLLM's - ``cache_control_injection_points`` re-place a single breakpoint on - the resulting one-block system content. - """ - chunks: list[str] = [] - for block in content: - if isinstance(block, str): - chunks.append(block) - continue - if not isinstance(block, dict): - return None - if block.get("type") != "text": - return None - text = block.get("text") - if not isinstance(text, str): - return None - chunks.append(text) - return "\n\n".join(chunks) - - -def _flattened_request( - request: ModelRequest[ContextT], -) -> ModelRequest[ContextT] | None: - """Return a request with system_message flattened, or ``None`` for no-op.""" - sys_msg = request.system_message - if sys_msg is None: - return None - content = sys_msg.content - if not isinstance(content, list) or len(content) <= 1: - return None - - flattened = _flatten_text_blocks(content) - if flattened is None: - return None - - new_sys = SystemMessage( - content=flattened, - additional_kwargs=dict(sys_msg.additional_kwargs), - response_metadata=dict(sys_msg.response_metadata), - ) - if sys_msg.id is not None: - new_sys.id = sys_msg.id - return request.override(system_message=new_sys) - - -def _diagnostic_summary(request: ModelRequest[Any]) -> str: - """One-line dump of cache_control-relevant request shape. - - Temporary diagnostic to prove where the ``Found N`` cache_control - breakpoints are coming from when Anthropic 400s. Removed once the - root cause is confirmed and a fix is in place. - """ - sys_msg = request.system_message - if sys_msg is None: - sys_shape = "none" - elif isinstance(sys_msg.content, str): - sys_shape = f"str(len={len(sys_msg.content)})" - elif isinstance(sys_msg.content, list): - sys_shape = f"list(blocks={len(sys_msg.content)})" - else: - sys_shape = f"other({type(sys_msg.content).__name__})" - - role_hist: list[str] = [] - multi_block_msgs = 0 - msgs_with_cc = 0 - sys_msgs_in_history = 0 - for m in request.messages: - mtype = getattr(m, "type", type(m).__name__) - role_hist.append(mtype) - if isinstance(m, SystemMessage): - sys_msgs_in_history += 1 - c = getattr(m, "content", None) - if isinstance(c, list): - multi_block_msgs += 1 - for blk in c: - if isinstance(blk, dict) and "cache_control" in blk: - msgs_with_cc += 1 - break - if "cache_control" in getattr(m, "additional_kwargs", {}) or {}: - msgs_with_cc += 1 - - tools = request.tools or [] - tools_with_cc = 0 - for t in tools: - if isinstance(t, dict) and ( - "cache_control" in t or "cache_control" in t.get("function", {}) - ): - tools_with_cc += 1 - - return ( - f"sys={sys_shape} msgs={len(request.messages)} " - f"sys_msgs_in_history={sys_msgs_in_history} " - f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} " - f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} " - f"roles={role_hist[-8:]}" - ) - - -class FlattenSystemMessageMiddleware( - AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT] -): - """Collapse a multi-text-block system message to a single string. - - Sits innermost on the system-message-mutation chain so it observes - every middleware's contribution. Has no other side effect — the - body of every block is preserved, just joined with ``"\\n\\n"``. - """ - - def __init__(self) -> None: - super().__init__() - self.tools = [] - - def wrap_model_call( # type: ignore[override] - self, - request: ModelRequest[ContextT], - handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], - ) -> Any: - if logger.isEnabledFor(logging.DEBUG): - logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request)) - flattened = _flattened_request(request) - if flattened is not None: - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "[flatten_system] collapsed %d system blocks to one", - len(request.system_message.content), # type: ignore[arg-type, union-attr] - ) - return handler(flattened) - return handler(request) - - async def awrap_model_call( # type: ignore[override] - self, - request: ModelRequest[ContextT], - handler: Callable[ - [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]] - ], - ) -> Any: - if logger.isEnabledFor(logging.DEBUG): - logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request)) - flattened = _flattened_request(request) - if flattened is not None: - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "[flatten_system] collapsed %d system blocks to one", - len(request.system_message.content), # type: ignore[arg-type, union-attr] - ) - return await handler(flattened) - return await handler(request) - - -__all__ = [ - "FlattenSystemMessageMiddleware", - "_flatten_text_blocks", - "_flattened_request", -] diff --git a/surfsense_backend/app/agents/new_chat/middleware/permission.py b/surfsense_backend/app/agents/new_chat/middleware/permission.py deleted file mode 100644 index 07549bedb..000000000 --- a/surfsense_backend/app/agents/new_chat/middleware/permission.py +++ /dev/null @@ -1,427 +0,0 @@ -""" -PermissionMiddleware — pattern-based allow/deny/ask with HITL fallback. - -LangChain's :class:`HumanInTheLoopMiddleware` only supports a static -"this tool always asks" decision per tool. There's no rule-based -allow/deny/ask layered ruleset, no glob patterns, no per-search-space or -per-thread overrides, and no auto-deny synthesis. - -This middleware ports OpenCode's ``packages/opencode/src/permission/index.ts`` -ruleset model on top of SurfSense's existing ``interrupt({type, action, -context})`` payload shape (see ``app/agents/new_chat/tools/hitl.py``) so -the frontend keeps working unchanged. - -Operation: -1. ``aafter_model`` inspects the latest ``AIMessage.tool_calls``. -2. For each call, the middleware builds a list of ``patterns`` (the - tool name plus any tool-specific patterns from the resolver). It - evaluates each pattern against the layered rulesets and aggregates - the results: ``deny`` > ``ask`` > ``allow``. -3. On ``deny``: replaces the call with a synthetic ``ToolMessage`` - containing a :class:`StreamingError`. -4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. Both the legacy - SurfSense shape and LangChain HITL ``{"decisions": [{"type": ...}]}`` - replies are accepted via :func:`_normalize_permission_decision`. - - ``once``: proceed. - - ``approve_always``: also persist allow rules for ``request.always`` patterns. - - ``reject`` w/o feedback: raise :class:`RejectedError`. - - ``reject`` w/ feedback: raise :class:`CorrectedError`. -5. On ``allow``: proceed unchanged. - -The middleware also performs a *pre-model* tool-filter step (the -``before_model`` hook) so globally denied tools are stripped from the -exposed tool list before the model gets to see them. This mirrors -OpenCode's ``Permission.disabled`` and dramatically reduces the chance -the model emits a deny-only call. -""" - -from __future__ import annotations - -import logging -from collections.abc import Callable -from typing import Any - -from langchain.agents.middleware.types import ( - AgentMiddleware, - AgentState, - ContextT, -) -from langchain_core.messages import AIMessage, ToolMessage -from langgraph.runtime import Runtime -from langgraph.types import interrupt - -from app.agents.new_chat.errors import ( - CorrectedError, - RejectedError, - StreamingError, -) -from app.agents.new_chat.permissions import ( - Rule, - Ruleset, - aggregate_action, - evaluate_many, -) -from app.observability import metrics as ot_metrics, otel as ot - -logger = logging.getLogger(__name__) - - -# Mapping ``tool_name -> resolver`` that converts ``args`` to a list of -# patterns to evaluate. The first pattern is conventionally the bare -# tool name; later entries narrow down to specific resources. -PatternResolver = Callable[[dict[str, Any]], list[str]] - - -def _default_pattern_resolver(name: str) -> PatternResolver: - def _resolve(args: dict[str, Any]) -> list[str]: - # Bare name covers the default catch-all; primary-arg fallbacks - # are best added per-tool by callers. - del args - return [name] - - return _resolve - - -# Translation from the LangChain HITL envelope (what ``stream_resume_chat`` -# sends) to SurfSense's legacy ``decision_type`` shape. ``edit`` keeps the -# original tool args — tools needing argument edits should use -# ``request_approval`` from ``app/agents/new_chat/tools/hitl.py``. -_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = { - "approve": "once", - "reject": "reject", - "edit": "once", - "approve_always": "approve_always", -} - - -def _normalize_permission_decision(decision: Any) -> dict[str, Any]: - """Coerce any accepted reply shape into ``{"decision_type": ..., "feedback"?}``. - - Falls back to ``reject`` (with a warning) on unrecognized payloads so the - middleware fails closed. - """ - if isinstance(decision, str): - return {"decision_type": decision} - if not isinstance(decision, dict): - logger.warning( - "Unrecognized permission resume value (%s); treating as reject", - type(decision).__name__, - ) - return {"decision_type": "reject"} - - if decision.get("decision_type"): - return decision - - payload: dict[str, Any] = decision - decisions = decision.get("decisions") - if isinstance(decisions, list) and decisions: - first = decisions[0] - if isinstance(first, dict): - payload = first - - raw_type = payload.get("type") or payload.get("decision_type") - if not raw_type: - logger.warning( - "Permission resume missing decision type (keys=%s); treating as reject", - list(payload.keys()), - ) - return {"decision_type": "reject"} - - raw_type = str(raw_type).lower() - mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type) - if mapped is None: - # Tolerate legacy values arriving without ``decision_type`` wrapping. - if raw_type in {"once", "approve_always", "reject"}: - mapped = raw_type - else: - logger.warning( - "Unknown permission decision type %r; treating as reject", raw_type - ) - mapped = "reject" - - if raw_type == "edit": - logger.warning( - "Permission middleware received an 'edit' decision; original args " - "kept (edits not merged here)." - ) - - out: dict[str, Any] = {"decision_type": mapped} - feedback = payload.get("feedback") or payload.get("message") - if isinstance(feedback, str) and feedback.strip(): - out["feedback"] = feedback - return out - - -class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] - """Allow/deny/ask layer over the agent's tool calls. - - Args: - rulesets: Layered rulesets to evaluate. Earlier entries are - overridden by later ones (last-match-wins). Typical layering: - ``defaults < global < space < thread < runtime_approved``. - pattern_resolvers: Optional per-tool callables that return a list - of patterns to evaluate. When a tool isn't listed, the bare - tool name is used as the only pattern. - runtime_ruleset: Mutable :class:`Ruleset` that the middleware - extends in-place when the user replies ``"approve_always"`` to - an ask interrupt. Reused across all calls in the same agent - instance so newly-allowed rules apply to subsequent calls. - always_emit_interrupt_payload: If True, every ask uses the - SurfSense interrupt wire format (default). Set False to - disable interrupts and treat ``ask`` as ``deny`` for - non-interactive deployments. - """ - - tools = () - - def __init__( - self, - *, - rulesets: list[Ruleset] | None = None, - pattern_resolvers: dict[str, PatternResolver] | None = None, - runtime_ruleset: Ruleset | None = None, - always_emit_interrupt_payload: bool = True, - ) -> None: - super().__init__() - self._static_rulesets: list[Ruleset] = list(rulesets or []) - self._pattern_resolvers: dict[str, PatternResolver] = dict( - pattern_resolvers or {} - ) - self._runtime_ruleset: Ruleset = runtime_ruleset or Ruleset( - origin="runtime_approved" - ) - self._emit_interrupt = always_emit_interrupt_payload - - # ------------------------------------------------------------------ - # Tool-filter step (mirrors OpenCode's ``Permission.disabled``) - # ------------------------------------------------------------------ - - def _globally_denied(self, tool_name: str) -> bool: - """Return True if a deny rule with no narrowing pattern matches.""" - rules = evaluate_many(tool_name, ["*"], *self._all_rulesets()) - return aggregate_action(rules) == "deny" - - def _all_rulesets(self) -> list[Ruleset]: - return [*self._static_rulesets, self._runtime_ruleset] - - # NOTE: ``before_model`` filtering of the tools list is left to the - # agent factory. This middleware only blocks at execution time — and - # only via the rule-evaluator path, not by mutating ``request.tools``. - # Mutating ``request.tools`` per-call would invalidate provider - # prompt-cache prefixes (see Operational risks: prompt-cache regression). - - # ------------------------------------------------------------------ - # Tool-call evaluation - # ------------------------------------------------------------------ - - def _resolve_patterns(self, tool_name: str, args: dict[str, Any]) -> list[str]: - resolver = self._pattern_resolvers.get( - tool_name, _default_pattern_resolver(tool_name) - ) - try: - patterns = resolver(args or {}) - except Exception: - logger.exception( - "Pattern resolver for %s raised; using bare name", tool_name - ) - patterns = [tool_name] - if not patterns: - patterns = [tool_name] - return patterns - - def _evaluate( - self, tool_name: str, args: dict[str, Any] - ) -> tuple[str, list[str], list[Rule]]: - patterns = self._resolve_patterns(tool_name, args) - rules = evaluate_many(tool_name, patterns, *self._all_rulesets()) - action = aggregate_action(rules) - return action, patterns, rules - - # ------------------------------------------------------------------ - # HITL ask flow — SurfSense wire format - # ------------------------------------------------------------------ - - def _raise_interrupt( - self, - *, - tool_name: str, - args: dict[str, Any], - patterns: list[str], - rules: list[Rule], - ) -> dict[str, Any]: - """Block on user approval via SurfSense's ``interrupt`` shape.""" - if not self._emit_interrupt: - return {"decision_type": "reject"} - - # ``params`` (NOT ``args``) is what SurfSense's streaming - # normalizer forwards. Other fields move into ``context``. - payload = { - "type": "permission_ask", - "action": {"tool": tool_name, "params": args or {}}, - "context": { - "patterns": patterns, - "rules": [ - { - "permission": r.permission, - "pattern": r.pattern, - "action": r.action, - } - for r in rules - ], - # Rules of thumb for the frontend: surface the patterns - # the user can promote to "approve_always" with a single reply. - "always": patterns, - }, - } - # Open ``permission.asked`` + ``interrupt.raised`` OTel spans - # (no-op when OTel is disabled) so dashboards can correlate - # "we asked X" with "interrupt was actually delivered". - with ( - ot.permission_asked_span( - permission=tool_name, - pattern=patterns[0] if patterns else None, - extra={"permission.patterns": list(patterns)}, - ), - ot.interrupt_span(interrupt_type="permission_ask"), - ): - ot_metrics.record_permission_ask(permission=tool_name) - ot_metrics.record_interrupt(interrupt_type="permission_ask") - decision = interrupt(payload) - return _normalize_permission_decision(decision) - - def _persist_always(self, tool_name: str, patterns: list[str]) -> None: - """Promote ``approve_always`` reply into runtime allow rules. - - Persistence to ``agent_permission_rules`` is done by the - streaming layer (``stream_new_chat``) once it observes the - ``approve_always`` reply — the middleware just keeps an - in-memory copy so subsequent calls in the same stream see the rule. - """ - for pattern in patterns: - self._runtime_ruleset.rules.append( - Rule(permission=tool_name, pattern=pattern, action="allow") - ) - - # ------------------------------------------------------------------ - # Synthesizing deny -> ToolMessage - # ------------------------------------------------------------------ - - @staticmethod - def _deny_message( - tool_call: dict[str, Any], - rule: Rule, - ) -> ToolMessage: - err = StreamingError( - code="permission_denied", - retryable=False, - suggestion=( - f"rule permission={rule.permission!r} pattern={rule.pattern!r} " - f"blocked this call" - ), - ) - return ToolMessage( - content=( - f"Permission denied: rule {rule.permission}/{rule.pattern} " - f"blocked tool {tool_call.get('name')!r}." - ), - tool_call_id=tool_call.get("id") or "", - name=tool_call.get("name"), - status="error", - additional_kwargs={"error": err.model_dump()}, - ) - - # ------------------------------------------------------------------ - # The hook: aafter_model - # ------------------------------------------------------------------ - - def _process( - self, - state: AgentState, - runtime: Runtime[Any], - ) -> dict[str, Any] | None: - del runtime # unused - messages = state.get("messages") or [] - if not messages: - return None - last = messages[-1] - if not isinstance(last, AIMessage) or not last.tool_calls: - return None - - deny_messages: list[ToolMessage] = [] - kept_calls: list[dict[str, Any]] = [] - any_change = False - - for raw in last.tool_calls: - call = ( - dict(raw) - if isinstance(raw, dict) - else { - "name": getattr(raw, "name", None), - "args": getattr(raw, "args", {}), - "id": getattr(raw, "id", None), - "type": "tool_call", - } - ) - name = call.get("name") or "" - args = call.get("args") or {} - action, patterns, rules = self._evaluate(name, args) - - if action == "deny": - # Find the deny rule for the suggestion text - deny_rule = next((r for r in rules if r.action == "deny"), rules[0]) - deny_messages.append(self._deny_message(call, deny_rule)) - any_change = True - continue - - if action == "ask": - decision = self._raise_interrupt( - tool_name=name, args=args, patterns=patterns, rules=rules - ) - kind = str(decision.get("decision_type") or "reject").lower() - if kind == "once": - kept_calls.append(call) - elif kind == "approve_always": - self._persist_always(name, patterns) - kept_calls.append(call) - elif kind == "reject": - feedback = decision.get("feedback") - if isinstance(feedback, str) and feedback.strip(): - raise CorrectedError(feedback, tool=name) - raise RejectedError( - tool=name, pattern=patterns[0] if patterns else None - ) - else: - logger.warning( - "Unknown permission decision %r; treating as reject", kind - ) - raise RejectedError(tool=name) - continue - - # allow - kept_calls.append(call) - - if not any_change and len(kept_calls) == len(last.tool_calls): - return None - - updated = last.model_copy(update={"tool_calls": kept_calls}) - result_messages: list[Any] = [updated] - if deny_messages: - result_messages.extend(deny_messages) - return {"messages": result_messages} - - def after_model( # type: ignore[override] - self, state: AgentState, runtime: Runtime[ContextT] - ) -> dict[str, Any] | None: - return self._process(state, runtime) - - async def aafter_model( # type: ignore[override] - self, state: AgentState, runtime: Runtime[ContextT] - ) -> dict[str, Any] | None: - return self._process(state, runtime) - - -__all__ = [ - "PatternResolver", - "PermissionMiddleware", - "_normalize_permission_decision", -] diff --git a/surfsense_backend/app/agents/new_chat/prompt_caching.py b/surfsense_backend/app/agents/new_chat/prompt_caching.py deleted file mode 100644 index b58a48266..000000000 --- a/surfsense_backend/app/agents/new_chat/prompt_caching.py +++ /dev/null @@ -1,241 +0,0 @@ -r"""LiteLLM-native prompt caching configuration for SurfSense agents. - -Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never -activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)`` -gate always failed) with LiteLLM's universal caching mechanism. - -Coverage: - -- Marker-based providers (need ``cache_control`` injection, which LiteLLM - performs automatically when ``cache_control_injection_points`` is set): - ``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``, - ``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/`` - (Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM). -- Auto-cached (LiteLLM strips the marker silently): ``openai/``, - ``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024 - tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``. - -We inject **two** breakpoints per request: - -- ``index: 0`` — pins the SurfSense system prompt at the head of the - request (provider variant, citation rules, tool catalog, KB tree, - skills metadata). The langchain agent factory always prepends - ``request.system_message`` at index 0 (see ``factory.py`` - ``_execute_model_async``), so this targets exactly the main system - prompt regardless of how many other ``SystemMessage``\ s the - ``before_agent`` injectors (priority, tree, memory, file-intent, - anonymous-doc) have inserted into ``state["messages"]``. Using - ``role: system`` here would apply ``cache_control`` to **every** - system-role message and trip Anthropic's hard cap of 4 cache - breakpoints per request once the conversation accumulates enough - injected system messages — which surfaces as the upstream 400 - ``A maximum of 4 blocks with cache_control may be provided. Found N`` - via OpenRouter→Anthropic. -- ``index: -1`` — pins the latest message so multi-turn savings compound: - Anthropic-family providers use longest-matching-prefix lookup, so turn - N+1 still reads turn N's cache up to the shared prefix. - -For OpenAI-family configs we additionally pass: - -- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that - raises hit rate by sending requests with a shared prefix to the same - backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and - ``azure/`` (added to LiteLLM's Azure transformer in - https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified - against ``AzureOpenAIConfig.get_supported_openai_params`` in our - installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``, - ``azure/gpt-5.4``, ``azure/gpt-5.4-mini``). -- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default - 5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's - server-side support landed in Microsoft's docs on 2026-05-13 but - LiteLLM 1.83.14's Azure transformer still omits it from its supported - params list, so it gets silently dropped by ``litellm.drop_params``. - Azure's default in-memory retention (5-10 min, max 1 h) already - bridges intra-conversation turns; revisit when LiteLLM bumps Azure. - -Safety net: ``litellm.drop_params=True`` is set globally in -``app.services.llm_service`` at module-load time. Any kwarg the destination -provider doesn't recognise is auto-stripped at the provider transformer -layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on -``prompt_cache_key`` etc. -""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING, Any - -from langchain_core.language_models import BaseChatModel - -if TYPE_CHECKING: - from app.agents.new_chat.llm_config import AgentConfig - -logger = logging.getLogger(__name__) - - -# Two-breakpoint policy: head-of-request + latest message. See module -# docstring for rationale. Anthropic caps requests at 4 ``cache_control`` -# blocks; we use 2 here, leaving headroom for Phase-2 tool caching. -# -# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's -# ``before_agent`` middlewares (priority, tree, memory, file-intent, -# anonymous-doc) insert ``SystemMessage`` instances into -# ``state["messages"]`` that accumulate across turns. With -# ``role: system`` the LiteLLM hook would tag *every* one of them with -# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0`` -# always targets the langchain-prepended ``request.system_message`` -# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text -# block), giving us exactly one stable cache breakpoint. -_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = ( - {"location": "message", "index": 0}, - {"location": "message", "index": -1}, -) - -# Providers (uppercase ``AgentConfig.provider`` values) that accept the -# OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs -# (2026-05-13) confirm automatic prompt caching applies to every GPT-4o -# or newer Azure deployment at ≥1024 tokens with no configuration needed, -# and that ``prompt_cache_key`` is combined with the prefix hash to -# improve routing affinity and therefore cache hit rate. LiteLLM's Azure -# transformer ships ``prompt_cache_key`` in its supported params as of -# https://github.com/BerriAI/litellm/pull/20989. -# -# Strict whitelist — many other providers in ``PROVIDER_MAP`` route -# through litellm's ``openai`` prefix without implementing the OpenAI -# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer -# family from the litellm prefix alone. -_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset( - {"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"} -) - -# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept -# ``prompt_cache_retention="24h"``. Azure is excluded: see module -# docstring — LiteLLM 1.83.14's Azure transformer omits the param so -# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM -# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``. -_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset( - {"OPENAI", "DEEPSEEK", "XAI"} -) - - -def _is_router_llm(llm: BaseChatModel) -> bool: - """Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import. - - Importing ``app.services.llm_router_service`` at module-load time would - create a cycle via ``llm_config -> prompt_caching -> llm_router_service``. - Class-name comparison is sufficient since the class is defined in a - single place. - """ - return type(llm).__name__ == "ChatLiteLLMRouter" - - -def _provider_supports_prompt_cache_key(agent_config: AgentConfig | None) -> bool: - """Whether the config targets a provider that accepts ``prompt_cache_key``. - - Strict — only returns True for explicitly chosen OPENAI, DEEPSEEK, - XAI, AZURE, or AZURE_OPENAI providers. Auto-mode and custom - providers return False because we can't statically know the - destination and the router fans out across mixed providers. - """ - if agent_config is None or not agent_config.provider: - return False - if agent_config.is_auto_mode: - return False - if agent_config.custom_provider: - return False - return agent_config.provider.upper() in _PROMPT_CACHE_KEY_PROVIDERS - - -def _provider_supports_prompt_cache_retention( - agent_config: AgentConfig | None, -) -> bool: - """Whether the config targets a provider that accepts ``prompt_cache_retention``. - - Tighter than :func:`_provider_supports_prompt_cache_key` — Azure - deployments are excluded until LiteLLM ships the param in its Azure - transformer (see module docstring). - """ - if agent_config is None or not agent_config.provider: - return False - if agent_config.is_auto_mode: - return False - if agent_config.custom_provider: - return False - return agent_config.provider.upper() in _PROMPT_CACHE_RETENTION_PROVIDERS - - -def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None: - """Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail. - - Initialises the field to ``{}`` when present-but-None on a Pydantic v2 - model. Returns ``None`` if the LLM type doesn't expose a writable - ``model_kwargs`` attribute (caller should treat as no-op). - """ - model_kwargs = getattr(llm, "model_kwargs", None) - if isinstance(model_kwargs, dict): - return model_kwargs - try: - llm.model_kwargs = {} # type: ignore[attr-defined] - except Exception: - return None - refreshed = getattr(llm, "model_kwargs", None) - return refreshed if isinstance(refreshed, dict) else None - - -def apply_litellm_prompt_caching( - llm: BaseChatModel, - *, - agent_config: AgentConfig | None = None, - thread_id: int | None = None, -) -> None: - """Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter. - - Idempotent — values already present in ``llm.model_kwargs`` (e.g. from - ``agent_config.litellm_params`` overrides) are preserved. Mutates - ``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion`` - via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge - in our custom ``ChatLiteLLMRouter``. - - Args: - llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance. - agent_config: Optional ``AgentConfig`` driving provider-specific - behaviour. When omitted (or auto-mode), only the universal - ``cache_control_injection_points`` are set. - thread_id: Optional thread id used to construct a per-thread - ``prompt_cache_key`` for OpenAI-family providers. Caching still - works without it (server-side automatic), but the key improves - backend routing affinity and therefore hit rate. - """ - model_kwargs = _get_or_init_model_kwargs(llm) - if model_kwargs is None: - logger.debug( - "apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping", - type(llm).__name__, - ) - return - - if "cache_control_injection_points" not in model_kwargs: - model_kwargs["cache_control_injection_points"] = [ - dict(point) for point in _DEFAULT_INJECTION_POINTS - ] - - # OpenAI-style extras only when we statically know the destination - # accepts them. Auto-mode router fans out across mixed providers so - # we can't safely set destination-specific kwargs there (drop_params - # would strip them but it's wasteful to set them in the first - # place). - if _is_router_llm(llm): - return - - if ( - thread_id is not None - and "prompt_cache_key" not in model_kwargs - and _provider_supports_prompt_cache_key(agent_config) - ): - model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}" - - if ( - "prompt_cache_retention" not in model_kwargs - and _provider_supports_prompt_cache_retention(agent_config) - ): - model_kwargs["prompt_cache_retention"] = "24h" diff --git a/surfsense_backend/app/agents/new_chat/subagents/__init__.py b/surfsense_backend/app/agents/new_chat/subagents/__init__.py deleted file mode 100644 index bd1823b57..000000000 --- a/surfsense_backend/app/agents/new_chat/subagents/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Specialized user-facing subagents for the SurfSense agent. - -The :class:`deepagents.SubAgentMiddleware` already provides the -materialization machinery (each :class:`deepagents.SubAgent` typed-dict -spec is compiled into an ephemeral runnable invoked via the ``task`` -tool); what's specific to SurfSense is the *seeding* of those subagents -with declarative deny rules. - -Per-subagent permission rules are injected as a -:class:`PermissionMiddleware` entry inside the subagent's ``middleware`` -field. The auto-deny pattern (e.g. forbid ``task``/``todowrite`` -recursion, block write tools for read-only research roles) is borrowed -from OpenCode's ``packages/opencode/src/tool/task.ts``, which has -analogous logic for restricting child sessions. -""" - -from .config import ( - build_connector_negotiator_subagent, - build_explore_subagent, - build_report_writer_subagent, - build_specialized_subagents, -) -from .providers.linear import build_linear_specialist_subagent -from .providers.slack import build_slack_specialist_subagent - -__all__ = [ - "build_connector_negotiator_subagent", - "build_explore_subagent", - "build_linear_specialist_subagent", - "build_report_writer_subagent", - "build_slack_specialist_subagent", - "build_specialized_subagents", -] diff --git a/surfsense_backend/app/agents/new_chat/subagents/config.py b/surfsense_backend/app/agents/new_chat/subagents/config.py deleted file mode 100644 index 2cfd47441..000000000 --- a/surfsense_backend/app/agents/new_chat/subagents/config.py +++ /dev/null @@ -1,436 +0,0 @@ -"""Builders for specialized SurfSense subagents. - -Each subagent is built from three pieces: - -1. A name + description + system prompt (the user-facing contract for - when ``task`` should delegate to this role). -2. A filtered tool list (subset of the parent's bound tools). -3. A :class:`PermissionMiddleware` instance carrying a deny ruleset that - prevents the subagent from acting outside its scope (e.g. an - explore-only role cannot mutate state). - -Skill sources (``/skills/builtin/`` + ``/skills/space/``) are inherited -from the parent unconditionally — every subagent benefits from the same -authored guidance documents. -""" - -from __future__ import annotations - -import logging -from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Any - -from app.agents.new_chat.middleware.skills_backends import default_skills_sources -from app.agents.new_chat.permissions import Rule, Ruleset -from app.agents.new_chat.subagents.providers.linear import ( - build_linear_specialist_subagent, -) -from app.agents.new_chat.subagents.providers.slack import ( - build_slack_specialist_subagent, -) - -if TYPE_CHECKING: - from deepagents import SubAgent - from langchain_core.language_models import BaseChatModel - from langchain_core.tools import BaseTool - -logger = logging.getLogger(__name__) - - -# --------------------------------------------------------------------------- -# Tool name constants -# --------------------------------------------------------------------------- - -# Read-only tools that ``explore`` is permitted to use. Names match the -# tools provided by the deepagents ``FilesystemMiddleware`` (``ls``, ``read_file``, -# ``glob``, ``grep``) plus the SurfSense-side read tools. -EXPLORE_READ_TOOLS: frozenset[str] = frozenset( - { - "web_search", - "scrape_webpage", - "read_file", - "ls", - "glob", - "grep", - } -) - -# Tools ``report_writer`` may call. The set is intentionally narrow so the -# subagent doesn't drift into tangential research; if richer source-gathering -# is needed, the parent should hand off to ``explore`` first. -REPORT_WRITER_TOOLS: frozenset[str] = frozenset( - { - "read_file", - "generate_report", - } -) - -# Wildcard patterns that match write tools we deny by default in read-only -# subagents. Anchored at start AND end via :func:`Rule` semantics. We use -# substring-style ``*verb*`` patterns because connector tool names typically -# put the verb in the middle (``linear_create_issue``, ``slack_send_message``, -# ``notion_update_page``); strict suffix patterns (``*_create``) miss those. -# -# A handful of canonical exact-match names is appended so that bare verbs -# (``edit``, ``write``) are also blocked even when a connector dropped the -# usual prefix. -WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = ( - "*create*", - "*update*", - "*delete*", - "*send*", - "*write*", - "*edit*", - "*move*", - "*mkdir*", - "*upload*", - "edit_file", - "write_file", - "move_file", - "mkdir", - "rm", - "rmdir", - "update_memory", - "update_memory_team", - "update_memory_private", -) - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -# Tool names that are NOT in the registry's ``tools`` list because they -# are provided dynamically by middleware at compile time. We don't pass -# them through ``_filter_tools`` (the actual ``BaseTool`` instances live -# inside the middleware), but we do exempt them from the "missing" warning -# below — operators were seeing spurious noise like -# ``missing: ['glob', 'grep', 'ls', 'read_file']`` even though those -# tools are reachable via :class:`SurfSenseFilesystemMiddleware` once the -# subagent is compiled. -_MIDDLEWARE_PROVIDED_TOOL_NAMES: frozenset[str] = frozenset( - { - "ls", - "read_file", - "write_file", - "edit_file", - "glob", - "grep", - "execute", - "write_todos", - "task", - } -) - - -def _filter_tools( - tools: Sequence[BaseTool], - allowed_names: Iterable[str], -) -> list[BaseTool]: - """Return only tools whose ``name`` appears in ``allowed_names``. - - Tools are looked up by exact name. Names matching - :data:`_MIDDLEWARE_PROVIDED_TOOL_NAMES` are intentionally absent from - ``tools`` (they're injected by middleware at compile time) and are - silently excluded from the "missing" warning so operators don't see - false positives every build. - """ - allowed = set(allowed_names) - selected = [t for t in tools if t.name in allowed] - missing = sorted( - (allowed - {t.name for t in selected}) - _MIDDLEWARE_PROVIDED_TOOL_NAMES - ) - if missing: - logger.info( - "Subagent build: %d/%d registry tools available; missing: %s", - len(selected), - len(allowed - _MIDDLEWARE_PROVIDED_TOOL_NAMES), - missing, - ) - return selected - - -def _read_only_deny_rules() -> list[Rule]: - """Synthesize a list of deny rules covering common write-tool patterns.""" - return [ - Rule(permission=pattern, pattern="*", action="deny") - for pattern in WRITE_TOOL_DENY_PATTERNS - ] - - -def _build_permission_middleware(deny_rules: list[Rule], origin: str): - """Construct a :class:`PermissionMiddleware` seeded with ``deny_rules``. - - Imported lazily because the middleware module pulls in interrupt/HITL - machinery we don't want at import time of this config file. - """ - from app.agents.new_chat.middleware.permission import PermissionMiddleware - - return PermissionMiddleware( - rulesets=[Ruleset(rules=deny_rules, origin=origin)], - ) - - -def _wrap_with_subagent_essentials( - custom_middleware: list, - *, - agent_tools: Sequence[BaseTool], - extra_middleware: Sequence[Any] | None = None, -): - """Compose the final middleware list for a specialized subagent. - - Order, outer to inner: - - 1. ``extra_middleware`` — provided by the caller (typically the parent - agent's ``SurfSenseFilesystemMiddleware`` and ``TodoListMiddleware``) - so the subagent inherits the parent's filesystem/todo view. These - run **before** the subagent-local middleware so their tools are - wired up before permissioning kicks in. - 2. ``custom_middleware`` — subagent-local rules (e.g. permission deny - lists). - 3. :class:`PatchToolCallsMiddleware` — normalizes tool-call shapes. - 4. :class:`DedupHITLToolCallsMiddleware` — collapses duplicate HITL - calls using metadata declared at registry time. - - Without ``extra_middleware`` the subagent will only have the registry - tools listed in its ``tools`` field — meaning ``read_file``, ``ls``, - ``grep``, etc. won't exist. Always pass ``extra_middleware`` from the - parent unless you specifically want a sandboxed subagent. - """ - from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware - - from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware - - return [ - *(extra_middleware or []), - *custom_middleware, - PatchToolCallsMiddleware(), - DedupHITLToolCallsMiddleware(agent_tools=list(agent_tools)), - ] - - -# --------------------------------------------------------------------------- -# System prompts -# --------------------------------------------------------------------------- - -EXPLORE_SYSTEM_PROMPT = """You are the **explore** subagent for SurfSense. - -## Your job -Conduct read-only research across the user's knowledge base, the web, and any documents the parent agent has surfaced. Return a synthesized answer with explicit citations — never speculate beyond the sources you have actually inspected. - -## Tools available -- `web_search` — only when the user's KB clearly does not contain the answer. -- `scrape_webpage` — to read a URL the user or the search results provided. -- `read_file`, `ls`, `glob`, `grep` — to inspect specific documents or trees the parent has flagged. - -## Rules -- Read-only. You cannot create, edit, delete, send, or move anything. -- Cite every claim. Use `[citation:chunk_id]` exactly as the chunk tag specifies. -- If a sub-question has no support in the inspected sources, say so explicitly. Do not fabricate. -- Return the most useful synthesis in your single final message. The parent agent will not be able to follow up. -""" - - -REPORT_WRITER_SYSTEM_PROMPT = """You are the **report_writer** subagent for SurfSense. - -## Your job -Produce a single high-quality report deliverable using `generate_report`. The parent has already gathered (or knows where to gather) the underlying sources. - -## Workflow -1. **Outline first.** Before calling `generate_report`, write a one-paragraph outline of the sections you plan to produce. Confirm the outline reflects the parent's instructions. -2. **Source resolution.** Decide whether to call `read_file` for any final-checks, or whether the parent's earlier tool calls already cover the source set. -3. **One report.** Call `generate_report` exactly once with `source_strategy` chosen per the topic and chat history (see the `report-writing` skill). -4. **Confirm.** End with a one-sentence summary in your final message — never paste the report back into chat; the artifact card renders itself. -""" - - -CONNECTOR_NEGOTIATOR_SYSTEM_PROMPT = """You are the **connector_negotiator** subagent for SurfSense. - -## Your job -Coordinate cross-connector workflows: chains where the result of one service's tool feeds into another's. Common shapes include "find Linear issues mentioned in last week's Slack messages", "draft a Gmail reply citing a Notion doc", or "list Linear tickets opened by the same person who filed Jira FOO-123". - -## Workflow -1. **Plan.** Identify the connector hops needed and the order they should run in. Write a short plan in your first message. -2. **Verify access.** Use `get_connected_accounts` to confirm the relevant connectors are actually wired up before issuing tool calls. If a connector is missing, stop and report — do not fabricate. -3. **Execute.** Run each hop, citing IDs (issue keys, message ts, page IDs) in your scratch notes so the parent can audit. -4. **Hand back.** Return a structured summary with the final answer plus the chain of evidence (issue → message → page, etc.). - -## Caveats -- If a hop fails, do not retry blindly — return the partial result and explain. -- Mutating tools (create, update, delete, send) require parent permission; you are NOT cleared to call them on your own. -""" - - -# --------------------------------------------------------------------------- -# Subagent builders -# --------------------------------------------------------------------------- - - -def build_explore_subagent( - *, - tools: Sequence[BaseTool], - model: BaseChatModel | None = None, - extra_middleware: Sequence[Any] | None = None, -) -> SubAgent: - """Build the read-only ``explore`` subagent spec. - - Pass ``extra_middleware`` (typically the parent's filesystem + todo - middleware) so the subagent can actually use ``read_file``, ``ls``, - ``grep``, ``glob`` — which its system prompt promises but which only - exist when their middleware is mounted. - """ - from deepagents import SubAgent # noqa: F401 (TypedDict for type clarity) - - selected_tools = _filter_tools(tools, EXPLORE_READ_TOOLS) - deny_rules = _read_only_deny_rules() - permission_mw = _build_permission_middleware(deny_rules, origin="subagent_explore") - - spec: dict = { - "name": "explore", - "description": ( - "Read-only research across the user's knowledge base and the web. " - "Use when the parent needs deeply-cited synthesis without " - "modifying anything." - ), - "system_prompt": EXPLORE_SYSTEM_PROMPT, - "tools": selected_tools, - "middleware": _wrap_with_subagent_essentials( - [permission_mw], - agent_tools=selected_tools, - extra_middleware=extra_middleware, - ), - "skills": default_skills_sources(), - } - if model is not None: - spec["model"] = model - return spec # type: ignore[return-value] - - -def build_report_writer_subagent( - *, - tools: Sequence[BaseTool], - model: BaseChatModel | None = None, - extra_middleware: Sequence[Any] | None = None, -) -> SubAgent: - """Build the ``report_writer`` subagent spec. - - Read-only deny ruleset still applies — the subagent should call - ``generate_report`` and nothing else mutating. ``generate_report`` - creates a report artifact via a backend service and is intentionally - **not** denied. - - Pass ``extra_middleware`` (typically the parent's filesystem + todo - middleware) so the subagent can run ``read_file`` for source-checks - before calling ``generate_report``. - """ - selected_tools = _filter_tools(tools, REPORT_WRITER_TOOLS) - deny_rules = _read_only_deny_rules() - permission_mw = _build_permission_middleware( - deny_rules, origin="subagent_report_writer" - ) - - spec: dict = { - "name": "report_writer", - "description": ( - "Produce a single Markdown report artifact via generate_report, " - "using the outline-then-fill protocol. Use when the parent has " - "decided a deliverable is needed." - ), - "system_prompt": REPORT_WRITER_SYSTEM_PROMPT, - "tools": selected_tools, - "middleware": _wrap_with_subagent_essentials( - [permission_mw], - agent_tools=selected_tools, - extra_middleware=extra_middleware, - ), - "skills": default_skills_sources(), - } - if model is not None: - spec["model"] = model - return spec # type: ignore[return-value] - - -def build_connector_negotiator_subagent( - *, - tools: Sequence[BaseTool], - model: BaseChatModel | None = None, - extra_middleware: Sequence[Any] | None = None, -) -> SubAgent: - """Build the ``connector_negotiator`` subagent spec. - - Inherits all MCP / connector tools the parent has plus - ``get_connected_accounts``. Read-only by default; permission rules deny - write/mutation patterns. The parent agent re-asks for permission if a - connector mutation is genuinely needed. - - Pass ``extra_middleware`` (typically the parent's filesystem + todo - middleware) so this subagent shares the parent's filesystem view when - citing evidence across hops. - """ - parent_tool_names = {t.name for t in tools} - allowed: set[str] = set() - if "get_connected_accounts" in parent_tool_names: - allowed.add("get_connected_accounts") - # Inherit anything that smells connector- or MCP-related but is not a - # bulk-write API. Heuristic: keep all parent tools; rely on the deny - # ruleset to block mutation patterns. This mirrors the plan: "all - # MCP/connector tools the parent has". - for name in parent_tool_names: - allowed.add(name) - selected_tools = _filter_tools(tools, allowed) - - deny_rules = _read_only_deny_rules() - permission_mw = _build_permission_middleware( - deny_rules, origin="subagent_connector_negotiator" - ) - - spec: dict = { - "name": "connector_negotiator", - "description": ( - "Coordinate read-only chains across connectors (Slack → Linear, " - "Notion → Gmail, etc.). Returns a structured summary with the " - "evidence chain. Cannot mutate connector state." - ), - "system_prompt": CONNECTOR_NEGOTIATOR_SYSTEM_PROMPT, - "tools": selected_tools, - "middleware": _wrap_with_subagent_essentials( - [permission_mw], - agent_tools=selected_tools, - extra_middleware=extra_middleware, - ), - "skills": default_skills_sources(), - } - if model is not None: - spec["model"] = model - return spec # type: ignore[return-value] - - -def build_specialized_subagents( - *, - tools: Sequence[BaseTool], - model: BaseChatModel | None = None, - extra_middleware: Sequence[Any] | None = None, -) -> list[SubAgent]: - """Return the canonical list of specialized subagents to register. - - Order matters only for the order they appear in the ``task`` tool - description — most useful first. - """ - return [ - build_explore_subagent( - tools=tools, model=model, extra_middleware=extra_middleware - ), - build_report_writer_subagent( - tools=tools, model=model, extra_middleware=extra_middleware - ), - build_linear_specialist_subagent( - tools=tools, model=model, extra_middleware=extra_middleware - ), - build_slack_specialist_subagent( - tools=tools, model=model, extra_middleware=extra_middleware - ), - build_connector_negotiator_subagent( - tools=tools, model=model, extra_middleware=extra_middleware - ), - ] diff --git a/surfsense_backend/app/agents/new_chat/subagents/constants.py b/surfsense_backend/app/agents/new_chat/subagents/constants.py deleted file mode 100644 index cb1da499b..000000000 --- a/surfsense_backend/app/agents/new_chat/subagents/constants.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Shared constants for provider subagent safety policies.""" - -from __future__ import annotations - -# Generic mutation-deny patterns for read-only specialist roles. -WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = ( - "*create*", - "*update*", - "*delete*", - "*send*", - "*write*", - "*edit*", - "*move*", - "*mkdir*", - "*upload*", - "edit_file", - "write_file", - "move_file", - "mkdir", - "update_memory", - "update_memory_team", - "update_memory_private", -) - -# Tools that mutate virtual KB filesystem or parent/global chat state. -# Provider specialists should not mutate these surfaces directly. -NON_PROVIDER_STATE_MUTATION_DENY: frozenset[str] = frozenset( - { - # Exact tool names from shared deny patterns. - *{name for name in WRITE_TOOL_DENY_PATTERNS if "*" not in name}, - # Additional non-provider state mutation controls. - "write_todos", - "task", - } -) diff --git a/surfsense_backend/app/agents/new_chat/subagents/providers/linear.py b/surfsense_backend/app/agents/new_chat/subagents/providers/linear.py deleted file mode 100644 index da332fe28..000000000 --- a/surfsense_backend/app/agents/new_chat/subagents/providers/linear.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Linear provider specialist subagent. - -This file is intentionally standalone so provider specialists can be reviewed -and evolved independently (one provider per file). -""" - -from __future__ import annotations - -from collections.abc import Sequence -from typing import TYPE_CHECKING, Any - -from app.agents.new_chat.permissions import Rule, Ruleset -from app.agents.new_chat.subagents.constants import NON_PROVIDER_STATE_MUTATION_DENY -from app.services.mcp_oauth.registry import ( - LINEAR_MCP_READONLY_TOOL_NAMES, - linear_mcp_original_tool_name, -) - -if TYPE_CHECKING: - from deepagents import SubAgent - from langchain_core.language_models import BaseChatModel - from langchain_core.tools import BaseTool - - -# Read vs write Linear MCP tools are defined in -# ``app.services.mcp_oauth.registry`` (``LINEAR_MCP_READONLY_TOOL_NAMES`` / -# ``LINEAR_MCP_WRITE_TOOL_NAMES``). Any other Linear-domain tool requires approval. - -LINEAR_SYSTEM_PROMPT = """You are the linear_specialist subagent for SurfSense. - -Role: -- You are the Linear domain specialist. Handle Linear-only requests accurately. - -Primary objective: -- Resolve the user's Linear task and return a concise, auditable result. - -Routing boundary: -- Use this subagent for Linear-domain tasks (issues, status, assignees, labels, - teams, and project references). -- If the task is primarily non-Linear or cross-connector orchestration, return - status=needs_input and hand control back to the parent with the exact next hop. - -Execution steps: -1) Verify Linear access first (use get_connected_accounts if needed). -2) Prefer read/list tools first to gather current issue facts before concluding. -3) Track key identifiers in your reasoning: issue ID, issue key, team ID, label ID. -4) If required identifiers are missing, ask the parent for exactly what is missing. -5) Return a compact result with findings + evidence references. - -Output format: -- status: success | needs_input | blocked | error -- summary: one short paragraph -- evidence: bullet list of concrete IDs / issue keys used -- next_step: one sentence (only when blocked or needs_input) - -Constraints: -- Do not invent issue keys, IDs, or workflow state names. -- Mutating Linear operations are allowed only with explicit approval. -- If Linear connector access is unavailable, stop and return status=blocked. -""" - - -def _select_linear_tools(tools: Sequence[BaseTool]) -> list[BaseTool]: - """Keep Linear tools plus minimal shared read utilities.""" - allowed_exact = { - "get_connected_accounts", - "read_file", - "ls", - "glob", - "grep", - } - selected: list[BaseTool] = [] - for tool in tools: - if tool.name in allowed_exact: - selected.append(tool) - continue - if linear_mcp_original_tool_name(tool.name) is not None: - selected.append(tool) - continue - if tool.name.startswith("linear_") or tool.name.endswith("_linear_issue"): - selected.append(tool) - return selected - - -def _is_linear_readonly_tool_name(name: str) -> bool: - """Return True when a tool name maps to a read-only Linear MCP operation.""" - base = linear_mcp_original_tool_name(name) - return base is not None and base in LINEAR_MCP_READONLY_TOOL_NAMES - - -def _is_linear_domain_tool_name(name: str) -> bool: - """Return True for Linear-domain tools handled by this specialist.""" - if linear_mcp_original_tool_name(name) is not None: - return True - return name.startswith("linear_") or name.endswith("_linear_issue") - - -def _permission_middleware(*, selected_tools: Sequence[BaseTool]) -> Any: - """Permission policy for Linear specialist.""" - from app.agents.new_chat.middleware.permission import PermissionMiddleware - - ask_tools = sorted( - { - tool.name - for tool in selected_tools - if _is_linear_domain_tool_name(tool.name) - and not _is_linear_readonly_tool_name(tool.name) - } - ) - rules: list[Rule] = [Rule(permission="*", pattern="*", action="allow")] - rules.extend( - Rule(permission=name, pattern="*", action="deny") - for name in NON_PROVIDER_STATE_MUTATION_DENY - ) - rules.extend(Rule(permission=name, pattern="*", action="ask") for name in ask_tools) - return PermissionMiddleware( - rulesets=[Ruleset(rules=rules, origin="subagent_linear_specialist")] - ) - - -def _wrap_subagent_middleware( - *, - selected_tools: Sequence[BaseTool], - extra_middleware: Sequence[Any] | None, -) -> list[Any]: - """Apply standard middleware chain used by other subagents.""" - from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware - - from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware - - return [ - *(extra_middleware or []), - _permission_middleware(selected_tools=selected_tools), - PatchToolCallsMiddleware(), - DedupHITLToolCallsMiddleware(agent_tools=list(selected_tools)), - ] - - -def build_linear_specialist_subagent( - *, - tools: Sequence[BaseTool], - model: BaseChatModel | None = None, - extra_middleware: Sequence[Any] | None = None, -) -> SubAgent: - """Build the ``linear_specialist`` provider subagent spec.""" - selected_tools = _select_linear_tools(tools) - spec: dict[str, Any] = { - "name": "linear_specialist", - "description": ( - "Linear operations specialist for issue and workflow requests, " - "with strict evidence tracking and approval-gated mutating operations." - ), - "system_prompt": LINEAR_SYSTEM_PROMPT, - "tools": selected_tools, - "middleware": _wrap_subagent_middleware( - selected_tools=selected_tools, - extra_middleware=extra_middleware, - ), - } - if model is not None: - spec["model"] = model - return spec # type: ignore[return-value] diff --git a/surfsense_backend/app/agents/new_chat/subagents/providers/slack.py b/surfsense_backend/app/agents/new_chat/subagents/providers/slack.py deleted file mode 100644 index 90ca80152..000000000 --- a/surfsense_backend/app/agents/new_chat/subagents/providers/slack.py +++ /dev/null @@ -1,170 +0,0 @@ -"""Slack provider specialist subagent. - -This file is intentionally standalone so provider specialists can be reviewed -and evolved independently (one provider per file). -""" - -from __future__ import annotations - -from collections.abc import Sequence -from typing import TYPE_CHECKING, Any - -from app.agents.new_chat.permissions import Rule, Ruleset -from app.agents.new_chat.subagents.constants import NON_PROVIDER_STATE_MUTATION_DENY - -if TYPE_CHECKING: - from deepagents import SubAgent - from langchain_core.language_models import BaseChatModel - from langchain_core.tools import BaseTool - - -# Official references: -# - https://docs.slack.dev/ai/slack-mcp-server -# - https://www.npmjs.com/package/@modelcontextprotocol/server-slack -# -# Policy: only known read-only Slack tools are auto-allowed. Any other -# ``slack_*`` tool is treated as mutating and requires explicit approval. -SLACK_READONLY_TOOL_NAMES: frozenset[str] = frozenset( - { - # Slack-hosted MCP read tools - "slack_search_channels", - "slack_read_channel", - "slack_read_thread", - "slack_read_canvas", - "slack_read_user_profile", - # modelcontextprotocol/server-slack read tools - "slack_list_channels", - "slack_get_channel_history", - "slack_get_thread_replies", - "slack_get_users", - "slack_get_user_profile", - } -) - -SLACK_SYSTEM_PROMPT = """You are the slack_specialist subagent for SurfSense. - -Role: -- You are the Slack domain specialist. Handle Slack-only requests accurately. - -Primary objective: -- Resolve the user's Slack task and return a concise, auditable result. - -Routing boundary: -- Use this subagent for Slack-domain tasks (channels, threads, users, messages, - and Slack canvases). -- If the task is primarily non-Slack or cross-connector orchestration, return - status=needs_input and hand control back to the parent with the exact next hop. - -Execution steps: -1) Verify Slack access first (use get_connected_accounts if needed). -2) Prefer read/list tools first to gather facts before concluding. -3) Track key identifiers in your reasoning: channel ID, message ts, thread ts, user ID. -4) If required identifiers are missing, ask the parent for exactly what is missing. -5) Return a compact result with findings + evidence references. - -Output format: -- status: success | needs_input | blocked | error -- summary: one short paragraph -- evidence: bullet list of concrete IDs / timestamps used -- next_step: one sentence (only when blocked or needs_input) - -Constraints: -- Do not invent Slack IDs, channels, users, or message content. -- Mutating Slack operations are allowed only with explicit approval. -- If Slack connector access is unavailable, stop and return status=blocked. -""" - - -def _select_slack_tools(tools: Sequence[BaseTool]) -> list[BaseTool]: - """Keep Slack tools plus minimal shared read utilities.""" - allowed_exact = { - "get_connected_accounts", - "read_file", - "ls", - "glob", - "grep", - } - slack_prefix = "slack_" - selected: list[BaseTool] = [] - for tool in tools: - if tool.name in allowed_exact: - selected.append(tool) - continue - if tool.name.startswith(slack_prefix): - selected.append(tool) - return selected - - -def _permission_middleware(*, selected_tools: Sequence[BaseTool]) -> Any: - """Permission policy for Slack specialist. - - Intent: - - Allow Slack-domain operations by default. - - Gate Slack mutating operations behind approval (`ask`). - - Hard-deny non-Slack state mutations, especially KB virtual filesystem - mutation and parent-context mutation tools. - """ - from app.agents.new_chat.middleware.permission import PermissionMiddleware - - ask_tools = sorted( - { - tool.name - for tool in selected_tools - if tool.name.startswith("slack_") - and tool.name not in SLACK_READONLY_TOOL_NAMES - } - ) - rules: list[Rule] = [Rule(permission="*", pattern="*", action="allow")] - rules.extend( - Rule(permission=name, pattern="*", action="deny") - for name in NON_PROVIDER_STATE_MUTATION_DENY - ) - rules.extend(Rule(permission=name, pattern="*", action="ask") for name in ask_tools) - return PermissionMiddleware( - rulesets=[Ruleset(rules=rules, origin="subagent_slack_specialist")] - ) - - -def _wrap_subagent_middleware( - *, - selected_tools: Sequence[BaseTool], - extra_middleware: Sequence[Any] | None, -) -> list[Any]: - """Apply standard middleware chain used by other subagents.""" - from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware - - from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware - - return [ - *(extra_middleware or []), - _permission_middleware(selected_tools=selected_tools), - PatchToolCallsMiddleware(), - DedupHITLToolCallsMiddleware(agent_tools=list(selected_tools)), - ] - - -def build_slack_specialist_subagent( - *, - tools: Sequence[BaseTool], - model: BaseChatModel | None = None, - extra_middleware: Sequence[Any] | None = None, -) -> SubAgent: - """Build the ``slack_specialist`` provider subagent spec.""" - selected_tools = _select_slack_tools(tools) - spec: dict[str, Any] = { - "name": "slack_specialist", - "description": ( - "Slack operations specialist for any Slack-domain request " - "(channels, threads, users, and messages), with strict evidence " - "tracking and approval-gated mutating operations." - ), - "system_prompt": SLACK_SYSTEM_PROMPT, - "tools": selected_tools, - "middleware": _wrap_subagent_middleware( - selected_tools=selected_tools, - extra_middleware=extra_middleware, - ), - } - if model is not None: - spec["model"] = model - return spec # type: ignore[return-value] diff --git a/surfsense_backend/app/agents/new_chat/tools/__init__.py b/surfsense_backend/app/agents/new_chat/tools/__init__.py deleted file mode 100644 index 4b5ae3706..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/__init__.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -Tools module for SurfSense deep agent. - -This module contains all the tools available to the SurfSense agent. -To add a new tool, see the documentation in registry.py. - -Available tools: -- generate_podcast: Generate audio podcasts from content -- generate_video_presentation: Generate video presentations with slides and narration -- generate_image: Generate images from text descriptions using AI models -- scrape_webpage: Extract content from webpages -- update_memory: Update the user's / team's memory document -""" - -# Registry exports -# Tool factory exports (for direct use) -from .generate_image import create_generate_image_tool -from .knowledge_base import ( - CONNECTOR_DESCRIPTIONS, - format_documents_for_context, - search_knowledge_base_async, -) -from .podcast import create_generate_podcast_tool -from .registry import ( - BUILTIN_TOOLS, - ToolDefinition, - build_tools, - get_all_tool_names, - get_default_enabled_tools, - get_tool_by_name, -) -from .scrape_webpage import create_scrape_webpage_tool -from .update_memory import create_update_memory_tool, create_update_team_memory_tool -from .video_presentation import create_generate_video_presentation_tool - -__all__ = [ - # Registry - "BUILTIN_TOOLS", - # Knowledge base utilities - "CONNECTOR_DESCRIPTIONS", - "ToolDefinition", - "build_tools", - # Tool factories - "create_generate_image_tool", - "create_generate_podcast_tool", - "create_generate_video_presentation_tool", - "create_scrape_webpage_tool", - "create_update_memory_tool", - "create_update_team_memory_tool", - "format_documents_for_context", - "get_all_tool_names", - "get_default_enabled_tools", - "get_tool_by_name", - "search_knowledge_base_async", -] diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/__init__.py b/surfsense_backend/app/agents/new_chat/tools/confluence/__init__.py deleted file mode 100644 index 3bf80b61b..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Confluence tools for creating, updating, and deleting pages.""" - -from .create_page import create_create_confluence_page_tool -from .delete_page import create_delete_confluence_page_tool -from .update_page import create_update_confluence_page_tool - -__all__ = [ - "create_create_confluence_page_tool", - "create_delete_confluence_page_tool", - "create_update_confluence_page_tool", -] diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py deleted file mode 100644 index c56db1528..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py +++ /dev/null @@ -1,232 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm.attributes import flag_modified - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.confluence_history import ConfluenceHistoryConnector -from app.db import async_session_maker -from app.services.confluence import ConfluenceToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_create_confluence_page_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, - connector_id: int | None = None, -): - """ - Factory function to create the create_confluence_page tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured create_confluence_page tool - """ - del db_session # per-call session — see docstring - - @tool - async def create_confluence_page( - title: str, - content: str | None = None, - space_id: str | None = None, - ) -> dict[str, Any]: - """Create a new page in Confluence. - - Use this tool when the user explicitly asks to create a new Confluence page. - - Args: - title: Title of the page. - content: Optional HTML/storage format content for the page body. - space_id: Optional Confluence space ID to create the page in. - - Returns: - Dictionary with status, page_id, and message. - - IMPORTANT: - - If status is "rejected", do NOT retry. - - If status is "insufficient_permissions", inform user to re-authenticate. - """ - logger.info(f"create_confluence_page called: title='{title}'") - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Confluence tool not properly configured.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = ConfluenceToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected Confluence accounts need re-authentication.", - "connector_type": "confluence", - } - - result = request_approval( - action_type="confluence_page_creation", - tool_name="create_confluence_page", - params={ - "title": title, - "content": content, - "space_id": space_id, - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_title = result.params.get("title", title) - final_content = result.params.get("content", content) or "" - final_space_id = result.params.get("space_id", space_id) - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_title or not final_title.strip(): - return {"status": "error", "message": "Page title cannot be empty."} - if not final_space_id: - return {"status": "error", "message": "A space must be selected."} - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Confluence connector found.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Confluence connector is invalid.", - } - - try: - client = ConfluenceHistoryConnector( - session=db_session, connector_id=actual_connector_id - ) - api_result = await client.create_page( - space_id=final_space_id, - title=final_title, - body=final_content, - ) - await client.close() - except Exception as api_err: - if ( - "http 403" in str(api_err).lower() - or "status code 403" in str(api_err).lower() - ): - try: - _conn = connector - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - pass - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - page_id = str(api_result.get("id", "")) - page_links = ( - api_result.get("_links", {}) if isinstance(api_result, dict) else {} - ) - page_url = "" - if page_links.get("base") and page_links.get("webui"): - page_url = f"{page_links['base']}{page_links['webui']}" - - kb_message_suffix = "" - try: - from app.services.confluence import ConfluenceKBSyncService - - kb_service = ConfluenceKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - page_id=page_id, - page_title=final_title, - space_id=final_space_id, - body_content=final_content, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." - - return { - "status": "success", - "page_id": page_id, - "page_url": page_url, - "message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error(f"Error creating Confluence page: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while creating the page.", - } - - return create_confluence_page diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py deleted file mode 100644 index d4cd5032f..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py +++ /dev/null @@ -1,213 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm.attributes import flag_modified - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.confluence_history import ConfluenceHistoryConnector -from app.db import async_session_maker -from app.services.confluence import ConfluenceToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_delete_confluence_page_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, - connector_id: int | None = None, -): - """ - Factory function to create the delete_confluence_page tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured delete_confluence_page tool - """ - del db_session # per-call session — see docstring - - @tool - async def delete_confluence_page( - page_title_or_id: str, - delete_from_kb: bool = False, - ) -> dict[str, Any]: - """Delete a Confluence page. - - Use this tool when the user asks to delete or remove a Confluence page. - - Args: - page_title_or_id: The page title or ID to identify the page. - delete_from_kb: Whether to also remove from the knowledge base. - - Returns: - Dictionary with status, message, and deleted_from_kb. - - IMPORTANT: - - If status is "rejected", do NOT retry. - - If status is "not_found", relay the message to the user. - - If status is "insufficient_permissions", inform user to re-authenticate. - """ - logger.info( - f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'" - ) - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Confluence tool not properly configured.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = ConfluenceToolMetadataService(db_session) - context = await metadata_service.get_deletion_context( - search_space_id, user_id, page_title_or_id - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "confluence", - } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - - page_data = context["page"] - page_id = page_data["page_id"] - page_title = page_data.get("page_title", "") - document_id = page_data["document_id"] - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="confluence_page_deletion", - tool_name="delete_confluence_page", - params={ - "page_id": page_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get( - "delete_from_kb", delete_from_kb - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this page.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Confluence connector is invalid.", - } - - try: - client = ConfluenceHistoryConnector( - session=db_session, connector_id=final_connector_id - ) - await client.delete_page(final_page_id) - await client.close() - except Exception as api_err: - if ( - "http 403" in str(api_err).lower() - or "status code 403" in str(api_err).lower() - ): - try: - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass - return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - - message = f"Confluence page '{page_title}' deleted successfully." - if deleted_from_kb: - message += " Also removed from the knowledge base." - - return { - "status": "success", - "page_id": final_page_id, - "deleted_from_kb": deleted_from_kb, - "message": message, - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error(f"Error deleting Confluence page: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while deleting the page.", - } - - return delete_confluence_page diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py deleted file mode 100644 index 51c205e00..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py +++ /dev/null @@ -1,240 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm.attributes import flag_modified - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.confluence_history import ConfluenceHistoryConnector -from app.db import async_session_maker -from app.services.confluence import ConfluenceToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_update_confluence_page_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, - connector_id: int | None = None, -): - """ - Factory function to create the update_confluence_page tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured update_confluence_page tool - """ - del db_session # per-call session — see docstring - - @tool - async def update_confluence_page( - page_title_or_id: str, - new_title: str | None = None, - new_content: str | None = None, - ) -> dict[str, Any]: - """Update an existing Confluence page. - - Use this tool when the user asks to modify or edit a Confluence page. - - Args: - page_title_or_id: The page title or ID to identify the page. - new_title: Optional new title for the page. - new_content: Optional new HTML/storage format content. - - Returns: - Dictionary with status and message. - - IMPORTANT: - - If status is "rejected", do NOT retry. - - If status is "not_found", relay the message to the user. - - If status is "insufficient_permissions", inform user to re-authenticate. - """ - logger.info( - f"update_confluence_page called: page_title_or_id='{page_title_or_id}'" - ) - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Confluence tool not properly configured.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = ConfluenceToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, page_title_or_id - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "confluence", - } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - - page_data = context["page"] - page_id = page_data["page_id"] - current_title = page_data["page_title"] - current_body = page_data.get("body", "") - current_version = page_data.get("version", 1) - document_id = page_data.get("document_id") - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="confluence_page_update", - tool_name="update_confluence_page", - params={ - "page_id": page_id, - "document_id": document_id, - "new_title": new_title, - "new_content": new_content, - "version": current_version, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_title = result.params.get("new_title", new_title) or current_title - final_content = result.params.get("new_content", new_content) - if final_content is None: - final_content = current_body - final_version = result.params.get("version", current_version) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_document_id = result.params.get("document_id", document_id) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this page.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Confluence connector is invalid.", - } - - try: - client = ConfluenceHistoryConnector( - session=db_session, connector_id=final_connector_id - ) - api_result = await client.update_page( - page_id=final_page_id, - title=final_title, - body=final_content, - version_number=final_version + 1, - ) - await client.close() - except Exception as api_err: - if ( - "http 403" in str(api_err).lower() - or "status code 403" in str(api_err).lower() - ): - try: - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass - return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - page_links = ( - api_result.get("_links", {}) if isinstance(api_result, dict) else {} - ) - page_url = "" - if page_links.get("base") and page_links.get("webui"): - page_url = f"{page_links['base']}{page_links['webui']}" - - kb_message_suffix = "" - if final_document_id: - try: - from app.services.confluence import ConfluenceKBSyncService - - kb_service = ConfluenceKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=final_document_id, - page_id=final_page_id, - user_id=user_id, - search_space_id=search_space_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = ( - " The knowledge base will be updated in the next sync." - ) - except Exception as kb_err: - logger.warning(f"KB sync after update failed: {kb_err}") - kb_message_suffix = ( - " The knowledge base will be updated in the next sync." - ) - - return { - "status": "success", - "page_id": final_page_id, - "page_url": page_url, - "message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error(f"Error updating Confluence page: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while updating the page.", - } - - return update_confluence_page diff --git a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py deleted file mode 100644 index 6420a90e6..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Connected-accounts discovery tool. - -Lets the LLM discover which accounts are connected for a given service -(e.g. "jira", "linear", "slack") and retrieve the metadata it needs to -call action tools — such as Jira's ``cloudId``. - -The tool returns **only** non-sensitive fields explicitly listed in the -service's ``account_metadata_keys`` (see ``registry.py``), plus the -always-present ``display_name`` and ``connector_id``. -""" - -import logging -from typing import Any - -from langchain_core.tools import StructuredTool -from pydantic import BaseModel, Field -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker -from app.services.mcp_oauth.registry import MCP_SERVICES - -logger = logging.getLogger(__name__) - -_SERVICE_KEY_BY_CONNECTOR_TYPE: dict[str, str] = { - cfg.connector_type: key for key, cfg in MCP_SERVICES.items() -} - - -class GetConnectedAccountsInput(BaseModel): - service: str = Field( - description=( - "Service key to look up connected accounts for. " - "Valid values: " + ", ".join(sorted(MCP_SERVICES.keys())) - ), - ) - - -def _extract_display_name(connector: SearchSourceConnector) -> str: - """Best-effort human-readable label for a connector.""" - cfg = connector.config or {} - if cfg.get("display_name"): - return cfg["display_name"] - if cfg.get("base_url"): - return f"{connector.name} ({cfg['base_url']})" - if cfg.get("organization_name"): - return f"{connector.name} ({cfg['organization_name']})" - return connector.name - - -def create_get_connected_accounts_tool( - db_session: AsyncSession, - search_space_id: int, - user_id: str, -) -> StructuredTool: - """Factory function to create the get_connected_accounts tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - search_space_id: Search space ID to scope account discovery to. - user_id: User ID to scope account discovery to. - - Returns: - Configured StructuredTool for connected-accounts discovery. - """ - del db_session # per-call session — see docstring - - async def _run(service: str) -> list[dict[str, Any]]: - svc_cfg = MCP_SERVICES.get(service) - if not svc_cfg: - return [ - { - "error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}" - } - ] - - try: - connector_type = SearchSourceConnectorType(svc_cfg.connector_type) - except ValueError: - return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}] - - async with async_session_maker() as db_session: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == connector_type, - ) - ) - connectors = result.scalars().all() - - if not connectors: - return [ - { - "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings." - } - ] - - is_multi = len(connectors) > 1 - - accounts: list[dict[str, Any]] = [] - for conn in connectors: - cfg = conn.config or {} - entry: dict[str, Any] = { - "connector_id": conn.id, - "display_name": _extract_display_name(conn), - "service": service, - } - if is_multi: - entry["tool_prefix"] = f"{service}_{conn.id}" - for key in svc_cfg.account_metadata_keys: - if key in cfg: - entry[key] = cfg[key] - accounts.append(entry) - - return accounts - - return StructuredTool( - name="get_connected_accounts", - description=( - "Discover which accounts are connected for a service (e.g. jira, linear, slack, clickup, airtable). " - "Returns display names and service-specific metadata the action tools need " - "(e.g. Jira's cloudId). Call this BEFORE using a service's action tools when " - "you need an account identifier or are unsure which account to use." - ), - coroutine=_run, - args_schema=GetConnectedAccountsInput, - metadata={"hitl": False}, - ) diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/__init__.py b/surfsense_backend/app/agents/new_chat/tools/discord/__init__.py deleted file mode 100644 index b4eaec1f0..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/discord/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from app.agents.new_chat.tools.discord.list_channels import ( - create_list_discord_channels_tool, -) -from app.agents.new_chat.tools.discord.read_messages import ( - create_read_discord_messages_tool, -) -from app.agents.new_chat.tools.discord.send_message import ( - create_send_discord_message_tool, -) - -__all__ = [ - "create_list_discord_channels_tool", - "create_read_discord_messages_tool", - "create_send_discord_message_tool", -] diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py b/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py deleted file mode 100644 index c345f8a5e..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Shared auth helper for Discord agent tools (REST API, not gateway bot).""" - -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.config import config -from app.db import SearchSourceConnector, SearchSourceConnectorType -from app.utils.oauth_security import TokenEncryption - -DISCORD_API = "https://discord.com/api/v10" - - -async def get_discord_connector( - db_session: AsyncSession, - search_space_id: int, - user_id: str, -) -> SearchSourceConnector | None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.DISCORD_CONNECTOR, - ) - ) - return result.scalars().first() - - -def get_bot_token(connector: SearchSourceConnector) -> str: - """Extract and decrypt the bot token from connector config.""" - cfg = dict(connector.config) - if cfg.get("_token_encrypted") and config.SECRET_KEY: - enc = TokenEncryption(config.SECRET_KEY) - if cfg.get("bot_token"): - cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"]) - token = cfg.get("bot_token") - if not token: - raise ValueError("Discord bot token not found in connector config.") - return token - - -def get_guild_id(connector: SearchSourceConnector) -> str | None: - return connector.config.get("guild_id") diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py deleted file mode 100644 index 01159a261..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py +++ /dev/null @@ -1,107 +0,0 @@ -import logging -from typing import Any - -import httpx -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.db import async_session_maker - -from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id - -logger = logging.getLogger(__name__) - - -def create_list_discord_channels_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the list_discord_channels tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured list_discord_channels tool - """ - del db_session # per-call session — see docstring - - @tool - async def list_discord_channels() -> dict[str, Any]: - """List text channels in the connected Discord server. - - Returns: - Dictionary with status and a list of channels (id, name). - """ - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Discord tool not properly configured.", - } - - try: - async with async_session_maker() as db_session: - connector = await get_discord_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Discord connector found."} - - guild_id = get_guild_id(connector) - if not guild_id: - return { - "status": "error", - "message": "No guild ID in Discord connector config.", - } - - token = get_bot_token(connector) - - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{DISCORD_API}/guilds/{guild_id}/channels", - headers={"Authorization": f"Bot {token}"}, - timeout=15.0, - ) - - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Discord bot token is invalid.", - "connector_type": "discord", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Discord API error: {resp.status_code}", - } - - # Type 0 = text channel - channels = [ - {"id": ch["id"], "name": ch["name"]} - for ch in resp.json() - if ch.get("type") == 0 - ] - return { - "status": "success", - "guild_id": guild_id, - "channels": channels, - "total": len(channels), - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error("Error listing Discord channels: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to list Discord channels."} - - return list_discord_channels diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py deleted file mode 100644 index 88d6cdd49..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py +++ /dev/null @@ -1,120 +0,0 @@ -import logging -from typing import Any - -import httpx -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.db import async_session_maker - -from ._auth import DISCORD_API, get_bot_token, get_discord_connector - -logger = logging.getLogger(__name__) - - -def create_read_discord_messages_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the read_discord_messages tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured read_discord_messages tool - """ - del db_session # per-call session — see docstring - - @tool - async def read_discord_messages( - channel_id: str, - limit: int = 25, - ) -> dict[str, Any]: - """Read recent messages from a Discord text channel. - - Args: - channel_id: The Discord channel ID (from list_discord_channels). - limit: Number of messages to fetch (default 25, max 50). - - Returns: - Dictionary with status and a list of messages including - id, author, content, timestamp. - """ - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Discord tool not properly configured.", - } - - limit = min(limit, 50) - - try: - async with async_session_maker() as db_session: - connector = await get_discord_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Discord connector found."} - - token = get_bot_token(connector) - - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{DISCORD_API}/channels/{channel_id}/messages", - headers={"Authorization": f"Bot {token}"}, - params={"limit": limit}, - timeout=15.0, - ) - - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Discord bot token is invalid.", - "connector_type": "discord", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Bot lacks permission to read this channel.", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Discord API error: {resp.status_code}", - } - - messages = [ - { - "id": m["id"], - "author": m.get("author", {}).get("username", "Unknown"), - "content": m.get("content", ""), - "timestamp": m.get("timestamp", ""), - } - for m in resp.json() - ] - - return { - "status": "success", - "channel_id": channel_id, - "messages": messages, - "total": len(messages), - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error("Error reading Discord messages: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to read Discord messages."} - - return read_discord_messages diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py deleted file mode 100644 index 5fe6fde35..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py +++ /dev/null @@ -1,136 +0,0 @@ -import logging -from typing import Any - -import httpx -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.db import async_session_maker - -from ._auth import DISCORD_API, get_bot_token, get_discord_connector - -logger = logging.getLogger(__name__) - - -def create_send_discord_message_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the send_discord_message tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured send_discord_message tool - """ - del db_session # per-call session — see docstring - - @tool - async def send_discord_message( - channel_id: str, - content: str, - ) -> dict[str, Any]: - """Send a message to a Discord text channel. - - Args: - channel_id: The Discord channel ID (from list_discord_channels). - content: The message text (max 2000 characters). - - Returns: - Dictionary with status, message_id on success. - - IMPORTANT: - - If status is "rejected", the user explicitly declined. Do NOT retry. - """ - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Discord tool not properly configured.", - } - - if len(content) > 2000: - return { - "status": "error", - "message": "Message exceeds Discord's 2000-character limit.", - } - - try: - async with async_session_maker() as db_session: - connector = await get_discord_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Discord connector found."} - - result = request_approval( - action_type="discord_send_message", - tool_name="send_discord_message", - params={"channel_id": channel_id, "content": content}, - context={"connector_id": connector.id}, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Message was not sent.", - } - - final_content = result.params.get("content", content) - final_channel = result.params.get("channel_id", channel_id) - - token = get_bot_token(connector) - - async with httpx.AsyncClient() as client: - resp = await client.post( - f"{DISCORD_API}/channels/{final_channel}/messages", - headers={ - "Authorization": f"Bot {token}", - "Content-Type": "application/json", - }, - json={"content": final_content}, - timeout=15.0, - ) - - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Discord bot token is invalid.", - "connector_type": "discord", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Bot lacks permission to send messages in this channel.", - } - if resp.status_code not in (200, 201): - return { - "status": "error", - "message": f"Discord API error: {resp.status_code}", - } - - msg_data = resp.json() - return { - "status": "success", - "message_id": msg_data.get("id"), - "message": f"Message sent to channel {final_channel}.", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error("Error sending Discord message: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to send Discord message."} - - return send_discord_message diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/__init__.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/__init__.py deleted file mode 100644 index 836b9ee41..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from app.agents.new_chat.tools.dropbox.create_file import ( - create_create_dropbox_file_tool, -) -from app.agents.new_chat.tools.dropbox.trash_file import ( - create_delete_dropbox_file_tool, -) - -__all__ = [ - "create_create_dropbox_file_tool", - "create_delete_dropbox_file_tool", -] diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py deleted file mode 100644 index 7aae034cc..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py +++ /dev/null @@ -1,299 +0,0 @@ -import logging -import os -import tempfile -from pathlib import Path -from typing import Any, Literal - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.dropbox.client import DropboxClient -from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker - -logger = logging.getLogger(__name__) - -DOCX_MIME = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - -_FILE_TYPE_LABELS = { - "paper": "Dropbox Paper (.paper)", - "docx": "Word Document (.docx)", -} - -_SUPPORTED_TYPES = [ - {"value": "paper", "label": "Dropbox Paper (.paper)"}, - {"value": "docx", "label": "Word Document (.docx)"}, -] - - -def _ensure_extension(name: str, file_type: str) -> str: - """Strip any existing extension and append the correct one.""" - stem = Path(name).stem - ext = ".paper" if file_type == "paper" else ".docx" - return f"{stem}{ext}" - - -def _markdown_to_docx(markdown_text: str) -> bytes: - """Convert a markdown string to DOCX bytes using pypandoc.""" - import pypandoc - - fd, tmp_path = tempfile.mkstemp(suffix=".docx") - os.close(fd) - try: - pypandoc.convert_text( - markdown_text, - "docx", - format="gfm", - extra_args=["--standalone"], - outputfile=tmp_path, - ) - with open(tmp_path, "rb") as f: - return f.read() - finally: - os.unlink(tmp_path) - - -def create_create_dropbox_file_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the create_dropbox_file tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured create_dropbox_file tool - """ - del db_session # per-call session — see docstring - - @tool - async def create_dropbox_file( - name: str, - file_type: Literal["paper", "docx"] = "paper", - content: str | None = None, - ) -> dict[str, Any]: - """Create a new document in Dropbox. - - Use this tool when the user explicitly asks to create a new document - in Dropbox. The user MUST specify a topic before you call this tool. - - Args: - name: The document title (without extension). - file_type: Either "paper" (Dropbox Paper, default) or "docx" (Word document). - content: Optional initial content as markdown. - - Returns: - Dictionary with status, file_id, name, web_url, and message. - """ - logger.info( - f"create_dropbox_file called: name='{name}', file_type='{file_type}'" - ) - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Dropbox tool not properly configured.", - } - - try: - async with async_session_maker() as db_session: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.DROPBOX_CONNECTOR, - ) - ) - connectors = result.scalars().all() - - if not connectors: - return { - "status": "error", - "message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.", - } - - accounts = [] - for c in connectors: - cfg = c.config or {} - accounts.append( - { - "id": c.id, - "name": c.name, - "user_email": cfg.get("user_email"), - "auth_expired": cfg.get("auth_expired", False), - } - ) - - if all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected Dropbox accounts need re-authentication.", - "connector_type": "dropbox", - } - - parent_folders: dict[int, list[dict[str, str]]] = {} - for acc in accounts: - cid = acc["id"] - if acc.get("auth_expired"): - parent_folders[cid] = [] - continue - try: - client = DropboxClient(session=db_session, connector_id=cid) - items, err = await client.list_folder("") - if err: - logger.warning( - "Failed to list folders for connector %s: %s", cid, err - ) - parent_folders[cid] = [] - else: - parent_folders[cid] = [ - { - "folder_path": item.get("path_lower", ""), - "name": item["name"], - } - for item in items - if item.get(".tag") == "folder" and item.get("name") - ] - except Exception: - logger.warning( - "Error fetching folders for connector %s", - cid, - exc_info=True, - ) - parent_folders[cid] = [] - - context: dict[str, Any] = { - "accounts": accounts, - "parent_folders": parent_folders, - "supported_types": _SUPPORTED_TYPES, - } - - result = request_approval( - action_type="dropbox_file_creation", - tool_name="create_dropbox_file", - params={ - "name": name, - "file_type": file_type, - "content": content, - "connector_id": None, - "parent_folder_path": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_name = result.params.get("name", name) - final_file_type = result.params.get("file_type", file_type) - final_content = result.params.get("content", content) - final_connector_id = result.params.get("connector_id") - final_parent_folder_path = result.params.get("parent_folder_path") - - if not final_name or not final_name.strip(): - return {"status": "error", "message": "File name cannot be empty."} - - final_name = _ensure_extension(final_name, final_file_type) - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.DROPBOX_CONNECTOR, - ) - ) - connector = result.scalars().first() - else: - connector = connectors[0] - - if not connector: - return { - "status": "error", - "message": "Selected Dropbox connector is invalid.", - } - - client = DropboxClient(session=db_session, connector_id=connector.id) - - parent_path = final_parent_folder_path or "" - file_path = ( - f"{parent_path}/{final_name}" if parent_path else f"/{final_name}" - ) - - if final_file_type == "paper": - created = await client.create_paper_doc( - file_path, final_content or "" - ) - file_id = created.get("file_id", "") - web_url = created.get("url", "") - else: - docx_bytes = _markdown_to_docx(final_content or "") - created = await client.upload_file( - file_path, docx_bytes, mode="add", autorename=True - ) - file_id = created.get("id", "") - web_url = "" - - logger.info(f"Dropbox file created: id={file_id}, name={final_name}") - - kb_message_suffix = "" - try: - from app.services.dropbox import DropboxKBSyncService - - kb_service = DropboxKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - file_id=file_id, - file_name=final_name, - file_path=file_path, - web_url=web_url, - content=final_content, - connector_id=connector.id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - - return { - "status": "success", - "file_id": file_id, - "name": final_name, - "web_url": web_url, - "message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error(f"Error creating Dropbox file: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while creating the file. Please try again.", - } - - return create_dropbox_file diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py deleted file mode 100644 index 0e59e49db..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py +++ /dev/null @@ -1,301 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy import String, and_, cast, func -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.dropbox.client import DropboxClient -from app.db import ( - Document, - DocumentType, - SearchSourceConnector, - SearchSourceConnectorType, - async_session_maker, -) - -logger = logging.getLogger(__name__) - - -def create_delete_dropbox_file_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the delete_dropbox_file tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured delete_dropbox_file tool - """ - del db_session # per-call session — see docstring - - @tool - async def delete_dropbox_file( - file_name: str, - delete_from_kb: bool = False, - ) -> dict[str, Any]: - """Delete a file from Dropbox. - - Use this tool when the user explicitly asks to delete, remove, or trash - a file in Dropbox. - - Args: - file_name: The exact name of the file to delete. - delete_from_kb: Whether to also remove the file from the knowledge base. - Default is False. - - Returns: - Dictionary with: - - status: "success", "rejected", "not_found", or "error" - - file_id: Dropbox file ID (if success) - - deleted_from_kb: whether the document was removed from the knowledge base - - message: Result message - - IMPORTANT: - - If status is "rejected", the user explicitly declined. Respond with a brief - acknowledgment and do NOT retry or suggest alternatives. - - If status is "not_found", relay the exact message to the user and ask them - to verify the file name or check if it has been indexed. - """ - logger.info( - f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" - ) - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Dropbox tool not properly configured.", - } - - try: - async with async_session_maker() as db_session: - doc_result = await db_session.execute( - select(Document) - .join( - SearchSourceConnector, - Document.connector_id == SearchSourceConnector.id, - ) - .filter( - and_( - Document.search_space_id == search_space_id, - Document.document_type == DocumentType.DROPBOX_FILE, - func.lower(Document.title) == func.lower(file_name), - SearchSourceConnector.user_id == user_id, - ) - ) - .order_by(Document.updated_at.desc().nullslast()) - .limit(1) - ) - document = doc_result.scalars().first() - - if not document: - doc_result = await db_session.execute( - select(Document) - .join( - SearchSourceConnector, - Document.connector_id == SearchSourceConnector.id, - ) - .filter( - and_( - Document.search_space_id == search_space_id, - Document.document_type == DocumentType.DROPBOX_FILE, - func.lower( - cast( - Document.document_metadata["dropbox_file_name"], - String, - ) - ) - == func.lower(file_name), - SearchSourceConnector.user_id == user_id, - ) - ) - .order_by(Document.updated_at.desc().nullslast()) - .limit(1) - ) - document = doc_result.scalars().first() - - if not document: - return { - "status": "not_found", - "message": ( - f"File '{file_name}' not found in your indexed Dropbox files. " - "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " - "or (3) the file name is different." - ), - } - - if not document.connector_id: - return { - "status": "error", - "message": "Document has no associated connector.", - } - - meta = document.document_metadata or {} - file_path = meta.get("dropbox_path") - file_id = meta.get("dropbox_file_id") - document_id = document.id - - if not file_path: - return { - "status": "error", - "message": "File path is missing. Please re-index the file.", - } - - conn_result = await db_session.execute( - select(SearchSourceConnector).filter( - and_( - SearchSourceConnector.id == document.connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.DROPBOX_CONNECTOR, - ) - ) - ) - connector = conn_result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Dropbox connector not found or access denied.", - } - - cfg = connector.config or {} - if cfg.get("auth_expired"): - return { - "status": "auth_error", - "message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "dropbox", - } - - context = { - "file": { - "file_id": file_id, - "file_path": file_path, - "name": file_name, - "document_id": document_id, - }, - "account": { - "id": connector.id, - "name": connector.name, - "user_email": cfg.get("user_email"), - }, - } - - result = request_approval( - action_type="dropbox_file_trash", - tool_name="delete_dropbox_file", - params={ - "file_path": file_path, - "connector_id": connector.id, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_file_path = result.params.get("file_path", file_path) - final_connector_id = result.params.get("connector_id", connector.id) - final_delete_from_kb = result.params.get( - "delete_from_kb", delete_from_kb - ) - - if final_connector_id != connector.id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - and_( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id - == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.DROPBOX_CONNECTOR, - ) - ) - ) - validated_connector = result.scalars().first() - if not validated_connector: - return { - "status": "error", - "message": "Selected Dropbox connector is invalid or has been disconnected.", - } - actual_connector_id = validated_connector.id - else: - actual_connector_id = connector.id - - logger.info( - f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}" - ) - - client = DropboxClient( - session=db_session, connector_id=actual_connector_id - ) - await client.delete_file(final_file_path) - - logger.info(f"Dropbox file deleted: path={final_file_path}") - - trash_result: dict[str, Any] = { - "status": "success", - "file_id": file_id, - "message": f"Successfully deleted '{file_name}' from Dropbox.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - doc = doc_result.scalars().first() - if doc: - await db_session.delete(doc) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"File deleted, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" - ) - - return trash_result - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error(f"Error deleting Dropbox file: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while deleting the file. Please try again.", - } - - return delete_dropbox_file diff --git a/surfsense_backend/app/agents/new_chat/tools/generate_image.py b/surfsense_backend/app/agents/new_chat/tools/generate_image.py deleted file mode 100644 index 9e287ac51..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/generate_image.py +++ /dev/null @@ -1,280 +0,0 @@ -""" -Image generation tool for the SurfSense agent. - -This module provides a tool that generates images using litellm.aimage_generation() -and returns the result directly in a format the frontend Image component can render. - -Config resolution: -1. Uses the search space's image_generation_config_id preference -2. Falls back to Auto mode (router load balancing) if available -3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs) -""" - -import hashlib -import logging -from typing import Any - -from langchain_core.tools import tool -from litellm import aimage_generation -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.config import config -from app.db import ( - ImageGeneration, - ImageGenerationConfig, - SearchSpace, - shielded_async_session, -) -from app.services.image_gen_router_service import ( - IMAGE_GEN_AUTO_MODE_ID, - ImageGenRouterService, - is_image_gen_auto_mode, -) -from app.services.provider_api_base import resolve_api_base -from app.utils.signed_image_urls import generate_image_token - -logger = logging.getLogger(__name__) - -# Provider mapping (same as routes) -_PROVIDER_MAP = { - "OPENAI": "openai", - "AZURE_OPENAI": "azure", - "GOOGLE": "gemini", - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", - "RECRAFT": "recraft", - "OPENROUTER": "openrouter", - "XINFERENCE": "xinference", - "NSCALE": "nscale", -} - - -def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: - if custom_provider: - return custom_provider - return _PROVIDER_MAP.get(provider.upper(), provider.lower()) - - -def _build_model_string( - provider: str, model_name: str, custom_provider: str | None -) -> str: - prefix = _resolve_provider_prefix(provider, custom_provider) - return f"{prefix}/{model_name}" - - -def _get_global_image_gen_config(config_id: int) -> dict | None: - """Get a global image gen config by negative ID.""" - for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: - if cfg.get("id") == config_id: - return cfg - return None - - -def create_generate_image_tool( - search_space_id: int, - db_session: AsyncSession, -): - """ - Factory function to create the generate_image tool. - - Args: - search_space_id: The search space ID (for config resolution) - db_session: Reserved for compatibility with the tool registry. - The streaming task's ``AsyncSession`` is shared by every tool; - because AsyncSession is not concurrency-safe, parallel tool calls - would interleave flushes (e.g. podcast + image in the same step) - and poison the transaction. This tool opens its own session. - """ - del db_session # use a fresh per-call session, see below - - @tool - async def generate_image( - prompt: str, - n: int = 1, - ) -> dict[str, Any]: - """ - Generate an image from a text description using AI image models. - - Use this tool when the user asks you to create, generate, draw, or make an image. - The generated image will be displayed directly in the chat. - - Args: - prompt: A detailed text description of the image to generate. - Be specific about subject, style, colors, composition, and mood. - n: Number of images to generate (1-4). Default: 1 - - Returns: - A dictionary containing the generated image(s) for display in the chat. - """ - try: - # Use a per-call session so concurrent tool calls don't share an - # AsyncSession (which is not concurrency-safe). The streaming - # task's session is shared across every tool; without isolation, - # autoflushes from a concurrent writer poison this tool too. - async with shielded_async_session() as session: - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - if not search_space: - return {"error": "Search space not found"} - - config_id = ( - search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID - ) - - # Build generation kwargs - # NOTE: size, quality, and style are intentionally NOT passed. - # Different models support different values for these params - # (e.g. DALL-E 3 wants "hd"/"standard" for quality while - # gpt-image-1 wants "high"/"medium"/"low"; size options also - # differ). Letting the model use its own defaults avoids errors. - gen_kwargs: dict[str, Any] = {} - if n is not None and n > 1: - gen_kwargs["n"] = n - - # Call litellm based on config type - if is_image_gen_auto_mode(config_id): - if not ImageGenRouterService.is_initialized(): - return { - "error": "No image generation models configured. " - "Please add an image model in Settings > Image Models." - } - response = await ImageGenRouterService.aimage_generation( - prompt=prompt, model="auto", **gen_kwargs - ) - elif config_id < 0: - cfg = _get_global_image_gen_config(config_id) - if not cfg: - return { - "error": f"Image generation config {config_id} not found" - } - - provider_prefix = _resolve_provider_prefix( - cfg.get("provider", ""), cfg.get("custom_provider") - ) - model_string = f"{provider_prefix}/{cfg['model_name']}" - gen_kwargs["api_key"] = cfg.get("api_key") - api_base = resolve_api_base( - provider=cfg.get("provider"), - provider_prefix=provider_prefix, - config_api_base=cfg.get("api_base"), - ) - if api_base: - gen_kwargs["api_base"] = api_base - if cfg.get("api_version"): - gen_kwargs["api_version"] = cfg["api_version"] - if cfg.get("litellm_params"): - gen_kwargs.update(cfg["litellm_params"]) - - response = await aimage_generation( - prompt=prompt, model=model_string, **gen_kwargs - ) - else: - # Positive ID = user-created ImageGenerationConfig - cfg_result = await session.execute( - select(ImageGenerationConfig).filter( - ImageGenerationConfig.id == config_id - ) - ) - db_cfg = cfg_result.scalars().first() - if not db_cfg: - return { - "error": f"Image generation config {config_id} not found" - } - - provider_prefix = _resolve_provider_prefix( - db_cfg.provider.value, db_cfg.custom_provider - ) - model_string = f"{provider_prefix}/{db_cfg.model_name}" - gen_kwargs["api_key"] = db_cfg.api_key - api_base = resolve_api_base( - provider=db_cfg.provider.value, - provider_prefix=provider_prefix, - config_api_base=db_cfg.api_base, - ) - if api_base: - gen_kwargs["api_base"] = api_base - if db_cfg.api_version: - gen_kwargs["api_version"] = db_cfg.api_version - if db_cfg.litellm_params: - gen_kwargs.update(db_cfg.litellm_params) - - response = await aimage_generation( - prompt=prompt, model=model_string, **gen_kwargs - ) - - # Parse the response and store in DB - response_dict = ( - response.model_dump() - if hasattr(response, "model_dump") - else dict(response) - ) - - # Generate a random access token for this image - access_token = generate_image_token() - - # Save to image_generations table for history - db_image_gen = ImageGeneration( - prompt=prompt, - model=getattr(response, "_hidden_params", {}).get("model"), - n=n, - image_generation_config_id=config_id, - response_data=response_dict, - search_space_id=search_space_id, - access_token=access_token, - ) - session.add(db_image_gen) - await session.commit() - await session.refresh(db_image_gen) - db_image_gen_id = db_image_gen.id - - # Extract image URLs from response - images = response_dict.get("data", []) - if not images: - return {"error": "No images were generated"} - - first_image = images[0] - revised_prompt = first_image.get("revised_prompt", prompt) - - # Resolve image URL: - # - If the API returned a URL, use it directly. - # - If the API returned b64_json (e.g. gpt-image-1), serve the - # image through our backend endpoint to avoid bloating the - # LLM context with megabytes of base64 data. - if first_image.get("url"): - image_url = first_image["url"] - elif first_image.get("b64_json"): - backend_url = config.BACKEND_URL or "http://localhost:8000" - image_url = ( - f"{backend_url}/api/v1/image-generations/" - f"{db_image_gen_id}/image?token={access_token}" - ) - else: - return {"error": "No displayable image data in the response"} - - image_id = f"image-{hashlib.md5(image_url.encode()).hexdigest()[:12]}" - - return { - "id": image_id, - "assetId": image_url, - "src": image_url, - "alt": revised_prompt or prompt, - "title": "Generated Image", - "description": revised_prompt if revised_prompt != prompt else None, - "domain": "ai-generated", - "ratio": "auto", - "generated": True, - "prompt": prompt, - "image_count": len(images), - } - - except Exception as e: - logger.exception("Image generation failed in tool") - return { - "error": f"Image generation failed: {e!s}", - "prompt": prompt, - } - - return generate_image diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/__init__.py b/surfsense_backend/app/agents/new_chat/tools/gmail/__init__.py deleted file mode 100644 index 294840122..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -from app.agents.new_chat.tools.gmail.create_draft import ( - create_create_gmail_draft_tool, -) -from app.agents.new_chat.tools.gmail.read_email import ( - create_read_gmail_email_tool, -) -from app.agents.new_chat.tools.gmail.search_emails import ( - create_search_gmail_tool, -) -from app.agents.new_chat.tools.gmail.send_email import ( - create_send_gmail_email_tool, -) -from app.agents.new_chat.tools.gmail.trash_email import ( - create_trash_gmail_email_tool, -) -from app.agents.new_chat.tools.gmail.update_draft import ( - create_update_gmail_draft_tool, -) - -__all__ = [ - "create_create_gmail_draft_tool", - "create_read_gmail_email_tool", - "create_search_gmail_tool", - "create_send_gmail_email_tool", - "create_trash_gmail_email_tool", - "create_update_gmail_draft_tool", -] diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py b/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py deleted file mode 100644 index 0ca1191a4..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Any - -from app.db import SearchSourceConnector -from app.services.composio_service import ComposioService - - -def split_recipients(value: str | None) -> list[str]: - if not value: - return [] - return [recipient.strip() for recipient in value.split(",") if recipient.strip()] - - -def unwrap_composio_data(data: Any) -> Any: - if isinstance(data, dict): - inner = data.get("data", data) - if isinstance(inner, dict): - return inner.get("response_data", inner) - return inner - return data - - -async def execute_composio_gmail_tool( - connector: SearchSourceConnector, - user_id: str, - tool_name: str, - params: dict[str, Any], -) -> tuple[Any, str | None]: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return None, "Composio connected account ID not found for this Gmail connector." - - result = await ComposioService().execute_tool( - connected_account_id=cca_id, - tool_name=tool_name, - params=params, - entity_id=f"surfsense_{user_id}", - ) - if not result.get("success"): - return None, result.get("error", "Unknown Composio Gmail error") - - return unwrap_composio_data(result.get("data")), None diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py deleted file mode 100644 index c88b48d2d..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py +++ /dev/null @@ -1,361 +0,0 @@ -import asyncio -import base64 -import logging -from datetime import datetime -from email.mime.text import MIMEText -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.db import async_session_maker -from app.services.gmail import GmailToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_create_gmail_draft_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the create_gmail_draft tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured create_gmail_draft tool - """ - del db_session # per-call session — see docstring - - @tool - async def create_gmail_draft( - to: str, - subject: str, - body: str, - cc: str | None = None, - bcc: str | None = None, - ) -> dict[str, Any]: - """Create a draft email in Gmail. - - Use when the user asks to draft, compose, or prepare an email without - sending it. - - Args: - to: Recipient email address. - subject: Email subject line. - body: Email body content. - cc: Optional CC recipient(s), comma-separated. - bcc: Optional BCC recipient(s), comma-separated. - - Returns: - Dictionary with: - - status: "success", "rejected", or "error" - - draft_id: Gmail draft ID (if success) - - message: Result message - - IMPORTANT: - - If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment and do NOT retry or suggest alternatives. - - If status is "insufficient_permissions", the connector lacks the required OAuth scope. - Inform the user they need to re-authenticate and do NOT retry the action. - - Examples: - - "Draft an email to alice@example.com about the meeting" - - "Compose a reply to Bob about the project update" - """ - logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'") - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Gmail tool not properly configured. Please contact support.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error( - f"Failed to fetch creation context: {context['error']}" - ) - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Gmail accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - - logger.info( - f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'" - ) - result = request_approval( - action_type="gmail_draft_creation", - tool_name="create_gmail_draft", - params={ - "to": to, - "subject": subject, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The draft was not created. Do not ask again or suggest alternatives.", - } - - final_to = result.params.get("to", to) - final_subject = result.params.get("subject", subject) - final_body = result.params.get("body", body) - final_cc = result.params.get("cc", cc) - final_bcc = result.params.get("bcc", bcc) - final_connector_id = result.params.get("connector_id") - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", - } - actual_connector_id = connector.id - - logger.info( - f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" - ) - - is_composio_gmail = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ) - if is_composio_gmail: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } - else: - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] - ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = ( - token_encryption.decrypt_token( - config_data["refresh_token"] - ) - ) - if config_data.get("client_secret"): - config_data["client_secret"] = ( - token_encryption.decrypt_token( - config_data["client_secret"] - ) - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - - message = MIMEText(final_body) - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - - try: - if is_composio_gmail: - from app.agents.new_chat.tools.gmail.composio_helpers import ( - execute_composio_gmail_tool, - split_recipients, - ) - - created, error = await execute_composio_gmail_tool( - connector, - user_id, - "GMAIL_CREATE_EMAIL_DRAFT", - { - "user_id": "me", - "recipient_email": final_to, - "subject": final_subject, - "body": final_body, - "cc": split_recipients(final_cc), - "bcc": split_recipients(final_bcc), - "is_html": False, - }, - ) - if error: - raise RuntimeError(error) - if not isinstance(created, dict): - created = {} - else: - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - created = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .drafts() - .create(userId="me", body={"message": {"raw": raw}}) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - logger.info(f"Gmail draft created: id={created.get('id')}") - - kb_message_suffix = "" - try: - from app.services.gmail import GmailKBSyncService - - kb_service = GmailKBSyncService(db_session) - draft_message = created.get("message", {}) - kb_result = await kb_service.sync_after_create( - message_id=draft_message.get("id", ""), - thread_id=draft_message.get("threadId", ""), - subject=final_subject, - sender="me", - date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - body_text=final_body, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - draft_id=created.get("id"), - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync." - - return { - "status": "success", - "draft_id": created.get("id"), - "message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error creating Gmail draft: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while creating the draft. Please try again.", - } - - return create_gmail_draft diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py deleted file mode 100644 index 464713591..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py +++ /dev/null @@ -1,172 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker - -logger = logging.getLogger(__name__) - -_GMAIL_TYPES = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, -] - - -def create_read_gmail_email_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the read_gmail_email tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured read_gmail_email tool - """ - del db_session # per-call session — see docstring - - @tool - async def read_gmail_email(message_id: str) -> dict[str, Any]: - """Read the full content of a specific Gmail email by its message ID. - - Use after search_gmail to get the complete body of an email. - - Args: - message_id: The Gmail message ID (from search_gmail results). - - Returns: - Dictionary with status and the full email content formatted as markdown. - """ - if search_space_id is None or user_id is None: - return {"status": "error", "message": "Gmail tool not properly configured."} - - try: - async with async_session_maker() as db_session: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", - } - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found.", - } - - from app.agents.new_chat.tools.gmail.search_emails import ( - _format_gmail_summary, - ) - from app.services.composio_service import ComposioService - - service = ComposioService() - detail, error = await service.get_gmail_message_detail( - connected_account_id=cca_id, - entity_id=f"surfsense_{user_id}", - message_id=message_id, - ) - if error: - return {"status": "error", "message": error} - if not detail: - return { - "status": "not_found", - "message": f"Email with ID '{message_id}' not found.", - } - - summary = _format_gmail_summary(detail) - content = ( - f"# {summary['subject']}\n\n" - f"**From:** {summary['from']}\n" - f"**To:** {summary['to']}\n" - f"**Date:** {summary['date']}\n\n" - f"## Message Content\n\n" - f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n" - f"## Message Details\n\n" - f"- **Message ID:** {summary['message_id']}\n" - f"- **Thread ID:** {summary['thread_id']}\n" - ) - return { - "status": "success", - "message_id": summary["message_id"] or message_id, - "content": content, - } - - from app.agents.new_chat.tools.gmail.search_emails import ( - _build_credentials, - ) - - creds = _build_credentials(connector) - - from app.connectors.google_gmail_connector import GoogleGmailConnector - - gmail = GoogleGmailConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) - - detail, error = await gmail.get_message_details(message_id) - if error: - if ( - "re-authenticate" in error.lower() - or "authentication failed" in error.lower() - ): - return { - "status": "auth_error", - "message": error, - "connector_type": "gmail", - } - return {"status": "error", "message": error} - - if not detail: - return { - "status": "not_found", - "message": f"Email with ID '{message_id}' not found.", - } - - content = gmail.format_message_to_markdown(detail) - - return { - "status": "success", - "message_id": message_id, - "content": content, - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error("Error reading Gmail email: %s", e, exc_info=True) - return { - "status": "error", - "message": "Failed to read email. Please try again.", - } - - return read_gmail_email diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py deleted file mode 100644 index 3ce154c53..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py +++ /dev/null @@ -1,260 +0,0 @@ -import logging -from datetime import datetime -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker - -logger = logging.getLogger(__name__) - -_GMAIL_TYPES = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, -] - -_token_encryption_cache: object | None = None - - -def _get_token_encryption(): - global _token_encryption_cache - if _token_encryption_cache is None: - from app.config import config - from app.utils.oauth_security import TokenEncryption - - if not config.SECRET_KEY: - raise RuntimeError("SECRET_KEY not configured for token decryption.") - _token_encryption_cache = TokenEncryption(config.SECRET_KEY) - return _token_encryption_cache - - -def _build_credentials(connector: SearchSourceConnector): - """Build Google OAuth Credentials from a connector's stored config. - - Handles both native OAuth connectors (with encrypted tokens) and - Composio-backed connectors. Shared by Gmail and Calendar tools. - """ - from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES - - if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - raise ValueError("Composio connectors must use Composio tool execution.") - - from google.oauth2.credentials import Credentials - - cfg = dict(connector.config) - if cfg.get("_token_encrypted"): - enc = _get_token_encryption() - for key in ("token", "refresh_token", "client_secret"): - if cfg.get(key): - cfg[key] = enc.decrypt_token(cfg[key]) - - exp = (cfg.get("expiry") or "").replace("Z", "") - return Credentials( - token=cfg.get("token"), - refresh_token=cfg.get("refresh_token"), - token_uri=cfg.get("token_uri"), - client_id=cfg.get("client_id"), - client_secret=cfg.get("client_secret"), - scopes=cfg.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - - -def _gmail_headers(message: dict[str, Any]) -> dict[str, str]: - headers = message.get("payload", {}).get("headers", []) - return { - header.get("name", "").lower(): header.get("value", "") - for header in headers - if isinstance(header, dict) - } - - -def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]: - headers = _gmail_headers(message) - return { - "message_id": message.get("id") or message.get("messageId"), - "thread_id": message.get("threadId"), - "subject": message.get("subject") or headers.get("subject", "No Subject"), - "from": message.get("sender") or headers.get("from", "Unknown"), - "to": message.get("to") or headers.get("to", ""), - "date": message.get("messageTimestamp") or headers.get("date", ""), - "snippet": message.get("snippet") or message.get("messageText", "")[:300], - "labels": message.get("labelIds", []), - } - - -async def _search_composio_gmail( - connector: SearchSourceConnector, - user_id: str, - query: str, - max_results: int, -) -> dict[str, Any]: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found.", - } - - from app.services.composio_service import ComposioService - - service = ComposioService() - messages, _next_token, _estimate, error = await service.get_gmail_messages( - connected_account_id=cca_id, - entity_id=f"surfsense_{user_id}", - query=query, - max_results=max_results, - ) - if error: - return {"status": "error", "message": error} - - emails = [_format_gmail_summary(message) for message in messages] - return { - "status": "success", - "emails": emails, - "total": len(emails), - "message": "No emails found." if not emails else None, - } - - -def create_search_gmail_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the search_gmail tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured search_gmail tool - """ - del db_session # per-call session — see docstring - - @tool - async def search_gmail( - query: str, - max_results: int = 10, - ) -> dict[str, Any]: - """Search emails in the user's Gmail inbox using Gmail search syntax. - - Args: - query: Gmail search query, same syntax as the Gmail search bar. - Examples: "from:alice@example.com", "subject:meeting", - "is:unread", "after:2024/01/01 before:2024/02/01", - "has:attachment", "in:sent". - max_results: Number of emails to return (default 10, max 20). - - Returns: - Dictionary with status and a list of email summaries including - message_id, subject, from, date, snippet. - """ - if search_space_id is None or user_id is None: - return {"status": "error", "message": "Gmail tool not properly configured."} - - max_results = min(max_results, 20) - - try: - async with async_session_maker() as db_session: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", - } - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - return await _search_composio_gmail( - connector, str(user_id), query, max_results - ) - - creds = _build_credentials(connector) - - from app.connectors.google_gmail_connector import GoogleGmailConnector - - gmail = GoogleGmailConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) - - messages_list, error = await gmail.get_messages_list( - max_results=max_results, query=query - ) - if error: - if ( - "re-authenticate" in error.lower() - or "authentication failed" in error.lower() - ): - return { - "status": "auth_error", - "message": error, - "connector_type": "gmail", - } - return {"status": "error", "message": error} - - if not messages_list: - return { - "status": "success", - "emails": [], - "total": 0, - "message": "No emails found.", - } - - emails = [] - for msg in messages_list: - detail, err = await gmail.get_message_details(msg["id"]) - if err: - continue - headers = { - h["name"].lower(): h["value"] - for h in detail.get("payload", {}).get("headers", []) - } - emails.append( - { - "message_id": detail.get("id"), - "thread_id": detail.get("threadId"), - "subject": headers.get("subject", "No Subject"), - "from": headers.get("from", "Unknown"), - "to": headers.get("to", ""), - "date": headers.get("date", ""), - "snippet": detail.get("snippet", ""), - "labels": detail.get("labelIds", []), - } - ) - - return {"status": "success", "emails": emails, "total": len(emails)} - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error("Error searching Gmail: %s", e, exc_info=True) - return { - "status": "error", - "message": "Failed to search Gmail. Please try again.", - } - - return search_gmail diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py deleted file mode 100644 index 4d5aa3bcc..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py +++ /dev/null @@ -1,363 +0,0 @@ -import asyncio -import base64 -import logging -from datetime import datetime -from email.mime.text import MIMEText -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.db import async_session_maker -from app.services.gmail import GmailToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_send_gmail_email_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the send_gmail_email tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured send_gmail_email tool - """ - del db_session # per-call session — see docstring - - @tool - async def send_gmail_email( - to: str, - subject: str, - body: str, - cc: str | None = None, - bcc: str | None = None, - ) -> dict[str, Any]: - """Send an email via Gmail. - - Use when the user explicitly asks to send an email. This sends the - email immediately - it cannot be unsent. - - Args: - to: Recipient email address. - subject: Email subject line. - body: Email body content. - cc: Optional CC recipient(s), comma-separated. - bcc: Optional BCC recipient(s), comma-separated. - - Returns: - Dictionary with: - - status: "success", "rejected", or "error" - - message_id: Gmail message ID (if success) - - thread_id: Gmail thread ID (if success) - - message: Result message - - IMPORTANT: - - If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment and do NOT retry or suggest alternatives. - - If status is "insufficient_permissions", the connector lacks the required OAuth scope. - Inform the user they need to re-authenticate and do NOT retry the action. - - Examples: - - "Send an email to alice@example.com about the meeting" - - "Email Bob the project update" - """ - logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'") - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Gmail tool not properly configured. Please contact support.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error( - f"Failed to fetch creation context: {context['error']}" - ) - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Gmail accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - - logger.info( - f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'" - ) - result = request_approval( - action_type="gmail_email_send", - tool_name="send_gmail_email", - params={ - "to": to, - "subject": subject, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.", - } - - final_to = result.params.get("to", to) - final_subject = result.params.get("subject", subject) - final_body = result.params.get("body", body) - final_cc = result.params.get("cc", cc) - final_bcc = result.params.get("bcc", bcc) - final_connector_id = result.params.get("connector_id") - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", - } - actual_connector_id = connector.id - - logger.info( - f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" - ) - - is_composio_gmail = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ) - if is_composio_gmail: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } - else: - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] - ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = ( - token_encryption.decrypt_token( - config_data["refresh_token"] - ) - ) - if config_data.get("client_secret"): - config_data["client_secret"] = ( - token_encryption.decrypt_token( - config_data["client_secret"] - ) - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - - message = MIMEText(final_body) - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - - try: - if is_composio_gmail: - from app.agents.new_chat.tools.gmail.composio_helpers import ( - execute_composio_gmail_tool, - split_recipients, - ) - - sent, error = await execute_composio_gmail_tool( - connector, - user_id, - "GMAIL_SEND_EMAIL", - { - "user_id": "me", - "recipient_email": final_to, - "subject": final_subject, - "body": final_body, - "cc": split_recipients(final_cc), - "bcc": split_recipients(final_bcc), - "is_html": False, - }, - ) - if error: - raise RuntimeError(error) - if not isinstance(sent, dict): - sent = {} - else: - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - sent = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .messages() - .send(userId="me", body={"raw": raw}) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - logger.info( - f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}" - ) - - kb_message_suffix = "" - try: - from app.services.gmail import GmailKBSyncService - - kb_service = GmailKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - message_id=sent.get("id", ""), - thread_id=sent.get("threadId", ""), - subject=final_subject, - sender="me", - date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - body_text=final_body, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after send failed: {kb_err}") - kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." - - return { - "status": "success", - "message_id": sent.get("id"), - "thread_id": sent.get("threadId"), - "message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error sending Gmail email: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while sending the email. Please try again.", - } - - return send_gmail_email diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py deleted file mode 100644 index 95f5b4e6c..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py +++ /dev/null @@ -1,344 +0,0 @@ -import asyncio -import logging -from datetime import datetime -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.db import async_session_maker -from app.services.gmail import GmailToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_trash_gmail_email_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the trash_gmail_email tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured trash_gmail_email tool - """ - del db_session # per-call session — see docstring - - @tool - async def trash_gmail_email( - email_subject_or_id: str, - delete_from_kb: bool = False, - ) -> dict[str, Any]: - """Move an email or draft to trash in Gmail. - - Use when the user asks to delete, remove, or trash an email or draft. - - Args: - email_subject_or_id: The exact subject line or message ID of the - email to trash (as it appears in the inbox). - delete_from_kb: Whether to also remove the email from the knowledge base. - Default is False. - Set to True to remove from both Gmail and knowledge base. - - Returns: - Dictionary with: - - status: "success", "rejected", "not_found", or "error" - - message_id: Gmail message ID (if success) - - deleted_from_kb: whether the document was removed from the knowledge base - - message: Result message - - IMPORTANT: - - If status is "rejected", the user explicitly declined. Respond with a brief - acknowledgment and do NOT retry or suggest alternatives. - - If status is "not_found", relay the exact message to the user and ask them - to verify the email subject or check if it has been indexed. - - If status is "insufficient_permissions", the connector lacks the required OAuth scope. - Inform the user they need to re-authenticate and do NOT retry this tool. - Examples: - - "Delete the email about 'Meeting Cancelled'" - - "Trash the email from Bob about the project" - """ - logger.info( - f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}" - ) - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Gmail tool not properly configured. Please contact support.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_trash_context( - search_space_id, user_id, email_subject_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Email not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch trash context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Gmail account %s has expired authentication", - account.get("id"), - ) - return { - "status": "auth_error", - "message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - - email = context["email"] - message_id = email["message_id"] - document_id = email.get("document_id") - connector_id_from_context = context["account"]["id"] - - if not message_id: - return { - "status": "error", - "message": "Message ID is missing from the indexed document. Please re-index the email and try again.", - } - - logger.info( - f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="gmail_email_trash", - tool_name="trash_gmail_email", - params={ - "message_id": message_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.", - } - - final_message_id = result.params.get("message_id", message_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get( - "delete_from_kb", delete_from_kb - ) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this email.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - - logger.info( - f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}" - ) - - is_composio_gmail = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ) - if is_composio_gmail: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } - else: - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] - ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = ( - token_encryption.decrypt_token( - config_data["refresh_token"] - ) - ) - if config_data.get("client_secret"): - config_data["client_secret"] = ( - token_encryption.decrypt_token( - config_data["client_secret"] - ) - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - - try: - if is_composio_gmail: - from app.agents.new_chat.tools.gmail.composio_helpers import ( - execute_composio_gmail_tool, - ) - - _trashed, error = await execute_composio_gmail_tool( - connector, - user_id, - "GMAIL_MOVE_TO_TRASH", - {"user_id": "me", "message_id": final_message_id}, - ) - if error: - raise RuntimeError(error) - else: - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .messages() - .trash(userId="me", id=final_message_id) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {connector.id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - logger.info(f"Gmail email trashed: message_id={final_message_id}") - - trash_result: dict[str, Any] = { - "status": "success", - "message_id": final_message_id, - "message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"Email trashed, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" - ) - - return trash_result - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error trashing Gmail email: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while trashing the email. Please try again.", - } - - return trash_gmail_email diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py deleted file mode 100644 index 129b7defb..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py +++ /dev/null @@ -1,495 +0,0 @@ -import asyncio -import base64 -import logging -from datetime import datetime -from email.mime.text import MIMEText -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.db import async_session_maker -from app.services.gmail import GmailToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_update_gmail_draft_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the update_gmail_draft tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured update_gmail_draft tool - """ - del db_session # per-call session — see docstring - - @tool - async def update_gmail_draft( - draft_subject_or_id: str, - body: str, - to: str | None = None, - subject: str | None = None, - cc: str | None = None, - bcc: str | None = None, - ) -> dict[str, Any]: - """Update an existing Gmail draft. - - Use when the user asks to modify, edit, or add content to an existing - email draft. This replaces the draft content with the new version. - The user will be able to review and edit the content before it is applied. - - If the user simply wants to "edit" a draft without specifying exact changes, - generate the body yourself using your best understanding of the conversation - context. The user will review and can freely edit the content in the approval - card before confirming. - - IMPORTANT: This tool is ONLY for modifying Gmail draft content, NOT for - deleting/trashing drafts (use trash_gmail_email instead), Notion pages, - calendar events, or any other content type. - - Args: - draft_subject_or_id: The exact subject line of the draft to update - (as it appears in Gmail drafts). - body: The full updated body content for the draft. Generate this - yourself based on the user's request and conversation context. - to: Optional new recipient email address (keeps original if omitted). - subject: Optional new subject line (keeps original if omitted). - cc: Optional CC recipient(s), comma-separated. - bcc: Optional BCC recipient(s), comma-separated. - - Returns: - Dictionary with: - - status: "success", "rejected", "not_found", or "error" - - draft_id: Gmail draft ID (if success) - - message: Result message - - IMPORTANT: - - If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment and do NOT retry or suggest alternatives. - - If status is "not_found", relay the exact message to the user and ask them - to verify the draft subject or check if it has been indexed. - - If status is "insufficient_permissions", the connector lacks the required OAuth scope. - Inform the user they need to re-authenticate and do NOT retry the action. - - Examples: - - "Update the Kurseong Plan draft with the new itinerary details" - - "Edit my draft about the project proposal and change the recipient" - - "Let me edit the meeting notes draft" (call with current body content so user can edit in the approval card) - """ - logger.info( - f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'" - ) - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Gmail tool not properly configured. Please contact support.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, draft_subject_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Draft not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch update context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Gmail account %s has expired authentication", - account.get("id"), - ) - return { - "status": "auth_error", - "message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - - email = context["email"] - message_id = email["message_id"] - document_id = email.get("document_id") - connector_id_from_context = account["id"] - draft_id_from_context = context.get("draft_id") - - original_subject = email.get("subject", draft_subject_or_id) - final_subject_default = subject if subject else original_subject - final_to_default = to if to else "" - - logger.info( - f"Requesting approval for updating Gmail draft: '{original_subject}' " - f"(message_id={message_id}, draft_id={draft_id_from_context})" - ) - result = request_approval( - action_type="gmail_draft_update", - tool_name="update_gmail_draft", - params={ - "message_id": message_id, - "draft_id": draft_id_from_context, - "to": final_to_default, - "subject": final_subject_default, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.", - } - - final_to = result.params.get("to", final_to_default) - final_subject = result.params.get("subject", final_subject_default) - final_body = result.params.get("body", body) - final_cc = result.params.get("cc", cc) - final_bcc = result.params.get("bcc", bcc) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_draft_id = result.params.get("draft_id", draft_id_from_context) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this draft.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - - logger.info( - f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}" - ) - - is_composio_gmail = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ) - if is_composio_gmail: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } - else: - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] - ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = ( - token_encryption.decrypt_token( - config_data["refresh_token"] - ) - ) - if config_data.get("client_secret"): - config_data["client_secret"] = ( - token_encryption.decrypt_token( - config_data["client_secret"] - ) - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - - # Resolve draft_id if not already available - if not final_draft_id: - logger.info( - f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}" - ) - if is_composio_gmail: - final_draft_id = await _find_composio_draft_id_by_message( - connector, user_id, message_id - ) - else: - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - final_draft_id = await _find_draft_id_by_message( - gmail_service, message_id - ) - - if not final_draft_id: - return { - "status": "error", - "message": ( - "Could not find this draft in Gmail. " - "It may have already been sent or deleted." - ), - } - - message = MIMEText(final_body) - if final_to: - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - - try: - if is_composio_gmail: - from app.agents.new_chat.tools.gmail.composio_helpers import ( - execute_composio_gmail_tool, - split_recipients, - ) - - updated, error = await execute_composio_gmail_tool( - connector, - user_id, - "GMAIL_UPDATE_DRAFT", - { - "user_id": "me", - "draft_id": final_draft_id, - "recipient_email": final_to, - "subject": final_subject, - "body": final_body, - "cc": split_recipients(final_cc), - "bcc": split_recipients(final_bcc), - "is_html": False, - }, - ) - if error: - raise RuntimeError(error) - if not isinstance(updated, dict): - updated = {} - else: - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - updated = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .drafts() - .update( - userId="me", - id=final_draft_id, - body={"message": {"raw": raw}}, - ) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {connector.id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - if isinstance(api_err, HttpError) and api_err.resp.status == 404: - return { - "status": "error", - "message": "Draft no longer exists in Gmail. It may have been sent or deleted.", - } - raise - - logger.info(f"Gmail draft updated: id={updated.get('id')}") - - kb_message_suffix = "" - if document_id: - try: - from sqlalchemy.future import select as sa_select - from sqlalchemy.orm.attributes import flag_modified - - from app.db import Document - - doc_result = await db_session.execute( - sa_select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - document.source_markdown = final_body - document.title = final_subject - meta = dict(document.document_metadata or {}) - meta["subject"] = final_subject - meta["draft_id"] = updated.get("id", final_draft_id) - updated_msg = updated.get("message", {}) - if updated_msg.get("id"): - meta["message_id"] = updated_msg["id"] - document.document_metadata = meta - flag_modified(document, "document_metadata") - await db_session.commit() - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - logger.info( - f"KB document {document_id} updated for draft {final_draft_id}" - ) - else: - kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB update after draft edit failed: {kb_err}") - await db_session.rollback() - kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." - - return { - "status": "success", - "draft_id": updated.get("id"), - "message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error updating Gmail draft: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while updating the draft. Please try again.", - } - - return update_gmail_draft - - -async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str | None: - """Look up a draft's ID by its message ID via the Gmail API.""" - try: - page_token = None - while True: - kwargs: dict[str, Any] = {"userId": "me", "maxResults": 100} - if page_token: - kwargs["pageToken"] = page_token - - response = await asyncio.get_event_loop().run_in_executor( - None, - lambda kwargs=kwargs: ( - gmail_service.users().drafts().list(**kwargs).execute() - ), - ) - - for draft in response.get("drafts", []): - if draft.get("message", {}).get("id") == message_id: - return draft["id"] - - page_token = response.get("nextPageToken") - if not page_token: - break - - return None - except Exception as e: - logger.warning(f"Failed to look up draft by message_id: {e}") - return None - - -async def _find_composio_draft_id_by_message( - connector: Any, user_id: str, message_id: str -) -> str | None: - from app.agents.new_chat.tools.gmail.composio_helpers import ( - execute_composio_gmail_tool, - ) - - page_token = "" - while True: - params: dict[str, Any] = { - "user_id": "me", - "max_results": 100, - "verbose": False, - } - if page_token: - params["page_token"] = page_token - - data, error = await execute_composio_gmail_tool( - connector, user_id, "GMAIL_LIST_DRAFTS", params - ) - if error or not isinstance(data, dict): - return None - - for draft in data.get("drafts", []): - if draft.get("message", {}).get("id") == message_id: - return draft.get("id") - - page_token = data.get("nextPageToken") or data.get("next_page_token") or "" - if not page_token: - return None diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/__init__.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/__init__.py deleted file mode 100644 index 13d4c06cb..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from app.agents.new_chat.tools.google_calendar.create_event import ( - create_create_calendar_event_tool, -) -from app.agents.new_chat.tools.google_calendar.delete_event import ( - create_delete_calendar_event_tool, -) -from app.agents.new_chat.tools.google_calendar.search_events import ( - create_search_calendar_events_tool, -) -from app.agents.new_chat.tools.google_calendar.update_event import ( - create_update_calendar_event_tool, -) - -__all__ = [ - "create_create_calendar_event_tool", - "create_delete_calendar_event_tool", - "create_search_calendar_events_tool", - "create_update_calendar_event_tool", -] diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py deleted file mode 100644 index dec92cc8b..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py +++ /dev/null @@ -1,382 +0,0 @@ -import asyncio -import logging -from datetime import datetime -from typing import Any - -from google.oauth2.credentials import Credentials -from googleapiclient.discovery import build -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.db import async_session_maker -from app.services.google_calendar import GoogleCalendarToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_create_calendar_event_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the create_calendar_event tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured create_calendar_event tool - """ - del db_session # per-call session — see docstring - - @tool - async def create_calendar_event( - summary: str, - start_datetime: str, - end_datetime: str, - description: str | None = None, - location: str | None = None, - attendees: list[str] | None = None, - ) -> dict[str, Any]: - """Create a new event on Google Calendar. - - Use when the user asks to schedule, create, or add a calendar event. - Ask for event details if not provided. - - Args: - summary: The event title. - start_datetime: Start time in ISO 8601 format (e.g. "2026-03-20T10:00:00"). - end_datetime: End time in ISO 8601 format (e.g. "2026-03-20T11:00:00"). - description: Optional event description. - location: Optional event location. - attendees: Optional list of attendee email addresses. - - Returns: - Dictionary with: - - status: "success", "rejected", "auth_error", or "error" - - event_id: Google Calendar event ID (if success) - - html_link: URL to open the event (if success) - - message: Result message - - IMPORTANT: - - If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment and do NOT retry or suggest alternatives. - - Examples: - - "Schedule a meeting with John tomorrow at 10am" - - "Create a calendar event for the team standup" - """ - logger.info( - f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'" - ) - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Google Calendar tool not properly configured. Please contact support.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = GoogleCalendarToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error( - f"Failed to fetch creation context: {context['error']}" - ) - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning( - "All Google Calendar accounts have expired authentication" - ) - return { - "status": "auth_error", - "message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_calendar", - } - - logger.info( - f"Requesting approval for creating calendar event: summary='{summary}'" - ) - result = request_approval( - action_type="google_calendar_event_creation", - tool_name="create_calendar_event", - params={ - "summary": summary, - "start_datetime": start_datetime, - "end_datetime": end_datetime, - "description": description, - "location": location, - "attendees": attendees, - "timezone": context.get("timezone"), - "connector_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The event was not created. Do not ask again or suggest alternatives.", - } - - final_summary = result.params.get("summary", summary) - final_start_datetime = result.params.get( - "start_datetime", start_datetime - ) - final_end_datetime = result.params.get("end_datetime", end_datetime) - final_description = result.params.get("description", description) - final_location = result.params.get("location", location) - final_attendees = result.params.get("attendees", attendees) - final_connector_id = result.params.get("connector_id") - - if not final_summary or not final_summary.strip(): - return { - "status": "error", - "message": "Event summary cannot be empty.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _calendar_types = [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Calendar connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", - } - actual_connector_id = connector.id - - logger.info( - f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}" - ) - - is_composio_calendar = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ) - if is_composio_calendar: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this connector.", - } - else: - config_data = dict(connector.config) - - from app.config import config as app_config - from app.utils.oauth_security import TokenEncryption - - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and app_config.SECRET_KEY: - token_encryption = TokenEncryption(app_config.SECRET_KEY) - for key in ("token", "refresh_token", "client_secret"): - if config_data.get(key): - config_data[key] = token_encryption.decrypt_token( - config_data[key] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - - tz = context.get("timezone", "UTC") - event_body: dict[str, Any] = { - "summary": final_summary, - "start": {"dateTime": final_start_datetime, "timeZone": tz}, - "end": {"dateTime": final_end_datetime, "timeZone": tz}, - } - if final_description: - event_body["description"] = final_description - if final_location: - event_body["location"] = final_location - if final_attendees: - event_body["attendees"] = [ - {"email": e.strip()} for e in final_attendees if e.strip() - ] - - try: - if is_composio_calendar: - from app.services.composio_service import ComposioService - - composio_params = { - "calendar_id": "primary", - "summary": final_summary, - "start_datetime": final_start_datetime, - "end_datetime": final_end_datetime, - "timezone": tz, - "attendees": final_attendees or [], - } - if final_description: - composio_params["description"] = final_description - if final_location: - composio_params["location"] = final_location - - composio_result = await ComposioService().execute_tool( - connected_account_id=cca_id, - tool_name="GOOGLECALENDAR_CREATE_EVENT", - params=composio_params, - entity_id=f"surfsense_{user_id}", - ) - if not composio_result.get("success"): - raise RuntimeError( - composio_result.get( - "error", "Unknown Composio Calendar error" - ) - ) - created = composio_result.get("data", {}) - if isinstance(created, dict): - created = created.get("data", created) - if isinstance(created, dict): - created = created.get("response_data", created) - else: - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - created = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .insert(calendarId="primary", body=event_body) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - logger.info( - f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}" - ) - - kb_message_suffix = "" - try: - from app.services.google_calendar import GoogleCalendarKBSyncService - - kb_service = GoogleCalendarKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - event_id=created.get("id"), - event_summary=final_summary, - calendar_id="primary", - start_time=final_start_datetime, - end_time=final_end_datetime, - location=final_location, - html_link=created.get("htmlLink"), - description=final_description, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." - - return { - "status": "success", - "event_id": created.get("id"), - "html_link": created.get("htmlLink"), - "message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error creating calendar event: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while creating the event. Please try again.", - } - - return create_calendar_event diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py deleted file mode 100644 index e7e891b08..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py +++ /dev/null @@ -1,340 +0,0 @@ -import asyncio -import logging -from datetime import datetime -from typing import Any - -from google.oauth2.credentials import Credentials -from googleapiclient.discovery import build -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.db import async_session_maker -from app.services.google_calendar import GoogleCalendarToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_delete_calendar_event_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the delete_calendar_event tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured delete_calendar_event tool - """ - del db_session # per-call session — see docstring - - @tool - async def delete_calendar_event( - event_title_or_id: str, - delete_from_kb: bool = False, - ) -> dict[str, Any]: - """Delete a Google Calendar event. - - Use when the user asks to delete, remove, or cancel a calendar event. - - Args: - event_title_or_id: The exact title or event ID of the event to delete. - delete_from_kb: Whether to also remove the event from the knowledge base. - Default is False. - Set to True to remove from both Google Calendar and knowledge base. - - Returns: - Dictionary with: - - status: "success", "rejected", "not_found", "auth_error", or "error" - - event_id: Google Calendar event ID (if success) - - deleted_from_kb: whether the document was removed from the knowledge base - - message: Result message - - IMPORTANT: - - If status is "rejected", the user explicitly declined. Respond with a brief - acknowledgment and do NOT retry or suggest alternatives. - - If status is "not_found", relay the exact message to the user and ask them - to verify the event name or check if it has been indexed. - Examples: - - "Delete the team standup event" - - "Cancel my dentist appointment on Friday" - """ - logger.info( - f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}" - ) - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Google Calendar tool not properly configured. Please contact support.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = GoogleCalendarToolMetadataService(db_session) - context = await metadata_service.get_deletion_context( - search_space_id, user_id, event_title_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Event not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch deletion context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Google Calendar account %s has expired authentication", - account.get("id"), - ) - return { - "status": "auth_error", - "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_calendar", - } - - event = context["event"] - event_id = event["event_id"] - document_id = event.get("document_id") - connector_id_from_context = context["account"]["id"] - - if not event_id: - return { - "status": "error", - "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", - } - - logger.info( - f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="google_calendar_event_deletion", - tool_name="delete_calendar_event", - params={ - "event_id": event_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.", - } - - final_event_id = result.params.get("event_id", event_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get( - "delete_from_kb", delete_from_kb - ) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this event.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _calendar_types = [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Calendar connector is invalid or has been disconnected.", - } - - actual_connector_id = connector.id - - logger.info( - f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}" - ) - - is_composio_calendar = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ) - if is_composio_calendar: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this connector.", - } - else: - config_data = dict(connector.config) - - from app.config import config as app_config - from app.utils.oauth_security import TokenEncryption - - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and app_config.SECRET_KEY: - token_encryption = TokenEncryption(app_config.SECRET_KEY) - for key in ("token", "refresh_token", "client_secret"): - if config_data.get(key): - config_data[key] = token_encryption.decrypt_token( - config_data[key] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - - try: - if is_composio_calendar: - from app.services.composio_service import ComposioService - - composio_result = await ComposioService().execute_tool( - connected_account_id=cca_id, - tool_name="GOOGLECALENDAR_DELETE_EVENT", - params={ - "calendar_id": "primary", - "event_id": final_event_id, - }, - entity_id=f"surfsense_{user_id}", - ) - if not composio_result.get("success"): - raise RuntimeError( - composio_result.get( - "error", "Unknown Composio Calendar error" - ) - ) - else: - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .delete(calendarId="primary", eventId=final_event_id) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - logger.info(f"Calendar event deleted: event_id={final_event_id}") - - delete_result: dict[str, Any] = { - "status": "success", - "event_id": final_event_id, - "message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - delete_result["warning"] = ( - f"Event deleted, but failed to remove from knowledge base: {e!s}" - ) - - delete_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - delete_result["message"] = ( - f"{delete_result.get('message', '')} (also removed from knowledge base)" - ) - - return delete_result - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error deleting calendar event: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while deleting the event. Please try again.", - } - - return delete_calendar_event diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py deleted file mode 100644 index e5f18f675..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py +++ /dev/null @@ -1,187 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.agents.new_chat.tools.gmail.search_emails import _build_credentials -from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker - -logger = logging.getLogger(__name__) - -_CALENDAR_TYPES = [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, -] - - -def _to_calendar_boundary(value: str, *, is_end: bool) -> str: - if "T" in value: - return value - time = "23:59:59" if is_end else "00:00:00" - return f"{value}T{time}Z" - - -def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]: - events = [] - for ev in events_raw: - start = ev.get("start", {}) - end = ev.get("end", {}) - attendees_raw = ev.get("attendees", []) - events.append( - { - "event_id": ev.get("id"), - "summary": ev.get("summary", "No Title"), - "start": start.get("dateTime") or start.get("date", ""), - "end": end.get("dateTime") or end.get("date", ""), - "location": ev.get("location", ""), - "description": ev.get("description", ""), - "html_link": ev.get("htmlLink", ""), - "attendees": [a.get("email", "") for a in attendees_raw[:10]], - "status": ev.get("status", ""), - } - ) - return events - - -def create_search_calendar_events_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the search_calendar_events tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured search_calendar_events tool - """ - del db_session # per-call session — see docstring - - @tool - async def search_calendar_events( - start_date: str, - end_date: str, - max_results: int = 25, - ) -> dict[str, Any]: - """Search Google Calendar events within a date range. - - Args: - start_date: Start date in YYYY-MM-DD format (e.g. "2026-04-01"). - end_date: End date in YYYY-MM-DD format (e.g. "2026-04-30"). - max_results: Maximum number of events to return (default 25, max 50). - - Returns: - Dictionary with status and a list of events including - event_id, summary, start, end, location, attendees. - """ - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Calendar tool not properly configured.", - } - - max_results = min(max_results, 50) - - try: - async with async_session_maker() as db_session: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", - } - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this connector.", - } - - from app.services.composio_service import ComposioService - - events_raw, error = await ComposioService().get_calendar_events( - connected_account_id=cca_id, - entity_id=f"surfsense_{user_id}", - time_min=_to_calendar_boundary(start_date, is_end=False), - time_max=_to_calendar_boundary(end_date, is_end=True), - max_results=max_results, - ) - if not events_raw and not error: - error = "No events found in the specified date range." - else: - creds = _build_credentials(connector) - - from app.connectors.google_calendar_connector import ( - GoogleCalendarConnector, - ) - - cal = GoogleCalendarConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) - - events_raw, error = await cal.get_all_primary_calendar_events( - start_date=start_date, - end_date=end_date, - max_results=max_results, - ) - - if error: - if ( - "re-authenticate" in error.lower() - or "authentication failed" in error.lower() - ): - return { - "status": "auth_error", - "message": error, - "connector_type": "google_calendar", - } - if "no events found" in error.lower(): - return { - "status": "success", - "events": [], - "total": 0, - "message": error, - } - return {"status": "error", "message": error} - - events = _format_calendar_events(events_raw) - - return {"status": "success", "events": events, "total": len(events)} - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error("Error searching calendar events: %s", e, exc_info=True) - return { - "status": "error", - "message": "Failed to search calendar events. Please try again.", - } - - return search_calendar_events diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py deleted file mode 100644 index b8561fee6..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py +++ /dev/null @@ -1,419 +0,0 @@ -import asyncio -import logging -from datetime import datetime -from typing import Any - -from google.oauth2.credentials import Credentials -from googleapiclient.discovery import build -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.db import async_session_maker -from app.services.google_calendar import GoogleCalendarToolMetadataService - -logger = logging.getLogger(__name__) - - -def _is_date_only(value: str) -> bool: - """Return True when *value* looks like a bare date (YYYY-MM-DD) with no time component.""" - return len(value) <= 10 and "T" not in value - - -def _build_time_body(value: str, context: dict[str, Any] | Any) -> dict[str, str]: - """Build a Google Calendar start/end body using ``date`` for all-day - events and ``dateTime`` for timed events.""" - if _is_date_only(value): - return {"date": value} - tz = context.get("timezone", "UTC") if isinstance(context, dict) else "UTC" - return {"dateTime": value, "timeZone": tz} - - -def create_update_calendar_event_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the update_calendar_event tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured update_calendar_event tool - """ - del db_session # per-call session — see docstring - - @tool - async def update_calendar_event( - event_title_or_id: str, - new_summary: str | None = None, - new_start_datetime: str | None = None, - new_end_datetime: str | None = None, - new_description: str | None = None, - new_location: str | None = None, - new_attendees: list[str] | None = None, - ) -> dict[str, Any]: - """Update an existing Google Calendar event. - - Use when the user asks to modify, reschedule, or change a calendar event. - - Args: - event_title_or_id: The exact title or event ID of the event to update. - new_summary: New event title (if changing). - new_start_datetime: New start time in ISO 8601 format (if rescheduling). - new_end_datetime: New end time in ISO 8601 format (if rescheduling). - new_description: New event description (if changing). - new_location: New event location (if changing). - new_attendees: New list of attendee email addresses (if changing). - - Returns: - Dictionary with: - - status: "success", "rejected", "not_found", "auth_error", or "error" - - event_id: Google Calendar event ID (if success) - - html_link: URL to open the event (if success) - - message: Result message - - IMPORTANT: - - If status is "rejected", the user explicitly declined. Respond with a brief - acknowledgment and do NOT retry or suggest alternatives. - - If status is "not_found", relay the exact message to the user and ask them - to verify the event name or check if it has been indexed. - Examples: - - "Reschedule the team standup to 3pm" - - "Change the location of my dentist appointment" - """ - logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'") - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Google Calendar tool not properly configured. Please contact support.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = GoogleCalendarToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, event_title_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Event not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch update context: {error_msg}") - return {"status": "error", "message": error_msg} - - if context.get("auth_expired"): - logger.warning("Google Calendar account has expired authentication") - return { - "status": "auth_error", - "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_calendar", - } - - event = context["event"] - event_id = event["event_id"] - document_id = event.get("document_id") - connector_id_from_context = context["account"]["id"] - - if not event_id: - return { - "status": "error", - "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", - } - - logger.info( - f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})" - ) - result = request_approval( - action_type="google_calendar_event_update", - tool_name="update_calendar_event", - params={ - "event_id": event_id, - "document_id": document_id, - "connector_id": connector_id_from_context, - "new_summary": new_summary, - "new_start_datetime": new_start_datetime, - "new_end_datetime": new_end_datetime, - "new_description": new_description, - "new_location": new_location, - "new_attendees": new_attendees, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The event was not updated. Do not ask again or suggest alternatives.", - } - - final_event_id = result.params.get("event_id", event_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_new_summary = result.params.get("new_summary", new_summary) - final_new_start_datetime = result.params.get( - "new_start_datetime", new_start_datetime - ) - final_new_end_datetime = result.params.get( - "new_end_datetime", new_end_datetime - ) - final_new_description = result.params.get( - "new_description", new_description - ) - final_new_location = result.params.get("new_location", new_location) - final_new_attendees = result.params.get("new_attendees", new_attendees) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this event.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _calendar_types = [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Calendar connector is invalid or has been disconnected.", - } - - actual_connector_id = connector.id - - logger.info( - f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}" - ) - - is_composio_calendar = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ) - if is_composio_calendar: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this connector.", - } - else: - config_data = dict(connector.config) - - from app.config import config as app_config - from app.utils.oauth_security import TokenEncryption - - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and app_config.SECRET_KEY: - token_encryption = TokenEncryption(app_config.SECRET_KEY) - for key in ("token", "refresh_token", "client_secret"): - if config_data.get(key): - config_data[key] = token_encryption.decrypt_token( - config_data[key] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - - update_body: dict[str, Any] = {} - if final_new_summary is not None: - update_body["summary"] = final_new_summary - if final_new_start_datetime is not None: - update_body["start"] = _build_time_body( - final_new_start_datetime, context - ) - if final_new_end_datetime is not None: - update_body["end"] = _build_time_body( - final_new_end_datetime, context - ) - if final_new_description is not None: - update_body["description"] = final_new_description - if final_new_location is not None: - update_body["location"] = final_new_location - if final_new_attendees is not None: - update_body["attendees"] = [ - {"email": e.strip()} for e in final_new_attendees if e.strip() - ] - - if not update_body: - return { - "status": "error", - "message": "No changes specified. Please provide at least one field to update.", - } - - try: - if is_composio_calendar: - from app.services.composio_service import ComposioService - - composio_params: dict[str, Any] = { - "calendar_id": "primary", - "event_id": final_event_id, - } - if final_new_summary is not None: - composio_params["summary"] = final_new_summary - if final_new_start_datetime is not None: - composio_params["start_time"] = final_new_start_datetime - if final_new_end_datetime is not None: - composio_params["end_time"] = final_new_end_datetime - if final_new_description is not None: - composio_params["description"] = final_new_description - if final_new_location is not None: - composio_params["location"] = final_new_location - if final_new_attendees is not None: - composio_params["attendees"] = [ - e.strip() for e in final_new_attendees if e.strip() - ] - if not _is_date_only( - final_new_start_datetime or final_new_end_datetime or "" - ): - composio_params["timezone"] = context.get("timezone", "UTC") - - composio_result = await ComposioService().execute_tool( - connected_account_id=cca_id, - tool_name="GOOGLECALENDAR_PATCH_EVENT", - params=composio_params, - entity_id=f"surfsense_{user_id}", - ) - if not composio_result.get("success"): - raise RuntimeError( - composio_result.get( - "error", "Unknown Composio Calendar error" - ) - ) - updated = composio_result.get("data", {}) - if isinstance(updated, dict): - updated = updated.get("data", updated) - if isinstance(updated, dict): - updated = updated.get("response_data", updated) - else: - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - updated = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .patch( - calendarId="primary", - eventId=final_event_id, - body=update_body, - ) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - logger.info(f"Calendar event updated: event_id={final_event_id}") - - kb_message_suffix = "" - if document_id is not None: - try: - from app.services.google_calendar import ( - GoogleCalendarKBSyncService, - ) - - kb_service = GoogleCalendarKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=document_id, - event_id=final_event_id, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after update failed: {kb_err}") - kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." - - return { - "status": "success", - "event_id": final_event_id, - "html_link": updated.get("htmlLink"), - "message": f"Successfully updated the calendar event.{kb_message_suffix}", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error updating calendar event: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while updating the event. Please try again.", - } - - return update_calendar_event diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/__init__.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/__init__.py deleted file mode 100644 index 9c63bceb1..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from app.agents.new_chat.tools.google_drive.create_file import ( - create_create_google_drive_file_tool, -) -from app.agents.new_chat.tools.google_drive.trash_file import ( - create_delete_google_drive_file_tool, -) - -__all__ = [ - "create_create_google_drive_file_tool", - "create_delete_google_drive_file_tool", -] diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py deleted file mode 100644 index 66199ca67..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py +++ /dev/null @@ -1,340 +0,0 @@ -import logging -from typing import Any, Literal - -from googleapiclient.errors import HttpError -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.google_drive.client import GoogleDriveClient -from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET -from app.db import async_session_maker -from app.services.google_drive import GoogleDriveToolMetadataService - -logger = logging.getLogger(__name__) - -_MIME_MAP: dict[str, str] = { - "google_doc": GOOGLE_DOC, - "google_sheet": GOOGLE_SHEET, -} - - -def create_create_google_drive_file_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the create_google_drive_file tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - search_space_id: Search space ID to find the Google Drive connector - user_id: User ID for fetching user-specific context - - Returns: - Configured create_google_drive_file tool - """ - del db_session # per-call session — see docstring - - @tool - async def create_google_drive_file( - name: str, - file_type: Literal["google_doc", "google_sheet"], - content: str | None = None, - ) -> dict[str, Any]: - """Create a new Google Doc or Google Sheet in Google Drive. - - Use this tool when the user explicitly asks to create a new document - or spreadsheet in Google Drive. The user MUST specify a topic before - you call this tool. If the request does not contain a topic (e.g. - "create a drive doc" or "make a Google Sheet"), ask what the file - should be about. Never call this tool without a clear topic from the user. - - Args: - name: The file name (without extension). - file_type: Either "google_doc" or "google_sheet". - content: Optional initial content. Generate from the user's topic. - For google_doc, provide markdown text. For google_sheet, provide CSV-formatted text. - - Returns: - Dictionary with: - - status: "success", "rejected", or "error" - - file_id: Google Drive file ID (if success) - - name: File name (if success) - - web_view_link: URL to open the file (if success) - - message: Result message - - IMPORTANT: - - If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment and do NOT retry or suggest alternatives. - - If status is "insufficient_permissions", the connector lacks the required OAuth scope. - Inform the user they need to re-authenticate and do NOT retry the action. - - Examples: - - "Create a Google Doc with today's meeting notes" - - "Create a spreadsheet for the 2026 budget" - """ - logger.info( - f"create_google_drive_file called: name='{name}', type='{file_type}'" - ) - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Google Drive tool not properly configured. Please contact support.", - } - - if file_type not in _MIME_MAP: - return { - "status": "error", - "message": f"Unsupported file type '{file_type}'. Use 'google_doc' or 'google_sheet'.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = GoogleDriveToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error( - f"Failed to fetch creation context: {context['error']}" - ) - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning( - "All Google Drive accounts have expired authentication" - ) - return { - "status": "auth_error", - "message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_drive", - } - - logger.info( - f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'" - ) - result = request_approval( - action_type="google_drive_file_creation", - tool_name="create_google_drive_file", - params={ - "name": name, - "file_type": file_type, - "content": content, - "connector_id": None, - "parent_folder_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The file was not created. Do not ask again or suggest alternatives.", - } - - final_name = result.params.get("name", name) - final_file_type = result.params.get("file_type", file_type) - final_content = result.params.get("content", content) - final_connector_id = result.params.get("connector_id") - final_parent_folder_id = result.params.get("parent_folder_id") - - if not final_name or not final_name.strip(): - return {"status": "error", "message": "File name cannot be empty."} - - mime_type = _MIME_MAP.get(final_file_type) - if not mime_type: - return { - "status": "error", - "message": f"Unsupported file type '{final_file_type}'.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _drive_types = [ - SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_drive_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Drive connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_drive_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.", - } - actual_connector_id = connector.id - - logger.info( - f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" - ) - - is_composio_drive = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ) - if is_composio_drive: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this Drive connector.", - } - client = GoogleDriveClient( - session=db_session, - connector_id=actual_connector_id, - ) - try: - if is_composio_drive: - from app.services.composio_service import ComposioService - - params: dict[str, Any] = { - "name": final_name, - "mimeType": mime_type, - "fields": "id,name,webViewLink,mimeType", - } - if final_parent_folder_id: - params["parents"] = [final_parent_folder_id] - if final_content: - params["description"] = final_content[:4096] - - result = await ComposioService().execute_tool( - connected_account_id=cca_id, - tool_name="GOOGLEDRIVE_CREATE_FILE", - params=params, - entity_id=f"surfsense_{user_id}", - ) - if not result.get("success"): - raise RuntimeError( - result.get("error", "Unknown Composio Drive error") - ) - created = result.get("data", {}) - if isinstance(created, dict): - created = created.get("data", created) - if isinstance(created, dict): - created = created.get("response_data", created) - if not isinstance(created, dict): - created = {} - else: - created = await client.create_file( - name=final_name, - mime_type=mime_type, - parent_folder_id=final_parent_folder_id, - content=final_content, - ) - except HttpError as http_err: - if http_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {http_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - logger.info( - f"Google Drive file created: id={created.get('id')}, name={created.get('name')}" - ) - - kb_message_suffix = "" - try: - from app.services.google_drive import GoogleDriveKBSyncService - - kb_service = GoogleDriveKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - file_id=created.get("id"), - file_name=created.get("name", final_name), - mime_type=mime_type, - web_view_link=created.get("webViewLink"), - content=final_content, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - - return { - "status": "success", - "file_id": created.get("id"), - "name": created.get("name"), - "web_view_link": created.get("webViewLink"), - "message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error creating Google Drive file: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while creating the file. Please try again.", - } - - return create_google_drive_file diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py deleted file mode 100644 index b3c9240d8..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py +++ /dev/null @@ -1,299 +0,0 @@ -import logging -from typing import Any - -from googleapiclient.errors import HttpError -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.google_drive.client import GoogleDriveClient -from app.db import async_session_maker -from app.services.google_drive import GoogleDriveToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_delete_google_drive_file_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the delete_google_drive_file tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - search_space_id: Search space ID to find the Google Drive connector - user_id: User ID for fetching user-specific context - - Returns: - Configured delete_google_drive_file tool - """ - del db_session # per-call session — see docstring - - @tool - async def delete_google_drive_file( - file_name: str, - delete_from_kb: bool = False, - ) -> dict[str, Any]: - """Move a Google Drive file to trash. - - Use this tool when the user explicitly asks to delete, remove, or trash - a file in Google Drive. - - Args: - file_name: The exact name of the file to trash (as it appears in Drive). - delete_from_kb: Whether to also remove the file from the knowledge base. - Default is False. - Set to True to remove from both Google Drive and knowledge base. - - Returns: - Dictionary with: - - status: "success", "rejected", "not_found", or "error" - - file_id: Google Drive file ID (if success) - - deleted_from_kb: whether the document was removed from the knowledge base - - message: Result message - - IMPORTANT: - - If status is "rejected", the user explicitly declined. Respond with a brief - acknowledgment and do NOT retry or suggest alternatives. - - If status is "not_found", relay the exact message to the user and ask them - to verify the file name or check if it has been indexed. - - If status is "insufficient_permissions", the connector lacks the required OAuth scope. - Inform the user they need to re-authenticate and do NOT retry this tool. - Examples: - - "Delete the 'Meeting Notes' file from Google Drive" - - "Trash the 'Old Budget' spreadsheet" - """ - logger.info( - f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" - ) - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Google Drive tool not properly configured. Please contact support.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = GoogleDriveToolMetadataService(db_session) - context = await metadata_service.get_trash_context( - search_space_id, user_id, file_name - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"File not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch trash context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Google Drive account %s has expired authentication", - account.get("id"), - ) - return { - "status": "auth_error", - "message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_drive", - } - - file = context["file"] - file_id = file["file_id"] - document_id = file.get("document_id") - connector_id_from_context = context["account"]["id"] - - if not file_id: - return { - "status": "error", - "message": "File ID is missing from the indexed document. Please re-index the file and try again.", - } - - logger.info( - f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="google_drive_file_trash", - tool_name="delete_google_drive_file", - params={ - "file_id": file_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.", - } - - final_file_id = result.params.get("file_id", file_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get( - "delete_from_kb", delete_from_kb - ) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this file.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _drive_types = [ - SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_drive_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Drive connector is invalid or has been disconnected.", - } - - logger.info( - f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}" - ) - - is_composio_drive = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ) - if is_composio_drive: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this Drive connector.", - } - - client = GoogleDriveClient( - session=db_session, - connector_id=connector.id, - ) - try: - if is_composio_drive: - from app.services.composio_service import ComposioService - - result = await ComposioService().execute_tool( - connected_account_id=cca_id, - tool_name="GOOGLEDRIVE_TRASH_FILE", - params={"file_id": final_file_id}, - entity_id=f"surfsense_{user_id}", - ) - if not result.get("success"): - raise RuntimeError( - result.get("error", "Unknown Composio Drive error") - ) - else: - await client.trash_file(file_id=final_file_id) - except HttpError as http_err: - if http_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {connector.id}: {http_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - logger.info( - f"Google Drive file deleted (moved to trash): file_id={final_file_id}" - ) - - trash_result: dict[str, Any] = { - "status": "success", - "file_id": final_file_id, - "message": f"Successfully moved '{file['name']}' to trash.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"File moved to trash, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" - ) - - return trash_result - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error deleting Google Drive file: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while trashing the file. Please try again.", - } - - return delete_google_drive_file diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/__init__.py b/surfsense_backend/app/agents/new_chat/tools/linear/__init__.py deleted file mode 100644 index 31acf1e2a..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/linear/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Linear tools for creating, updating, and deleting issues.""" - -from .create_issue import create_create_linear_issue_tool -from .delete_issue import create_delete_linear_issue_tool -from .update_issue import create_update_linear_issue_tool - -__all__ = [ - "create_create_linear_issue_tool", - "create_delete_linear_issue_tool", - "create_update_linear_issue_tool", -] diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py deleted file mode 100644 index f897bee7a..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py +++ /dev/null @@ -1,266 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.linear_connector import LinearAPIError, LinearConnector -from app.db import async_session_maker -from app.services.linear import LinearToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_create_linear_issue_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, - connector_id: int | None = None, -): - """Factory function to create the create_linear_issue tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker`. This is critical for the compiled-agent - cache: the compiled graph (and therefore this closure) is reused - across HTTP requests, so capturing a per-request session here would - surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - search_space_id: Search space ID to find the Linear connector - user_id: User ID for fetching user-specific context - connector_id: Optional specific connector ID (if known) - - Returns: - Configured create_linear_issue tool - """ - del db_session # per-call session — see docstring - - @tool - async def create_linear_issue( - title: str, - description: str | None = None, - ) -> dict[str, Any]: - """Create a new issue in Linear. - - Use this tool when the user explicitly asks to create, add, or file - a new issue / ticket / task in Linear. The user MUST describe the issue - before you call this tool. If the request is vague, ask what the issue - should be about. Never call this tool without a clear topic from the user. - - Args: - title: Short, descriptive issue title. Infer from the user's request. - description: Optional markdown body for the issue. Generate from context. - - Returns: - Dictionary with: - - status: "success", "rejected", or "error" - - issue_id: Linear issue UUID (if success) - - identifier: Human-readable ID like "ENG-42" (if success) - - url: URL to the created issue (if success) - - message: Result message - - IMPORTANT: If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment (e.g., "Understood, I won't create the issue.") - and move on. Do NOT retry, troubleshoot, or suggest alternatives. - - Examples: - - "Create a Linear issue for the login bug" - - "File a ticket about the payment timeout problem" - - "Add an issue for the broken search feature" - """ - logger.info(f"create_linear_issue called: title='{title}'") - - if search_space_id is None or user_id is None: - logger.error( - "Linear tool not properly configured - missing required parameters" - ) - return { - "status": "error", - "message": "Linear tool not properly configured. Please contact support.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error( - f"Failed to fetch creation context: {context['error']}" - ) - return {"status": "error", "message": context["error"]} - - workspaces = context.get("workspaces", []) - if workspaces and all(w.get("auth_expired") for w in workspaces): - logger.warning("All Linear accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "linear", - } - - logger.info(f"Requesting approval for creating Linear issue: '{title}'") - result = request_approval( - action_type="linear_issue_creation", - tool_name="create_linear_issue", - params={ - "title": title, - "description": description, - "team_id": None, - "state_id": None, - "assignee_id": None, - "priority": None, - "label_ids": [], - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue creation rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_title = result.params.get("title", title) - final_description = result.params.get("description", description) - final_team_id = result.params.get("team_id") - final_state_id = result.params.get("state_id") - final_assignee_id = result.params.get("assignee_id") - final_priority = result.params.get("priority") - final_label_ids = result.params.get("label_ids") or [] - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_title or not final_title.strip(): - logger.error("Title is empty or contains only whitespace") - return { - "status": "error", - "message": "Issue title cannot be empty.", - } - if not final_team_id: - return { - "status": "error", - "message": "A team must be selected to create an issue.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Linear connector found. Please connect Linear in your workspace settings.", - } - actual_connector_id = connector.id - logger.info(f"Found Linear connector: id={actual_connector_id}") - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", - } - logger.info(f"Validated Linear connector: id={actual_connector_id}") - - logger.info( - f"Creating Linear issue with final params: title='{final_title}'" - ) - linear_client = LinearConnector( - session=db_session, connector_id=actual_connector_id - ) - result = await linear_client.create_issue( - team_id=final_team_id, - title=final_title, - description=final_description, - state_id=final_state_id, - assignee_id=final_assignee_id, - priority=final_priority, - label_ids=final_label_ids if final_label_ids else None, - ) - - if result.get("status") == "error": - logger.error( - f"Failed to create Linear issue: {result.get('message')}" - ) - return {"status": "error", "message": result.get("message")} - - logger.info( - f"Linear issue created: {result.get('identifier')} - {result.get('title')}" - ) - - kb_message_suffix = "" - try: - from app.services.linear import LinearKBSyncService - - kb_service = LinearKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - issue_id=result.get("id"), - issue_identifier=result.get("identifier", ""), - issue_title=result.get("title", final_title), - issue_url=result.get("url"), - description=final_description, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." - - return { - "status": "success", - "issue_id": result.get("id"), - "identifier": result.get("identifier"), - "url": result.get("url"), - "message": (result.get("message", "") + kb_message_suffix), - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error creating Linear issue: {e}", exc_info=True) - if isinstance(e, ValueError | LinearAPIError): - message = str(e) - else: - message = ( - "Something went wrong while creating the issue. Please try again." - ) - return {"status": "error", "message": message} - - return create_linear_issue diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py deleted file mode 100644 index c5039a8eb..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py +++ /dev/null @@ -1,256 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.linear_connector import LinearAPIError, LinearConnector -from app.db import async_session_maker -from app.services.linear import LinearToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_delete_linear_issue_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, - connector_id: int | None = None, -): - """Factory function to create the delete_linear_issue tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker`. This is critical for the compiled-agent - cache: the compiled graph (and therefore this closure) is reused - across HTTP requests, so capturing a per-request session here would - surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - search_space_id: Search space ID to find the Linear connector - user_id: User ID for finding the correct Linear connector - connector_id: Optional specific connector ID (if known) - - Returns: - Configured delete_linear_issue tool - """ - del db_session # per-call session — see docstring - - @tool - async def delete_linear_issue( - issue_ref: str, - delete_from_kb: bool = False, - ) -> dict[str, Any]: - """Archive (delete) a Linear issue. - - Use this tool when the user asks to delete, remove, or archive a Linear issue. - Note that Linear archives issues rather than permanently deleting them - (they can be restored from the archive). - - - Args: - issue_ref: The issue to delete. Can be the issue title (e.g. "Fix login bug"), - the identifier (e.g. "ENG-42"), or the full document title - (e.g. "ENG-42: Fix login bug"). - delete_from_kb: Whether to also remove the issue from the knowledge base. - Default is False. Set to True to remove from both Linear - and the knowledge base. - - Returns: - Dictionary with: - - status: "success", "rejected", "not_found", or "error" - - identifier: Human-readable ID like "ENG-42" (if success) - - message: Success or error message - - deleted_from_kb: Whether the issue was also removed from the knowledge base (if success) - - IMPORTANT: - - If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment (e.g., "Understood, I won't delete the issue.") - and move on. Do NOT ask for alternatives or troubleshoot. - - If status is "not_found", inform the user conversationally using the exact message - provided. Do NOT treat this as an error. Simply relay the message and ask the user - to verify the issue title or identifier, or check if it has been indexed. - Examples: - - "Delete the 'Fix login bug' Linear issue" - - "Archive ENG-42" - - "Remove the 'Old payment flow' issue from Linear" - """ - logger.info( - f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}" - ) - - if search_space_id is None or user_id is None: - logger.error( - "Linear tool not properly configured - missing required parameters" - ) - return { - "status": "error", - "message": "Linear tool not properly configured. Please contact support.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_delete_context( - search_space_id, user_id, issue_ref - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - logger.warning(f"Auth expired for delete context: {error_msg}") - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "linear", - } - if "not found" in error_msg.lower(): - logger.warning(f"Issue not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - else: - logger.error(f"Failed to fetch delete context: {error_msg}") - return {"status": "error", "message": error_msg} - - issue_id = context["issue"]["id"] - issue_identifier = context["issue"].get("identifier", "") - document_id = context["issue"]["document_id"] - connector_id_from_context = context.get("workspace", {}).get("id") - - logger.info( - f"Requesting approval for deleting Linear issue: '{issue_ref}' " - f"(id={issue_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="linear_issue_deletion", - tool_name="delete_linear_issue", - params={ - "issue_id": issue_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue deletion rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_id = result.params.get("issue_id", issue_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get( - "delete_from_kb", delete_from_kb - ) - - logger.info( - f"Deleting Linear issue with final params: issue_id={final_issue_id}, " - f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - logger.info(f"Validated Linear connector: id={actual_connector_id}") - else: - logger.error("No connector found for this issue") - return { - "status": "error", - "message": "No connector found for this issue.", - } - - linear_client = LinearConnector( - session=db_session, connector_id=actual_connector_id - ) - - result = await linear_client.archive_issue(issue_id=final_issue_id) - - logger.info( - f"archive_issue result: {result.get('status')} - {result.get('message', '')}" - ) - - deleted_from_kb = False - if ( - result.get("status") == "success" - and final_delete_from_kb - and document_id - ): - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - result["warning"] = ( - f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}" - ) - - if result.get("status") == "success": - result["deleted_from_kb"] = deleted_from_kb - if issue_identifier: - result["message"] = ( - f"Issue {issue_identifier} archived successfully." - ) - if deleted_from_kb: - result["message"] = ( - f"{result.get('message', '')} Also removed from the knowledge base." - ) - - return result - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error deleting Linear issue: {e}", exc_info=True) - if isinstance(e, ValueError | LinearAPIError): - message = str(e) - else: - message = ( - "Something went wrong while deleting the issue. Please try again." - ) - return {"status": "error", "message": message} - - return delete_linear_issue diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py deleted file mode 100644 index d610ce2b7..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py +++ /dev/null @@ -1,327 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.linear_connector import LinearAPIError, LinearConnector -from app.db import async_session_maker -from app.services.linear import LinearKBSyncService, LinearToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_update_linear_issue_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, - connector_id: int | None = None, -): - """Factory function to create the update_linear_issue tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker`. This is critical for the compiled-agent - cache: the compiled graph (and therefore this closure) is reused - across HTTP requests, so capturing a per-request session here would - surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - search_space_id: Search space ID to find the Linear connector - user_id: User ID for fetching user-specific context - connector_id: Optional specific connector ID (if known) - - Returns: - Configured update_linear_issue tool - """ - del db_session # per-call session — see docstring - - @tool - async def update_linear_issue( - issue_ref: str, - new_title: str | None = None, - new_description: str | None = None, - new_state_name: str | None = None, - new_assignee_email: str | None = None, - new_priority: int | None = None, - new_label_names: list[str] | None = None, - ) -> dict[str, Any]: - """Update an existing Linear issue that has been indexed in the knowledge base. - - Use this tool when the user asks to modify, change, or update a Linear issue — - for example, changing its status, reassigning it, updating its title or description, - adjusting its priority, or changing its labels. - - Only issues already indexed in the knowledge base can be updated. - - Args: - issue_ref: The issue to update. Can be the issue title (e.g. "Fix login bug"), - the identifier (e.g. "ENG-42"), or the full document title - (e.g. "ENG-42: Fix login bug"). Matched case-insensitively. - new_title: New title for the issue (optional). - new_description: New markdown body for the issue (optional). - new_state_name: New workflow state name (e.g. "In Progress", "Done"). - Matched case-insensitively against the team's states. - new_assignee_email: Email address of the new assignee. - Matched case-insensitively against the team's members. - new_priority: New priority (0 = No Priority, 1 = Urgent, 2 = High, - 3 = Medium, 4 = Low). - new_label_names: New set of label names to apply. - Matched case-insensitively against the team's labels. - Unrecognised names are silently skipped. - - Returns: - Dictionary with: - - status: "success", "rejected", "not_found", or "error" - - identifier: Human-readable ID like "ENG-42" (if success) - - url: URL to the updated issue (if success) - - message: Result message - - IMPORTANT: - - If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment (e.g., "Understood, I didn't update the issue.") - and move on. Do NOT ask for alternatives or troubleshoot. - - If status is "not_found", inform the user conversationally using the exact message - provided. Do NOT treat this as an error. Simply relay the message and ask the user - to verify the issue title or identifier, or check if it has been indexed. - - Examples: - - "Mark the 'Fix login bug' issue as done" - - "Assign ENG-42 to john@company.com" - - "Change the priority of 'Payment timeout' to urgent" - """ - logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'") - - if search_space_id is None or user_id is None: - logger.error( - "Linear tool not properly configured - missing required parameters" - ) - return { - "status": "error", - "message": "Linear tool not properly configured. Please contact support.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, issue_ref - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - logger.warning(f"Auth expired for update context: {error_msg}") - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "linear", - } - if "not found" in error_msg.lower(): - logger.warning(f"Issue not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - else: - logger.error(f"Failed to fetch update context: {error_msg}") - return {"status": "error", "message": error_msg} - - issue_id = context["issue"]["id"] - document_id = context["issue"]["document_id"] - connector_id_from_context = context.get("workspace", {}).get("id") - - team = context.get("team", {}) - new_state_id = _resolve_state(team, new_state_name) - new_assignee_id = _resolve_assignee(team, new_assignee_email) - new_label_ids = _resolve_labels(team, new_label_names) - - logger.info( - f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})" - ) - result = request_approval( - action_type="linear_issue_update", - tool_name="update_linear_issue", - params={ - "issue_id": issue_id, - "document_id": document_id, - "new_title": new_title, - "new_description": new_description, - "new_state_id": new_state_id, - "new_assignee_id": new_assignee_id, - "new_priority": new_priority, - "new_label_ids": new_label_ids, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue update rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_id = result.params.get("issue_id", issue_id) - final_document_id = result.params.get("document_id", document_id) - final_new_title = result.params.get("new_title", new_title) - final_new_description = result.params.get( - "new_description", new_description - ) - final_new_state_id = result.params.get("new_state_id", new_state_id) - final_new_assignee_id = result.params.get( - "new_assignee_id", new_assignee_id - ) - final_new_priority = result.params.get("new_priority", new_priority) - final_new_label_ids: list[str] | None = result.params.get( - "new_label_ids", new_label_ids - ) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - - if not final_connector_id: - logger.error("No connector found for this issue") - return { - "status": "error", - "message": "No connector found for this issue.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", - } - logger.info(f"Validated Linear connector: id={final_connector_id}") - - logger.info( - f"Updating Linear issue with final params: issue_id={final_issue_id}" - ) - linear_client = LinearConnector( - session=db_session, connector_id=final_connector_id - ) - updated_issue = await linear_client.update_issue( - issue_id=final_issue_id, - title=final_new_title, - description=final_new_description, - state_id=final_new_state_id, - assignee_id=final_new_assignee_id, - priority=final_new_priority, - label_ids=final_new_label_ids, - ) - - if updated_issue.get("status") == "error": - logger.error( - f"Failed to update Linear issue: {updated_issue.get('message')}" - ) - return { - "status": "error", - "message": updated_issue.get("message"), - } - - logger.info( - f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}" - ) - - if final_document_id is not None: - logger.info( - f"Updating knowledge base for document {final_document_id}..." - ) - kb_service = LinearKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=final_document_id, - issue_id=final_issue_id, - user_id=user_id, - search_space_id=search_space_id, - ) - if kb_result["status"] == "success": - logger.info( - f"Knowledge base successfully updated for issue {final_issue_id}" - ) - kb_message = " Your knowledge base has also been updated." - elif kb_result["status"] == "not_indexed": - kb_message = " This issue will be added to your knowledge base in the next scheduled sync." - else: - logger.warning( - f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}" - ) - kb_message = " Your knowledge base will be updated in the next scheduled sync." - else: - kb_message = "" - - identifier = updated_issue.get("identifier") - default_msg = f"Issue {identifier} updated successfully." - return { - "status": "success", - "identifier": identifier, - "url": updated_issue.get("url"), - "message": f"{updated_issue.get('message', default_msg)}{kb_message}", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error updating Linear issue: {e}", exc_info=True) - if isinstance(e, ValueError | LinearAPIError): - message = str(e) - else: - message = ( - "Something went wrong while updating the issue. Please try again." - ) - return {"status": "error", "message": message} - - return update_linear_issue - - -def _resolve_state(team: dict, state_name: str | None) -> str | None: - if not state_name: - return None - name_lower = state_name.lower() - for state in team.get("states", []): - if state.get("name", "").lower() == name_lower: - return state["id"] - return None - - -def _resolve_assignee(team: dict, assignee_email: str | None) -> str | None: - if not assignee_email: - return None - email_lower = assignee_email.lower() - for member in team.get("members", []): - if member.get("email", "").lower() == email_lower: - return member["id"] - return None - - -def _resolve_labels(team: dict, label_names: list[str] | None) -> list[str] | None: - if label_names is None: - return None - if not label_names: - return [] - name_set = {n.lower() for n in label_names} - return [ - label["id"] - for label in team.get("labels", []) - if label.get("name", "").lower() in name_set - ] diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/__init__.py b/surfsense_backend/app/agents/new_chat/tools/luma/__init__.py deleted file mode 100644 index 255119bee..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/luma/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from app.agents.new_chat.tools.luma.create_event import ( - create_create_luma_event_tool, -) -from app.agents.new_chat.tools.luma.list_events import ( - create_list_luma_events_tool, -) -from app.agents.new_chat.tools.luma.read_event import ( - create_read_luma_event_tool, -) - -__all__ = [ - "create_create_luma_event_tool", - "create_list_luma_events_tool", - "create_read_luma_event_tool", -] diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py b/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py deleted file mode 100644 index 37deb1525..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Shared auth helper for Luma agent tools.""" - -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.db import SearchSourceConnector, SearchSourceConnectorType - -LUMA_API = "https://public-api.luma.com/v1" - - -async def get_luma_connector( - db_session: AsyncSession, - search_space_id: int, - user_id: str, -) -> SearchSourceConnector | None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LUMA_CONNECTOR, - ) - ) - return result.scalars().first() - - -def get_api_key(connector: SearchSourceConnector) -> str: - """Extract the API key from connector config (handles both key names).""" - key = connector.config.get("api_key") or connector.config.get("LUMA_API_KEY") - if not key: - raise ValueError("Luma API key not found in connector config.") - return key - - -def luma_headers(api_key: str) -> dict[str, str]: - return { - "Content-Type": "application/json", - "x-luma-api-key": api_key, - } diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py deleted file mode 100644 index 65c177d7a..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py +++ /dev/null @@ -1,150 +0,0 @@ -import logging -from typing import Any - -import httpx -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.db import async_session_maker - -from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers - -logger = logging.getLogger(__name__) - - -def create_create_luma_event_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the create_luma_event tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured create_luma_event tool - """ - del db_session # per-call session — see docstring - - @tool - async def create_luma_event( - name: str, - start_at: str, - end_at: str, - description: str | None = None, - timezone: str = "UTC", - ) -> dict[str, Any]: - """Create a new event on Luma. - - Args: - name: The event title. - start_at: Start time in ISO 8601 format (e.g. "2026-05-01T18:00:00"). - end_at: End time in ISO 8601 format (e.g. "2026-05-01T20:00:00"). - description: Optional event description (markdown supported). - timezone: Timezone string (default "UTC", e.g. "America/New_York"). - - Returns: - Dictionary with status, event_id on success. - - IMPORTANT: - - If status is "rejected", the user explicitly declined. Do NOT retry. - """ - if search_space_id is None or user_id is None: - return {"status": "error", "message": "Luma tool not properly configured."} - - try: - async with async_session_maker() as db_session: - connector = await get_luma_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Luma connector found."} - - result = request_approval( - action_type="luma_create_event", - tool_name="create_luma_event", - params={ - "name": name, - "start_at": start_at, - "end_at": end_at, - "description": description, - "timezone": timezone, - }, - context={"connector_id": connector.id}, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Event was not created.", - } - - final_name = result.params.get("name", name) - final_start = result.params.get("start_at", start_at) - final_end = result.params.get("end_at", end_at) - final_desc = result.params.get("description", description) - final_tz = result.params.get("timezone", timezone) - - api_key = get_api_key(connector) - headers = luma_headers(api_key) - - body: dict[str, Any] = { - "name": final_name, - "start_at": final_start, - "end_at": final_end, - "timezone": final_tz, - } - if final_desc: - body["description_md"] = final_desc - - async with httpx.AsyncClient(timeout=20.0) as client: - resp = await client.post( - f"{LUMA_API}/event/create", - headers=headers, - json=body, - ) - - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Luma API key is invalid.", - "connector_type": "luma", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Luma Plus subscription required to create events via API.", - } - if resp.status_code not in (200, 201): - return { - "status": "error", - "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}", - } - - data = resp.json() - event_id = data.get("api_id") or data.get("event", {}).get("api_id") - - return { - "status": "success", - "event_id": event_id, - "message": f"Event '{final_name}' created on Luma.", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error("Error creating Luma event: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to create Luma event."} - - return create_luma_event diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py deleted file mode 100644 index 6885c2049..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py +++ /dev/null @@ -1,133 +0,0 @@ -import logging -from typing import Any - -import httpx -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.db import async_session_maker - -from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers - -logger = logging.getLogger(__name__) - - -def create_list_luma_events_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the list_luma_events tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured list_luma_events tool - """ - del db_session # per-call session — see docstring - - @tool - async def list_luma_events( - max_results: int = 25, - ) -> dict[str, Any]: - """List upcoming and recent Luma events. - - Args: - max_results: Maximum events to return (default 25, max 50). - - Returns: - Dictionary with status and a list of events including - event_id, name, start_at, end_at, location, url. - """ - if search_space_id is None or user_id is None: - return {"status": "error", "message": "Luma tool not properly configured."} - - max_results = min(max_results, 50) - - try: - async with async_session_maker() as db_session: - connector = await get_luma_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Luma connector found."} - - api_key = get_api_key(connector) - headers = luma_headers(api_key) - - all_entries: list[dict] = [] - cursor = None - - async with httpx.AsyncClient(timeout=20.0) as client: - while len(all_entries) < max_results: - params: dict[str, Any] = { - "limit": min(100, max_results - len(all_entries)) - } - if cursor: - params["cursor"] = cursor - - resp = await client.get( - f"{LUMA_API}/calendar/list-events", - headers=headers, - params=params, - ) - - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Luma API key is invalid.", - "connector_type": "luma", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Luma API error: {resp.status_code}", - } - - data = resp.json() - entries = data.get("entries", []) - if not entries: - break - all_entries.extend(entries) - - next_cursor = data.get("next_cursor") - if not next_cursor: - break - cursor = next_cursor - - events = [] - for entry in all_entries[:max_results]: - ev = entry.get("event", {}) - geo = ev.get("geo_info", {}) - events.append( - { - "event_id": entry.get("api_id"), - "name": ev.get("name", "Untitled"), - "start_at": ev.get("start_at", ""), - "end_at": ev.get("end_at", ""), - "timezone": ev.get("timezone", ""), - "location": geo.get("name", ""), - "url": ev.get("url", ""), - "visibility": ev.get("visibility", ""), - } - ) - - return {"status": "success", "events": events, "total": len(events)} - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error("Error listing Luma events: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to list Luma events."} - - return list_luma_events diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py deleted file mode 100644 index a8484e9c0..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py +++ /dev/null @@ -1,114 +0,0 @@ -import logging -from typing import Any - -import httpx -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.db import async_session_maker - -from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers - -logger = logging.getLogger(__name__) - - -def create_read_luma_event_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the read_luma_event tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured read_luma_event tool - """ - del db_session # per-call session — see docstring - - @tool - async def read_luma_event(event_id: str) -> dict[str, Any]: - """Read detailed information about a specific Luma event. - - Args: - event_id: The Luma event API ID (from list_luma_events). - - Returns: - Dictionary with status and full event details including - description, attendees count, meeting URL. - """ - if search_space_id is None or user_id is None: - return {"status": "error", "message": "Luma tool not properly configured."} - - try: - async with async_session_maker() as db_session: - connector = await get_luma_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Luma connector found."} - - api_key = get_api_key(connector) - headers = luma_headers(api_key) - - async with httpx.AsyncClient(timeout=15.0) as client: - resp = await client.get( - f"{LUMA_API}/events/{event_id}", - headers=headers, - ) - - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Luma API key is invalid.", - "connector_type": "luma", - } - if resp.status_code == 404: - return { - "status": "not_found", - "message": f"Event '{event_id}' not found.", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Luma API error: {resp.status_code}", - } - - data = resp.json() - ev = data.get("event", data) - geo = ev.get("geo_info", {}) - - event_detail = { - "event_id": event_id, - "name": ev.get("name", ""), - "description": ev.get("description", ""), - "start_at": ev.get("start_at", ""), - "end_at": ev.get("end_at", ""), - "timezone": ev.get("timezone", ""), - "location_name": geo.get("name", ""), - "address": geo.get("address", ""), - "url": ev.get("url", ""), - "meeting_url": ev.get("meeting_url", ""), - "visibility": ev.get("visibility", ""), - "cover_url": ev.get("cover_url", ""), - } - - return {"status": "success", "event": event_detail} - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error("Error reading Luma event: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to read Luma event."} - - return read_luma_event diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/__init__.py b/surfsense_backend/app/agents/new_chat/tools/notion/__init__.py deleted file mode 100644 index 6ce825dca..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/notion/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Notion tools for creating, updating, and deleting pages.""" - -from .create_page import create_create_notion_page_tool -from .delete_page import create_delete_notion_page_tool -from .update_page import create_update_notion_page_tool - -__all__ = [ - "create_create_notion_page_tool", - "create_delete_notion_page_tool", - "create_update_notion_page_tool", -] diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py deleted file mode 100644 index 6ec95e9f0..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py +++ /dev/null @@ -1,258 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector -from app.db import async_session_maker -from app.services.notion import NotionToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_create_notion_page_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, - connector_id: int | None = None, -): - """ - Factory function to create the create_notion_page tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker`. This is critical for the compiled-agent - cache: the compiled graph (and therefore this closure) is reused - across HTTP requests, so capturing a per-request session here would - surface stale/closed sessions on cache hits. Per-call sessions also - keep the request's outer transaction free of long-running Notion API - blocking. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - search_space_id: Search space ID to find the Notion connector - user_id: User ID for fetching user-specific context - connector_id: Optional specific connector ID (if known) - - Returns: - Configured create_notion_page tool - """ - del db_session # per-call session — see docstring - - @tool - async def create_notion_page( - title: str, - content: str | None = None, - ) -> dict[str, Any]: - """Create a new page in Notion with the given title and content. - - Use this tool when the user asks you to create, save, or publish - something to Notion. The page will be created in the user's - configured Notion workspace. The user MUST specify a topic before you - call this tool. If the request does not contain a topic (e.g. "create a - notion page"), ask what the page should be about. Never call this tool - without a clear topic from the user. - - Args: - title: The title of the Notion page. - content: Optional markdown content for the page body (supports headings, lists, paragraphs). - Generate this yourself based on the user's topic. - - Returns: - Dictionary with: - - status: "success", "rejected", or "error" - - page_id: Created page ID (if success) - - url: URL to the created page (if success) - - title: Page title (if success) - - message: Result message - - IMPORTANT: If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment (e.g., "Understood, I didn't create the page.") - and move on. Do NOT troubleshoot or suggest alternatives. - - Examples: - - "Create a Notion page about our Q2 roadmap" - - "Save a summary of today's discussion to Notion" - """ - logger.info(f"create_notion_page called: title='{title}'") - - if search_space_id is None or user_id is None: - logger.error( - "Notion tool not properly configured - missing required parameters" - ) - return { - "status": "error", - "message": "Notion tool not properly configured. Please contact support.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = NotionToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error( - f"Failed to fetch creation context: {context['error']}" - ) - return { - "status": "error", - "message": context["error"], - } - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Notion accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "notion", - } - - logger.info(f"Requesting approval for creating Notion page: '{title}'") - result = request_approval( - action_type="notion_page_creation", - tool_name="create_notion_page", - params={ - "title": title, - "content": content, - "parent_page_id": None, - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - logger.info("Notion page creation rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_title = result.params.get("title", title) - final_content = result.params.get("content", content) - final_parent_page_id = result.params.get("parent_page_id") - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_title or not final_title.strip(): - logger.error("Title is empty or contains only whitespace") - return { - "status": "error", - "message": "Page title cannot be empty. Please provide a valid title.", - } - - logger.info( - f"Creating Notion page with final params: title='{final_title}'" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) - ) - connector = result.scalars().first() - - if not connector: - logger.warning( - f"No Notion connector found for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "No Notion connector found. Please connect Notion in your workspace settings.", - } - - actual_connector_id = connector.id - logger.info(f"Found Notion connector: id={actual_connector_id}") - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) - ) - connector = result.scalars().first() - - if not connector: - logger.error( - f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", - } - logger.info(f"Validated Notion connector: id={actual_connector_id}") - - notion_connector = NotionHistoryConnector( - session=db_session, - connector_id=actual_connector_id, - ) - - result = await notion_connector.create_page( - title=final_title, - content=final_content, - parent_page_id=final_parent_page_id, - ) - logger.info( - f"create_page result: {result.get('status')} - {result.get('message', '')}" - ) - - if result.get("status") == "success": - kb_message_suffix = "" - try: - from app.services.notion import NotionKBSyncService - - kb_service = NotionKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - page_id=result.get("page_id"), - page_title=result.get("title", final_title), - page_url=result.get("url"), - content=final_content, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." - - result["message"] = result.get("message", "") + kb_message_suffix - - return result - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error creating Notion page: {e}", exc_info=True) - if isinstance(e, ValueError | NotionAPIError): - message = str(e) - else: - message = ( - "Something went wrong while creating the page. Please try again." - ) - return {"status": "error", "message": message} - - return create_notion_page diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py deleted file mode 100644 index 7b85da4c2..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py +++ /dev/null @@ -1,273 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector -from app.db import async_session_maker -from app.services.notion.tool_metadata_service import NotionToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_delete_notion_page_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, - connector_id: int | None = None, -): - """ - Factory function to create the delete_notion_page tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - search_space_id: Search space ID to find the Notion connector - user_id: User ID for finding the correct Notion connector - connector_id: Optional specific connector ID (if known) - - Returns: - Configured delete_notion_page tool - """ - del db_session # per-call session — see docstring - - @tool - async def delete_notion_page( - page_title: str, - delete_from_kb: bool = False, - ) -> dict[str, Any]: - """Delete (archive) a Notion page. - - Use this tool when the user asks you to delete, remove, or archive - a Notion page. Note that Notion doesn't permanently delete pages, - it archives them (they can be restored from trash). - - Args: - page_title: The title of the Notion page to delete. - delete_from_kb: Whether to also remove the page from the knowledge base. - Default is False. - Set to True to permanently remove from both Notion and knowledge base. - - Returns: - Dictionary with: - - status: "success", "rejected", "not_found", or "error" - - page_id: Deleted page ID (if success) - - message: Success or error message - - deleted_from_kb: Whether the page was also removed from knowledge base (if success) - - Examples: - - "Delete the 'Meeting Notes' Notion page" - - "Remove the 'Old Project Plan' Notion page" - - "Archive the 'Draft Ideas' Notion page" - """ - logger.info( - f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}" - ) - - if search_space_id is None or user_id is None: - logger.error( - "Notion tool not properly configured - missing required parameters" - ) - return { - "status": "error", - "message": "Notion tool not properly configured. Please contact support.", - } - - try: - async with async_session_maker() as db_session: - # Get page context (page_id, account, title) from indexed data - metadata_service = NotionToolMetadataService(db_session) - context = await metadata_service.get_delete_context( - search_space_id, user_id, page_title - ) - - if "error" in context: - error_msg = context["error"] - # Check if it's a "not found" error (softer handling for LLM) - if "not found" in error_msg.lower(): - logger.warning(f"Page not found: {error_msg}") - return { - "status": "not_found", - "message": error_msg, - } - else: - logger.error(f"Failed to fetch delete context: {error_msg}") - return { - "status": "error", - "message": error_msg, - } - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Notion account %s has expired authentication", - account.get("id"), - ) - return { - "status": "auth_error", - "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", - } - - page_id = context.get("page_id") - connector_id_from_context = account.get("id") - document_id = context.get("document_id") - - logger.info( - f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})" - ) - - result = request_approval( - action_type="notion_page_deletion", - tool_name="delete_notion_page", - params={ - "page_id": page_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - logger.info("Notion page deletion rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get( - "delete_from_kb", delete_from_kb - ) - - logger.info( - f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - # Validate the connector - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) - ) - connector = result.scalars().first() - - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", - } - actual_connector_id = connector.id - logger.info(f"Validated Notion connector: id={actual_connector_id}") - else: - logger.error("No connector found for this page") - return { - "status": "error", - "message": "No connector found for this page.", - } - - # Create connector instance - notion_connector = NotionHistoryConnector( - session=db_session, - connector_id=actual_connector_id, - ) - - # Delete the page from Notion - result = await notion_connector.delete_page(page_id=final_page_id) - logger.info( - f"delete_page result: {result.get('status')} - {result.get('message', '')}" - ) - - # If deletion was successful and user wants to delete from KB - deleted_from_kb = False - if ( - result.get("status") == "success" - and final_delete_from_kb - and document_id - ): - try: - from sqlalchemy.future import select - - from app.db import Document - - # Get the document - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - result["warning"] = ( - f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}" - ) - - # Update result with KB deletion status - if result.get("status") == "success": - result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - result["message"] = ( - f"{result.get('message', '')} (also removed from knowledge base)" - ) - - return result - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error deleting Notion page: {e}", exc_info=True) - error_str = str(e).lower() - if isinstance(e, NotionAPIError) and ( - "401" in error_str or "unauthorized" in error_str - ): - return { - "status": "auth_error", - "message": str(e), - "connector_id": connector_id_from_context - if "connector_id_from_context" in dir() - else None, - "connector_type": "notion", - } - if isinstance(e, ValueError | NotionAPIError): - message = str(e) - else: - message = ( - "Something went wrong while deleting the page. Please try again." - ) - return {"status": "error", "message": message} - - return delete_notion_page diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py deleted file mode 100644 index df757476a..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py +++ /dev/null @@ -1,276 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector -from app.db import async_session_maker -from app.services.notion import NotionToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_update_notion_page_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, - connector_id: int | None = None, -): - """ - Factory function to create the update_notion_page tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache (see - ``create_create_notion_page_tool`` for the full rationale). - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - search_space_id: Search space ID to find the Notion connector - user_id: User ID for fetching user-specific context - connector_id: Optional specific connector ID (if known) - - Returns: - Configured update_notion_page tool - """ - del db_session # per-call session — see docstring - - @tool - async def update_notion_page( - page_title: str, - content: str | None = None, - ) -> dict[str, Any]: - """Update an existing Notion page by appending new content. - - Use this tool when the user asks you to add content to, modify, or update - a Notion page. The new content will be appended to the existing page content. - The user MUST specify what to add before you call this tool. If the - request is vague, ask what content they want added. - - Args: - page_title: The title of the Notion page to update. - content: Optional markdown content to append to the page body (supports headings, lists, paragraphs). - Generate this yourself based on the user's request. - - Returns: - Dictionary with: - - status: "success", "rejected", "not_found", or "error" - - page_id: Updated page ID (if success) - - url: URL to the updated page (if success) - - title: Current page title (if success) - - message: Result message - - IMPORTANT: - - If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment (e.g., "Understood, I didn't update the page.") - and move on. Do NOT ask for alternatives or troubleshoot. - - If status is "not_found", inform the user conversationally using the exact message provided. - Example: "I couldn't find the page '[page_title]' in your indexed Notion pages. [message details]" - Do NOT treat this as an error. Do NOT invent information. Simply relay the message and - ask the user to verify the page title or check if it's been indexed. - Examples: - - "Add today's meeting notes to the 'Meeting Notes' Notion page" - - "Update the 'Project Plan' page with a status update on phase 1" - """ - logger.info( - f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}" - ) - - if search_space_id is None or user_id is None: - logger.error( - "Notion tool not properly configured - missing required parameters" - ) - return { - "status": "error", - "message": "Notion tool not properly configured. Please contact support.", - } - - if not content or not content.strip(): - logger.error(f"Empty content provided for page '{page_title}'") - return { - "status": "error", - "message": "Content is required to update the page. Please provide the actual content you want to add.", - } - - try: - async with async_session_maker() as db_session: - metadata_service = NotionToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, page_title - ) - - if "error" in context: - error_msg = context["error"] - # Check if it's a "not found" error (softer handling for LLM) - if "not found" in error_msg.lower(): - logger.warning(f"Page not found: {error_msg}") - return { - "status": "not_found", - "message": error_msg, - } - else: - logger.error(f"Failed to fetch update context: {error_msg}") - return { - "status": "error", - "message": error_msg, - } - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Notion account %s has expired authentication", - account.get("id"), - ) - return { - "status": "auth_error", - "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", - } - - page_id = context.get("page_id") - document_id = context.get("document_id") - connector_id_from_context = context.get("account", {}).get("id") - - logger.info( - f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})" - ) - result = request_approval( - action_type="notion_page_update", - tool_name="update_notion_page", - params={ - "page_id": page_id, - "content": content, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - logger.info("Notion page update rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_content = result.params.get("content", content) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - - logger.info( - f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) - ) - connector = result.scalars().first() - - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", - } - actual_connector_id = connector.id - logger.info(f"Validated Notion connector: id={actual_connector_id}") - else: - logger.error("No connector found for this page") - return { - "status": "error", - "message": "No connector found for this page.", - } - - notion_connector = NotionHistoryConnector( - session=db_session, - connector_id=actual_connector_id, - ) - - result = await notion_connector.update_page( - page_id=final_page_id, - content=final_content, - ) - logger.info( - f"update_page result: {result.get('status')} - {result.get('message', '')}" - ) - - if result.get("status") == "success" and document_id is not None: - from app.services.notion import NotionKBSyncService - - logger.info( - f"Updating knowledge base for document {document_id}..." - ) - kb_service = NotionKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=document_id, - appended_content=final_content, - user_id=user_id, - search_space_id=search_space_id, - appended_block_ids=result.get("appended_block_ids"), - ) - - if kb_result["status"] == "success": - result["message"] = ( - f"{result['message']}. Your knowledge base has also been updated." - ) - logger.info( - f"Knowledge base successfully updated for page {final_page_id}" - ) - elif kb_result["status"] == "not_indexed": - result["message"] = ( - f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync." - ) - else: - result["message"] = ( - f"{result['message']}. Your knowledge base will be updated in the next scheduled sync." - ) - logger.warning( - f"KB update failed for page {final_page_id}: {kb_result['message']}" - ) - - return result - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error updating Notion page: {e}", exc_info=True) - error_str = str(e).lower() - if isinstance(e, NotionAPIError) and ( - "401" in error_str or "unauthorized" in error_str - ): - return { - "status": "auth_error", - "message": str(e), - "connector_id": connector_id_from_context - if "connector_id_from_context" in dir() - else None, - "connector_type": "notion", - } - if isinstance(e, ValueError | NotionAPIError): - message = str(e) - else: - message = ( - "Something went wrong while updating the page. Please try again." - ) - return {"status": "error", "message": message} - - return update_notion_page diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/__init__.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/__init__.py deleted file mode 100644 index 8edb4857e..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from app.agents.new_chat.tools.onedrive.create_file import ( - create_create_onedrive_file_tool, -) -from app.agents.new_chat.tools.onedrive.trash_file import ( - create_delete_onedrive_file_tool, -) - -__all__ = [ - "create_create_onedrive_file_tool", - "create_delete_onedrive_file_tool", -] diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py deleted file mode 100644 index 5f199a41b..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py +++ /dev/null @@ -1,274 +0,0 @@ -import logging -import os -import tempfile -from pathlib import Path -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.onedrive.client import OneDriveClient -from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker - -logger = logging.getLogger(__name__) - -DOCX_MIME = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - - -def _ensure_docx_extension(name: str) -> str: - """Strip any existing extension and append .docx.""" - stem = Path(name).stem - return f"{stem}.docx" - - -def _markdown_to_docx(markdown_text: str) -> bytes: - """Convert a markdown string to DOCX bytes using pypandoc.""" - import pypandoc - - fd, tmp_path = tempfile.mkstemp(suffix=".docx") - os.close(fd) - try: - pypandoc.convert_text( - markdown_text, - "docx", - format="gfm", - extra_args=["--standalone"], - outputfile=tmp_path, - ) - with open(tmp_path, "rb") as f: - return f.read() - finally: - os.unlink(tmp_path) - - -def create_create_onedrive_file_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the create_onedrive_file tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured create_onedrive_file tool - """ - del db_session # per-call session — see docstring - - @tool - async def create_onedrive_file( - name: str, - content: str | None = None, - ) -> dict[str, Any]: - """Create a new Word document (.docx) in Microsoft OneDrive. - - Use this tool when the user explicitly asks to create a new document - in OneDrive. The user MUST specify a topic before you call this tool. - - The file is always saved as a .docx Word document. Provide content as - markdown and it will be automatically converted to a formatted Word file. - - Args: - name: The document title (without extension). Extension will be set to .docx automatically. - content: Optional initial content as markdown. Will be converted to a formatted Word document. - - Returns: - Dictionary with status, file_id, name, web_url, and message. - """ - logger.info(f"create_onedrive_file called: name='{name}'") - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "OneDrive tool not properly configured.", - } - - try: - async with async_session_maker() as db_session: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, - ) - ) - connectors = result.scalars().all() - - if not connectors: - return { - "status": "error", - "message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.", - } - - accounts = [] - for c in connectors: - cfg = c.config or {} - accounts.append( - { - "id": c.id, - "name": c.name, - "user_email": cfg.get("user_email"), - "auth_expired": cfg.get("auth_expired", False), - } - ) - - if all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected OneDrive accounts need re-authentication.", - "connector_type": "onedrive", - } - - parent_folders: dict[int, list[dict[str, str]]] = {} - for acc in accounts: - cid = acc["id"] - if acc.get("auth_expired"): - parent_folders[cid] = [] - continue - try: - client = OneDriveClient(session=db_session, connector_id=cid) - items, err = await client.list_children("root") - if err: - logger.warning( - "Failed to list folders for connector %s: %s", cid, err - ) - parent_folders[cid] = [] - else: - parent_folders[cid] = [ - {"folder_id": item["id"], "name": item["name"]} - for item in items - if item.get("folder") is not None - and item.get("id") - and item.get("name") - ] - except Exception: - logger.warning( - "Error fetching folders for connector %s", - cid, - exc_info=True, - ) - parent_folders[cid] = [] - - context: dict[str, Any] = { - "accounts": accounts, - "parent_folders": parent_folders, - } - - result = request_approval( - action_type="onedrive_file_creation", - tool_name="create_onedrive_file", - params={ - "name": name, - "content": content, - "connector_id": None, - "parent_folder_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_name = result.params.get("name", name) - final_content = result.params.get("content", content) - final_connector_id = result.params.get("connector_id") - final_parent_folder_id = result.params.get("parent_folder_id") - - if not final_name or not final_name.strip(): - return {"status": "error", "message": "File name cannot be empty."} - - final_name = _ensure_docx_extension(final_name) - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, - ) - ) - connector = result.scalars().first() - else: - connector = connectors[0] - - if not connector: - return { - "status": "error", - "message": "Selected OneDrive connector is invalid.", - } - - docx_bytes = _markdown_to_docx(final_content or "") - - client = OneDriveClient(session=db_session, connector_id=connector.id) - created = await client.create_file( - name=final_name, - parent_id=final_parent_folder_id, - content=docx_bytes, - mime_type=DOCX_MIME, - ) - - logger.info( - f"OneDrive file created: id={created.get('id')}, name={created.get('name')}" - ) - - kb_message_suffix = "" - try: - from app.services.onedrive import OneDriveKBSyncService - - kb_service = OneDriveKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - file_id=created.get("id"), - file_name=created.get("name", final_name), - mime_type=DOCX_MIME, - web_url=created.get("webUrl"), - content=final_content, - connector_id=connector.id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - - return { - "status": "success", - "file_id": created.get("id"), - "name": created.get("name"), - "web_url": created.get("webUrl"), - "message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error(f"Error creating OneDrive file: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while creating the file. Please try again.", - } - - return create_onedrive_file diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py deleted file mode 100644 index 4857ea988..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py +++ /dev/null @@ -1,305 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy import String, and_, cast, func -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.onedrive.client import OneDriveClient -from app.db import ( - Document, - DocumentType, - SearchSourceConnector, - SearchSourceConnectorType, - async_session_maker, -) - -logger = logging.getLogger(__name__) - - -def create_delete_onedrive_file_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the delete_onedrive_file tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured delete_onedrive_file tool - """ - del db_session # per-call session — see docstring - - @tool - async def delete_onedrive_file( - file_name: str, - delete_from_kb: bool = False, - ) -> dict[str, Any]: - """Move a OneDrive file to the recycle bin. - - Use this tool when the user explicitly asks to delete, remove, or trash - a file in OneDrive. - - Args: - file_name: The exact name of the file to trash. - delete_from_kb: Whether to also remove the file from the knowledge base. - Default is False. - Set to True to remove from both OneDrive and knowledge base. - - Returns: - Dictionary with: - - status: "success", "rejected", "not_found", or "error" - - file_id: OneDrive file ID (if success) - - deleted_from_kb: whether the document was removed from the knowledge base - - message: Result message - - IMPORTANT: - - If status is "rejected", the user explicitly declined. Respond with a brief - acknowledgment and do NOT retry or suggest alternatives. - - If status is "not_found", relay the exact message to the user and ask them - to verify the file name or check if it has been indexed. - """ - logger.info( - f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" - ) - - if search_space_id is None or user_id is None: - return { - "status": "error", - "message": "OneDrive tool not properly configured.", - } - - try: - async with async_session_maker() as db_session: - doc_result = await db_session.execute( - select(Document) - .join( - SearchSourceConnector, - Document.connector_id == SearchSourceConnector.id, - ) - .filter( - and_( - Document.search_space_id == search_space_id, - Document.document_type == DocumentType.ONEDRIVE_FILE, - func.lower(Document.title) == func.lower(file_name), - SearchSourceConnector.user_id == user_id, - ) - ) - .order_by(Document.updated_at.desc().nullslast()) - .limit(1) - ) - document = doc_result.scalars().first() - - if not document: - doc_result = await db_session.execute( - select(Document) - .join( - SearchSourceConnector, - Document.connector_id == SearchSourceConnector.id, - ) - .filter( - and_( - Document.search_space_id == search_space_id, - Document.document_type == DocumentType.ONEDRIVE_FILE, - func.lower( - cast( - Document.document_metadata[ - "onedrive_file_name" - ], - String, - ) - ) - == func.lower(file_name), - SearchSourceConnector.user_id == user_id, - ) - ) - .order_by(Document.updated_at.desc().nullslast()) - .limit(1) - ) - document = doc_result.scalars().first() - - if not document: - return { - "status": "not_found", - "message": ( - f"File '{file_name}' not found in your indexed OneDrive files. " - "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " - "or (3) the file name is different." - ), - } - - if not document.connector_id: - return { - "status": "error", - "message": "Document has no associated connector.", - } - - meta = document.document_metadata or {} - file_id = meta.get("onedrive_file_id") - document_id = document.id - - if not file_id: - return { - "status": "error", - "message": "File ID is missing. Please re-index the file.", - } - - conn_result = await db_session.execute( - select(SearchSourceConnector).filter( - and_( - SearchSourceConnector.id == document.connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, - ) - ) - ) - connector = conn_result.scalars().first() - if not connector: - return { - "status": "error", - "message": "OneDrive connector not found or access denied.", - } - - cfg = connector.config or {} - if cfg.get("auth_expired"): - return { - "status": "auth_error", - "message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "onedrive", - } - - context = { - "file": { - "file_id": file_id, - "name": file_name, - "document_id": document_id, - "web_url": meta.get("web_url"), - }, - "account": { - "id": connector.id, - "name": connector.name, - "user_email": cfg.get("user_email"), - }, - } - - result = request_approval( - action_type="onedrive_file_trash", - tool_name="delete_onedrive_file", - params={ - "file_id": file_id, - "connector_id": connector.id, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_file_id = result.params.get("file_id", file_id) - final_connector_id = result.params.get("connector_id", connector.id) - final_delete_from_kb = result.params.get( - "delete_from_kb", delete_from_kb - ) - - if final_connector_id != connector.id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - and_( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id - == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, - ) - ) - ) - validated_connector = result.scalars().first() - if not validated_connector: - return { - "status": "error", - "message": "Selected OneDrive connector is invalid or has been disconnected.", - } - actual_connector_id = validated_connector.id - else: - actual_connector_id = connector.id - - logger.info( - f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}" - ) - - client = OneDriveClient( - session=db_session, connector_id=actual_connector_id - ) - await client.trash_file(final_file_id) - - logger.info( - f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}" - ) - - trash_result: dict[str, Any] = { - "status": "success", - "file_id": final_file_id, - "message": f"Successfully moved '{file_name}' to the recycle bin.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - doc = doc_result.scalars().first() - if doc: - await db_session.delete(doc) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" - ) - - return trash_result - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error(f"Error deleting OneDrive file: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while trashing the file. Please try again.", - } - - return delete_onedrive_file diff --git a/surfsense_backend/app/agents/new_chat/tools/podcast.py b/surfsense_backend/app/agents/new_chat/tools/podcast.py deleted file mode 100644 index 83ac98768..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/podcast.py +++ /dev/null @@ -1,160 +0,0 @@ -""" -Podcast generation tool for the SurfSense agent. - -This module provides a factory function for creating the generate_podcast tool -that submits a Celery task for background podcast generation. The tool then -polls the podcast row until it reaches a terminal status (READY/FAILED) and -returns that status. The wait is bounded by the chat's HTTP / process -lifetime; see app.agents.shared.deliverable_wait for details. -""" - -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.shared.deliverable_wait import wait_for_deliverable -from app.db import Podcast, PodcastStatus, shielded_async_session - -logger = logging.getLogger(__name__) - - -def create_generate_podcast_tool( - search_space_id: int, - db_session: AsyncSession, - thread_id: int | None = None, -): - """ - Factory function to create the generate_podcast tool with injected dependencies. - - Pre-creates podcast record with pending status so podcast_id is available - immediately for frontend polling. - - Args: - search_space_id: The user's search space ID - db_session: Reserved for future read-side use; the row is written via a - fresh, tool-local session so parallel tool calls (e.g. podcast + - video presentation in the same agent step) don't share an - ``AsyncSession`` (which is not concurrency-safe). - thread_id: The chat thread ID for associating the podcast - - Returns: - A configured tool function for generating podcasts - """ - del db_session # writes use a fresh tool-local session, see below - - @tool - async def generate_podcast( - source_content: str, - podcast_title: str = "SurfSense Podcast", - user_prompt: str | None = None, - ) -> dict[str, Any]: - """ - Generate a podcast from the provided content. - - Use this tool when the user asks to create, generate, or make a podcast. - Common triggers include phrases like: - - "Give me a podcast about this" - - "Create a podcast from this conversation" - - "Generate a podcast summary" - - "Make a podcast about..." - - "Turn this into a podcast" - - Args: - source_content: The text content to convert into a podcast. - podcast_title: Title for the podcast (default: "SurfSense Podcast") - user_prompt: Optional instructions for podcast style, tone, or format. - - Returns: - A dictionary containing: - - status: PodcastStatus value (pending, generating, or failed) - - podcast_id: The podcast ID for polling (when status is pending or generating) - - title: The podcast title - - message: Status message (or "error" field if status is failed) - """ - try: - # Open a fresh session per call. The streaming task's session is - # shared between every tool, and ``AsyncSession`` is NOT safe for - # concurrent use: when the LLM emits parallel tool calls, two - # concurrent ``add()`` / ``commit()`` paths interleave and the - # second one hits "Session.add() during flush" → the transaction - # is poisoned for both tools. - async with shielded_async_session() as session: - podcast = Podcast( - title=podcast_title, - status=PodcastStatus.PENDING, - search_space_id=search_space_id, - thread_id=thread_id, - ) - session.add(podcast) - await session.commit() - await session.refresh(podcast) - podcast_id = podcast.id - - from app.tasks.celery_tasks.podcast_tasks import ( - generate_content_podcast_task, - ) - - task = generate_content_podcast_task.delay( - podcast_id=podcast_id, - source_content=source_content, - search_space_id=search_space_id, - user_prompt=user_prompt, - ) - - logger.info( - "[generate_podcast] Created podcast %s, task: %s", - podcast_id, - task.id, - ) - - # Wait until the Celery worker flips the row to a terminal - # state. No internal budget — see deliverable_wait module. - terminal_status, columns, elapsed = await wait_for_deliverable( - model=Podcast, - row_id=podcast_id, - columns=[Podcast.status, Podcast.file_location], - terminal_statuses={PodcastStatus.READY, PodcastStatus.FAILED}, - ) - - if terminal_status == PodcastStatus.READY: - file_location = columns[1] if columns else None - logger.info( - "[generate_podcast] Podcast %s READY in %.2fs (file=%s)", - podcast_id, - elapsed, - file_location, - ) - return { - "status": PodcastStatus.READY.value, - "podcast_id": podcast_id, - "title": podcast_title, - "file_location": file_location, - "message": ("Podcast generated and saved to your podcast panel."), - } - - # Only other terminal state is FAILED. - logger.warning( - "[generate_podcast] Podcast %s FAILED in %.2fs", - podcast_id, - elapsed, - ) - return { - "status": PodcastStatus.FAILED.value, - "podcast_id": podcast_id, - "title": podcast_title, - "error": ("Background worker reported FAILED status for this podcast."), - } - - except Exception as e: - error_message = str(e) - logger.exception("[generate_podcast] Error: %s", error_message) - return { - "status": PodcastStatus.FAILED.value, - "error": error_message, - "title": podcast_title, - "podcast_id": None, - } - - return generate_podcast diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py deleted file mode 100644 index 6f011e372..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ /dev/null @@ -1,962 +0,0 @@ -"""Tools registry for SurfSense deep agent. - -This module provides a registry pattern for managing tools in the SurfSense agent. -It makes it easy for OSS contributors to add new tools by: -1. Creating a tool factory function in a new file in this directory -2. Registering the tool in the BUILTIN_TOOLS list below - -Example of adding a new tool: ------------------------------- -1. Create your tool file (e.g., `tools/my_tool.py`): - - from langchain_core.tools import tool - from sqlalchemy.ext.asyncio import AsyncSession - - def create_my_tool(search_space_id: int, db_session: AsyncSession): - @tool - async def my_tool(param: str) -> dict: - '''My tool description.''' - # Your implementation - return {"result": "success"} - return my_tool - -2. Import and register in this file: - - from .my_tool import create_my_tool - - # Add to BUILTIN_TOOLS list: - ToolDefinition( - name="my_tool", - description="Description of what your tool does", - factory=lambda deps: create_my_tool( - search_space_id=deps["search_space_id"], - db_session=deps["db_session"], - ), - requires=["search_space_id", "db_session"], - ), -""" - -import logging -from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - -from langchain_core.tools import BaseTool - -from app.agents.new_chat.middleware.dedup_tool_calls import ( - wrap_dedup_key_by_arg_name, -) -from app.db import ChatVisibility - -from .confluence import ( - create_create_confluence_page_tool, - create_delete_confluence_page_tool, - create_update_confluence_page_tool, -) -from .connected_accounts import create_get_connected_accounts_tool -from .discord import ( - create_list_discord_channels_tool, - create_read_discord_messages_tool, - create_send_discord_message_tool, -) -from .dropbox import ( - create_create_dropbox_file_tool, - create_delete_dropbox_file_tool, -) -from .generate_image import create_generate_image_tool -from .gmail import ( - create_create_gmail_draft_tool, - create_read_gmail_email_tool, - create_search_gmail_tool, - create_send_gmail_email_tool, - create_trash_gmail_email_tool, - create_update_gmail_draft_tool, -) -from .google_calendar import ( - create_create_calendar_event_tool, - create_delete_calendar_event_tool, - create_search_calendar_events_tool, - create_update_calendar_event_tool, -) -from .google_drive import ( - create_create_google_drive_file_tool, - create_delete_google_drive_file_tool, -) -from .luma import ( - create_create_luma_event_tool, - create_list_luma_events_tool, - create_read_luma_event_tool, -) -from .mcp_tool import load_mcp_tools -from .notion import ( - create_create_notion_page_tool, - create_delete_notion_page_tool, - create_update_notion_page_tool, -) -from .onedrive import ( - create_create_onedrive_file_tool, - create_delete_onedrive_file_tool, -) -from .podcast import create_generate_podcast_tool -from .report import create_generate_report_tool -from .resume import create_generate_resume_tool -from .scrape_webpage import create_scrape_webpage_tool -from .teams import ( - create_list_teams_channels_tool, - create_read_teams_messages_tool, - create_send_teams_message_tool, -) -from .update_memory import create_update_memory_tool, create_update_team_memory_tool -from .video_presentation import create_generate_video_presentation_tool -from .web_search import create_web_search_tool - -logger = logging.getLogger(__name__) - -# ============================================================================= -# Tool Definition -# ============================================================================= - - -@dataclass -class ToolDefinition: - """Definition of a tool that can be added to the agent. - - Attributes: - name: Unique identifier for the tool - description: Human-readable description of what the tool does - factory: Callable that creates the tool. Receives a dict of dependencies. - requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session") - enabled_by_default: Whether the tool is enabled when no explicit config is provided - required_connector: Searchable type string (e.g. ``"LINEAR_CONNECTOR"``) - that must be in ``available_connectors`` for the tool to be enabled. - dedup_key: Optional callable that maps a tool's ``args`` dict to a - string signature used by :class:`DedupHITLToolCallsMiddleware` - to drop duplicate calls within a single LLM response. - reverse: Optional callable that, given the tool's ``(args, result)``, - returns a ``ReverseDescriptor`` describing the inverse tool - invocation. Consumed by the snapshot/revert pipeline. - - """ - - name: str - description: str - factory: Callable[[dict[str, Any]], BaseTool] - requires: list[str] = field(default_factory=list) - enabled_by_default: bool = True - hidden: bool = False - required_connector: str | None = None - dedup_key: Callable[[dict[str, Any]], str] | None = None - reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None - - -# ============================================================================= -# Deferred-import factories -# ============================================================================= -# Used for tools whose impls live under ``multi_agent_chat``. Importing those -# at module-load time would cycle (``multi_agent_chat`` middleware imports -# this registry). The import inside the factory runs only when -# ``build_tools`` is called, by which point ``multi_agent_chat`` is fully -# initialised. - - -def _build_create_automation_tool(deps: dict[str, Any]) -> BaseTool: - from app.agents.multi_agent_chat.main_agent.tools.automation import ( - create_create_automation_tool, - ) - - return create_create_automation_tool( - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - llm=deps["llm"], - ) - - -# ============================================================================= -# Built-in Tools Registry -# ============================================================================= - -# Registry of all built-in tools -# Contributors: Add your new tools here! -BUILTIN_TOOLS: list[ToolDefinition] = [ - # Podcast generation tool - ToolDefinition( - name="generate_podcast", - description="Generate an audio podcast from provided content", - factory=lambda deps: create_generate_podcast_tool( - search_space_id=deps["search_space_id"], - db_session=deps["db_session"], - thread_id=deps["thread_id"], - ), - requires=["search_space_id", "db_session", "thread_id"], - ), - # Video presentation generation tool - ToolDefinition( - name="generate_video_presentation", - description="Generate a video presentation with slides and narration from provided content", - factory=lambda deps: create_generate_video_presentation_tool( - search_space_id=deps["search_space_id"], - db_session=deps["db_session"], - thread_id=deps["thread_id"], - ), - requires=["search_space_id", "db_session", "thread_id"], - ), - # Report generation tool (inline, short-lived sessions for DB ops) - # Supports internal KB search via source_strategy so the agent does not - # need a separate search step before generating. - ToolDefinition( - name="generate_report", - description="Generate a structured report from provided content and export it", - factory=lambda deps: create_generate_report_tool( - search_space_id=deps["search_space_id"], - thread_id=deps["thread_id"], - connector_service=deps.get("connector_service"), - available_connectors=deps.get("available_connectors"), - available_document_types=deps.get("available_document_types"), - ), - requires=["search_space_id", "thread_id"], - # connector_service, available_connectors, and available_document_types - # are optional — when missing, source_strategy="kb_search" degrades - # gracefully to "provided" - ), - # Resume generation tool (Typst-based, uses rendercv package) - ToolDefinition( - name="generate_resume", - description="Generate a professional resume as a Typst document", - factory=lambda deps: create_generate_resume_tool( - search_space_id=deps["search_space_id"], - thread_id=deps["thread_id"], - ), - requires=["search_space_id", "thread_id"], - ), - # Generate image tool - creates images using AI models (DALL-E, GPT Image, etc.) - ToolDefinition( - name="generate_image", - description="Generate images from text descriptions using AI image models", - factory=lambda deps: create_generate_image_tool( - search_space_id=deps["search_space_id"], - db_session=deps["db_session"], - ), - requires=["search_space_id", "db_session"], - ), - # Web scraping tool - extracts content from webpages - ToolDefinition( - name="scrape_webpage", - description="Scrape and extract the main content from a webpage", - factory=lambda deps: create_scrape_webpage_tool( - firecrawl_api_key=deps.get("firecrawl_api_key"), - ), - requires=[], # firecrawl_api_key is optional - ), - # Web search tool — real-time web search via SearXNG + user-configured engines - ToolDefinition( - name="web_search", - description="Search the web for real-time information using configured search engines", - factory=lambda deps: create_web_search_tool( - search_space_id=deps.get("search_space_id"), - available_connectors=deps.get("available_connectors"), - ), - requires=[], - ), - # ========================================================================= - # SERVICE ACCOUNT DISCOVERY - # Generic tool for the LLM to discover connected accounts and resolve - # service-specific identifiers (e.g. Jira cloudId, Slack team, etc.) - # ========================================================================= - ToolDefinition( - name="get_connected_accounts", - description="Discover connected accounts for a service and their metadata", - factory=lambda deps: create_get_connected_accounts_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - ), - # ========================================================================= - # AUTOMATION AUTHORING - single HITL tool. The tool takes an NL ``intent`` - # from the main agent, drafts the full AutomationCreate JSON via a focused - # sub-LLM, surfaces it on an approval card, and persists on approval. The - # factory defers its import because the impl lives under ``multi_agent_chat`` - # and that package transitively pulls this registry via middleware; - # deferring to ``build_tools`` call-time breaks the cycle without a - # parallel registry. - # ========================================================================= - ToolDefinition( - name="create_automation", - description="Draft an automation from an NL intent; user approves the card; tool saves", - factory=_build_create_automation_tool, - requires=["search_space_id", "user_id", "llm"], - ), - # ========================================================================= - # MEMORY TOOL - single update_memory, private or team by thread_visibility - # ========================================================================= - ToolDefinition( - name="update_memory", - description="Save important long-term facts, preferences, and instructions to the (personal or team) memory", - factory=lambda deps: ( - create_update_team_memory_tool( - search_space_id=deps["search_space_id"], - db_session=deps["db_session"], - llm=deps.get("llm"), - ) - if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE - else create_update_memory_tool( - user_id=deps["user_id"], - db_session=deps["db_session"], - llm=deps.get("llm"), - ) - ), - requires=[ - "user_id", - "search_space_id", - "db_session", - "thread_visibility", - "llm", - ], - ), - # ========================================================================= - # NOTION TOOLS - create, update, delete pages - # Auto-disabled when no Notion connector is configured (see chat_deepagent.py) - # ========================================================================= - ToolDefinition( - name="create_notion_page", - description="Create a new page in the user's Notion workspace", - factory=lambda deps: create_create_notion_page_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="NOTION_CONNECTOR", - dedup_key=wrap_dedup_key_by_arg_name("title"), - ), - ToolDefinition( - name="update_notion_page", - description="Append new content to an existing Notion page", - factory=lambda deps: create_update_notion_page_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="NOTION_CONNECTOR", - dedup_key=wrap_dedup_key_by_arg_name("page_title"), - ), - ToolDefinition( - name="delete_notion_page", - description="Delete an existing Notion page", - factory=lambda deps: create_delete_notion_page_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="NOTION_CONNECTOR", - dedup_key=wrap_dedup_key_by_arg_name("page_title"), - ), - # ========================================================================= - # GOOGLE DRIVE TOOLS - create files, delete files - # Auto-disabled when no Google Drive connector is configured (see chat_deepagent.py) - # ========================================================================= - ToolDefinition( - name="create_google_drive_file", - description="Create a new Google Doc or Google Sheet in Google Drive", - factory=lambda deps: create_create_google_drive_file_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="GOOGLE_DRIVE_FILE", - dedup_key=wrap_dedup_key_by_arg_name("file_name"), - ), - ToolDefinition( - name="delete_google_drive_file", - description="Move an indexed Google Drive file to trash", - factory=lambda deps: create_delete_google_drive_file_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="GOOGLE_DRIVE_FILE", - dedup_key=wrap_dedup_key_by_arg_name("file_name"), - ), - # ========================================================================= - # DROPBOX TOOLS - create and trash files - # Auto-disabled when no Dropbox connector is configured (see chat_deepagent.py) - # ========================================================================= - ToolDefinition( - name="create_dropbox_file", - description="Create a new file in Dropbox", - factory=lambda deps: create_create_dropbox_file_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="DROPBOX_FILE", - dedup_key=wrap_dedup_key_by_arg_name("file_name"), - ), - ToolDefinition( - name="delete_dropbox_file", - description="Delete a file from Dropbox", - factory=lambda deps: create_delete_dropbox_file_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="DROPBOX_FILE", - dedup_key=wrap_dedup_key_by_arg_name("file_name"), - ), - # ========================================================================= - # ONEDRIVE TOOLS - create and trash files - # Auto-disabled when no OneDrive connector is configured (see chat_deepagent.py) - # ========================================================================= - ToolDefinition( - name="create_onedrive_file", - description="Create a new file in Microsoft OneDrive", - factory=lambda deps: create_create_onedrive_file_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="ONEDRIVE_FILE", - dedup_key=wrap_dedup_key_by_arg_name("file_name"), - ), - ToolDefinition( - name="delete_onedrive_file", - description="Move a OneDrive file to the recycle bin", - factory=lambda deps: create_delete_onedrive_file_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="ONEDRIVE_FILE", - dedup_key=wrap_dedup_key_by_arg_name("file_name"), - ), - # ========================================================================= - # GOOGLE CALENDAR TOOLS - search, create, update, delete events - # Auto-disabled when no Google Calendar connector is configured - # ========================================================================= - ToolDefinition( - name="search_calendar_events", - description="Search Google Calendar events within a date range", - factory=lambda deps: create_search_calendar_events_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="GOOGLE_CALENDAR_CONNECTOR", - ), - ToolDefinition( - name="create_calendar_event", - description="Create a new event on Google Calendar", - factory=lambda deps: create_create_calendar_event_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="GOOGLE_CALENDAR_CONNECTOR", - dedup_key=wrap_dedup_key_by_arg_name("title"), - ), - ToolDefinition( - name="update_calendar_event", - description="Update an existing indexed Google Calendar event", - factory=lambda deps: create_update_calendar_event_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="GOOGLE_CALENDAR_CONNECTOR", - dedup_key=wrap_dedup_key_by_arg_name("event_title_or_id"), - ), - ToolDefinition( - name="delete_calendar_event", - description="Delete an existing indexed Google Calendar event", - factory=lambda deps: create_delete_calendar_event_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="GOOGLE_CALENDAR_CONNECTOR", - dedup_key=wrap_dedup_key_by_arg_name("event_title_or_id"), - ), - # ========================================================================= - # GMAIL TOOLS - search, read, create drafts, update drafts, send, trash - # Auto-disabled when no Gmail connector is configured - # ========================================================================= - ToolDefinition( - name="search_gmail", - description="Search emails in Gmail using Gmail search syntax", - factory=lambda deps: create_search_gmail_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="GOOGLE_GMAIL_CONNECTOR", - ), - ToolDefinition( - name="read_gmail_email", - description="Read the full content of a specific Gmail email", - factory=lambda deps: create_read_gmail_email_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="GOOGLE_GMAIL_CONNECTOR", - ), - ToolDefinition( - name="create_gmail_draft", - description="Create a draft email in Gmail", - factory=lambda deps: create_create_gmail_draft_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="GOOGLE_GMAIL_CONNECTOR", - dedup_key=wrap_dedup_key_by_arg_name("subject"), - ), - ToolDefinition( - name="send_gmail_email", - description="Send an email via Gmail", - factory=lambda deps: create_send_gmail_email_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="GOOGLE_GMAIL_CONNECTOR", - dedup_key=wrap_dedup_key_by_arg_name("subject"), - ), - ToolDefinition( - name="trash_gmail_email", - description="Move an indexed email to trash in Gmail", - factory=lambda deps: create_trash_gmail_email_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="GOOGLE_GMAIL_CONNECTOR", - dedup_key=wrap_dedup_key_by_arg_name("email_subject_or_id"), - ), - ToolDefinition( - name="update_gmail_draft", - description="Update an existing Gmail draft", - factory=lambda deps: create_update_gmail_draft_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="GOOGLE_GMAIL_CONNECTOR", - dedup_key=wrap_dedup_key_by_arg_name("draft_subject_or_id"), - ), - # ========================================================================= - # CONFLUENCE TOOLS - create, update, delete pages - # Auto-disabled when no Confluence connector is configured (see chat_deepagent.py) - # ========================================================================= - ToolDefinition( - name="create_confluence_page", - description="Create a new page in the user's Confluence space", - factory=lambda deps: create_create_confluence_page_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="CONFLUENCE_CONNECTOR", - dedup_key=wrap_dedup_key_by_arg_name("title"), - ), - ToolDefinition( - name="update_confluence_page", - description="Update an existing indexed Confluence page", - factory=lambda deps: create_update_confluence_page_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="CONFLUENCE_CONNECTOR", - dedup_key=wrap_dedup_key_by_arg_name("page_title_or_id"), - ), - ToolDefinition( - name="delete_confluence_page", - description="Delete an existing indexed Confluence page", - factory=lambda deps: create_delete_confluence_page_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="CONFLUENCE_CONNECTOR", - dedup_key=wrap_dedup_key_by_arg_name("page_title_or_id"), - ), - # ========================================================================= - # DISCORD TOOLS - list channels, read messages, send messages - # Auto-disabled when no Discord connector is configured - # ========================================================================= - ToolDefinition( - name="list_discord_channels", - description="List text channels in the connected Discord server", - factory=lambda deps: create_list_discord_channels_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="DISCORD_CONNECTOR", - ), - ToolDefinition( - name="read_discord_messages", - description="Read recent messages from a Discord text channel", - factory=lambda deps: create_read_discord_messages_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="DISCORD_CONNECTOR", - ), - ToolDefinition( - name="send_discord_message", - description="Send a message to a Discord text channel", - factory=lambda deps: create_send_discord_message_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="DISCORD_CONNECTOR", - ), - # ========================================================================= - # TEAMS TOOLS - list channels, read messages, send messages - # Auto-disabled when no Teams connector is configured - # ========================================================================= - ToolDefinition( - name="list_teams_channels", - description="List Microsoft Teams and their channels", - factory=lambda deps: create_list_teams_channels_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="TEAMS_CONNECTOR", - ), - ToolDefinition( - name="read_teams_messages", - description="Read recent messages from a Microsoft Teams channel", - factory=lambda deps: create_read_teams_messages_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="TEAMS_CONNECTOR", - ), - ToolDefinition( - name="send_teams_message", - description="Send a message to a Microsoft Teams channel", - factory=lambda deps: create_send_teams_message_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="TEAMS_CONNECTOR", - ), - # ========================================================================= - # LUMA TOOLS - list events, read event details, create events - # Auto-disabled when no Luma connector is configured - # ========================================================================= - ToolDefinition( - name="list_luma_events", - description="List upcoming and recent Luma events", - factory=lambda deps: create_list_luma_events_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="LUMA_CONNECTOR", - ), - ToolDefinition( - name="read_luma_event", - description="Read detailed information about a specific Luma event", - factory=lambda deps: create_read_luma_event_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="LUMA_CONNECTOR", - ), - ToolDefinition( - name="create_luma_event", - description="Create a new event on Luma", - factory=lambda deps: create_create_luma_event_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - required_connector="LUMA_CONNECTOR", - ), -] - - -# ============================================================================= -# Registry Functions -# ============================================================================= - - -def get_tool_by_name(name: str) -> ToolDefinition | None: - """Get a tool definition by its name.""" - for tool_def in BUILTIN_TOOLS: - if tool_def.name == name: - return tool_def - return None - - -def get_connector_gated_tools( - available_connectors: list[str] | None, -) -> list[str]: - """Return tool names to disable""" - available = set() if available_connectors is None else set(available_connectors) - - disabled: list[str] = [] - for tool_def in BUILTIN_TOOLS: - if tool_def.required_connector and tool_def.required_connector not in available: - disabled.append(tool_def.name) - return disabled - - -def get_all_tool_names() -> list[str]: - """Get names of all registered tools.""" - return [tool_def.name for tool_def in BUILTIN_TOOLS] - - -def get_default_enabled_tools() -> list[str]: - """Get names of tools that are enabled by default (excludes hidden tools).""" - return [tool_def.name for tool_def in BUILTIN_TOOLS if tool_def.enabled_by_default] - - -def build_tools( - dependencies: dict[str, Any], - enabled_tools: list[str] | None = None, - disabled_tools: list[str] | None = None, - additional_tools: list[BaseTool] | None = None, -) -> list[BaseTool]: - """Build the list of tools for the agent. - - Args: - dependencies: Dict containing all possible dependencies: - - search_space_id: The search space ID - - db_session: Database session - - connector_service: Connector service instance - - firecrawl_api_key: Optional Firecrawl API key - enabled_tools: Explicit list of tool names to enable. If None, uses defaults. - disabled_tools: List of tool names to disable (applied after enabled_tools). - additional_tools: Extra tools to add (e.g., custom tools not in registry). - - Returns: - List of configured tool instances ready for the agent. - - Example: - # Use all default tools - tools = build_tools(deps) - - # Use only specific tools - tools = build_tools(deps, enabled_tools=["generate_report"]) - - # Use defaults but disable podcast - tools = build_tools(deps, disabled_tools=["generate_podcast"]) - - # Add custom tools - tools = build_tools(deps, additional_tools=[my_custom_tool]) - - """ - # Determine which tools to enable - if enabled_tools is not None: - tool_names_to_use = set(enabled_tools) - else: - tool_names_to_use = set(get_default_enabled_tools()) - - # Apply disabled list - if disabled_tools: - tool_names_to_use -= set(disabled_tools) - - # Build the tools (skip hidden/WIP tools unconditionally) - tools: list[BaseTool] = [] - for tool_def in BUILTIN_TOOLS: - if tool_def.hidden or tool_def.name not in tool_names_to_use: - continue - - # Check that all required dependencies are provided - missing_deps = [dep for dep in tool_def.requires if dep not in dependencies] - if missing_deps: - msg = f"Tool '{tool_def.name}' requires dependencies: {missing_deps}" - raise ValueError( - msg, - ) - - # Create the tool - tool = tool_def.factory(dependencies) - # Propagate the registry-level metadata so middleware (e.g. - # ``DedupHITLToolCallsMiddleware``) and the action-log/revert - # pipeline can pick the resolvers up via ``tool.metadata`` without - # re-importing :data:`BUILTIN_TOOLS`. - if tool_def.dedup_key is not None or tool_def.reverse is not None: - existing_meta = getattr(tool, "metadata", None) or {} - merged_meta = dict(existing_meta) - if tool_def.dedup_key is not None: - merged_meta.setdefault("dedup_key", tool_def.dedup_key) - if tool_def.reverse is not None: - merged_meta.setdefault("reverse", tool_def.reverse) - try: - tool.metadata = merged_meta - except Exception: - logger.debug( - "Tool %s rejected metadata mutation; relying on registry lookup", - tool_def.name, - ) - tools.append(tool) - - # Add any additional custom tools - if additional_tools: - tools.extend(additional_tools) - - return tools - - -async def build_tools_async( - dependencies: dict[str, Any], - enabled_tools: list[str] | None = None, - disabled_tools: list[str] | None = None, - additional_tools: list[BaseTool] | None = None, - include_mcp_tools: bool = True, -) -> list[BaseTool]: - """Async version of build_tools that also loads MCP tools from database. - - Design Note: - This function exists because MCP tools require database queries to load - user configs, while built-in tools are created synchronously from static - code. - - Alternative: We could make build_tools() itself async and always query - the database, but that would force async everywhere even when only using - built-in tools. The current design keeps the simple case (static tools - only) synchronous while supporting dynamic database-loaded tools through - this async wrapper. - - Phase 1.3: built-in tool construction (CPU; runs in a thread pool to - avoid event-loop stalls) and MCP tool loading (HTTP/DB I/O; runs on - the event loop) are kicked off concurrently. Cold-path savings are - bounded by the slower of the two — typically MCP at ~200ms-1.7s — - so the parallelization recovers the ~50-200ms previously spent - serially on built-in construction. - - Args: - dependencies: Dict containing all possible dependencies - enabled_tools: Explicit list of tool names to enable. If None, uses defaults. - disabled_tools: List of tool names to disable (applied after enabled_tools). - additional_tools: Extra tools to add (e.g., custom tools not in registry). - include_mcp_tools: Whether to load user's MCP tools from database. - - Returns: - List of configured tool instances ready for the agent, including MCP tools. - - """ - import asyncio - import time - - _perf_log = logging.getLogger("surfsense.perf") - _perf_log.setLevel(logging.DEBUG) - - can_load_mcp = ( - include_mcp_tools - and "db_session" in dependencies - and "search_space_id" in dependencies - ) - - # Built-in tool construction is synchronous + CPU-only. Off-loop it so - # MCP's HTTP/DB I/O can fire concurrently. ``build_tools`` is pure - # function over its inputs — safe to thread-shift. - _t0 = time.perf_counter() - builtin_task = asyncio.create_task( - asyncio.to_thread( - build_tools, dependencies, enabled_tools, disabled_tools, additional_tools - ) - ) - - mcp_task: asyncio.Task | None = None - if can_load_mcp: - mcp_task = asyncio.create_task( - load_mcp_tools( - dependencies["db_session"], - dependencies["search_space_id"], - ) - ) - - # Surface failures from each task independently so a flaky MCP - # endpoint never poisons built-in tool registration. ``return_exceptions`` - # gives us per-task exceptions instead of dropping the second result - # when the first raises. - if mcp_task is not None: - builtin_result, mcp_result = await asyncio.gather( - builtin_task, mcp_task, return_exceptions=True - ) - else: - builtin_result = await builtin_task - mcp_result = None - - if isinstance(builtin_result, BaseException): - raise builtin_result # built-in registration failure is non-recoverable - tools: list[BaseTool] = builtin_result - _perf_log.info( - "[build_tools_async] Built-in tools in %.3fs (%d tools, parallel)", - time.perf_counter() - _t0, - len(tools), - ) - - if mcp_task is not None: - if isinstance(mcp_result, BaseException): - # ``return_exceptions=True`` captures the exception out-of-band, - # so ``sys.exc_info()`` is empty here. Pass the captured - # exception via ``exc_info=`` to get a real traceback. - logging.error( - "Failed to load MCP tools: %s", mcp_result, exc_info=mcp_result - ) - else: - mcp_tools = mcp_result or [] - _perf_log.info( - "[build_tools_async] MCP tools loaded in %.3fs (%d tools, parallel)", - time.perf_counter() - _t0, - len(mcp_tools), - ) - tools.extend(mcp_tools) - logging.info( - "Registered %d MCP tools: %s", - len(mcp_tools), - [t.name for t in mcp_tools], - ) - - logging.info( - "Total tools for agent: %d — %s", - len(tools), - [t.name for t in tools], - ) - - return tools diff --git a/surfsense_backend/app/agents/new_chat/tools/report.py b/surfsense_backend/app/agents/new_chat/tools/report.py deleted file mode 100644 index 8c0bd95ea..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/report.py +++ /dev/null @@ -1,1084 +0,0 @@ -""" -Report generation tool for the SurfSense agent. - -This module provides a factory function for creating the generate_report tool -that generates a structured Markdown report inline (no Celery). The LLM is -called within the tool, the result is saved to the database, and the tool -returns immediately with a ready status. - -Uses short-lived database sessions to avoid holding connections during long -LLM calls (30-120+ seconds). Each DB operation (read config, save report) -opens and closes its own session, ensuring no connection is held idle during -the LLM API call. - -Generation strategies: - - Single-shot generation for all new reports - - Section-level revision for targeted edits (preserves unchanged sections) - - Full-document revision as fallback for global changes - -Source strategies (how source content is collected): - - "provided" — Use only the supplied source_content (default, backward-compat) - - "conversation" — Same as "provided"; agent passes conversation summary - - "kb_search" — Tool searches knowledge base internally with targeted queries - - "auto" — Use source_content if sufficient, else search KB as fallback -""" - -import asyncio -import json -import logging -import re -from typing import Any - -from langchain_core.callbacks import dispatch_custom_event -from langchain_core.messages import HumanMessage -from langchain_core.tools import tool - -from app.db import Report, shielded_async_session -from app.services.connector_service import ConnectorService -from app.services.llm_service import get_agent_llm - -logger = logging.getLogger(__name__) - -# ─── Shared Formatting Rules ──────────────────────────────────────────────── -# Reusable formatting instructions appended to section-level and review prompts. - -_FORMATTING_RULES = """\ -- IMPORTANT: Output raw Markdown directly. Do NOT wrap the entire output in a \ -code fence (e.g. ```markdown, ````markdown, or any backtick fence). Individual \ -code examples and diagrams inside the report should still use fenced code blocks, \ -but the report itself must NOT be enclosed in one. -- Maintain proper Markdown formatting throughout. -- When including code examples, ALWAYS format them as proper fenced code blocks \ -with the correct language identifier (e.g. ```java, ```python). Code inside code \ -blocks MUST have proper line breaks and indentation — NEVER put multiple statements \ -on a single line. Each statement, brace, and logical block must be on its own line \ -with correct indentation. -- When including Mermaid diagrams, use ```mermaid fenced code blocks. Each Mermaid \ -statement MUST be on its own line — NEVER use semicolons to join multiple statements \ -on one line. For line breaks inside node labels, use
(NOT
). -- When including mathematical formulas or equations, ALWAYS use LaTeX notation. \ -NEVER use backtick code spans or Unicode symbols for math.""" - -# ─── Standard Report Footer ───────────────────────────────────────────────── -# Appended to every generated report after content generation. - -_REPORT_FOOTER = "Powered by SurfSense AI." - -# ─── Prompt: Single-Shot Report Generation ─────────────────────────────────── - -_REPORT_PROMPT = """You are an expert report writer. Generate a comprehensive Markdown report. - -**Topic:** {topic} -**Report Style:** {report_style} -{user_instructions_section} -{previous_version_section} - -**Source Content:** -{source_content} - ---- - -{length_instruction} - -Write a well-structured Markdown report with a # title, executive summary, organized sections, and conclusion. Cite facts from the source content. Be thorough and professional. - -{formatting_rules} -""" - -# ─── Prompt: Full-Document Revision (fallback when section-level fails) ────── - -_REVISION_PROMPT = """You are an expert report editor. Apply ONLY the requested changes — do NOT rewrite from scratch. - -**Topic:** {topic} -**Report Style:** {report_style} -**Modification Instructions:** {user_instructions_section} - -**Source Content (use if relevant):** -{source_content} - ---- - -**EXISTING REPORT:** - -{previous_report_content} - ---- - -{length_instruction} - -Preserve all structure and content not affected by the modification. - -{formatting_rules} -""" - -# ─── Prompt: Section-Level Revision — Identify Affected Sections ───────────── - -_IDENTIFY_SECTIONS_PROMPT = """You are analyzing a Markdown report to determine which sections need modification based on the user's request. - -**User's Modification Request:** {user_instructions} - -**Report Sections (indexed starting at 0):** -{sections_listing} - ---- - -Determine which sections need to be modified, added, or removed to fulfill the user's request. - -Return ONLY a JSON object with these fields: -- "modify": Array of section indices (0-based) that need content changes -- "add": Array of objects like {{"after_index": 2, "heading": "## New Section Title", "description": "What this section should cover"}} for new sections to insert -- "remove": Array of section indices to remove entirely (use sparingly) -- "reasoning": A brief explanation of your decisions - -Guidelines: -- If the change is GLOBAL (e.g., "change the tone", "make the whole report shorter", "translate to Spanish"), include ALL section indices in "modify". -- If the change is TARGETED (e.g., "expand the budget section", "fix the conclusion"), include ONLY the affected section indices. -- For "add a section about X", use the "add" field with the appropriate insertion point. -- Prefer modifying over removing+adding when possible. - -Return ONLY valid JSON, no markdown fences: -""" - -# ─── Prompt: Section-Level Revision — Revise a Single Section ──────────────── - -_REVISE_SECTION_PROMPT = """Revise ONLY this section based on the instructions. If the instructions don't apply, return it UNCHANGED. - -**Modification Instructions:** {user_instructions} - -**Current Section:** -{section_content} - -**Context (surrounding sections — for coherence only, do NOT output them):** -{context_sections} - -**Source Content:** -{source_content} - ---- - -Keep the same heading and heading level. Preserve content not affected by the modification. -{formatting_rules} -""" - -# ─── Prompt: New Section Generation (for section-level add) ───────────────── - -_NEW_SECTION_PROMPT = """You are an expert report writer. Write a new section to be inserted into an existing report. - -**Report Topic:** {topic} -**Report Style:** {report_style} -**Section Heading:** {heading} -**Section Goal:** {description} -**User Instructions:** {user_instructions} - -**Surrounding Context:** -{context_sections} - -**Source Content:** -{source_content} - ---- - -**Rules:** -1. Write ONLY this section, starting with the heading "{heading}". -2. Ensure the section flows naturally with the surrounding context. -3. Be comprehensive — cover the topic described above. -{formatting_rules} - -Write the new section now: -""" - - -# ─── Utility Functions ────────────────────────────────────────────────────── - - -def _strip_wrapping_code_fences(text: str) -> str: - """Remove wrapping code fences that LLMs often add around Markdown output. - - Handles patterns like: - ```markdown\\n...content...\\n``` - ````markdown\\n...content...\\n```` - ```md\\n...content...\\n``` - ```\\n...content...\\n``` - ```json\\n...content...\\n``` - Supports 3 or more backticks (LLMs escalate when content has triple-backtick blocks). - """ - stripped = text.strip() - # Match opening fence with 3+ backticks and optional language tag - m = re.match(r"^(`{3,})(?:markdown|md|json)?\s*\n", stripped) - if m: - fence = m.group(1) # e.g. "```" or "````" - if stripped.endswith(fence): - stripped = stripped[m.end() :] # remove opening fence - stripped = stripped[: -len(fence)].rstrip() # remove closing fence - return stripped - - -def _extract_metadata(content: str) -> dict[str, Any]: - """Extract metadata from generated Markdown content.""" - # Count section headings - headings = re.findall(r"^(#{1,6})\s+(.+)$", content, re.MULTILINE) - - # Word count - word_count = len(content.split()) - - # Character count - char_count = len(content) - - return { - "status": "ready", - "word_count": word_count, - "char_count": char_count, - "section_count": len(headings), - } - - -def _parse_sections(content: str) -> list[dict[str, str]]: - """Parse Markdown content into sections split by # and ## headings. - - Returns a list of dicts: [{"heading": "## Title", "body": "content..."}, ...] - Content before the first heading is captured with heading="". - ### and deeper headings are kept inside their parent ## section's body. - """ - lines = content.split("\n") - sections: list[dict[str, str]] = [] - current_heading = "" - current_body_lines: list[str] = [] - in_code_block = False - - for line in lines: - # Track code blocks to avoid matching headings inside them - stripped = line.strip() - if stripped.startswith("```"): - in_code_block = not in_code_block - - # Only split on # or ## headings (not ### or deeper) and only outside code blocks - is_section_heading = ( - not in_code_block - and re.match(r"^#{1,2}\s+", line) - and not re.match(r"^#{3,}\s+", line) - ) - - if is_section_heading: - # Save previous section - if current_heading or current_body_lines: - sections.append( - { - "heading": current_heading, - "body": "\n".join(current_body_lines).strip(), - } - ) - current_heading = line.strip() - current_body_lines = [] - else: - current_body_lines.append(line) - - # Save last section - if current_heading or current_body_lines: - sections.append( - { - "heading": current_heading, - "body": "\n".join(current_body_lines).strip(), - } - ) - - return sections - - -def _stitch_sections(sections: list[dict[str, str]]) -> str: - """Stitch parsed sections back into a single Markdown string.""" - parts = [] - for section in sections: - if section["heading"]: - parts.append(section["heading"]) - if section["body"]: - parts.append(section["body"]) - return "\n\n".join(parts) - - -# ─── Async Generation Helpers ─────────────────────────────────────────────── - - -async def _revise_with_sections( - llm: Any, - parent_content: str, - user_instructions: str, - source_content: str, - topic: str, - report_style: str, -) -> str | None: - """Section-level revision: identify affected sections and revise only those. - - Unchanged sections are kept byte-for-byte identical. - Returns the revised content, or None to trigger full-document revision fallback. - """ - # Parse report into sections - sections = _parse_sections(parent_content) - if len(sections) < 2: - logger.info( - "[generate_report] Too few sections for section-level revision, using full revision" - ) - return None - - # Build a sections listing for the LLM - sections_listing = "" - for i, sec in enumerate(sections): - heading = sec["heading"] or "(preamble — content before first heading)" - body_preview = ( - sec["body"][:200] + "..." if len(sec["body"]) > 200 else sec["body"] - ) - sections_listing += f"\n[{i}] {heading}\n Preview: {body_preview}\n" - - # Step 1: Ask LLM which sections need modification - identify_prompt = _IDENTIFY_SECTIONS_PROMPT.format( - user_instructions=user_instructions, - sections_listing=sections_listing, - ) - - try: - response = await llm.ainvoke([HumanMessage(content=identify_prompt)]) - raw = response.content - if not raw or not isinstance(raw, str): - return None - - raw = _strip_wrapping_code_fences(raw).strip() - json_match = re.search(r"\{[\s\S]*\}", raw) - if json_match: - raw = json_match.group(0) - - plan = json.loads(raw) - modify_indices: list[int] = plan.get("modify", []) - add_sections: list[dict[str, Any]] = plan.get("add", []) - remove_indices: list[int] = plan.get("remove", []) - reasoning = plan.get("reasoning", "") - - logger.info( - f"[generate_report] Section-level revision plan: " - f"modify={modify_indices}, add={len(add_sections)}, " - f"remove={remove_indices}, reasoning={reasoning}" - ) - except Exception: - logger.warning( - "[generate_report] Failed to identify sections for revision, " - "falling back to full revision", - exc_info=True, - ) - return None - - # If ALL sections need modification, full revision is more efficient and coherent - if len(modify_indices) >= len(sections): - logger.info( - "[generate_report] All sections need modification, deferring to full revision" - ) - return None - - # Compute total operations for progress tracking - total_ops = len(modify_indices) + len(add_sections) - current_op = 0 - - # Emit plan summary - parts = [] - if modify_indices: - parts.append( - f"modifying {len(modify_indices)} section{'s' if len(modify_indices) > 1 else ''}" - ) - if add_sections: - parts.append( - f"adding {len(add_sections)} new section{'s' if len(add_sections) > 1 else ''}" - ) - if remove_indices: - parts.append( - f"removing {len(remove_indices)} section{'s' if len(remove_indices) > 1 else ''}" - ) - plan_summary = ", ".join(parts) if parts else "no changes needed" - - dispatch_custom_event( - "report_progress", - { - "phase": "revision_plan", - "message": plan_summary.capitalize(), - "modify_count": len(modify_indices), - "add_count": len(add_sections), - "remove_count": len(remove_indices), - "total_ops": total_ops, - }, - ) - - # Step 2: Revise only the affected sections - revised_sections = list(sections) # shallow copy — unmodified sections stay as-is - - for idx in modify_indices: - if idx < 0 or idx >= len(sections): - continue - - current_op += 1 - sec = sections[idx] - - # Extract plain section name (strip markdown heading markers) - section_name = ( - re.sub(r"^#+\s*", "", sec["heading"]).strip() - if sec["heading"] - else "Preamble" - ) - dispatch_custom_event( - "report_progress", - { - "phase": "revising_section", - "message": f"Revising: {section_name} ({current_op}/{total_ops})...", - }, - ) - - section_content = ( - f"{sec['heading']}\n\n{sec['body']}" if sec["heading"] else sec["body"] - ) - - # Build context from surrounding sections - context_parts = [] - if idx > 0: - prev = sections[idx - 1] - prev_preview = prev["body"][:300] + ( - "..." if len(prev["body"]) > 300 else "" - ) - context_parts.append( - f"**Previous section:** {prev['heading']}\n{prev_preview}" - ) - if idx < len(sections) - 1: - nxt = sections[idx + 1] - nxt_preview = nxt["body"][:300] + ("..." if len(nxt["body"]) > 300 else "") - context_parts.append(f"**Next section:** {nxt['heading']}\n{nxt_preview}") - context = ( - "\n\n".join(context_parts) if context_parts else "(No surrounding sections)" - ) - - revise_prompt = _REVISE_SECTION_PROMPT.format( - user_instructions=user_instructions, - section_content=section_content, - context_sections=context, - source_content=source_content[:40000], - formatting_rules=_FORMATTING_RULES, - ) - - resp = await llm.ainvoke([HumanMessage(content=revise_prompt)]) - revised_text = resp.content - if revised_text and isinstance(revised_text, str): - revised_text = _strip_wrapping_code_fences(revised_text).strip() - # Parse the LLM output back into heading + body - revised_parsed = _parse_sections(revised_text) - if revised_parsed: - revised_sections[idx] = revised_parsed[0] - else: - revised_sections[idx] = { - "heading": sec["heading"], - "body": revised_text, - } - - logger.info(f"[generate_report] Revised section [{idx}]: {sec['heading']}") - - # Step 3: Handle new section additions (insert in reverse order to preserve indices) - for add_info in sorted( - add_sections, - key=lambda x: x.get("after_index", len(revised_sections) - 1), - reverse=True, - ): - current_op += 1 - after_idx = add_info.get("after_index", len(revised_sections) - 1) - heading = add_info.get("heading", "## New Section") - description = add_info.get("description", "") - - # Extract plain section name for progress display - plain_heading = re.sub(r"^#+\s*", "", heading).strip() - dispatch_custom_event( - "report_progress", - { - "phase": "adding_section", - "message": f"Adding: {plain_heading} ({current_op}/{total_ops})...", - }, - ) - - # Build context from the surrounding sections at the insertion point - ctx_parts = [] - if 0 <= after_idx < len(revised_sections): - before_sec = revised_sections[after_idx] - ctx_parts.append( - f"**Section before:** {before_sec['heading']}\n{before_sec['body'][:300]}" - ) - insert_idx = min(after_idx + 1, len(revised_sections)) - if insert_idx < len(revised_sections): - after_sec = revised_sections[insert_idx] - ctx_parts.append( - f"**Section after:** {after_sec['heading']}\n{after_sec['body'][:300]}" - ) - - new_prompt = _NEW_SECTION_PROMPT.format( - topic=topic, - report_style=report_style, - heading=heading, - description=description, - user_instructions=user_instructions, - context_sections="\n\n".join(ctx_parts) if ctx_parts else "(None)", - source_content=source_content[:30000], - formatting_rules=_FORMATTING_RULES, - ) - - resp = await llm.ainvoke([HumanMessage(content=new_prompt)]) - new_content = resp.content - if new_content and isinstance(new_content, str): - new_content = _strip_wrapping_code_fences(new_content).strip() - new_parsed = _parse_sections(new_content) - if new_parsed: - revised_sections.insert(insert_idx, new_parsed[0]) - else: - revised_sections.insert( - insert_idx, - { - "heading": heading, - "body": new_content, - }, - ) - - logger.info( - f"[generate_report] Added new section after [{after_idx}]: {heading}" - ) - - # Step 4: Handle removals (reverse order to preserve indices) - for idx in sorted(remove_indices, reverse=True): - if 0 <= idx < len(revised_sections): - logger.info( - f"[generate_report] Removed section [{idx}]: " - f"{revised_sections[idx]['heading']}" - ) - revised_sections.pop(idx) - - return _stitch_sections(revised_sections) - - -# ─── Tool Factory ─────────────────────────────────────────────────────────── - - -def create_generate_report_tool( - search_space_id: int, - thread_id: int | None = None, - connector_service: ConnectorService | None = None, - available_connectors: list[str] | None = None, - available_document_types: list[str] | None = None, -): - """ - Factory function to create the generate_report tool with injected dependencies. - - The tool generates a Markdown report inline using the search space's - agent LLM, saves it to the database, and returns immediately. - - Uses short-lived database sessions for each DB operation so no connection - is held during the long LLM API call. - - Generation strategies: - - New reports: single-shot generation (1 LLM call) - - Revisions (targeted edits): section-level (unchanged sections preserved) - - Revisions (global changes): full-document revision fallback - - Source strategies: - - "provided"/"conversation": use only the supplied source_content - - "kb_search": search the knowledge base internally using targeted queries - - "auto": use source_content if sufficient, otherwise fall back to KB search - - Args: - search_space_id: The user's search space ID - thread_id: The chat thread ID for associating the report - connector_service: Optional connector service for internal KB search. - When provided, the tool can search the knowledge base internally - (used by the "kb_search" and "auto" source strategies). - available_connectors: Optional list of connector types available in the - search space (used to scope internal KB searches). - - Returns: - A configured tool function for generating reports - """ - - @tool - async def generate_report( - topic: str, - source_content: str = "", - source_strategy: str = "provided", - search_queries: list[str] | None = None, - report_style: str = "detailed", - user_instructions: str | None = None, - parent_report_id: int | None = None, - ) -> dict[str, Any]: - """ - Generate a structured Markdown report artifact from provided content. - - Use this tool when the user asks to create, generate, write, produce, - draft, or summarize into a report-style deliverable. - - Trigger classes include: - - Direct trigger words WITH creation/modification verb: report, - document, memo, letter, template, article, guide, blog post, - one-pager, briefing, comprehensive guide. - - Creation-intent phrases: "write a report", "generate a document", - "draft a summary", "create an executive summary". - - Modification-intent phrases: "revise the report", "update the - report", "make it shorter", "add a section about X", "expand the - budget section", "rewrite in formal tone". - - IMPORTANT — what does NOT count as "asking for a report": - - Questions or discussion about a report or its topic are NOT report - requests. Respond to these conversationally in chat. - Examples: "What other examples to put there?", "What else could be - added?", "Can you explain section 2?", "Is the data accurate?", - "What's missing?", "How could this be improved?", "What other - topics are related?" - - Quick summary requests, explanations, or follow-up questions. - - The test: Does the message contain a creation/modification VERB - (write, create, generate, draft, add, revise, update, expand, - rewrite, make) directed at producing a deliverable? If no verb - → answer in chat. - - FORMAT/EXPORT RULE: - - Always generate the report content in Markdown. - - If the user requests DOCX/Word/PDF or another file format, export - from the generated Markdown report. - - SOURCE STRATEGY (how to collect source material): - - source_strategy="conversation" — The conversation already has - enough context (prior Q&A, filesystem exploration, pasted text, - uploaded files, scraped webpages). Pass a thorough summary as - source_content. - - source_strategy="kb_search" — Search the knowledge base - internally. Provide 1-5 targeted search_queries. The tool - handles searching internally — do NOT manually read and dump - /documents/ files into source_content. - - source_strategy="provided" — Use only what is in source_content - (default, backward-compatible). - - source_strategy="auto" — Use source_content if it has enough - material; otherwise fall back to internal KB search using - search_queries. - - CONVERSATION REUSE (HIGH PRIORITY): - - If the user has been asking questions in this chat and the - conversation contains substantive answers/discussion on the - topic, prefer source_strategy="conversation" with a thorough - summary of the full chat history as source_content. - - The user's prior questions and your answers ARE the source - material. Do NOT redundantly search the knowledge base for - information that is already in the chat. - - VERSIONING — parent_report_id: - - Set parent_report_id when the user wants to MODIFY, REVISE, - IMPROVE, UPDATE, EXPAND, or ADD CONTENT TO an existing report - that was already generated in this conversation. - - This includes both explicit AND implicit modification requests. - If the user references the existing report using words like "it", - "this", "here", "the report", or clearly refers to a previously - generated report, treat it as a revision request. - - The value must be the report_id from a previous generate_report - result in this same conversation. - - Do NOT set parent_report_id when: - * The user asks for a report on a completely NEW/DIFFERENT topic - * The user says "generate another report" (new report, not revision) - * There is no prior report to reference - - Examples of when to SET parent_report_id: - User: "Make that report shorter" → parent_report_id = - User: "Add a cost analysis section to the report" → parent_report_id = - User: "Rewrite the report in a more formal tone" → parent_report_id = - User: "I want more details about pricing in here" → parent_report_id = - User: "Include more examples" → parent_report_id = - User: "Can you also cover nutrition in this?" → parent_report_id = - User: "Make it more detailed" → parent_report_id = - User: "Not bad, but expand on the budget section" → parent_report_id = - User: "Also mention the competitor landscape" → parent_report_id = - - Examples of when to LEAVE parent_report_id as None: - User: "Generate a report on climate change" → None (new topic) - User: "Write me a report about the budget" → None (new topic) - User: "Create another report, this time about marketing" → None - User: "Now write one about travel trends in Europe" → None (new topic) - - Args: - topic: Short title for the report (max ~8 words). - source_content: Text to base the report on. Can be empty when - using source_strategy="kb_search". - source_strategy: How to collect source material. One of - "provided", "conversation", "kb_search", or "auto". - search_queries: When source_strategy is "kb_search" or "auto", - provide 1-5 targeted search queries for the knowledge base. - These should be specific, not just the topic repeated. - report_style: "detailed", "deep_research", or "brief". - user_instructions: Optional focus or modification instructions. - When revising (parent_report_id set), describe WHAT TO CHANGE. - parent_report_id: ID of a previous report to revise (creates new - version in the same version group). - - Returns: - Dict with status, report_id, title, word_count, and message. - """ - # Initialize version tracking variables (used by _save_failed_report closure) - parent_report_content: str | None = None - report_group_id: int | None = None - - async def _save_failed_report(error_msg: str) -> int | None: - """Persist a failed report row using a short-lived session.""" - try: - async with shielded_async_session() as session: - failed_report = Report( - title=topic, - content=None, - report_metadata={ - "status": "failed", - "error_message": error_msg, - }, - report_style=report_style, - search_space_id=search_space_id, - thread_id=thread_id, - report_group_id=report_group_id, - ) - session.add(failed_report) - await session.commit() - await session.refresh(failed_report) - # If this is a new group (v1 failed), set group to self - if not failed_report.report_group_id: - failed_report.report_group_id = failed_report.id - await session.commit() - logger.info( - f"[generate_report] Saved failed report {failed_report.id}: {error_msg}" - ) - return failed_report.id - except Exception: - logger.exception( - "[generate_report] Could not persist failed report row" - ) - return None - - try: - # ── Phase 1: READ (short-lived session) ────────────────────── - # Fetch parent report and LLM config, then close the session - # so no DB connection is held during the long LLM call. - async with shielded_async_session() as read_session: - if parent_report_id: - parent_report = await read_session.get(Report, parent_report_id) - if parent_report: - report_group_id = parent_report.report_group_id - parent_report_content = parent_report.content - logger.info( - f"[generate_report] Creating new version from parent {parent_report_id} " - f"(group {report_group_id})" - ) - else: - logger.warning( - f"[generate_report] parent_report_id={parent_report_id} not found, " - "creating standalone report" - ) - - llm = await get_agent_llm(read_session, search_space_id) - # read_session closed — connection returned to pool - - if not llm: - error_msg = ( - "No LLM configured. Please configure a language model in Settings." - ) - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": topic, - } - - # Build the user instructions string - user_instructions_section = "" - if user_instructions: - user_instructions_section = ( - f"**Additional Instructions:** {user_instructions}" - ) - - # ── Phase 1b: SOURCE COLLECTION (smart KB search) ──────────── - # Decide whether to augment source_content with KB search results. - effective_source = source_content or "" - - strategy = (source_strategy or "provided").lower().strip() - - needs_kb_search = False - if strategy == "kb_search": - needs_kb_search = True - elif strategy == "auto": - # Heuristic: if source_content has fewer than 200 words, - # it's likely insufficient — augment with KB search. - word_count_estimate = len(effective_source.split()) - if word_count_estimate < 200: - needs_kb_search = True - logger.info( - f"[generate_report] auto strategy: source has ~{word_count_estimate} words, " - "triggering KB search" - ) - # "provided" and "conversation" → use source_content as-is - - if needs_kb_search and connector_service and search_queries: - query_count = min(len(search_queries), 5) - dispatch_custom_event( - "report_progress", - { - "phase": "kb_search", - "message": f"Searching knowledge base ({query_count} queries)...", - }, - ) - logger.info( - f"[generate_report] Running internal KB search with " - f"{query_count} queries: {search_queries[:5]}" - ) - try: - from .knowledge_base import search_knowledge_base_async - - # Run all queries in parallel, each with its own session - async def _run_single_query(q: str) -> str: - async with shielded_async_session() as kb_session: - kb_connector_svc = ConnectorService( - kb_session, search_space_id - ) - return await search_knowledge_base_async( - query=q, - search_space_id=search_space_id, - db_session=kb_session, - connector_service=kb_connector_svc, - top_k=10, - available_connectors=available_connectors, - available_document_types=available_document_types, - ) - - kb_results = await asyncio.gather( - *[_run_single_query(q) for q in search_queries[:5]] - ) - - # Merge non-empty results into source_content - kb_text_parts = [r for r in kb_results if r and r.strip()] - if kb_text_parts: - kb_combined = "\n\n---\n\n".join(kb_text_parts) - if effective_source.strip(): - effective_source = ( - effective_source - + "\n\n--- Knowledge Base Search Results ---\n\n" - + kb_combined - ) - else: - effective_source = kb_combined - - # Count docs found (rough: count tags) - doc_count = kb_combined.count("") - dispatch_custom_event( - "report_progress", - { - "phase": "kb_search_done", - "message": f"Found {doc_count} relevant documents" - if doc_count - else f"Found results from {len(kb_text_parts)} queries", - }, - ) - logger.info( - f"[generate_report] KB search added ~{len(kb_combined)} chars " - f"from {len(kb_text_parts)} queries" - ) - else: - dispatch_custom_event( - "report_progress", - { - "phase": "kb_search_done", - "message": "No results found in knowledge base", - }, - ) - logger.info("[generate_report] KB search returned no results") - - except Exception as e: - logger.warning( - f"[generate_report] Internal KB search failed: {e}. " - "Proceeding with existing source_content." - ) - elif needs_kb_search and not connector_service: - logger.warning( - "[generate_report] KB search requested but connector_service " - "not available. Using source_content as-is." - ) - elif needs_kb_search and not search_queries: - logger.warning( - "[generate_report] KB search requested but no search_queries " - "provided. Using source_content as-is." - ) - - capped_source = effective_source[:100000] # Cap source content - - # Length constraint — only when user explicitly asks for brevity - length_instruction = "" - if report_style == "brief": - length_instruction = ( - "**LENGTH CONSTRAINT (MANDATORY):** The user wants a SHORT report. " - "Keep it concise — aim for ~400 words (~1 page) unless a different " - "length is specified in the Additional Instructions above. " - "Prioritize brevity over thoroughness. Do NOT write a long report." - ) - - # ── Phase 2: LLM GENERATION (no DB connection held) ────────── - - report_content: str | None = None - - if parent_report_content: - # ─── REVISION MODE ─────────────────────────────────────── - # Strategy: Try section-level revision first (preserves - # unchanged sections byte-for-byte). Falls back to full- - # document revision if section identification fails or if - # all sections need changes. - dispatch_custom_event( - "report_progress", - { - "phase": "revision_start", - "message": "Analyzing sections to modify...", - }, - ) - logger.info( - "[generate_report] Revision mode — attempting section-level revision" - ) - report_content = await _revise_with_sections( - llm=llm, - parent_content=parent_report_content, - user_instructions=user_instructions - or "Improve and refine the report.", - source_content=capped_source, - topic=topic, - report_style=report_style, - ) - - if report_content is None: - # Fallback: full-document revision - dispatch_custom_event( - "report_progress", - {"phase": "writing", "message": "Rewriting your full report"}, - ) - logger.info( - "[generate_report] Section-level revision deferred, " - "using full-document revision" - ) - prompt = _REVISION_PROMPT.format( - topic=topic, - report_style=report_style, - user_instructions_section=user_instructions_section - or "Improve and refine the report.", - source_content=capped_source, - previous_report_content=parent_report_content, - length_instruction=length_instruction, - formatting_rules=_FORMATTING_RULES, - ) - response = await llm.ainvoke([HumanMessage(content=prompt)]) - report_content = response.content - - else: - # ─── NEW REPORT MODE ───────────────────────────────────── - # Single-shot generation: one LLM call produces the full - # report. Fast, globally coherent, and cost-efficient. - dispatch_custom_event( - "report_progress", - {"phase": "writing", "message": "Writing your report"}, - ) - logger.info( - "[generate_report] New report — using single-shot generation" - ) - prompt = _REPORT_PROMPT.format( - topic=topic, - report_style=report_style, - user_instructions_section=user_instructions_section, - previous_version_section="", - source_content=capped_source, - length_instruction=length_instruction, - formatting_rules=_FORMATTING_RULES, - ) - response = await llm.ainvoke([HumanMessage(content=prompt)]) - report_content = response.content - - # ── Validate LLM output ────────────────────────────────────── - - if not report_content or not isinstance(report_content, str): - error_msg = "LLM returned empty or invalid content" - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": topic, - } - - # LLMs often wrap output in ```markdown ... ``` fences — strip them - report_content = _strip_wrapping_code_fences(report_content) - - if not report_content: - error_msg = "LLM returned empty or invalid content" - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": topic, - } - - # Strip any existing footer(s) carried over from parent version(s) - while report_content.rstrip().endswith(_REPORT_FOOTER): - idx = report_content.rstrip().rfind(_REPORT_FOOTER) - report_content = report_content[:idx].rstrip() - if report_content.rstrip().endswith("---"): - report_content = report_content.rstrip()[:-3].rstrip() - - # Append exactly one standard disclaimer - report_content += "\n\n---\n\n" + _REPORT_FOOTER - - # Extract metadata (includes "status": "ready") - metadata = _extract_metadata(report_content) - - # ── Phase 3: WRITE (short-lived session) ───────────────────── - # Save the report to the database, then close the session. - async with shielded_async_session() as write_session: - report = Report( - title=topic, - content=report_content, - report_metadata=metadata, - report_style=report_style, - search_space_id=search_space_id, - thread_id=thread_id, - report_group_id=report_group_id, - ) - write_session.add(report) - await write_session.commit() - await write_session.refresh(report) - - # If this is a brand-new report (v1), set report_group_id = own id - if not report.report_group_id: - report.report_group_id = report.id - await write_session.commit() - - saved_report_id = report.id - saved_group_id = report.report_group_id - # write_session closed — connection returned to pool - - logger.info( - f"[generate_report] Created report {saved_report_id} " - f"(group={saved_group_id}): " - f"{metadata.get('word_count', 0)} words, " - f"{metadata.get('section_count', 0)} sections" - ) - - return { - "status": "ready", - "report_id": saved_report_id, - "title": topic, - "word_count": metadata.get("word_count", 0), - "is_revision": bool(parent_report_content), - "report_markdown": report_content, - "message": f"Report generated successfully: {topic}", - } - - except Exception as e: - error_message = str(e) - logger.exception(f"[generate_report] Error: {error_message}") - report_id = await _save_failed_report(error_message) - - return { - "status": "failed", - "error": error_message, - "report_id": report_id, - "title": topic, - } - - return generate_report diff --git a/surfsense_backend/app/agents/new_chat/tools/resume.py b/surfsense_backend/app/agents/new_chat/tools/resume.py deleted file mode 100644 index 17849bce7..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/resume.py +++ /dev/null @@ -1,812 +0,0 @@ -""" -Resume generation tool for the SurfSense agent. - -Generates a structured resume as Typst source code using the rendercv package. -The LLM outputs only the content body (= heading, sections, entries) while -the template header (import + show rule) is hardcoded and prepended by the -backend. This eliminates LLM errors in the complex configuration block. - -Templates are stored in a registry so new designs can be added by defining -a new entry in _TEMPLATES. - -Uses the same short-lived session pattern as generate_report so no DB -connection is held during the long LLM call. -""" - -import io -import logging -import re -from datetime import UTC, datetime -from typing import Any - -import pypdf -import typst -from langchain_core.callbacks import dispatch_custom_event -from langchain_core.messages import HumanMessage -from langchain_core.tools import tool - -from app.db import Report, shielded_async_session -from app.services.llm_service import get_agent_llm - -logger = logging.getLogger(__name__) - - -# ─── Template Registry ─────────────────────────────────────────────────────── -# Each template defines: -# header - Typst import + show rule with {name}, {year}, {month}, {day} placeholders -# component_reference - component docs shown to the LLM -# rules - generation rules for the LLM - -_TEMPLATES: dict[str, dict[str, str]] = { - "classic": { - "header": """\ -#import "@preview/rendercv:0.3.0": * - -#show: rendercv.with( - name: "{name}", - title: "{name} - Resume", - footer: context {{ [#emph[{name} -- #str(here().page())\\/#str(counter(page).final().first())]] }}, - top-note: [ #emph[Last updated in {month_name} {year}] ], - locale-catalog-language: "en", - text-direction: ltr, - page-size: "us-letter", - page-top-margin: 0.7in, - page-bottom-margin: 0.7in, - page-left-margin: 0.7in, - page-right-margin: 0.7in, - page-show-footer: false, - page-show-top-note: true, - colors-body: rgb(0, 0, 0), - colors-name: rgb(0, 0, 0), - colors-headline: rgb(0, 0, 0), - colors-connections: rgb(0, 0, 0), - colors-section-titles: rgb(0, 0, 0), - colors-links: rgb(0, 0, 0), - colors-footer: rgb(128, 128, 128), - colors-top-note: rgb(128, 128, 128), - typography-line-spacing: 0.6em, - typography-alignment: "justified", - typography-date-and-location-column-alignment: right, - typography-font-family-body: "XCharter", - typography-font-family-name: "XCharter", - typography-font-family-headline: "XCharter", - typography-font-family-connections: "XCharter", - typography-font-family-section-titles: "XCharter", - typography-font-size-body: 10pt, - typography-font-size-name: 25pt, - typography-font-size-headline: 10pt, - typography-font-size-connections: 10pt, - typography-font-size-section-titles: 1.2em, - typography-small-caps-name: false, - typography-small-caps-headline: false, - typography-small-caps-connections: false, - typography-small-caps-section-titles: false, - typography-bold-name: false, - typography-bold-headline: false, - typography-bold-connections: false, - typography-bold-section-titles: true, - links-underline: true, - links-show-external-link-icon: false, - header-alignment: center, - header-photo-width: 3.5cm, - header-space-below-name: 0.7cm, - header-space-below-headline: 0.7cm, - header-space-below-connections: 0.7cm, - header-connections-hyperlink: true, - header-connections-show-icons: false, - header-connections-display-urls-instead-of-usernames: true, - header-connections-separator: "|", - header-connections-space-between-connections: 0.5cm, - section-titles-type: "with_full_line", - section-titles-line-thickness: 0.5pt, - section-titles-space-above: 0.5cm, - section-titles-space-below: 0.3cm, - sections-allow-page-break: true, - sections-space-between-text-based-entries: 0.15cm, - sections-space-between-regular-entries: 0.42cm, - entries-date-and-location-width: 4.15cm, - entries-side-space: 0cm, - entries-space-between-columns: 0.1cm, - entries-allow-page-break: false, - entries-short-second-row: false, - entries-degree-width: 1cm, - entries-summary-space-left: 0cm, - entries-summary-space-above: 0.08cm, - entries-highlights-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt), - entries-highlights-nested-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt), - entries-highlights-space-left: 0cm, - entries-highlights-space-above: 0.08cm, - entries-highlights-space-between-items: 0.02cm, - entries-highlights-space-between-bullet-and-text: 0.3em, - date: datetime( - year: {year}, - month: {month}, - day: {day}, - ), -) - -""", - "component_reference": """\ -Available components (use ONLY these): - -= Full Name // Top-level heading — person's full name - -#connections( // Contact info row (pipe-separated) - [City, Country], - [#link("mailto:email@example.com", icon: false, if-underline: false, if-color: false)[email\\@example.com]], - [#link("https://linkedin.com/in/user", icon: false, if-underline: false, if-color: false)[linkedin.com\\/in\\/user]], - [#link("https://github.com/user", icon: false, if-underline: false, if-color: false)[github.com\\/user]], -) - -== Section Title // Section heading (arbitrary name) - -#regular-entry( // Work experience, projects, publications, etc. - [ - #strong[Role/Title], Company Name -- Location - ], - [ - Start -- End - ], - main-column-second-row: [ - - Achievement or responsibility - - Another bullet point - ], -) - -#education-entry( // Education entries - [ - #strong[Institution], Degree in Field -- Location - ], - [ - Start -- End - ], - main-column-second-row: [ - - GPA, honours, relevant coursework - ], -) - -#summary([Short paragraph summary]) // Optional summary inside an entry -#content-area([Free-form content]) // Freeform text block - -For skills sections, use one bullet per category label: -- #strong[Category:] item1, item2, item3 - -For simple list sections (e.g. Honors), use plain bullet points: -- Item one -- Item two -""", - "rules": """\ -RULES: -- Do NOT include any #import or #show lines. Start directly with = Full Name. -- Output ONLY valid Typst content. No explanatory text before or after. -- Do NOT wrap output in ```typst code fences. -- The = heading MUST use the person's COMPLETE full name exactly as provided. NEVER shorten or abbreviate. -- Escape @ symbols inside link labels with a backslash: email\\@example.com -- Escape forward slashes in link display text: linkedin.com\\/in\\/user -- Every section MUST use == heading. -- Use #regular-entry() for experience, projects, publications, certifications, and similar entries. -- Use #education-entry() for education. -- For skills sections, use one bullet line per category with a bold label. -- Keep content professional, concise, and achievement-oriented. -- Use action verbs for bullet points (Led, Built, Designed, Reduced, etc.). -- This template works for ALL professions — adapt sections to the user's field. -- Default behavior should prioritize concise one-page content. -""", - }, -} - -DEFAULT_TEMPLATE = "classic" -MIN_RESUME_PAGES = 1 -MAX_RESUME_PAGES = 5 -MAX_COMPRESSION_ATTEMPTS = 2 - - -# ─── Template Helpers ───────────────────────────────────────────────────────── - - -def _get_template(template_id: str | None = None) -> dict[str, str]: - """Get a template by ID, falling back to default.""" - return _TEMPLATES.get(template_id or DEFAULT_TEMPLATE, _TEMPLATES[DEFAULT_TEMPLATE]) - - -_MONTH_NAMES = [ - "", - "Jan", - "Feb", - "Mar", - "Apr", - "May", - "Jun", - "Jul", - "Aug", - "Sep", - "Oct", - "Nov", - "Dec", -] - - -def _build_header(template: dict[str, str], name: str) -> str: - """Build the template header with the person's name and current date.""" - now = datetime.now(tz=UTC) - return ( - template["header"] - .replace("{name}", name) - .replace("{year}", str(now.year)) - .replace("{month}", str(now.month)) - .replace("{day}", str(now.day)) - .replace("{month_name}", _MONTH_NAMES[now.month]) - ) - - -def _strip_header(full_source: str) -> str: - """Strip the import + show rule from stored source to get the body only. - - Finds the closing parenthesis of the rendercv.with(...) block by tracking - nesting depth, then returns everything after it. - """ - show_match = re.search(r"#show:\s*rendercv\.with\(", full_source) - if not show_match: - return full_source - - start = show_match.end() - depth = 1 - i = start - while i < len(full_source) and depth > 0: - if full_source[i] == "(": - depth += 1 - elif full_source[i] == ")": - depth -= 1 - i += 1 - - return full_source[i:].lstrip("\n") - - -def _extract_name(body: str) -> str | None: - """Extract the person's full name from the = heading in the body.""" - match = re.search(r"^=\s+(.+)$", body, re.MULTILINE) - return match.group(1).strip() if match else None - - -def _strip_imports(body: str) -> str: - """Remove any #import or #show lines the LLM might accidentally include.""" - lines = body.split("\n") - cleaned: list[str] = [] - skip_show = False - depth = 0 - - for line in lines: - stripped = line.strip() - - if stripped.startswith("#import"): - continue - - if skip_show: - depth += stripped.count("(") - stripped.count(")") - if depth <= 0: - skip_show = False - continue - - if stripped.startswith("#show:") and "rendercv" in stripped: - depth = stripped.count("(") - stripped.count(")") - if depth > 0: - skip_show = True - continue - - cleaned.append(line) - - result = "\n".join(cleaned).strip() - return result - - -def _build_llm_reference(template: dict[str, str]) -> str: - """Build the LLM prompt reference from a template.""" - return f"""\ -You MUST output valid Typst content for a resume. -Do NOT include any #import or #show lines — those are handled automatically. -Start directly with the = Full Name heading. - -{template["component_reference"]} - -{template["rules"]}""" - - -# ─── Prompts ───────────────────────────────────────────────────────────────── - -_RESUME_PROMPT = """\ -You are an expert resume writer. Generate professional resume content as Typst markup. - -{llm_reference} - -**User Information:** -{user_info} - -**Target Maximum Pages:** {max_pages} - -{user_instructions_section} - -Generate the resume content now (starting with = Full Name): -""" - -_REVISION_PROMPT = """\ -You are an expert resume editor. Modify the existing resume according to the instructions. -Apply ONLY the requested changes — do NOT rewrite sections that are not affected. - -{llm_reference} - -**Target Maximum Pages:** {max_pages} - -**Modification Instructions:** {user_instructions} - -**EXISTING RESUME CONTENT:** - -{previous_content} - ---- - -Output the complete, updated resume content with the changes applied (starting with = Full Name): -""" - -_FIX_COMPILE_PROMPT = """\ -The resume content you generated failed to compile. Fix the error while preserving all content. - -{llm_reference} - -**Compilation Error:** -{error} - -**Full Typst Source (for context — error line numbers refer to this):** -{full_source} - -**Your content starts after the template header. Output ONLY the content portion \ -(starting with = Full Name), NOT the #import or #show rule:** -""" - -_COMPRESS_TO_PAGE_LIMIT_PROMPT = """\ -The resume compiles, but it exceeds the maximum allowed page count. -Compress the resume while preserving high-impact accomplishments and role relevance. - -{llm_reference} - -**Target Maximum Pages:** {max_pages} -**Current Page Count:** {actual_pages} -**Compression Attempt:** {attempt_number} - -Compression priorities (in this order): -1) Keep recent, high-impact, role-relevant bullets. -2) Remove low-impact or redundant bullets. -3) Shorten verbose wording while preserving meaning. -4) Trim older or less relevant details before recent ones. - -Return the complete updated Typst content (starting with = Full Name), and keep it at or below the target pages. - -**EXISTING RESUME CONTENT:** -{previous_content} -""" - - -# ─── Helpers ───────────────────────────────────────────────────────────────── - - -def _strip_typst_fences(text: str) -> str: - """Remove wrapping ```typst ... ``` fences that LLMs sometimes add.""" - stripped = text.strip() - m = re.match(r"^(`{3,})(?:typst|typ)?\s*\n", stripped) - if m: - fence = m.group(1) - if stripped.endswith(fence): - stripped = stripped[m.end() :] - stripped = stripped[: -len(fence)].rstrip() - return stripped - - -def _compile_typst(source: str) -> bytes: - """Compile Typst source to PDF bytes. Raises on failure.""" - return typst.compile(source.encode("utf-8")) - - -def _count_pdf_pages(pdf_bytes: bytes) -> int: - """Count the number of pages in compiled PDF bytes.""" - with io.BytesIO(pdf_bytes) as pdf_stream: - reader = pypdf.PdfReader(pdf_stream) - return len(reader.pages) - - -def _validate_max_pages(max_pages: int) -> int: - """Validate and normalize max_pages input.""" - if MIN_RESUME_PAGES <= max_pages <= MAX_RESUME_PAGES: - return max_pages - msg = ( - f"max_pages must be between {MIN_RESUME_PAGES} and " - f"{MAX_RESUME_PAGES}. Received: {max_pages}" - ) - raise ValueError(msg) - - -# ─── Tool Factory ─────────────────────────────────────────────────────────── - - -def create_generate_resume_tool( - search_space_id: int, - thread_id: int | None = None, -): - """ - Factory function to create the generate_resume tool. - - Generates a Typst-based resume, validates it via compilation, - and stores the source in the Report table with content_type='typst'. - The LLM generates only the content body; the template header is - prepended by the backend. - """ - - @tool - async def generate_resume( - user_info: str, - user_instructions: str | None = None, - parent_report_id: int | None = None, - max_pages: int = 1, - ) -> dict[str, Any]: - """ - Generate a professional resume as a Typst document. - - Use this tool when the user asks to create, build, generate, write, - or draft a resume or CV. Also use it when the user wants to modify, - update, or revise an existing resume generated in this conversation. - - Trigger phrases include: - - "build me a resume", "create my resume", "generate a CV" - - "update my resume", "change my title", "add my new job" - - "make my resume more concise", "reformat my resume" - - Do NOT use this tool for: - - General questions about resumes or career advice - - Reviewing or critiquing a resume without changes - - Cover letters (use generate_report instead) - - VERSIONING — parent_report_id: - - Set parent_report_id when the user wants to MODIFY an existing - resume that was already generated in this conversation. - - Leave as None for new resumes. - - Args: - user_info: The user's resume content — work experience, - education, skills, contact info, etc. Can be structured - or unstructured text. - user_instructions: Optional style or content preferences - (e.g. "emphasize leadership", "keep it to one page", - "use a modern style"). For revisions, describe what to change. - parent_report_id: ID of a previous resume to revise (creates - new version in the same version group). - max_pages: Maximum number of pages for the generated resume. - Defaults to 1. Allowed range: 1-5. - - Returns: - Dict with status, report_id, title, and content_type. - """ - report_group_id: int | None = None - parent_content: str | None = None - - template = _get_template() - llm_reference = _build_llm_reference(template) - - async def _save_failed_report(error_msg: str) -> int | None: - try: - async with shielded_async_session() as session: - failed = Report( - title="Resume", - content=None, - content_type="typst", - report_metadata={ - "status": "failed", - "error_message": error_msg, - }, - report_style="resume", - search_space_id=search_space_id, - thread_id=thread_id, - report_group_id=report_group_id, - ) - session.add(failed) - await session.commit() - await session.refresh(failed) - if not failed.report_group_id: - failed.report_group_id = failed.id - await session.commit() - logger.info( - f"[generate_resume] Saved failed report {failed.id}: {error_msg}" - ) - return failed.id - except Exception: - logger.exception( - "[generate_resume] Could not persist failed report row" - ) - return None - - try: - try: - validated_max_pages = _validate_max_pages(max_pages) - except ValueError as e: - error_msg = str(e) - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - # ── Phase 1: READ ───────────────────────────────────────────── - async with shielded_async_session() as read_session: - if parent_report_id: - parent_report = await read_session.get(Report, parent_report_id) - if parent_report: - report_group_id = parent_report.report_group_id - parent_content = parent_report.content - logger.info( - f"[generate_resume] Revising from parent {parent_report_id} " - f"(group {report_group_id})" - ) - - llm = await get_agent_llm(read_session, search_space_id) - - if not llm: - error_msg = ( - "No LLM configured. Please configure a language model in Settings." - ) - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - # ── Phase 2: LLM GENERATION ─────────────────────────────────── - - user_instructions_section = "" - if user_instructions: - user_instructions_section = ( - f"**Additional Instructions:** {user_instructions}" - ) - - if parent_content: - dispatch_custom_event( - "report_progress", - {"phase": "writing", "message": "Updating your resume"}, - ) - parent_body = _strip_header(parent_content) - prompt = _REVISION_PROMPT.format( - llm_reference=llm_reference, - max_pages=validated_max_pages, - user_instructions=user_instructions - or "Improve and refine the resume.", - previous_content=parent_body, - ) - else: - dispatch_custom_event( - "report_progress", - {"phase": "writing", "message": "Building your resume"}, - ) - prompt = _RESUME_PROMPT.format( - llm_reference=llm_reference, - user_info=user_info, - max_pages=validated_max_pages, - user_instructions_section=user_instructions_section, - ) - - response = await llm.ainvoke([HumanMessage(content=prompt)]) - body = response.content - - if not body or not isinstance(body, str): - error_msg = "LLM returned empty or invalid content" - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - body = _strip_typst_fences(body) - body = _strip_imports(body) - - # ── Phase 3: ASSEMBLE + COMPILE ─────────────────────────────── - dispatch_custom_event( - "report_progress", - {"phase": "compiling", "message": "Compiling resume..."}, - ) - - name = _extract_name(body) or "Resume" - typst_source = "" - actual_pages = 0 - compression_attempts = 0 - target_page_met = False - - for compression_round in range(MAX_COMPRESSION_ATTEMPTS + 1): - header = _build_header(template, name) - typst_source = header + body - compile_error: str | None = None - pdf_bytes: bytes | None = None - - for compile_attempt in range(2): - try: - pdf_bytes = _compile_typst(typst_source) - compile_error = None - break - except Exception as e: - compile_error = str(e) - logger.warning( - "[generate_resume] Compile attempt %s failed: %s", - compile_attempt + 1, - compile_error, - ) - - if compile_attempt == 0: - dispatch_custom_event( - "report_progress", - { - "phase": "fixing", - "message": "Fixing compilation issue...", - }, - ) - fix_prompt = _FIX_COMPILE_PROMPT.format( - llm_reference=llm_reference, - error=compile_error, - full_source=typst_source, - ) - fix_response = await llm.ainvoke( - [HumanMessage(content=fix_prompt)] - ) - if fix_response.content and isinstance( - fix_response.content, str - ): - body = _strip_typst_fences(fix_response.content) - body = _strip_imports(body) - name = _extract_name(body) or name - header = _build_header(template, name) - typst_source = header + body - - if compile_error or not pdf_bytes: - error_msg = ( - "Typst compilation failed after 2 attempts: " - f"{compile_error or 'Unknown compile error'}" - ) - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - actual_pages = _count_pdf_pages(pdf_bytes) - if actual_pages <= validated_max_pages: - target_page_met = True - break - - if compression_round >= MAX_COMPRESSION_ATTEMPTS: - break - - compression_attempts += 1 - dispatch_custom_event( - "report_progress", - { - "phase": "compressing", - "message": f"Condensing resume to {validated_max_pages} page(s)...", - }, - ) - compress_prompt = _COMPRESS_TO_PAGE_LIMIT_PROMPT.format( - llm_reference=llm_reference, - max_pages=validated_max_pages, - actual_pages=actual_pages, - attempt_number=compression_attempts, - previous_content=body, - ) - compress_response = await llm.ainvoke( - [HumanMessage(content=compress_prompt)] - ) - if not compress_response.content or not isinstance( - compress_response.content, str - ): - error_msg = "LLM returned empty content while compressing resume" - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - body = _strip_typst_fences(compress_response.content) - body = _strip_imports(body) - name = _extract_name(body) or name - - if actual_pages > MAX_RESUME_PAGES: - error_msg = ( - "Resume exceeds hard page limit after compression retries. " - f"Hard limit: <= {MAX_RESUME_PAGES} page(s), actual: {actual_pages}." - ) - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - # ── Phase 4: SAVE ───────────────────────────────────────────── - dispatch_custom_event( - "report_progress", - {"phase": "saving", "message": "Saving your resume"}, - ) - - resume_title = f"{name} - Resume" if name != "Resume" else "Resume" - - metadata: dict[str, Any] = { - "status": "ready", - "word_count": len(typst_source.split()), - "char_count": len(typst_source), - "target_max_pages": validated_max_pages, - "actual_page_count": actual_pages, - "page_limit_enforced": True, - "compression_attempts": compression_attempts, - "target_page_met": target_page_met, - } - - async with shielded_async_session() as write_session: - report = Report( - title=resume_title, - content=typst_source, - content_type="typst", - report_metadata=metadata, - report_style="resume", - search_space_id=search_space_id, - thread_id=thread_id, - report_group_id=report_group_id, - ) - write_session.add(report) - await write_session.commit() - await write_session.refresh(report) - - if not report.report_group_id: - report.report_group_id = report.id - await write_session.commit() - - saved_id = report.id - - logger.info(f"[generate_resume] Created resume {saved_id}: {resume_title}") - - return { - "status": "ready", - "report_id": saved_id, - "title": resume_title, - "content_type": "typst", - "is_revision": bool(parent_content), - "message": ( - f"Resume generated successfully: {resume_title}" - if target_page_met - else ( - f"Resume generated, but could not fit the target of <= {validated_max_pages} " - f"page(s). Final length: {actual_pages} page(s)." - ) - ), - } - - except Exception as e: - error_message = str(e) - logger.exception(f"[generate_resume] Error: {error_message}") - report_id = await _save_failed_report(error_message) - return { - "status": "failed", - "error": error_message, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - return generate_resume diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/__init__.py b/surfsense_backend/app/agents/new_chat/tools/teams/__init__.py deleted file mode 100644 index 60e2add49..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/teams/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from app.agents.new_chat.tools.teams.list_channels import ( - create_list_teams_channels_tool, -) -from app.agents.new_chat.tools.teams.read_messages import ( - create_read_teams_messages_tool, -) -from app.agents.new_chat.tools.teams.send_message import ( - create_send_teams_message_tool, -) - -__all__ = [ - "create_list_teams_channels_tool", - "create_read_teams_messages_tool", - "create_send_teams_message_tool", -] diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py b/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py deleted file mode 100644 index 4345bb476..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Shared auth helper for Teams agent tools (Microsoft Graph REST API).""" - -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.db import SearchSourceConnector, SearchSourceConnectorType - -GRAPH_API = "https://graph.microsoft.com/v1.0" - - -async def get_teams_connector( - db_session: AsyncSession, - search_space_id: int, - user_id: str, -) -> SearchSourceConnector | None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.TEAMS_CONNECTOR, - ) - ) - return result.scalars().first() - - -async def get_access_token( - db_session: AsyncSession, - connector: SearchSourceConnector, -) -> str: - """Get a valid Microsoft Graph access token, refreshing if expired.""" - from app.connectors.teams_connector import TeamsConnector - - tc = TeamsConnector( - session=db_session, - connector_id=connector.id, - ) - return await tc._get_valid_token() diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py deleted file mode 100644 index 0fc52b5c7..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py +++ /dev/null @@ -1,114 +0,0 @@ -import logging -from typing import Any - -import httpx -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.db import async_session_maker - -from ._auth import GRAPH_API, get_access_token, get_teams_connector - -logger = logging.getLogger(__name__) - - -def create_list_teams_channels_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the list_teams_channels tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured list_teams_channels tool - """ - del db_session # per-call session — see docstring - - @tool - async def list_teams_channels() -> dict[str, Any]: - """List all Microsoft Teams and their channels the user has access to. - - Returns: - Dictionary with status and a list of teams, each containing - team_id, team_name, and a list of channels (id, name). - """ - if search_space_id is None or user_id is None: - return {"status": "error", "message": "Teams tool not properly configured."} - - try: - async with async_session_maker() as db_session: - connector = await get_teams_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Teams connector found."} - - token = await get_access_token(db_session, connector) - headers = {"Authorization": f"Bearer {token}"} - - async with httpx.AsyncClient(timeout=20.0) as client: - teams_resp = await client.get( - f"{GRAPH_API}/me/joinedTeams", headers=headers - ) - - if teams_resp.status_code == 401: - return { - "status": "auth_error", - "message": "Teams token expired. Please re-authenticate.", - "connector_type": "teams", - } - if teams_resp.status_code != 200: - return { - "status": "error", - "message": f"Graph API error: {teams_resp.status_code}", - } - - teams_data = teams_resp.json().get("value", []) - result_teams = [] - - async with httpx.AsyncClient(timeout=20.0) as client: - for team in teams_data: - team_id = team["id"] - ch_resp = await client.get( - f"{GRAPH_API}/teams/{team_id}/channels", - headers=headers, - ) - channels = [] - if ch_resp.status_code == 200: - channels = [ - {"id": ch["id"], "name": ch.get("displayName", "")} - for ch in ch_resp.json().get("value", []) - ] - result_teams.append( - { - "team_id": team_id, - "team_name": team.get("displayName", ""), - "channels": channels, - } - ) - - return { - "status": "success", - "teams": result_teams, - "total_teams": len(result_teams), - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error("Error listing Teams channels: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to list Teams channels."} - - return list_teams_channels diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py deleted file mode 100644 index 0ebda021e..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py +++ /dev/null @@ -1,125 +0,0 @@ -import logging -from typing import Any - -import httpx -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.db import async_session_maker - -from ._auth import GRAPH_API, get_access_token, get_teams_connector - -logger = logging.getLogger(__name__) - - -def create_read_teams_messages_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the read_teams_messages tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured read_teams_messages tool - """ - del db_session # per-call session — see docstring - - @tool - async def read_teams_messages( - team_id: str, - channel_id: str, - limit: int = 25, - ) -> dict[str, Any]: - """Read recent messages from a Microsoft Teams channel. - - Args: - team_id: The team ID (from list_teams_channels). - channel_id: The channel ID (from list_teams_channels). - limit: Number of messages to fetch (default 25, max 50). - - Returns: - Dictionary with status and a list of messages including - id, sender, content, timestamp. - """ - if search_space_id is None or user_id is None: - return {"status": "error", "message": "Teams tool not properly configured."} - - limit = min(limit, 50) - - try: - async with async_session_maker() as db_session: - connector = await get_teams_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Teams connector found."} - - token = await get_access_token(db_session, connector) - - async with httpx.AsyncClient(timeout=20.0) as client: - resp = await client.get( - f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages", - headers={"Authorization": f"Bearer {token}"}, - params={"$top": limit}, - ) - - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Teams token expired. Please re-authenticate.", - "connector_type": "teams", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Insufficient permissions to read this channel.", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Graph API error: {resp.status_code}", - } - - raw_msgs = resp.json().get("value", []) - messages = [] - for m in raw_msgs: - sender = m.get("from", {}) - user_info = sender.get("user", {}) if sender else {} - body = m.get("body", {}) - messages.append( - { - "id": m.get("id"), - "sender": user_info.get("displayName", "Unknown"), - "content": body.get("content", ""), - "content_type": body.get("contentType", "text"), - "timestamp": m.get("createdDateTime", ""), - } - ) - - return { - "status": "success", - "team_id": team_id, - "channel_id": channel_id, - "messages": messages, - "total": len(messages), - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error("Error reading Teams messages: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to read Teams messages."} - - return read_teams_messages diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py deleted file mode 100644 index 6f40d27e1..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py +++ /dev/null @@ -1,136 +0,0 @@ -import logging -from typing import Any - -import httpx -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.db import async_session_maker - -from ._auth import GRAPH_API, get_access_token, get_teams_connector - -logger = logging.getLogger(__name__) - - -def create_send_teams_message_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, -): - """ - Factory function to create the send_teams_message tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - Configured send_teams_message tool - """ - del db_session # per-call session — see docstring - - @tool - async def send_teams_message( - team_id: str, - channel_id: str, - content: str, - ) -> dict[str, Any]: - """Send a message to a Microsoft Teams channel. - - Requires the ChannelMessage.Send OAuth scope. If the user gets a - permission error, they may need to re-authenticate with updated scopes. - - Args: - team_id: The team ID (from list_teams_channels). - channel_id: The channel ID (from list_teams_channels). - content: The message text (HTML supported). - - Returns: - Dictionary with status, message_id on success. - - IMPORTANT: - - If status is "rejected", the user explicitly declined. Do NOT retry. - """ - if search_space_id is None or user_id is None: - return {"status": "error", "message": "Teams tool not properly configured."} - - try: - async with async_session_maker() as db_session: - connector = await get_teams_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Teams connector found."} - - result = request_approval( - action_type="teams_send_message", - tool_name="send_teams_message", - params={ - "team_id": team_id, - "channel_id": channel_id, - "content": content, - }, - context={"connector_id": connector.id}, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Message was not sent.", - } - - final_content = result.params.get("content", content) - final_team = result.params.get("team_id", team_id) - final_channel = result.params.get("channel_id", channel_id) - - token = await get_access_token(db_session, connector) - - async with httpx.AsyncClient(timeout=20.0) as client: - resp = await client.post( - f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages", - headers={ - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", - }, - json={"body": {"content": final_content}}, - ) - - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Teams token expired. Please re-authenticate.", - "connector_type": "teams", - } - if resp.status_code == 403: - return { - "status": "insufficient_permissions", - "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.", - } - if resp.status_code not in (200, 201): - return { - "status": "error", - "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}", - } - - msg_data = resp.json() - return { - "status": "success", - "message_id": msg_data.get("id"), - "message": "Message sent to Teams channel.", - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - logger.error("Error sending Teams message: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to send Teams message."} - - return send_teams_message diff --git a/surfsense_backend/app/agents/new_chat/tools/tool_response.py b/surfsense_backend/app/agents/new_chat/tools/tool_response.py deleted file mode 100644 index 8644ada5c..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/tool_response.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Standardised response dict factories for LangChain agent tools.""" - -from __future__ import annotations - -from typing import Any - - -class ToolResponse: - @staticmethod - def success(message: str, **data: Any) -> dict[str, Any]: - return {"status": "success", "message": message, **data} - - @staticmethod - def error(error: str, **data: Any) -> dict[str, Any]: - return {"status": "error", "error": error, **data} - - @staticmethod - def auth_error(service: str, **data: Any) -> dict[str, Any]: - return { - "status": "auth_error", - "error": ( - f"{service} authentication has expired or been revoked. " - "Please re-connect the integration in Settings → Connectors." - ), - **data, - } - - @staticmethod - def rejected(message: str = "Action was declined by the user.") -> dict[str, Any]: - return {"status": "rejected", "message": message} - - @staticmethod - def not_found(resource: str, identifier: str, **data: Any) -> dict[str, Any]: - return { - "status": "not_found", - "error": f"{resource} '{identifier}' was not found.", - **data, - } diff --git a/surfsense_backend/app/agents/new_chat/tools/video_presentation.py b/surfsense_backend/app/agents/new_chat/tools/video_presentation.py deleted file mode 100644 index 34f5183ca..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/video_presentation.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -Video presentation generation tool for the SurfSense agent. - -This module provides a factory function for creating the generate_video_presentation -tool that submits a Celery task for background video presentation generation. The -tool then polls the row until it reaches a terminal status (READY/FAILED) and -returns that status. The wait is bounded by the chat's HTTP / process lifetime; -see app.agents.shared.deliverable_wait for details. -""" - -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.shared.deliverable_wait import wait_for_deliverable -from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session - -logger = logging.getLogger(__name__) - - -def create_generate_video_presentation_tool( - search_space_id: int, - db_session: AsyncSession, - thread_id: int | None = None, -): - """ - Factory function to create the generate_video_presentation tool with injected dependencies. - - Pre-creates video presentation record with pending status so the ID is available - immediately for frontend polling. The row is written via a fresh, tool-local - session so parallel tool calls (e.g. video + podcast in the same agent step) - don't share an ``AsyncSession`` (which is not concurrency-safe). - """ - del db_session # writes use a fresh tool-local session, see below - - @tool - async def generate_video_presentation( - source_content: str, - video_title: str = "SurfSense Presentation", - user_prompt: str | None = None, - ) -> dict[str, Any]: - """Generate a video presentation from the provided content. - - Use this tool when the user asks to create a video, presentation, slides, or slide deck. - - Args: - source_content: The text content to turn into a presentation. - video_title: Title for the presentation (default: "SurfSense Presentation") - user_prompt: Optional style/tone instructions. - """ - try: - # See podcast.py for the rationale: parallel tool calls share the - # streaming session, and AsyncSession is not concurrency-safe — - # interleaved flushes produce "Session.add() during flush" and - # poison the transaction for every concurrent tool. - async with shielded_async_session() as session: - video_pres = VideoPresentation( - title=video_title, - status=VideoPresentationStatus.PENDING, - search_space_id=search_space_id, - thread_id=thread_id, - ) - session.add(video_pres) - await session.commit() - await session.refresh(video_pres) - video_pres_id = video_pres.id - - from app.tasks.celery_tasks.video_presentation_tasks import ( - generate_video_presentation_task, - ) - - task = generate_video_presentation_task.delay( - video_presentation_id=video_pres_id, - source_content=source_content, - search_space_id=search_space_id, - user_prompt=user_prompt, - ) - - logger.info( - "[generate_video_presentation] Created video presentation %s, task: %s", - video_pres_id, - task.id, - ) - - # Wait until the Celery worker flips the row to a terminal - # state. No internal budget — see deliverable_wait module. - terminal_status, _columns, elapsed = await wait_for_deliverable( - model=VideoPresentation, - row_id=video_pres_id, - columns=[VideoPresentation.status], - terminal_statuses={ - VideoPresentationStatus.READY, - VideoPresentationStatus.FAILED, - }, - ) - - if terminal_status == VideoPresentationStatus.READY: - logger.info( - "[generate_video_presentation] %s READY in %.2fs", - video_pres_id, - elapsed, - ) - return { - "status": VideoPresentationStatus.READY.value, - "video_presentation_id": video_pres_id, - "title": video_title, - "message": "Video presentation generated and saved.", - } - - # Only other terminal state is FAILED. - logger.warning( - "[generate_video_presentation] %s FAILED in %.2fs", - video_pres_id, - elapsed, - ) - return { - "status": VideoPresentationStatus.FAILED.value, - "video_presentation_id": video_pres_id, - "title": video_title, - "error": ( - "Background worker reported FAILED status for this " - "video presentation." - ), - } - - except Exception as e: - error_message = str(e) - logger.exception("[generate_video_presentation] Error: %s", error_message) - return { - "status": VideoPresentationStatus.FAILED.value, - "error": error_message, - "title": video_title, - "video_presentation_id": None, - } - - return generate_video_presentation diff --git a/surfsense_backend/app/agents/podcaster/nodes.py b/surfsense_backend/app/agents/podcaster/nodes.py index 277536211..d1f140a44 100644 --- a/surfsense_backend/app/agents/podcaster/nodes.py +++ b/surfsense_backend/app/agents/podcaster/nodes.py @@ -24,36 +24,27 @@ from .utils import get_voice_for_provider async def create_podcast_transcript( state: State, config: RunnableConfig ) -> dict[str, Any]: - """Each node does work.""" - - # Get configuration from runnable config + """Generate the podcast transcript from the source content.""" configuration = Configuration.from_runnable_config(config) search_space_id = configuration.search_space_id user_prompt = configuration.user_prompt - # Use the search space's agent LLM for podcast transcript generation. llm = await get_agent_llm(state.db_session, search_space_id) if not llm: error_message = f"No agent LLM configured for search space {search_space_id}" print(error_message) raise RuntimeError(error_message) - # Get the prompt prompt = get_podcast_generation_prompt(user_prompt) - - # Create the messages messages = [ SystemMessage(content=prompt), HumanMessage( content=f"{state.source_content}" ), ] - - # Generate the podcast transcript llm_response = await llm.ainvoke(messages) - # Reasoning models (e.g. Kimi K2.5) may return content as a list of - # blocks including 'reasoning' entries. Normalise to a plain string. + # Reasoning models may return content as blocks; normalise to a string. content = strip_markdown_fences(extract_text_content(llm_response.content)) try: @@ -87,17 +78,13 @@ async def create_merged_podcast_audio( state: State, config: RunnableConfig ) -> dict[str, Any]: """Generate audio for each transcript and merge them into a single podcast file.""" - - # configuration = Configuration.from_runnable_config(config) - starting_transcript = PodcastTranscriptEntry( speaker_id=1, dialog="Welcome to Surfsense Podcast." ) transcript = state.podcast_transcript - # Merge the starting transcript with the podcast transcript - # Check if transcript is a PodcastTranscripts object or already a list + # transcript may be a PodcastTranscripts object or already a list. if hasattr(transcript, "podcast_transcripts"): transcript_entries = transcript.podcast_transcripts else: @@ -105,20 +92,16 @@ async def create_merged_podcast_audio( merged_transcript = [starting_transcript, *transcript_entries] - # Create a temporary directory for audio files temp_dir = Path("temp_audio") temp_dir.mkdir(exist_ok=True) - # Generate a unique session ID for this podcast session_id = str(uuid.uuid4()) output_path = f"podcasts/{session_id}_podcast.mp3" os.makedirs("podcasts", exist_ok=True) - # Generate audio for each transcript segment audio_files = [] async def generate_speech_for_segment(segment, index): - # Handle both dictionary and PodcastTranscriptEntry objects if hasattr(segment, "speaker_id"): speaker_id = segment.speaker_id dialog = segment.dialog @@ -126,20 +109,15 @@ async def create_merged_podcast_audio( speaker_id = segment.get("speaker_id", 0) dialog = segment.get("dialog", "") - # Select voice based on speaker_id voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id) - # Generate a unique filename for this segment if app_config.TTS_SERVICE == "local/kokoro": - # Kokoro generates WAV files filename = f"{temp_dir}/{session_id}_{index}.wav" else: - # Other services generate MP3 files filename = f"{temp_dir}/{session_id}_{index}.mp3" try: if app_config.TTS_SERVICE == "local/kokoro": - # Use Kokoro TTS service kokoro_service = await get_kokoro_tts_service( lang_code="a" ) # American English @@ -168,7 +146,6 @@ async def create_merged_podcast_audio( timeout=600, ) - # Save the audio to a file - use proper streaming method with open(filename, "wb") as f: f.write(response.content) @@ -177,23 +154,17 @@ async def create_merged_podcast_audio( print(f"Error generating speech for segment {index}: {e!s}") raise - # Generate all audio files concurrently tasks = [ generate_speech_for_segment(segment, i) for i, segment in enumerate(merged_transcript) ] audio_files = await asyncio.gather(*tasks) - # Merge audio files using ffmpeg try: - # Create FFmpeg instance with the first input ffmpeg = FFmpeg().option("y") - - # Add each audio file as input for audio_file in audio_files: ffmpeg = ffmpeg.input(audio_file) - # Configure the concatenation and output filter_complex = [] for i in range(len(audio_files)): filter_complex.append(f"[{i}:0]") @@ -203,8 +174,6 @@ async def create_merged_podcast_audio( ) ffmpeg = ffmpeg.option("filter_complex", filter_complex_str) ffmpeg = ffmpeg.output(output_path, map="[outa]") - - # Execute FFmpeg await ffmpeg.execute() print(f"Successfully created podcast audio: {output_path}") @@ -213,7 +182,6 @@ async def create_merged_podcast_audio( print(f"Error merging audio files: {e!s}") raise finally: - # Clean up temporary files for audio_file in audio_files: try: os.remove(audio_file) diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 11a55e948..e9ffa74d7 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -23,7 +23,7 @@ from starlette.requests import Request as StarletteRequest from starlette.responses import Response as StarletteResponse from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware -from app.agents.new_chat.checkpointer import ( +from app.agents.chat.runtime.checkpointer import ( close_checkpointer, setup_checkpointer_tables, ) @@ -487,7 +487,7 @@ async def _warm_agent_jit_caches() -> None: ) from langchain_core.tools import tool - from app.agents.new_chat.context import SurfSenseContextSchema + from app.agents.chat.shared.context import SurfSenseContextSchema # Minimal LLM stub. ``FakeListChatModel`` satisfies # ``BaseChatModel`` without any network or auth — perfect for diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py index 99e295f30..aa96e4f6e 100644 --- a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py @@ -10,9 +10,12 @@ from langchain_core.messages import HumanMessage from langgraph.types import Command from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent -from app.agents.new_chat.context import SurfSenseContextSchema -from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text +from app.agents.chat.multi_agent_chat import create_multi_agent_chat_deep_agent +from app.agents.chat.runtime.mention_resolver import ( + resolve_mentions, + substitute_in_text, +) +from app.agents.chat.shared.context import SurfSenseContextSchema from app.db import ChatVisibility, async_session_maker from app.schemas.new_chat import MentionedDocumentInfo diff --git a/surfsense_backend/app/automations/services/model_policy.py b/surfsense_backend/app/automations/services/model_policy.py index 88e9d5f28..7e3e46b61 100644 --- a/surfsense_backend/app/automations/services/model_policy.py +++ b/surfsense_backend/app/automations/services/model_policy.py @@ -39,7 +39,9 @@ def _is_premium_global(kind: ModelKind, config_id: int) -> bool: cfg: dict | None = None if kind == "llm": - from app.agents.new_chat.llm_config import load_global_llm_config_by_id + from app.agents.chat.runtime.llm_config import ( + load_global_llm_config_by_id, + ) cfg = load_global_llm_config_by_id(config_id) elif kind == "image": diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index f3c05f2d6..203e36580 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -645,9 +645,6 @@ class Config: # Anonymous / no-login mode settings NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE" - MULTI_AGENT_CHAT_ENABLED = ( - os.getenv("MULTI_AGENT_CHAT_ENABLED", "FALSE").upper() == "TRUE" - ) ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000")) ANON_TOKEN_WARNING_THRESHOLD = int( os.getenv("ANON_TOKEN_WARNING_THRESHOLD", "400000") diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index c6fe1ee37..5b232b55c 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -2051,60 +2051,6 @@ class Log(BaseModel, TimestampMixin): search_space = relationship("SearchSpace", back_populates="logs") -class Notification(BaseModel, TimestampMixin): - __tablename__ = "notifications" - __table_args__ = ( - # Composite index for unread-count queries that filter by - # (user_id, read, type) and order by created_at. - Index( - "ix_notifications_user_read_type_created", - "user_id", - "read", - "type", - "created_at", - ), - # Covers the common list query: user_id + search_space_id + created_at DESC - Index( - "ix_notifications_user_space_created", - "user_id", - "search_space_id", - "created_at", - ), - ) - - user_id = Column( - UUID(as_uuid=True), - ForeignKey("user.id", ondelete="CASCADE"), - nullable=False, - index=True, - ) - search_space_id = Column( - Integer, - ForeignKey("searchspaces.id", ondelete="CASCADE"), - nullable=True, - index=True, - ) - type = Column( - String(50), nullable=False, index=True - ) # 'connector_indexing', 'document_processing', etc. - title = Column(String(200), nullable=False) - message = Column(Text, nullable=False) - read = Column( - Boolean, nullable=False, default=False, server_default=text("false"), index=True - ) - notification_metadata = Column("metadata", JSONB, nullable=True, default={}) - updated_at = Column( - TIMESTAMP(timezone=True), - nullable=True, - default=lambda: datetime.now(UTC), - onupdate=lambda: datetime.now(UTC), - index=True, - ) - - user = relationship("User", back_populates="notifications") - search_space = relationship("SearchSpace", back_populates="notifications") - - class UserIncentiveTask(BaseModel, TimestampMixin): """ Tracks completed incentive tasks for users. @@ -2928,6 +2874,7 @@ from app.automations.persistence import ( # noqa: E402, F401 AutomationTrigger, ) from app.file_storage.persistence import DocumentFile # noqa: E402, F401 +from app.notifications.persistence import Notification # noqa: E402, F401 engine = create_async_engine( DATABASE_URL, diff --git a/surfsense_backend/app/gateway/agent_invoke.py b/surfsense_backend/app/gateway/agent_invoke.py index 7a2219b1d..dcbf9a954 100644 --- a/surfsense_backend/app/gateway/agent_invoke.py +++ b/surfsense_backend/app/gateway/agent_invoke.py @@ -16,7 +16,7 @@ from app.gateway.bindings import get_or_create_thread_for_binding from app.gateway.hitl_filter import DEFAULT_HITL_TOOL_NAMES from app.gateway.thread_lock import acquire_thread_lock, release_thread_lock from app.observability.metrics import record_gateway_turn_latency -from app.tasks.chat.stream_new_chat import stream_new_chat +from app.tasks.chat.streaming.flows import stream_new_chat logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/notifications/__init__.py b/surfsense_backend/app/notifications/__init__.py new file mode 100644 index 000000000..6ffe45000 --- /dev/null +++ b/surfsense_backend/app/notifications/__init__.py @@ -0,0 +1,19 @@ +"""User notifications: persistence, service, and HTTP API. + +Emit notifications via :class:`~app.notifications.service.NotificationService`; +the router in :mod:`app.notifications.api` exposes the inbox endpoints. +""" + +from __future__ import annotations + +# Initialize app.db first to avoid a partial-init circular import when this +# package is the entry point (e.g. Celery loading it before any ORM code). +import app.db # noqa: F401 + +from app.notifications.persistence import Notification +from app.notifications.service import NotificationService + +__all__ = [ + "Notification", + "NotificationService", +] diff --git a/surfsense_backend/app/notifications/api/__init__.py b/surfsense_backend/app/notifications/api/__init__.py new file mode 100644 index 000000000..2708c8805 --- /dev/null +++ b/surfsense_backend/app/notifications/api/__init__.py @@ -0,0 +1,7 @@ +"""Notifications HTTP API.""" + +from __future__ import annotations + +from app.notifications.api.api import router + +__all__ = ["router"] diff --git a/surfsense_backend/app/routes/notifications_routes.py b/surfsense_backend/app/notifications/api/api.py similarity index 63% rename from surfsense_backend/app/routes/notifications_routes.py rename to surfsense_backend/app/notifications/api/api.py index 611227795..ddca09c66 100644 --- a/surfsense_backend/app/routes/notifications_routes.py +++ b/surfsense_backend/app/notifications/api/api.py @@ -1,124 +1,36 @@ -""" -Notifications API routes. -These endpoints allow marking notifications as read and fetching older notifications. -Zero automatically syncs the changes to all connected clients for recent items. -For older items (beyond the sync window), use the list endpoint. -""" +"""HTTP routes for the notifications inbox (list, counts, mark-read).""" + +from __future__ import annotations from datetime import UTC, datetime, timedelta -from typing import Literal from fastapi import APIRouter, Depends, HTTPException, Query, status -from pydantic import BaseModel from sqlalchemy import case, desc, func, literal, literal_column, select, update from sqlalchemy.ext.asyncio import AsyncSession -from app.db import Notification, User, get_async_session +from app.db import User, get_async_session +from app.notifications.api.schemas import ( + BatchUnreadCountResponse, + CategoryUnreadCount, + MarkAllReadResponse, + MarkReadResponse, + NotificationListResponse, + SourceTypeItem, + SourceTypesResponse, + UnreadCountResponse, +) +from app.notifications.api.transform import ( + parse_before_date, + parse_source_type, + to_response, +) +from app.notifications.constants import CATEGORY_TYPES, SYNC_WINDOW_DAYS +from app.notifications.persistence import Notification +from app.notifications.types import NotificationCategory, NotificationType from app.users import current_active_user router = APIRouter(prefix="/notifications", tags=["notifications"]) -# Must match frontend SYNC_WINDOW_DAYS in use-inbox.ts -SYNC_WINDOW_DAYS = 14 - -# Valid notification types - must match frontend InboxItemTypeEnum -NotificationType = Literal[ - "connector_indexing", - "connector_deletion", - "document_processing", - "new_mention", - "comment_reply", - "page_limit_exceeded", -] - -# Category-to-types mapping for filtering by tab -NotificationCategory = Literal["comments", "status"] -CATEGORY_TYPES: dict[str, tuple[str, ...]] = { - "comments": ("new_mention", "comment_reply"), - "status": ( - "connector_indexing", - "connector_deletion", - "document_processing", - "page_limit_exceeded", - ), -} - - -class NotificationResponse(BaseModel): - """Response model for a single notification.""" - - id: int - user_id: str - search_space_id: int | None - type: str - title: str - message: str - read: bool - metadata: dict - created_at: str - updated_at: str | None - - class Config: - from_attributes = True - - -class NotificationListResponse(BaseModel): - """Response for listing notifications with pagination.""" - - items: list[NotificationResponse] - total: int - has_more: bool - next_offset: int | None - - -class MarkReadResponse(BaseModel): - """Response for mark as read operations.""" - - success: bool - message: str - - -class MarkAllReadResponse(BaseModel): - """Response for mark all as read operation.""" - - success: bool - message: str - updated_count: int - - -class SourceTypeItem(BaseModel): - """A single source type with its category and count.""" - - key: str - type: str - category: str # "connector" or "document" - count: int - - -class SourceTypesResponse(BaseModel): - """Response for notification source types used in status tab filter.""" - - sources: list[SourceTypeItem] - - -class UnreadCountResponse(BaseModel): - """Response for unread count with split between recent and older items.""" - - total_unread: int - recent_unread: int # Within SYNC_WINDOW_DAYS - - -class CategoryUnreadCount(BaseModel): - total_unread: int - recent_unread: int - - -class BatchUnreadCountResponse(BaseModel): - """Batched unread counts for all categories in a single response.""" - - comments: CategoryUnreadCount - status: CategoryUnreadCount - @router.get("/unread-counts-batch", response_model=BatchUnreadCountResponse) async def get_unread_counts_batch( @@ -126,12 +38,7 @@ async def get_unread_counts_batch( user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session), ) -> BatchUnreadCountResponse: - """ - Get unread counts for all notification categories in a single DB query. - - Replaces multiple separate calls to /unread-count with different category - filters, reducing round-trips from 2+ to 1. - """ + """Unread counts for every category in a single query.""" cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS) base_filter = [ @@ -140,6 +47,7 @@ async def get_unread_counts_batch( ] if search_space_id is not None: + # Include global (null search-space) notifications. base_filter.append( (Notification.search_space_id == search_space_id) | (Notification.search_space_id.is_(None)) @@ -181,14 +89,11 @@ async def get_notification_source_types( user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session), ) -> SourceTypesResponse: - """ - Get all distinct connector types and document types from the user's - status notifications. Used to populate the filter dropdown in the - inbox Status tab so that all types are shown regardless of pagination. - """ + """Distinct connector/document source types for the Status tab filter.""" base_filter = [Notification.user_id == user.id] if search_space_id is not None: + # Include global (null search-space) notifications. base_filter.append( (Notification.search_space_id == search_space_id) | (Notification.search_space_id.is_(None)) @@ -258,47 +163,35 @@ async def get_unread_count( user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session), ) -> UnreadCountResponse: - """ - Get the total unread notification count for the current user. + """Total and recent (within sync window) unread counts for the user. - Returns both: - - total_unread: All unread notifications (for accurate badge count) - - recent_unread: Unread notifications within the sync window (last 14 days) - - This allows the frontend to calculate: - - older_unread = total_unread - recent_unread (static until reconciliation) - - Display count = older_unread + live_recent_count (from Zero) + Returning both lets a client hold the older count static while + live-syncing the recent ones. """ - # Calculate cutoff date for sync window cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS) - # Base filter for user's unread notifications base_filter = [ Notification.user_id == user.id, Notification.read == False, # noqa: E712 ] - # Add search space filter if provided (include null for global notifications) if search_space_id is not None: + # Include global (null search-space) notifications. base_filter.append( (Notification.search_space_id == search_space_id) | (Notification.search_space_id.is_(None)) ) - # Filter by notification type if provided if type_filter: base_filter.append(Notification.type == type_filter) - # Filter by category (maps to multiple types) if category: base_filter.append(Notification.type.in_(CATEGORY_TYPES[category])) - # Total unread count (all time) total_query = select(func.count(Notification.id)).where(*base_filter) total_result = await session.execute(total_query) total_unread = total_result.scalar() or 0 - # Recent unread count (within sync window) recent_query = select(func.count(Notification.id)).where( *base_filter, Notification.created_at > cutoff_date, @@ -340,22 +233,14 @@ async def list_notifications( user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session), ) -> NotificationListResponse: - """ - List notifications for the current user with pagination. - - This endpoint is used as a fallback for older notifications that are - outside the Zero sync window (2 weeks). - - Use `before_date` to paginate through older notifications efficiently. - """ - # Build base query + """Paginated inbox fallback for items outside the Zero sync window.""" query = select(Notification).where(Notification.user_id == user.id) count_query = select(func.count(Notification.id)).where( Notification.user_id == user.id ) - # Filter by search space (include null search_space_id for global notifications) if search_space_id is not None: + # Include global (null search-space) notifications. query = query.where( (Notification.search_space_id == search_space_id) | (Notification.search_space_id.is_(None)) @@ -365,39 +250,26 @@ async def list_notifications( | (Notification.search_space_id.is_(None)) ) - # Filter by type if type_filter: query = query.where(Notification.type == type_filter) count_query = count_query.where(Notification.type == type_filter) - # Filter by category (maps to multiple types) if category: cat_types = CATEGORY_TYPES[category] query = query.where(Notification.type.in_(cat_types)) count_query = count_query.where(Notification.type.in_(cat_types)) - # Filter by source type (connector or document type from JSONB metadata) + # source_type encodes the JSONB facet to match: 'connector:' or 'doctype:'. if source_type: - if source_type.startswith("connector:"): - connector_val = source_type[len("connector:") :] - source_filter = Notification.type.in_( - ("connector_indexing", "connector_deletion") - ) & ( - Notification.notification_metadata["connector_type"].astext - == connector_val - ) - query = query.where(source_filter) - count_query = count_query.where(source_filter) - elif source_type.startswith("doctype:"): - doctype_val = source_type[len("doctype:") :] - source_filter = Notification.type.in_(("document_processing",)) & ( - Notification.notification_metadata["document_type"].astext - == doctype_val + parsed_source = parse_source_type(source_type) + if parsed_source: + source_filter = Notification.type.in_(parsed_source.types) & ( + Notification.notification_metadata[parsed_source.metadata_key].astext + == parsed_source.value ) query = query.where(source_filter) count_query = count_query.where(source_filter) - # Filter by preset: 'unread' or 'errors' if filter == "unread": unread_filter = Notification.read == False # noqa: E712 query = query.where(unread_filter) @@ -409,10 +281,9 @@ async def list_notifications( query = query.where(error_filter) count_query = count_query.where(error_filter) - # Filter by date (for efficient pagination of older items) if before_date: try: - before_datetime = datetime.fromisoformat(before_date.replace("Z", "+00:00")) + before_datetime = parse_before_date(before_date) query = query.where(Notification.created_at < before_datetime) count_query = count_query.where(Notification.created_at < before_datetime) except ValueError: @@ -421,7 +292,6 @@ async def list_notifications( detail="Invalid date format. Use ISO format (e.g., 2024-01-15T00:00:00Z)", ) from None - # Filter by search query (case-insensitive title/message search) if search: search_term = f"%{search}%" search_filter = Notification.title.ilike( @@ -430,45 +300,22 @@ async def list_notifications( query = query.where(search_filter) count_query = count_query.where(search_filter) - # Get total count total_result = await session.execute(count_query) total = total_result.scalar() or 0 - # Apply ordering and pagination + # Over-fetch by one to tell whether another page exists. query = ( query.order_by(desc(Notification.created_at)).offset(offset).limit(limit + 1) ) - # Execute query result = await session.execute(query) notifications = result.scalars().all() - # Check if there are more items has_more = len(notifications) > limit if has_more: notifications = notifications[:limit] - # Convert to response format - items = [] - for notification in notifications: - items.append( - NotificationResponse( - id=notification.id, - user_id=str(notification.user_id), - search_space_id=notification.search_space_id, - type=notification.type, - title=notification.title, - message=notification.message, - read=notification.read, - metadata=notification.notification_metadata or {}, - created_at=notification.created_at.isoformat() - if notification.created_at - else "", - updated_at=notification.updated_at.isoformat() - if notification.updated_at - else None, - ) - ) + items = [to_response(notification) for notification in notifications] return NotificationListResponse( items=items, @@ -484,12 +331,8 @@ async def mark_notification_as_read( user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session), ) -> MarkReadResponse: - """ - Mark a single notification as read. - - Zero will automatically sync this change to all connected clients. - """ - # Verify the notification belongs to the user + """Mark one of the user's notifications read; Zero syncs the change.""" + # Scope to the caller's own notifications. result = await session.execute( select(Notification).where( Notification.id == notification_id, @@ -510,7 +353,6 @@ async def mark_notification_as_read( message="Notification already marked as read", ) - # Update the notification notification.read = True await session.commit() @@ -525,12 +367,7 @@ async def mark_all_notifications_as_read( user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session), ) -> MarkAllReadResponse: - """ - Mark all notifications as read for the current user. - - Zero will automatically sync these changes to all connected clients. - """ - # Update all unread notifications for the user + """Mark all of the user's notifications read; Zero syncs the changes.""" result = await session.execute( update(Notification) .where( diff --git a/surfsense_backend/app/notifications/api/schemas.py b/surfsense_backend/app/notifications/api/schemas.py new file mode 100644 index 000000000..727e5485a --- /dev/null +++ b/surfsense_backend/app/notifications/api/schemas.py @@ -0,0 +1,81 @@ +"""Response shapes for the notifications API.""" + +from __future__ import annotations + +from pydantic import BaseModel + + +class NotificationResponse(BaseModel): + """A single notification.""" + + id: int + user_id: str + search_space_id: int | None + type: str + title: str + message: str + read: bool + metadata: dict + created_at: str + updated_at: str | None + + class Config: + from_attributes = True + + +class NotificationListResponse(BaseModel): + """A page of notifications.""" + + items: list[NotificationResponse] + total: int + has_more: bool + next_offset: int | None + + +class MarkReadResponse(BaseModel): + """Outcome of marking one notification read.""" + + success: bool + message: str + + +class MarkAllReadResponse(BaseModel): + """Outcome of marking every notification read.""" + + success: bool + message: str + updated_count: int + + +class SourceTypeItem(BaseModel): + """A source type with its category and count.""" + + key: str + type: str + category: str # "connector" or "document" + count: int + + +class SourceTypesResponse(BaseModel): + """Source types available for the Status tab filter.""" + + sources: list[SourceTypeItem] + + +class UnreadCountResponse(BaseModel): + """Unread totals, split by sync-window recency.""" + + total_unread: int + recent_unread: int + + +class CategoryUnreadCount(BaseModel): + total_unread: int + recent_unread: int + + +class BatchUnreadCountResponse(BaseModel): + """Per-category unread counts in one response.""" + + comments: CategoryUnreadCount + status: CategoryUnreadCount diff --git a/surfsense_backend/app/notifications/api/transform.py b/surfsense_backend/app/notifications/api/transform.py new file mode 100644 index 000000000..8970cb0b8 --- /dev/null +++ b/surfsense_backend/app/notifications/api/transform.py @@ -0,0 +1,62 @@ +"""Pure request/response helpers for the notifications API. + +No DB or framework objects, so these are unit-testable in isolation. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import NamedTuple + +from app.notifications.api.schemas import NotificationResponse +from app.notifications.persistence import Notification + + +class SourceTypeFilter(NamedTuple): + """The notification types and JSONB facet a source-type filter selects.""" + + types: tuple[str, ...] + metadata_key: str + value: str + + +def parse_source_type(source_type: str) -> SourceTypeFilter | None: + """Decode a `connector:` / `doctype:` filter, or None if unknown.""" + if source_type.startswith("connector:"): + return SourceTypeFilter( + types=("connector_indexing", "connector_deletion"), + metadata_key="connector_type", + value=source_type[len("connector:") :], + ) + if source_type.startswith("doctype:"): + return SourceTypeFilter( + types=("document_processing",), + metadata_key="document_type", + value=source_type[len("doctype:") :], + ) + return None + + +def parse_before_date(before_date: str) -> datetime: + """Parse an ISO date for pagination; raises ValueError if malformed.""" + return datetime.fromisoformat(before_date.replace("Z", "+00:00")) + + +def to_response(notification: Notification) -> NotificationResponse: + """Map a persisted notification to its API response shape.""" + return NotificationResponse( + id=notification.id, + user_id=str(notification.user_id), + search_space_id=notification.search_space_id, + type=notification.type, + title=notification.title, + message=notification.message, + read=notification.read, + metadata=notification.notification_metadata or {}, + created_at=notification.created_at.isoformat() + if notification.created_at + else "", + updated_at=notification.updated_at.isoformat() + if notification.updated_at + else None, + ) diff --git a/surfsense_backend/app/notifications/constants.py b/surfsense_backend/app/notifications/constants.py new file mode 100644 index 000000000..e8bd8391d --- /dev/null +++ b/surfsense_backend/app/notifications/constants.py @@ -0,0 +1,17 @@ +"""Notification policy constants.""" + +from __future__ import annotations + +# Notifications newer than this are live-synced; older ones load via the list endpoint. +SYNC_WINDOW_DAYS = 14 + +# Maps an inbox tab to the notification types it shows. +CATEGORY_TYPES: dict[str, tuple[str, ...]] = { + "comments": ("new_mention", "comment_reply"), + "status": ( + "connector_indexing", + "connector_deletion", + "document_processing", + "page_limit_exceeded", + ), +} diff --git a/surfsense_backend/app/notifications/persistence/__init__.py b/surfsense_backend/app/notifications/persistence/__init__.py new file mode 100644 index 000000000..82f9e6f01 --- /dev/null +++ b/surfsense_backend/app/notifications/persistence/__init__.py @@ -0,0 +1,7 @@ +"""Notification persistence models.""" + +from __future__ import annotations + +from .models import Notification + +__all__ = ["Notification"] diff --git a/surfsense_backend/app/notifications/persistence/models.py b/surfsense_backend/app/notifications/persistence/models.py new file mode 100644 index 000000000..557c4bf17 --- /dev/null +++ b/surfsense_backend/app/notifications/persistence/models.py @@ -0,0 +1,72 @@ +"""Per-user inbox notifications, synced to clients via Zero.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from sqlalchemy import ( + TIMESTAMP, + Boolean, + Column, + ForeignKey, + Index, + Integer, + String, + Text, + text, +) +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import relationship + +from app.db import BaseModel, TimestampMixin + + +class Notification(BaseModel, TimestampMixin): + __tablename__ = "notifications" + __table_args__ = ( + # Serves unread-count queries. + Index( + "ix_notifications_user_read_type_created", + "user_id", + "read", + "type", + "created_at", + ), + # Serves the paginated inbox list query. + Index( + "ix_notifications_user_space_created", + "user_id", + "search_space_id", + "created_at", + ), + ) + + user_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + search_space_id = Column( + Integer, + ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=True, + index=True, + ) + type = Column(String(50), nullable=False, index=True) + title = Column(String(200), nullable=False) + message = Column(Text, nullable=False) + read = Column( + Boolean, nullable=False, default=False, server_default=text("false"), index=True + ) + notification_metadata = Column("metadata", JSONB, nullable=True, default={}) + updated_at = Column( + TIMESTAMP(timezone=True), + nullable=True, + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + index=True, + ) + + user = relationship("User", back_populates="notifications") + search_space = relationship("SearchSpace", back_populates="notifications") diff --git a/surfsense_backend/app/notifications/service/__init__.py b/surfsense_backend/app/notifications/service/__init__.py new file mode 100644 index 000000000..8cb8491ac --- /dev/null +++ b/surfsense_backend/app/notifications/service/__init__.py @@ -0,0 +1,7 @@ +"""Notification creation/update service.""" + +from __future__ import annotations + +from app.notifications.service.facade import NotificationService + +__all__ = ["NotificationService"] diff --git a/surfsense_backend/app/notifications/service/base.py b/surfsense_backend/app/notifications/service/base.py new file mode 100644 index 000000000..31b378cda --- /dev/null +++ b/surfsense_backend/app/notifications/service/base.py @@ -0,0 +1,119 @@ +"""Shared find/upsert/update logic for a single notification type.""" + +from __future__ import annotations + +import logging +from typing import Any +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm.attributes import flag_modified + +from app.notifications.persistence import Notification +from app.notifications.service.metadata import apply_update, start_metadata + +logger = logging.getLogger(__name__) + + +class BaseNotificationHandler: + """Find, upsert, and update notifications of one ``type``.""" + + def __init__(self, notification_type: str): + self.notification_type = notification_type + + async def find_notification_by_operation( + self, + session: AsyncSession, + user_id: UUID, + operation_id: str, + search_space_id: int | None = None, + ) -> Notification | None: + """Return the notification for ``operation_id``, if one exists.""" + query = select(Notification).where( + Notification.user_id == user_id, + Notification.type == self.notification_type, + Notification.notification_metadata["operation_id"].astext == operation_id, + ) + if search_space_id is not None: + query = query.where(Notification.search_space_id == search_space_id) + + result = await session.execute(query) + return result.scalar_one_or_none() + + async def find_or_create_notification( + self, + session: AsyncSession, + user_id: UUID, + operation_id: str, + title: str, + message: str, + search_space_id: int | None = None, + initial_metadata: dict[str, Any] | None = None, + ) -> Notification: + """Upsert a notification keyed by ``operation_id``.""" + notification = await self.find_notification_by_operation( + session, user_id, operation_id, search_space_id + ) + + if notification: + notification.title = title + notification.message = message + if initial_metadata: + notification.notification_metadata = apply_update( + notification.notification_metadata, + metadata_updates=initial_metadata, + ) + # Tell SQLAlchemy the JSONB dict changed in place. + flag_modified(notification, "notification_metadata") + await session.commit() + await session.refresh(notification) + logger.info( + f"Updated notification {notification.id} for operation {operation_id}" + ) + return notification + + metadata = start_metadata(operation_id, initial_metadata) + + notification = Notification( + user_id=user_id, + search_space_id=search_space_id, + type=self.notification_type, + title=title, + message=message, + notification_metadata=metadata, + ) + session.add(notification) + await session.commit() + await session.refresh(notification) + logger.info( + f"Created notification {notification.id} for operation {operation_id}" + ) + return notification + + async def update_notification( + self, + session: AsyncSession, + notification: Notification, + title: str | None = None, + message: str | None = None, + status: str | None = None, + metadata_updates: dict[str, Any] | None = None, + ) -> Notification: + """Apply field/status/metadata changes and persist.""" + if title is not None: + notification.title = title + if message is not None: + notification.message = message + + if status is not None or metadata_updates: + notification.notification_metadata = apply_update( + notification.notification_metadata, status, metadata_updates + ) + # Tell SQLAlchemy the JSONB dict changed in place. + flag_modified(notification, "notification_metadata") + + await session.commit() + await session.refresh(notification) + logger.info(f"Updated notification {notification.id}") + return notification diff --git a/surfsense_backend/app/notifications/service/facade.py b/surfsense_backend/app/notifications/service/facade.py new file mode 100644 index 000000000..63154301c --- /dev/null +++ b/surfsense_backend/app/notifications/service/facade.py @@ -0,0 +1,55 @@ +"""Single entry point that composes the per-type notification handlers.""" + +from __future__ import annotations + +import logging +from typing import Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.notifications.persistence import Notification +from app.notifications.service.handlers import ( + CommentReplyNotificationHandler, + ConnectorIndexingNotificationHandler, + DocumentProcessingNotificationHandler, + MentionNotificationHandler, + PageLimitNotificationHandler, +) + +logger = logging.getLogger(__name__) + + +class NotificationService: + """Facade over the per-type handlers; mutations sync via Zero.""" + + connector_indexing = ConnectorIndexingNotificationHandler() + document_processing = DocumentProcessingNotificationHandler() + mention = MentionNotificationHandler() + comment_reply = CommentReplyNotificationHandler() + page_limit = PageLimitNotificationHandler() + + @staticmethod + async def create_notification( + session: AsyncSession, + user_id: UUID, + notification_type: str, + title: str, + message: str, + search_space_id: int | None = None, + notification_metadata: dict[str, Any] | None = None, + ) -> Notification: + """Create a generic notification of any ``notification_type``.""" + notification = Notification( + user_id=user_id, + search_space_id=search_space_id, + type=notification_type, + title=title, + message=message, + notification_metadata=notification_metadata or {}, + ) + session.add(notification) + await session.commit() + await session.refresh(notification) + logger.info(f"Created notification {notification.id} for user {user_id}") + return notification diff --git a/surfsense_backend/app/notifications/service/handlers/__init__.py b/surfsense_backend/app/notifications/service/handlers/__init__.py new file mode 100644 index 000000000..8c32dea3b --- /dev/null +++ b/surfsense_backend/app/notifications/service/handlers/__init__.py @@ -0,0 +1,17 @@ +"""Per-type notification handlers.""" + +from __future__ import annotations + +from .comment_reply import CommentReplyNotificationHandler +from .connector_indexing import ConnectorIndexingNotificationHandler +from .document_processing import DocumentProcessingNotificationHandler +from .mention import MentionNotificationHandler +from .page_limit import PageLimitNotificationHandler + +__all__ = [ + "CommentReplyNotificationHandler", + "ConnectorIndexingNotificationHandler", + "DocumentProcessingNotificationHandler", + "MentionNotificationHandler", + "PageLimitNotificationHandler", +] diff --git a/surfsense_backend/app/notifications/service/handlers/comment_reply.py b/surfsense_backend/app/notifications/service/handlers/comment_reply.py new file mode 100644 index 000000000..7d9a9495a --- /dev/null +++ b/surfsense_backend/app/notifications/service/handlers/comment_reply.py @@ -0,0 +1,108 @@ +"""Notifications for replies to a user's comments.""" + +from __future__ import annotations + +import logging +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.notifications.persistence import Notification +from app.notifications.service.base import BaseNotificationHandler +from app.notifications.service.messages.text import truncate + +logger = logging.getLogger(__name__) + + +class CommentReplyNotificationHandler(BaseNotificationHandler): + """Notifications for replies to a user's comments.""" + + def __init__(self): + super().__init__("comment_reply") + + async def find_notification_by_reply( + self, + session: AsyncSession, + reply_id: int, + user_id: UUID, + ) -> Notification | None: + query = select(Notification).where( + Notification.type == self.notification_type, + Notification.user_id == user_id, + Notification.notification_metadata["reply_id"].astext == str(reply_id), + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def notify_comment_reply( + self, + session: AsyncSession, + user_id: UUID, + reply_id: int, + parent_comment_id: int, + message_id: int, + thread_id: int, + thread_title: str, + author_id: str, + author_name: str, + author_avatar_url: str | None, + author_email: str, + content_preview: str, + search_space_id: int, + ) -> Notification: + """Notify of a reply; idempotent on ``reply_id`` per user.""" + existing = await self.find_notification_by_reply(session, reply_id, user_id) + if existing: + logger.info( + f"Notification already exists for reply {reply_id} to user {user_id}" + ) + return existing + + title = f"{author_name} replied in a thread" + message = truncate(content_preview, 100) + + metadata = { + "reply_id": reply_id, + "parent_comment_id": parent_comment_id, + "message_id": message_id, + "thread_id": thread_id, + "thread_title": thread_title, + "author_id": author_id, + "author_name": author_name, + "author_avatar_url": author_avatar_url, + "author_email": author_email, + "content_preview": content_preview[:200], + } + + try: + notification = Notification( + user_id=user_id, + search_space_id=search_space_id, + type=self.notification_type, + title=title, + message=message, + notification_metadata=metadata, + ) + session.add(notification) + await session.commit() + await session.refresh(notification) + logger.info( + f"Created comment_reply notification {notification.id} for user {user_id}" + ) + return notification + except Exception as e: + await session.rollback() + if ( + "duplicate key" in str(e).lower() + or "unique constraint" in str(e).lower() + ): + logger.warning( + f"Duplicate notification for reply {reply_id} to user {user_id}" + ) + existing = await self.find_notification_by_reply( + session, reply_id, user_id + ) + if existing: + return existing + raise diff --git a/surfsense_backend/app/notifications/service/handlers/connector_indexing.py b/surfsense_backend/app/notifications/service/handlers/connector_indexing.py new file mode 100644 index 000000000..9ebfae2ea --- /dev/null +++ b/surfsense_backend/app/notifications/service/handlers/connector_indexing.py @@ -0,0 +1,183 @@ +"""Notifications for connector indexing runs.""" + +from __future__ import annotations + +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.notifications.persistence import Notification +from app.notifications.service.base import BaseNotificationHandler +from app.notifications.service.messages import connector_indexing as msg + + +class ConnectorIndexingNotificationHandler(BaseNotificationHandler): + """Notifications for connector indexing runs.""" + + def __init__(self): + super().__init__("connector_indexing") + + async def notify_indexing_started( + self, + session: AsyncSession, + user_id: UUID, + connector_id: int, + connector_name: str, + connector_type: str, + search_space_id: int, + start_date: str | None = None, + end_date: str | None = None, + ) -> Notification: + """Open (or refresh) the notification when indexing starts.""" + operation_id = msg.operation_id(connector_id, start_date, end_date) + title = f"Syncing: {connector_name}" + message = "Connecting to your account" + + metadata = { + "connector_id": connector_id, + "connector_name": connector_name, + "connector_type": connector_type, + "start_date": start_date, + "end_date": end_date, + "indexed_count": 0, + "sync_stage": "connecting", + } + + return await self.find_or_create_notification( + session=session, + user_id=user_id, + operation_id=operation_id, + title=title, + message=message, + search_space_id=search_space_id, + initial_metadata=metadata, + ) + + async def notify_indexing_progress( + self, + session: AsyncSession, + notification: Notification, + indexed_count: int, + total_count: int | None = None, + stage: str | None = None, + stage_message: str | None = None, + ) -> Notification: + """Update the notification with indexing progress.""" + message, metadata_updates = msg.progress( + indexed_count, total_count, stage, stage_message + ) + return await self.update_notification( + session=session, + notification=notification, + message=message, + status="in_progress", + metadata_updates=metadata_updates, + ) + + async def notify_retry_progress( + self, + session: AsyncSession, + notification: Notification, + indexed_count: int, + retry_reason: str, + attempt: int, + max_attempts: int, + wait_seconds: float | None = None, + service_name: str | None = None, + ) -> Notification: + """Surface that an external service is rate-limiting/retrying.""" + connector_name = notification.notification_metadata.get( + "connector_name", "Service" + ) + message, metadata_updates = msg.retry( + connector_name, + indexed_count, + retry_reason, + attempt, + max_attempts, + wait_seconds, + service_name, + ) + return await self.update_notification( + session=session, + notification=notification, + message=message, + status="in_progress", + metadata_updates=metadata_updates, + ) + + async def notify_indexing_completed( + self, + session: AsyncSession, + notification: Notification, + indexed_count: int, + error_message: str | None = None, + is_warning: bool = False, + skipped_count: int | None = None, + unsupported_count: int | None = None, + ) -> Notification: + """Finalize the notification as ready/failed when indexing ends.""" + connector_name = notification.notification_metadata.get( + "connector_name", "Connector" + ) + title, message, status, metadata_updates = msg.completion( + connector_name, + indexed_count, + error_message, + is_warning, + skipped_count, + unsupported_count, + ) + return await self.update_notification( + session=session, + notification=notification, + title=title, + message=message, + status=status, + metadata_updates=metadata_updates, + ) + + async def notify_google_drive_indexing_started( + self, + session: AsyncSession, + user_id: UUID, + connector_id: int, + connector_name: str, + connector_type: str, + search_space_id: int, + folder_count: int, + file_count: int, + folder_names: list[str] | None = None, + file_names: list[str] | None = None, + ) -> Notification: + """Open (or refresh) the notification when Drive indexing starts.""" + operation_id = msg.google_drive_operation_id( + connector_id, folder_count, file_count + ) + title = f"Syncing: {connector_name}" + message = "Preparing your files" + + metadata = { + "connector_id": connector_id, + "connector_name": connector_name, + "connector_type": connector_type, + "folder_count": folder_count, + "file_count": file_count, + "indexed_count": 0, + "sync_stage": "connecting", + } + + if folder_names: + metadata["folder_names"] = folder_names + if file_names: + metadata["file_names"] = file_names + + return await self.find_or_create_notification( + session=session, + user_id=user_id, + operation_id=operation_id, + title=title, + message=message, + search_space_id=search_space_id, + initial_metadata=metadata, + ) diff --git a/surfsense_backend/app/notifications/service/handlers/document_processing.py b/surfsense_backend/app/notifications/service/handlers/document_processing.py new file mode 100644 index 000000000..8644df2c8 --- /dev/null +++ b/surfsense_backend/app/notifications/service/handlers/document_processing.py @@ -0,0 +1,95 @@ +"""Notifications for single-document processing.""" + +from __future__ import annotations + +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.notifications.persistence import Notification +from app.notifications.service.base import BaseNotificationHandler +from app.notifications.service.messages import document_processing as msg + + +class DocumentProcessingNotificationHandler(BaseNotificationHandler): + """Notifications for single-document processing.""" + + def __init__(self): + super().__init__("document_processing") + + async def notify_processing_started( + self, + session: AsyncSession, + user_id: UUID, + document_type: str, + document_name: str, + search_space_id: int, + file_size: int | None = None, + ) -> Notification: + """Open the notification when document processing is queued.""" + operation_id = msg.operation_id(document_type, document_name, search_space_id) + title = f"Processing: {document_name}" + message = "Waiting in queue" + + metadata = { + "document_type": document_type, + "document_name": document_name, + "processing_stage": "queued", + } + + if file_size is not None: + metadata["file_size"] = file_size + + return await self.find_or_create_notification( + session=session, + user_id=user_id, + operation_id=operation_id, + title=title, + message=message, + search_space_id=search_space_id, + initial_metadata=metadata, + ) + + async def notify_processing_progress( + self, + session: AsyncSession, + notification: Notification, + stage: str, + stage_message: str | None = None, + chunks_count: int | None = None, + ) -> Notification: + """Update the notification with the current processing stage.""" + message, metadata_updates = msg.progress(stage, stage_message, chunks_count) + + return await self.update_notification( + session=session, + notification=notification, + message=message, + status="in_progress", + metadata_updates=metadata_updates, + ) + + async def notify_processing_completed( + self, + session: AsyncSession, + notification: Notification, + document_id: int | None = None, + chunks_count: int | None = None, + error_message: str | None = None, + ) -> Notification: + """Finalize the notification as ready/failed when processing ends.""" + document_name = notification.notification_metadata.get( + "document_name", "Document" + ) + title, message, status, metadata_updates = msg.completion( + document_name, error_message, document_id, chunks_count + ) + + return await self.update_notification( + session=session, + notification=notification, + title=title, + message=message, + status=status, + metadata_updates=metadata_updates, + ) diff --git a/surfsense_backend/app/notifications/service/handlers/mention.py b/surfsense_backend/app/notifications/service/handlers/mention.py new file mode 100644 index 000000000..568dc01de --- /dev/null +++ b/surfsense_backend/app/notifications/service/handlers/mention.py @@ -0,0 +1,106 @@ +"""Notifications for @mentions in comments.""" + +from __future__ import annotations + +import logging +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.notifications.persistence import Notification +from app.notifications.service.base import BaseNotificationHandler +from app.notifications.service.messages.text import truncate + +logger = logging.getLogger(__name__) + + +class MentionNotificationHandler(BaseNotificationHandler): + """Notifications for @mentions in comments.""" + + def __init__(self): + super().__init__("new_mention") + + async def find_notification_by_mention( + self, + session: AsyncSession, + mention_id: int, + ) -> Notification | None: + """Return the notification for ``mention_id``, if one exists.""" + query = select(Notification).where( + Notification.type == self.notification_type, + Notification.notification_metadata["mention_id"].astext == str(mention_id), + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def notify_new_mention( + self, + session: AsyncSession, + mentioned_user_id: UUID, + mention_id: int, + comment_id: int, + message_id: int, + thread_id: int, + thread_title: str, + author_id: str, + author_name: str, + author_avatar_url: str | None, + author_email: str, + content_preview: str, + search_space_id: int, + ) -> Notification: + """Notify a mentioned user; idempotent on ``mention_id``.""" + existing = await self.find_notification_by_mention(session, mention_id) + if existing: + logger.info( + f"Notification already exists for mention {mention_id}, returning existing" + ) + return existing + + title = f"{author_name} mentioned you" + message = truncate(content_preview, 100) + + metadata = { + "mention_id": mention_id, + "comment_id": comment_id, + "message_id": message_id, + "thread_id": thread_id, + "thread_title": thread_title, + "author_id": author_id, + "author_name": author_name, + "author_avatar_url": author_avatar_url, + "author_email": author_email, + "content_preview": content_preview[:200], + } + + try: + notification = Notification( + user_id=mentioned_user_id, + search_space_id=search_space_id, + type=self.notification_type, + title=title, + message=message, + notification_metadata=metadata, + ) + session.add(notification) + await session.commit() + await session.refresh(notification) + logger.info( + f"Created new_mention notification {notification.id} for user {mentioned_user_id}" + ) + return notification + except Exception as e: + # Race: a concurrent insert won; fetch the existing row instead. + await session.rollback() + if ( + "duplicate key" in str(e).lower() + or "unique constraint" in str(e).lower() + ): + logger.warning( + f"Duplicate notification detected for mention {mention_id}, fetching existing" + ) + existing = await self.find_notification_by_mention(session, mention_id) + if existing: + return existing + raise diff --git a/surfsense_backend/app/notifications/service/handlers/page_limit.py b/surfsense_backend/app/notifications/service/handlers/page_limit.py new file mode 100644 index 000000000..90722dc62 --- /dev/null +++ b/surfsense_backend/app/notifications/service/handlers/page_limit.py @@ -0,0 +1,68 @@ +"""Notifications for exceeding the page limit.""" + +from __future__ import annotations + +import logging +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.notifications.persistence import Notification +from app.notifications.service.base import BaseNotificationHandler +from app.notifications.service.messages import page_limit as msg + +logger = logging.getLogger(__name__) + + +class PageLimitNotificationHandler(BaseNotificationHandler): + """Notifications for exceeding the page limit.""" + + def __init__(self): + super().__init__("page_limit_exceeded") + + async def notify_page_limit_exceeded( + self, + session: AsyncSession, + user_id: UUID, + document_name: str, + document_type: str, + search_space_id: int, + pages_used: int, + pages_limit: int, + pages_to_add: int, + ) -> Notification: + """Notify that a document was blocked by the page limit.""" + operation_id = msg.operation_id(document_name, search_space_id) + title, message = msg.summary( + document_name, pages_used, pages_limit, pages_to_add + ) + + metadata = { + "operation_id": operation_id, + "document_name": document_name, + "document_type": document_type, + "pages_used": pages_used, + "pages_limit": pages_limit, + "pages_to_add": pages_to_add, + "status": "failed", + "error_type": "page_limit_exceeded", + # Where the inbox item links to. + "action_url": f"/dashboard/{search_space_id}/more-pages", + "action_label": "Upgrade Plan", + } + + notification = Notification( + user_id=user_id, + search_space_id=search_space_id, + type=self.notification_type, + title=title, + message=message, + notification_metadata=metadata, + ) + session.add(notification) + await session.commit() + await session.refresh(notification) + logger.info( + f"Created page_limit_exceeded notification {notification.id} for user {user_id}" + ) + return notification diff --git a/surfsense_backend/app/notifications/service/messages/__init__.py b/surfsense_backend/app/notifications/service/messages/__init__.py new file mode 100644 index 000000000..95373537d --- /dev/null +++ b/surfsense_backend/app/notifications/service/messages/__init__.py @@ -0,0 +1,6 @@ +"""Pure, side-effect-free presentation logic for notifications. + +Handlers compute their user-facing title/message/status/metadata here, then +persist the result. Keeping this layer free of I/O makes it unit-testable +without a database. +""" diff --git a/surfsense_backend/app/notifications/service/messages/connector_indexing.py b/surfsense_backend/app/notifications/service/messages/connector_indexing.py new file mode 100644 index 000000000..8a2926211 --- /dev/null +++ b/surfsense_backend/app/notifications/service/messages/connector_indexing.py @@ -0,0 +1,164 @@ +"""Pure presentation logic for connector-indexing notifications.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + + +def operation_id( + connector_id: int, + start_date: str | None = None, + end_date: str | None = None, +) -> str: + """Build a unique id for a connector indexing run.""" + timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S") + date_range = "" + if start_date or end_date: + date_range = f"_{start_date or 'none'}_{end_date or 'none'}" + return f"connector_{connector_id}_{timestamp}{date_range}" + + +def google_drive_operation_id( + connector_id: int, folder_count: int, file_count: int +) -> str: + """Build a unique id for a Google Drive indexing run.""" + timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S") + items_info = f"_{folder_count}f_{file_count}files" + return f"drive_{connector_id}_{timestamp}{items_info}" + + +def progress( + indexed_count: int, + total_count: int | None = None, + stage: str | None = None, + stage_message: str | None = None, +) -> tuple[str, dict[str, Any]]: + """Compute the progress message and metadata updates for an indexing run.""" + stage_messages = { + "connecting": "Connecting to your account", + "fetching": "Fetching your content", + "processing": "Preparing for search", + "storing": "Almost done", + } + + if stage or stage_message: + progress_msg = stage_message or stage_messages.get(stage, "Processing") + else: + # Legacy callers that pass neither stage nor message. + progress_msg = "Fetching your content" + + metadata_updates: dict[str, Any] = {"indexed_count": indexed_count} + if total_count is not None: + metadata_updates["total_count"] = total_count + progress_percent = int((indexed_count / total_count) * 100) + metadata_updates["progress_percent"] = progress_percent + if stage: + metadata_updates["sync_stage"] = stage + + return progress_msg, metadata_updates + + +def retry( + connector_name: str, + indexed_count: int, + retry_reason: str, + attempt: int, + max_attempts: int, + wait_seconds: float | None = None, + service_name: str | None = None, +) -> tuple[str, dict[str, Any]]: + """Compute the retry message and metadata, framing the delay as the provider's.""" + if not service_name: + service_name = connector_name + # Strip the workspace suffix, e.g. "Notion - My Workspace" -> "Notion". + if " - " in service_name: + service_name = service_name.split(" - ")[0] + + # Worded so the delay reads as the provider's, not ours. + retry_messages = { + "rate_limit": f"{service_name} rate limit reached", + "server_error": f"{service_name} is slow to respond", + "timeout": f"{service_name} took too long", + "temporary_error": f"{service_name} temporarily unavailable", + } + + base_message = retry_messages.get(retry_reason, f"Waiting for {service_name}") + + # Only surface a wait time when it's long enough to be worth showing. + if wait_seconds and wait_seconds > 5: + message = f"{base_message}. Retrying in {int(wait_seconds)}s..." + else: + message = f"{base_message}. Retrying..." + + if indexed_count > 0: + item_text = "item" if indexed_count == 1 else "items" + message = f"{message} ({indexed_count} {item_text} synced so far)" + + metadata_updates = { + "indexed_count": indexed_count, + "sync_stage": "waiting_retry", + "retry_attempt": attempt, + "retry_max_attempts": max_attempts, + "retry_reason": retry_reason, + "retry_wait_seconds": wait_seconds, + } + + return message, metadata_updates + + +def completion( + connector_name: str, + indexed_count: int, + error_message: str | None = None, + is_warning: bool = False, + skipped_count: int | None = None, + unsupported_count: int | None = None, +) -> tuple[str, str, str, dict[str, Any]]: + """Compute the final title, message, status, and metadata for a finished run.""" + unsupported_text = "" + if unsupported_count and unsupported_count > 0: + file_word = "file was" if unsupported_count == 1 else "files were" + unsupported_text = f" {unsupported_count} {file_word} not supported." + + if error_message: + if indexed_count > 0: + title = f"Ready: {connector_name}" + file_text = "file" if indexed_count == 1 else "files" + message = f"Now searchable! {indexed_count} {file_text} synced.{unsupported_text} Note: {error_message}" + status = "completed" + elif is_warning: + title = f"Ready: {connector_name}" + message = f"Sync complete.{unsupported_text} {error_message}" + status = "completed" + else: + title = f"Failed: {connector_name}" + message = f"Sync failed: {error_message}" + if unsupported_text: + message += unsupported_text + status = "failed" + else: + title = f"Ready: {connector_name}" + if indexed_count == 0: + if unsupported_count and unsupported_count > 0: + message = f"Sync complete.{unsupported_text}" + else: + message = "Already up to date!" + else: + file_text = "file" if indexed_count == 1 else "files" + message = f"Now searchable! {indexed_count} {file_text} synced." + if unsupported_text: + message += unsupported_text + status = "completed" + + metadata_updates = { + "indexed_count": indexed_count, + "skipped_count": skipped_count or 0, + "unsupported_count": unsupported_count or 0, + "sync_stage": "completed" + if (not error_message or is_warning or indexed_count > 0) + else "failed", + "error_message": error_message, + } + + return title, message, status, metadata_updates diff --git a/surfsense_backend/app/notifications/service/messages/document_processing.py b/surfsense_backend/app/notifications/service/messages/document_processing.py new file mode 100644 index 000000000..3805c2847 --- /dev/null +++ b/surfsense_backend/app/notifications/service/messages/document_processing.py @@ -0,0 +1,64 @@ +"""Pure presentation logic for document-processing notifications.""" + +from __future__ import annotations + +import hashlib +from datetime import UTC, datetime +from typing import Any + + +def operation_id(document_type: str, filename: str, search_space_id: int) -> str: + """Build a unique id for a document processing run.""" + timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f") + filename_hash = hashlib.md5(filename.encode()).hexdigest()[:8] + return f"doc_{document_type}_{search_space_id}_{timestamp}_{filename_hash}" + + +def progress( + stage: str, + stage_message: str | None = None, + chunks_count: int | None = None, +) -> tuple[str, dict[str, Any]]: + """Compute the progress message and metadata updates for a processing run.""" + stage_messages = { + "parsing": "Reading your file", + "chunking": "Preparing for search", + "embedding": "Preparing for search", + "storing": "Finalizing", + } + + message = stage_message or stage_messages.get(stage, "Processing") + + metadata_updates: dict[str, Any] = {"processing_stage": stage} + if chunks_count is not None: + metadata_updates["chunks_count"] = chunks_count + + return message, metadata_updates + + +def completion( + document_name: str, + error_message: str | None = None, + document_id: int | None = None, + chunks_count: int | None = None, +) -> tuple[str, str, str, dict[str, Any]]: + """Compute the final title, message, status, and metadata for a finished run.""" + if error_message: + title = f"Failed: {document_name}" + message = f"Processing failed: {error_message}" + status = "failed" + else: + title = f"Ready: {document_name}" + message = "Now searchable!" + status = "completed" + + metadata_updates: dict[str, Any] = { + "processing_stage": "completed" if not error_message else "failed", + "error_message": error_message, + } + if document_id is not None: + metadata_updates["document_id"] = document_id + if chunks_count is not None: + metadata_updates["chunks_count"] = chunks_count + + return title, message, status, metadata_updates diff --git a/surfsense_backend/app/notifications/service/messages/page_limit.py b/surfsense_backend/app/notifications/service/messages/page_limit.py new file mode 100644 index 000000000..54e5cbdec --- /dev/null +++ b/surfsense_backend/app/notifications/service/messages/page_limit.py @@ -0,0 +1,25 @@ +"""Pure presentation logic for page-limit notifications.""" + +from __future__ import annotations + +import hashlib +from datetime import UTC, datetime + +from app.notifications.service.messages.text import truncate + + +def operation_id(document_name: str, search_space_id: int) -> str: + """Build a unique id for a page-limit notification.""" + timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f") + doc_hash = hashlib.md5(document_name.encode()).hexdigest()[:8] + return f"page_limit_{search_space_id}_{timestamp}_{doc_hash}" + + +def summary( + document_name: str, pages_used: int, pages_limit: int, pages_to_add: int +) -> tuple[str, str]: + """Compute the title and message for a blocked-by-page-limit document.""" + display_name = truncate(document_name, 40) + title = f"Page limit exceeded: {display_name}" + message = f"This document has ~{pages_to_add} page(s) but you've used {pages_used}/{pages_limit} pages. Upgrade to process more documents." + return title, message diff --git a/surfsense_backend/app/notifications/service/messages/text.py b/surfsense_backend/app/notifications/service/messages/text.py new file mode 100644 index 000000000..98d5284cb --- /dev/null +++ b/surfsense_backend/app/notifications/service/messages/text.py @@ -0,0 +1,8 @@ +"""Shared text helpers for notification copy.""" + +from __future__ import annotations + + +def truncate(text: str, limit: int) -> str: + """Return ``text`` capped at ``limit`` chars, appending an ellipsis if cut.""" + return text[:limit] + "..." if len(text) > limit else text diff --git a/surfsense_backend/app/notifications/service/metadata.py b/surfsense_backend/app/notifications/service/metadata.py new file mode 100644 index 000000000..2679893dc --- /dev/null +++ b/surfsense_backend/app/notifications/service/metadata.py @@ -0,0 +1,33 @@ +"""Pure metadata transitions for the notification lifecycle.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + + +def start_metadata( + operation_id: str, initial_metadata: dict[str, Any] | None = None +) -> dict[str, Any]: + """Seed metadata for a freshly opened, in-progress notification.""" + metadata = dict(initial_metadata or {}) + metadata["operation_id"] = operation_id + metadata["status"] = "in_progress" + metadata["started_at"] = datetime.now(UTC).isoformat() + return metadata + + +def apply_update( + current: dict[str, Any], + status: str | None = None, + metadata_updates: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Return metadata with the status/timestamp stamped and updates merged in.""" + metadata = dict(current) + if status is not None: + metadata["status"] = status + if status in ("completed", "failed"): + metadata["completed_at"] = datetime.now(UTC).isoformat() + if metadata_updates: + metadata = {**metadata, **metadata_updates} + return metadata diff --git a/surfsense_backend/app/notifications/types.py b/surfsense_backend/app/notifications/types.py new file mode 100644 index 000000000..bb8bcfab1 --- /dev/null +++ b/surfsense_backend/app/notifications/types.py @@ -0,0 +1,16 @@ +"""The notification types the API recognizes.""" + +from __future__ import annotations + +from typing import Literal + +NotificationType = Literal[ + "connector_indexing", + "connector_deletion", + "document_processing", + "new_mention", + "comment_reply", + "page_limit_exceeded", +] + +NotificationCategory = Literal["comments", "status"] diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/prompts/default_system_instructions.py similarity index 87% rename from surfsense_backend/app/agents/new_chat/system_prompt.py rename to surfsense_backend/app/prompts/default_system_instructions.py index 70634c65d..fd0a8e186 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/prompts/default_system_instructions.py @@ -1,5 +1,5 @@ """ -Thin compatibility wrapper around :mod:`app.agents.new_chat.prompts.composer`. +Thin compatibility wrapper around :mod:`app.prompts.system_prompt_composer.composer`. The composer split the previous monolithic prompt string into a fragment tree under ``prompts/`` plus a model-family dispatch step (see the @@ -7,11 +7,11 @@ composer module docstring for credits). This module preserves the public function surface (``build_surfsense_system_prompt`` / ``build_configurable_system_prompt`` / ``get_default_system_instructions`` / ``SURFSENSE_SYSTEM_PROMPT``) so -that existing call sites — `chat_deepagent.py`, anonymous chat routes, -and the configurable-prompt admin path — keep working without churn. +that existing call sites — the multi-agent chat factory, anonymous chat +routes, and the configurable-prompt admin path — keep working without churn. For new call sites prefer importing ``compose_system_prompt`` directly -from :mod:`app.agents.new_chat.prompts.composer`. +from :mod:`app.prompts.system_prompt_composer.composer`. """ from __future__ import annotations @@ -20,7 +20,7 @@ from datetime import UTC, datetime from app.db import ChatVisibility -from .prompts.composer import ( +from .system_prompt_composer.composer import ( _read_fragment, compose_system_prompt, detect_provider_variant, @@ -55,7 +55,7 @@ def build_surfsense_system_prompt( ) -> str: """Build the default SurfSense system prompt (citations on, defaults). - See :func:`app.agents.new_chat.prompts.composer.compose_system_prompt` + See :func:`app.prompts.system_prompt_composer.composer.compose_system_prompt` for full parameter docs. """ return compose_system_prompt( @@ -84,7 +84,7 @@ def build_configurable_system_prompt( ) -> str: """Build a configurable SurfSense system prompt (NewLLMConfig path). - See :func:`app.agents.new_chat.prompts.composer.compose_system_prompt` + See :func:`app.prompts.system_prompt_composer.composer.compose_system_prompt` for full parameter docs. """ return compose_system_prompt( @@ -108,7 +108,9 @@ def get_default_system_instructions() -> str: The output reflects the current fragment tree, not a baked-in constant. """ resolved_today = datetime.now(UTC).date().isoformat() - from .prompts.composer import _build_system_instructions # local import + from .system_prompt_composer.composer import ( + _build_system_instructions, # local import + ) return _build_system_instructions( visibility=ChatVisibility.PRIVATE, diff --git a/surfsense_backend/app/agents/new_chat/prompts/__init__.py b/surfsense_backend/app/prompts/system_prompt_composer/__init__.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/__init__.py rename to surfsense_backend/app/prompts/system_prompt_composer/__init__.py diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/__init__.py b/surfsense_backend/app/prompts/system_prompt_composer/base/__init__.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/examples/__init__.py rename to surfsense_backend/app/prompts/system_prompt_composer/base/__init__.py diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/agent_private.md b/surfsense_backend/app/prompts/system_prompt_composer/base/agent_private.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/base/agent_private.md rename to surfsense_backend/app/prompts/system_prompt_composer/base/agent_private.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/agent_team.md b/surfsense_backend/app/prompts/system_prompt_composer/base/agent_team.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/base/agent_team.md rename to surfsense_backend/app/prompts/system_prompt_composer/base/agent_team.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/citations_off.md b/surfsense_backend/app/prompts/system_prompt_composer/base/citations_off.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/base/citations_off.md rename to surfsense_backend/app/prompts/system_prompt_composer/base/citations_off.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/citations_on.md b/surfsense_backend/app/prompts/system_prompt_composer/base/citations_on.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/base/citations_on.md rename to surfsense_backend/app/prompts/system_prompt_composer/base/citations_on.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_private.md b/surfsense_backend/app/prompts/system_prompt_composer/base/kb_only_policy_private.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_private.md rename to surfsense_backend/app/prompts/system_prompt_composer/base/kb_only_policy_private.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_team.md b/surfsense_backend/app/prompts/system_prompt_composer/base/kb_only_policy_team.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_team.md rename to surfsense_backend/app/prompts/system_prompt_composer/base/kb_only_policy_team.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_private.md b/surfsense_backend/app/prompts/system_prompt_composer/base/memory_protocol_private.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_private.md rename to surfsense_backend/app/prompts/system_prompt_composer/base/memory_protocol_private.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_team.md b/surfsense_backend/app/prompts/system_prompt_composer/base/memory_protocol_team.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_team.md rename to surfsense_backend/app/prompts/system_prompt_composer/base/memory_protocol_team.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/parameter_resolution.md b/surfsense_backend/app/prompts/system_prompt_composer/base/parameter_resolution.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/base/parameter_resolution.md rename to surfsense_backend/app/prompts/system_prompt_composer/base/parameter_resolution.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_private.md b/surfsense_backend/app/prompts/system_prompt_composer/base/tool_routing_private.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_private.md rename to surfsense_backend/app/prompts/system_prompt_composer/base/tool_routing_private.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_team.md b/surfsense_backend/app/prompts/system_prompt_composer/base/tool_routing_team.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_team.md rename to surfsense_backend/app/prompts/system_prompt_composer/base/tool_routing_team.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/composer.py b/surfsense_backend/app/prompts/system_prompt_composer/composer.py similarity index 99% rename from surfsense_backend/app/agents/new_chat/prompts/composer.py rename to surfsense_backend/app/prompts/system_prompt_composer/composer.py index 412665813..3849af313 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/composer.py +++ b/surfsense_backend/app/prompts/system_prompt_composer/composer.py @@ -2,7 +2,7 @@ Prompt composer for the SurfSense ``new_chat`` agent. This module assembles the agent's system prompt from the markdown fragments -under :mod:`app.agents.new_chat.prompts`. It replaces the monolithic +under :mod:`app.prompts.system_prompt_composer`. It replaces the monolithic ``system_prompt.py`` with a clean, fragment-based composition: :: @@ -119,7 +119,7 @@ def detect_provider_variant(model_name: str | None) -> ProviderVariant: # ----------------------------------------------------------------------------- -_PROMPTS_PACKAGE = "app.agents.new_chat.prompts" +_PROMPTS_PACKAGE = "app.prompts.system_prompt_composer" def _read_fragment(subpath: str) -> str: diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/__init__.py b/surfsense_backend/app/prompts/system_prompt_composer/examples/__init__.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/providers/__init__.py rename to surfsense_backend/app/prompts/system_prompt_composer/examples/__init__.py diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_image.md b/surfsense_backend/app/prompts/system_prompt_composer/examples/generate_image.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/examples/generate_image.md rename to surfsense_backend/app/prompts/system_prompt_composer/examples/generate_image.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_podcast.md b/surfsense_backend/app/prompts/system_prompt_composer/examples/generate_podcast.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/examples/generate_podcast.md rename to surfsense_backend/app/prompts/system_prompt_composer/examples/generate_podcast.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_report.md b/surfsense_backend/app/prompts/system_prompt_composer/examples/generate_report.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/examples/generate_report.md rename to surfsense_backend/app/prompts/system_prompt_composer/examples/generate_report.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_resume.md b/surfsense_backend/app/prompts/system_prompt_composer/examples/generate_resume.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/examples/generate_resume.md rename to surfsense_backend/app/prompts/system_prompt_composer/examples/generate_resume.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_video_presentation.md b/surfsense_backend/app/prompts/system_prompt_composer/examples/generate_video_presentation.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/examples/generate_video_presentation.md rename to surfsense_backend/app/prompts/system_prompt_composer/examples/generate_video_presentation.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/scrape_webpage.md b/surfsense_backend/app/prompts/system_prompt_composer/examples/scrape_webpage.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/examples/scrape_webpage.md rename to surfsense_backend/app/prompts/system_prompt_composer/examples/scrape_webpage.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_private.md b/surfsense_backend/app/prompts/system_prompt_composer/examples/update_memory_private.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_private.md rename to surfsense_backend/app/prompts/system_prompt_composer/examples/update_memory_private.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_team.md b/surfsense_backend/app/prompts/system_prompt_composer/examples/update_memory_team.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_team.md rename to surfsense_backend/app/prompts/system_prompt_composer/examples/update_memory_team.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/web_search.md b/surfsense_backend/app/prompts/system_prompt_composer/examples/web_search.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/examples/web_search.md rename to surfsense_backend/app/prompts/system_prompt_composer/examples/web_search.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/routing/__init__.py b/surfsense_backend/app/prompts/system_prompt_composer/providers/__init__.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/routing/__init__.py rename to surfsense_backend/app/prompts/system_prompt_composer/providers/__init__.py diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/anthropic.md b/surfsense_backend/app/prompts/system_prompt_composer/providers/anthropic.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/providers/anthropic.md rename to surfsense_backend/app/prompts/system_prompt_composer/providers/anthropic.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/deepseek.md b/surfsense_backend/app/prompts/system_prompt_composer/providers/deepseek.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/providers/deepseek.md rename to surfsense_backend/app/prompts/system_prompt_composer/providers/deepseek.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/default.md b/surfsense_backend/app/prompts/system_prompt_composer/providers/default.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/providers/default.md rename to surfsense_backend/app/prompts/system_prompt_composer/providers/default.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/google.md b/surfsense_backend/app/prompts/system_prompt_composer/providers/google.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/providers/google.md rename to surfsense_backend/app/prompts/system_prompt_composer/providers/google.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/grok.md b/surfsense_backend/app/prompts/system_prompt_composer/providers/grok.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/providers/grok.md rename to surfsense_backend/app/prompts/system_prompt_composer/providers/grok.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/kimi.md b/surfsense_backend/app/prompts/system_prompt_composer/providers/kimi.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/providers/kimi.md rename to surfsense_backend/app/prompts/system_prompt_composer/providers/kimi.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/openai_classic.md b/surfsense_backend/app/prompts/system_prompt_composer/providers/openai_classic.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/providers/openai_classic.md rename to surfsense_backend/app/prompts/system_prompt_composer/providers/openai_classic.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/openai_codex.md b/surfsense_backend/app/prompts/system_prompt_composer/providers/openai_codex.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/providers/openai_codex.md rename to surfsense_backend/app/prompts/system_prompt_composer/providers/openai_codex.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/openai_reasoning.md b/surfsense_backend/app/prompts/system_prompt_composer/providers/openai_reasoning.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/providers/openai_reasoning.md rename to surfsense_backend/app/prompts/system_prompt_composer/providers/openai_reasoning.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/__init__.py b/surfsense_backend/app/prompts/system_prompt_composer/routing/__init__.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/tools/__init__.py rename to surfsense_backend/app/prompts/system_prompt_composer/routing/__init__.py diff --git a/surfsense_backend/app/agents/new_chat/prompts/routing/jira.md b/surfsense_backend/app/prompts/system_prompt_composer/routing/jira.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/routing/jira.md rename to surfsense_backend/app/prompts/system_prompt_composer/routing/jira.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/routing/linear.md b/surfsense_backend/app/prompts/system_prompt_composer/routing/linear.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/routing/linear.md rename to surfsense_backend/app/prompts/system_prompt_composer/routing/linear.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/routing/slack.md b/surfsense_backend/app/prompts/system_prompt_composer/routing/slack.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/routing/slack.md rename to surfsense_backend/app/prompts/system_prompt_composer/routing/slack.md diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/__init__.py b/surfsense_backend/app/prompts/system_prompt_composer/tools/__init__.py similarity index 100% rename from surfsense_backend/app/agents/new_chat/skills/builtin/__init__.py rename to surfsense_backend/app/prompts/system_prompt_composer/tools/__init__.py diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/_preamble.md b/surfsense_backend/app/prompts/system_prompt_composer/tools/_preamble.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/tools/_preamble.md rename to surfsense_backend/app/prompts/system_prompt_composer/tools/_preamble.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_image.md b/surfsense_backend/app/prompts/system_prompt_composer/tools/generate_image.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/tools/generate_image.md rename to surfsense_backend/app/prompts/system_prompt_composer/tools/generate_image.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_podcast.md b/surfsense_backend/app/prompts/system_prompt_composer/tools/generate_podcast.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/tools/generate_podcast.md rename to surfsense_backend/app/prompts/system_prompt_composer/tools/generate_podcast.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_report.md b/surfsense_backend/app/prompts/system_prompt_composer/tools/generate_report.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/tools/generate_report.md rename to surfsense_backend/app/prompts/system_prompt_composer/tools/generate_report.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_resume.md b/surfsense_backend/app/prompts/system_prompt_composer/tools/generate_resume.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/tools/generate_resume.md rename to surfsense_backend/app/prompts/system_prompt_composer/tools/generate_resume.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_video_presentation.md b/surfsense_backend/app/prompts/system_prompt_composer/tools/generate_video_presentation.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/tools/generate_video_presentation.md rename to surfsense_backend/app/prompts/system_prompt_composer/tools/generate_video_presentation.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/scrape_webpage.md b/surfsense_backend/app/prompts/system_prompt_composer/tools/scrape_webpage.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/tools/scrape_webpage.md rename to surfsense_backend/app/prompts/system_prompt_composer/tools/scrape_webpage.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_private.md b/surfsense_backend/app/prompts/system_prompt_composer/tools/update_memory_private.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_private.md rename to surfsense_backend/app/prompts/system_prompt_composer/tools/update_memory_private.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_team.md b/surfsense_backend/app/prompts/system_prompt_composer/tools/update_memory_team.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_team.md rename to surfsense_backend/app/prompts/system_prompt_composer/tools/update_memory_team.md diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/web_search.md b/surfsense_backend/app/prompts/system_prompt_composer/tools/web_search.md similarity index 100% rename from surfsense_backend/app/agents/new_chat/prompts/tools/web_search.md rename to surfsense_backend/app/prompts/system_prompt_composer/tools/web_search.md diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index dba1bfe7d..4750b9948 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -45,7 +45,7 @@ from .model_list_routes import router as model_list_router from .new_chat_routes import router as new_chat_router from .new_llm_config_routes import router as new_llm_config_router from .notes_routes import router as notes_router -from .notifications_routes import router as notifications_router +from app.notifications.api import router as notifications_router from .notion_add_connector_route import router as notion_add_connector_router from .obsidian_plugin_routes import router as obsidian_plugin_router from .onedrive_add_connector_route import router as onedrive_add_connector_router diff --git a/surfsense_backend/app/routes/agent_action_log_route.py b/surfsense_backend/app/routes/agent_action_log_route.py index 2608aa3b1..9a55fdec3 100644 --- a/surfsense_backend/app/routes/agent_action_log_route.py +++ b/surfsense_backend/app/routes/agent_action_log_route.py @@ -28,7 +28,7 @@ from pydantic import BaseModel from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.feature_flags import get_flags +from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags from app.db import ( AgentActionLog, NewChatThread, diff --git a/surfsense_backend/app/routes/agent_flags_route.py b/surfsense_backend/app/routes/agent_flags_route.py index 99388af66..e97608cbe 100644 --- a/surfsense_backend/app/routes/agent_flags_route.py +++ b/surfsense_backend/app/routes/agent_flags_route.py @@ -22,7 +22,10 @@ from dataclasses import asdict from fastapi import APIRouter, Depends from pydantic import BaseModel -from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags +from app.agents.chat.multi_agent_chat.shared.feature_flags import ( + AgentFeatureFlags, + get_flags, +) from app.config import config from app.db import User from app.users import current_active_user diff --git a/surfsense_backend/app/routes/agent_permissions_route.py b/surfsense_backend/app/routes/agent_permissions_route.py index 1c76e00e6..0c07eeb9c 100644 --- a/surfsense_backend/app/routes/agent_permissions_route.py +++ b/surfsense_backend/app/routes/agent_permissions_route.py @@ -30,7 +30,7 @@ from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.feature_flags import get_flags +from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags from app.db import ( AgentPermissionRule, NewChatThread, diff --git a/surfsense_backend/app/routes/agent_revert_route.py b/surfsense_backend/app/routes/agent_revert_route.py index 711081b15..ce21de69d 100644 --- a/surfsense_backend/app/routes/agent_revert_route.py +++ b/surfsense_backend/app/routes/agent_revert_route.py @@ -32,7 +32,7 @@ from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.feature_flags import get_flags +from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags from app.db import ( AgentActionLog, User, diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py index eb952e684..ad3277375 100644 --- a/surfsense_backend/app/routes/anonymous_chat_routes.py +++ b/surfsense_backend/app/routes/anonymous_chat_routes.py @@ -236,7 +236,7 @@ async def stream_anonymous_chat( detail="No-login mode is not enabled.", ) - from app.agents.new_chat.llm_config import ( + from app.agents.chat.runtime.llm_config import ( AgentConfig, create_chat_litellm_from_agent_config, ) @@ -351,12 +351,13 @@ async def stream_anonymous_chat( async def _generate(): from langchain_core.messages import AIMessage, HumanMessage - from app.agents.new_chat.anonymous_agent import create_anonymous_chat_agent - from app.agents.new_chat.checkpointer import get_checkpointer + from app.agents.chat.anonymous_chat import create_anonymous_chat_agent + from app.agents.chat.runtime.checkpointer import get_checkpointer from app.db import shielded_async_session from app.services.new_streaming_service import VercelStreamingService from app.services.token_tracking_service import start_turn - from app.tasks.chat.stream_new_chat import StreamResult, _stream_agent_events + from app.tasks.chat.streaming.agent.event_loop import stream_agent_events + from app.tasks.chat.streaming.shared.stream_result import StreamResult accumulator = start_turn() streaming_service = VercelStreamingService() @@ -419,7 +420,7 @@ async def stream_anonymous_chat( stream_result = StreamResult() - async for sse in _stream_agent_events( + async for sse in stream_agent_events( agent=agent, config=langgraph_config, input_data=input_state, diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index cafd34ef7..865068fba 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload -from app.agents.new_chat.path_resolver import virtual_path_to_doc +from app.agents.chat.runtime.path_resolver import virtual_path_to_doc from app.db import ( Chunk, Document, diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index 57248d631..fdeb6ecfd 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -665,7 +665,7 @@ def _refresh_mcp_cache(connector_id: int, space_id: int) -> None: isolated from the OAuth response flow. """ try: - from app.agents.new_chat.tools.mcp_tools_cache import ( + from app.agents.chat.multi_agent_chat.shared.tools.mcp.cache import ( refresh_mcp_tools_cache_for_connector, ) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 63b7732a9..0e4e557be 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -24,18 +24,18 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload -from app.agents.new_chat.filesystem_selection import ( - ClientPlatform, - FilesystemMode, - FilesystemSelection, - LocalFilesystemMount, -) -from app.agents.new_chat.middleware.busy_mutex import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import ( get_cancel_state, is_cancel_requested, manager, request_cancel, ) +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( + ClientPlatform, + FilesystemMode, + FilesystemSelection, + LocalFilesystemMount, +) from app.config import config from app.db import ( ChatComment, @@ -71,7 +71,7 @@ from app.schemas.new_chat import ( TokenUsageSummary, TurnStatusResponse, ) -from app.tasks.chat.stream_new_chat import ( +from app.tasks.chat.streaming.flows import ( stream_new_chat, stream_resume_chat, ) @@ -476,7 +476,7 @@ async def _revert_turns_for_regenerate( def _try_delete_sandbox(thread_id: int) -> None: """Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked.""" - from app.agents.new_chat.sandbox import ( + from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox import ( delete_local_sandbox_files, delete_sandbox, is_sandbox_enabled, @@ -1668,7 +1668,7 @@ async def list_agent_tools( Hidden (WIP) tools are excluded from the response. """ - from app.agents.new_chat.tools.registry import BUILTIN_TOOLS + from app.agents.chat.multi_agent_chat.shared.tools.catalog import TOOL_CATALOG return [ AgentToolInfo( @@ -1676,7 +1676,7 @@ async def list_agent_tools( description=t.description, enabled_by_default=t.enabled_by_default, ) - for t in BUILTIN_TOOLS + for t in TOOL_CATALOG if not t.hidden ] @@ -1934,7 +1934,7 @@ async def regenerate_response( """ from langchain_core.messages import HumanMessage - from app.agents.new_chat.checkpointer import get_checkpointer + from app.agents.chat.runtime.checkpointer import get_checkpointer try: # Verify thread exists and user has permission diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py index e090a1a7c..84d66bb13 100644 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ b/surfsense_backend/app/routes/new_llm_config_routes.py @@ -13,7 +13,6 @@ from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.agents.new_chat.system_prompt import get_default_system_instructions from app.config import config from app.db import ( NewLLMConfig, @@ -21,6 +20,7 @@ from app.db import ( User, get_async_session, ) +from app.prompts.default_system_instructions import get_default_system_instructions from app.schemas import ( DefaultSystemInstructionsResponse, GlobalNewLLMConfigRead, diff --git a/surfsense_backend/app/routes/obsidian_plugin_routes.py b/surfsense_backend/app/routes/obsidian_plugin_routes.py index 0dae7a463..512596550 100644 --- a/surfsense_backend/app/routes/obsidian_plugin_routes.py +++ b/surfsense_backend/app/routes/obsidian_plugin_routes.py @@ -43,7 +43,7 @@ from app.schemas.obsidian_plugin import ( SyncAckItem, SyncBatchRequest, ) -from app.services.notification_service import NotificationService +from app.notifications.service import NotificationService from app.services.obsidian_plugin_indexer import ( delete_note, get_manifest, diff --git a/surfsense_backend/app/routes/sandbox_routes.py b/surfsense_backend/app/routes/sandbox_routes.py index f656e8d76..fefe51997 100644 --- a/surfsense_backend/app/routes/sandbox_routes.py +++ b/surfsense_backend/app/routes/sandbox_routes.py @@ -51,7 +51,10 @@ async def download_sandbox_file( ): """Download a file from the Daytona sandbox associated with a chat thread.""" - from app.agents.new_chat.sandbox import get_or_create_sandbox, is_sandbox_enabled + from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox import ( + get_or_create_sandbox, + is_sandbox_enabled, + ) if not is_sandbox_enabled(): raise HTTPException(status_code=404, detail="Sandbox is not enabled") @@ -71,7 +74,9 @@ async def download_sandbox_file( "You don't have permission to access files in this thread", ) - from app.agents.new_chat.sandbox import get_local_sandbox_file + from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox import ( + get_local_sandbox_file, + ) # Prefer locally-persisted copy (sandbox may already be deleted) local_content = get_local_sandbox_file(thread_id, path) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 3060fdf4a..dc26b4c02 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -43,6 +43,7 @@ from app.db import ( async_session_maker, get_async_session, ) +from app.notifications.service import NotificationService from app.observability import metrics as ot_metrics, otel as ot from app.schemas import ( GoogleDriveIndexRequest, @@ -55,7 +56,6 @@ from app.schemas import ( SearchSourceConnectorUpdate, ) from app.services.composio_service import ComposioService, get_composio_service -from app.services.notification_service import NotificationService from app.users import current_active_user # NOTE: connector indexer functions are imported lazily inside each @@ -675,7 +675,9 @@ async def delete_search_source_connector( await session.commit() if is_mcp: - from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + from app.agents.chat.multi_agent_chat.shared.tools.mcp.tool import ( + invalidate_mcp_tools_cache, + ) invalidate_mcp_tools_cache(search_space_id) @@ -2687,7 +2689,7 @@ async def create_mcp_connector( f"for user {user.id} in search space {search_space_id}" ) - from app.agents.new_chat.tools.mcp_tools_cache import ( + from app.agents.chat.multi_agent_chat.shared.tools.mcp.cache import ( refresh_mcp_tools_cache_for_connector, ) @@ -2867,7 +2869,7 @@ async def update_mcp_connector( logger.info(f"Updated MCP connector {connector_id}") - from app.agents.new_chat.tools.mcp_tools_cache import ( + from app.agents.chat.multi_agent_chat.shared.tools.mcp.cache import ( refresh_mcp_tools_cache_for_connector, ) @@ -2927,7 +2929,9 @@ async def delete_mcp_connector( await session.delete(connector) await session.commit() - from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + from app.agents.chat.multi_agent_chat.shared.tools.mcp.tool import ( + invalidate_mcp_tools_cache, + ) invalidate_mcp_tools_cache(search_space_id) @@ -2966,7 +2970,7 @@ async def test_mcp_server_connection( Connection status and list of available tools """ try: - from app.agents.new_chat.tools.mcp_client import ( + from app.agents.chat.multi_agent_chat.shared.tools.mcp.client import ( test_mcp_connection, test_mcp_http_connection, ) @@ -3157,7 +3161,9 @@ async def trust_mcp_tool( connectors (``LINEAR_CONNECTOR``, ``JIRA_CONNECTOR``, ...) — the storage primitive is the same JSON list under ``config.trusted_tools``. """ - from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + from app.agents.chat.multi_agent_chat.shared.tools.mcp.tool import ( + invalidate_mcp_tools_cache, + ) from app.services.user_tool_allowlist import add_user_trust try: @@ -3197,7 +3203,9 @@ async def untrust_mcp_tool( The tool will require HITL approval again on subsequent calls. """ - from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + from app.agents.chat.multi_agent_chat.shared.tools.mcp.tool import ( + invalidate_mcp_tools_cache, + ) from app.services.user_tool_allowlist import remove_user_trust try: diff --git a/surfsense_backend/app/services/chat_comments_service.py b/surfsense_backend/app/services/chat_comments_service.py index 54662fe5b..c9afb8a67 100644 --- a/surfsense_backend/app/services/chat_comments_service.py +++ b/surfsense_backend/app/services/chat_comments_service.py @@ -31,7 +31,7 @@ from app.schemas.chat_comments import ( MentionListResponse, MentionResponse, ) -from app.services.notification_service import NotificationService +from app.notifications.service import NotificationService from app.utils.chat_comments import parse_mentions, render_mentions from app.utils.rbac import check_permission, get_user_permissions diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 099e7c573..7061a826f 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -203,7 +203,9 @@ async def validate_llm_config( if litellm_params: litellm_kwargs.update(litellm_params) - from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + from app.agents.chat.runtime.llm_config import ( + SanitizedChatLiteLLM, + ) llm = SanitizedChatLiteLLM(**litellm_kwargs) @@ -375,7 +377,9 @@ async def get_search_space_llm_instance( if disable_streaming: litellm_kwargs["disable_streaming"] = True - from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + from app.agents.chat.runtime.llm_config import ( + SanitizedChatLiteLLM, + ) return SanitizedChatLiteLLM(**litellm_kwargs) @@ -454,7 +458,9 @@ async def get_search_space_llm_instance( if disable_streaming: litellm_kwargs["disable_streaming"] = True - from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + from app.agents.chat.runtime.llm_config import ( + SanitizedChatLiteLLM, + ) return SanitizedChatLiteLLM(**litellm_kwargs) @@ -569,7 +575,9 @@ async def get_vision_llm( if global_cfg.get("litellm_params"): litellm_kwargs.update(global_cfg["litellm_params"]) - from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + from app.agents.chat.runtime.llm_config import ( + SanitizedChatLiteLLM, + ) inner_llm = SanitizedChatLiteLLM(**litellm_kwargs) @@ -623,7 +631,9 @@ async def get_vision_llm( if vision_cfg.litellm_params: litellm_kwargs.update(vision_cfg.litellm_params) - from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + from app.agents.chat.runtime.llm_config import ( + SanitizedChatLiteLLM, + ) return SanitizedChatLiteLLM(**litellm_kwargs) @@ -652,7 +662,9 @@ def get_planner_llm() -> ChatLiteLLM | None: Callers MUST fall back to their chat LLM when this returns ``None`` so deployments without a planner config keep working unchanged. """ - from app.agents.new_chat.llm_config import create_chat_litellm_from_config + from app.agents.chat.runtime.llm_config import ( + create_chat_litellm_from_config, + ) planner_cfg = next( (cfg for cfg in config.GLOBAL_LLM_CONFIGS if cfg.get("is_planner") is True), diff --git a/surfsense_backend/app/services/notification_service.py b/surfsense_backend/app/services/notification_service.py deleted file mode 100644 index 5ffee12d7..000000000 --- a/surfsense_backend/app/services/notification_service.py +++ /dev/null @@ -1,1089 +0,0 @@ -"""Service for creating and managing notifications with Zero sync.""" - -import logging -from datetime import UTC, datetime -from typing import Any -from uuid import UUID - -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm.attributes import flag_modified - -from app.db import Notification - -logger = logging.getLogger(__name__) - - -class BaseNotificationHandler: - """Base class for notification handlers - provides common functionality.""" - - def __init__(self, notification_type: str): - """ - Initialize the notification handler. - - Args: - notification_type: Type of notification (e.g., 'connector_indexing', 'document_processing') - """ - self.notification_type = notification_type - - async def find_notification_by_operation( - self, - session: AsyncSession, - user_id: UUID, - operation_id: str, - search_space_id: int | None = None, - ) -> Notification | None: - """ - Find an existing notification by operation ID. - - Args: - session: Database session - user_id: User ID - operation_id: Unique operation identifier - search_space_id: Optional search space ID - - Returns: - Notification if found, None otherwise - """ - query = select(Notification).where( - Notification.user_id == user_id, - Notification.type == self.notification_type, - Notification.notification_metadata["operation_id"].astext == operation_id, - ) - if search_space_id is not None: - query = query.where(Notification.search_space_id == search_space_id) - - result = await session.execute(query) - return result.scalar_one_or_none() - - async def find_or_create_notification( - self, - session: AsyncSession, - user_id: UUID, - operation_id: str, - title: str, - message: str, - search_space_id: int | None = None, - initial_metadata: dict[str, Any] | None = None, - ) -> Notification: - """ - Find an existing notification or create a new one. - - Args: - session: Database session - user_id: User ID - operation_id: Unique operation identifier - title: Notification title - message: Notification message - search_space_id: Optional search space ID - initial_metadata: Initial metadata dictionary - - Returns: - Notification: The found or created notification - """ - # Try to find existing notification - notification = await self.find_notification_by_operation( - session, user_id, operation_id, search_space_id - ) - - if notification: - # Update existing notification - notification.title = title - notification.message = message - if initial_metadata: - notification.notification_metadata = { - **notification.notification_metadata, - **initial_metadata, - } - # Mark JSONB column as modified so SQLAlchemy detects the change - flag_modified(notification, "notification_metadata") - await session.commit() - await session.refresh(notification) - logger.info( - f"Updated notification {notification.id} for operation {operation_id}" - ) - return notification - - # Create new notification - metadata = initial_metadata or {} - metadata["operation_id"] = operation_id - metadata["status"] = "in_progress" - metadata["started_at"] = datetime.now(UTC).isoformat() - - notification = Notification( - user_id=user_id, - search_space_id=search_space_id, - type=self.notification_type, - title=title, - message=message, - notification_metadata=metadata, - ) - session.add(notification) - await session.commit() - await session.refresh(notification) - logger.info( - f"Created notification {notification.id} for operation {operation_id}" - ) - return notification - - async def update_notification( - self, - session: AsyncSession, - notification: Notification, - title: str | None = None, - message: str | None = None, - status: str | None = None, - metadata_updates: dict[str, Any] | None = None, - ) -> Notification: - """ - Update an existing notification. - - Args: - session: Database session - notification: Notification to update - title: New title (optional) - message: New message (optional) - status: New status (optional) - metadata_updates: Additional metadata to merge (optional) - - Returns: - Updated notification - """ - if title is not None: - notification.title = title - if message is not None: - notification.message = message - - if status is not None: - notification.notification_metadata["status"] = status - if status in ("completed", "failed"): - notification.notification_metadata["completed_at"] = datetime.now( - UTC - ).isoformat() - # Mark JSONB column as modified so SQLAlchemy detects the change - flag_modified(notification, "notification_metadata") - - if metadata_updates: - notification.notification_metadata = { - **notification.notification_metadata, - **metadata_updates, - } - # Mark JSONB column as modified - flag_modified(notification, "notification_metadata") - - await session.commit() - await session.refresh(notification) - logger.info(f"Updated notification {notification.id}") - return notification - - -class ConnectorIndexingNotificationHandler(BaseNotificationHandler): - """Handler for connector indexing notifications.""" - - def __init__(self): - super().__init__("connector_indexing") - - def _generate_operation_id( - self, - connector_id: int, - start_date: str | None = None, - end_date: str | None = None, - ) -> str: - """ - Generate a unique operation ID for a connector indexing operation. - - Args: - connector_id: Connector ID - start_date: Start date (optional) - end_date: End date (optional) - - Returns: - Unique operation ID string - """ - timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S") - date_range = "" - if start_date or end_date: - date_range = f"_{start_date or 'none'}_{end_date or 'none'}" - return f"connector_{connector_id}_{timestamp}{date_range}" - - def _generate_google_drive_operation_id( - self, connector_id: int, folder_count: int, file_count: int - ) -> str: - """ - Generate a unique operation ID for a Google Drive indexing operation. - - Args: - connector_id: Connector ID - folder_count: Number of folders to index - file_count: Number of files to index - - Returns: - Unique operation ID string - """ - timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S") - items_info = f"_{folder_count}f_{file_count}files" - return f"drive_{connector_id}_{timestamp}{items_info}" - - async def notify_indexing_started( - self, - session: AsyncSession, - user_id: UUID, - connector_id: int, - connector_name: str, - connector_type: str, - search_space_id: int, - start_date: str | None = None, - end_date: str | None = None, - ) -> Notification: - """ - Create or update notification when connector indexing starts. - - Args: - session: Database session - user_id: User ID - connector_id: Connector ID - connector_name: Connector name - connector_type: Connector type - search_space_id: Search space ID - start_date: Start date for indexing - end_date: End date for indexing - - Returns: - Notification: The created or updated notification - """ - operation_id = self._generate_operation_id(connector_id, start_date, end_date) - title = f"Syncing: {connector_name}" - message = "Connecting to your account" - - metadata = { - "connector_id": connector_id, - "connector_name": connector_name, - "connector_type": connector_type, - "start_date": start_date, - "end_date": end_date, - "indexed_count": 0, - "sync_stage": "connecting", - } - - return await self.find_or_create_notification( - session=session, - user_id=user_id, - operation_id=operation_id, - title=title, - message=message, - search_space_id=search_space_id, - initial_metadata=metadata, - ) - - async def notify_indexing_progress( - self, - session: AsyncSession, - notification: Notification, - indexed_count: int, - total_count: int | None = None, - stage: str | None = None, - stage_message: str | None = None, - ) -> Notification: - """ - Update notification with indexing progress. - - Args: - session: Database session - notification: Notification to update - indexed_count: Number of items indexed so far - total_count: Total number of items (optional) - stage: Current sync stage (fetching, processing, storing) (optional) - stage_message: Optional custom message for the stage - - Returns: - Updated notification - """ - # User-friendly stage messages (clean, no ellipsis - spinner shows activity) - stage_messages = { - "connecting": "Connecting to your account", - "fetching": "Fetching your content", - "processing": "Preparing for search", - "storing": "Almost done", - } - - # Use stage-based message if stage provided, otherwise fallback - if stage or stage_message: - progress_msg = stage_message or stage_messages.get(stage, "Processing") - else: - # Fallback for backward compatibility - progress_msg = "Fetching your content" - - metadata_updates = {"indexed_count": indexed_count} - if total_count is not None: - metadata_updates["total_count"] = total_count - progress_percent = int((indexed_count / total_count) * 100) - metadata_updates["progress_percent"] = progress_percent - if stage: - metadata_updates["sync_stage"] = stage - - return await self.update_notification( - session=session, - notification=notification, - message=progress_msg, - status="in_progress", - metadata_updates=metadata_updates, - ) - - async def notify_retry_progress( - self, - session: AsyncSession, - notification: Notification, - indexed_count: int, - retry_reason: str, - attempt: int, - max_attempts: int, - wait_seconds: float | None = None, - service_name: str | None = None, - ) -> Notification: - """ - Update notification when a connector is retrying due to rate limits or errors. - - This method provides user-friendly feedback when external service limitations - (rate limits, temporary outages) cause delays. Users see that the delay is - not our fault and the sync is still progressing. - - This method can be used by ANY connector (Notion, Slack, Airtable, etc.) - when they hit rate limits or transient errors. - - Args: - session: Database session - notification: Notification to update - indexed_count: Number of items indexed so far - retry_reason: Reason for retry ('rate_limit', 'server_error', 'timeout') - attempt: Current retry attempt number (1-based) - max_attempts: Maximum number of retry attempts - wait_seconds: Seconds to wait before retry (optional, for display) - service_name: Name of the external service (e.g., 'Notion', 'Slack') - If not provided, extracts from notification metadata - - Returns: - Updated notification - """ - # Get service name from notification if not provided - if not service_name: - service_name = notification.notification_metadata.get( - "connector_name", "Service" - ) - # Extract just the service name if it's "Notion - My Workspace" - if " - " in service_name: - service_name = service_name.split(" - ")[0] - - # User-friendly messages for different retry reasons - # These make it clear the delay is due to the external service, not SurfSense - retry_messages = { - "rate_limit": f"{service_name} rate limit reached", - "server_error": f"{service_name} is slow to respond", - "timeout": f"{service_name} took too long", - "temporary_error": f"{service_name} temporarily unavailable", - } - - base_message = retry_messages.get(retry_reason, f"Waiting for {service_name}") - - # Add wait time and progress info - if wait_seconds and wait_seconds > 5: - # Only show wait time if it's significant - message = f"{base_message}. Retrying in {int(wait_seconds)}s..." - else: - message = f"{base_message}. Retrying..." - - # Add progress count if we have any - if indexed_count > 0: - item_text = "item" if indexed_count == 1 else "items" - message = f"{message} ({indexed_count} {item_text} synced so far)" - - metadata_updates = { - "indexed_count": indexed_count, - "sync_stage": "waiting_retry", - "retry_attempt": attempt, - "retry_max_attempts": max_attempts, - "retry_reason": retry_reason, - "retry_wait_seconds": wait_seconds, - } - - return await self.update_notification( - session=session, - notification=notification, - message=message, - status="in_progress", - metadata_updates=metadata_updates, - ) - - async def notify_indexing_completed( - self, - session: AsyncSession, - notification: Notification, - indexed_count: int, - error_message: str | None = None, - is_warning: bool = False, - skipped_count: int | None = None, - unsupported_count: int | None = None, - ) -> Notification: - """ - Update notification when connector indexing completes. - - Args: - session: Database session - notification: Notification to update - indexed_count: Total number of files indexed - error_message: Error message if indexing failed, or warning message (optional) - is_warning: If True, treat error_message as a warning (success case) rather than an error - skipped_count: Number of files skipped (e.g., unchanged) - optional - unsupported_count: Number of files skipped because the ETL parser doesn't support them - - Returns: - Updated notification - """ - connector_name = notification.notification_metadata.get( - "connector_name", "Connector" - ) - - unsupported_text = "" - if unsupported_count and unsupported_count > 0: - file_word = "file was" if unsupported_count == 1 else "files were" - unsupported_text = f" {unsupported_count} {file_word} not supported." - - if error_message: - if indexed_count > 0: - title = f"Ready: {connector_name}" - file_text = "file" if indexed_count == 1 else "files" - message = f"Now searchable! {indexed_count} {file_text} synced.{unsupported_text} Note: {error_message}" - status = "completed" - elif is_warning: - title = f"Ready: {connector_name}" - message = f"Sync complete.{unsupported_text} {error_message}" - status = "completed" - else: - title = f"Failed: {connector_name}" - message = f"Sync failed: {error_message}" - if unsupported_text: - message += unsupported_text - status = "failed" - else: - title = f"Ready: {connector_name}" - if indexed_count == 0: - if unsupported_count and unsupported_count > 0: - message = f"Sync complete.{unsupported_text}" - else: - message = "Already up to date!" - else: - file_text = "file" if indexed_count == 1 else "files" - message = f"Now searchable! {indexed_count} {file_text} synced." - if unsupported_text: - message += unsupported_text - status = "completed" - - metadata_updates = { - "indexed_count": indexed_count, - "skipped_count": skipped_count or 0, - "unsupported_count": unsupported_count or 0, - "sync_stage": "completed" - if (not error_message or is_warning or indexed_count > 0) - else "failed", - "error_message": error_message, - } - - return await self.update_notification( - session=session, - notification=notification, - title=title, - message=message, - status=status, - metadata_updates=metadata_updates, - ) - - async def notify_google_drive_indexing_started( - self, - session: AsyncSession, - user_id: UUID, - connector_id: int, - connector_name: str, - connector_type: str, - search_space_id: int, - folder_count: int, - file_count: int, - folder_names: list[str] | None = None, - file_names: list[str] | None = None, - ) -> Notification: - """ - Create or update notification when Google Drive indexing starts. - - Args: - session: Database session - user_id: User ID - connector_id: Connector ID - connector_name: Connector name - connector_type: Connector type - search_space_id: Search space ID - folder_count: Number of folders to index - file_count: Number of files to index - folder_names: List of folder names (optional) - file_names: List of file names (optional) - - Returns: - Notification: The created or updated notification - """ - operation_id = self._generate_google_drive_operation_id( - connector_id, folder_count, file_count - ) - title = f"Syncing: {connector_name}" - message = "Preparing your files" - - metadata = { - "connector_id": connector_id, - "connector_name": connector_name, - "connector_type": connector_type, - "folder_count": folder_count, - "file_count": file_count, - "indexed_count": 0, - "sync_stage": "connecting", - } - - if folder_names: - metadata["folder_names"] = folder_names - if file_names: - metadata["file_names"] = file_names - - return await self.find_or_create_notification( - session=session, - user_id=user_id, - operation_id=operation_id, - title=title, - message=message, - search_space_id=search_space_id, - initial_metadata=metadata, - ) - - -class DocumentProcessingNotificationHandler(BaseNotificationHandler): - """Handler for document processing notifications.""" - - def __init__(self): - super().__init__("document_processing") - - def _generate_operation_id( - self, document_type: str, filename: str, search_space_id: int - ) -> str: - """ - Generate a unique operation ID for a document processing operation. - - Args: - document_type: Type of document (FILE, YOUTUBE_VIDEO, CRAWLED_URL, etc.) - filename: Name of the file/document - search_space_id: Search space ID - - Returns: - Unique operation ID string - """ - timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f") - # Create a short hash of filename to ensure uniqueness - import hashlib - - filename_hash = hashlib.md5(filename.encode()).hexdigest()[:8] - return f"doc_{document_type}_{search_space_id}_{timestamp}_{filename_hash}" - - async def notify_processing_started( - self, - session: AsyncSession, - user_id: UUID, - document_type: str, - document_name: str, - search_space_id: int, - file_size: int | None = None, - ) -> Notification: - """ - Create notification when document processing starts. - - Args: - session: Database session - user_id: User ID - document_type: Type of document (FILE, YOUTUBE_VIDEO, CRAWLED_URL, etc.) - document_name: Name/title of the document - search_space_id: Search space ID - file_size: Size of file in bytes (optional) - - Returns: - Notification: The created notification - """ - operation_id = self._generate_operation_id( - document_type, document_name, search_space_id - ) - title = f"Processing: {document_name}" - message = "Waiting in queue" - - metadata = { - "document_type": document_type, - "document_name": document_name, - "processing_stage": "queued", - } - - if file_size is not None: - metadata["file_size"] = file_size - - return await self.find_or_create_notification( - session=session, - user_id=user_id, - operation_id=operation_id, - title=title, - message=message, - search_space_id=search_space_id, - initial_metadata=metadata, - ) - - async def notify_processing_progress( - self, - session: AsyncSession, - notification: Notification, - stage: str, - stage_message: str | None = None, - chunks_count: int | None = None, - ) -> Notification: - """ - Update notification with processing progress. - - Args: - session: Database session - notification: Notification to update - stage: Current processing stage (parsing, chunking, embedding, storing) - stage_message: Optional custom message for the stage - chunks_count: Number of chunks created (optional, stored in metadata only) - - Returns: - Updated notification - """ - # User-friendly stage messages - stage_messages = { - "parsing": "Reading your file", - "chunking": "Preparing for search", - "embedding": "Preparing for search", - "storing": "Finalizing", - } - - message = stage_message or stage_messages.get(stage, "Processing") - - metadata_updates = {"processing_stage": stage} - # Store chunks_count in metadata for debugging, but don't show to user - if chunks_count is not None: - metadata_updates["chunks_count"] = chunks_count - - return await self.update_notification( - session=session, - notification=notification, - message=message, - status="in_progress", - metadata_updates=metadata_updates, - ) - - async def notify_processing_completed( - self, - session: AsyncSession, - notification: Notification, - document_id: int | None = None, - chunks_count: int | None = None, - error_message: str | None = None, - ) -> Notification: - """ - Update notification when document processing completes. - - Args: - session: Database session - notification: Notification to update - document_id: ID of the created document (optional) - chunks_count: Total number of chunks created (optional) - error_message: Error message if processing failed (optional) - - Returns: - Updated notification - """ - document_name = notification.notification_metadata.get( - "document_name", "Document" - ) - - if error_message: - title = f"Failed: {document_name}" - message = f"Processing failed: {error_message}" - status = "failed" - else: - title = f"Ready: {document_name}" - message = "Now searchable!" - status = "completed" - - metadata_updates = { - "processing_stage": "completed" if not error_message else "failed", - "error_message": error_message, - } - - if document_id is not None: - metadata_updates["document_id"] = document_id - # Store chunks_count in metadata for debugging, but don't show to user - if chunks_count is not None: - metadata_updates["chunks_count"] = chunks_count - - return await self.update_notification( - session=session, - notification=notification, - title=title, - message=message, - status=status, - metadata_updates=metadata_updates, - ) - - -class MentionNotificationHandler(BaseNotificationHandler): - """Handler for new mention notifications.""" - - def __init__(self): - super().__init__("new_mention") - - async def find_notification_by_mention( - self, - session: AsyncSession, - mention_id: int, - ) -> Notification | None: - """ - Find an existing notification by mention ID. - - Args: - session: Database session - mention_id: The mention ID to search for - - Returns: - Notification if found, None otherwise - """ - query = select(Notification).where( - Notification.type == self.notification_type, - Notification.notification_metadata["mention_id"].astext == str(mention_id), - ) - result = await session.execute(query) - return result.scalar_one_or_none() - - async def notify_new_mention( - self, - session: AsyncSession, - mentioned_user_id: UUID, - mention_id: int, - comment_id: int, - message_id: int, - thread_id: int, - thread_title: str, - author_id: str, - author_name: str, - author_avatar_url: str | None, - author_email: str, - content_preview: str, - search_space_id: int, - ) -> Notification: - """ - Create notification when a user is @mentioned in a comment. - Uses mention_id for idempotency to prevent duplicate notifications. - - Args: - session: Database session - mentioned_user_id: User who was mentioned - mention_id: ID of the mention record (used for idempotency) - comment_id: ID of the comment containing the mention - message_id: ID of the message being commented on - thread_id: ID of the chat thread - thread_title: Title of the chat thread - author_id: ID of the comment author - author_name: Display name of the comment author - author_avatar_url: Avatar URL of the comment author - author_email: Email of the comment author (for fallback initials) - content_preview: First ~100 chars of the comment - search_space_id: Search space ID - - Returns: - Notification: The created or existing notification - """ - # Check if notification already exists for this mention (idempotency) - existing = await self.find_notification_by_mention(session, mention_id) - if existing: - logger.info( - f"Notification already exists for mention {mention_id}, returning existing" - ) - return existing - - title = f"{author_name} mentioned you" - message = content_preview[:100] + ("..." if len(content_preview) > 100 else "") - - metadata = { - "mention_id": mention_id, - "comment_id": comment_id, - "message_id": message_id, - "thread_id": thread_id, - "thread_title": thread_title, - "author_id": author_id, - "author_name": author_name, - "author_avatar_url": author_avatar_url, - "author_email": author_email, - "content_preview": content_preview[:200], - } - - try: - notification = Notification( - user_id=mentioned_user_id, - search_space_id=search_space_id, - type=self.notification_type, - title=title, - message=message, - notification_metadata=metadata, - ) - session.add(notification) - await session.commit() - await session.refresh(notification) - logger.info( - f"Created new_mention notification {notification.id} for user {mentioned_user_id}" - ) - return notification - except Exception as e: - # Handle race condition - if duplicate key error, try to fetch existing - await session.rollback() - if ( - "duplicate key" in str(e).lower() - or "unique constraint" in str(e).lower() - ): - logger.warning( - f"Duplicate notification detected for mention {mention_id}, fetching existing" - ) - existing = await self.find_notification_by_mention(session, mention_id) - if existing: - return existing - # Re-raise if not a duplicate key error or couldn't find existing - raise - - -class CommentReplyNotificationHandler(BaseNotificationHandler): - """Handler for comment reply notifications.""" - - def __init__(self): - super().__init__("comment_reply") - - async def find_notification_by_reply( - self, - session: AsyncSession, - reply_id: int, - user_id: UUID, - ) -> Notification | None: - query = select(Notification).where( - Notification.type == self.notification_type, - Notification.user_id == user_id, - Notification.notification_metadata["reply_id"].astext == str(reply_id), - ) - result = await session.execute(query) - return result.scalar_one_or_none() - - async def notify_comment_reply( - self, - session: AsyncSession, - user_id: UUID, - reply_id: int, - parent_comment_id: int, - message_id: int, - thread_id: int, - thread_title: str, - author_id: str, - author_name: str, - author_avatar_url: str | None, - author_email: str, - content_preview: str, - search_space_id: int, - ) -> Notification: - existing = await self.find_notification_by_reply(session, reply_id, user_id) - if existing: - logger.info( - f"Notification already exists for reply {reply_id} to user {user_id}" - ) - return existing - - title = f"{author_name} replied in a thread" - message = content_preview[:100] + ("..." if len(content_preview) > 100 else "") - - metadata = { - "reply_id": reply_id, - "parent_comment_id": parent_comment_id, - "message_id": message_id, - "thread_id": thread_id, - "thread_title": thread_title, - "author_id": author_id, - "author_name": author_name, - "author_avatar_url": author_avatar_url, - "author_email": author_email, - "content_preview": content_preview[:200], - } - - try: - notification = Notification( - user_id=user_id, - search_space_id=search_space_id, - type=self.notification_type, - title=title, - message=message, - notification_metadata=metadata, - ) - session.add(notification) - await session.commit() - await session.refresh(notification) - logger.info( - f"Created comment_reply notification {notification.id} for user {user_id}" - ) - return notification - except Exception as e: - await session.rollback() - if ( - "duplicate key" in str(e).lower() - or "unique constraint" in str(e).lower() - ): - logger.warning( - f"Duplicate notification for reply {reply_id} to user {user_id}" - ) - existing = await self.find_notification_by_reply( - session, reply_id, user_id - ) - if existing: - return existing - raise - - -class PageLimitNotificationHandler(BaseNotificationHandler): - """Handler for page limit exceeded notifications.""" - - def __init__(self): - super().__init__("page_limit_exceeded") - - def _generate_operation_id(self, document_name: str, search_space_id: int) -> str: - """ - Generate a unique operation ID for a page limit exceeded notification. - - Args: - document_name: Name of the document that triggered the limit - search_space_id: Search space ID - - Returns: - Unique operation ID string - """ - import hashlib - - timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f") - # Create a short hash of document name to ensure uniqueness - doc_hash = hashlib.md5(document_name.encode()).hexdigest()[:8] - return f"page_limit_{search_space_id}_{timestamp}_{doc_hash}" - - async def notify_page_limit_exceeded( - self, - session: AsyncSession, - user_id: UUID, - document_name: str, - document_type: str, - search_space_id: int, - pages_used: int, - pages_limit: int, - pages_to_add: int, - ) -> Notification: - """ - Create notification when a document exceeds the user's page limit. - - Args: - session: Database session - user_id: User ID - document_name: Name of the document that triggered the limit - document_type: Type of document (FILE, YOUTUBE_VIDEO, etc.) - search_space_id: Search space ID - pages_used: Current number of pages used - pages_limit: User's page limit - pages_to_add: Number of pages the document would add - - Returns: - Notification: The created notification - """ - operation_id = self._generate_operation_id(document_name, search_space_id) - - # Truncate document name for title if too long - display_name = ( - document_name[:40] + "..." if len(document_name) > 40 else document_name - ) - title = f"Page limit exceeded: {display_name}" - message = f"This document has ~{pages_to_add} page(s) but you've used {pages_used}/{pages_limit} pages. Upgrade to process more documents." - - metadata = { - "operation_id": operation_id, - "document_name": document_name, - "document_type": document_type, - "pages_used": pages_used, - "pages_limit": pages_limit, - "pages_to_add": pages_to_add, - "status": "failed", - "error_type": "page_limit_exceeded", - # Navigation target for frontend - "action_url": f"/dashboard/{search_space_id}/more-pages", - "action_label": "Upgrade Plan", - } - - notification = Notification( - user_id=user_id, - search_space_id=search_space_id, - type=self.notification_type, - title=title, - message=message, - notification_metadata=metadata, - ) - session.add(notification) - await session.commit() - await session.refresh(notification) - logger.info( - f"Created page_limit_exceeded notification {notification.id} for user {user_id}" - ) - return notification - - -class NotificationService: - """Service for creating and managing notifications that sync via Zero.""" - - # Handler instances - connector_indexing = ConnectorIndexingNotificationHandler() - document_processing = DocumentProcessingNotificationHandler() - mention = MentionNotificationHandler() - comment_reply = CommentReplyNotificationHandler() - page_limit = PageLimitNotificationHandler() - - @staticmethod - async def create_notification( - session: AsyncSession, - user_id: UUID, - notification_type: str, - title: str, - message: str, - search_space_id: int | None = None, - notification_metadata: dict[str, Any] | None = None, - ) -> Notification: - """ - Create a notification - Zero will automatically sync it to frontend. - - Args: - session: Database session - user_id: User to notify - notification_type: Type of notification (e.g., 'document_processing', 'connector_indexing') - title: Notification title - message: Notification message - search_space_id: Optional search space ID - notification_metadata: Optional metadata dictionary - - Returns: - Notification: The created notification - """ - notification = Notification( - user_id=user_id, - search_space_id=search_space_id, - type=notification_type, - title=title, - message=message, - notification_metadata=notification_metadata or {}, - ) - session.add(notification) - await session.commit() - await session.refresh(notification) - logger.info(f"Created notification {notification.id} for user {user_id}") - return notification diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py index e9a1c33e1..f094c9954 100644 --- a/surfsense_backend/app/services/provider_capabilities.py +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -53,10 +53,10 @@ logger = logging.getLogger(__name__) # # Owned here because ``app.services.provider_capabilities`` is the # only edge that's safe to call from ``app.config``'s YAML loader at -# class-body init time. ``app.agents.new_chat.llm_config`` re-exports +# class-body init time. ``app.agents.chat.runtime.llm_config`` re-exports # this constant under the historical ``PROVIDER_MAP`` name; placing the # map there directly would re-introduce the -# ``app.config -> ... -> app.agents.new_chat.tools.generate_image -> +# ``app.config -> ... -> deliverables/tools/generate_image -> # app.config`` cycle that prompted the move. _PROVIDER_PREFIX_MAP: dict[str, str] = { "OPENAI": "openai", diff --git a/surfsense_backend/app/services/revert_service.py b/surfsense_backend/app/services/revert_service.py index 60f6503aa..6db5e2604 100644 --- a/surfsense_backend/app/services/revert_service.py +++ b/surfsense_backend/app/services/revert_service.py @@ -38,7 +38,7 @@ from typing import Any, Literal from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.path_resolver import ( +from app.agents.chat.runtime.path_resolver import ( DOCUMENTS_ROOT, safe_filename, safe_folder_segment, diff --git a/surfsense_backend/app/services/user_tool_allowlist.py b/surfsense_backend/app/services/user_tool_allowlist.py index fdfa51560..9b87fbdea 100644 --- a/surfsense_backend/app/services/user_tool_allowlist.py +++ b/surfsense_backend/app/services/user_tool_allowlist.py @@ -16,10 +16,10 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified -from app.agents.multi_agent_chat.constants import ( +from app.agents.chat.multi_agent_chat.constants import ( CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS, ) -from app.agents.new_chat.permissions import Rule, Ruleset +from app.agents.chat.multi_agent_chat.shared.permissions import Rule, Ruleset from app.db import SearchSourceConnector, async_session_maker logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index 5b7b62f35..211d9e5b3 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -10,7 +10,7 @@ from uuid import UUID from app.celery_app import celery_app from app.config import config from app.observability import metrics as ot_metrics -from app.services.notification_service import NotificationService +from app.notifications.service import NotificationService from app.services.task_logging_service import TaskLoggingService from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task from app.tasks.connector_indexers.local_folder_indexer import ( diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py index e41251407..e88fb58b9 100644 --- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py @@ -6,7 +6,8 @@ from datetime import UTC, datetime from sqlalchemy.future import select from app.celery_app import celery_app -from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.notifications.persistence import Notification from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task from app.utils.indexing_locks import is_connector_indexing_locked diff --git a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py index d51c85dee..5bf857d9b 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py @@ -33,7 +33,8 @@ from sqlalchemy.future import select from app.celery_app import celery_app from app.config import config -from app.db import Document, DocumentStatus, Notification +from app.db import Document, DocumentStatus +from app.notifications.persistence import Notification from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py deleted file mode 100644 index e150cf494..000000000 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ /dev/null @@ -1,3064 +0,0 @@ -""" -Streaming task for the new SurfSense deep agent chat. - -This module streams responses from the deep agent using the Vercel AI SDK -Data Stream Protocol (SSE format). - -Supports loading LLM configurations from: -- YAML files (negative IDs for global configs) -- NewLLMConfig database table (positive IDs for user-created configs with prompt settings) -""" - -import asyncio -import contextlib -import gc -import json -import logging -import sys -import time -from collections.abc import AsyncGenerator -from dataclasses import dataclass, field -from functools import partial -from typing import Any, Literal -from uuid import UUID - -import anyio -from langchain_core.messages import HumanMessage -from sqlalchemy.future import select - -from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent -from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent -from app.agents.new_chat.checkpointer import get_checkpointer -from app.agents.new_chat.context import SurfSenseContextSchema -from app.agents.new_chat.errors import BusyError -from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection -from app.agents.new_chat.llm_config import ( - AgentConfig, - create_chat_litellm_from_agent_config, - create_chat_litellm_from_config, - load_agent_config, - load_global_llm_config_by_id, -) -from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text -from app.agents.new_chat.middleware.busy_mutex import ( - end_turn, - get_cancel_state, - is_cancel_requested, -) -from app.agents.new_chat.middleware.kb_persistence import ( - commit_staged_filesystem_state, -) -from app.db import ( - ChatVisibility, - NewChatMessage, - NewChatThread, - Report, - SearchSourceConnectorType, - async_session_maker, - shielded_async_session, -) -from app.observability import metrics as ot_metrics, otel as ot -from app.prompts import TITLE_GENERATION_PROMPT -from app.services.auto_model_pin_service import ( - mark_runtime_cooldown, - resolve_or_get_pinned_llm_config_id, -) -from app.services.chat_session_state_service import ( - clear_ai_responding, - set_ai_responding, -) -from app.services.connector_service import ConnectorService -from app.services.new_streaming_service import VercelStreamingService -from app.tasks.chat.streaming.graph_stream.event_stream import stream_output -from app.tasks.chat.streaming.helpers.interrupt_inspector import ( - all_interrupt_values, -) -from app.utils.content_utils import bootstrap_history_from_db -from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap -from app.utils.user_message_multimodal import build_human_message_content - -_background_tasks: set[asyncio.Task] = set() -_perf_log = get_perf_logger() -logger = logging.getLogger(__name__) - -TURN_CANCELLING_INITIAL_DELAY_MS = 200 -TURN_CANCELLING_BACKOFF_FACTOR = 2 -TURN_CANCELLING_MAX_DELAY_MS = 1500 - - -def _resume_step_prefix(turn_id: str) -> str: - """Build the per-turn ``step_prefix`` for a resume invocation. - - Each ``_stream_agent_events`` call constructs a fresh - :class:`AgentEventRelayState` with ``thinking_step_counter=0``, so two - consecutive resume turns would otherwise both emit ``thinking-resume-1``, - ``-2`` etc. The frontend rehydrates ``currentThinkingSteps`` from the - immediate prior assistant message at the start of every resume — if the - new stream's IDs collide with the seeded ones, React renders sibling - Timeline rows with the same key. Salting with ``turn_id`` guarantees - disjoint IDs across resumes within one thread. - """ - return f"thinking-resume-{turn_id}" - - -def _compute_turn_cancelling_retry_delay(attempt: int) -> int: - if attempt < 1: - attempt = 1 - delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( - TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) - ) - return min(delay, TURN_CANCELLING_MAX_DELAY_MS) - - -def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: - """Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts. - - Returns a dict with three keys: - - * ``text`` — concatenated string content (empty string if the chunk - contributes none). - * ``reasoning`` — concatenated reasoning content (empty string if the - chunk contributes none). - * ``tool_call_chunks`` — flat list of LangChain ``tool_call_chunk`` - dicts surfaced from either the typed-block list or the - ``tool_call_chunks`` attribute. - - Background - ---------- - ``AIMessageChunk.content`` can be: - - * a ``str`` (most providers), or - * a ``list`` of typed blocks ``{type: 'text' | 'reasoning' | - 'tool_call_chunk' | 'tool_use' | ..., text/content/...}`` for - Anthropic, Bedrock, and several reasoning configurations. - - Reasoning may also live under - ``chunk.additional_kwargs['reasoning_content']`` (some providers - surface it that way instead of as a typed block). Tool-call chunks - may live under ``chunk.tool_call_chunks`` even when ``content`` is a - plain string. - - Earlier versions only handled the ``isinstance(content, str)`` branch - and silently dropped reasoning blocks + tool-call chunks emitted by - LangChain ``AIMessageChunk``s. - """ - out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []} - if chunk is None: - return out - - content = getattr(chunk, "content", None) - if isinstance(content, str): - if content: - out["text"] = content - elif isinstance(content, list): - text_parts: list[str] = [] - reasoning_parts: list[str] = [] - for block in content: - if not isinstance(block, dict): - continue - block_type = block.get("type") - if block_type == "text": - value = block.get("text") or block.get("content") or "" - if isinstance(value, str) and value: - text_parts.append(value) - elif block_type == "reasoning": - value = ( - block.get("reasoning") - or block.get("text") - or block.get("content") - or "" - ) - if isinstance(value, str) and value: - reasoning_parts.append(value) - elif block_type in ("tool_call_chunk", "tool_use"): - out["tool_call_chunks"].append(block) - if text_parts: - out["text"] = "".join(text_parts) - if reasoning_parts: - out["reasoning"] = "".join(reasoning_parts) - - additional = getattr(chunk, "additional_kwargs", None) or {} - if isinstance(additional, dict): - extra_reasoning = additional.get("reasoning_content") - if isinstance(extra_reasoning, str) and extra_reasoning: - existing = out["reasoning"] - out["reasoning"] = ( - (existing + extra_reasoning) if existing else extra_reasoning - ) - - extra_tool_chunks = getattr(chunk, "tool_call_chunks", None) - if isinstance(extra_tool_chunks, list): - for tcc in extra_tool_chunks: - if isinstance(tcc, dict): - out["tool_call_chunks"].append(tcc) - - return out - - -def extract_todos_from_deepagents(command_output) -> dict: - """ - Extract todos from deepagents' TodoListMiddleware Command output. - - deepagents returns a Command object with: - - Command.update['todos'] = [{'content': '...', 'status': '...'}] - - Returns the todos directly (no transformation needed - UI matches deepagents format). - """ - todos_data = [] - if hasattr(command_output, "update"): - # It's a Command object from deepagents - update = command_output.update - todos_data = update.get("todos", []) - elif isinstance(command_output, dict): - # Already a dict - check if it has todos directly or in update - if "todos" in command_output: - todos_data = command_output.get("todos", []) - elif "update" in command_output and isinstance(command_output["update"], dict): - todos_data = command_output["update"].get("todos", []) - - return {"todos": todos_data} - - -@dataclass -class StreamResult: - accumulated_text: str = "" - is_interrupted: bool = False - sandbox_files: list[str] = field(default_factory=list) - request_id: str | None = None - turn_id: str = "" - filesystem_mode: str = "cloud" - client_platform: str = "web" - intent_detected: str = "chat_only" - intent_confidence: float = 0.0 - write_attempted: bool = False - write_succeeded: bool = False - verification_succeeded: bool = False - commit_gate_passed: bool = True - commit_gate_reason: str = "" - # Pre-allocated assistant ``new_chat_messages.id`` for this turn, - # captured by ``persist_assistant_shell`` right after the user row is - # persisted. ``None`` for the legacy / anonymous code paths that don't - # opt in to server-side ``ContentPart[]`` projection. - assistant_message_id: int | None = None - # In-memory mirror of the FE's assistant-ui ``ContentPartsState``, - # populated by the lifecycle methods called from ``_stream_agent_events`` - # at each ``streaming_service.format_*`` yield site. Snapshot in the - # streaming ``finally`` to produce the rich JSONB persisted by - # ``finalize_assistant_turn``. ``repr=False`` keeps the - # log-on-error path (``StreamResult`` is logged in some error - # branches) from dumping a potentially-large parts list. - content_builder: Any | None = field(default=None, repr=False) - - -def _safe_float(value: Any, default: float = 0.0) -> float: - try: - return float(value) - except (TypeError, ValueError): - return default - - -def _tool_output_to_text(tool_output: Any) -> str: - if isinstance(tool_output, dict): - if isinstance(tool_output.get("result"), str): - return tool_output["result"] - if isinstance(tool_output.get("error"), str): - return tool_output["error"] - return json.dumps(tool_output, ensure_ascii=False) - return str(tool_output) - - -def _tool_output_has_error(tool_output: Any) -> bool: - if isinstance(tool_output, dict): - if tool_output.get("error"): - return True - result = tool_output.get("result") - return bool( - isinstance(result, str) and result.strip().lower().startswith("error:") - ) - if isinstance(tool_output, str): - return tool_output.strip().lower().startswith("error:") - return False - - -def _extract_resolved_file_path( - *, tool_name: str, tool_output: Any, tool_input: Any | None = None -) -> str | None: - if isinstance(tool_output, dict): - path_value = tool_output.get("path") - if isinstance(path_value, str) and path_value.strip(): - return path_value.strip() - if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict): - file_path = tool_input.get("file_path") - if isinstance(file_path, str) and file_path.strip(): - return file_path.strip() - return None - - -def _contract_enforcement_active(result: StreamResult) -> bool: - # Keep policy deterministic with no env-driven progression modes: - # enforce the file-operation contract only in desktop local-folder mode. - return result.filesystem_mode == "desktop_local_folder" - - -def _evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]: - if result.intent_detected != "file_write": - return True, "" - if not result.write_attempted: - return False, "no_write_attempt" - if not result.write_succeeded: - return False, "write_failed" - if not result.verification_succeeded: - return False, "verification_failed" - return True, "" - - -def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: - payload: dict[str, Any] = { - "stage": stage, - "request_id": result.request_id or "unknown", - "turn_id": result.turn_id or "unknown", - "chat_id": result.turn_id.split(":", 1)[0] - if ":" in result.turn_id - else "unknown", - "filesystem_mode": result.filesystem_mode, - "client_platform": result.client_platform, - "intent_detected": result.intent_detected, - "intent_confidence": result.intent_confidence, - "write_attempted": result.write_attempted, - "write_succeeded": result.write_succeeded, - "verification_succeeded": result.verification_succeeded, - "commit_gate_passed": result.commit_gate_passed, - "commit_gate_reason": result.commit_gate_reason or None, - } - payload.update(extra) - _perf_log.info( - "[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False) - ) - - -def _log_chat_stream_error( - *, - flow: Literal["new", "resume", "regenerate"], - error_kind: str, - error_code: str | None, - severity: Literal["info", "warn", "error"], - is_expected: bool, - request_id: str | None, - thread_id: int | None, - search_space_id: int | None, - user_id: str | None, - message: str, - extra: dict[str, Any] | None = None, -) -> None: - payload: dict[str, Any] = { - "event": "chat_stream_error", - "flow": flow, - "error_kind": error_kind, - "error_code": error_code, - "severity": severity, - "is_expected": is_expected, - "request_id": request_id or "unknown", - "thread_id": thread_id, - "search_space_id": search_space_id, - "user_id": user_id, - "message": message, - } - if extra: - payload.update(extra) - - logger = logging.getLogger(__name__) - rendered = json.dumps(payload, ensure_ascii=False) - if severity == "error": - logger.error("[chat_stream_error] %s", rendered) - elif severity == "warn": - logger.warning("[chat_stream_error] %s", rendered) - else: - logger.info("[chat_stream_error] %s", rendered) - - -def _parse_error_payload(message: str) -> dict[str, Any] | None: - candidates = [message] - first_brace_idx = message.find("{") - if first_brace_idx >= 0: - candidates.append(message[first_brace_idx:]) - - for candidate in candidates: - try: - parsed = json.loads(candidate) - if isinstance(parsed, dict): - return parsed - except Exception: - continue - return None - - -def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: - if not isinstance(parsed, dict): - return None - candidates: list[Any] = [parsed.get("code")] - nested = parsed.get("error") - if isinstance(nested, dict): - candidates.append(nested.get("code")) - for value in candidates: - try: - if value is None: - continue - return int(value) - except Exception: - continue - return None - - -def _is_provider_rate_limited(exc: BaseException) -> bool: - """Best-effort detection for provider-side runtime throttling. - - Covers LiteLLM/OpenRouter shapes like: - - class name contains ``RateLimit`` - - nested payload ``{"error": {"code": 429}}`` - - nested payload ``{"error": {"type": "rate_limit_error"}}`` - """ - raw = str(exc) - lowered = raw.lower() - if "ratelimit" in type(exc).__name__.lower(): - return True - parsed = _parse_error_payload(raw) - provider_code = _extract_provider_error_code(parsed) - if provider_code == 429: - return True - - provider_error_type = "" - if parsed: - top_type = parsed.get("type") - if isinstance(top_type, str): - provider_error_type = top_type.lower() - nested = parsed.get("error") - if isinstance(nested, dict): - nested_type = nested.get("type") - if isinstance(nested_type, str): - provider_error_type = nested_type.lower() - if provider_error_type == "rate_limit_error": - return True - - return ( - "rate limited" in lowered - or "rate-limited" in lowered - or "temporarily rate-limited upstream" in lowered - ) - - -async def _build_main_agent_for_thread( - agent_factory: Any, - *, - llm: Any, - search_space_id: int, - db_session: Any, - connector_service: ConnectorService, - checkpointer: Any, - user_id: str | None, - thread_id: int | None, - agent_config: AgentConfig | None, - firecrawl_api_key: str | None, - thread_visibility: ChatVisibility | None, - filesystem_selection: FilesystemSelection | None, - disabled_tools: list[str] | None = None, - mentioned_document_ids: list[int] | None = None, -) -> Any: - """Single (re)build path so the agent factory cannot drift across the - initial build and mid-stream 429 recovery for one ``thread_id``: a - graph swap mid-turn would corrupt checkpointer state.""" - return await agent_factory( - llm=llm, - search_space_id=search_space_id, - db_session=db_session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=thread_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=thread_visibility, - filesystem_selection=filesystem_selection, - disabled_tools=disabled_tools, - mentioned_document_ids=mentioned_document_ids, - ) - - -def _classify_stream_exception( - exc: Exception, - *, - flow_label: str, -) -> tuple[ - str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None -]: - raw = str(exc) - if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: - busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None - if busy_thread_id and is_cancel_requested(busy_thread_id): - cancel_state = get_cancel_state(busy_thread_id) - attempt = cancel_state[0] if cancel_state else 1 - retry_after_ms = _compute_turn_cancelling_retry_delay(attempt) - retry_after_at = int(time.time() * 1000) + retry_after_ms - return ( - "thread_busy", - "TURN_CANCELLING", - "info", - True, - "A previous response is still stopping. Please try again in a moment.", - { - "retry_after_ms": retry_after_ms, - "retry_after_at": retry_after_at, - }, - ) - return ( - "thread_busy", - "THREAD_BUSY", - "warn", - True, - "Another response is still finishing for this thread. Please try again in a moment.", - None, - ) - - if _is_provider_rate_limited(exc): - return ( - "rate_limited", - "RATE_LIMITED", - "warn", - True, - "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", - None, - ) - - return ( - "server_error", - "SERVER_ERROR", - "error", - False, - f"Error during {flow_label}: {raw}", - None, - ) - - -def _emit_stream_terminal_error( - *, - streaming_service: VercelStreamingService, - flow: str, - request_id: str | None, - thread_id: int, - search_space_id: int, - user_id: str | None, - message: str, - error_kind: str = "server_error", - error_code: str = "SERVER_ERROR", - severity: Literal["info", "warn", "error"] = "error", - is_expected: bool = False, - extra: dict[str, Any] | None = None, -) -> str: - _log_chat_stream_error( - flow=flow, - error_kind=error_kind, - error_code=error_code, - severity=severity, - is_expected=is_expected, - request_id=request_id, - thread_id=thread_id, - search_space_id=search_space_id, - user_id=user_id, - message=message, - extra=extra, - ) - return streaming_service.format_error(message, error_code=error_code, extra=extra) - - -def _legacy_match_lc_id( - pending_tool_call_chunks: list[dict[str, Any]], - tool_name: str, - run_id: str, - lc_tool_call_id_by_run: dict[str, str], -) -> str | None: - """Best-effort match a buffered ``tool_call_chunk`` to a tool name. - - Pure extract of the in-line match used at ``on_tool_start`` when the - chunk path didn't register an index for this call. Pops the next - id-bearing chunk whose ``name`` - matches ``tool_name`` (or any id-bearing chunk as a fallback) and - returns its id. Mutates ``pending_tool_call_chunks`` and - ``lc_tool_call_id_by_run`` in place. - """ - matched_idx: int | None = None - for idx, tcc in enumerate(pending_tool_call_chunks): - if tcc.get("name") == tool_name and tcc.get("id"): - matched_idx = idx - break - if matched_idx is None: - for idx, tcc in enumerate(pending_tool_call_chunks): - if tcc.get("id"): - matched_idx = idx - break - if matched_idx is None: - return None - matched = pending_tool_call_chunks.pop(matched_idx) - candidate = matched.get("id") - if isinstance(candidate, str) and candidate: - if run_id: - lc_tool_call_id_by_run[run_id] = candidate - return candidate - return None - - -async def _stream_agent_events( - agent: Any, - config: dict[str, Any], - input_data: Any, - streaming_service: VercelStreamingService, - result: StreamResult, - step_prefix: str = "thinking", - initial_step_id: str | None = None, - initial_step_title: str = "", - initial_step_items: list[str] | None = None, - *, - fallback_commit_search_space_id: int | None = None, - fallback_commit_created_by_id: str | None = None, - fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, - fallback_commit_thread_id: int | None = None, - runtime_context: Any = None, - content_builder: Any | None = None, -) -> AsyncGenerator[str, None]: - """Shared async generator that streams and formats astream_events from the agent. - - Yields SSE-formatted strings. After exhausting, inspect the ``result`` - object for accumulated_text and interrupt state. - - Args: - agent: The compiled LangGraph agent. - config: LangGraph config dict (must include configurable.thread_id). - input_data: The input to pass to agent.astream_events (dict or Command). - streaming_service: VercelStreamingService instance for formatting events. - result: Mutable StreamResult populated with accumulated_text / interrupt info. - step_prefix: Prefix for thinking step IDs (e.g. "thinking" or "thinking-resume"). - initial_step_id: If set, the helper inherits an already-active thinking step. - initial_step_title: Title of the inherited thinking step. - initial_step_items: Items of the inherited thinking step. - content_builder: Optional ``AssistantContentBuilder``. When set, every - ``streaming_service.format_*`` yield site also drives the matching - builder lifecycle method (``on_text_*``, ``on_reasoning_*``, - ``on_tool_*``, ``on_thinking_step``, ``on_step_separator``) so the - in-memory ``ContentPart[]`` projection stays in lockstep with what - the FE renders live. Pure in-memory accumulation — no DB I/O — - consumed by the streaming ``finally`` to produce the rich JSONB - persisted via ``finalize_assistant_turn``. ``None`` (the default) - is used by the anonymous / legacy code paths and is a no-op. - - Yields: - SSE-formatted strings for each event. - """ - async for sse in stream_output( - agent=agent, - config=config, - input_data=input_data, - streaming_service=streaming_service, - result=result, - step_prefix=step_prefix, - initial_step_id=initial_step_id, - initial_step_title=initial_step_title, - initial_step_items=initial_step_items, - content_builder=content_builder, - runtime_context=runtime_context, - ): - yield sse - - accumulated_text = result.accumulated_text - - state = await agent.aget_state(config) - state_values = getattr(state, "values", {}) or {} - - # Safety net: if astream_events was cancelled before - # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work - # (dirty_paths / staged_dirs / pending_moves / pending_deletes / - # pending_dir_deletes) will still be in the checkpointed state. Run - # the SAME shared commit helper here so the turn's writes don't get - # lost on client disconnect, then push the delta back into the graph - # using `as_node=...` so reducers fire as if the after_agent hook - # produced it. - if ( - fallback_commit_filesystem_mode == FilesystemMode.CLOUD - and fallback_commit_search_space_id is not None - and ( - (state_values.get("dirty_paths") or []) - or (state_values.get("staged_dirs") or []) - or (state_values.get("pending_moves") or []) - or (state_values.get("pending_deletes") or []) - or (state_values.get("pending_dir_deletes") or []) - ) - ): - try: - delta = await commit_staged_filesystem_state( - state_values, - search_space_id=fallback_commit_search_space_id, - created_by_id=fallback_commit_created_by_id, - filesystem_mode=fallback_commit_filesystem_mode, - thread_id=fallback_commit_thread_id, - dispatch_events=False, - ) - if delta: - await agent.aupdate_state( - config, - delta, - as_node="KnowledgeBasePersistenceMiddleware.after_agent", - ) - except Exception as exc: - _perf_log.warning("[stream_new_chat] safety-net commit failed: %s", exc) - - contract_state = state_values.get("file_operation_contract") or {} - contract_turn_id = contract_state.get("turn_id") - current_turn_id = config.get("configurable", {}).get("turn_id", "") - intent_value = contract_state.get("intent") - if ( - isinstance(intent_value, str) - and intent_value in ("chat_only", "file_write", "file_read") - and contract_turn_id == current_turn_id - ): - result.intent_detected = intent_value - if ( - isinstance(intent_value, str) - and intent_value - in ( - "chat_only", - "file_write", - "file_read", - ) - and contract_turn_id != current_turn_id - ): - # Ignore stale intent contracts from previous turns/checkpoints. - result.intent_detected = "chat_only" - result.intent_confidence = ( - _safe_float(contract_state.get("confidence"), default=0.0) - if contract_turn_id == current_turn_id - else 0.0 - ) - - if result.intent_detected == "file_write": - result.commit_gate_passed, result.commit_gate_reason = ( - _evaluate_file_contract_outcome(result) - ) - if not result.commit_gate_passed and _contract_enforcement_active(result): - gate_notice = ( - "I could not complete the requested file write because no successful " - "write_file/edit_file operation was confirmed." - ) - gate_text_id = streaming_service.generate_text_id() - yield streaming_service.format_text_start(gate_text_id) - if content_builder is not None: - content_builder.on_text_start(gate_text_id) - yield streaming_service.format_text_delta(gate_text_id, gate_notice) - if content_builder is not None: - content_builder.on_text_delta(gate_text_id, gate_notice) - yield streaming_service.format_text_end(gate_text_id) - if content_builder is not None: - content_builder.on_text_end(gate_text_id) - yield streaming_service.format_terminal_info(gate_notice, "error") - accumulated_text = gate_notice - else: - result.commit_gate_passed = True - result.commit_gate_reason = "" - - result.accumulated_text = accumulated_text - _log_file_contract("turn_outcome", result) - - pending_values = all_interrupt_values(state) - if pending_values: - result.is_interrupted = True - # One frame per paused subagent so each parallel HITL renders its own - # approval card on the wire. Order matches ``state.interrupts``, which - # the resume slicer in ``checkpointed_subagent_middleware.resume_routing`` - # consumes in the same order — keeping emit and resume in lock-step. - for interrupt_value in pending_values: - yield streaming_service.format_interrupt_request(interrupt_value) - - -async def stream_new_chat( - user_query: str, - search_space_id: int, - chat_id: int, - user_id: str | None = None, - llm_config_id: int = -1, - mentioned_document_ids: list[int] | None = None, - mentioned_folder_ids: list[int] | None = None, - mentioned_connector_ids: list[int] | None = None, - mentioned_connectors: list[dict[str, Any]] | None = None, - mentioned_documents: list[dict[str, Any]] | None = None, - checkpoint_id: str | None = None, - needs_history_bootstrap: bool = False, - thread_visibility: ChatVisibility | None = None, - current_user_display_name: str | None = None, - disabled_tools: list[str] | None = None, - filesystem_selection: FilesystemSelection | None = None, - request_id: str | None = None, - user_image_data_urls: list[str] | None = None, - flow: Literal["new", "regenerate"] = "new", -) -> AsyncGenerator[str, None]: - """ - Stream chat responses from the new SurfSense deep agent. - - This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming. - The chat_id is used as LangGraph's thread_id for memory/checkpointing. - - The function creates and manages its own database session to guarantee proper - cleanup even when Starlette's middleware cancels the task on client disconnect. - - Args: - user_query: The user's query - search_space_id: The search space ID - chat_id: The chat ID (used as LangGraph thread_id for memory) - user_id: The current user's UUID string (for memory tools and session state) - llm_config_id: The LLM configuration ID (default: -1 for first global config) - needs_history_bootstrap: If True, load message history from DB (for cloned chats) - mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat - mentioned_folder_ids: Optional list of knowledge-base folder IDs mentioned with @ (cloud mode) - checkpoint_id: Optional checkpoint ID to rewind/fork from (for edit/reload operations) - - Yields: - str: SSE formatted response strings - """ - streaming_service = VercelStreamingService() - stream_result = StreamResult() - _t_total = time.perf_counter() - fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud" - fs_platform = ( - filesystem_selection.client_platform.value if filesystem_selection else "web" - ) - stream_result.request_id = request_id - stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}" - stream_result.filesystem_mode = fs_mode - stream_result.client_platform = fs_platform - chat_agent_mode = "unknown" - chat_outcome = "success" - chat_error_category: str | None = None - chat_span_cm = ot.chat_request_span( - chat_id=chat_id, - search_space_id=search_space_id, - flow=flow, - request_id=request_id, - turn_id=stream_result.turn_id, - filesystem_mode=fs_mode, - client_platform=fs_platform, - agent_mode=chat_agent_mode, - ) - chat_span = chat_span_cm.__enter__() - _log_file_contract("turn_start", stream_result) - _perf_log.info( - "[stream_new_chat] filesystem_mode=%s client_platform=%s", - fs_mode, - fs_platform, - ) - log_system_snapshot("stream_new_chat_START") - - from app.services.token_tracking_service import start_turn - - accumulator = start_turn() - - # Premium credit (USD micro-units) tracking state. Stores the - # amount reserved up front so we can release it on cancellation - # and finalize-debit the actual provider cost reported by LiteLLM. - _premium_reserved_micros = 0 - _premium_request_id: str | None = None - - # ``BusyError`` fires before the lock is acquired; the ``finally`` must - # not release the in-flight caller's lock. - _busy_error_raised = False - - _emit_stream_error = partial( - _emit_stream_terminal_error, - streaming_service=streaming_service, - flow=flow, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - ) - - session = async_session_maker() - try: - # Mark AI as responding to this user for live collaboration - if user_id: - await set_ai_responding(session, chat_id, UUID(user_id)) - # Load LLM config - supports both YAML (negative IDs) and database (positive IDs) - agent_config: AgentConfig | None = None - requested_llm_config_id = llm_config_id - - async def _load_llm_bundle( - config_id: int, - ) -> tuple[Any, AgentConfig | None, str | None]: - if config_id >= 0: - loaded_agent_config = await load_agent_config( - session=session, - config_id=config_id, - search_space_id=search_space_id, - ) - if not loaded_agent_config: - return ( - None, - None, - f"Failed to load NewLLMConfig with id {config_id}", - ) - return ( - create_chat_litellm_from_agent_config(loaded_agent_config), - loaded_agent_config, - None, - ) - - loaded_llm_config = load_global_llm_config_by_id(config_id) - if not loaded_llm_config: - return None, None, f"Failed to load LLM config with id {config_id}" - return ( - create_chat_litellm_from_config(loaded_llm_config), - AgentConfig.from_yaml_config(loaded_llm_config), - None, - ) - - _t0 = time.perf_counter() - # Image-bearing turns force the Auto-pin resolver to filter the - # candidate pool to vision-capable cfgs (and force-repin a - # text-only existing pin). For explicit selections this flag is - # a no-op — the resolver returns the user's chosen id unchanged. - _requires_image_input = bool(user_image_data_urls) - try: - llm_config_id = ( - await resolve_or_get_pinned_llm_config_id( - session, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - selected_llm_config_id=llm_config_id, - requires_image_input=_requires_image_input, - ) - ).resolved_llm_config_id - ot.add_event( - "model.pin.resolved", - { - "pin.requested_id": requested_llm_config_id, - "pin.resolved_id": llm_config_id, - "pin.requires_image_input": _requires_image_input, - }, - ) - except ValueError as pin_error: - # Auto-pin's "no vision-capable cfg" path raises a ValueError - # whose message we map to the friendly image-input SSE error - # so the user sees the same message regardless of whether - # the gate fired in Auto-mode or in the agent_config check - # below. - error_code = ( - "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" - if _requires_image_input and "vision-capable" in str(pin_error) - else "SERVER_ERROR" - ) - error_kind = ( - "user_error" - if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" - else "server_error" - ) - if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT": - ot.add_event( - "quota.denied", - { - "quota.code": error_code, - }, - ) - yield _emit_stream_error( - message=str(pin_error), - error_kind=error_kind, - error_code=error_code, - ) - yield streaming_service.format_done() - return - - llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) - if llm_load_error: - yield _emit_stream_error( - message=llm_load_error, - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - _perf_log.info( - "[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)", - time.perf_counter() - _t0, - llm_config_id, - ) - - # Capability safety net: a turn carrying user-uploaded images - # cannot be routed to a chat config that LiteLLM's authoritative - # model map *explicitly* marks as text-only (``supports_vision`` - # set to False). The check is intentionally narrow — it only - # fires when LiteLLM is *certain* the model can't accept image - # input. Unknown / unmapped / vision-capable models pass - # through. Without this guard a known-text-only model would 404 - # at the provider with ``"No endpoints found that support image - # input"``, surfacing as an opaque ``SERVER_ERROR`` SSE chunk; - # failing here lets us return a friendly message that tells the - # user what to change. - if user_image_data_urls and agent_config is not None: - from app.services.provider_capabilities import ( - is_known_text_only_chat_model, - ) - - agent_litellm_params = agent_config.litellm_params or {} - agent_base_model = ( - agent_litellm_params.get("base_model") - if isinstance(agent_litellm_params, dict) - else None - ) - if is_known_text_only_chat_model( - provider=agent_config.provider, - model_name=agent_config.model_name, - base_model=agent_base_model, - custom_provider=agent_config.custom_provider, - ): - model_label = ( - agent_config.config_name or agent_config.model_name or "model" - ) - ot.add_event( - "quota.denied", - { - "quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT", - }, - ) - yield _emit_stream_error( - message=( - f"The selected model ({model_label}) does not support " - "image input. Switch to a vision-capable model " - "(e.g. GPT-4o, Claude, Gemini) or remove the image " - "attachment and try again." - ), - error_kind="user_error", - error_code="MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT", - ) - yield streaming_service.format_done() - return - - # Premium quota reservation for pinned premium model only. - _needs_premium_quota = ( - agent_config is not None and user_id and agent_config.is_premium - ) - if _needs_premium_quota: - import uuid as _uuid - - from app.services.token_quota_service import ( - TokenQuotaService, - estimate_call_reserve_micros, - ) - - _premium_request_id = _uuid.uuid4().hex[:16] - _agent_litellm_params = agent_config.litellm_params or {} - _agent_base_model = ( - _agent_litellm_params.get("base_model") or agent_config.model_name or "" - ) - reserve_amount_micros = estimate_call_reserve_micros( - base_model=_agent_base_model, - quota_reserve_tokens=agent_config.quota_reserve_tokens, - ) - async with shielded_async_session() as quota_session: - quota_result = await TokenQuotaService.premium_reserve( - db_session=quota_session, - user_id=UUID(user_id), - request_id=_premium_request_id, - reserve_micros=reserve_amount_micros, - ) - _premium_reserved_micros = reserve_amount_micros - if not quota_result.allowed: - ot.add_event( - "quota.denied", - { - "quota.code": "PREMIUM_QUOTA_EXHAUSTED", - }, - ) - if requested_llm_config_id == 0: - try: - llm_config_id = ( - await resolve_or_get_pinned_llm_config_id( - session, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - selected_llm_config_id=0, - force_repin_free=True, - requires_image_input=_requires_image_input, - ) - ).resolved_llm_config_id - ot.add_event( - "model.repin", - { - "repin.reason": "premium_quota_exhausted", - "repin.to_config_id": llm_config_id, - }, - ) - except ValueError as pin_error: - yield _emit_stream_error( - message=str(pin_error), - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - - llm, agent_config, llm_load_error = await _load_llm_bundle( - llm_config_id - ) - if llm_load_error: - yield _emit_stream_error( - message=llm_load_error, - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - _premium_request_id = None - _premium_reserved_micros = 0 - _log_chat_stream_error( - flow=flow, - error_kind="premium_quota_exhausted", - error_code="PREMIUM_QUOTA_EXHAUSTED", - severity="info", - is_expected=True, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - message=( - "Premium quota exhausted on pinned model; auto-fallback switched to a free model" - ), - extra={ - "fallback_config_id": llm_config_id, - "auto_fallback": True, - }, - ) - else: - yield _emit_stream_error( - message=( - "Buy more tokens to continue with this model, or switch to a free model" - ), - error_kind="premium_quota_exhausted", - error_code="PREMIUM_QUOTA_EXHAUSTED", - severity="info", - is_expected=True, - extra={ - "resolved_config_id": llm_config_id, - "auto_fallback": False, - }, - ) - yield streaming_service.format_done() - return - - if not llm: - yield _emit_stream_error( - message="Failed to create LLM instance", - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - - # Create connector service - _t0 = time.perf_counter() - connector_service = ConnectorService(session, search_space_id=search_space_id) - - firecrawl_api_key = None - webcrawler_connector = await connector_service.get_connector_by_type( - SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id - ) - if webcrawler_connector and webcrawler_connector.config: - firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") - _perf_log.info( - "[stream_new_chat] Connector service + firecrawl key in %.3fs", - time.perf_counter() - _t0, - ) - - # Get the PostgreSQL checkpointer for persistent conversation memory - _t0 = time.perf_counter() - checkpointer = await get_checkpointer() - _perf_log.info( - "[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0 - ) - - visibility = thread_visibility or ChatVisibility.PRIVATE - from app.config import config as _app_config - - use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED) - chat_agent_mode = "multi" if use_multi_agent else "single" - with contextlib.suppress(Exception): - chat_span.set_attribute("agent.mode", chat_agent_mode) - - _t0 = time.perf_counter() - agent_factory = ( - create_multi_agent_chat_deep_agent - if use_multi_agent - else create_surfsense_deep_agent - ) - # Build the agent inline. Provider 429s surface through the - # in-stream recovery loop below (``_is_provider_rate_limited``), - # which repins the thread to an eligible alternative config and - # rebuilds the agent before the user sees any output. - agent = await _build_main_agent_for_thread( - agent_factory, - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - filesystem_selection=filesystem_selection, - disabled_tools=disabled_tools, - mentioned_document_ids=mentioned_document_ids, - ) - _perf_log.info( - "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 - ) - - # Build input with message history - langchain_messages = [] - - _t0 = time.perf_counter() - # Bootstrap history for cloned chats (no LangGraph checkpoint exists yet) - if needs_history_bootstrap: - langchain_messages = await bootstrap_history_from_db( - session, chat_id, thread_visibility=visibility - ) - - thread_result = await session.execute( - select(NewChatThread).filter(NewChatThread.id == chat_id) - ) - thread = thread_result.scalars().first() - if thread: - thread.needs_history_bootstrap = False - await session.commit() - - # Mentioned KB documents are now handled by KnowledgeBaseSearchMiddleware - # which merges them into the scoped filesystem with full document - # structure. Only report context is inlined here. - - # Fetch the most recent report(s) in this thread so the LLM can - # easily find report_id for versioning decisions, instead of - # having to dig through conversation history. - recent_reports_result = await session.execute( - select(Report) - .filter( - Report.thread_id == chat_id, - Report.content.isnot(None), # exclude failed reports - ) - .order_by(Report.id.desc()) - .limit(3) - ) - recent_reports = list(recent_reports_result.scalars().all()) - - # Resolve @-mention chips to canonical virtual paths and rewrite - # the user-typed text so the LLM sees ``\`/documents/...\``` instead - # of bare ``@title``. The substitution lands in ``agent_user_query`` - # ONLY — the original ``user_query`` (with ``@title`` tokens) flows - # untouched into ``persist_user_turn`` below so chip rendering on - # reload still works (``UserTextPart`` → ``parseMentionSegments`` - # matches ``@title``, not ``\`/documents/...\```). It also feeds - # the human-readable surfaces — SSE "Processing X" status, auto - # thread title, memory seed — which all want what the user typed. - # See ``persistence._build_user_content``. - # - # Cloud mode only: local-folder mode keeps the legacy - # ``@title`` text path; mention support there is a follow-up - # task because the path scheme (mount-rooted) and the picker - # UI both need separate work. - agent_user_query = user_query - accepted_folder_ids: list[int] = [] - if fs_mode == FilesystemMode.CLOUD.value and ( - mentioned_document_ids or mentioned_folder_ids or mentioned_documents - ): - from app.schemas.new_chat import ( - MentionedDocumentInfo as _MentionedDocumentInfo, - ) - - chip_objs: list[_MentionedDocumentInfo] | None = None - if mentioned_documents: - chip_objs = [] - for raw in mentioned_documents: - if isinstance(raw, _MentionedDocumentInfo): - chip_objs.append(raw) - continue - try: - chip_objs.append(_MentionedDocumentInfo.model_validate(raw)) - except Exception: - logger.debug( - "stream_new_chat: dropping malformed mention chip %r", - raw, - ) - - resolved = await resolve_mentions( - session, - search_space_id=search_space_id, - mentioned_documents=chip_objs, - mentioned_document_ids=mentioned_document_ids, - mentioned_folder_ids=mentioned_folder_ids, - ) - agent_user_query = substitute_in_text(user_query, resolved.token_to_path) - accepted_folder_ids = resolved.mentioned_folder_ids - - # Format the user query with context (reports only). - # Uses ``agent_user_query`` so the LLM sees backtick-wrapped paths - # instead of bare ``@title`` tokens. - final_query = agent_user_query - context_parts = [] - - if mentioned_connectors: - connector_lines = [] - for connector in mentioned_connectors: - if not isinstance(connector, dict): - continue - connector_id = connector.get("id") - connector_type = connector.get("connector_type") or connector.get( - "document_type" - ) - account_name = connector.get("account_name") or connector.get("title") - if connector_id is None or connector_type is None: - continue - connector_lines.append( - f' - connector_id={connector_id}, connector_type="{connector_type}", ' - f'account_name="{account_name or ""}"' - ) - if connector_lines: - context_parts.append( - "\n" - "The user selected these exact connector accounts with @. " - "These entries are selection metadata, not retrieved connector content. " - "When a connector-backed tool needs an account, use the matching " - "connector_id from this list if the tool supports connector_id:\n" - + "\n".join(connector_lines) - + "\n" - ) - - # Surface report IDs prominently so the LLM doesn't have to - # retrieve them from old tool responses in conversation history. - if recent_reports: - report_lines = [] - for r in recent_reports: - report_lines.append( - f' - report_id={r.id}, title="{r.title}", ' - f'style="{r.report_style or "detailed"}"' - ) - reports_listing = "\n".join(report_lines) - context_parts.append( - "\n" - "Previously generated reports in this conversation:\n" - f"{reports_listing}\n\n" - "If the user wants to MODIFY, REVISE, UPDATE, or ADD to one of " - "these reports, set parent_report_id to the relevant report_id above.\n" - "If the user wants a completely NEW report on a different topic, " - "leave parent_report_id unset.\n" - "" - ) - - if context_parts: - context = "\n\n".join(context_parts) - final_query = f"{context}\n\n{agent_user_query}" - - if visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name: - final_query = f"**[{current_user_display_name}]:** {final_query}" - - # if messages: - # # Convert frontend messages to LangChain format - # for msg in messages: - # if msg.role == "user": - # langchain_messages.append(HumanMessage(content=msg.content)) - # elif msg.role == "assistant": - # langchain_messages.append(AIMessage(content=msg.content)) - # else: - human_content = build_human_message_content( - final_query, list(user_image_data_urls or ()) - ) - langchain_messages.append(HumanMessage(content=human_content)) - - input_state = { - # Lets not pass this message atm because we are using the checkpointer to manage the conversation history - # We will use this to simulate group chat functionality in the future - "messages": langchain_messages, - "search_space_id": search_space_id, - "request_id": request_id or "unknown", - "turn_id": stream_result.turn_id, - } - - _perf_log.info( - "[stream_new_chat] History bootstrap + doc/report queries in %.3fs", - time.perf_counter() - _t0, - ) - - # All pre-streaming DB reads are done. Commit to release the - # transaction and its ACCESS SHARE locks so we don't block DDL - # (e.g. migrations) for the entire duration of LLM streaming. - # Tools that need DB access during streaming will start their own - # short-lived transactions (or use isolated sessions). - await session.commit() - - # Detach heavy ORM objects (documents with chunks, reports, etc.) - # from the session identity map now that we've extracted the data - # we need. This prevents them from accumulating in memory for the - # entire duration of LLM streaming (which can be several minutes). - session.expunge_all() - - _perf_log.info( - "[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)", - time.perf_counter() - _t_total, - chat_id, - ) - - # Configure LangGraph with thread_id for memory - # If checkpoint_id is provided, fork from that checkpoint (for edit/reload) - configurable = {"thread_id": str(chat_id)} - configurable["request_id"] = request_id or "unknown" - configurable["turn_id"] = stream_result.turn_id - if checkpoint_id: - configurable["checkpoint_id"] = checkpoint_id - - config = { - "configurable": configurable, - # Effectively uncapped, matching the agent-level - # ``with_config`` default in ``chat_deepagent.create_agent`` - # and the unbounded ``while(true)`` loop used by OpenCode's - # ``session/processor.ts``. Real circuit-breakers live in - # middleware: ``DoomLoopMiddleware`` (sliding-window tool - # signature check), plus ``enable_tool_call_limit`` / - # ``enable_model_call_limit`` when those flags are set. The - # original LangGraph default of 25 (and our previous 80 - # bump) hit users on legitimate multi-tool plans. - "recursion_limit": 10_000, - } - - # Start the message stream - yield streaming_service.format_message_start() - yield streaming_service.format_start_step() - - # Surface the per-turn correlation id at the very start of the - # stream so the frontend can stamp it onto the in-flight - # assistant message and replay it via ``appendMessage`` - # for durable storage. Tool/action-log events DO carry it later, - # but pure-text turns never produce action-log events; this - # event guarantees the frontend learns the turn id regardless. - yield streaming_service.format_data( - "turn-info", - {"chat_turn_id": stream_result.turn_id}, - ) - yield streaming_service.format_data("turn-status", {"status": "busy"}) - - # Persist the user-side row for this turn before any expensive - # work runs. Closes the "ghost-thread" abuse vector - # (authenticated client hits POST /new_chat then never calls - # /messages — empty new_chat_messages, free LLM completion). - # Idempotent against the unique index in migration 141 so the - # legacy frontend appendMessage call is a no-op on the second - # writer. Hard failure aborts the turn so we never produce a - # title or assistant row that isn't anchored to a persisted - # user message. - from app.tasks.chat.content_builder import AssistantContentBuilder - from app.tasks.chat.persistence import ( - persist_assistant_shell, - persist_user_turn, - ) - - user_message_id = await persist_user_turn( - chat_id=chat_id, - user_id=user_id, - turn_id=stream_result.turn_id, - user_query=user_query, - user_image_data_urls=user_image_data_urls, - mentioned_documents=mentioned_documents, - ) - if user_message_id is None: - yield _emit_stream_error( - message=( - "We couldn't save your message. Please try again in a moment." - ), - error_kind="server_error", - error_code="MESSAGE_PERSIST_FAILED", - ) - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - return - - # Emit canonical user message id BEFORE any LLM streaming so the - # FE can rename its optimistic ``msg-user-XXX`` placeholder to - # ``msg-{user_message_id}`` and unlock features gated on a real - # DB id (comments, edit-from-this-message). See B4 in - # ``sse-based_message_id_handshake`` plan. - yield streaming_service.format_data( - "user-message-id", - {"message_id": user_message_id, "turn_id": stream_result.turn_id}, - ) - - # Pre-write the assistant row for this turn so we have a stable - # ``message_id`` to anchor mid-stream metadata (token_usage, - # future agent_action_log.message_id correlation) and a - # write-once UPDATE target at finalize time. Idempotent against - # the (thread_id, turn_id, ASSISTANT) partial unique index from - # migration 141 — if the legacy frontend appendMessage races - # this, we recover the existing row's id. - assistant_message_id = await persist_assistant_shell( - chat_id=chat_id, - user_id=user_id, - turn_id=stream_result.turn_id, - ) - if assistant_message_id is None: - # Genuine DB failure — abort the turn rather than stream - # into a void. The user row is already persisted so the - # legacy "ghost-thread" gate isn't reopened. - yield _emit_stream_error( - message=( - "We couldn't initialize the assistant message. Please try again." - ), - error_kind="server_error", - error_code="MESSAGE_PERSIST_FAILED", - ) - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - return - - # Emit canonical assistant message id BEFORE any LLM streaming - # so the FE can rename its optimistic ``msg-assistant-XXX`` - # placeholder to ``msg-{assistant_message_id}`` and bind - # ``tokenUsageStore`` / ``pendingInterrupt`` to the real id - # immediately. See B4 in ``sse-based_message_id_handshake`` - # plan. - yield streaming_service.format_data( - "assistant-message-id", - {"message_id": assistant_message_id, "turn_id": stream_result.turn_id}, - ) - - stream_result.assistant_message_id = assistant_message_id - stream_result.content_builder = AssistantContentBuilder() - - # Initial thinking step - analyzing the request - initial_title = "Understanding your request" - action_verb = "Processing" - - processing_parts = [] - if user_query.strip(): - query_text = user_query[:80] + ("..." if len(user_query) > 80 else "") - processing_parts.append(query_text) - elif user_image_data_urls: - processing_parts.append(f"[{len(user_image_data_urls)} image(s)]") - else: - processing_parts.append("(message)") - - initial_items = [f"{action_verb}: {' '.join(processing_parts)}"] - initial_step_id = "thinking-1" - - # Drive the builder for this initial thinking step too — the - # ``_emit_thinking_step`` helper lives inside ``_stream_agent_events`` - # so it isn't in scope here, but the FE folds this step into - # the same singleton ``data-thinking-steps`` part as everything - # the agent stream emits later. Mirror that fold server-side. - if stream_result.content_builder is not None: - stream_result.content_builder.on_thinking_step( - initial_step_id, initial_title, "in_progress", initial_items - ) - yield streaming_service.format_thinking_step( - step_id=initial_step_id, - title=initial_title, - status="in_progress", - items=initial_items, - ) - - # These ORM objects can be large. They're only needed to build context - # strings already copied into final_query / langchain_messages — - # release them before streaming. - del recent_reports - del langchain_messages, final_query - - # Check if this is the first assistant response so we can generate - # a title in parallel with the agent stream (better UX than waiting - # until after the full response). - # Use a LIMIT 1 EXISTS-style probe rather than COUNT(*) because - # this is now a hot path executed on every turn, and COUNT scales - # with thread length (server-side persistence can grow rows - # quickly under power users). - # - # IMPORTANT: ``persist_assistant_shell`` above (line ~3112) already - # inserted THIS turn's assistant row. We must therefore exclude - # it from the probe — otherwise the gate fires on every turn - # except the very first, and title generation never runs for new - # threads. Excluding by primary key (``id != assistant_message_id``) - # is bulletproof regardless of ``turn_id`` shape (legacy NULLs, - # resume turns, etc.). - first_assistant_probe = await session.execute( - select(NewChatMessage.id) - .filter( - NewChatMessage.thread_id == chat_id, - NewChatMessage.role == "assistant", - NewChatMessage.id != assistant_message_id, - ) - .limit(1) - ) - is_first_response = first_assistant_probe.scalars().first() is None - - title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None - # Gate title generation on a persisted user message so a stream - # that fails before persistence (we abort above) can never leave - # behind a thread with a generated title and no anchoring rows. - if is_first_response and user_message_id is not None: - - async def _generate_title() -> tuple[str | None, dict | None]: - """Generate a short title via litellm.acompletion. - - Returns (title, usage_dict). Usage is extracted directly from - the response object because litellm fires its async callback - via fire-and-forget ``create_task``, so the - ``TokenTrackingCallback`` would run too late. We also blank - the accumulator in this child-task context so the late callback - doesn't double-count. - """ - try: - from litellm import acompletion - - from app.services.llm_router_service import LLMRouterService - from app.services.provider_api_base import resolve_api_base - from app.services.token_tracking_service import _turn_accumulator - - _turn_accumulator.set(None) - - title_seed = user_query.strip() or ( - f"[{len(user_image_data_urls or [])} image(s)]" - if user_image_data_urls - else "" - ) - prompt = TITLE_GENERATION_PROMPT.replace( - "{user_query}", title_seed[:500] or "(message)" - ) - messages = [{"role": "user", "content": prompt}] - - if getattr(llm, "model", None) == "auto": - router = LLMRouterService.get_router() - response = await router.acompletion( - model="auto", messages=messages - ) - else: - # Apply the same ``api_base`` cascade chat / vision / - # image-gen call sites use so we never inherit - # ``litellm.api_base`` (commonly set by - # ``AZURE_OPENAI_ENDPOINT``) when the chat config - # itself ships an empty ``api_base``. Without this - # the title-gen on an OpenRouter chat config would - # 404 against the inherited Azure endpoint — see - # ``provider_api_base`` docstring for the same - # bug repro on the image-gen / vision paths. - raw_model = getattr(llm, "model", "") or "" - provider_prefix = ( - raw_model.split("/", 1)[0] if "/" in raw_model else None - ) - provider_value = ( - agent_config.provider if agent_config is not None else None - ) - title_api_base = resolve_api_base( - provider=provider_value, - provider_prefix=provider_prefix, - config_api_base=getattr(llm, "api_base", None), - ) - response = await acompletion( - model=raw_model, - messages=messages, - api_key=getattr(llm, "api_key", None), - api_base=title_api_base, - ) - - usage_info = None - usage = getattr(response, "usage", None) - if usage: - raw_model = getattr(llm, "model", "") or "" - model_name = ( - raw_model.split("/", 1)[-1] - if "/" in raw_model - else (raw_model or response.model or "unknown") - ) - usage_info = { - "model": model_name, - "prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0, - "completion_tokens": getattr(usage, "completion_tokens", 0) - or 0, - "total_tokens": getattr(usage, "total_tokens", 0) or 0, - } - - raw_title = response.choices[0].message.content.strip() - if raw_title and len(raw_title) <= 100: - return raw_title.strip("\"'"), usage_info - return None, usage_info - except Exception: - logging.getLogger(__name__).exception( - "[TitleGen] _generate_title failed" - ) - return None, None - - title_task = asyncio.create_task(_generate_title()) - - title_emitted = False - - # Build the per-invocation runtime context (Phase 1.5). - # ``mentioned_document_ids`` is read by ``KnowledgePriorityMiddleware`` - # via ``runtime.context.mentioned_document_ids`` instead of its - # ``__init__`` closure — that way the same compiled-agent instance - # can serve multiple turns with different mention lists. - runtime_context = SurfSenseContextSchema( - search_space_id=search_space_id, - mentioned_document_ids=list(mentioned_document_ids or []), - mentioned_folder_ids=list( - accepted_folder_ids or mentioned_folder_ids or [] - ), - mentioned_connector_ids=list(mentioned_connector_ids or []), - mentioned_connectors=list(mentioned_connectors or []), - request_id=request_id, - turn_id=stream_result.turn_id, - ) - - _t_stream_start = time.perf_counter() - _first_event_logged = False - runtime_rate_limit_recovered = False - while True: - try: - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=input_state, - streaming_service=streaming_service, - result=stream_result, - step_prefix="thinking", - initial_step_id=initial_step_id, - initial_step_title=initial_title, - initial_step_items=initial_items, - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - runtime_context=runtime_context, - content_builder=stream_result.content_builder, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_new_chat] First agent event in %.3fs (time since stream start), " - "%.3fs (total since request start) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, - ) - _first_event_logged = True - yield sse - - # Inject title update mid-stream as soon as the background - # task finishes. - if ( - title_task is not None - and title_task.done() - and not title_emitted - ): - generated_title, title_usage = title_task.result() - if title_usage: - accumulator.add(**title_usage) - if generated_title: - async with shielded_async_session() as title_session: - title_thread_result = await title_session.execute( - select(NewChatThread).filter( - NewChatThread.id == chat_id - ) - ) - title_thread = title_thread_result.scalars().first() - if title_thread: - title_thread.title = generated_title - await title_session.commit() - yield streaming_service.format_thread_title_update( - chat_id, generated_title - ) - title_emitted = True - break - except Exception as stream_exc: - can_runtime_recover = ( - not runtime_rate_limit_recovered - and requested_llm_config_id == 0 - and llm_config_id < 0 - and not _first_event_logged - and _is_provider_rate_limited(stream_exc) - ) - if not can_runtime_recover: - raise - - runtime_rate_limit_recovered = True - previous_config_id = llm_config_id - # The failed attempt may still hold the per-thread busy mutex - # (middleware teardown can lag behind raised provider errors). - # Force release before we retry within the same request. - end_turn(str(chat_id)) - mark_runtime_cooldown( - previous_config_id, - reason="provider_rate_limited", - ) - - llm_config_id = ( - await resolve_or_get_pinned_llm_config_id( - session, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - selected_llm_config_id=0, - exclude_config_ids={previous_config_id}, - requires_image_input=_requires_image_input, - ) - ).resolved_llm_config_id - - llm, agent_config, llm_load_error = await _load_llm_bundle( - llm_config_id - ) - if llm_load_error: - raise stream_exc - - # Title generation uses the initial llm object. After a runtime - # repin we keep the stream focused on response recovery and skip - # title generation for this turn. - if title_task is not None and not title_task.done(): - title_task.cancel() - title_task = None - - _t0 = time.perf_counter() - agent = await _build_main_agent_for_thread( - agent_factory, - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - filesystem_selection=filesystem_selection, - disabled_tools=disabled_tools, - mentioned_document_ids=mentioned_document_ids, - ) - _perf_log.info( - "[stream_new_chat] Runtime rate-limit recovery repinned " - "config_id=%s -> %s and rebuilt agent in %.3fs", - previous_config_id, - llm_config_id, - time.perf_counter() - _t0, - ) - ot.add_event( - "chat.rate_limit.recovered", - { - "recovery.reason": "provider_rate_limited", - "recovery.previous_config_id": previous_config_id, - "recovery.fallback_config_id": llm_config_id, - }, - ) - _log_chat_stream_error( - flow=flow, - error_kind="rate_limited", - error_code="RATE_LIMITED", - severity="info", - is_expected=True, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - message=( - "Auto-pinned model hit runtime rate limit; switched to " - "another eligible model and retried." - ), - extra={ - "auto_runtime_recover": True, - "previous_config_id": previous_config_id, - "fallback_config_id": llm_config_id, - }, - ) - continue - - _perf_log.info( - "[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)", - time.perf_counter() - _t_stream_start, - chat_id, - ) - log_system_snapshot("stream_new_chat_END") - - if stream_result.is_interrupted: - ot.add_event( - "chat.interrupted", - { - "chat.flow": flow, - }, - ) - if title_task is not None and not title_task.done(): - title_task.cancel() - - usage_summary = accumulator.per_message_summary() - _perf_log.info( - "[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s", - len(accumulator.calls), - accumulator.grand_total, - accumulator.total_cost_micros, - usage_summary, - ) - if usage_summary: - yield streaming_service.format_data( - "token-usage", - { - "usage": usage_summary, - "prompt_tokens": accumulator.total_prompt_tokens, - "completion_tokens": accumulator.total_completion_tokens, - "total_tokens": accumulator.grand_total, - "cost_micros": accumulator.total_cost_micros, - "call_details": accumulator.serialized_calls(), - }, - ) - - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - return - - # If the title task didn't finish during streaming, await it now - if title_task is not None and not title_emitted: - generated_title, title_usage = await title_task - if title_usage: - accumulator.add(**title_usage) - if generated_title: - async with shielded_async_session() as title_session: - title_thread_result = await title_session.execute( - select(NewChatThread).filter(NewChatThread.id == chat_id) - ) - title_thread = title_thread_result.scalars().first() - if title_thread: - title_thread.title = generated_title - await title_session.commit() - yield streaming_service.format_thread_title_update( - chat_id, generated_title - ) - - # Finalize premium credit debit with the actual provider cost - # reported by LiteLLM, summed across every call in the turn. - # Mirrors the pre-cost behaviour of "premium turn → all calls - # count" so free sub-agent calls during a premium turn still - # contribute to the bill (they're $0 in practice anyway). - if _premium_request_id and user_id: - try: - from app.services.token_quota_service import TokenQuotaService - - async with shielded_async_session() as quota_session: - await TokenQuotaService.premium_finalize( - db_session=quota_session, - user_id=UUID(user_id), - request_id=_premium_request_id, - actual_micros=accumulator.total_cost_micros, - reserved_micros=_premium_reserved_micros, - ) - _premium_request_id = None - _premium_reserved_micros = 0 - except Exception: - logging.getLogger(__name__).warning( - "Failed to finalize premium quota for user %s", - user_id, - exc_info=True, - ) - - usage_summary = accumulator.per_message_summary() - _perf_log.info( - "[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s", - len(accumulator.calls), - accumulator.grand_total, - accumulator.total_cost_micros, - usage_summary, - ) - if usage_summary: - yield streaming_service.format_data( - "token-usage", - { - "usage": usage_summary, - "prompt_tokens": accumulator.total_prompt_tokens, - "completion_tokens": accumulator.total_completion_tokens, - "total_tokens": accumulator.grand_total, - "cost_micros": accumulator.total_cost_micros, - "call_details": accumulator.serialized_calls(), - }, - ) - - # Finish the step and message - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - - except Exception as e: - # Handle any errors - import traceback - - # ``BusyError`` fires before the agent acquires the lock; the - # cleanup path must skip lock release to avoid freeing the - # in-flight caller's lock. Classification is handled below. - if isinstance(e, BusyError): - _busy_error_raised = True - - ( - error_kind, - error_code, - severity, - is_expected, - user_message, - error_extra, - ) = _classify_stream_exception(e, flow_label="chat") - chat_outcome = error_code or error_kind or "error" - chat_error_category = ot_metrics.categorize_exception(e) - with contextlib.suppress(Exception): - chat_span.set_attribute("chat.outcome", chat_outcome) - chat_span.set_attribute("error.category", chat_error_category) - ot.record_error(chat_span, e) - error_message = f"Error during chat: {e!s}" - print(f"[stream_new_chat] {error_message}") - print(f"[stream_new_chat] Exception type: {type(e).__name__}") - print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}") - if error_code == "TURN_CANCELLING": - status_payload: dict[str, Any] = {"status": "cancelling"} - if error_extra: - status_payload.update(error_extra) - yield streaming_service.format_data("turn-status", status_payload) - else: - yield streaming_service.format_data("turn-status", {"status": "busy"}) - - yield _emit_stream_error( - message=user_message, - error_kind=error_kind, - error_code=error_code, - severity=severity, - is_expected=is_expected, - extra=error_extra, - ) - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - - finally: - # Shield the ENTIRE async cleanup from anyio cancel-scope - # cancellation. Starlette's BaseHTTPMiddleware uses anyio task - # groups; on client disconnect, it cancels the scope with - # level-triggered cancellation — every unshielded `await` inside - # the cancelled scope raises CancelledError immediately. Without - # this shield the very first `await` (session.rollback) would - # raise CancelledError, `except Exception` wouldn't catch it - # (CancelledError is a BaseException), and the rest of the - # finally block — including session.close() — would never run. - with anyio.CancelScope(shield=True): - # Authoritative fallback cleanup for lock/cancel state. Middleware - # teardown can be skipped on some client-abort paths. - end_turn(str(chat_id)) - - # Release premium reservation if not finalized - if _premium_request_id and _premium_reserved_micros > 0 and user_id: - try: - from app.services.token_quota_service import TokenQuotaService - - async with shielded_async_session() as quota_session: - await TokenQuotaService.premium_release( - db_session=quota_session, - user_id=UUID(user_id), - reserved_micros=_premium_reserved_micros, - ) - _premium_reserved_micros = 0 - except Exception: - logging.getLogger(__name__).warning( - "Failed to release premium quota for user %s", user_id - ) - - try: - await session.rollback() - await clear_ai_responding(session, chat_id) - except Exception: - try: - async with shielded_async_session() as fresh_session: - await clear_ai_responding(fresh_session, chat_id) - except Exception: - logging.getLogger(__name__).warning( - "Failed to clear AI responding state for thread %s", chat_id - ) - - with contextlib.suppress(Exception): - session.expunge_all() - - with contextlib.suppress(Exception): - await session.close() - - # Server-side assistant-message + token_usage finalization. - # Runs after the main session has been closed (uses its own - # shielded session) so we don't fight the same DB connection. - # Idempotent against the legacy frontend appendMessage: - # * the assistant row was already INSERTed by - # ``persist_assistant_shell`` above, so this just UPDATEs - # it with the rich ContentPart[] from the builder. - # * token_usage uses INSERT ... ON CONFLICT DO NOTHING - # against migration 142's partial unique index, so a - # racing append_message recovery branch can never - # double-write. - # ``mark_interrupted`` closes any open text/reasoning blocks - # and flips running tool-calls (no result) to state=aborted - # so the persisted JSONB reflects a coherent end-state even - # on client disconnect. - # Never raises (best-effort, logs only). - if ( - stream_result - and stream_result.turn_id - and stream_result.assistant_message_id - ): - from app.tasks.chat.persistence import finalize_assistant_turn - - builder_stats: dict[str, int] | None = None - if stream_result.content_builder is not None: - stream_result.content_builder.mark_interrupted() - # Snapshot stats BEFORE deepcopy in ``snapshot()`` so - # the perf log records the actual finalised payload - # (post-mark_interrupted), not the live-mutating - # builder state. - builder_stats = stream_result.content_builder.stats() - content_payload = stream_result.content_builder.snapshot() - else: - # Defensive fallback — we always set the builder - # alongside ``assistant_message_id`` above, so this - # branch only fires if a future refactor ever - # decouples them. Persist whatever accumulated - # text we captured so the row at least renders. - content_payload = [ - { - "type": "text", - "text": stream_result.accumulated_text or "", - } - ] - - if builder_stats is not None: - _perf_log.info( - "[stream_new_chat] finalize_payload chat_id=%s " - "message_id=%s parts=%d bytes=%d text=%d " - "reasoning=%d tool_calls=%d " - "tool_calls_completed=%d tool_calls_aborted=%d " - "thinking_step_parts=%d step_separators=%d", - chat_id, - stream_result.assistant_message_id, - builder_stats["parts"], - builder_stats["bytes"], - builder_stats["text"], - builder_stats["reasoning"], - builder_stats["tool_calls"], - builder_stats["tool_calls_completed"], - builder_stats["tool_calls_aborted"], - builder_stats["thinking_step_parts"], - builder_stats["step_separators"], - ) - - await finalize_assistant_turn( - message_id=stream_result.assistant_message_id, - chat_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - turn_id=stream_result.turn_id, - content=content_payload, - accumulator=accumulator, - ) - - # Persist any sandbox-produced files to local storage so they - # remain downloadable after the Daytona sandbox auto-deletes. - if stream_result and stream_result.sandbox_files: - with contextlib.suppress(Exception): - from app.agents.new_chat.sandbox import ( - is_sandbox_enabled, - persist_and_delete_sandbox, - ) - - if is_sandbox_enabled(): - with anyio.CancelScope(shield=True): - await persist_and_delete_sandbox( - chat_id, stream_result.sandbox_files - ) - - # ``aafter_agent`` doesn't fire on ``interrupt()`` or early bailout. - # Skip on ``BusyError`` (caller never acquired the lock). - if not _busy_error_raised: - with contextlib.suppress(Exception): - end_turn(str(chat_id)) - _perf_log.info( - "[stream_new_chat] end_turn cleanup (chat_id=%s)", - chat_id, - ) - - # Break circular refs held by the agent graph, tools, and LLM - # wrappers so the GC can reclaim them in a single pass. - agent = llm = connector_service = None - input_state = stream_result = None - session = None - - collected = gc.collect(0) + gc.collect(1) + gc.collect(2) - if collected: - _perf_log.info( - "[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)", - collected, - chat_id, - ) - trim_native_heap() - log_system_snapshot("stream_new_chat_END") - with contextlib.suppress(Exception): - chat_span.set_attribute("chat.outcome", chat_outcome) - ot_metrics.record_chat_request_duration( - (time.perf_counter() - _t_total) * 1000, - flow=flow, - outcome=chat_outcome, - agent_mode=chat_agent_mode, - ) - ot_metrics.record_chat_request_outcome( - flow=flow, - outcome=chat_outcome, - agent_mode=chat_agent_mode, - error_category=chat_error_category, - ) - chat_span_cm.__exit__(*sys.exc_info()) - - -async def stream_resume_chat( - chat_id: int, - search_space_id: int, - decisions: list[dict], - user_id: str | None = None, - llm_config_id: int = -1, - thread_visibility: ChatVisibility | None = None, - filesystem_selection: FilesystemSelection | None = None, - request_id: str | None = None, - disabled_tools: list[str] | None = None, -) -> AsyncGenerator[str, None]: - streaming_service = VercelStreamingService() - stream_result = StreamResult() - _t_total = time.perf_counter() - fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud" - fs_platform = ( - filesystem_selection.client_platform.value if filesystem_selection else "web" - ) - stream_result.request_id = request_id - stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}" - stream_result.filesystem_mode = fs_mode - stream_result.client_platform = fs_platform - chat_agent_mode = "unknown" - chat_outcome = "success" - chat_error_category: str | None = None - chat_span_cm = ot.chat_request_span( - chat_id=chat_id, - search_space_id=search_space_id, - flow="resume", - request_id=request_id, - turn_id=stream_result.turn_id, - filesystem_mode=fs_mode, - client_platform=fs_platform, - agent_mode=chat_agent_mode, - ) - chat_span = chat_span_cm.__enter__() - _log_file_contract("turn_start", stream_result) - _perf_log.info( - "[stream_resume] filesystem_mode=%s client_platform=%s", - fs_mode, - fs_platform, - ) - from app.services.token_tracking_service import start_turn - - accumulator = start_turn() - - # Skip the finally release on ``BusyError`` (caller never acquired the lock). - _busy_error_raised = False - - _emit_stream_error = partial( - _emit_stream_terminal_error, - streaming_service=streaming_service, - flow="resume", - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - ) - - session = async_session_maker() - try: - if user_id: - await set_ai_responding(session, chat_id, UUID(user_id)) - - agent_config: AgentConfig | None = None - requested_llm_config_id = llm_config_id - - async def _load_llm_bundle( - config_id: int, - ) -> tuple[Any, AgentConfig | None, str | None]: - if config_id >= 0: - loaded_agent_config = await load_agent_config( - session=session, - config_id=config_id, - search_space_id=search_space_id, - ) - if not loaded_agent_config: - return ( - None, - None, - f"Failed to load NewLLMConfig with id {config_id}", - ) - return ( - create_chat_litellm_from_agent_config(loaded_agent_config), - loaded_agent_config, - None, - ) - - loaded_llm_config = load_global_llm_config_by_id(config_id) - if not loaded_llm_config: - return None, None, f"Failed to load LLM config with id {config_id}" - return ( - create_chat_litellm_from_config(loaded_llm_config), - AgentConfig.from_yaml_config(loaded_llm_config), - None, - ) - - _t0 = time.perf_counter() - try: - llm_config_id = ( - await resolve_or_get_pinned_llm_config_id( - session, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - selected_llm_config_id=llm_config_id, - ) - ).resolved_llm_config_id - ot.add_event( - "model.pin.resolved", - { - "pin.requested_id": requested_llm_config_id, - "pin.resolved_id": llm_config_id, - "pin.requires_image_input": False, - }, - ) - except ValueError as pin_error: - yield _emit_stream_error( - message=str(pin_error), - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - - llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) - if llm_load_error: - yield _emit_stream_error( - message=llm_load_error, - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - _perf_log.info( - "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 - ) - - # Premium credit reservation (same logic as stream_new_chat). - _resume_premium_reserved_micros = 0 - _resume_premium_request_id: str | None = None - _resume_needs_premium = ( - agent_config is not None and user_id and agent_config.is_premium - ) - if _resume_needs_premium: - import uuid as _uuid - - from app.services.token_quota_service import ( - TokenQuotaService, - estimate_call_reserve_micros, - ) - - _resume_premium_request_id = _uuid.uuid4().hex[:16] - _resume_litellm_params = agent_config.litellm_params or {} - _resume_base_model = ( - _resume_litellm_params.get("base_model") - or agent_config.model_name - or "" - ) - reserve_amount_micros = estimate_call_reserve_micros( - base_model=_resume_base_model, - quota_reserve_tokens=agent_config.quota_reserve_tokens, - ) - async with shielded_async_session() as quota_session: - quota_result = await TokenQuotaService.premium_reserve( - db_session=quota_session, - user_id=UUID(user_id), - request_id=_resume_premium_request_id, - reserve_micros=reserve_amount_micros, - ) - _resume_premium_reserved_micros = reserve_amount_micros - if not quota_result.allowed: - ot.add_event( - "quota.denied", - { - "quota.code": "PREMIUM_QUOTA_EXHAUSTED", - }, - ) - if requested_llm_config_id == 0: - try: - llm_config_id = ( - await resolve_or_get_pinned_llm_config_id( - session, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - selected_llm_config_id=0, - force_repin_free=True, - ) - ).resolved_llm_config_id - ot.add_event( - "model.repin", - { - "repin.reason": "premium_quota_exhausted", - "repin.to_config_id": llm_config_id, - }, - ) - except ValueError as pin_error: - yield _emit_stream_error( - message=str(pin_error), - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - - llm, agent_config, llm_load_error = await _load_llm_bundle( - llm_config_id - ) - if llm_load_error: - yield _emit_stream_error( - message=llm_load_error, - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - _resume_premium_request_id = None - _resume_premium_reserved_micros = 0 - _log_chat_stream_error( - flow="resume", - error_kind="premium_quota_exhausted", - error_code="PREMIUM_QUOTA_EXHAUSTED", - severity="info", - is_expected=True, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - message=( - "Premium quota exhausted on pinned model; auto-fallback switched to a free model" - ), - extra={ - "fallback_config_id": llm_config_id, - "auto_fallback": True, - }, - ) - else: - yield _emit_stream_error( - message=( - "Buy more tokens to continue with this model, or switch to a free model" - ), - error_kind="premium_quota_exhausted", - error_code="PREMIUM_QUOTA_EXHAUSTED", - severity="info", - is_expected=True, - extra={ - "resolved_config_id": llm_config_id, - "auto_fallback": False, - }, - ) - yield streaming_service.format_done() - return - - if not llm: - yield _emit_stream_error( - message="Failed to create LLM instance", - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - - _t0 = time.perf_counter() - connector_service = ConnectorService(session, search_space_id=search_space_id) - - firecrawl_api_key = None - webcrawler_connector = await connector_service.get_connector_by_type( - SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id - ) - if webcrawler_connector and webcrawler_connector.config: - firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") - _perf_log.info( - "[stream_resume] Connector service + firecrawl key in %.3fs", - time.perf_counter() - _t0, - ) - - _t0 = time.perf_counter() - checkpointer = await get_checkpointer() - _perf_log.info( - "[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0 - ) - - visibility = thread_visibility or ChatVisibility.PRIVATE - from app.config import config as _app_config - - chat_agent_mode = "multi" if _app_config.MULTI_AGENT_CHAT_ENABLED else "single" - with contextlib.suppress(Exception): - chat_span.set_attribute("agent.mode", chat_agent_mode) - _t0 = time.perf_counter() - agent_factory = ( - create_multi_agent_chat_deep_agent - if _app_config.MULTI_AGENT_CHAT_ENABLED - else create_surfsense_deep_agent - ) - # Build the agent inline. Provider 429s are handled by the - # in-stream recovery loop, which repins to an eligible - # alternative config and rebuilds the agent before the user sees - # any output. - agent = await _build_main_agent_for_thread( - agent_factory, - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - filesystem_selection=filesystem_selection, - disabled_tools=disabled_tools, - ) - _perf_log.info( - "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 - ) - - # Release the transaction before streaming (same rationale as stream_new_chat). - await session.commit() - session.expunge_all() - - _perf_log.info( - "[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)", - time.perf_counter() - _t_total, - chat_id, - ) - - from langgraph.types import Command - - from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( - build_lg_resume_map, - collect_pending_tool_calls, - slice_decisions_by_tool_call, - ) - - # Each pending interrupt is stamped with its originating ``tool_call_id`` - # (see ``checkpointed_subagent_middleware.propagation``) so we can route - # a flat ``decisions`` list back to the right paused subagent. - parent_state = await agent.aget_state( - {"configurable": {"thread_id": str(chat_id)}} - ) - pending = collect_pending_tool_calls(parent_state) - _perf_log.info( - "[hitl_route] resume_entry chat_id=%s decisions=%d pending_subagents=%d", - chat_id, - len(decisions), - len(pending), - ) - routed_resume_value = slice_decisions_by_tool_call(decisions, pending) - # Langgraph rejects scalar ``Command(resume=...)`` when multiple - # interrupts are pending (parallel HITL); the mapped form works - # for the single-pause case too, so we always use it. - lg_resume_map = build_lg_resume_map(parent_state, routed_resume_value) - - config = { - "configurable": { - "thread_id": str(chat_id), - "request_id": request_id or "unknown", - "turn_id": stream_result.turn_id, - # Per-``tool_call_id`` resume slices read by - # ``SurfSenseCheckpointedSubAgentMiddleware``. Parallel - # siblings each pop their own entry, so they never race. - "surfsense_resume_value": routed_resume_value, - }, - # See ``stream_new_chat`` above for rationale: effectively - # uncapped to mirror the agent default and OpenCode's - # session loop. Doom-loop / call-limit middleware enforce - # the real ceiling. - "recursion_limit": 10_000, - } - - yield streaming_service.format_message_start() - yield streaming_service.format_start_step() - # Same rationale as ``stream_new_chat``: emit the turn id so - # resumed streams can be persisted with their correlation id - # intact. - yield streaming_service.format_data( - "turn-info", - {"chat_turn_id": stream_result.turn_id}, - ) - yield streaming_service.format_data("turn-status", {"status": "busy"}) - - # Pre-write a fresh assistant row for this resume turn. The - # original (interrupted) ``stream_new_chat`` invocation already - # persisted its own assistant row anchored to a different - # ``turn_id``; resume allocates a new ``turn_id`` (above) so we - # need a separate row keyed on the same ``(thread_id, turn_id, - # ASSISTANT)`` invariant. Idempotent against migration 141's - # partial unique index — recovers existing id on retry. - from app.tasks.chat.content_builder import AssistantContentBuilder - from app.tasks.chat.persistence import persist_assistant_shell - - assistant_message_id = await persist_assistant_shell( - chat_id=chat_id, - user_id=user_id, - turn_id=stream_result.turn_id, - ) - if assistant_message_id is None: - yield _emit_stream_error( - message=( - "We couldn't initialize the assistant message. Please try again." - ), - error_kind="server_error", - error_code="MESSAGE_PERSIST_FAILED", - ) - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - return - - # Emit canonical assistant message id BEFORE any LLM streaming - # so the FE can rename ``pendingInterrupt.assistantMsgId`` to - # ``msg-{assistant_message_id}`` immediately. Resume does NOT - # emit ``data-user-message-id`` because the user row is from - # the original interrupted turn (different ``turn_id``) and is - # never re-persisted here. See B5 in the - # ``sse-based_message_id_handshake`` plan. - yield streaming_service.format_data( - "assistant-message-id", - {"message_id": assistant_message_id, "turn_id": stream_result.turn_id}, - ) - - stream_result.assistant_message_id = assistant_message_id - stream_result.content_builder = AssistantContentBuilder() - - # Resume path doesn't carry new ``mentioned_document_ids`` — - # those are seeded in the original turn. We still pass a - # context so future middleware extensions (Phase 2) can rely on - # ``runtime.context`` always being populated. - runtime_context = SurfSenseContextSchema( - search_space_id=search_space_id, - request_id=request_id, - turn_id=stream_result.turn_id, - ) - - _t_stream_start = time.perf_counter() - _first_event_logged = False - runtime_rate_limit_recovered = False - while True: - try: - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=Command(resume=lg_resume_map), - streaming_service=streaming_service, - result=stream_result, - step_prefix=_resume_step_prefix(stream_result.turn_id), - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - runtime_context=runtime_context, - content_builder=stream_result.content_builder, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, - ) - _first_event_logged = True - yield sse - break - except Exception as stream_exc: - can_runtime_recover = ( - not runtime_rate_limit_recovered - and requested_llm_config_id == 0 - and llm_config_id < 0 - and not _first_event_logged - and _is_provider_rate_limited(stream_exc) - ) - if not can_runtime_recover: - raise - - runtime_rate_limit_recovered = True - previous_config_id = llm_config_id - # Ensure the same-request recovery retry does not trip the - # BusyMutex lock retained by the failed attempt. - end_turn(str(chat_id)) - mark_runtime_cooldown( - previous_config_id, - reason="provider_rate_limited", - ) - llm_config_id = ( - await resolve_or_get_pinned_llm_config_id( - session, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - selected_llm_config_id=0, - exclude_config_ids={previous_config_id}, - ) - ).resolved_llm_config_id - - llm, agent_config, llm_load_error = await _load_llm_bundle( - llm_config_id - ) - if llm_load_error: - raise stream_exc - - _t0 = time.perf_counter() - agent = await _build_main_agent_for_thread( - agent_factory, - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - filesystem_selection=filesystem_selection, - disabled_tools=disabled_tools, - ) - _perf_log.info( - "[stream_resume] Runtime rate-limit recovery repinned " - "config_id=%s -> %s and rebuilt agent in %.3fs", - previous_config_id, - llm_config_id, - time.perf_counter() - _t0, - ) - ot.add_event( - "chat.rate_limit.recovered", - { - "recovery.reason": "provider_rate_limited", - "recovery.previous_config_id": previous_config_id, - "recovery.fallback_config_id": llm_config_id, - }, - ) - _log_chat_stream_error( - flow="resume", - error_kind="rate_limited", - error_code="RATE_LIMITED", - severity="info", - is_expected=True, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - message=( - "Auto-pinned model hit runtime rate limit; switched to " - "another eligible model and retried." - ), - extra={ - "auto_runtime_recover": True, - "previous_config_id": previous_config_id, - "fallback_config_id": llm_config_id, - }, - ) - continue - _perf_log.info( - "[stream_resume] Agent stream completed in %.3fs (chat_id=%s)", - time.perf_counter() - _t_stream_start, - chat_id, - ) - if stream_result.is_interrupted: - ot.add_event( - "chat.interrupted", - { - "chat.flow": "resume", - }, - ) - usage_summary = accumulator.per_message_summary() - _perf_log.info( - "[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s", - len(accumulator.calls), - accumulator.grand_total, - accumulator.total_cost_micros, - usage_summary, - ) - if usage_summary: - yield streaming_service.format_data( - "token-usage", - { - "usage": usage_summary, - "prompt_tokens": accumulator.total_prompt_tokens, - "completion_tokens": accumulator.total_completion_tokens, - "total_tokens": accumulator.grand_total, - "cost_micros": accumulator.total_cost_micros, - "call_details": accumulator.serialized_calls(), - }, - ) - - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - return - - # Finalize premium credit debit for resume path with the actual - # provider cost reported by LiteLLM (sum of cost across all - # calls in the turn). - if _resume_premium_request_id and user_id: - try: - from app.services.token_quota_service import TokenQuotaService - - async with shielded_async_session() as quota_session: - await TokenQuotaService.premium_finalize( - db_session=quota_session, - user_id=UUID(user_id), - request_id=_resume_premium_request_id, - actual_micros=accumulator.total_cost_micros, - reserved_micros=_resume_premium_reserved_micros, - ) - _resume_premium_request_id = None - _resume_premium_reserved_micros = 0 - except Exception: - logging.getLogger(__name__).warning( - "Failed to finalize premium quota for user %s (resume)", - user_id, - exc_info=True, - ) - - usage_summary = accumulator.per_message_summary() - _perf_log.info( - "[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s", - len(accumulator.calls), - accumulator.grand_total, - accumulator.total_cost_micros, - usage_summary, - ) - if usage_summary: - yield streaming_service.format_data( - "token-usage", - { - "usage": usage_summary, - "prompt_tokens": accumulator.total_prompt_tokens, - "completion_tokens": accumulator.total_completion_tokens, - "total_tokens": accumulator.grand_total, - "cost_micros": accumulator.total_cost_micros, - "call_details": accumulator.serialized_calls(), - }, - ) - - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - - except Exception as e: - import traceback - - # ``BusyError`` fires before the agent acquires the lock; the - # cleanup path must skip lock release to avoid freeing the - # in-flight caller's lock. Classification is handled below. - if isinstance(e, BusyError): - _busy_error_raised = True - - ( - error_kind, - error_code, - severity, - is_expected, - user_message, - error_extra, - ) = _classify_stream_exception(e, flow_label="resume") - chat_outcome = error_code or error_kind or "error" - chat_error_category = ot_metrics.categorize_exception(e) - with contextlib.suppress(Exception): - chat_span.set_attribute("chat.outcome", chat_outcome) - chat_span.set_attribute("error.category", chat_error_category) - ot.record_error(chat_span, e) - error_message = f"Error during resume: {e!s}" - print(f"[stream_resume_chat] {error_message}") - print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") - if error_code == "TURN_CANCELLING": - status_payload: dict[str, Any] = {"status": "cancelling"} - if error_extra: - status_payload.update(error_extra) - yield streaming_service.format_data("turn-status", status_payload) - else: - yield streaming_service.format_data("turn-status", {"status": "busy"}) - yield _emit_stream_error( - message=user_message, - error_kind=error_kind, - error_code=error_code, - severity=severity, - is_expected=is_expected, - extra=error_extra, - ) - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - - finally: - with anyio.CancelScope(shield=True): - # Authoritative fallback cleanup for lock/cancel state. Middleware - # teardown can be skipped on some client-abort paths. - end_turn(str(chat_id)) - - # Release premium reservation if not finalized - if ( - _resume_premium_request_id - and _resume_premium_reserved_micros > 0 - and user_id - ): - try: - from app.services.token_quota_service import TokenQuotaService - - async with shielded_async_session() as quota_session: - await TokenQuotaService.premium_release( - db_session=quota_session, - user_id=UUID(user_id), - reserved_micros=_resume_premium_reserved_micros, - ) - _resume_premium_reserved_micros = 0 - except Exception: - logging.getLogger(__name__).warning( - "Failed to release premium quota for user %s (resume)", user_id - ) - - try: - await session.rollback() - await clear_ai_responding(session, chat_id) - except Exception: - try: - async with shielded_async_session() as fresh_session: - await clear_ai_responding(fresh_session, chat_id) - except Exception: - logging.getLogger(__name__).warning( - "Failed to clear AI responding state for thread %s", chat_id - ) - - with contextlib.suppress(Exception): - session.expunge_all() - - with contextlib.suppress(Exception): - await session.close() - - # Server-side assistant-message + token_usage finalization for - # the resume flow. The original user message was persisted by - # the original (interrupted) ``stream_new_chat`` invocation; - # the resume's own ``persist_assistant_shell`` write lives at - # the new ``turn_id`` above. This finalize updates that row - # with the rich ContentPart[] from the builder and writes - # token_usage idempotently via migration 142's partial - # unique index. Best-effort, never raises. - if ( - stream_result - and stream_result.turn_id - and stream_result.assistant_message_id - ): - from app.tasks.chat.persistence import finalize_assistant_turn - - builder_stats: dict[str, int] | None = None - if stream_result.content_builder is not None: - stream_result.content_builder.mark_interrupted() - builder_stats = stream_result.content_builder.stats() - content_payload = stream_result.content_builder.snapshot() - else: - content_payload = [ - { - "type": "text", - "text": stream_result.accumulated_text or "", - } - ] - - if builder_stats is not None: - _perf_log.info( - "[stream_resume] finalize_payload chat_id=%s " - "message_id=%s parts=%d bytes=%d text=%d " - "reasoning=%d tool_calls=%d " - "tool_calls_completed=%d tool_calls_aborted=%d " - "thinking_step_parts=%d step_separators=%d", - chat_id, - stream_result.assistant_message_id, - builder_stats["parts"], - builder_stats["bytes"], - builder_stats["text"], - builder_stats["reasoning"], - builder_stats["tool_calls"], - builder_stats["tool_calls_completed"], - builder_stats["tool_calls_aborted"], - builder_stats["thinking_step_parts"], - builder_stats["step_separators"], - ) - - await finalize_assistant_turn( - message_id=stream_result.assistant_message_id, - chat_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - turn_id=stream_result.turn_id, - content=content_payload, - accumulator=accumulator, - ) - - # Release the lock from the original interrupted turn or any - # re-interrupt/bailout. Skip on ``BusyError`` (lock not held here). - if not _busy_error_raised: - with contextlib.suppress(Exception): - end_turn(str(chat_id)) - _perf_log.info( - "[stream_resume] end_turn cleanup (chat_id=%s)", - chat_id, - ) - - agent = llm = connector_service = None - stream_result = None - session = None - - collected = gc.collect(0) + gc.collect(1) + gc.collect(2) - if collected: - _perf_log.info( - "[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)", - collected, - chat_id, - ) - trim_native_heap() - log_system_snapshot("stream_resume_chat_END") - with contextlib.suppress(Exception): - chat_span.set_attribute("chat.outcome", chat_outcome) - ot_metrics.record_chat_request_duration( - (time.perf_counter() - _t_total) * 1000, - flow="resume", - outcome=chat_outcome, - agent_mode=chat_agent_mode, - ) - ot_metrics.record_chat_request_outcome( - flow="resume", - outcome=chat_outcome, - agent_mode=chat_agent_mode, - error_category=chat_error_category, - ) - chat_span_cm.__exit__(*sys.exc_info()) diff --git a/surfsense_backend/app/tasks/chat/streaming/agent/builder.py b/surfsense_backend/app/tasks/chat/streaming/agent/builder.py index 0db42edbf..dcbd37521 100644 --- a/surfsense_backend/app/tasks/chat/streaming/agent/builder.py +++ b/surfsense_backend/app/tasks/chat/streaming/agent/builder.py @@ -9,8 +9,10 @@ from __future__ import annotations from typing import Any -from app.agents.new_chat.filesystem_selection import FilesystemSelection -from app.agents.new_chat.llm_config import AgentConfig +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( + FilesystemSelection, +) +from app.agents.chat.runtime.llm_config import AgentConfig from app.db import ChatVisibility from app.services.connector_service import ConnectorService diff --git a/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py index b77bd3890..d96144bcd 100644 --- a/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py +++ b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py @@ -11,10 +11,10 @@ from __future__ import annotations from collections.abc import AsyncGenerator from typing import Any -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.middleware.kb_persistence import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.kb_persistence import ( commit_staged_filesystem_state, ) +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode from app.services.new_streaming_service import VercelStreamingService from app.tasks.chat.streaming.contract.file_contract import ( contract_enforcement_active, diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py index 3af2b9f9f..6b37df343 100644 --- a/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py +++ b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py @@ -7,11 +7,11 @@ import logging import time from typing import Any, Literal -from app.agents.new_chat.errors import BusyError -from app.agents.new_chat.middleware.busy_mutex import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import ( get_cancel_state, is_cancel_requested, ) +from app.agents.chat.runtime.errors import BusyError TURN_CANCELLING_INITIAL_DELAY_MS = 200 TURN_CANCELLING_BACKOFF_FACTOR = 2 diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/auto_pin.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/auto_pin.py index af496cee7..dbb8ee2e4 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/auto_pin.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/auto_pin.py @@ -50,8 +50,14 @@ async def resolve_initial_auto_pin( selected_llm_config_id: int, requires_image_input: bool, requested_llm_config_id: int, + force_repin_free: bool = False, ) -> AutoPinResult: - """Run the resolver and classify any ``ValueError`` for the SSE error path.""" + """Run the resolver and classify any ``ValueError`` for the SSE error path. + + ``force_repin_free`` forces a fresh re-pin to a free-tier config (used on + the premium-quota-exhausted fallback so an out-of-quota user isn't repinned + onto another paid model). + """ try: pinned = await resolve_or_get_pinned_llm_config_id( session, @@ -60,6 +66,7 @@ async def resolve_initial_auto_pin( user_id=user_id, selected_llm_config_id=selected_llm_config_id, requires_image_input=requires_image_input, + force_repin_free=force_repin_free, ) ot.add_event( "model.pin.resolved", diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/input_state.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/input_state.py index c9ef6edd6..064843aba 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/input_state.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/input_state.py @@ -9,9 +9,9 @@ Pipeline: can resolve ``report_id`` for versioning without spelunking history. 3. **@-mention resolve** (cloud mode) — substitute ``@title`` tokens in the query with canonical ``\`/documents/...\``` paths the LLM expects. - 4. **Context block render** — XML-wrap recent reports, prepend to the - rewritten query, optionally prefix with display name for SEARCH_SPACE - visibility. + 4. **Context block render** — XML-wrap @-mentioned connectors and recent + reports, prepend to the rewritten query, optionally prefix with display + name for SEARCH_SPACE visibility. 5. **HumanMessage** — multimodal content if images are attached. Returns the assembled ``input_state`` dict plus side-channel data the @@ -28,8 +28,11 @@ from langchain_core.messages import HumanMessage from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.runtime.mention_resolver import ( + resolve_mentions, + substitute_in_text, +) from app.db import ( ChatVisibility, NewChatThread, @@ -201,7 +204,10 @@ def _render_query_with_context( mentioned_connectors: list[dict[str, Any]] | None, recent_reports: list[Report], ) -> str: - """Prepend connector/report XML context blocks to the user query.""" + """Prepend the ```` then ```` blocks. + + Order is load-bearing for legacy parity. + """ context_parts: list[str] = [] connector_context = _render_mentioned_connectors(mentioned_connectors) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py index 9f4e5d2d8..69b9f4ab8 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py @@ -15,7 +15,7 @@ tells the user what to change. from __future__ import annotations -from app.agents.new_chat.llm_config import AgentConfig +from app.agents.chat.runtime.llm_config import AgentConfig from app.observability import otel as ot diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py index 9c25218bf..e33dca376 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py @@ -29,11 +29,12 @@ from typing import Any, Literal import anyio -from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent -from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent -from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection -from app.agents.new_chat.middleware.busy_mutex import end_turn -from app.config import config as _app_config +from app.agents.chat.multi_agent_chat import create_multi_agent_chat_deep_agent +from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import end_turn +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( + FilesystemMode, + FilesystemSelection, +) from app.db import ChatVisibility, async_session_maker from app.observability import otel as ot from app.services.new_streaming_service import VercelStreamingService @@ -274,6 +275,7 @@ async def stream_new_chat( selected_llm_config_id=0, requires_image_input=requires_image_input, requested_llm_config_id=requested_llm_config_id, + force_repin_free=True, ) if pin_fallback.error is not None: message, error_code, error_kind = pin_fallback.error @@ -369,12 +371,6 @@ async def stream_new_chat( mentioned_documents=mentioned_documents, background_tasks=_background_tasks, ) - persist_asst_task = spawn_persist_assistant_shell_task( - chat_id=chat_id, - user_id=user_id, - turn_id=stream_result.turn_id, - background_tasks=_background_tasks, - ) _t0 = time.perf_counter() connector_service, firecrawl_api_key = await setup_connector_and_firecrawl( @@ -392,16 +388,11 @@ async def stream_new_chat( ) visibility = thread_visibility or ChatVisibility.PRIVATE - use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED) - chat_agent_mode = "multi" if use_multi_agent else "single" + chat_agent_mode = "multi" set_agent_mode(chat_span, chat_agent_mode) _t0 = time.perf_counter() - agent_factory = ( - create_multi_agent_chat_deep_agent - if use_multi_agent - else create_surfsense_deep_agent - ) + agent_factory = create_multi_agent_chat_deep_agent # Build the agent inline. Provider 429s surface through the in-stream # recovery loop below, which repins the thread to an eligible # alternative config and rebuilds the agent before the user sees any @@ -526,6 +517,14 @@ async def stream_new_chat( {"message_id": user_message_id, "turn_id": stream_result.turn_id}, ) + # Spawned only after the user row is confirmed, so a user-persist + # failure can't orphan an assistant shell on the same turn. + persist_asst_task = spawn_persist_assistant_shell_task( + chat_id=chat_id, + user_id=user_id, + turn_id=stream_result.turn_id, + background_tasks=_background_tasks, + ) assistant_message_id = await await_persist_task( persist_asst_task, chat_id=chat_id, @@ -830,7 +829,7 @@ async def stream_new_chat( # downloadable after the Daytona sandbox auto-deletes. if stream_result and stream_result.sandbox_files: with contextlib.suppress(Exception): - from app.agents.new_chat.sandbox import ( + from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox import ( is_sandbox_enabled, persist_and_delete_sandbox, ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/runtime_context.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/runtime_context.py index 66233dec8..195a16b1e 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/runtime_context.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/runtime_context.py @@ -8,9 +8,7 @@ mention lists / request ids / turn ids without rebuilding the graph. from __future__ import annotations -from typing import Any - -from app.agents.new_chat.context import SurfSenseContextSchema +from app.agents.chat.shared.context import SurfSenseContextSchema def build_new_chat_runtime_context( @@ -20,7 +18,7 @@ def build_new_chat_runtime_context( accepted_folder_ids: list[int], mentioned_folder_ids: list[int] | None, mentioned_connector_ids: list[int] | None, - mentioned_connectors: list[dict[str, Any]] | None, + mentioned_connectors: list[dict[str, object]] | None, request_id: str | None, turn_id: str, ) -> SurfSenseContextSchema: @@ -30,6 +28,9 @@ def build_new_chat_runtime_context( ``mentioned_folder_ids`` from the request: the resolver drops chips that pointed at deleted folders or folders the caller can't see, so middlewares only get authorized ids. + + Connector mentions are set on the schema for legacy parity even though no + middleware reads them yet. """ return SurfSenseContextSchema( search_space_id=search_space_id, diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py index 7db45941b..fe3d210bb 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py @@ -30,7 +30,7 @@ from app.prompts import TITLE_GENERATION_PROMPT from app.services.new_streaming_service import VercelStreamingService if TYPE_CHECKING: - from app.agents.new_chat.llm_config import AgentConfig + from app.agents.chat.runtime.llm_config import AgentConfig from app.services.token_tracking_service import TokenAccumulator diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py index e1b95aa63..6d0924850 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py @@ -23,11 +23,12 @@ from uuid import UUID import anyio -from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent -from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent -from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection -from app.agents.new_chat.middleware.busy_mutex import end_turn -from app.config import config as _app_config +from app.agents.chat.multi_agent_chat import create_multi_agent_chat_deep_agent +from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import end_turn +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( + FilesystemMode, + FilesystemSelection, +) from app.db import ChatVisibility, async_session_maker from app.observability import otel as ot from app.services.chat_session_state_service import set_ai_responding @@ -326,16 +327,11 @@ async def stream_resume_chat( ) visibility = thread_visibility or ChatVisibility.PRIVATE - use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED) - chat_agent_mode = "multi" if use_multi_agent else "single" + chat_agent_mode = "multi" set_agent_mode(chat_span, chat_agent_mode) _t0 = time.perf_counter() - agent_factory = ( - create_multi_agent_chat_deep_agent - if use_multi_agent - else create_surfsense_deep_agent - ) + agent_factory = create_multi_agent_chat_deep_agent agent = await build_main_agent_for_thread( agent_factory, llm=llm, diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/resume_routing.py b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/resume_routing.py index 7f4f67aac..d9877c9b0 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/resume_routing.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/resume_routing.py @@ -41,7 +41,7 @@ async def build_resume_routing( ``surfsense_resume_value`` configurable; parallel siblings each pop their own entry so they never race. """ - from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.resume_routing import ( build_lg_resume_map, collect_pending_tool_calls, slice_decisions_by_tool_call, diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/runtime_context.py b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/runtime_context.py index 59d5d8ca7..54f0dfba0 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/runtime_context.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/runtime_context.py @@ -7,7 +7,7 @@ can rely on ``runtime.context`` always being populated. from __future__ import annotations -from app.agents.new_chat.context import SurfSenseContextSchema +from app.agents.chat.shared.context import SurfSenseContextSchema def build_resume_chat_runtime_context( diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py index 2f334114c..7e2bc950b 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py @@ -14,7 +14,7 @@ from typing import Any from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.llm_config import ( +from app.agents.chat.runtime.llm_config import ( AgentConfig, create_chat_litellm_from_agent_config, create_chat_litellm_from_config, diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/pre_stream_setup.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/pre_stream_setup.py index ec92306dd..f717cb325 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/pre_stream_setup.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/pre_stream_setup.py @@ -4,7 +4,7 @@ from __future__ import annotations from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.checkpointer import get_checkpointer +from app.agents.chat.runtime.checkpointer import get_checkpointer from app.db import SearchSourceConnectorType from app.services.connector_service import ConnectorService @@ -33,7 +33,7 @@ async def setup_connector_and_firecrawl( async def get_chat_checkpointer(): """Resolve the PostgreSQL checkpointer for persistent conversation memory. - Thin wrapper around ``app.agents.new_chat.checkpointer.get_checkpointer`` so + Thin wrapper around ``app.agents.chat.runtime.checkpointer.get_checkpointer`` so flow orchestrators can rely on a streaming-local symbol and we have a hook point if the checkpointer source ever needs to vary per flow. """ diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py index cbf44764c..6c08cb29f 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py @@ -19,7 +19,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING from uuid import UUID -from app.agents.new_chat.llm_config import AgentConfig +from app.agents.chat.runtime.llm_config import AgentConfig from app.db import shielded_async_session if TYPE_CHECKING: diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/rate_limit_recovery.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/rate_limit_recovery.py index 6b3857594..29018fe07 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/rate_limit_recovery.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/rate_limit_recovery.py @@ -17,7 +17,7 @@ from typing import Literal from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.middleware.busy_mutex import end_turn +from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import end_turn from app.observability import otel as ot from app.services.auto_model_pin_service import ( mark_runtime_cooldown, diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/stream_loop.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/stream_loop.py index 6cf0df855..f455a8ffd 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/stream_loop.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/stream_loop.py @@ -15,7 +15,7 @@ from __future__ import annotations from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any -from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode from app.services.new_streaming_service import VercelStreamingService from app.tasks.chat.streaming.agent.event_loop import stream_agent_events from app.tasks.chat.streaming.shared.stream_result import StreamResult diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/terminal_error.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/terminal_error.py index b305dba23..126149cc1 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/terminal_error.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/terminal_error.py @@ -14,7 +14,7 @@ import traceback from collections.abc import Iterator from typing import Any, Literal -from app.agents.new_chat.errors import BusyError +from app.agents.chat.runtime.errors import BusyError from app.observability import metrics as ot_metrics, otel as ot from app.services.new_streaming_service import VercelStreamingService from app.tasks.chat.streaming.errors.classifier import classify_stream_exception diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py index 2ff810447..ae04c2823 100644 --- a/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py @@ -26,7 +26,7 @@ def _unwrap_command_output(raw_output: Any) -> Any: """Replace a ``Command`` from a tool return with its inner ``ToolMessage``. Tools that participate in receipt-style state writes (see - ``app.agents.shared.receipt_command.with_receipt``) return a + ``app.agents.chat.multi_agent_chat.shared.receipts.command.with_receipt``) return a ``Command(update={"messages": [ToolMessage(...)], "receipts": [...]})``. LangChain's ``on_tool_end`` event surfaces that ``Command`` verbatim as ``data.output``, which the rest of this handler can't introspect: it has diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py index 51a67f369..34283bcdb 100644 --- a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py @@ -21,7 +21,7 @@ def iter_completion_emission_frames( # ``ready`` is the live success status now that the tool waits for the # Celery worker to reach a terminal state. ``pending`` is retained as a # legacy branch for old saved chats that pre-date the wait-for-terminal - # change (see ``app.agents.shared.deliverable_wait``). + # change (see ``app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait``). if status == "ready": yield ctx.streaming_service.format_terminal_info( f"Video presentation generated successfully: {out.get('title', 'Presentation')}", diff --git a/surfsense_backend/app/tasks/document_processors/file_processors.py b/surfsense_backend/app/tasks/document_processors/file_processors.py index 805f5554d..f6929b87c 100644 --- a/surfsense_backend/app/tasks/document_processors/file_processors.py +++ b/surfsense_backend/app/tasks/document_processors/file_processors.py @@ -15,8 +15,9 @@ from dataclasses import dataclass from fastapi import HTTPException from sqlalchemy.ext.asyncio import AsyncSession -from app.db import Document, Log, Notification -from app.services.notification_service import NotificationService +from app.db import Document, Log +from app.notifications.persistence import Notification +from app.notifications.service import NotificationService from app.services.task_logging_service import TaskLoggingService from ._helpers import update_document_from_connector diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index 8fe4081b5..492569c95 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -195,10 +195,11 @@ asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "session" asyncio_default_test_loop_scope = "session" testpaths = ["tests"] +pythonpath = ["."] python_files = ["test_*.py"] python_classes = ["Test*"] python_functions = ["test_*"] -addopts = "-v --tb=short -x --strict-markers -ra --durations=5" +addopts = "-v --tb=short -x --strict-markers -ra --durations=5 --import-mode=importlib" markers = [ "unit: pure logic tests, no DB or external services", "integration: tests that require a real PostgreSQL database" diff --git a/surfsense_backend/tests/README.md b/surfsense_backend/tests/README.md new file mode 100644 index 000000000..5764252a5 --- /dev/null +++ b/surfsense_backend/tests/README.md @@ -0,0 +1,62 @@ +# Tests + +How the backend test suite is organized and the conventions to follow when adding tests. + +## Layout: type-first, module-mirrored + +Tests are split by **type** at the top level, and each type **mirrors the `app/` module tree** inside: + +``` +tests/ +├── conftest.py # global fixtures + DATABASE_URL pinning +├── unit/ # pure logic: no DB, no app, no network +│ └── notifications/ +│ ├── api/test_transform.py +│ └── service/ +│ ├── messages/test_connector_indexing.py +│ └── test_metadata.py +└── integration/ # real PostgreSQL (pgvector) + ├── conftest.py # async engine, transactional db_session, db_user, ... + └── notifications/ + ├── conftest.py # module-scoped fixtures (e.g. transactional client) + └── test_*_handler.py +``` + +To find a feature's tests, look under `tests//`. + +## Unit vs integration + +- `@pytest.mark.unit` — pure, fast, no I/O. Test behavior through a public function's inputs/outputs. +- `@pytest.mark.integration` — requires a real database. Run with `AUTH_TYPE=LOCAL`. + +Maximize logic covered by unit tests; keep integration tests for what genuinely needs the DB (persistence, SQL filters, scoping, HTTP wiring). + +## Principles + +- **Behavior, not implementation.** Assert observable outputs (returned values, persisted rows, HTTP responses), never private helpers. Tests should survive a refactor. +- **Functional core / imperative shell.** Put pure decision logic in a side-effect-free module (e.g. `app/notifications/service/messages/`) so it is unit-testable; keep the persistence shell thin and cover it with a few integration tests. +- **One responsibility per test file**, mirroring the slice it covers. +- **Mock only at system boundaries** (external APIs, brokers), never internal collaborators. Prefer dependency overrides and the transactional `db_session` over mocks. + +## Fixtures + +`conftest.py` is scoped to its directory and below. Keep truly global fixtures in `tests/conftest.py`; put module-specific fixtures in that module's `conftest.py` so a DB fixture never loads for a pure unit test. + +For API integration tests, override `get_async_session` and `current_active_user` to ride the test's transactional `db_session` (see `tests/integration/notifications/conftest.py`): rows seeded in the test and rows read via the endpoint share one transaction that rolls back automatically. + +## Import mode + +The suite uses `--import-mode=importlib` with `pythonpath = ["."]` (see `pyproject.toml`). This lets test files share basenames across modules (e.g. many `test_api.py`) without `__init__.py` boilerplate; new test directories do not need an `__init__.py`. + +## Running + +```bash +# fast unit tests +uv run pytest -m unit + +# integration (needs Postgres + pgvector) +AUTH_TYPE=LOCAL uv run pytest -m integration + +# a single module's tests +uv run pytest tests/unit/notifications +``` diff --git a/surfsense_backend/tests/e2e/auth_mint.py b/surfsense_backend/tests/e2e/auth_mint.py index f489ed274..edbf09f1a 100644 --- a/surfsense_backend/tests/e2e/auth_mint.py +++ b/surfsense_backend/tests/e2e/auth_mint.py @@ -51,7 +51,9 @@ async def mint_test_token( raise HTTPException(status_code=403, detail="invalid e2e mint secret") async with async_session_maker() as session: result = await session.execute(select(User).where(User.email == body.email)) - user = result.scalar_one_or_none() + # ``.unique()`` is required because the User mapper eager-loads a + # collection (oauth_accounts) via joined load. + user = result.unique().scalar_one_or_none() if user is None: raise HTTPException( status_code=404, detail=f"e2e user {body.email!r} not seeded" diff --git a/surfsense_backend/tests/e2e/fakes/chat_llm.py b/surfsense_backend/tests/e2e/fakes/chat_llm.py index fa3a2b158..234a18ec1 100644 --- a/surfsense_backend/tests/e2e/fakes/chat_llm.py +++ b/surfsense_backend/tests/e2e/fakes/chat_llm.py @@ -553,6 +553,49 @@ class FakeChatLLM(BaseChatModel): latest_tool_name = getattr(latest_tool, "name", None) latest_tool_text = _content_to_text(latest_tool.content) if latest_tool else "" + # Marker unique to a connector subagent's prompt: the main agent must + # delegate via ``task``; only the subagent has connector tools registered. + in_connector_subagent = ( + "specialist for the user's connected" in _messages_to_text(messages) + ) + + # Main agent: delegate live-tool connector work to its subagent (which + # then runs the real tools below). Indexed connectors are absent here. + if not in_connector_subagent and latest_tool is None: + connector_delegations = ( + ("gmail", ("gmail", "email", "message", GMAIL_CANARY_SUBJECT)), + ("calendar", ("calendar", "event", "meeting", CALENDAR_CANARY_SUMMARY)), + ( + "jira", + ( + "jira", + "atlassian", + JIRA_CANARY_SUMMARY, + JIRA_CANARY_KEY, + "surfsense-e2e.atlassian.net", + "fake-jira-cloud-001", + ), + ), + ("linear", ("linear", "issue", LINEAR_CANARY_TITLE)), + ("slack", ("slack", SLACK_CANARY_TOKEN)), + ("clickup", ("clickup", CLICKUP_CANARY_TITLE)), + ) + for subagent_type, needles in connector_delegations: + if _contains_any(latest_human, needles): + return AIMessage( + content="", + tool_calls=[ + { + "name": "task", + "args": { + "subagent_type": subagent_type, + "description": latest_human, + }, + "id": f"call_e2e_task_{subagent_type}", + } + ], + ) + if ( latest_tool_name == "search_gmail" and GMAIL_CANARY_MESSAGE_ID in latest_tool_text diff --git a/surfsense_backend/tests/e2e/fakes/mcp_runtime.py b/surfsense_backend/tests/e2e/fakes/mcp_runtime.py index e772bb63a..5e4ef403f 100644 --- a/surfsense_backend/tests/e2e/fakes/mcp_runtime.py +++ b/surfsense_backend/tests/e2e/fakes/mcp_runtime.py @@ -137,10 +137,10 @@ def install(active_patches: list[Any]) -> None: """Patch production MCP streamable-HTTP boundaries exactly once.""" targets = [ ( - "app.agents.new_chat.tools.mcp_tool.streamablehttp_client", + "app.agents.chat.multi_agent_chat.shared.tools.mcp.tool.streamablehttp_client", _fake_streamablehttp_client, ), - ("app.agents.new_chat.tools.mcp_tool.ClientSession", _FakeClientSession), + ("app.agents.chat.multi_agent_chat.shared.tools.mcp.tool.ClientSession", _FakeClientSession), ] for target, replacement in targets: p = patch(target, replacement) diff --git a/surfsense_backend/tests/e2e/fakes/native_google.py b/surfsense_backend/tests/e2e/fakes/native_google.py index 73c8cc738..1afcaf9c3 100644 --- a/surfsense_backend/tests/e2e/fakes/native_google.py +++ b/surfsense_backend/tests/e2e/fakes/native_google.py @@ -429,9 +429,18 @@ def install(active_patches: list[Any]) -> None: ("app.connectors.google_drive.client.build", _fake_build), ("app.connectors.google_gmail_connector.build", _fake_build), ("app.connectors.google_calendar_connector.build", _fake_build), - ("app.agents.new_chat.tools.google_calendar.create_event.build", _fake_build), - ("app.agents.new_chat.tools.google_calendar.update_event.build", _fake_build), - ("app.agents.new_chat.tools.google_calendar.delete_event.build", _fake_build), + ( + "app.agents.chat.multi_agent_chat.subagents.connectors.calendar.tools.create_event.build", + _fake_build, + ), + ( + "app.agents.chat.multi_agent_chat.subagents.connectors.calendar.tools.update_event.build", + _fake_build, + ), + ( + "app.agents.chat.multi_agent_chat.subagents.connectors.calendar.tools.delete_event.build", + _fake_build, + ), ("googleapiclient.http.MediaIoBaseDownload", _FakeMediaIoBaseDownload), ( "app.connectors.google_drive.client._build_thread_http", diff --git a/surfsense_backend/tests/e2e/run_backend.py b/surfsense_backend/tests/e2e/run_backend.py index 6781b1634..87977626f 100644 --- a/surfsense_backend/tests/e2e/run_backend.py +++ b/surfsense_backend/tests/e2e/run_backend.py @@ -72,8 +72,6 @@ def _load_dotenv_and_set_env_defaults() -> None: """ from dotenv import load_dotenv - load_dotenv() - os.environ.setdefault( "DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense", @@ -138,6 +136,11 @@ def _load_dotenv_and_set_env_defaults() -> None: os.environ["SLACK_CLIENT_ID"] = "fake-slack-mcp-client-id" os.environ["SLACK_CLIENT_SECRET"] = "fake-slack-mcp-client-secret" + # Load .env last so the E2E defaults above win over a developer's .env + # (e.g. AUTH_TYPE=GOOGLE), while an explicitly exported shell var still + # beats both: setdefault respects it and load_dotenv() never overrides. + load_dotenv() + def _install_synthetic_global_llm_config() -> None: """Materialise a fake ``app/config/global_llm_config.yaml`` for E2E. @@ -239,19 +242,19 @@ def _patch_llm_bindings() -> None: chat_targets = [ ( - "app.agents.new_chat.llm_config.create_chat_litellm_from_agent_config", + "app.agents.chat.runtime.llm_config.create_chat_litellm_from_agent_config", fake_create_chat_litellm_from_agent_config, ), ( - "app.agents.new_chat.llm_config.create_chat_litellm_from_config", + "app.agents.chat.runtime.llm_config.create_chat_litellm_from_config", fake_create_chat_litellm_from_config, ), ( - "app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config", + "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_agent_config", fake_create_chat_litellm_from_agent_config, ), ( - "app.tasks.chat.stream_new_chat.create_chat_litellm_from_config", + "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config", fake_create_chat_litellm_from_config, ), ] diff --git a/surfsense_backend/tests/e2e/run_celery.py b/surfsense_backend/tests/e2e/run_celery.py index d0fbb4760..bde547083 100644 --- a/surfsense_backend/tests/e2e/run_celery.py +++ b/surfsense_backend/tests/e2e/run_celery.py @@ -57,8 +57,6 @@ def _load_dotenv_and_set_env_defaults() -> None: """ from dotenv import load_dotenv - load_dotenv() - os.environ.setdefault( "DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense", @@ -122,6 +120,11 @@ def _load_dotenv_and_set_env_defaults() -> None: os.environ["SLACK_CLIENT_ID"] = "fake-slack-mcp-client-id" os.environ["SLACK_CLIENT_SECRET"] = "fake-slack-mcp-client-secret" + # Load .env last so the E2E defaults above win over a developer's .env + # (e.g. AUTH_TYPE=GOOGLE), while an explicitly exported shell var still + # beats both: setdefault respects it and load_dotenv() never overrides. + load_dotenv() + def _install_synthetic_global_llm_config() -> None: """Materialise a fake ``app/config/global_llm_config.yaml`` for E2E. @@ -212,19 +215,19 @@ def _patch_llm_bindings() -> None: chat_targets = [ ( - "app.agents.new_chat.llm_config.create_chat_litellm_from_agent_config", + "app.agents.chat.runtime.llm_config.create_chat_litellm_from_agent_config", fake_create_chat_litellm_from_agent_config, ), ( - "app.agents.new_chat.llm_config.create_chat_litellm_from_config", + "app.agents.chat.runtime.llm_config.create_chat_litellm_from_config", fake_create_chat_litellm_from_config, ), ( - "app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config", + "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_agent_config", fake_create_chat_litellm_from_agent_config, ), ( - "app.tasks.chat.stream_new_chat.create_chat_litellm_from_config", + "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config", fake_create_chat_litellm_from_config, ), ] diff --git a/surfsense_backend/tests/integration/agents/multi_agent_chat/test_agent_turn.py b/surfsense_backend/tests/integration/agents/multi_agent_chat/test_agent_turn.py new file mode 100644 index 000000000..b30744177 --- /dev/null +++ b/surfsense_backend/tests/integration/agents/multi_agent_chat/test_agent_turn.py @@ -0,0 +1,142 @@ +"""Guardrail D: the real multi-agent is still assemblable and runnable. + +Builds the production ``create_multi_agent_chat_deep_agent`` factory against a +real (test) DB with a scripted LLM, then drives one turn. This is the only +guard that proves the *assembled* agent — full tool registry, middleware stack, +compiled graph — still executes end to end after files move. A/B/C prove the +parts import, wire, and load; this proves they run together. + +Scripted LLM + faked external tools; everything we own (graph, middleware, +DB-backed connector service) runs for real. +""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langgraph.checkpoint.memory import InMemorySaver + +from app.agents.chat.multi_agent_chat import create_multi_agent_chat_deep_agent +from app.services.connector_service import ConnectorService +from tests.integration.harness import ( + ScriptedTurn, + StubToolSpec, + build_scripted_harness, +) + +pytestmark = pytest.mark.integration + + +def _last_ai_text(messages: list) -> str | None: + for m in reversed(messages): + if isinstance(m, AIMessage): + return m.content if isinstance(m.content, str) else str(m.content) + return None + + +@pytest.mark.asyncio +async def test_agent_runs_a_scripted_text_turn(db_session, db_user, db_search_space): + """A freshly assembled agent streams a scripted final-text turn to completion.""" + harness = build_scripted_harness(turns=[ScriptedTurn(text="done")]) + + agent = await create_multi_agent_chat_deep_agent( + llm=harness.model, + search_space_id=db_search_space.id, + db_session=db_session, + connector_service=ConnectorService(db_session), + checkpointer=InMemorySaver(), + user_id=str(db_user.id), + thread_id=db_search_space.id, + agent_config=None, + ) + + result = await agent.ainvoke( + {"messages": [HumanMessage(content="hello")]}, + config={"configurable": {"thread_id": "guard-d-thread-1"}}, + ) + + assert _last_ai_text(result["messages"]) == "done" + + +@pytest.mark.asyncio +async def test_agent_routes_a_scripted_tool_call(db_session, db_user, db_search_space): + """The compiled graph routes a model tool call to its tool and resumes.""" + harness = build_scripted_harness( + turns=[ + ScriptedTurn( + tool_calls=[{"name": "echo", "args": {"x": 1}, "id": "call_1"}] + ), + ScriptedTurn(text="echoed"), + ], + tools=[ + StubToolSpec( + name="echo", + description="Echo the args back.", + handler=lambda **kwargs: {"echoed": kwargs}, + ), + ], + ) + + agent = await create_multi_agent_chat_deep_agent( + llm=harness.model, + search_space_id=db_search_space.id, + db_session=db_session, + connector_service=ConnectorService(db_session), + checkpointer=InMemorySaver(), + user_id=str(db_user.id), + thread_id=db_search_space.id, + agent_config=None, + additional_tools=harness.tools, + ) + + result = await agent.ainvoke( + {"messages": [HumanMessage(content="echo please")]}, + config={"configurable": {"thread_id": "guard-d-thread-2"}}, + ) + + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert any("echoed" in str(m.content) for m in tool_messages) + assert _last_ai_text(result["messages"]) == "echoed" + + +@pytest.mark.asyncio +async def test_agent_checkpoint_round_trips_across_turns( + db_session, db_user, db_search_space +): + """Turn 2 sees turn 1's history, proving the checkpoint serializes and reloads. + + Uses InMemorySaver, which serializes via the same ``JsonPlusSerializer`` as + the production Postgres checkpointer — so a state class that became + unserializable after a module move would fail here too. + """ + harness = build_scripted_harness( + turns=[ScriptedTurn(text="ok-one"), ScriptedTurn(text="ok-two")] + ) + checkpointer = InMemorySaver() + config = {"configurable": {"thread_id": "guard-e-thread-1"}} + + async def _build(): + return await create_multi_agent_chat_deep_agent( + llm=harness.model, + search_space_id=db_search_space.id, + db_session=db_session, + connector_service=ConnectorService(db_session), + checkpointer=checkpointer, + user_id=str(db_user.id), + thread_id=db_search_space.id, + agent_config=None, + ) + + agent = await _build() + first = await agent.ainvoke( + {"messages": [HumanMessage(content="remember apple")]}, config + ) + second = await agent.ainvoke( + {"messages": [HumanMessage(content="second turn")]}, config + ) + + texts = [ + m.content for m in second["messages"] if isinstance(m, HumanMessage) + ] + assert "remember apple" in texts, "turn 1 history not reloaded from checkpoint" + assert len(second["messages"]) > len(first["messages"]) diff --git a/surfsense_backend/tests/integration/agents/multi_agent_chat/test_kb_filesystem_cloud.py b/surfsense_backend/tests/integration/agents/multi_agent_chat/test_kb_filesystem_cloud.py new file mode 100644 index 000000000..878473f55 --- /dev/null +++ b/surfsense_backend/tests/integration/agents/multi_agent_chat/test_kb_filesystem_cloud.py @@ -0,0 +1,203 @@ +"""Real-behavior tests for the LIVE knowledge-base filesystem middleware (B) in +cloud mode. + +Cloud mode is the default production filesystem for web chat. Unlike desktop, +cloud writes/edits/moves/deletes are *staged* into LangGraph state during the +turn and committed to Postgres at end-of-turn by the persistence middleware. +These tests drive the production ``build_filesystem_mw`` cloud tools through a +real ``create_agent`` graph and assert the staging contract (namespace policy, +read-from-stage, mkdir staging, duplicate rejection) — all deterministic and +DB-free because cloud ``awrite`` is pure in-state staging. + +The end-of-turn DB commit (``commit_staged_filesystem_state``) is covered +separately; here we lock the per-tool behavior that the reorg could break. +""" + +from __future__ import annotations + +import pytest +from langchain.agents import create_agent +from langchain_core.messages import HumanMessage, ToolMessage +from langgraph.checkpoint.memory import InMemorySaver + +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( + FilesystemMode, + FilesystemSelection, +) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem import ( + build_filesystem_mw, +) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.resolver import ( + build_backend_resolver, +) +from tests.integration.harness import ScriptedTurn, build_scripted_harness + +pytestmark = [pytest.mark.integration, pytest.mark.asyncio] + +_SEARCH_SPACE_ID = 1 + + +def _build_cloud_fs_mw(): + """Build the production filesystem middleware in cloud mode. + + A non-None ``search_space_id`` makes the resolver hand out a + ``KBPostgresBackend``, exactly as production does. Staging operations never + touch the DB, so a dummy id is sufficient for these tests. + """ + selection = FilesystemSelection(mode=FilesystemMode.CLOUD) + resolver = build_backend_resolver(selection, search_space_id=_SEARCH_SPACE_ID) + return build_filesystem_mw( + backend_resolver=resolver, + filesystem_mode=FilesystemMode.CLOUD, + search_space_id=_SEARCH_SPACE_ID, + user_id="00000000-0000-0000-0000-000000000001", + thread_id=_SEARCH_SPACE_ID, + read_only=False, + ) + + +async def _run(turns: list[ScriptedTurn], thread: str): + harness = build_scripted_harness(turns=turns) + agent = create_agent( + harness.model, + tools=[], + middleware=[_build_cloud_fs_mw()], + checkpointer=InMemorySaver(), + ) + return await agent.ainvoke( + {"messages": [HumanMessage(content="do kb work")]}, + config={"configurable": {"thread_id": thread}}, + ) + + +def _tool_text(result, name: str) -> str: + for m in result["messages"]: + if isinstance(m, ToolMessage) and m.name == name: + return str(m.content) + raise AssertionError(f"no ToolMessage from {name!r}") + + +def _write(path: str, content: str, call_id: str) -> ScriptedTurn: + return ScriptedTurn( + tool_calls=[ + { + "name": "write_file", + "args": {"file_path": path, "content": content}, + "id": call_id, + } + ] + ) + + +async def test_cloud_write_then_read_returns_staged_content(): + """A cloud write stages into state and a later read returns that content.""" + result = await _run( + [ + _write("/documents/note.md", "cloud CANARY-CLD-1", "c1"), + ScriptedTurn( + tool_calls=[ + { + "name": "read_file", + "args": {"file_path": "/documents/note.md"}, + "id": "c2", + } + ] + ), + ScriptedTurn(text="done"), + ], + "fs-cloud-write-read", + ) + + assert "Updated file /documents/note.md" in _tool_text(result, "write_file") + assert "CANARY-CLD-1" in _tool_text(result, "read_file") + + +async def test_cloud_write_outside_documents_is_rejected(): + """Cloud namespace policy: writes must target /documents (non-temp paths).""" + result = await _run( + [ + _write("/scratch/note.md", "nope", "c1"), + ScriptedTurn(text="done"), + ], + "fs-cloud-namespace", + ) + + msg = _tool_text(result, "write_file") + assert "must target /documents" in msg + + +async def test_cloud_temp_prefixed_write_is_allowed_anywhere(): + """A ``temp_`` basename escapes the /documents namespace restriction.""" + result = await _run( + [ + _write("/temp_scratch.md", "ephemeral", "c1"), + ScriptedTurn(text="done"), + ], + "fs-cloud-temp", + ) + + msg = _tool_text(result, "write_file") + assert "must target /documents" not in msg + assert "Updated file" in msg + + +async def test_cloud_mkdir_stages_directory(): + """Cloud mkdir stages the directory for end-of-turn creation (no immediate IO).""" + result = await _run( + [ + ScriptedTurn( + tool_calls=[ + { + "name": "mkdir", + "args": {"path": "/documents/projects"}, + "id": "c1", + } + ] + ), + ScriptedTurn(text="done"), + ], + "fs-cloud-mkdir", + ) + + msg = _tool_text(result, "mkdir") + assert "Staged directory" in msg + assert "/documents/projects" in msg + + +async def test_cloud_mkdir_outside_documents_is_rejected(): + """Cloud mkdir is also restricted to the /documents namespace.""" + result = await _run( + [ + ScriptedTurn( + tool_calls=[ + {"name": "mkdir", "args": {"path": "/elsewhere"}, "id": "c1"} + ] + ), + ScriptedTurn(text="done"), + ], + "fs-cloud-mkdir-bad", + ) + + assert "must target a path under /documents" in _tool_text(result, "mkdir") + + +async def test_cloud_duplicate_write_is_rejected(): + """Writing to a path already staged this turn is rejected (use edit instead).""" + result = await _run( + [ + _write("/documents/dup.md", "first", "c1"), + _write("/documents/dup.md", "second", "c2"), + ScriptedTurn(text="done"), + ], + "fs-cloud-dup", + ) + + # Two write ToolMessages: first succeeds, second is rejected. + write_msgs = [ + str(m.content) + for m in result["messages"] + if isinstance(m, ToolMessage) and m.name == "write_file" + ] + assert len(write_msgs) == 2 + assert "Updated file" in write_msgs[0] + assert "already exists" in write_msgs[1] diff --git a/surfsense_backend/tests/integration/agents/multi_agent_chat/test_kb_filesystem_desktop.py b/surfsense_backend/tests/integration/agents/multi_agent_chat/test_kb_filesystem_desktop.py new file mode 100644 index 000000000..4c624d80d --- /dev/null +++ b/surfsense_backend/tests/integration/agents/multi_agent_chat/test_kb_filesystem_desktop.py @@ -0,0 +1,351 @@ +"""Real-behavior tests for the LIVE knowledge-base filesystem middleware (B). + +These exercise ``app.agents.chat.multi_agent_chat.shared.middleware.filesystem`` — +the decomposed middleware + tools that production actually mounts on the +knowledge_base subagent (via ``build_filesystem_mw``). The previous +``tests/unit/middleware/test_filesystem_*.py`` suite asserts a *dead twin* +(``app.agents.chat.shared.middleware.filesystem``) that is never instantiated, so the +live tool path had no real coverage. + +Strategy: mount the production ``build_filesystem_mw`` on a minimal +``create_agent`` graph and drive its tools with the scripted harness. Desktop +mode binds a ``MultiRootLocalFolderBackend`` to a real ``tmp_path`` directory, +so every write/edit/move/rm is asserted against the real on-disk filesystem — +no mocks, only the LLM is scripted. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +from langchain.agents import create_agent +from langchain_core.messages import HumanMessage, ToolMessage +from langgraph.checkpoint.memory import InMemorySaver + +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( + FilesystemMode, + FilesystemSelection, + LocalFilesystemMount, +) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem import ( + build_filesystem_mw, +) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.resolver import ( + build_backend_resolver, +) +from tests.integration.harness import ScriptedTurn, build_scripted_harness + +pytestmark = [pytest.mark.integration, pytest.mark.asyncio] + +_MOUNT_ID = "workspace" + + +def _build_desktop_fs_mw(root: Path): + """Build the production filesystem middleware bound to a real local folder.""" + selection = FilesystemSelection( + mode=FilesystemMode.DESKTOP_LOCAL_FOLDER, + local_mounts=( + LocalFilesystemMount(mount_id=_MOUNT_ID, root_path=str(root)), + ), + ) + resolver = build_backend_resolver(selection) + return build_filesystem_mw( + backend_resolver=resolver, + filesystem_mode=FilesystemMode.DESKTOP_LOCAL_FOLDER, + search_space_id=1, + user_id="00000000-0000-0000-0000-000000000001", + thread_id=1, + read_only=False, + ) + + +async def _run(root: Path, turns: list[ScriptedTurn], thread: str): + """Assemble a 1-middleware agent and drive the scripted turns to completion.""" + harness = build_scripted_harness(turns=turns) + fs_mw = _build_desktop_fs_mw(root) + agent = create_agent( + harness.model, + tools=[], + middleware=[fs_mw], + checkpointer=InMemorySaver(), + ) + return await agent.ainvoke( + {"messages": [HumanMessage(content="do filesystem work")]}, + config={"configurable": {"thread_id": thread}}, + ) + + +def _tool_messages(result) -> list[ToolMessage]: + return [m for m in result["messages"] if isinstance(m, ToolMessage)] + + +def _tool_text(result, name: str) -> str: + for m in _tool_messages(result): + if m.name == name: + return str(m.content) + raise AssertionError(f"no ToolMessage from {name!r} in {_tool_messages(result)}") + + +async def test_write_then_read_round_trip(tmp_path: Path): + """write_file persists to the real folder and read_file returns the content.""" + result = await _run( + tmp_path, + [ + ScriptedTurn( + tool_calls=[ + { + "name": "write_file", + "args": { + "file_path": f"/{_MOUNT_ID}/notes.md", + "content": "hello FS-CANARY-001", + }, + "id": "c1", + } + ] + ), + ScriptedTurn( + tool_calls=[ + { + "name": "read_file", + "args": {"file_path": f"/{_MOUNT_ID}/notes.md"}, + "id": "c2", + } + ] + ), + ScriptedTurn(text="done"), + ], + "fs-desktop-write-read", + ) + + # Real on-disk effect, not a mock. + assert (tmp_path / "notes.md").read_text() == "hello FS-CANARY-001" + # The tool actually returned the file content. + assert "FS-CANARY-001" in _tool_text(result, "read_file") + + +async def test_write_then_ls_lists_file(tmp_path: Path): + """ls reflects a freshly written file in the real folder.""" + result = await _run( + tmp_path, + [ + ScriptedTurn( + tool_calls=[ + { + "name": "write_file", + "args": { + "file_path": f"/{_MOUNT_ID}/report.md", + "content": "x", + }, + "id": "c1", + } + ] + ), + ScriptedTurn( + tool_calls=[ + {"name": "ls", "args": {"path": f"/{_MOUNT_ID}"}, "id": "c2"} + ] + ), + ScriptedTurn(text="done"), + ], + "fs-desktop-ls", + ) + + assert (tmp_path / "report.md").exists() + assert "report.md" in _tool_text(result, "ls") + + +async def test_edit_file_rewrites_on_disk(tmp_path: Path): + """edit_file applies a real string replacement to the on-disk file.""" + result = await _run( + tmp_path, + [ + ScriptedTurn( + tool_calls=[ + { + "name": "write_file", + "args": { + "file_path": f"/{_MOUNT_ID}/doc.md", + "content": "the quick brown fox", + }, + "id": "c1", + } + ] + ), + ScriptedTurn( + tool_calls=[ + { + "name": "edit_file", + "args": { + "file_path": f"/{_MOUNT_ID}/doc.md", + "old_string": "brown", + "new_string": "red", + }, + "id": "c2", + } + ] + ), + ScriptedTurn(text="done"), + ], + "fs-desktop-edit", + ) + + assert (tmp_path / "doc.md").read_text() == "the quick red fox" + + +async def test_write_into_existing_subdir(tmp_path: Path): + """A write into an EXISTING subdirectory lands on disk under that folder.""" + (tmp_path / "sub").mkdir() + result = await _run( + tmp_path, + [ + ScriptedTurn( + tool_calls=[ + { + "name": "write_file", + "args": { + "file_path": f"/{_MOUNT_ID}/sub/inner.md", + "content": "nested", + }, + "id": "c1", + } + ] + ), + ScriptedTurn(text="done"), + ], + "fs-desktop-subdir", + ) + + assert "Error" not in _tool_text(result, "write_file") + assert (tmp_path / "sub" / "inner.md").read_text() == "nested" + + +async def test_write_to_missing_parent_dir_is_rejected(tmp_path: Path): + """Desktop write refuses to create a file under a non-existent directory. + + Real current behavior: the local-folder backend requires the parent to + exist (and ``mkdir`` is a no-op for this backend), so the agent cannot + fabricate new nested folders via ``write_file``. Locking this guards against + a silent behavior change during the agents-module reorg. + """ + result = await _run( + tmp_path, + [ + ScriptedTurn( + tool_calls=[ + { + "name": "write_file", + "args": { + "file_path": f"/{_MOUNT_ID}/missing/inner.md", + "content": "nested", + }, + "id": "c1", + } + ] + ), + ScriptedTurn(text="done"), + ], + "fs-desktop-missing-parent", + ) + + write_msg = _tool_text(result, "write_file") + assert "parent directory" in write_msg.lower() + assert not (tmp_path / "missing").exists() + + +async def test_move_file_relocates_on_disk(tmp_path: Path): + """move_file relocates the real file from source to destination.""" + await _run( + tmp_path, + [ + ScriptedTurn( + tool_calls=[ + { + "name": "write_file", + "args": { + "file_path": f"/{_MOUNT_ID}/src.md", + "content": "movable", + }, + "id": "c1", + } + ] + ), + ScriptedTurn( + tool_calls=[ + { + "name": "move_file", + "args": { + "source_path": f"/{_MOUNT_ID}/src.md", + "destination_path": f"/{_MOUNT_ID}/dst.md", + }, + "id": "c2", + } + ] + ), + ScriptedTurn(text="done"), + ], + "fs-desktop-move", + ) + + assert not (tmp_path / "src.md").exists() + assert (tmp_path / "dst.md").read_text() == "movable" + + +async def test_rm_deletes_file_on_disk(tmp_path: Path): + """rm removes the real file (desktop deletes are immediate).""" + await _run( + tmp_path, + [ + ScriptedTurn( + tool_calls=[ + { + "name": "write_file", + "args": { + "file_path": f"/{_MOUNT_ID}/trash.md", + "content": "bye", + }, + "id": "c1", + } + ] + ), + ScriptedTurn( + tool_calls=[ + { + "name": "rm", + "args": {"path": f"/{_MOUNT_ID}/trash.md"}, + "id": "c2", + } + ] + ), + ScriptedTurn(text="done"), + ], + "fs-desktop-rm", + ) + + assert not (tmp_path / "trash.md").exists() + + +async def test_rmdir_removes_empty_dir_on_disk(tmp_path: Path): + """rmdir removes a real empty directory.""" + (tmp_path / "gone").mkdir() + assert (tmp_path / "gone").is_dir() + + result = await _run( + tmp_path, + [ + ScriptedTurn( + tool_calls=[ + { + "name": "rmdir", + "args": {"path": f"/{_MOUNT_ID}/gone"}, + "id": "c1", + } + ] + ), + ScriptedTurn(text="done"), + ], + "fs-desktop-rmdir", + ) + + assert "Error" not in _tool_text(result, "rmdir") + assert not (tmp_path / "gone").exists() diff --git a/surfsense_backend/tests/integration/google_unification/conftest.py b/surfsense_backend/tests/integration/google_unification/conftest.py index de68c7acb..390442fdd 100644 --- a/surfsense_backend/tests/integration/google_unification/conftest.py +++ b/surfsense_backend/tests/integration/google_unification/conftest.py @@ -239,7 +239,7 @@ def patched_shielded_session(async_engine, monkeypatch): yield session monkeypatch.setattr( - "app.agents.new_chat.tools.knowledge_base.shielded_async_session", + "app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.knowledge_base.shielded_async_session", _test_shielded, ) diff --git a/surfsense_backend/tests/integration/google_unification/test_browse_includes_legacy_docs.py b/surfsense_backend/tests/integration/google_unification/test_browse_includes_legacy_docs.py index fc2fec5a8..f0d5c6c6c 100644 --- a/surfsense_backend/tests/integration/google_unification/test_browse_includes_legacy_docs.py +++ b/surfsense_backend/tests/integration/google_unification/test_browse_includes_legacy_docs.py @@ -17,7 +17,9 @@ async def test_browse_recent_documents_with_list_type_returns_both( committed_google_data, patched_shielded_session ): """_browse_recent_documents returns docs of all types when given a list.""" - from app.agents.new_chat.tools.knowledge_base import _browse_recent_documents + from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.knowledge_base import ( + _browse_recent_documents, + ) space_id = committed_google_data["search_space_id"] diff --git a/surfsense_backend/tests/integration/notifications/conftest.py b/surfsense_backend/tests/integration/notifications/conftest.py new file mode 100644 index 000000000..17a44a51d --- /dev/null +++ b/surfsense_backend/tests/integration/notifications/conftest.py @@ -0,0 +1,53 @@ +"""Notifications integration fixtures. + +The app's DB session and current-user dependencies are overridden to ride the +test's transactional `db_session`, so API calls and seeded rows share one +transaction that rolls back per test. Overriding `current_active_user` also +bypasses real JWT auth, so these tests don't depend on AUTH_TYPE. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator + +import httpx +import pytest +import pytest_asyncio +from httpx import ASGITransport +from sqlalchemy.ext.asyncio import AsyncSession + +from app.app import app, limiter +from app.db import User, get_async_session +from app.users import current_active_user + +pytestmark = pytest.mark.integration + +limiter.enabled = False + + +@pytest_asyncio.fixture +async def client( + db_session: AsyncSession, + db_user: User, +) -> AsyncGenerator[httpx.AsyncClient, None]: + async def override_session() -> AsyncGenerator[AsyncSession, None]: + yield db_session + + async def override_user() -> User: + return db_user + + previous_overrides = app.dependency_overrides.copy() + app.dependency_overrides[get_async_session] = override_session + app.dependency_overrides[current_active_user] = override_user + + try: + async with httpx.AsyncClient( + transport=ASGITransport(app=app), + base_url="http://test", + timeout=30.0, + follow_redirects=False, + ) as test_client: + yield test_client + finally: + app.dependency_overrides.clear() + app.dependency_overrides.update(previous_overrides) diff --git a/surfsense_backend/tests/integration/notifications/test_base_handler.py b/surfsense_backend/tests/integration/notifications/test_base_handler.py new file mode 100644 index 000000000..ef7d9ee6c --- /dev/null +++ b/surfsense_backend/tests/integration/notifications/test_base_handler.py @@ -0,0 +1,164 @@ +"""Behavior guard for the shared find/upsert/update logic (BaseNotificationHandler). + +Uses the connector-indexing handler instance to drive the base methods against +real Postgres, pinning upsert dedup, search-space scoping, and status stamping. +""" + +from __future__ import annotations + +import pytest +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import SearchSpace, User +from app.notifications.persistence import Notification +from app.notifications.service import NotificationService + +pytestmark = pytest.mark.integration + +handler = NotificationService.connector_indexing + + +async def test_find_or_create_creates_with_progress_metadata( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """Creating a notification seeds operation id, in-progress status, and start time.""" + notification = await handler.find_or_create_notification( + session=db_session, + user_id=db_user.id, + operation_id="op-create", + title="Title", + message="Message", + search_space_id=db_search_space.id, + ) + + assert notification.notification_metadata["operation_id"] == "op-create" + assert notification.notification_metadata["status"] == "in_progress" + assert "started_at" in notification.notification_metadata + + +async def test_find_or_create_upserts_same_operation( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """Reusing an operation id updates the same row instead of creating a duplicate.""" + first = await handler.find_or_create_notification( + session=db_session, + user_id=db_user.id, + operation_id="op-upsert", + title="First", + message="First message", + search_space_id=db_search_space.id, + ) + + second = await handler.find_or_create_notification( + session=db_session, + user_id=db_user.id, + operation_id="op-upsert", + title="Second", + message="Second message", + search_space_id=db_search_space.id, + ) + + assert second.id == first.id + assert second.title == "Second" + assert second.message == "Second message" + + count = await db_session.scalar( + select(func.count(Notification.id)).where( + Notification.user_id == db_user.id, + Notification.notification_metadata["operation_id"].astext == "op-upsert", + ) + ) + assert count == 1 + + +async def test_find_by_operation_is_scoped_to_search_space( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """Operation-id lookup is scoped per search space, so other spaces don't match.""" + await handler.find_or_create_notification( + session=db_session, + user_id=db_user.id, + operation_id="op-scoped", + title="Title", + message="Message", + search_space_id=db_search_space.id, + ) + + other_space = SearchSpace(name="Other Space", user_id=db_user.id) + db_session.add(other_space) + await db_session.flush() + + found_other = await handler.find_notification_by_operation( + session=db_session, + user_id=db_user.id, + operation_id="op-scoped", + search_space_id=other_space.id, + ) + assert found_other is None + + found_same = await handler.find_notification_by_operation( + session=db_session, + user_id=db_user.id, + operation_id="op-scoped", + search_space_id=db_search_space.id, + ) + assert found_same is not None + + +async def test_update_notification_completed_stamps_completed_at( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """Completing a notification stamps completed_at and merges metadata updates.""" + notification = await handler.find_or_create_notification( + session=db_session, + user_id=db_user.id, + operation_id="op-complete", + title="Title", + message="Message", + search_space_id=db_search_space.id, + ) + + updated = await handler.update_notification( + session=db_session, + notification=notification, + status="completed", + metadata_updates={"indexed_count": 7}, + ) + + assert updated.notification_metadata["status"] == "completed" + assert "completed_at" in updated.notification_metadata + assert updated.notification_metadata["indexed_count"] == 7 + + +async def test_update_notification_failed_stamps_completed_at( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """Failing a notification also stamps completed_at for the terminal state.""" + notification = await handler.find_or_create_notification( + session=db_session, + user_id=db_user.id, + operation_id="op-fail", + title="Title", + message="Message", + search_space_id=db_search_space.id, + ) + + updated = await handler.update_notification( + session=db_session, + notification=notification, + status="failed", + ) + + assert updated.notification_metadata["status"] == "failed" + assert "completed_at" in updated.notification_metadata diff --git a/surfsense_backend/tests/integration/notifications/test_comment_reply_handler.py b/surfsense_backend/tests/integration/notifications/test_comment_reply_handler.py new file mode 100644 index 000000000..eed5b286f --- /dev/null +++ b/surfsense_backend/tests/integration/notifications/test_comment_reply_handler.py @@ -0,0 +1,62 @@ +"""Behavior guard for the comment-reply notification handler.""" + +from __future__ import annotations + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import SearchSpace, User +from app.notifications.service import NotificationService + +pytestmark = pytest.mark.integration + +handler = NotificationService.comment_reply + + +async def _notify(db_session, db_user, db_search_space, *, reply_id=1, preview="hi"): + """Raise a comment-reply notification for the assertions in the tests below.""" + return await handler.notify_comment_reply( + session=db_session, + user_id=db_user.id, + reply_id=reply_id, + parent_comment_id=10, + message_id=20, + thread_id=30, + thread_title="Thread", + author_id="author-1", + author_name="Bob", + author_avatar_url=None, + author_email="bob@surfsense.net", + content_preview=preview, + search_space_id=db_search_space.id, + ) + + +async def test_comment_reply_title_and_message( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +): + """A reply notification names the author and carries the comment preview.""" + notification = await _notify(db_session, db_user, db_search_space, preview="thanks") + + assert notification.type == "comment_reply" + assert notification.title == "Bob replied in a thread" + assert notification.message == "thanks" + + +async def test_comment_reply_truncates_long_preview( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +): + """A long comment preview is truncated in the reply message.""" + notification = await _notify(db_session, db_user, db_search_space, preview="y" * 150) + + assert notification.message == "y" * 100 + "..." + + +async def test_comment_reply_is_idempotent( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +): + """Re-notifying the same reply id reuses the existing notification row.""" + first = await _notify(db_session, db_user, db_search_space, reply_id=5) + second = await _notify(db_session, db_user, db_search_space, reply_id=5) + + assert second.id == first.id diff --git a/surfsense_backend/tests/integration/notifications/test_connector_indexing_handler.py b/surfsense_backend/tests/integration/notifications/test_connector_indexing_handler.py new file mode 100644 index 000000000..a882716b9 --- /dev/null +++ b/surfsense_backend/tests/integration/notifications/test_connector_indexing_handler.py @@ -0,0 +1,235 @@ +"""Behavior guard for the connector-indexing notification handler. + +Exercises the real handler against Postgres via the transactional db_session, +pinning the title/message/status/metadata it produces so the upcoming +functional-core extraction cannot drift. +""" + +from __future__ import annotations + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import SearchSpace, User +from app.notifications.service import NotificationService + +pytestmark = pytest.mark.integration + + +async def test_indexing_started_opens_notification( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """Starting indexing opens an unread notification with connecting-stage metadata.""" + notification = await NotificationService.connector_indexing.notify_indexing_started( + session=db_session, + user_id=db_user.id, + connector_id=42, + connector_name="Notion - My Workspace", + connector_type="NOTION_CONNECTOR", + search_space_id=db_search_space.id, + ) + + assert notification.id is not None + assert notification.type == "connector_indexing" + assert notification.title == "Syncing: Notion - My Workspace" + assert notification.message == "Connecting to your account" + assert notification.read is False + + metadata = notification.notification_metadata + assert metadata["connector_id"] == 42 + assert metadata["connector_type"] == "NOTION_CONNECTOR" + assert metadata["indexed_count"] == 0 + assert metadata["sync_stage"] == "connecting" + assert metadata["status"] == "in_progress" + assert "operation_id" in metadata + assert "started_at" in metadata + + +async def _started( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, + *, + connector_name: str = "Notion - My Workspace", +): + """Open a connector-indexing notification to update in the tests below.""" + return await NotificationService.connector_indexing.notify_indexing_started( + session=db_session, + user_id=db_user.id, + connector_id=42, + connector_name=connector_name, + connector_type="NOTION_CONNECTOR", + search_space_id=db_search_space.id, + ) + + +async def test_indexing_progress_reports_stage_and_percent( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """Progress updates surface the stage message and compute a percent complete.""" + notification = await _started(db_session, db_user, db_search_space) + + updated = await NotificationService.connector_indexing.notify_indexing_progress( + session=db_session, + notification=notification, + indexed_count=5, + total_count=10, + stage="fetching", + ) + + assert updated.message == "Fetching your content" + metadata = updated.notification_metadata + assert metadata["indexed_count"] == 5 + assert metadata["total_count"] == 10 + assert metadata["progress_percent"] == 50 + assert metadata["sync_stage"] == "fetching" + assert metadata["status"] == "in_progress" + + +async def test_indexing_completed_clean_success( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """A clean multi-file sync reports ready/completed with plural wording.""" + notification = await _started(db_session, db_user, db_search_space) + + done = await NotificationService.connector_indexing.notify_indexing_completed( + session=db_session, + notification=notification, + indexed_count=3, + ) + + assert done.title == "Ready: Notion - My Workspace" + assert done.message == "Now searchable! 3 files synced." + assert done.notification_metadata["status"] == "completed" + assert done.notification_metadata["sync_stage"] == "completed" + + +async def test_indexing_completed_singular_file( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """A single synced file uses singular 'file' wording.""" + notification = await _started(db_session, db_user, db_search_space) + + done = await NotificationService.connector_indexing.notify_indexing_completed( + session=db_session, + notification=notification, + indexed_count=1, + ) + + assert done.message == "Now searchable! 1 file synced." + + +async def test_indexing_completed_nothing_to_sync( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """Completing with nothing new reports 'Already up to date!'.""" + notification = await _started(db_session, db_user, db_search_space) + + done = await NotificationService.connector_indexing.notify_indexing_completed( + session=db_session, + notification=notification, + indexed_count=0, + ) + + assert done.title == "Ready: Notion - My Workspace" + assert done.message == "Already up to date!" + assert done.notification_metadata["status"] == "completed" + + +async def test_indexing_completed_hard_failure( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """An error with nothing synced reports a hard failure.""" + notification = await _started(db_session, db_user, db_search_space) + + done = await NotificationService.connector_indexing.notify_indexing_completed( + session=db_session, + notification=notification, + indexed_count=0, + error_message="boom", + ) + + assert done.title == "Failed: Notion - My Workspace" + assert done.message == "Sync failed: boom" + assert done.notification_metadata["status"] == "failed" + assert done.notification_metadata["sync_stage"] == "failed" + + +async def test_indexing_completed_partial_with_error_note( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """An error after partial progress still completes, with an appended note.""" + notification = await _started(db_session, db_user, db_search_space) + + done = await NotificationService.connector_indexing.notify_indexing_completed( + session=db_session, + notification=notification, + indexed_count=2, + error_message="partial outage", + ) + + assert done.title == "Ready: Notion - My Workspace" + assert done.message == "Now searchable! 2 files synced. Note: partial outage" + assert done.notification_metadata["status"] == "completed" + + +async def test_retry_progress_frames_delay_as_providers( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """A retry message frames the delay as the provider's, using its short name.""" + notification = await _started(db_session, db_user, db_search_space) + + retry = await NotificationService.connector_indexing.notify_retry_progress( + session=db_session, + notification=notification, + indexed_count=0, + retry_reason="rate_limit", + attempt=1, + max_attempts=3, + ) + + # service_name is derived from the connector name, stripping the workspace suffix. + assert retry.message == "Notion rate limit reached. Retrying..." + assert retry.notification_metadata["sync_stage"] == "waiting_retry" + assert retry.notification_metadata["retry_attempt"] == 1 + assert retry.notification_metadata["retry_reason"] == "rate_limit" + + +async def test_retry_progress_shows_wait_and_synced_count( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + """A retry surfaces the wait time and how many items synced so far.""" + notification = await _started(db_session, db_user, db_search_space) + + retry = await NotificationService.connector_indexing.notify_retry_progress( + session=db_session, + notification=notification, + indexed_count=2, + retry_reason="rate_limit", + attempt=2, + max_attempts=3, + wait_seconds=10, + ) + + assert ( + retry.message + == "Notion rate limit reached. Retrying in 10s... (2 items synced so far)" + ) diff --git a/surfsense_backend/tests/integration/notifications/test_document_processing_handler.py b/surfsense_backend/tests/integration/notifications/test_document_processing_handler.py new file mode 100644 index 000000000..f602f2e66 --- /dev/null +++ b/surfsense_backend/tests/integration/notifications/test_document_processing_handler.py @@ -0,0 +1,80 @@ +"""Behavior guard for the document-processing notification handler.""" + +from __future__ import annotations + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import SearchSpace, User +from app.notifications.service import NotificationService + +pytestmark = pytest.mark.integration + +handler = NotificationService.document_processing + + +async def _started(db_session, db_user, db_search_space, *, name="report.pdf"): + """Open a document-processing notification to update in the tests below.""" + return await handler.notify_processing_started( + session=db_session, + user_id=db_user.id, + document_type="FILE", + document_name=name, + search_space_id=db_search_space.id, + ) + + +async def test_processing_started_queues( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +): + """Starting processing queues a notification in the 'queued' stage.""" + notification = await _started(db_session, db_user, db_search_space) + + assert notification.type == "document_processing" + assert notification.title == "Processing: report.pdf" + assert notification.message == "Waiting in queue" + assert notification.notification_metadata["processing_stage"] == "queued" + + +async def test_processing_progress_maps_stage( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +): + """A progress update maps the stage to its user-facing message.""" + notification = await _started(db_session, db_user, db_search_space) + + updated = await handler.notify_processing_progress( + session=db_session, notification=notification, stage="parsing" + ) + + assert updated.message == "Reading your file" + assert updated.notification_metadata["processing_stage"] == "parsing" + + +async def test_processing_completed_success( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +): + """Successful processing reports ready/searchable and a completed status.""" + notification = await _started(db_session, db_user, db_search_space) + + done = await handler.notify_processing_completed( + session=db_session, notification=notification, document_id=99 + ) + + assert done.title == "Ready: report.pdf" + assert done.message == "Now searchable!" + assert done.notification_metadata["status"] == "completed" + + +async def test_processing_completed_failure( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +): + """Failed processing reports a failed status with the error in the message.""" + notification = await _started(db_session, db_user, db_search_space) + + done = await handler.notify_processing_completed( + session=db_session, notification=notification, error_message="bad file" + ) + + assert done.title == "Failed: report.pdf" + assert done.message == "Processing failed: bad file" + assert done.notification_metadata["status"] == "failed" diff --git a/surfsense_backend/tests/integration/notifications/test_inbox_api.py b/surfsense_backend/tests/integration/notifications/test_inbox_api.py new file mode 100644 index 000000000..461e5c857 --- /dev/null +++ b/surfsense_backend/tests/integration/notifications/test_inbox_api.py @@ -0,0 +1,221 @@ +"""Behavior guard for the notifications inbox HTTP API. + +Rows are seeded through the transactional db_session and read back through the +real endpoints (auth + DB bound to the same transaction), pinning list filters, +counts, mark-read semantics, and response mapping. +""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import SearchSpace, User +from app.notifications.persistence import Notification + +pytestmark = pytest.mark.integration + +BASE = "/api/v1/notifications" + + +async def _seed( + db_session: AsyncSession, + user: User, + *, + type: str = "document_processing", + title: str = "Title", + message: str = "Message", + read: bool = False, + search_space_id: int | None = None, + metadata: dict | None = None, + created_at: datetime | None = None, +) -> Notification: + """Insert a notification row directly for the API tests to read back.""" + notification = Notification( + user_id=user.id, + search_space_id=search_space_id, + type=type, + title=title, + message=message, + read=read, + notification_metadata=metadata or {}, + ) + if created_at is not None: + notification.created_at = created_at + db_session.add(notification) + await db_session.flush() + return notification + + +async def test_list_returns_user_notifications_mapped(client, db_session, db_user): + """GET / returns the caller's notifications mapped to the response shape.""" + seeded = await _seed( + db_session, db_user, type="document_processing", title="Doc done" + ) + + resp = await client.get(BASE) + + assert resp.status_code == 200 + body = resp.json() + assert body["total"] == 1 + item = body["items"][0] + assert item["id"] == seeded.id + assert item["user_id"] == str(db_user.id) + assert item["type"] == "document_processing" + assert item["title"] == "Doc done" + assert item["read"] is False + assert item["created_at"] # ISO string present + + +async def test_list_orders_newest_first(client, db_session, db_user): + """The list is ordered by creation time, newest first.""" + now = datetime.now(UTC) + await _seed(db_session, db_user, title="older", created_at=now - timedelta(hours=2)) + await _seed(db_session, db_user, title="newer", created_at=now) + + resp = await client.get(BASE) + + titles = [item["title"] for item in resp.json()["items"]] + assert titles == ["newer", "older"] + + +async def test_list_filters_by_category(client, db_session, db_user): + """The category filter narrows results to that category's notification types.""" + await _seed(db_session, db_user, type="connector_indexing", title="status item") + await _seed(db_session, db_user, type="comment_reply", title="comment item") + + resp = await client.get(BASE, params={"category": "comments"}) + + titles = [item["title"] for item in resp.json()["items"]] + assert titles == ["comment item"] + + +async def test_list_filters_unread_only(client, db_session, db_user): + """The unread filter returns only notifications that haven't been read.""" + await _seed(db_session, db_user, title="unread one", read=False) + await _seed(db_session, db_user, title="read one", read=True) + + resp = await client.get(BASE, params={"filter": "unread"}) + + titles = [item["title"] for item in resp.json()["items"]] + assert titles == ["unread one"] + + +async def test_list_filters_by_connector_source_type(client, db_session, db_user): + """A 'connector:' source filter selects only that connector's notifications.""" + await _seed( + db_session, + db_user, + type="connector_indexing", + title="github", + metadata={"connector_type": "GITHUB_CONNECTOR"}, + ) + await _seed( + db_session, + db_user, + type="connector_indexing", + title="notion", + metadata={"connector_type": "NOTION_CONNECTOR"}, + ) + + resp = await client.get(BASE, params={"source_type": "connector:GITHUB_CONNECTOR"}) + + titles = [item["title"] for item in resp.json()["items"]] + assert titles == ["github"] + + +async def test_list_rejects_invalid_before_date(client, db_session, db_user): + """A malformed before_date is rejected with a 400.""" + await _seed(db_session, db_user) + + resp = await client.get(BASE, params={"before_date": "not-a-date"}) + + assert resp.status_code == 400 + + +async def test_list_paginates_with_has_more(client, db_session, db_user): + """Pagination caps the page and reports has_more plus the next offset.""" + now = datetime.now(UTC) + for i in range(3): + await _seed( + db_session, db_user, title=f"n{i}", created_at=now - timedelta(minutes=i) + ) + + resp = await client.get(BASE, params={"limit": 2, "offset": 0}) + + body = resp.json() + assert len(body["items"]) == 2 + assert body["has_more"] is True + assert body["next_offset"] == 2 + + +async def test_unread_count_splits_total_and_recent(client, db_session, db_user): + """The unread count reports total unread and a recent-window subset.""" + now = datetime.now(UTC) + await _seed(db_session, db_user, read=False, created_at=now) + await _seed(db_session, db_user, read=False, created_at=now - timedelta(days=30)) + await _seed(db_session, db_user, read=True, created_at=now) + + resp = await client.get(f"{BASE}/unread-count") + + body = resp.json() + assert body["total_unread"] == 2 + assert body["recent_unread"] == 1 + + +async def test_unread_counts_batch_by_category(client, db_session, db_user): + """The batch endpoint breaks unread counts down per category.""" + await _seed(db_session, db_user, type="comment_reply", read=False) + await _seed(db_session, db_user, type="connector_indexing", read=False) + + resp = await client.get(f"{BASE}/unread-counts-batch") + + body = resp.json() + assert body["comments"]["total_unread"] == 1 + assert body["status"]["total_unread"] == 1 + + +async def test_mark_read_then_idempotent(client, db_session, db_user): + """Marking read succeeds, and a repeat call is a no-op reporting already-read.""" + notification = await _seed(db_session, db_user, read=False) + + first = await client.patch(f"{BASE}/{notification.id}/read") + assert first.status_code == 200 + assert first.json()["success"] is True + + second = await client.patch(f"{BASE}/{notification.id}/read") + assert second.status_code == 200 + assert second.json()["message"] == "Notification already marked as read" + + +async def test_mark_read_foreign_notification_404(client, db_session, db_user): + """Marking another user's notification read returns 404, not a cross-user write.""" + other = User( + email="other@surfsense.net", + hashed_password="hashed", + is_active=True, + is_superuser=False, + is_verified=True, + ) + db_session.add(other) + await db_session.flush() + foreign = await _seed(db_session, other, read=False) + + resp = await client.patch(f"{BASE}/{foreign.id}/read") + + assert resp.status_code == 404 + + +async def test_mark_all_read_returns_count(client, db_session, db_user): + """Mark-all-read flips only the unread rows and returns how many changed.""" + await _seed(db_session, db_user, read=False) + await _seed(db_session, db_user, read=False) + await _seed(db_session, db_user, read=True) + + resp = await client.patch(f"{BASE}/read-all") + + body = resp.json() + assert body["success"] is True + assert body["updated_count"] == 2 diff --git a/surfsense_backend/tests/integration/notifications/test_mention_handler.py b/surfsense_backend/tests/integration/notifications/test_mention_handler.py new file mode 100644 index 000000000..dc25f7888 --- /dev/null +++ b/surfsense_backend/tests/integration/notifications/test_mention_handler.py @@ -0,0 +1,62 @@ +"""Behavior guard for the @mention notification handler.""" + +from __future__ import annotations + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import SearchSpace, User +from app.notifications.service import NotificationService + +pytestmark = pytest.mark.integration + +handler = NotificationService.mention + + +async def _notify(db_session, db_user, db_search_space, *, mention_id=1, preview="hi"): + """Raise an @mention notification for the assertions in the tests below.""" + return await handler.notify_new_mention( + session=db_session, + mentioned_user_id=db_user.id, + mention_id=mention_id, + comment_id=10, + message_id=20, + thread_id=30, + thread_title="Thread", + author_id="author-1", + author_name="Alice", + author_avatar_url=None, + author_email="alice@surfsense.net", + content_preview=preview, + search_space_id=db_search_space.id, + ) + + +async def test_new_mention_title_and_message( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +): + """A mention notification names the author and carries the comment preview.""" + notification = await _notify(db_session, db_user, db_search_space, preview="hello") + + assert notification.type == "new_mention" + assert notification.title == "Alice mentioned you" + assert notification.message == "hello" + + +async def test_new_mention_truncates_long_preview( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +): + """A long comment preview is truncated in the mention message.""" + notification = await _notify(db_session, db_user, db_search_space, preview="x" * 150) + + assert notification.message == "x" * 100 + "..." + + +async def test_new_mention_is_idempotent( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +): + """Re-notifying the same mention id reuses the existing notification row.""" + first = await _notify(db_session, db_user, db_search_space, mention_id=7) + second = await _notify(db_session, db_user, db_search_space, mention_id=7) + + assert second.id == first.id diff --git a/surfsense_backend/tests/integration/notifications/test_page_limit_handler.py b/surfsense_backend/tests/integration/notifications/test_page_limit_handler.py new file mode 100644 index 000000000..ab89d63c9 --- /dev/null +++ b/surfsense_backend/tests/integration/notifications/test_page_limit_handler.py @@ -0,0 +1,61 @@ +"""Behavior guard for the page-limit notification handler.""" + +from __future__ import annotations + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import SearchSpace, User +from app.notifications.service import NotificationService + +pytestmark = pytest.mark.integration + +handler = NotificationService.page_limit + + +async def test_page_limit_message_and_action( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +): + """A page-limit notification states usage and carries an upgrade action link.""" + notification = await handler.notify_page_limit_exceeded( + session=db_session, + user_id=db_user.id, + document_name="short.pdf", + document_type="FILE", + search_space_id=db_search_space.id, + pages_used=95, + pages_limit=100, + pages_to_add=10, + ) + + assert notification.type == "page_limit_exceeded" + assert notification.title == "Page limit exceeded: short.pdf" + assert notification.message == ( + "This document has ~10 page(s) but you've used 95/100 pages. " + "Upgrade to process more documents." + ) + assert notification.notification_metadata["status"] == "failed" + assert notification.notification_metadata["action_label"] == "Upgrade Plan" + assert notification.notification_metadata["action_url"] == ( + f"/dashboard/{db_search_space.id}/more-pages" + ) + + +async def test_page_limit_truncates_long_name( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +): + """A long document name is truncated in the notification title.""" + long_name = "a" * 50 + + notification = await handler.notify_page_limit_exceeded( + session=db_session, + user_id=db_user.id, + document_name=long_name, + document_type="FILE", + search_space_id=db_search_space.id, + pages_used=95, + pages_limit=100, + pages_to_add=10, + ) + + assert notification.title == f"Page limit exceeded: {'a' * 40}..." diff --git a/surfsense_backend/tests/integration/retriever/test_knowledge_search_date_filters.py b/surfsense_backend/tests/integration/retriever/test_knowledge_search_date_filters.py index 910d882a7..ce076b147 100644 --- a/surfsense_backend/tests/integration/retriever/test_knowledge_search_date_filters.py +++ b/surfsense_backend/tests/integration/retriever/test_knowledge_search_date_filters.py @@ -8,7 +8,10 @@ from datetime import UTC, datetime, timedelta import numpy as np import pytest -from app.agents.new_chat.middleware.knowledge_search import search_knowledge_base +from app.agents.chat.multi_agent_chat.shared.middleware import knowledge_search as ks +from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import ( + search_knowledge_base, +) from .conftest import DUMMY_EMBEDDING @@ -26,13 +29,9 @@ async def test_search_knowledge_base_applies_date_filters( async def fake_shielded_async_session(): yield db_session + monkeypatch.setattr(ks, "shielded_async_session", fake_shielded_async_session) monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.shielded_async_session", - fake_shielded_async_session, - ) - monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.embed_texts", - lambda texts: [np.array(DUMMY_EMBEDDING) for _ in texts], + ks, "embed_texts", lambda texts: [np.array(DUMMY_EMBEDDING) for _ in texts] ) space_id = seed_date_filtered_docs["search_space"].id diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py index 72408a5d9..45db9c901 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py @@ -14,16 +14,16 @@ from langgraph.graph import END, START, StateGraph from langgraph.types import Command, interrupt from typing_extensions import TypedDict -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import ( - subagent_invoke_config, -) -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.resume_routing import ( collect_pending_tool_calls, slice_decisions_by_tool_call, ) -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.task_tool import ( build_task_tool_with_parent_config, ) +from app.agents.chat.multi_agent_chat.subagents.shared.invocation import ( + subagent_invoke_config, +) class _SubagentState(TypedDict, total=False): diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_heterogeneous_decisions.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_heterogeneous_decisions.py index d4a68939e..dd895c54e 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_heterogeneous_decisions.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_heterogeneous_decisions.py @@ -40,12 +40,12 @@ from langgraph.graph.message import add_messages from langgraph.types import Command, Send, interrupt from typing_extensions import TypedDict -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.resume_routing import ( build_lg_resume_map, collect_pending_tool_calls, slice_decisions_by_tool_call, ) -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.task_tool import ( build_task_tool_with_parent_config, ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_partial_pause_routing.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_partial_pause_routing.py index 1aba0c480..7ac7686e9 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_partial_pause_routing.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_partial_pause_routing.py @@ -47,12 +47,12 @@ from langgraph.graph.message import add_messages from langgraph.types import Command, Send, interrupt from typing_extensions import TypedDict -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.resume_routing import ( build_lg_resume_map, collect_pending_tool_calls, slice_decisions_by_tool_call, ) -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.task_tool import ( build_task_tool_with_parent_config, ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_reject_only_routing.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_reject_only_routing.py index 5810d5394..a1bbb9e7a 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_reject_only_routing.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_reject_only_routing.py @@ -37,12 +37,12 @@ from langgraph.graph.message import add_messages from langgraph.types import Command, Send, interrupt from typing_extensions import TypedDict -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.resume_routing import ( build_lg_resume_map, collect_pending_tool_calls, slice_decisions_by_tool_call, ) -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.task_tool import ( build_task_tool_with_parent_config, ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py index 839cb7564..b082119e3 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py @@ -37,12 +37,12 @@ from langgraph.graph.message import add_messages from langgraph.types import Command, Send, interrupt from typing_extensions import TypedDict -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.resume_routing import ( build_lg_resume_map, collect_pending_tool_calls, slice_decisions_by_tool_call, ) -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.task_tool import ( build_task_tool_with_parent_config, ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_self_and_middleware_gated.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_self_and_middleware_gated.py index 921c4a9eb..2c098ef8a 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_self_and_middleware_gated.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_self_and_middleware_gated.py @@ -35,21 +35,21 @@ from langgraph.graph.message import add_messages from langgraph.types import Command, Send from typing_extensions import TypedDict -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.resume_routing import ( build_lg_resume_map, collect_pending_tool_calls, slice_decisions_by_tool_call, ) -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.task_tool import ( build_task_tool_with_parent_config, ) -from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import ( +from app.agents.chat.multi_agent_chat.shared.permissions import Rule +from app.agents.chat.multi_agent_chat.shared.permissions.ask.request import ( request_permission_decision, ) -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) -from app.agents.new_chat.permissions import Rule class _SubState(TypedDict, total=False): diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_tasks.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_tasks.py index 81be7d1ac..836822d34 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_tasks.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_tasks.py @@ -18,7 +18,7 @@ from langgraph.graph import END, START, StateGraph from langgraph.types import Command from typing_extensions import TypedDict -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.task_tool import ( build_task_tool_with_parent_config, ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_pending_interrupt.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_pending_interrupt.py index 75242689d..ec757bcf0 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_pending_interrupt.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_pending_interrupt.py @@ -9,7 +9,7 @@ from __future__ import annotations from types import SimpleNamespace -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.resume import ( get_first_pending_subagent_interrupt, ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_decision_routing.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_decision_routing.py index ceb0df830..62f33addc 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_decision_routing.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_decision_routing.py @@ -17,7 +17,7 @@ from types import SimpleNamespace import pytest -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.resume_routing import ( collect_pending_tool_calls, slice_decisions_by_tool_call, ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_helpers.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_helpers.py index e8aacfc5d..ba9d163a4 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_helpers.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_helpers.py @@ -4,7 +4,7 @@ from __future__ import annotations from langchain.tools import ToolRuntime -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.config import ( consume_surfsense_resume, has_surfsense_resume, ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_interrupt_stamping.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_interrupt_stamping.py index 7df9dedc6..4bc0ecace 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_interrupt_stamping.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_interrupt_stamping.py @@ -30,7 +30,7 @@ from langgraph.graph import END, START, StateGraph from langgraph.types import Send, interrupt from typing_extensions import TypedDict -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.task_tool import ( build_task_tool_with_parent_config, ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_invoke_config.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_invoke_config.py index 3465dd1d8..5044d8fbe 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_invoke_config.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_invoke_config.py @@ -16,7 +16,7 @@ from __future__ import annotations from langchain.tools import ToolRuntime -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import ( +from app.agents.chat.multi_agent_chat.subagents.shared.invocation import ( subagent_invoke_config, ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_lc_hitl_wire.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_lc_hitl_wire.py index a331190b2..3f89a9707 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_lc_hitl_wire.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_lc_hitl_wire.py @@ -16,10 +16,10 @@ from langgraph.graph import END, START, StateGraph from langgraph.types import Command from typing_extensions import TypedDict -from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import ( +from app.agents.chat.multi_agent_chat.shared.permissions import Rule +from app.agents.chat.multi_agent_chat.shared.permissions.ask.request import ( request_permission_decision, ) -from app.agents.new_chat.permissions import Rule class _State(TypedDict, total=False): diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py index c9bd4e142..33256c2ff 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py @@ -13,14 +13,15 @@ from langgraph.graph.message import add_messages from pydantic import BaseModel from typing_extensions import TypedDict -from app.agents.multi_agent_chat.middleware.shared.permissions import ( +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.permissions import ( + Rule, + Ruleset, build_permission_mw, ) -from app.agents.multi_agent_chat.middleware.shared.permissions.ask.payload import ( +from app.agents.chat.multi_agent_chat.shared.permissions.ask.payload import ( build_permission_ask_payload, ) -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.permissions import Rule, Ruleset class _NoArgs(BaseModel): diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_subagent_owned_ruleset.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_subagent_owned_ruleset.py index b9ac6cd15..66dec22b0 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_subagent_owned_ruleset.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_subagent_owned_ruleset.py @@ -23,11 +23,12 @@ from langgraph.graph.message import add_messages from langgraph.types import Command from typing_extensions import TypedDict -from app.agents.multi_agent_chat.middleware.shared.permissions import ( +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.permissions import ( + Rule, + Ruleset, build_permission_mw, ) -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.permissions import Rule, Ruleset def _kb_style_ruleset() -> Ruleset: diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_trusted_tool_save_on_always.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_trusted_tool_save_on_always.py index 47d3704ac..479d607f7 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_trusted_tool_save_on_always.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_trusted_tool_save_on_always.py @@ -14,11 +14,12 @@ from langgraph.types import Command from pydantic import BaseModel from typing_extensions import TypedDict -from app.agents.multi_agent_chat.middleware.shared.permissions import ( +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.permissions import ( + Rule, + Ruleset, build_permission_mw, ) -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.permissions import Rule, Ruleset class _NoArgs(BaseModel): diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/test_lc_hitl_wire.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/test_lc_hitl_wire.py index 195b1bc01..a33d11358 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/test_lc_hitl_wire.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/test_lc_hitl_wire.py @@ -22,7 +22,7 @@ from langgraph.graph import END, START, StateGraph from langgraph.types import Command from typing_extensions import TypedDict -from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/wire/test_hitl_wire.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/wire/test_hitl_wire.py index c06f9a627..cdaa4d71d 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/wire/test_hitl_wire.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/wire/test_hitl_wire.py @@ -18,7 +18,7 @@ These tests pin the shape: from __future__ import annotations -from app.agents.multi_agent_chat.subagents.shared.hitl.wire import ( +from app.agents.chat.multi_agent_chat.subagents.shared.hitl.wire import ( LC_DECISION_APPROVE, LC_DECISION_EDIT, LC_DECISION_REJECT, diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py index 062ea92ec..2f3553a27 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py @@ -19,14 +19,14 @@ from langchain_core.language_models.fake_chat_models import ( from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.outputs import ChatGeneration, ChatResult -from app.agents.multi_agent_chat.middleware.shared.permissions.middleware.core import ( +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.permissions import Rule, Ruleset, evaluate +from app.agents.chat.multi_agent_chat.shared.permissions.middleware.core import ( PermissionMiddleware, ) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( +from app.agents.chat.multi_agent_chat.subagents.shared.subagent_builder import ( pack_subagent, ) -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.permissions import Rule, Ruleset, evaluate class RateLimitError(Exception): diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/test_prompt_resources.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/test_prompt_resources.py new file mode 100644 index 000000000..ccdfc0b98 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/test_prompt_resources.py @@ -0,0 +1,59 @@ +"""Guardrail C: package-relative prompt/snippet resources must resolve. + +Prompt fragments are loaded by *package name* via ``importlib.resources`` — not +by import, so the import-all smoke test (guardrail A) cannot see them, and not +by mocked unit tests. A move that relocates a package without its ``.md`` files, +or that leaves a hardcoded package string stale, returns an empty string and +silently degrades the system prompt. These tests assert the resources still +resolve to non-empty content. + +(Builtin skill resources are covered separately by ``test_skills_backends.py``.) +""" + +from __future__ import annotations + +import pytest + +from app.agents.chat.multi_agent_chat.main_agent.system_prompt.builder.load_md import ( + read_prompt_md, +) +from app.agents.chat.multi_agent_chat.subagents.registry import ( + SUBAGENT_BUILDERS_BY_NAME, + _route_resource_package, +) +from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import ( + read_md_file, + read_shared_snippet, +) + +pytestmark = pytest.mark.unit + + +@pytest.mark.parametrize("name", sorted(SUBAGENT_BUILDERS_BY_NAME)) +def test_every_subagent_has_description_md(name: str): + """Each specialist ships a non-empty ``description.md`` next to its agent.""" + package = _route_resource_package(SUBAGENT_BUILDERS_BY_NAME[name]) + assert read_md_file(package, "description").strip(), ( + f"{name}: description.md missing/empty at package {package}" + ) + + +# Real fragments under the hardcoded main-agent prompts package, including a +# nested path — guards both the package string and nested resource resolution. +@pytest.mark.parametrize( + "filename", + [ + "core_behavior.md", + "routing.md", + "tools/web_search/description.md", + ], +) +def test_main_agent_prompt_fragments_resolve(filename: str): + """Main-agent prompt fragments resolve to non-empty content.""" + assert read_prompt_md(filename).strip(), f"prompt fragment {filename} is empty" + + +@pytest.mark.parametrize("snippet", ["output_contract_base", "verifiable_handle"]) +def test_shared_snippets_resolve(snippet: str): + """Shared subagent snippets resolve from the snippets package.""" + assert read_shared_snippet(snippet).strip(), f"snippet {snippet} is empty" diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/test_subagent_composition.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/test_subagent_composition.py new file mode 100644 index 000000000..157f1703b --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/test_subagent_composition.py @@ -0,0 +1,72 @@ +"""Guardrail B: the subagent registry composition must stay intact. + +A structural move can silently drop, rename, or mis-wire a subagent builder +(e.g. a forgotten import line). The compiled agent would then quietly lose a +specialist with no ImportError. This test pins the exact registry contents and +their cross-references so any such drift fails loudly. +""" + +from __future__ import annotations + +import pytest + +from app.agents.chat.multi_agent_chat.constants import ( + SUBAGENT_TO_REQUIRED_CONNECTOR_MAP, +) +from app.agents.chat.multi_agent_chat.subagents.registry import ( + SUBAGENT_BUILDERS_BY_NAME, +) + +pytestmark = pytest.mark.unit + +# The full specialist roster the main agent composes from: 4 builtins + 15 +# connector routes. Adding/removing a specialist is a deliberate product change +# and must be reflected here. +_EXPECTED_SUBAGENTS = frozenset( + { + "airtable", + "calendar", + "clickup", + "confluence", + "deliverables", + "discord", + "dropbox", + "gmail", + "google_drive", + "jira", + "knowledge_base", + "linear", + "luma", + "memory", + "notion", + "onedrive", + "research", + "slack", + "teams", + } +) + +# Specialists that are always available regardless of connected sources, so they +# carry no required-connector entry. +_CONNECTORLESS = frozenset({"memory", "research"}) + + +def test_registry_contains_exactly_expected_subagents(): + """No specialist is silently added, dropped, or renamed by a move.""" + assert set(SUBAGENT_BUILDERS_BY_NAME) == _EXPECTED_SUBAGENTS + + +def test_every_builder_is_callable_route_agent(): + """Each registry value is a callable defined in its route's ``agent`` module.""" + for name, builder in SUBAGENT_BUILDERS_BY_NAME.items(): + assert callable(builder), f"{name} builder is not callable" + assert builder.__module__.endswith(".agent"), ( + f"{name} builder lives in {builder.__module__}, expected a *.agent module" + ) + + +def test_required_connector_map_covers_connector_subagents(): + """The connector-gating map stays in lockstep with the registry.""" + assert set(SUBAGENT_TO_REQUIRED_CONNECTOR_MAP) == ( + _EXPECTED_SUBAGENTS - _CONNECTORLESS + ) diff --git a/surfsense_backend/tests/unit/agents/new_chat/middleware/test_scoped_model_fallback.py b/surfsense_backend/tests/unit/agents/new_chat/middleware/test_scoped_model_fallback.py index 80b9862e7..361a23f41 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/middleware/test_scoped_model_fallback.py +++ b/surfsense_backend/tests/unit/agents/new_chat/middleware/test_scoped_model_fallback.py @@ -87,7 +87,7 @@ class RateLimitError(Exception): def _build_agent(primary: BaseChatModel, fallback: BaseChatModel): from langchain.agents import create_agent - from app.agents.new_chat.middleware.scoped_model_fallback import ( + from app.agents.chat.multi_agent_chat.shared.middleware.resilience.scoped_model_fallback import ( ScopedModelFallbackMiddleware, ) diff --git a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py index 36fe04aa2..4f0369e12 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py +++ b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py @@ -6,12 +6,12 @@ from datetime import UTC, datetime import pytest -from app.agents.new_chat.prompts.composer import ( +from app.db import ChatVisibility +from app.prompts.system_prompt_composer.composer import ( ALL_TOOL_NAMES_ORDERED, compose_system_prompt, detect_provider_variant, ) -from app.db import ChatVisibility pytestmark = pytest.mark.unit @@ -64,7 +64,7 @@ class TestProviderVariantDetection: ``gpt-5`` reasoning regex first. Codex is the more specialised prompt and mirrors OpenCode's dispatch order. """ - from app.agents.new_chat.prompts.composer import detect_provider_variant + from app.prompts.system_prompt_composer.composer import detect_provider_variant assert detect_provider_variant("openai:gpt-5-codex") == "openai_codex" assert detect_provider_variant("openai:gpt-5") == "openai_reasoning" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py index 8ef1430a9..e476538bd 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py @@ -10,9 +10,11 @@ import pytest from langchain_core.messages import ToolMessage from langchain_core.tools import tool -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.middleware.action_log import ActionLogMiddleware -from app.agents.new_chat.tools.registry import ToolDefinition +from app.agents.chat.multi_agent_chat.main_agent.middleware.action_log.middleware import ( + ActionLogMiddleware, + ToolDefinition, +) +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags @dataclass @@ -58,7 +60,7 @@ def _disabled_flags() -> AgentFeatureFlags: def patch_get_flags(): def _patch(flags: AgentFeatureFlags): return patch( - "app.agents.new_chat.middleware.action_log.get_flags", + "app.agents.chat.multi_agent_chat.main_agent.middleware.action_log.middleware.get_flags", return_value=flags, ) @@ -360,7 +362,7 @@ class TestActionLogDispatch: patch_get_flags(_enabled_flags()), patch("app.db.shielded_async_session", side_effect=lambda: factory()), patch( - "app.agents.new_chat.middleware.action_log.adispatch_custom_event", + "app.agents.chat.multi_agent_chat.main_agent.middleware.action_log.middleware.adispatch_custom_event", dispatch_mock, ), ): @@ -395,7 +397,7 @@ class TestActionLogDispatch: patch_get_flags(_enabled_flags()), patch("app.db.shielded_async_session", side_effect=_exploding_session), patch( - "app.agents.new_chat.middleware.action_log.adispatch_custom_event", + "app.agents.chat.multi_agent_chat.main_agent.middleware.action_log.middleware.adispatch_custom_event", dispatch_mock, ), ): diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py b/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py index 9b3de2db7..ecc5a1a83 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py @@ -16,7 +16,7 @@ from dataclasses import dataclass import pytest -from app.agents.new_chat.agent_cache import ( +from app.agents.chat.multi_agent_chat.main_agent.runtime.agent_cache_store import ( flags_signature, reload_for_tests, stable_hash, diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py index f0161f605..5a39c6e66 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py @@ -4,8 +4,7 @@ from __future__ import annotations import pytest -from app.agents.new_chat.errors import BusyError -from app.agents.new_chat.middleware.busy_mutex import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import ( BusyMutexMiddleware, end_turn, get_cancel_event, @@ -14,6 +13,7 @@ from app.agents.new_chat.middleware.busy_mutex import ( request_cancel, reset_cancel, ) +from app.agents.chat.runtime.errors import BusyError pytestmark = pytest.mark.unit diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py b/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py index c6d4cc452..2ac462959 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py @@ -10,7 +10,7 @@ from langchain_core.messages import ( ToolMessage, ) -from app.agents.new_chat.middleware.compaction import ( +from app.agents.chat.shared.middleware.compaction import ( PROTECTED_SYSTEM_PREFIXES, _is_protected_system_message, _sanitize_message_content, @@ -72,7 +72,7 @@ class TestPartitionMessages: # SurfSenseCompactionMiddleware without a real model, but the # override path needs ``_lc_helper`` to delegate to. We mock # that with a simple slicing partitioner equivalent to the real one. - from app.agents.new_chat.middleware.compaction import ( + from app.agents.chat.shared.middleware.compaction import ( SurfSenseCompactionMiddleware, ) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py b/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py index ba2246413..9632fd14d 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py @@ -7,7 +7,7 @@ from typing import Any import pytest from langchain_core.messages import AIMessage, HumanMessage, ToolMessage -from app.agents.new_chat.middleware.context_editing import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.context_editing.middleware import ( SpillToBackendEdit, _build_spill_placeholder, ) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py b/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py index 61d9b499f..61a04c1c1 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py @@ -6,7 +6,7 @@ import pytest from langchain_core.messages import AIMessage from langchain_core.tools import StructuredTool -from app.agents.new_chat.middleware.dedup_tool_calls import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.dedup_hitl import ( DedupHITLToolCallsMiddleware, ) @@ -91,10 +91,9 @@ def test_no_agent_tools_means_no_dedup() -> None: """After the cleanup tier removed the legacy ``_NATIVE_HITL_TOOL_DEDUP_KEYS`` map, dedup is purely declarative — no resolvers means no dedup runs. - Coverage for the previously hardcoded native HITL tools now lives on - each :class:`ToolDefinition.dedup_key` in - :mod:`app.agents.new_chat.tools.registry`, which is wired through to - ``tool.metadata`` by :func:`build_tools`. + Dedup is purely declarative: tools opt in by carrying a ``dedup_key`` + (callable) or ``hitl_dedup_key`` (arg name) in their ``metadata``. With no + agent tools, there are no resolvers and dedup is a no-op. """ mw = DedupHITLToolCallsMiddleware(agent_tools=None) state = { @@ -109,27 +108,6 @@ def test_no_agent_tools_means_no_dedup() -> None: assert out is None -def test_registry_propagates_dedup_key_to_tool_metadata() -> None: - """Smoke-check the wiring path that replaced the legacy native map. - - ``ToolDefinition.dedup_key`` set in the registry must be copied onto - the constructed tool's ``metadata`` so :class:`DedupHITLToolCallsMiddleware` - can pick it up at agent build time. - """ - from app.agents.new_chat.tools.registry import ( - BUILTIN_TOOLS, - wrap_dedup_key_by_arg_name, - ) - - notion_tool_defs = [t for t in BUILTIN_TOOLS if t.name == "create_notion_page"] - assert notion_tool_defs, "registry should still expose create_notion_page" - tool_def = notion_tool_defs[0] - assert tool_def.dedup_key is not None - # Same wrapping helper used in the registry — sanity check identity - sample = wrap_dedup_key_by_arg_name("title")({"title": "Plan"}) - assert sample == "plan" - - def test_full_args_dedup_keeps_distinct_calls_sharing_a_field() -> None: """Regression: MCP tools (e.g. ``createJiraIssue``) used to dedup on the schema's first required field, which is often the workspace / @@ -137,7 +115,9 @@ def test_full_args_dedup_keeps_distinct_calls_sharing_a_field() -> None: With :func:`dedup_key_full_args` only fully identical arg dicts dedup. """ - from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args + from app.agents.chat.multi_agent_chat.shared.middleware.dedup_tool_calls import ( + dedup_key_full_args, + ) tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args) mw = DedupHITLToolCallsMiddleware(agent_tools=[tool]) @@ -179,7 +159,9 @@ def test_full_args_dedup_keeps_distinct_calls_sharing_a_field() -> None: def test_full_args_dedup_drops_only_exact_duplicates() -> None: - from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args + from app.agents.chat.multi_agent_chat.shared.middleware.dedup_tool_calls import ( + dedup_key_full_args, + ) tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args) mw = DedupHITLToolCallsMiddleware(agent_tools=[tool]) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py b/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py index 2f222e148..b6341bfec 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py @@ -17,7 +17,7 @@ caused two production-painful behaviors: read-only tool calls, raising ``RejectedError("ls")``. * Mutating connector tools got *double* prompted — once via the middleware ``ask`` and again via the per-tool ``interrupt()`` in - ``app.agents.new_chat.tools.hitl``. + ``app.agents.chat.multi_agent_chat.shared.tools.hitl``. These tests pin the layering so a refactor that drops the default ruleset fails loud. @@ -27,7 +27,7 @@ from __future__ import annotations import pytest -from app.agents.new_chat.permissions import ( +from app.agents.chat.multi_agent_chat.shared.permissions import ( Rule, Ruleset, aggregate_action, diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py b/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py index 653175eab..62712e797 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py @@ -10,8 +10,7 @@ from __future__ import annotations import pytest -from app.agents.new_chat.middleware.permission import PermissionMiddleware -from app.agents.new_chat.permissions import ( +from app.agents.chat.multi_agent_chat.shared.permissions import ( Rule, Ruleset, aggregate_action, @@ -87,36 +86,3 @@ class TestDesktopSafetyOverridesAllowDefault: # Correct order: defaults < desktop_safety -> ask wins. action = _action_for("rm", SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET) assert action == "ask" - - -class TestPermissionMiddlewareIntegration: - def test_middleware_raises_interrupt_for_rm_in_desktop_mode(self) -> None: - from langchain_core.messages import AIMessage - - from app.agents.new_chat.errors import RejectedError - - mw = PermissionMiddleware(rulesets=[SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET]) - # Stub the interrupt to a "reject" decision so we can assert the - # ask path was taken without spinning up the LangGraph runtime. - mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment] - - state = { - "messages": [ - AIMessage( - content="", - tool_calls=[ - { - "name": "rm", - "args": {"path": "/Users/me/Documents/important.docx"}, - "id": "tc-rm", - } - ], - ) - ] - } - - class _FakeRuntime: - config: dict = {"configurable": {"thread_id": "test"}} - - with pytest.raises(RejectedError): - mw.after_model(state, _FakeRuntime()) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py b/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py index 802129bf6..47e962242 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py @@ -5,7 +5,10 @@ from __future__ import annotations import pytest from langchain_core.messages import AIMessage -from app.agents.new_chat.middleware.doom_loop import DoomLoopMiddleware, _signature +from app.agents.chat.multi_agent_chat.main_agent.middleware.doom_loop.middleware import ( + DoomLoopMiddleware, + _signature, +) pytestmark = pytest.mark.unit diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py index 099aea882..e715a80c6 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py @@ -4,7 +4,7 @@ from __future__ import annotations import pytest -from app.agents.new_chat.feature_flags import ( +from app.agents.chat.multi_agent_chat.shared.feature_flags import ( AgentFeatureFlags, reload_for_tests, ) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py b/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py deleted file mode 100644 index 6c323d920..000000000 --- a/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py +++ /dev/null @@ -1,344 +0,0 @@ -"""Tests for ``FlattenSystemMessageMiddleware``. - -The middleware exists to defend against Anthropic's "Found 5 cache_control -blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on -the system message and the OpenRouter→Anthropic adapter redistributes -``cache_control`` across all of them. The flattening collapses every -all-text system content list to a single string before the LLM call. -""" - -from __future__ import annotations - -from typing import Any -from unittest.mock import MagicMock - -import pytest -from langchain_core.messages import HumanMessage, SystemMessage - -from app.agents.new_chat.middleware.flatten_system import ( - FlattenSystemMessageMiddleware, - _flatten_text_blocks, - _flattened_request, -) - -pytestmark = pytest.mark.unit - - -# --------------------------------------------------------------------------- -# _flatten_text_blocks — pure helper, the heart of the middleware. -# --------------------------------------------------------------------------- - - -class TestFlattenTextBlocks: - def test_joins_text_blocks_with_double_newline(self) -> None: - blocks = [ - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - ] - assert ( - _flatten_text_blocks(blocks) - == "\n\n\n\n" - ) - - def test_handles_single_text_block(self) -> None: - blocks = [{"type": "text", "text": "only one"}] - assert _flatten_text_blocks(blocks) == "only one" - - def test_handles_empty_list(self) -> None: - assert _flatten_text_blocks([]) == "" - - def test_passes_through_bare_string_blocks(self) -> None: - # LangChain content can mix bare strings and dict blocks. - blocks = ["raw string", {"type": "text", "text": "dict block"}] - assert _flatten_text_blocks(blocks) == "raw string\n\ndict block" - - def test_returns_none_for_image_block(self) -> None: - # System messages with images are rare — but we never want to - # silently lose the image payload by joining as text. - blocks = [ - {"type": "text", "text": "look at this"}, - {"type": "image_url", "image_url": {"url": "data:image/png..."}}, - ] - assert _flatten_text_blocks(blocks) is None - - def test_returns_none_for_non_dict_non_str_block(self) -> None: - blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item] - assert _flatten_text_blocks(blocks) is None - - def test_returns_none_when_text_field_missing(self) -> None: - blocks = [{"type": "text"}] # no ``text`` key - assert _flatten_text_blocks(blocks) is None - - def test_returns_none_when_text_is_not_string(self) -> None: - blocks = [{"type": "text", "text": ["nested", "list"]}] - assert _flatten_text_blocks(blocks) is None - - def test_drops_cache_control_from_inner_blocks(self) -> None: - # The whole point: existing cache_control on inner blocks is - # discarded so LiteLLM's ``cache_control_injection_points`` can - # re-attach exactly one breakpoint after flattening. - blocks = [ - {"type": "text", "text": "first"}, - { - "type": "text", - "text": "second", - "cache_control": {"type": "ephemeral"}, - }, - ] - flattened = _flatten_text_blocks(blocks) - assert flattened == "first\n\nsecond" - assert "cache_control" not in flattened # type: ignore[operator] - - -# --------------------------------------------------------------------------- -# _flattened_request — decides when to override and when to no-op. -# --------------------------------------------------------------------------- - - -def _make_request(system_message: SystemMessage | None) -> Any: - """Build a minimal ModelRequest stub. We only need .system_message - and .override(system_message=...) — the middleware never touches - other fields. - """ - request = MagicMock() - request.system_message = system_message - - def override(**kwargs: Any) -> Any: - new_request = MagicMock() - new_request.system_message = kwargs.get( - "system_message", request.system_message - ) - new_request.messages = kwargs.get("messages", getattr(request, "messages", [])) - new_request.tools = kwargs.get("tools", getattr(request, "tools", [])) - return new_request - - request.override = override - return request - - -class TestFlattenedRequest: - def test_collapses_multi_block_system_to_string(self) -> None: - sys = SystemMessage( - content=[ - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - ] - ) - request = _make_request(sys) - flattened = _flattened_request(request) - - assert flattened is not None - assert isinstance(flattened.system_message, SystemMessage) - assert flattened.system_message.content == ( - "\n\n\n\n\n\n\n\n" - ) - - def test_no_op_for_string_content(self) -> None: - sys = SystemMessage(content="already a string") - request = _make_request(sys) - assert _flattened_request(request) is None - - def test_no_op_for_single_block_list(self) -> None: - # One block already produces one breakpoint — no need to flatten. - sys = SystemMessage(content=[{"type": "text", "text": "single"}]) - request = _make_request(sys) - assert _flattened_request(request) is None - - def test_no_op_when_system_message_missing(self) -> None: - request = _make_request(None) - assert _flattened_request(request) is None - - def test_no_op_when_list_contains_non_text_block(self) -> None: - sys = SystemMessage( - content=[ - {"type": "text", "text": "look"}, - {"type": "image_url", "image_url": {"url": "data:..."}}, - ] - ) - request = _make_request(sys) - assert _flattened_request(request) is None - - def test_preserves_additional_kwargs_and_metadata(self) -> None: - # Defensive: nothing in the current chain sets these on a system - # message, but losing them silently when something does in the - # future would be a regression. ``name`` in particular is the only - # ``additional_kwargs`` field that ChatLiteLLM's - # ``_convert_message_to_dict`` propagates onto the wire. - sys = SystemMessage( - content=[ - {"type": "text", "text": "a"}, - {"type": "text", "text": "b"}, - ], - additional_kwargs={"name": "surfsense_system", "x": 1}, - response_metadata={"tokens": 42}, - ) - sys.id = "sys-msg-1" - request = _make_request(sys) - - flattened = _flattened_request(request) - assert flattened is not None - assert flattened.system_message.content == "a\n\nb" - assert flattened.system_message.additional_kwargs == { - "name": "surfsense_system", - "x": 1, - } - assert flattened.system_message.response_metadata == {"tokens": 42} - assert flattened.system_message.id == "sys-msg-1" - - def test_idempotent_when_run_twice(self) -> None: - sys = SystemMessage( - content=[ - {"type": "text", "text": "a"}, - {"type": "text", "text": "b"}, - ] - ) - request = _make_request(sys) - first = _flattened_request(request) - assert first is not None - - # Second pass on the already-flattened request should be a no-op. - # We re-wrap in a request stub since the helper inspects - # ``request.system_message.content``. - second_request = _make_request(first.system_message) - assert _flattened_request(second_request) is None - - -# --------------------------------------------------------------------------- -# Middleware integration — verify the handler sees a flattened request. -# --------------------------------------------------------------------------- - - -class TestMiddlewareWrap: - @pytest.mark.asyncio - async def test_async_passes_flattened_request_to_handler(self) -> None: - sys = SystemMessage( - content=[ - {"type": "text", "text": "alpha"}, - {"type": "text", "text": "beta"}, - ] - ) - request = _make_request(sys) - captured: dict[str, Any] = {} - - async def handler(req: Any) -> str: - captured["request"] = req - return "ok" - - mw = FlattenSystemMessageMiddleware() - result = await mw.awrap_model_call(request, handler) - - assert result == "ok" - assert isinstance(captured["request"].system_message, SystemMessage) - assert captured["request"].system_message.content == "alpha\n\nbeta" - - @pytest.mark.asyncio - async def test_async_passes_through_when_already_string(self) -> None: - sys = SystemMessage(content="just a string") - request = _make_request(sys) - captured: dict[str, Any] = {} - - async def handler(req: Any) -> str: - captured["request"] = req - return "ok" - - mw = FlattenSystemMessageMiddleware() - await mw.awrap_model_call(request, handler) - - # Same request object: no override happened. - assert captured["request"] is request - - def test_sync_passes_flattened_request_to_handler(self) -> None: - sys = SystemMessage( - content=[ - {"type": "text", "text": "alpha"}, - {"type": "text", "text": "beta"}, - ] - ) - request = _make_request(sys) - captured: dict[str, Any] = {} - - def handler(req: Any) -> str: - captured["request"] = req - return "ok" - - mw = FlattenSystemMessageMiddleware() - result = mw.wrap_model_call(request, handler) - - assert result == "ok" - assert captured["request"].system_message.content == "alpha\n\nbeta" - - def test_sync_passes_through_when_no_system_message(self) -> None: - request = _make_request(None) - captured: dict[str, Any] = {} - - def handler(req: Any) -> str: - captured["request"] = req - return "ok" - - mw = FlattenSystemMessageMiddleware() - mw.wrap_model_call(request, handler) - assert captured["request"] is request - - -# --------------------------------------------------------------------------- -# Regression guard — pin the worst-case shape that triggered the -# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the -# downstream cache_control_injection_points can only place 1 breakpoint -# on the system message regardless of provider redistribution quirks. -# --------------------------------------------------------------------------- - - -def test_regression_five_block_system_collapses_to_one_block() -> None: - sys = SystemMessage( - content=[ - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - ] - ) - request = _make_request(sys) - flattened = _flattened_request(request) - - assert flattened is not None - assert isinstance(flattened.system_message.content, str) - # The exact join doesn't matter for the cache_control accounting — - # only that there is exactly ONE content block when LiteLLM's - # AnthropicCacheControlHook later targets ``role: system``. - assert " None: - # Sanity: the middleware MUST NOT touch user messages — only the - # system message. Multi-block user content is the path that carries - # image attachments and would lose its image_url block on - # accidental flatten. - sys = SystemMessage( - content=[ - {"type": "text", "text": "a"}, - {"type": "text", "text": "b"}, - ] - ) - user = HumanMessage( - content=[ - {"type": "text", "text": "look at this"}, - {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}, - ] - ) - request = _make_request(sys) - request.messages = [user] - - flattened = _flattened_request(request) - assert flattened is not None - # System flattened to string … - assert isinstance(flattened.system_message.content, str) - # … user message is untouched (the helper does not even look at it). - assert flattened.messages == [user] - assert isinstance(user.content, list) - assert len(user.content) == 2 diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py b/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py index d0ea73376..9c19cbd6b 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py @@ -10,7 +10,7 @@ from __future__ import annotations import pytest -from app.agents.new_chat.tools.hitl import ( +from app.agents.chat.multi_agent_chat.shared.tools.hitl import ( DEFAULT_AUTO_APPROVED_TOOLS, HITLResult, request_approval, diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_mention_resolver.py b/surfsense_backend/tests/unit/agents/new_chat/test_mention_resolver.py index 1f8d35841..4130c9d4e 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_mention_resolver.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_mention_resolver.py @@ -15,14 +15,17 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from app.agents.new_chat import mention_resolver -from app.agents.new_chat.mention_resolver import ( +from app.agents.chat.runtime import mention_resolver +from app.agents.chat.runtime.mention_resolver import ( ResolvedMention, ResolvedMentionSet, resolve_mentions, substitute_in_text, ) -from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, PathIndex +from app.agents.chat.runtime.path_resolver import ( + DOCUMENTS_ROOT, + PathIndex, +) from app.schemas.new_chat import MentionedDocumentInfo pytestmark = pytest.mark.unit diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py b/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py index 346271f4b..42df4eecf 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py @@ -5,7 +5,7 @@ from __future__ import annotations import pytest from langchain_core.messages import AIMessage, HumanMessage -from app.agents.new_chat.middleware.noop_injection import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.noop_injection.middleware import ( NOOP_TOOL_NAME, NoopInjectionMiddleware, _last_ai_has_tool_calls, diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py b/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py index dc59c6dac..e2978d277 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py @@ -8,7 +8,7 @@ from unittest.mock import MagicMock import pytest from langchain_core.messages import AIMessage, ToolMessage -from app.agents.new_chat.middleware.otel_span import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.otel_span.middleware import ( OtelSpanMiddleware, _annotate_model_response, _annotate_tool_result, @@ -206,13 +206,13 @@ class TestMiddlewareIntegration: duration_calls: list[dict[str, Any]] = [] token_calls: list[dict[str, Any]] = [] monkeypatch.setattr( - "app.agents.new_chat.middleware.otel_span.ot_metrics.record_model_call_duration", + "app.agents.chat.multi_agent_chat.main_agent.middleware.otel_span.middleware.ot_metrics.record_model_call_duration", lambda duration_ms, **attrs: duration_calls.append( {"duration_ms": duration_ms, **attrs} ), ) monkeypatch.setattr( - "app.agents.new_chat.middleware.otel_span.ot_metrics.record_model_token_usage", + "app.agents.chat.multi_agent_chat.main_agent.middleware.otel_span.middleware.ot_metrics.record_model_token_usage", lambda **attrs: token_calls.append(attrs), ) @@ -257,11 +257,11 @@ class TestMiddlewareIntegration: errors: list[str] = [] monkeypatch.setattr( - "app.agents.new_chat.middleware.otel_span.ot_metrics.record_tool_call_error", + "app.agents.chat.multi_agent_chat.main_agent.middleware.otel_span.middleware.ot_metrics.record_tool_call_error", lambda *, tool_name: errors.append(tool_name), ) monkeypatch.setattr( - "app.agents.new_chat.middleware.otel_span.ot_metrics.record_tool_call_duration", + "app.agents.chat.multi_agent_chat.main_agent.middleware.otel_span.middleware.ot_metrics.record_tool_call_duration", lambda *args, **kwargs: None, ) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py b/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py index ac6f61767..2617bff8e 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py @@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from app.agents.new_chat.path_resolver import ( +from app.agents.chat.runtime.path_resolver import ( DOCUMENTS_ROOT, PathIndex, doc_to_virtual_path, diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py b/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py deleted file mode 100644 index 68db11ba6..000000000 --- a/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py +++ /dev/null @@ -1,263 +0,0 @@ -"""Tests for PermissionMiddleware end-to-end behavior.""" - -from __future__ import annotations - -import pytest -from langchain_core.messages import AIMessage, ToolMessage - -from app.agents.new_chat.errors import CorrectedError, RejectedError -from app.agents.new_chat.middleware.permission import ( - PermissionMiddleware, - _normalize_permission_decision, -) -from app.agents.new_chat.permissions import Rule, Ruleset - -pytestmark = pytest.mark.unit - - -class _FakeRuntime: - config: dict = {"configurable": {"thread_id": "test"}} - - -def _msg(*tool_calls: dict) -> AIMessage: - return AIMessage(content="", tool_calls=list(tool_calls)) - - -class TestAllow: - def test_passthrough_when_allow(self) -> None: - rs = Ruleset(rules=[Rule("send_email", "*", "allow")]) - mw = PermissionMiddleware(rulesets=[rs]) - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - out = mw.after_model(state, _FakeRuntime()) - assert out is None # no change - - -class TestDeny: - def test_replaces_with_deny_tool_message(self) -> None: - rs = Ruleset(rules=[Rule("send_email", "*", "deny")]) - mw = PermissionMiddleware(rulesets=[rs]) - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - out = mw.after_model(state, _FakeRuntime()) - assert out is not None - msgs = out["messages"] - # Find the deny ToolMessage - deny_msgs = [m for m in msgs if isinstance(m, ToolMessage)] - assert len(deny_msgs) == 1 - assert deny_msgs[0].status == "error" - assert "permission_denied" in str(deny_msgs[0].additional_kwargs) - # AIMessage's tool_calls should now be empty (denied call removed) - ai_msg = next(m for m in msgs if isinstance(m, AIMessage)) - assert ai_msg.tool_calls == [] - - def test_mixed_allow_deny(self) -> None: - rs = Ruleset( - rules=[ - Rule("send_email", "*", "deny"), - Rule("read", "*", "allow"), - ] - ) - mw = PermissionMiddleware(rulesets=[rs]) - state = { - "messages": [ - _msg( - {"name": "send_email", "args": {}, "id": "1"}, - {"name": "read", "args": {}, "id": "2"}, - ) - ] - } - out = mw.after_model(state, _FakeRuntime()) - assert out is not None - ai_msg = next(m for m in out["messages"] if isinstance(m, AIMessage)) - assert len(ai_msg.tool_calls) == 1 - assert ai_msg.tool_calls[0]["name"] == "read" - - -class TestAsk: - def test_reject_without_feedback_raises(self) -> None: - # Default: nothing matches -> ask - rs = Ruleset(rules=[]) - mw = PermissionMiddleware(rulesets=[rs]) - - # Bypass real interrupt — patch the helper - mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment] - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - with pytest.raises(RejectedError): - mw.after_model(state, _FakeRuntime()) - - def test_reject_with_feedback_raises_corrected(self) -> None: - rs = Ruleset(rules=[]) - mw = PermissionMiddleware(rulesets=[rs]) - mw._raise_interrupt = lambda **kw: { # type: ignore[assignment] - "decision_type": "reject", - "feedback": "use a different subject line", - } - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - with pytest.raises(CorrectedError) as excinfo: - mw.after_model(state, _FakeRuntime()) - assert excinfo.value.feedback == "use a different subject line" - - def test_once_proceeds_without_persisting(self) -> None: - mw = PermissionMiddleware(rulesets=[]) - mw._raise_interrupt = lambda **kw: {"decision_type": "once"} # type: ignore[assignment] - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - out = mw.after_model(state, _FakeRuntime()) - # No state change because all calls kept - assert out is None - # No new rule persisted - assert mw._runtime_ruleset.rules == [] - - def test_approve_always_persists_runtime_rule(self) -> None: - mw = PermissionMiddleware(rulesets=[]) - mw._raise_interrupt = lambda **kw: {"decision_type": "approve_always"} # type: ignore[assignment] - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - out = mw.after_model(state, _FakeRuntime()) - assert out is None # call kept - # Runtime ruleset got the always-allow rule - new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"] - assert any(r.permission == "send_email" for r in new_rules) - - -class TestNormalizeDecision: - """Resume shapes ``_normalize_permission_decision`` must accept.""" - - def test_legacy_decision_type_dict_passes_through(self) -> None: - decision = {"decision_type": "once"} - assert _normalize_permission_decision(decision) == {"decision_type": "once"} - - def test_legacy_decision_type_with_feedback_passes_through(self) -> None: - decision = {"decision_type": "reject", "feedback": "no thanks"} - assert _normalize_permission_decision(decision) == decision - - def test_plain_string_wrapped(self) -> None: - assert _normalize_permission_decision("once") == {"decision_type": "once"} - assert _normalize_permission_decision("reject") == {"decision_type": "reject"} - - def test_lc_envelope_approve_maps_to_once(self) -> None: - decision = {"decisions": [{"type": "approve"}]} - assert _normalize_permission_decision(decision) == {"decision_type": "once"} - - def test_lc_envelope_reject_maps_to_reject(self) -> None: - decision = {"decisions": [{"type": "reject"}]} - assert _normalize_permission_decision(decision) == {"decision_type": "reject"} - - def test_lc_envelope_reject_with_message_carries_feedback(self) -> None: - decision = {"decisions": [{"type": "reject", "message": "wrong recipient"}]} - out = _normalize_permission_decision(decision) - assert out == {"decision_type": "reject", "feedback": "wrong recipient"} - - def test_lc_envelope_reject_with_feedback_field(self) -> None: - decision = { - "decisions": [{"type": "reject", "feedback": "tighten the subject"}] - } - out = _normalize_permission_decision(decision) - assert out == {"decision_type": "reject", "feedback": "tighten the subject"} - - def test_lc_envelope_edit_maps_to_once(self) -> None: - # Pins the contract: edited args are NOT merged by permission. - decision = { - "decisions": [ - { - "type": "edit", - "edited_action": { - "name": "send_email", - "args": {"subject": "edited"}, - }, - } - ] - } - assert _normalize_permission_decision(decision) == {"decision_type": "once"} - - def test_lc_single_decision_without_envelope(self) -> None: - assert _normalize_permission_decision({"type": "approve"}) == { - "decision_type": "once" - } - - def test_unknown_type_falls_back_to_reject(self) -> None: - decision = {"decisions": [{"type": "totally_unknown"}]} - assert _normalize_permission_decision(decision) == {"decision_type": "reject"} - - def test_missing_type_falls_back_to_reject(self) -> None: - assert _normalize_permission_decision({"decisions": [{}]}) == { - "decision_type": "reject" - } - - def test_non_dict_non_string_falls_back_to_reject(self) -> None: - assert _normalize_permission_decision(None) == {"decision_type": "reject"} - assert _normalize_permission_decision(42) == {"decision_type": "reject"} - - def test_empty_decisions_list_falls_back_to_reject(self) -> None: - # Fail-closed on a malformed reply rather than treat it as approve. - assert _normalize_permission_decision({"decisions": []}) == { - "decision_type": "reject" - } - - -class TestResumeShapesEndToEnd: - """LangChain HITL envelope reaches ``_process`` correctly via ``_raise_interrupt``.""" - - def test_lc_approve_envelope_keeps_call(self) -> None: - mw = PermissionMiddleware(rulesets=[]) - mw._raise_interrupt = lambda **kw: { # type: ignore[assignment] - "decisions": [{"type": "approve"}] - } - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - original = mw._raise_interrupt - mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment] - original(**kw) - ) - out = mw.after_model(state, _FakeRuntime()) - assert out is None - - def test_lc_reject_envelope_raises(self) -> None: - mw = PermissionMiddleware(rulesets=[]) - original = lambda **kw: {"decisions": [{"type": "reject"}]} # noqa: E731 - mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment] - original(**kw) - ) - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - with pytest.raises(RejectedError): - mw.after_model(state, _FakeRuntime()) - - def test_lc_reject_with_message_raises_corrected(self) -> None: - mw = PermissionMiddleware(rulesets=[]) - original = lambda **kw: { # noqa: E731 - "decisions": [{"type": "reject", "message": "wrong recipient"}] - } - mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment] - original(**kw) - ) - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - with pytest.raises(CorrectedError) as excinfo: - mw.after_model(state, _FakeRuntime()) - assert excinfo.value.feedback == "wrong recipient" - - def test_lc_edit_envelope_keeps_call_with_original_args(self) -> None: - # Pins the "edit -> once, args unchanged" contract. - mw = PermissionMiddleware(rulesets=[]) - original = lambda **kw: { # noqa: E731 - "decisions": [ - { - "type": "edit", - "edited_action": { - "name": "send_email", - "args": {"to": "edited@example.com"}, - }, - } - ] - } - mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment] - original(**kw) - ) - state = { - "messages": [ - _msg( - { - "name": "send_email", - "args": {"to": "original@example.com"}, - "id": "1", - } - ) - ] - } - out = mw.after_model(state, _FakeRuntime()) - assert out is None diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py b/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py index 8ec16617a..e680a955b 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py @@ -4,7 +4,7 @@ from __future__ import annotations import pytest -from app.agents.new_chat.permissions import ( +from app.agents.chat.multi_agent_chat.shared.permissions import ( Rule, Ruleset, aggregate_action, diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py b/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py index 5dbf765a7..3aae7cc75 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py @@ -6,13 +6,13 @@ from unittest.mock import MagicMock, patch from langchain.agents.middleware import AgentMiddleware -from app.agents.new_chat.plugin_loader import ( +from app.agents.chat.multi_agent_chat.main_agent.plugins.loader import ( PLUGIN_ENTRY_POINT_GROUP, PluginContext, load_allowed_plugin_names_from_env, load_plugin_middlewares, ) -from app.agents.new_chat.plugins.year_substituter import ( +from app.agents.chat.multi_agent_chat.main_agent.plugins.year_substituter import ( _YearSubstituterMiddleware, make_middleware as year_substituter_factory, ) @@ -66,7 +66,7 @@ class TestPluginLoaderBasics: ep = _FakeEntryPoint("dangerous_plugin", factory) with patch( - "app.agents.new_chat.plugin_loader.entry_points", + "app.agents.chat.multi_agent_chat.main_agent.plugins.loader.entry_points", return_value=[ep], ): result = load_plugin_middlewares( @@ -78,7 +78,7 @@ class TestPluginLoaderBasics: def test_loads_allowlisted_plugin(self) -> None: ep = _FakeEntryPoint("year_substituter", year_substituter_factory) with patch( - "app.agents.new_chat.plugin_loader.entry_points", + "app.agents.chat.multi_agent_chat.main_agent.plugins.loader.entry_points", return_value=[ep], ): result = load_plugin_middlewares( @@ -95,7 +95,7 @@ class TestPluginLoaderIsolation: ep = _FakeEntryPoint("buggy", crashing_factory) with patch( - "app.agents.new_chat.plugin_loader.entry_points", + "app.agents.chat.multi_agent_chat.main_agent.plugins.loader.entry_points", return_value=[ep], ): result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"buggy"}) @@ -107,7 +107,7 @@ class TestPluginLoaderIsolation: ep = _FakeEntryPoint("liar", bad_factory) with patch( - "app.agents.new_chat.plugin_loader.entry_points", + "app.agents.chat.multi_agent_chat.main_agent.plugins.loader.entry_points", return_value=[ep], ): result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"liar"}) @@ -121,7 +121,7 @@ class TestPluginLoaderIsolation: raise ImportError("cannot import") with patch( - "app.agents.new_chat.plugin_loader.entry_points", + "app.agents.chat.multi_agent_chat.main_agent.plugins.loader.entry_points", return_value=[_BrokenEP()], ): result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"broken"}) @@ -137,7 +137,7 @@ class TestPluginLoaderIsolation: _FakeEntryPoint("crashing", crashing_factory), _FakeEntryPoint("ok", year_substituter_factory), ] - with patch("app.agents.new_chat.plugin_loader.entry_points", return_value=eps): + with patch("app.agents.chat.multi_agent_chat.main_agent.plugins.loader.entry_points", return_value=eps): result = load_plugin_middlewares( _ctx(), allowed_plugin_names={"crashing", "ok"} ) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py index c3de15c58..6fbe39349 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py @@ -1,5 +1,5 @@ r"""Tests for ``apply_litellm_prompt_caching`` in -:mod:`app.agents.new_chat.prompt_caching`. +:mod:`app.agents.chat.runtime.prompt_caching`. The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never activated for our LiteLLM stack) with LiteLLM-native multi-provider @@ -34,8 +34,10 @@ from typing import Any import pytest -from app.agents.new_chat.llm_config import AgentConfig -from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching +from app.agents.chat.runtime.llm_config import AgentConfig +from app.agents.chat.runtime.prompt_caching import ( + apply_litellm_prompt_caching, +) pytestmark = pytest.mark.unit diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py b/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py deleted file mode 100644 index ffe3dbaa4..000000000 --- a/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Tests for ``_resolve_prompt_model_name`` in :mod:`app.agents.new_chat.chat_deepagent`. - -The helper picks the model id fed to ``detect_provider_variant`` so the -right ```` block lands in the system prompt. The tests -below pin its preference order: - -1. ``agent_config.litellm_params["base_model"]`` (Azure-correct). -2. ``agent_config.model_name``. -3. ``getattr(llm, "model", None)``. - -Without (1) an Azure deployment named e.g. ``"prod-chat-001"`` would -silently miss every provider regex. -""" - -from __future__ import annotations - -import pytest - -from app.agents.new_chat.chat_deepagent import _resolve_prompt_model_name -from app.agents.new_chat.llm_config import AgentConfig - -pytestmark = pytest.mark.unit - - -def _make_cfg(**overrides) -> AgentConfig: - """Build an ``AgentConfig`` with sensible defaults for the helper test.""" - defaults = { - "provider": "OPENAI", - "model_name": "x", - "api_key": "k", - } - return AgentConfig(**{**defaults, **overrides}) - - -class _FakeLLM: - """Stand-in for a ``ChatLiteLLM`` / ``ChatLiteLLMRouter`` instance. - - The resolver only reads the ``.model`` attribute via ``getattr``, - matching the established idiom in ``knowledge_search.py`` / - ``stream_new_chat.py`` / ``document_summarizer.py``. - """ - - def __init__(self, model: str | None) -> None: - self.model = model - - -def test_prefers_litellm_params_base_model_over_deployment_name() -> None: - """Azure deployment slug must NOT shadow the underlying model family. - - This is the failure mode the helper exists to prevent: a deployment - named ``"azure/prod-chat-001"`` would not match any provider regex - on its own, but the family ``"gpt-4o"`` lives in - ``litellm_params["base_model"]`` and routes to ``openai_classic``. - """ - cfg = _make_cfg( - model_name="azure/prod-chat-001", - litellm_params={"base_model": "gpt-4o"}, - ) - assert _resolve_prompt_model_name(cfg, _FakeLLM("azure/prod-chat-001")) == "gpt-4o" - - -def test_falls_back_to_model_name_when_litellm_params_is_none() -> None: - cfg = _make_cfg( - model_name="anthropic/claude-3-5-sonnet", - litellm_params=None, - ) - got = _resolve_prompt_model_name(cfg, _FakeLLM("anthropic/claude-3-5-sonnet")) - assert got == "anthropic/claude-3-5-sonnet" - - -def test_handles_litellm_params_without_base_model_key() -> None: - cfg = _make_cfg( - model_name="openai/gpt-4o", - litellm_params={"temperature": 0.5}, - ) - assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o" - - -def test_ignores_blank_base_model() -> None: - """Whitespace-only ``base_model`` must not shadow ``model_name``.""" - cfg = _make_cfg( - model_name="openai/gpt-4o", - litellm_params={"base_model": " "}, - ) - assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o" - - -def test_ignores_non_string_base_model() -> None: - """Defensive: a non-string ``base_model`` should not crash the resolver.""" - cfg = _make_cfg( - model_name="openai/gpt-4o", - litellm_params={"base_model": 42}, - ) - assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o" - - -def test_falls_back_to_llm_model_when_no_agent_config() -> None: - """No ``agent_config`` -> use ``llm.model`` directly. Defensive path - for direct callers; production callers always supply a config.""" - assert ( - _resolve_prompt_model_name(None, _FakeLLM("openai/gpt-4o-mini")) - == "openai/gpt-4o-mini" - ) - - -def test_returns_none_when_nothing_available() -> None: - """``compose_system_prompt`` treats ``None`` as the ``"default"`` - variant and emits no provider block.""" - assert _resolve_prompt_model_name(None, _FakeLLM(None)) is None - - -def test_auto_mode_resolves_to_auto_string() -> None: - """Auto mode -> ``"auto"``. ``detect_provider_variant("auto")`` - returns ``"default"``, which is correct: the child model isn't - known until the LiteLLM Router dispatches.""" - cfg = AgentConfig.from_auto_mode() - assert _resolve_prompt_model_name(cfg, _FakeLLM("auto")) == "auto" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py b/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py index d23fd693b..b70718ff9 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py @@ -4,7 +4,7 @@ from __future__ import annotations import pytest -from app.agents.new_chat.middleware.retry_after import ( +from app.agents.chat.shared.middleware.retry_after import ( RetryAfterMiddleware, _extract_retry_after_seconds, _is_non_retryable, diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_skills_backends.py b/surfsense_backend/tests/unit/agents/new_chat/test_skills_backends.py index eb9cf396c..1c497d99b 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_skills_backends.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_skills_backends.py @@ -7,7 +7,7 @@ from pathlib import Path import pytest -from app.agents.new_chat.middleware.skills_backends import ( +from app.agents.chat.multi_agent_chat.main_agent.skills.backends import ( SKILLS_BUILTIN_PREFIX, SKILLS_SPACE_PREFIX, BuiltinSkillsBackend, diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py b/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py deleted file mode 100644 index 3c7fe5336..000000000 --- a/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py +++ /dev/null @@ -1,337 +0,0 @@ -"""Tests for the specialized subagents (explore / report_writer / connector_negotiator).""" - -from __future__ import annotations - -from langchain_core.tools import tool - -from app.agents.new_chat.middleware.permission import PermissionMiddleware -from app.agents.new_chat.subagents import ( - build_connector_negotiator_subagent, - build_explore_subagent, - build_report_writer_subagent, - build_specialized_subagents, -) -from app.agents.new_chat.subagents.config import ( - EXPLORE_READ_TOOLS, - REPORT_WRITER_TOOLS, - WRITE_TOOL_DENY_PATTERNS, -) - -# --------------------------------------------------------------------------- -# Fake tools used to verify filtering & permission behavior -# --------------------------------------------------------------------------- - - -@tool -def web_search(query: str) -> str: - """Search the public web.""" - return "" - - -@tool -def scrape_webpage(url: str) -> str: - """Scrape a single webpage.""" - return "" - - -@tool -def read_file(path: str) -> str: - """Read a file.""" - return "" - - -@tool -def ls_tree(path: str) -> str: - """List a tree.""" - return "" - - -@tool -def grep(pattern: str) -> str: - """Grep.""" - return "" - - -@tool -def update_memory(content: str) -> str: - """Update the user's memory.""" - return "" - - -@tool -def edit_file(path: str, old: str, new: str) -> str: - """Edit a file.""" - return "" - - -@tool -def linear_create_issue(title: str) -> str: - """Create a Linear issue.""" - return "" - - -@tool -def slack_send_message(channel: str, text: str) -> str: - """Send a Slack message.""" - return "" - - -@tool -def get_connected_accounts() -> str: - """List connected accounts.""" - return "" - - -@tool -def generate_report(topic: str) -> str: - """Generate a report artifact.""" - return "" - - -ALL_TOOLS = [ - web_search, - scrape_webpage, - read_file, - ls_tree, - grep, - update_memory, - edit_file, - linear_create_issue, - slack_send_message, - get_connected_accounts, - generate_report, -] - - -class TestExploreSubagent: - def test_only_read_tools_are_exposed(self) -> None: - spec = build_explore_subagent(tools=ALL_TOOLS) - names = {t.name for t in spec["tools"]} # type: ignore[index] - assert names == EXPLORE_READ_TOOLS & {t.name for t in ALL_TOOLS} - assert "update_memory" not in names - assert "linear_create_issue" not in names - assert "edit_file" not in names - - def test_includes_permission_middleware_with_deny_rules(self) -> None: - spec = build_explore_subagent(tools=ALL_TOOLS) - permission_mws = [ - m - for m in spec["middleware"] - if isinstance(m, PermissionMiddleware) # type: ignore[index] - ] - assert len(permission_mws) == 1 - ruleset = permission_mws[0]._static_rulesets[0] - assert ruleset.origin == "subagent_explore" - deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"} - assert "update_memory" in deny_patterns - assert "edit_file" in deny_patterns - assert "*create*" in deny_patterns - assert "*send*" in deny_patterns - - def test_skills_inherits_default_sources(self) -> None: - spec = build_explore_subagent(tools=ALL_TOOLS) - assert spec["skills"] == ["/skills/builtin/", "/skills/space/"] # type: ignore[index] - - def test_name_and_description_match_contract(self) -> None: - spec = build_explore_subagent(tools=ALL_TOOLS) - assert spec["name"] == "explore" - assert "read-only" in spec["description"].lower() - - def test_includes_dedup_and_patch_middleware(self) -> None: - from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware - - from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware - - spec = build_explore_subagent(tools=ALL_TOOLS) - types = {type(m) for m in spec["middleware"]} # type: ignore[index] - assert PatchToolCallsMiddleware in types - assert DedupHITLToolCallsMiddleware in types - - -class TestReportWriterSubagent: - def test_exposes_only_report_writing_tools(self) -> None: - spec = build_report_writer_subagent(tools=ALL_TOOLS) - names = {t.name for t in spec["tools"]} # type: ignore[index] - assert names == REPORT_WRITER_TOOLS & {t.name for t in ALL_TOOLS} - assert "generate_report" in names - assert "read_file" in names - - def test_deny_rules_block_writes_but_allow_generate_report(self) -> None: - spec = build_report_writer_subagent(tools=ALL_TOOLS) - permission_mws = [ - m - for m in spec["middleware"] - if isinstance(m, PermissionMiddleware) # type: ignore[index] - ] - ruleset = permission_mws[0]._static_rulesets[0] - deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"} - assert "update_memory" in deny_patterns - # generate_report MUST not be denied — it's the whole point of the subagent. - assert "generate_report" not in deny_patterns - # No deny pattern should match `generate_report` either. - assert all( - not _wildcard_matches(pattern, "generate_report") - for pattern in deny_patterns - ) - - -class TestConnectorNegotiatorSubagent: - def test_inherits_all_parent_tools(self) -> None: - spec = build_connector_negotiator_subagent(tools=ALL_TOOLS) - names = {t.name for t in spec["tools"]} # type: ignore[index] - # Every parent tool is inherited; the deny ruleset enforces behavior - # at execution time instead of trimming the tool list. - assert names == {t.name for t in ALL_TOOLS} - - def test_get_connected_accounts_is_present(self) -> None: - spec = build_connector_negotiator_subagent(tools=ALL_TOOLS) - names = {t.name for t in spec["tools"]} # type: ignore[index] - assert "get_connected_accounts" in names - - def test_deny_ruleset_blocks_mutating_connector_tools(self) -> None: - spec = build_connector_negotiator_subagent(tools=ALL_TOOLS) - permission_mws = [ - m - for m in spec["middleware"] - if isinstance(m, PermissionMiddleware) # type: ignore[index] - ] - ruleset = permission_mws[0]._static_rulesets[0] - deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"} - # `linear_create_issue` matches the `*_create` deny pattern. - assert any(_wildcard_matches(p, "linear_create_issue") for p in deny_patterns) - assert any(_wildcard_matches(p, "slack_send_message") for p in deny_patterns) - - -class TestBuildSpecializedSubagents: - def test_returns_five_specs(self) -> None: - specs = build_specialized_subagents(tools=ALL_TOOLS) - names = [s["name"] for s in specs] # type: ignore[index] - assert names == [ - "explore", - "report_writer", - "linear_specialist", - "slack_specialist", - "connector_negotiator", - ] - - def test_all_specs_have_unique_names(self) -> None: - specs = build_specialized_subagents(tools=ALL_TOOLS) - names = [s["name"] for s in specs] # type: ignore[index] - assert len(set(names)) == len(names) - - def test_extra_middleware_is_prepended_to_each_spec(self) -> None: - """Sentinel middleware passed via ``extra_middleware`` must appear - in each subagent's ``middleware`` list, before the local rules. - - This guards against the regression where specialized subagents - promised filesystem tools (``read_file``, ``ls``, ``grep``) in - their system prompts but had no filesystem middleware mounted. - """ - - class _Sentinel: - pass - - sentinel = _Sentinel() - specs = build_specialized_subagents( - tools=ALL_TOOLS, extra_middleware=[sentinel] - ) - for spec in specs: - mws = spec["middleware"] # type: ignore[index] - assert sentinel in mws - # The sentinel must appear *before* the permission middleware - # (subagent-local rules), preserving the documented composition - # order: extra → custom → patch → dedup. - sentinel_idx = mws.index(sentinel) - perm_idx = next( - (i for i, m in enumerate(mws) if isinstance(m, PermissionMiddleware)), - None, - ) - assert perm_idx is not None - assert sentinel_idx < perm_idx - - -class TestFilterToolsWarningSuppression: - """Names provided by middleware (read_file, ls, grep, …) must not - trigger the spurious "missing" warning in :func:`_filter_tools`.""" - - def test_middleware_provided_names_are_silent(self, caplog) -> None: - import logging - - from app.agents.new_chat.subagents.config import _filter_tools - - with caplog.at_level( - logging.INFO, logger="app.agents.new_chat.subagents.config" - ): - # Allowed set asks for two registry tools (one present, one - # not) plus a bunch of middleware-provided names. - _filter_tools( - [web_search], - allowed_names={ - "web_search", - "scrape_webpage", # legitimately missing → should warn - "read_file", # mw-provided → suppressed - "ls", - "grep", - "glob", - "write_todos", - }, - ) - - warnings = [r.message for r in caplog.records if r.levelno >= logging.INFO] - # Exactly one warning, and it should mention scrape_webpage but not - # any middleware-provided name. Inspect the rendered "missing" - # list (between the brackets) so we don't false-match substrings - # like ``ls`` inside ``available``. - assert len(warnings) == 1, warnings - msg = warnings[0] - assert "scrape_webpage" in msg - bracket_section = msg.split("missing: ", 1)[1] - for noisy in ("read_file", "ls", "grep", "glob", "write_todos"): - assert f"'{noisy}'" not in bracket_section, msg - - -class TestDenyPatternsCoverage: - def test_deny_patterns_cover_canonical_write_tools(self) -> None: - canonical_writes = [ - "update_memory", - "edit_file", - "write_file", - "move_file", - "mkdir", - "linear_create_issue", - "linear_update_issue", - "linear_delete_issue", - "slack_send_message", - "create_index", - "update_account", - "delete_record", - "send_email", - ] - for tool_name in canonical_writes: - assert any( - _wildcard_matches(pattern, tool_name) - for pattern in WRITE_TOOL_DENY_PATTERNS - ), f"no deny pattern matches {tool_name!r}" - - def test_deny_patterns_do_not_match_safe_read_tools(self) -> None: - canonical_reads = [ - "read_file", - "ls_tree", - "grep", - "web_search", - "scrape_webpage", - "get_connected_accounts", - "generate_report", - ] - for tool_name in canonical_reads: - assert not any( - _wildcard_matches(pattern, tool_name) - for pattern in WRITE_TOOL_DENY_PATTERNS - ), f"deny pattern incorrectly matches read tool {tool_name!r}" - - -def _wildcard_matches(pattern: str, value: str) -> bool: - """Helper using the same matcher the rule evaluator does.""" - from app.agents.new_chat.permissions import wildcard_match - - return wildcard_match(value, pattern) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py b/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py index 185753990..637a10704 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py @@ -4,7 +4,7 @@ from __future__ import annotations import pytest -from app.agents.new_chat.state_reducers import ( +from app.agents.chat.multi_agent_chat.shared.state.reducers import ( _CLEAR, _add_unique_reducer, _dict_merge_with_tombstones_reducer, diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py b/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py index e02a04774..1e11e39ce 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py @@ -5,10 +5,12 @@ from __future__ import annotations import pytest from langchain_core.messages import AIMessage -from app.agents.new_chat.middleware.tool_call_repair import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.tool_call_repair.middleware import ( ToolCallNameRepairMiddleware, ) -from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME +from app.agents.chat.multi_agent_chat.main_agent.tools.invalid_tool import ( + INVALID_TOOL_NAME, +) pytestmark = pytest.mark.unit diff --git a/surfsense_backend/tests/unit/agents/new_chat/tools/test_mcp_tools_cache.py b/surfsense_backend/tests/unit/agents/new_chat/tools/test_mcp_tools_cache.py index bae97ba9f..7d9d35b55 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/tools/test_mcp_tools_cache.py +++ b/surfsense_backend/tests/unit/agents/new_chat/tools/test_mcp_tools_cache.py @@ -7,7 +7,7 @@ from types import SimpleNamespace import pytest -from app.agents.new_chat.tools.mcp_tools_cache import ( +from app.agents.chat.multi_agent_chat.shared.tools.mcp.cache import ( CachedMCPToolDef, CachedMCPTools, read_cached_tools, diff --git a/surfsense_backend/tests/unit/agents/new_chat/tools/test_resume_page_limits.py b/surfsense_backend/tests/unit/agents/new_chat/tools/test_resume_page_limits.py index f9212a45c..61fa87b76 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/tools/test_resume_page_limits.py +++ b/surfsense_backend/tests/unit/agents/new_chat/tools/test_resume_page_limits.py @@ -1,17 +1,58 @@ -"""Unit tests for resume page-limit helpers and enforcement flow.""" +"""Unit tests for resume page-limit helpers and enforcement flow. + +Targets the live deliverables resume tool. The tool returns a +``Command`` (payload JSON-encoded in ``update["messages"][0].content`` +plus a receipt), so flow tests invoke it via a ToolCall dict and unwrap +the payload. +""" import io +import json from types import SimpleNamespace from unittest.mock import AsyncMock import pypdf import pytest +from langchain.tools import ToolRuntime -from app.agents.new_chat.tools import resume as resume_tool +from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools import ( + resume as resume_tool, +) pytestmark = pytest.mark.unit +@pytest.fixture(autouse=True) +def _silence_progress_events(monkeypatch): + """The live tool emits ``dispatch_custom_event`` progress updates that + require a langgraph run context; neutralize them for direct unit calls.""" + monkeypatch.setattr(resume_tool, "dispatch_custom_event", lambda *a, **k: None) + + +def _runtime(tool_call_id: str = "call-1") -> ToolRuntime: + """Minimal ToolRuntime; the resume tool only reads ``tool_call_id``.""" + return ToolRuntime( + state={}, + context=None, + config={}, + stream_writer=None, + tool_call_id=tool_call_id, + store=None, + ) + + +async def _invoke(tool, args: dict) -> dict: + """Drive a Command-returning tool and return its decoded payload. + + These tools take an injected ``ToolRuntime`` and return a + ``Command``; invoke the raw coroutine with a hand-built runtime + (the repo's pattern for unit-testing such tools) and decode the + ToolMessage payload. + """ + command = await tool.coroutine(runtime=_runtime(), **args) + return json.loads(command.update["messages"][0].content) + + class _FakeReport: _next_id = 1000 @@ -108,7 +149,7 @@ async def test_generate_resume_defaults_to_one_page_target(monkeypatch) -> None: monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: 1) tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1) - result = await tool.ainvoke({"user_info": "Jane Doe experience"}) + result = await _invoke(tool, {"user_info": "Jane Doe experience"}) assert result["status"] == "ready" assert prompts @@ -138,7 +179,7 @@ async def test_generate_resume_compresses_when_over_limit(monkeypatch) -> None: monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts)) tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1) - result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1}) + result = await _invoke(tool, {"user_info": "Jane Doe experience", "max_pages": 1}) assert result["status"] == "ready" assert write_session.added, "Expected successful report write" @@ -173,7 +214,7 @@ async def test_generate_resume_returns_ready_when_target_not_met(monkeypatch) -> monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts)) tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1) - result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1}) + result = await _invoke(tool, {"user_info": "Jane Doe experience", "max_pages": 1}) assert result["status"] == "ready" assert "could not fit the target" in (result["message"] or "").lower() @@ -206,7 +247,7 @@ async def test_generate_resume_fails_when_hard_limit_exceeded(monkeypatch) -> No monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts)) tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1) - result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1}) + result = await _invoke(tool, {"user_info": "Jane Doe experience", "max_pages": 1}) assert result["status"] == "failed" assert "hard page limit" in (result["error"] or "").lower() diff --git a/surfsense_backend/tests/unit/agents/test_import_all.py b/surfsense_backend/tests/unit/agents/test_import_all.py new file mode 100644 index 000000000..b45bf3359 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/test_import_all.py @@ -0,0 +1,53 @@ +"""Guardrail A: every agent module (and its prod entrypoints) must import. + +Static reachability analysis and mocked unit tests cannot catch a module that +fails to import after files move or imports are rewritten. This smoke test +imports every submodule under ``app.agents`` plus the production entrypoints +that consume agents, turning a move-time ``ImportError`` into a fast, local CI +signal instead of a runtime failure in prod. +""" + +from __future__ import annotations + +import importlib +import pkgutil + +import pytest + +import app.agents as agents_pkg + +pytestmark = pytest.mark.unit + +# Prod consumers of app.agents that live OUTSIDE the agents tree; a broken +# importer here would not be caught by walking app.agents alone. +_PROD_ENTRYPOINTS = [ + "app.tasks.chat.streaming.flows.new_chat.orchestrator", + "app.tasks.chat.streaming.agent.builder", + "app.gateway.agent_invoke", + "app.routes.new_chat_routes", +] + + +def _iter_agent_modules() -> list[str]: + names: list[str] = [] + + def _record(name: str) -> None: + names.append(name) + + for info in pkgutil.walk_packages( + agents_pkg.__path__, prefix=agents_pkg.__name__ + ".", onerror=_record + ): + names.append(info.name) + return sorted(set(names)) + + +@pytest.mark.parametrize("module_name", _iter_agent_modules()) +def test_agent_module_imports(module_name: str) -> None: + """Importing the module must not raise (no broken or missed imports).""" + importlib.import_module(module_name) + + +@pytest.mark.parametrize("module_name", _PROD_ENTRYPOINTS) +def test_prod_entrypoint_imports(module_name: str) -> None: + """The production code paths that build/invoke agents must import.""" + importlib.import_module(module_name) diff --git a/surfsense_backend/tests/unit/automations/services/test_model_policy.py b/surfsense_backend/tests/unit/automations/services/test_model_policy.py index 2a471b4e9..8e0806151 100644 --- a/surfsense_backend/tests/unit/automations/services/test_model_policy.py +++ b/surfsense_backend/tests/unit/automations/services/test_model_policy.py @@ -44,7 +44,7 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch): -2: {"id": -2, "billing_tier": "free"}, } monkeypatch.setattr( - "app.agents.new_chat.llm_config.load_global_llm_config_by_id", + "app.agents.chat.runtime.llm_config.load_global_llm_config_by_id", lambda cid: llm_configs.get(cid), ) diff --git a/surfsense_backend/tests/unit/gateway/test_webhook_routes.py b/surfsense_backend/tests/unit/gateway/test_webhook_routes.py index 338a35c39..b686ebcb8 100644 --- a/surfsense_backend/tests/unit/gateway/test_webhook_routes.py +++ b/surfsense_backend/tests/unit/gateway/test_webhook_routes.py @@ -13,6 +13,29 @@ from app.db import ExternalChatAccount, ExternalChatAccountMode, ExternalChatPla from app.routes import gateway_webhook_routes as routes +@pytest.fixture(autouse=True) +def _enable_gateways(monkeypatch): + """Turn on the Telegram/Slack/Discord gateway flags the routes gate on. + + The routes early-return when their integration is unconfigured, so without + this the handlers never reach the logic these tests assert on. + """ + monkeypatch.setattr(routes.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "webhook") + monkeypatch.setattr(routes.config, "TELEGRAM_SHARED_BOT_TOKEN", "telegram-token") + monkeypatch.setattr(routes.config, "TELEGRAM_SHARED_BOT_USERNAME", "surf_bot") + monkeypatch.setattr(routes.config, "TELEGRAM_WEBHOOK_SECRET", "telegram-webhook-secret") + + monkeypatch.setattr(routes.config, "GATEWAY_SLACK_ENABLED", True) + monkeypatch.setattr(routes.config, "GATEWAY_SLACK_CLIENT_ID", "slack-client") + monkeypatch.setattr(routes.config, "GATEWAY_SLACK_CLIENT_SECRET", "slack-secret") + monkeypatch.setattr(routes.config, "GATEWAY_SLACK_SIGNING_SECRET", "signing-secret") + + monkeypatch.setattr(routes.config, "GATEWAY_DISCORD_ENABLED", True) + monkeypatch.setattr(routes.config, "DISCORD_CLIENT_ID", "discord-client") + monkeypatch.setattr(routes.config, "DISCORD_CLIENT_SECRET", "discord-secret") + monkeypatch.setattr(routes.config, "DISCORD_BOT_TOKEN", "discord-bot-token") + + class RequestStub: def __init__(self, payload=None, *, headers=None, json_exc: Exception | None = None): self.headers = headers or {} diff --git a/surfsense_backend/tests/unit/middleware/test_b_filesystem_path_resolution.py b/surfsense_backend/tests/unit/middleware/test_b_filesystem_path_resolution.py new file mode 100644 index 000000000..a4e23c39f --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_b_filesystem_path_resolution.py @@ -0,0 +1,287 @@ +"""Path/cwd/namespace + multi-root mount-normalization tests for LIVE filesystem. + +Ported from the dead-twin suites: +* ``tests/unit/middleware/test_filesystem_middleware.py`` (cwd defaults, + relative resolution, cloud write-namespace policy) +* ``tests/unit/middleware/test_filesystem_verification.py`` (desktop + multi-root mount-prefix normalization) + +Both exercised ``app.agents.chat.shared.middleware.filesystem`` (dead). This drives +the production free functions in +``app.agents.chat.multi_agent_chat.shared.middleware.filesystem.middleware`` instead. +The functions only touch ``mw._filesystem_mode`` and ``mw._get_backend`` so we +pass a lightweight fake ``mw`` rather than constructing the full middleware. +""" + +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.multi_root_local_folder import ( + MultiRootLocalFolderBackend, +) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.middleware.mode import ( + default_cwd, +) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.middleware.namespace_policy import ( + check_cloud_write_namespace, +) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.middleware.path_resolution import ( + current_cwd, + get_contract_suggested_path, + normalize_local_mount_path, + resolve_relative, +) + +pytestmark = pytest.mark.unit + + +def _mw(mode: FilesystemMode = FilesystemMode.CLOUD, backend=None): + return SimpleNamespace(_filesystem_mode=mode, _get_backend=lambda _rt: backend) + + +def _runtime(state: dict | None = None) -> SimpleNamespace: + return SimpleNamespace(state=state or {}) + + +# --------------------------------------------------------------------------- +# cwd defaults +# --------------------------------------------------------------------------- + + +class TestCwdDefaults: + def test_default_cwd_in_cloud_is_documents_root(self): + assert default_cwd(FilesystemMode.CLOUD) == "/documents" + + def test_default_cwd_in_desktop_is_root(self): + assert default_cwd(FilesystemMode.DESKTOP_LOCAL_FOLDER) == "/" + + def test_current_cwd_uses_state_when_set(self): + assert ( + current_cwd(_mw(), _runtime({"cwd": "/documents/notes"})) + == "/documents/notes" + ) + + def test_current_cwd_falls_back_to_default(self): + assert current_cwd(_mw(), _runtime({})) == "/documents" + + def test_current_cwd_ignores_invalid(self): + assert current_cwd(_mw(), _runtime({"cwd": "not-absolute"})) == "/documents" + + +# --------------------------------------------------------------------------- +# relative resolution +# --------------------------------------------------------------------------- + + +class TestRelativePathResolution: + def test_relative_path_resolves_against_cwd(self): + assert ( + resolve_relative(_mw(), "notes.md", _runtime({"cwd": "/documents/projects"})) + == "/documents/projects/notes.md" + ) + + def test_relative_path_with_dotdot(self): + assert ( + resolve_relative(_mw(), "../c.md", _runtime({"cwd": "/documents/a/b"})) + == "/documents/a/c.md" + ) + + def test_absolute_path_is_kept(self): + assert ( + resolve_relative(_mw(), "/other/x.md", _runtime({"cwd": "/documents"})) + == "/other/x.md" + ) + + def test_empty_path_returns_cwd(self): + assert ( + resolve_relative(_mw(), "", _runtime({"cwd": "/documents/projects"})) + == "/documents/projects" + ) + + +# --------------------------------------------------------------------------- +# contract suggested-path fallback +# --------------------------------------------------------------------------- + + +class TestContractSuggestedPath: + def test_falls_back_to_documents_notes_md_in_cloud(self): + suggested = get_contract_suggested_path( + _mw(FilesystemMode.CLOUD), + _runtime({"file_operation_contract": {}}), + ) + assert suggested == "/documents/notes.md" + + def test_falls_back_to_root_notes_md_in_desktop(self): + suggested = get_contract_suggested_path( + _mw(FilesystemMode.DESKTOP_LOCAL_FOLDER), + _runtime({"file_operation_contract": {}}), + ) + assert suggested == "/notes.md" + + +# --------------------------------------------------------------------------- +# cloud write-namespace policy +# --------------------------------------------------------------------------- + + +class TestCloudWriteNamespacePolicy: + def test_documents_path_allowed(self): + assert ( + check_cloud_write_namespace(_mw(), "/documents/foo.md", _runtime()) is None + ) + + def test_documents_root_allowed(self): + assert check_cloud_write_namespace(_mw(), "/documents", _runtime()) is None + + def test_temp_basename_anywhere_allowed(self): + assert ( + check_cloud_write_namespace(_mw(), "/temp_scratch.md", _runtime()) is None + ) + assert check_cloud_write_namespace(_mw(), "/foo/temp_x.md", _runtime()) is None + assert ( + check_cloud_write_namespace(_mw(), "/documents/temp_x.md", _runtime()) + is None + ) + + def test_other_paths_rejected(self): + err = check_cloud_write_namespace(_mw(), "/foo/bar.md", _runtime()) + assert err is not None + assert "must target /documents" in err + + def test_anon_doc_path_is_read_only(self): + runtime = _runtime( + { + "kb_anon_doc": { + "path": "/documents/uploaded.xml", + "title": "uploaded", + "content": "", + "chunks": [], + } + } + ) + err = check_cloud_write_namespace(_mw(), "/documents/uploaded.xml", runtime) + assert err is not None + assert "read-only" in err + + def test_desktop_mode_skips_namespace_policy(self): + assert ( + check_cloud_write_namespace( + _mw(FilesystemMode.DESKTOP_LOCAL_FOLDER), "/random/path.md", _runtime() + ) + is None + ) + + +# --------------------------------------------------------------------------- +# desktop multi-root mount normalization +# --------------------------------------------------------------------------- + + +def _desktop_mw(backend) -> SimpleNamespace: + return _mw(FilesystemMode.DESKTOP_LOCAL_FOLDER, backend) + + +class TestNormalizeLocalMountPath: + def test_prefixes_default_mount(self, tmp_path: Path): + root = tmp_path / "PC Backups" + root.mkdir() + backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),)) + resolved = normalize_local_mount_path( + _desktop_mw(backend), + "/random-note.md", + _runtime({"file_operation_contract": {}}), + ) + assert resolved == "/pc_backups/random-note.md" + + def test_keeps_explicit_mount(self, tmp_path: Path): + root = tmp_path / "PC Backups" + root.mkdir() + backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),)) + resolved = normalize_local_mount_path( + _desktop_mw(backend), + "/pc_backups/notes/random-note.md", + _runtime({"file_operation_contract": {}}), + ) + assert resolved == "/pc_backups/notes/random-note.md" + + def test_windows_backslashes(self, tmp_path: Path): + root = tmp_path / "PC Backups" + root.mkdir() + backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),)) + resolved = normalize_local_mount_path( + _desktop_mw(backend), + r"\notes\random-note.md", + _runtime({"file_operation_contract": {}}), + ) + assert resolved == "/pc_backups/notes/random-note.md" + + def test_normalizes_mixed_separators(self, tmp_path: Path): + root = tmp_path / "PC Backups" + root.mkdir() + backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),)) + resolved = normalize_local_mount_path( + _desktop_mw(backend), + r"\\notes//nested\\random-note.md", + _runtime({"file_operation_contract": {}}), + ) + assert resolved == "/pc_backups/notes/nested/random-note.md" + + def test_keeps_explicit_mount_with_backslashes(self, tmp_path: Path): + root = tmp_path / "PC Backups" + root.mkdir() + backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),)) + resolved = normalize_local_mount_path( + _desktop_mw(backend), + r"\pc_backups\notes\random-note.md", + _runtime({"file_operation_contract": {}}), + ) + assert resolved == "/pc_backups/notes/random-note.md" + + def test_prefixes_posix_absolute_path(self, tmp_path: Path): + root = tmp_path / "PC Backups" + root.mkdir() + backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),)) + resolved = normalize_local_mount_path( + _desktop_mw(backend), + "/var/log/app.log", + _runtime({"file_operation_contract": {}}), + ) + assert resolved == "/pc_backups/var/log/app.log" + + def test_prefers_unique_existing_parent_mount(self, tmp_path: Path): + root_a = tmp_path / "RootA" + root_b = tmp_path / "RootB" + (root_a / "other").mkdir(parents=True) + (root_b / "nested" / "deep").mkdir(parents=True) + backend = MultiRootLocalFolderBackend( + (("root_a", str(root_a)), ("root_b", str(root_b))) + ) + resolved = normalize_local_mount_path( + _desktop_mw(backend), + "/nested/deep/new-note.md", + _runtime({"file_operation_contract": {}}), + ) + assert resolved == "/root_b/nested/deep/new-note.md" + + def test_uses_suggested_mount_when_ambiguous(self, tmp_path: Path): + root_a = tmp_path / "RootA" + root_b = tmp_path / "RootB" + root_a.mkdir(parents=True) + root_b.mkdir(parents=True) + backend = MultiRootLocalFolderBackend( + (("root_a", str(root_a)), ("root_b", str(root_b))) + ) + resolved = normalize_local_mount_path( + _desktop_mw(backend), + "/brand-new-note.md", + _runtime( + {"file_operation_contract": {"suggested_path": "/root_b/notes/context.md"}} + ), + ) + assert resolved == "/root_b/brand-new-note.md" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py b/surfsense_backend/tests/unit/middleware/test_b_filesystem_rm_rmdir_cloud.py similarity index 61% rename from surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py rename to surfsense_backend/tests/unit/middleware/test_b_filesystem_rm_rmdir_cloud.py index 7cabb6524..898ec3765 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py +++ b/surfsense_backend/tests/unit/middleware/test_b_filesystem_rm_rmdir_cloud.py @@ -1,15 +1,14 @@ -"""Cloud-mode behavior tests for the new ``rm`` and ``rmdir`` filesystem tools. +"""Cloud-mode ``rm``/``rmdir`` staging tests for the LIVE filesystem middleware. -The tools build ``Command(update=...)`` payloads that the persistence -middleware applies at end of turn. These tests stub out the backend and -runtime to assert the staging payload shape: - -* ``rm`` queues into ``pending_deletes`` and tombstones state files. -* ``rm`` rejects directories, ``/documents``, root, and the anonymous doc. -* ``rmdir`` queues into ``pending_dir_deletes`` and rejects non-empty dirs. -* ``rmdir`` un-stages a same-turn ``mkdir`` rather than queuing a delete. -* ``rmdir`` refuses to drop the cwd or any of its ancestors. -* ``KBPostgresBackend`` view-helpers honor staged deletes. +Ported from the former ``tests/unit/agents/new_chat/test_rm_rmdir_cloud.py``, +which exercised the *dead twin* ``app.agents.chat.shared.middleware.filesystem``. +This drives the production decomposed tools +(``app.agents.chat.multi_agent_chat.shared.middleware.filesystem``) instead: it +builds the real middleware via ``build_filesystem_mw``, pulls the real ``rm`` / +``rmdir`` tools off it, and invokes their coroutines with a stubbed +``KBPostgresBackend`` + runtime so we can assert the end-of-turn staging +payloads (``pending_deletes`` / ``pending_dir_deletes``) and the destructive-op +guard rails (root, /documents, anon doc, non-empty, cwd/ancestor, file vs dir). """ from __future__ import annotations @@ -20,18 +19,38 @@ from unittest.mock import AsyncMock import pytest -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware -from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( + FilesystemMode, + FilesystemSelection, +) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem import ( + build_filesystem_mw, +) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import ( + KBPostgresBackend, +) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.resolver import ( + build_backend_resolver, +) +from app.agents.chat.multi_agent_chat.shared.state.reducers import _CLEAR pytestmark = pytest.mark.unit def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD): - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - middleware._filesystem_mode = mode - middleware._custom_tool_descriptions = {} - return middleware + selection = FilesystemSelection(mode=mode) + resolver = build_backend_resolver(selection, search_space_id=1) + return build_filesystem_mw( + backend_resolver=resolver, + filesystem_mode=mode, + search_space_id=1, + user_id="00000000-0000-0000-0000-000000000001", + thread_id=1, + ) + + +def _tool(mw, name: str): + return next(t for t in mw.tools if t.name == name) def _runtime(state: dict[str, Any] | None = None, *, tool_call_id: str = "tc-abc"): @@ -41,13 +60,12 @@ def _runtime(state: dict[str, Any] | None = None, *, tool_call_id: str = "tc-abc class _KBBackendStub(KBPostgresBackend): - """Construct-able subclass of :class:`KBPostgresBackend` for tests. + """Construct-able ``KBPostgresBackend`` subclass for tests. - We bypass the real ``__init__`` (which expects a runtime + DB session) - and inject just the methods the rm/rmdir tools touch. The class - inheritance keeps ``isinstance(backend, KBPostgresBackend)`` checks - inside the tools happy, which is what gates them from the desktop - code path. + Bypasses the real ``__init__`` (which expects a runtime + DB session) and + injects only the async methods the rm/rmdir tools touch. The class + inheritance keeps the ``isinstance(backend, KBPostgresBackend)`` checks in + the tools on the cloud path. """ def __init__(self, *, children=None, file_data=None) -> None: @@ -61,9 +79,8 @@ def _make_backend_stub(*, children=None, file_data=None) -> KBPostgresBackend: return _KBBackendStub(children=children, file_data=file_data) -def _bind_backend(middleware, backend): - """Inject a backend resolver onto the middleware test instance.""" - middleware._get_backend = lambda runtime: backend +def _bind_backend(mw, backend): + mw._get_backend = lambda runtime: backend return backend @@ -86,8 +103,7 @@ class TestRmStaging: tool_call_id="tc-1", ) - tool = m._create_rm_tool() - result = await tool.coroutine("/documents/notes.md", runtime=runtime) + result = await _tool(m, "rm").coroutine("/documents/notes.md", runtime=runtime) assert hasattr(result, "update"), f"expected Command, got {result!r}" update = result.update @@ -100,31 +116,22 @@ class TestRmStaging: @pytest.mark.asyncio async def test_rejects_documents_root(self): m = _make_middleware() - runtime = _runtime() - tool = m._create_rm_tool() - result = await tool.coroutine("/documents", runtime=runtime) + result = await _tool(m, "rm").coroutine("/documents", runtime=_runtime()) assert isinstance(result, str) assert "refusing to rm" in result @pytest.mark.asyncio async def test_rejects_root(self): m = _make_middleware() - runtime = _runtime() - tool = m._create_rm_tool() - result = await tool.coroutine("/", runtime=runtime) + result = await _tool(m, "rm").coroutine("/", runtime=_runtime()) assert isinstance(result, str) assert "refusing to rm" in result @pytest.mark.asyncio async def test_rejects_directory_via_staged_dirs(self): m = _make_middleware() - runtime = _runtime( - { - "staged_dirs": ["/documents/team-x"], - } - ) - tool = m._create_rm_tool() - result = await tool.coroutine("/documents/team-x", runtime=runtime) + runtime = _runtime({"staged_dirs": ["/documents/team-x"]}) + result = await _tool(m, "rm").coroutine("/documents/team-x", runtime=runtime) assert isinstance(result, str) assert "directory" in result.lower() assert "rmdir" in result @@ -138,9 +145,7 @@ class TestRmStaging: children=[{"path": "/documents/foo/x.md", "is_dir": False}] ), ) - runtime = _runtime() - tool = m._create_rm_tool() - result = await tool.coroutine("/documents/foo", runtime=runtime) + result = await _tool(m, "rm").coroutine("/documents/foo", runtime=_runtime()) assert isinstance(result, str) assert "directory" in result.lower() @@ -157,8 +162,9 @@ class TestRmStaging: } } ) - tool = m._create_rm_tool() - result = await tool.coroutine("/documents/uploaded.xml", runtime=runtime) + result = await _tool(m, "rm").coroutine( + "/documents/uploaded.xml", runtime=runtime + ) assert isinstance(result, str) assert "read-only" in result @@ -173,12 +179,9 @@ class TestRmStaging: "dirty_paths": ["/documents/notes.md"], } ) - tool = m._create_rm_tool() - result = await tool.coroutine("/documents/notes.md", runtime=runtime) - update = result.update - # First element is _CLEAR sentinel; the rest must NOT contain the - # rm'd path. - dirty = update.get("dirty_paths") or [] + result = await _tool(m, "rm").coroutine("/documents/notes.md", runtime=runtime) + dirty = result.update.get("dirty_paths") or [] + # First element is the _CLEAR sentinel; the rm'd path must not survive. assert "/documents/notes.md" not in dirty[1:] @@ -192,30 +195,19 @@ class TestRmdirStaging: async def test_stages_dir_delete_when_empty_and_db_backed(self): m = _make_middleware() backend = _bind_backend(m, _make_backend_stub(children=[])) - # Override _load_file_data to return None (folder, not a file) and - # parent listing to claim the folder exists. backend._load_file_data = AsyncMock(return_value=None) backend.als_info = AsyncMock( side_effect=[ [], # children of /documents/proj - [ - {"path": "/documents/proj", "is_dir": True}, - ], # parent listing + [{"path": "/documents/proj", "is_dir": True}], # parent listing ] ) - runtime = _runtime( - { - "cwd": "/documents", - }, - tool_call_id="tc-rd", - ) + runtime = _runtime({"cwd": "/documents"}, tool_call_id="tc-rd") - tool = m._create_rmdir_tool() - result = await tool.coroutine("/documents/proj", runtime=runtime) + result = await _tool(m, "rmdir").coroutine("/documents/proj", runtime=runtime) assert hasattr(result, "update") - update = result.update - assert update["pending_dir_deletes"] == [ + assert result.update["pending_dir_deletes"] == [ {"path": "/documents/proj", "tool_call_id": "tc-rd"} ] @@ -228,9 +220,9 @@ class TestRmdirStaging: children=[{"path": "/documents/proj/x.md", "is_dir": False}] ), ) - runtime = _runtime() - tool = m._create_rmdir_tool() - result = await tool.coroutine("/documents/proj", runtime=runtime) + result = await _tool(m, "rmdir").coroutine( + "/documents/proj", runtime=_runtime() + ) assert isinstance(result, str) assert "not empty" in result @@ -239,30 +231,25 @@ class TestRmdirStaging: m = _make_middleware() _bind_backend(m, _make_backend_stub(children=[])) runtime = _runtime( - { - "cwd": "/documents", - "staged_dirs": ["/documents/scratch"], - }, + {"cwd": "/documents", "staged_dirs": ["/documents/scratch"]}, tool_call_id="tc-rd", ) - tool = m._create_rmdir_tool() - result = await tool.coroutine("/documents/scratch", runtime=runtime) + result = await _tool(m, "rmdir").coroutine( + "/documents/scratch", runtime=runtime + ) assert hasattr(result, "update") update = result.update assert "pending_dir_deletes" not in update - # _CLEAR sentinel + remaining items (in this case, none). staged_after = update["staged_dirs"] - assert staged_after[0] == "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00" + assert staged_after[0] == _CLEAR assert "/documents/scratch" not in staged_after[1:] @pytest.mark.asyncio - async def test_rejects_root(self): + async def test_rejects_root_and_documents(self): m = _make_middleware() - runtime = _runtime() - tool = m._create_rmdir_tool() for victim in ("/", "/documents"): - result = await tool.coroutine(victim, runtime=runtime) + result = await _tool(m, "rmdir").coroutine(victim, runtime=_runtime()) assert isinstance(result, str) assert "refusing to rmdir" in result @@ -270,8 +257,7 @@ class TestRmdirStaging: async def test_rejects_cwd(self): m = _make_middleware() runtime = _runtime({"cwd": "/documents/proj"}) - tool = m._create_rmdir_tool() - result = await tool.coroutine("/documents/proj", runtime=runtime) + result = await _tool(m, "rmdir").coroutine("/documents/proj", runtime=runtime) assert isinstance(result, str) assert "cwd" in result.lower() @@ -279,8 +265,7 @@ class TestRmdirStaging: async def test_rejects_ancestor_of_cwd(self): m = _make_middleware() runtime = _runtime({"cwd": "/documents/proj/sub"}) - tool = m._create_rmdir_tool() - result = await tool.coroutine("/documents/proj", runtime=runtime) + result = await _tool(m, "rmdir").coroutine("/documents/proj", runtime=runtime) assert isinstance(result, str) assert "cwd" in result.lower() @@ -288,34 +273,31 @@ class TestRmdirStaging: async def test_rejects_files(self): m = _make_middleware() _bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]})) - runtime = _runtime() - tool = m._create_rmdir_tool() - result = await tool.coroutine("/documents/notes.md", runtime=runtime) + result = await _tool(m, "rmdir").coroutine( + "/documents/notes.md", runtime=_runtime() + ) assert isinstance(result, str) assert "is a file" in result # --------------------------------------------------------------------------- -# KBPostgresBackend view filter +# KBPostgresBackend staged-delete view filter (already the live backend) # --------------------------------------------------------------------------- class TestKBPostgresBackendDeleteFilter: - """als_info / glob / grep should suppress paths queued for delete.""" + """``als_info`` / glob / grep must suppress paths queued for delete.""" def _make_backend(self, state: dict[str, Any]) -> KBPostgresBackend: runtime = SimpleNamespace(state=state) - backend = KBPostgresBackend(search_space_id=1, runtime=runtime) - return backend + return KBPostgresBackend(search_space_id=1, runtime=runtime) def test_pending_filesystem_view_returns_deleted_paths(self): backend = self._make_backend( { - "pending_deletes": [ - {"path": "/documents/x.md", "tool_call_id": "t1"}, - ], + "pending_deletes": [{"path": "/documents/x.md", "tool_call_id": "t1"}], "pending_dir_deletes": [ - {"path": "/documents/d1", "tool_call_id": "t2"}, + {"path": "/documents/d1", "tool_call_id": "t2"} ], } ) diff --git a/surfsense_backend/tests/unit/middleware/test_b_filesystem_system_prompt.py b/surfsense_backend/tests/unit/middleware/test_b_filesystem_system_prompt.py new file mode 100644 index 000000000..b68dc5b4b --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_b_filesystem_system_prompt.py @@ -0,0 +1,54 @@ +"""Mode-specific system-prompt assembly tests for the LIVE filesystem middleware. + +Ported from ``TestModeSpecificPrompts`` in the former +``tests/unit/middleware/test_filesystem_middleware.py`` (which exercised the +dead twin ``app.agents.chat.shared.middleware.filesystem._build_filesystem_system_prompt``). + +These drive the production ``build_system_prompt`` so the prompt the model +actually receives stays mode-scoped: cloud rules don't leak into desktop +sessions and vice-versa, and the sandbox section appears only when available. + +The per-tool *description* assertions from the old suite are intentionally NOT +ported: they assert exact prompt copy (tightly coupled to the old wording) and +guard prompt token hygiene rather than the code-movement refactor this suite +protects. +""" + +from __future__ import annotations + +import pytest + +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.system_prompt import ( + build_system_prompt, +) + +pytestmark = pytest.mark.unit + + +class TestModeSpecificPrompts: + def test_cloud_prompt_omits_desktop_section(self): + prompt = build_system_prompt(FilesystemMode.CLOUD, sandbox_available=False) + assert "Local Folder Mode" not in prompt + assert "mount-prefixed" not in prompt + assert "Persistence Rules" in prompt + assert "/documents" in prompt + assert "temp_" in prompt + + def test_desktop_prompt_omits_cloud_persistence_rules(self): + prompt = build_system_prompt( + FilesystemMode.DESKTOP_LOCAL_FOLDER, sandbox_available=False + ) + assert "Persistence Rules" not in prompt + assert "Workspace Tree" not in prompt + assert "Local Folder Mode" in prompt + assert "mount-prefixed" in prompt + + def test_sandbox_addendum_appended_when_available(self): + prompt = build_system_prompt(FilesystemMode.CLOUD, sandbox_available=True) + assert "execute_code" in prompt + assert "Code Execution" in prompt + + def test_sandbox_addendum_absent_when_unavailable(self): + prompt = build_system_prompt(FilesystemMode.CLOUD, sandbox_available=False) + assert "execute_code" not in prompt diff --git a/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py b/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py index 467ba6d5f..91b6bcf3c 100644 --- a/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py +++ b/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py @@ -2,8 +2,10 @@ import pytest from langchain_core.messages import AIMessage from langchain_core.tools import StructuredTool -from app.agents.new_chat.middleware.dedup_tool_calls import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.dedup_hitl import ( DedupHITLToolCallsMiddleware, +) +from app.agents.chat.multi_agent_chat.shared.middleware.dedup_tool_calls import ( wrap_dedup_key_by_arg_name, ) diff --git a/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py b/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py deleted file mode 100644 index 7fd3fe4a7..000000000 --- a/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py +++ /dev/null @@ -1,214 +0,0 @@ -import pytest -from langchain_core.messages import AIMessage, HumanMessage - -from app.agents.new_chat.middleware.file_intent import ( - FileIntentMiddleware, - FileOperationIntent, - _fallback_path, -) - -pytestmark = pytest.mark.unit - - -class _FakeLLM: - def __init__(self, response_text: str): - self._response_text = response_text - - async def ainvoke(self, *_args, **_kwargs): - return AIMessage(content=self._response_text) - - -@pytest.mark.asyncio -async def test_file_write_intent_injects_contract_message(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.93,"suggested_filename":"ideas.md"}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="Create another random note for me")], - "turn_id": "123:456", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/ideas.md" - assert contract["turn_id"] == "123:456" - assert any( - "file_operation_contract" in str(msg.content) - for msg in result["messages"] - if hasattr(msg, "content") - ) - - -@pytest.mark.asyncio -async def test_non_write_intent_does_not_inject_contract_message(): - llm = _FakeLLM('{"intent":"file_read","confidence":0.88,"suggested_filename":null}') - middleware = FileIntentMiddleware(llm=llm) - original_messages = [HumanMessage(content="Read /notes.md")] - state = {"messages": original_messages, "turn_id": "abc:def"} - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - assert ( - result["file_operation_contract"]["intent"] - == FileOperationIntent.FILE_READ.value - ) - assert "messages" not in result - - -@pytest.mark.asyncio -async def test_file_write_null_filename_uses_semantic_default_path(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.74,"suggested_filename":null}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="create a random markdown file")], - "turn_id": "turn:1", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/notes.md" - - -@pytest.mark.asyncio -async def test_file_write_null_filename_defaults_to_markdown_path(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.71,"suggested_filename":null}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="create a sample json config file")], - "turn_id": "turn:2", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/notes.md" - - -@pytest.mark.asyncio -async def test_file_write_txt_suggestion_is_normalized_to_markdown(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.82,"suggested_filename":"random.txt"}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="create a random file")], - "turn_id": "turn:3", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/random.md" - - -@pytest.mark.asyncio -async def test_file_write_with_suggested_directory_preserves_folder(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.86,"suggested_filename":"random.md","suggested_directory":"pc backups","suggested_path":null}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="create a random file in pc backups folder")], - "turn_id": "turn:4", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/pc_backups/random.md" - - -@pytest.mark.asyncio -async def test_file_write_with_suggested_path_takes_precedence(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.9,"suggested_filename":"ignored.md","suggested_directory":"docs","suggested_path":"/reports/q2/summary.md"}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="create report")], - "turn_id": "turn:5", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/reports/q2/summary.md" - - -@pytest.mark.asyncio -async def test_file_write_infers_directory_from_user_text_when_missing(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.83,"suggested_filename":"random.md","suggested_directory":null,"suggested_path":null}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="create a random file in pc backups folder")], - "turn_id": "turn:6", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/pc_backups/random.md" - - -def test_fallback_path_normalizes_windows_slashes() -> None: - resolved = _fallback_path( - suggested_filename="summary.md", - suggested_path=r"\reports\q2\summary.md", - user_text="create report", - ) - - assert resolved == "/reports/q2/summary.md" - - -def test_fallback_path_normalizes_windows_drive_path() -> None: - resolved = _fallback_path( - suggested_filename=None, - suggested_path=r"C:\Users\anish\notes\todo.md", - user_text="create note", - ) - - assert resolved == "/C/Users/anish/notes/todo.md" - - -def test_fallback_path_normalizes_mixed_separators_and_duplicate_slashes() -> None: - resolved = _fallback_path( - suggested_filename="summary.md", - suggested_path=r"\\reports\\q2//summary.md", - user_text="create report", - ) - - assert resolved == "/reports/q2/summary.md" - - -def test_fallback_path_keeps_posix_style_absolute_path_for_linux_and_macos() -> None: - resolved = _fallback_path( - suggested_filename=None, - suggested_path="/var/log/surfsense/notes.md", - user_text="create note", - ) - - assert resolved == "/var/log/surfsense/notes.md" diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py b/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py index c71b5efde..dafda17d2 100644 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py @@ -2,16 +2,18 @@ from pathlib import Path import pytest -from app.agents.new_chat.filesystem_backends import build_backend_resolver -from app.agents.new_chat.filesystem_selection import ( +from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( ClientPlatform, FilesystemMode, FilesystemSelection, LocalFilesystemMount, ) -from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.multi_root_local_folder import ( MultiRootLocalFolderBackend, ) +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.resolver import ( + build_backend_resolver, +) pytestmark = pytest.mark.unit diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py b/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py deleted file mode 100644 index 70430f4ca..000000000 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py +++ /dev/null @@ -1,220 +0,0 @@ -"""Unit tests for the SurfSense filesystem middleware new behaviors. - -Covers: -* cloud cwd defaults to ``/documents`` and relative paths resolve under it -* cloud writes outside ``/documents/`` are rejected unless basename starts - with ``temp_`` -* cloud writes/edits to the anonymous document are rejected (read-only) -* helper methods on the middleware (``_resolve_relative``, - ``_check_cloud_write_namespace``, ``_default_cwd``) - -These tests use ``__new__`` to bypass the heavy ``__init__`` and exercise -the helper methods directly so the test surface stays narrow and fast. -""" - -from __future__ import annotations - -from types import SimpleNamespace - -import pytest - -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.middleware.filesystem import ( - SurfSenseFilesystemMiddleware, - _build_filesystem_system_prompt, - _build_tool_descriptions, -) - -pytestmark = pytest.mark.unit - - -def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD): - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - middleware._filesystem_mode = mode - return middleware - - -def _runtime(state: dict | None = None) -> SimpleNamespace: - return SimpleNamespace(state=state or {}) - - -class TestCloudCwdDefaults: - def test_default_cwd_in_cloud_is_documents_root(self): - m = _make_middleware() - assert m._default_cwd() == "/documents" - - def test_default_cwd_in_desktop_is_root(self): - m = _make_middleware(FilesystemMode.DESKTOP_LOCAL_FOLDER) - assert m._default_cwd() == "/" - - def test_current_cwd_uses_state_when_set(self): - m = _make_middleware() - runtime = _runtime({"cwd": "/documents/notes"}) - assert m._current_cwd(runtime) == "/documents/notes" - - def test_current_cwd_falls_back_to_default(self): - m = _make_middleware() - runtime = _runtime({}) - assert m._current_cwd(runtime) == "/documents" - - def test_current_cwd_ignores_invalid(self): - m = _make_middleware() - runtime = _runtime({"cwd": "not-absolute"}) - assert m._current_cwd(runtime) == "/documents" - - -class TestRelativePathResolution: - def test_relative_path_resolves_against_cwd(self): - m = _make_middleware() - runtime = _runtime({"cwd": "/documents/projects"}) - assert ( - m._resolve_relative("notes.md", runtime) == "/documents/projects/notes.md" - ) - - def test_relative_path_with_dotdot(self): - m = _make_middleware() - runtime = _runtime({"cwd": "/documents/a/b"}) - assert m._resolve_relative("../c.md", runtime) == "/documents/a/c.md" - - def test_absolute_path_is_kept(self): - m = _make_middleware() - runtime = _runtime({"cwd": "/documents"}) - assert m._resolve_relative("/other/x.md", runtime) == "/other/x.md" - - def test_empty_path_returns_cwd(self): - m = _make_middleware() - runtime = _runtime({"cwd": "/documents/projects"}) - assert m._resolve_relative("", runtime) == "/documents/projects" - - -class TestCloudWriteNamespacePolicy: - def test_documents_path_allowed(self): - m = _make_middleware() - runtime = _runtime() - assert m._check_cloud_write_namespace("/documents/foo.md", runtime) is None - - def test_documents_root_allowed(self): - m = _make_middleware() - runtime = _runtime() - assert m._check_cloud_write_namespace("/documents", runtime) is None - - def test_temp_basename_anywhere_allowed(self): - m = _make_middleware() - runtime = _runtime() - assert m._check_cloud_write_namespace("/temp_scratch.md", runtime) is None - assert m._check_cloud_write_namespace("/foo/temp_x.md", runtime) is None - assert m._check_cloud_write_namespace("/documents/temp_x.md", runtime) is None - - def test_other_paths_rejected(self): - m = _make_middleware() - runtime = _runtime() - err = m._check_cloud_write_namespace("/foo/bar.md", runtime) - assert err is not None - assert "must target /documents" in err - - def test_anon_doc_path_is_read_only(self): - m = _make_middleware() - runtime = _runtime( - { - "kb_anon_doc": { - "path": "/documents/uploaded.xml", - "title": "uploaded", - "content": "", - "chunks": [], - } - } - ) - err = m._check_cloud_write_namespace("/documents/uploaded.xml", runtime) - assert err is not None - assert "read-only" in err - - def test_desktop_mode_skips_namespace_policy(self): - m = _make_middleware(FilesystemMode.DESKTOP_LOCAL_FOLDER) - runtime = _runtime() - assert m._check_cloud_write_namespace("/random/path.md", runtime) is None - - -class TestModeSpecificPrompts: - """The prompt and tool descriptions must only describe the active mode. - - Cross-mode noise wastes tokens and confuses the model with rules it - cannot use this session. - """ - - def test_cloud_prompt_omits_desktop_section(self): - prompt = _build_filesystem_system_prompt( - FilesystemMode.CLOUD, sandbox_available=False - ) - assert "Local Folder Mode" not in prompt - assert "mount-prefixed" not in prompt - assert "Persistence Rules" in prompt - assert "/documents" in prompt - assert "temp_" in prompt - - def test_desktop_prompt_omits_cloud_persistence_rules(self): - prompt = _build_filesystem_system_prompt( - FilesystemMode.DESKTOP_LOCAL_FOLDER, sandbox_available=False - ) - assert "Persistence Rules" not in prompt - assert "Workspace Tree" not in prompt - assert "" not in prompt - assert "Local Folder Mode" in prompt - assert "mount-prefixed" in prompt - - def test_cloud_tool_descs_omit_desktop_phrases(self): - descs = _build_tool_descriptions(FilesystemMode.CLOUD) - for name in ( - "write_file", - "edit_file", - "move_file", - "mkdir", - "rm", - "rmdir", - "list_tree", - "grep", - ): - text = descs[name] - assert "Desktop" not in text, f"{name} leaks desktop hints" - assert "Cloud mode:" not in text, f"{name} qualifies a cloud-only desc" - - def test_desktop_tool_descs_omit_cloud_phrases(self): - descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER) - for name in ( - "write_file", - "edit_file", - "move_file", - "mkdir", - "rm", - "rmdir", - "list_tree", - "grep", - ): - text = descs[name] - assert "Cloud" not in text, f"{name} leaks cloud hints" - assert "/documents/" not in text, f"{name} mentions cloud namespace" - assert "temp_" not in text, f"{name} mentions cloud temp_ semantics" - - def test_cloud_descs_include_rm_and_rmdir(self): - descs = _build_tool_descriptions(FilesystemMode.CLOUD) - assert "rm" in descs and "rmdir" in descs - assert "Deletes a single file" in descs["rm"] - assert "Deletes an empty directory" in descs["rmdir"] - assert "rmdir" in descs["rmdir"] and "POSIX" in descs["rmdir"] - - def test_desktop_descs_warn_about_irreversibility(self): - descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER) - assert "NOT reversible" in descs["rm"] - assert "NOT reversible" in descs["rmdir"] - - def test_sandbox_addendum_appended_when_available(self): - prompt = _build_filesystem_system_prompt( - FilesystemMode.CLOUD, sandbox_available=True - ) - assert "execute_code" in prompt - assert "Code Execution" in prompt - - def test_sandbox_addendum_absent_when_unavailable(self): - prompt = _build_filesystem_system_prompt( - FilesystemMode.CLOUD, sandbox_available=False - ) - assert "execute_code" not in prompt diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py b/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py deleted file mode 100644 index 81cf590d3..000000000 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py +++ /dev/null @@ -1,173 +0,0 @@ -from pathlib import Path - -import pytest - -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware -from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( - MultiRootLocalFolderBackend, -) - -pytestmark = pytest.mark.unit - - -class _RuntimeNoSuggestedPath: - state = {"file_operation_contract": {}} - - -class _RuntimeWithSuggestedPath: - def __init__(self, suggested_path: str) -> None: - self.state = {"file_operation_contract": {"suggested_path": suggested_path}} - - -def test_contract_suggested_path_falls_back_to_documents_notes_md() -> None: - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - middleware._filesystem_mode = FilesystemMode.CLOUD - suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type] - # Cloud default cwd is /documents so the fallback lands in the KB. - assert suggested == "/documents/notes.md" - - -def test_contract_suggested_path_falls_back_to_root_notes_md_in_desktop() -> None: - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - middleware._filesystem_mode = FilesystemMode.DESKTOP_LOCAL_FOLDER - suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type] - assert suggested == "/notes.md" - - -def test_normalize_local_mount_path_prefixes_default_mount(tmp_path: Path) -> None: - root = tmp_path / "PC Backups" - root.mkdir() - backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),)) - runtime = _RuntimeNoSuggestedPath() - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign] - - resolved = middleware._normalize_local_mount_path("/random-note.md", runtime) # type: ignore[arg-type] - - assert resolved == "/pc_backups/random-note.md" - - -def test_normalize_local_mount_path_keeps_explicit_mount(tmp_path: Path) -> None: - root = tmp_path / "PC Backups" - root.mkdir() - backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),)) - runtime = _RuntimeNoSuggestedPath() - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign] - - resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type] - "/pc_backups/notes/random-note.md", - runtime, - ) - - assert resolved == "/pc_backups/notes/random-note.md" - - -def test_normalize_local_mount_path_windows_backslashes(tmp_path: Path) -> None: - root = tmp_path / "PC Backups" - root.mkdir() - backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),)) - runtime = _RuntimeNoSuggestedPath() - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign] - - resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type] - r"\notes\random-note.md", - runtime, - ) - - assert resolved == "/pc_backups/notes/random-note.md" - - -def test_normalize_local_mount_path_normalizes_mixed_separators(tmp_path: Path) -> None: - root = tmp_path / "PC Backups" - root.mkdir() - backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),)) - runtime = _RuntimeNoSuggestedPath() - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign] - - resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type] - r"\\notes//nested\\random-note.md", - runtime, - ) - - assert resolved == "/pc_backups/notes/nested/random-note.md" - - -def test_normalize_local_mount_path_keeps_explicit_mount_with_backslashes( - tmp_path: Path, -) -> None: - root = tmp_path / "PC Backups" - root.mkdir() - backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),)) - runtime = _RuntimeNoSuggestedPath() - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign] - - resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type] - r"\pc_backups\notes\random-note.md", - runtime, - ) - - assert resolved == "/pc_backups/notes/random-note.md" - - -def test_normalize_local_mount_path_prefixes_posix_absolute_path_for_linux_and_macos( - tmp_path: Path, -) -> None: - root = tmp_path / "PC Backups" - root.mkdir() - backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),)) - runtime = _RuntimeNoSuggestedPath() - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign] - - resolved = middleware._normalize_local_mount_path("/var/log/app.log", runtime) # type: ignore[arg-type] - - assert resolved == "/pc_backups/var/log/app.log" - - -def test_normalize_local_mount_path_prefers_unique_existing_parent_mount( - tmp_path: Path, -) -> None: - root_a = tmp_path / "RootA" - root_b = tmp_path / "RootB" - (root_a / "other").mkdir(parents=True) - (root_b / "nested" / "deep").mkdir(parents=True) - backend = MultiRootLocalFolderBackend( - (("root_a", str(root_a)), ("root_b", str(root_b))) - ) - runtime = _RuntimeNoSuggestedPath() - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign] - - resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type] - "/nested/deep/new-note.md", - runtime, - ) - - assert resolved == "/root_b/nested/deep/new-note.md" - - -def test_normalize_local_mount_path_uses_suggested_mount_when_ambiguous( - tmp_path: Path, -) -> None: - root_a = tmp_path / "RootA" - root_b = tmp_path / "RootB" - root_a.mkdir(parents=True) - root_b.mkdir(parents=True) - backend = MultiRootLocalFolderBackend( - (("root_a", str(root_a)), ("root_b", str(root_b))) - ) - runtime = _RuntimeWithSuggestedPath("/root_b/notes/context.md") - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign] - - resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type] - "/brand-new-note.md", - runtime, - ) - - assert resolved == "/root_b/brand-new-note.md" diff --git a/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py b/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py index ef95434bf..7724a4852 100644 --- a/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py +++ b/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py @@ -15,7 +15,7 @@ from unittest.mock import AsyncMock import numpy as np import pytest -from app.agents.new_chat.middleware import kb_persistence +from app.agents.chat.multi_agent_chat.main_agent.middleware.kb_persistence import middleware as kb_persistence from app.db import Document diff --git a/surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py b/surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py index feca23d27..500c6cc60 100644 --- a/surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py +++ b/surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py @@ -21,7 +21,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from app.agents.new_chat.middleware import kb_persistence +from app.agents.chat.multi_agent_chat.main_agent.middleware.kb_persistence import middleware as kb_persistence pytestmark = pytest.mark.unit diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py index 3529a946b..25de7308d 100644 --- a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py @@ -5,10 +5,13 @@ import json import pytest from langchain_core.messages import AIMessage, HumanMessage -from app.agents.new_chat.document_xml import build_document_xml as _build_document_xml -from app.agents.new_chat.middleware.knowledge_search import ( +from app.agents.chat.multi_agent_chat.shared.middleware import knowledge_search as ks +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.document_xml import ( + build_document_xml as _build_document_xml, +) +from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import ( KBSearchPlan, - KnowledgeBaseSearchMiddleware, + KnowledgePriorityMiddleware, _normalize_optional_date_range, _parse_kb_search_plan_response, _render_recent_conversation, @@ -201,7 +204,7 @@ class FakeBudgetLLM: return sum(len(msg.get("content", "")) for msg in messages) -class TestKnowledgeBaseSearchMiddlewarePlanner: +class TestKnowledgePriorityMiddlewarePlanner: @pytest.fixture(autouse=True) def _disable_planner_runnable(self, monkeypatch): # ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the @@ -258,7 +261,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: return [] monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + ks, "search_knowledge_base", fake_search_knowledge_base, ) @@ -271,7 +274,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: } ) ) - middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=37) + middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=37) result = await middleware.abefore_agent( { @@ -301,11 +304,11 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: return [] monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + ks, "search_knowledge_base", fake_search_knowledge_base, ) - middleware = KnowledgeBaseSearchMiddleware( + middleware = KnowledgePriorityMiddleware( llm=FakeLLM("not json"), search_space_id=37, ) @@ -330,11 +333,11 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: return [] monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + ks, "search_knowledge_base", fake_search_knowledge_base, ) - middleware = KnowledgeBaseSearchMiddleware( + middleware = KnowledgePriorityMiddleware( llm=FakeLLM( json.dumps( { @@ -375,11 +378,11 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: return [] monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.browse_recent_documents", + ks, "browse_recent_documents", fake_browse_recent_documents, ) monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + ks, "search_knowledge_base", fake_search_knowledge_base, ) @@ -393,7 +396,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: } ) ) - middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42) + middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42) result = await middleware.abefore_agent( {"messages": [HumanMessage(content="what's my latest file?")]}, @@ -422,11 +425,11 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: return [] monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.browse_recent_documents", + ks, "browse_recent_documents", fake_browse_recent_documents, ) monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + ks, "search_knowledge_base", fake_search_knowledge_base, ) @@ -440,7 +443,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: } ) ) - middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42) + middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42) await middleware.abefore_agent( {"messages": [HumanMessage(content="find the quarterly revenue report")]}, @@ -549,15 +552,15 @@ class TestKnowledgePriorityMentionDrain: return [] monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents", + ks, "fetch_mentioned_documents", fake_fetch_mentioned_documents, ) monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + ks, "search_knowledge_base", fake_search_knowledge_base, ) - middleware = KnowledgeBaseSearchMiddleware( + middleware = KnowledgePriorityMiddleware( llm=self._planner_llm(), search_space_id=42, mentioned_document_ids=[1, 2, 3], @@ -597,17 +600,17 @@ class TestKnowledgePriorityMentionDrain: return [] monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents", + ks, "fetch_mentioned_documents", fake_fetch_mentioned_documents, ) monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + ks, "search_knowledge_base", fake_search_knowledge_base, ) # Simulate a cached middleware instance whose closure was seeded # by a previous turn's cache-miss build (mentions=[1,2,3]). - middleware = KnowledgeBaseSearchMiddleware( + middleware = KnowledgePriorityMiddleware( llm=self._planner_llm(), search_space_id=42, mentioned_document_ids=[1, 2, 3], @@ -642,15 +645,15 @@ class TestKnowledgePriorityMentionDrain: return [] monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents", + ks, "fetch_mentioned_documents", fake_fetch_mentioned_documents, ) monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + ks, "search_knowledge_base", fake_search_knowledge_base, ) - middleware = KnowledgeBaseSearchMiddleware( + middleware = KnowledgePriorityMiddleware( llm=self._planner_llm(), search_space_id=42, mentioned_document_ids=[7, 8], diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_tree.py b/surfsense_backend/tests/unit/middleware/test_knowledge_tree.py index caaec3114..c14eca080 100644 --- a/surfsense_backend/tests/unit/middleware/test_knowledge_tree.py +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_tree.py @@ -9,8 +9,10 @@ contract cannot silently regress. from __future__ import annotations -from app.agents.new_chat.middleware.knowledge_tree import KnowledgeTreeMiddleware -from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT +from app.agents.chat.multi_agent_chat.main_agent.middleware.knowledge_tree.middleware import ( + KnowledgeTreeMiddleware, +) +from app.agents.chat.runtime.path_resolver import DOCUMENTS_ROOT def _compute(folder_paths: list[str], doc_paths: list[str]) -> set[str]: @@ -86,7 +88,7 @@ class TestFormatTreeRendering: folder_paths: list[str], doc_specs: list[dict], ) -> str: - from app.agents.new_chat.path_resolver import PathIndex + from app.agents.chat.runtime.path_resolver import PathIndex index = PathIndex( folder_paths={i + 1: p for i, p in enumerate(folder_paths)}, diff --git a/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py b/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py index 6e81ecf8e..aaa3b47fb 100644 --- a/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py +++ b/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py @@ -2,7 +2,9 @@ from pathlib import Path import pytest -from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.local_folder import ( + LocalFolderBackend, +) pytestmark = pytest.mark.unit diff --git a/surfsense_backend/tests/unit/middleware/test_multi_root_local_folder_backend.py b/surfsense_backend/tests/unit/middleware/test_multi_root_local_folder_backend.py index 43a671178..b2d545f27 100644 --- a/surfsense_backend/tests/unit/middleware/test_multi_root_local_folder_backend.py +++ b/surfsense_backend/tests/unit/middleware/test_multi_root_local_folder_backend.py @@ -2,7 +2,7 @@ from pathlib import Path import pytest -from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( +from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.multi_root_local_folder import ( MultiRootLocalFolderBackend, ) diff --git a/surfsense_backend/tests/unit/notifications/api/test_transform.py b/surfsense_backend/tests/unit/notifications/api/test_transform.py new file mode 100644 index 000000000..ba12ab3cf --- /dev/null +++ b/surfsense_backend/tests/unit/notifications/api/test_transform.py @@ -0,0 +1,94 @@ +"""Unit tests for pure notifications API request/response helpers.""" + +from __future__ import annotations + +import uuid +from datetime import UTC, datetime + +import pytest + +from app.notifications.api.transform import ( + parse_before_date, + parse_source_type, + to_response, +) +from app.notifications.persistence import Notification + +pytestmark = pytest.mark.unit + + +class TestParseSourceType: + def test_connector_prefix(self): + """A 'connector:' filter selects the connector types and JSONB facet.""" + parsed = parse_source_type("connector:GITHUB_CONNECTOR") + assert parsed.types == ("connector_indexing", "connector_deletion") + assert parsed.metadata_key == "connector_type" + assert parsed.value == "GITHUB_CONNECTOR" + + def test_doctype_prefix(self): + """A 'doctype:' filter selects the document type and JSONB facet.""" + parsed = parse_source_type("doctype:FILE") + assert parsed.types == ("document_processing",) + assert parsed.metadata_key == "document_type" + assert parsed.value == "FILE" + + def test_unknown_prefix_returns_none(self): + """An unrecognized prefix yields no filter.""" + assert parse_source_type("mystery:thing") is None + + +class TestParseBeforeDate: + def test_parses_iso_with_zulu(self): + """An ISO date with a 'Z' suffix parses to a UTC datetime.""" + parsed = parse_before_date("2024-01-15T00:00:00Z") + assert parsed == datetime(2024, 1, 15, tzinfo=UTC) + + def test_invalid_raises_value_error(self): + """A malformed date raises ValueError for the endpoint to turn into a 400.""" + with pytest.raises(ValueError): + parse_before_date("not-a-date") + + +def _notification(**overrides) -> Notification: + defaults = dict( + id=1, + user_id=uuid.uuid4(), + search_space_id=3, + type="document_processing", + title="Title", + message="Message", + read=False, + notification_metadata={"k": "v"}, + created_at=datetime(2024, 1, 1, tzinfo=UTC), + updated_at=datetime(2024, 1, 2, tzinfo=UTC), + ) + defaults.update(overrides) + return Notification(**defaults) + + +class TestToResponse: + def test_maps_core_fields(self): + """A persisted notification maps its core fields onto the response shape.""" + notification = _notification() + response = to_response(notification) + assert response.id == 1 + assert response.user_id == str(notification.user_id) + assert response.type == "document_processing" + assert response.metadata == {"k": "v"} + assert response.created_at == "2024-01-01T00:00:00+00:00" + assert response.updated_at == "2024-01-02T00:00:00+00:00" + + def test_missing_updated_at_maps_to_none(self): + """A missing updated_at is represented as None in the response.""" + response = to_response(_notification(updated_at=None)) + assert response.updated_at is None + + def test_missing_created_at_maps_to_empty_string(self): + """A missing created_at is represented as an empty string in the response.""" + response = to_response(_notification(created_at=None)) + assert response.created_at == "" + + def test_null_metadata_maps_to_empty_dict(self): + """Null metadata is normalized to an empty dict in the response.""" + response = to_response(_notification(notification_metadata=None)) + assert response.metadata == {} diff --git a/surfsense_backend/tests/unit/notifications/service/messages/test_connector_indexing.py b/surfsense_backend/tests/unit/notifications/service/messages/test_connector_indexing.py new file mode 100644 index 000000000..391ce4466 --- /dev/null +++ b/surfsense_backend/tests/unit/notifications/service/messages/test_connector_indexing.py @@ -0,0 +1,176 @@ +"""Unit tests for connector-indexing presentation logic.""" + +from __future__ import annotations + +import pytest + +from app.notifications.service.messages import connector_indexing as msg + +pytestmark = pytest.mark.unit + + +class TestOperationId: + def test_encodes_connector_id(self): + """The operation id embeds the connector id.""" + assert msg.operation_id(42).startswith("connector_42_") + + def test_appends_date_range_when_given(self): + """A start/end date range is appended to the operation id.""" + op = msg.operation_id(42, start_date="2024-01-01", end_date="2024-02-01") + assert op.endswith("_2024-01-01_2024-02-01") + + def test_uses_none_placeholder_for_open_ended_range(self): + """A missing range bound is encoded as the 'none' placeholder.""" + assert msg.operation_id(42, start_date="2024-01-01").endswith( + "_2024-01-01_none" + ) + + def test_google_drive_encodes_counts(self): + """The Drive operation id embeds connector id plus folder/file counts.""" + op = msg.google_drive_operation_id(7, folder_count=2, file_count=5) + assert op.startswith("drive_7_") + assert op.endswith("_2f_5files") + + +class TestProgress: + def test_known_stage_maps_to_message(self): + """A known stage maps to its user-facing message and is recorded.""" + message, meta = msg.progress(3, stage="fetching") + assert message == "Fetching your content" + assert meta["indexed_count"] == 3 + assert meta["sync_stage"] == "fetching" + + def test_unknown_stage_falls_back_to_processing(self): + """An unrecognized stage falls back to a generic 'Processing' message.""" + message, _ = msg.progress(1, stage="weird") + assert message == "Processing" + + def test_stage_message_overrides_mapping(self): + """An explicit stage message overrides the stage-to-message mapping.""" + message, _ = msg.progress(1, stage="fetching", stage_message="Custom") + assert message == "Custom" + + def test_no_stage_uses_legacy_default(self): + """With neither stage nor message, the legacy default message is used.""" + message, meta = msg.progress(1) + assert message == "Fetching your content" + assert "sync_stage" not in meta + + def test_total_count_yields_percent(self): + """Supplying a total count produces a progress percentage.""" + _, meta = msg.progress(5, total_count=10) + assert meta["total_count"] == 10 + assert meta["progress_percent"] == 50 + + +class TestRetry: + def test_strips_workspace_suffix_from_connector_name(self): + """The provider name is derived by stripping the workspace suffix.""" + message, _ = msg.retry("Notion - My Workspace", 0, "rate_limit", 1, 3) + assert message == "Notion rate limit reached. Retrying..." + + def test_explicit_service_name_wins(self): + """An explicit service name overrides the connector-derived name.""" + message, _ = msg.retry( + "Notion - WS", 0, "rate_limit", 1, 3, service_name="Slack" + ) + assert message.startswith("Slack rate limit reached") + + @pytest.mark.parametrize( + ("reason", "expected"), + [ + ("rate_limit", "Notion rate limit reached"), + ("server_error", "Notion is slow to respond"), + ("timeout", "Notion took too long"), + ("temporary_error", "Notion temporarily unavailable"), + ("something_else", "Waiting for Notion"), + ], + ) + def test_reason_wording(self, reason, expected): + """Each retry reason maps to its wording; unknown reasons get a fallback.""" + message, _ = msg.retry("Notion", 0, reason, 1, 3) + assert message.startswith(expected) + + def test_long_wait_shows_seconds(self): + """A wait longer than the threshold surfaces the retry delay in seconds.""" + message, _ = msg.retry("Notion", 0, "rate_limit", 1, 3, wait_seconds=10) + assert "Retrying in 10s..." in message + + def test_short_wait_is_hidden(self): + """A short wait is not worth showing, so no seconds are surfaced.""" + message, _ = msg.retry("Notion", 0, "rate_limit", 1, 3, wait_seconds=3) + assert message.endswith("Retrying...") + + def test_synced_count_suffix_singular_and_plural(self): + """Already-synced items are appended with correct singular/plural wording.""" + one, _ = msg.retry("Notion", 1, "rate_limit", 1, 3) + many, _ = msg.retry("Notion", 2, "rate_limit", 1, 3) + assert one.endswith("(1 item synced so far)") + assert many.endswith("(2 items synced so far)") + + def test_metadata_records_retry_state(self): + """Retry metadata captures the attempt, reason, and wait state.""" + _, meta = msg.retry("Notion", 0, "rate_limit", 2, 5, wait_seconds=8) + assert meta["sync_stage"] == "waiting_retry" + assert meta["retry_attempt"] == 2 + assert meta["retry_max_attempts"] == 5 + assert meta["retry_reason"] == "rate_limit" + assert meta["retry_wait_seconds"] == 8 + + +class TestCompletion: + def test_clean_success_plural(self): + """A clean multi-file sync reports ready/completed with plural wording.""" + title, message, status, meta = msg.completion("GitHub", 3) + assert title == "Ready: GitHub" + assert message == "Now searchable! 3 files synced." + assert status == "completed" + assert meta["sync_stage"] == "completed" + + def test_clean_success_singular(self): + """A single synced file uses singular 'file' wording.""" + _, message, _, _ = msg.completion("GitHub", 1) + assert message == "Now searchable! 1 file synced." + + def test_nothing_to_sync(self): + """Zero new items with no error reports 'Already up to date!'.""" + _, message, status, _ = msg.completion("GitHub", 0) + assert message == "Already up to date!" + assert status == "completed" + + def test_hard_failure(self): + """An error with nothing synced reports a hard failure.""" + title, message, status, meta = msg.completion("GitHub", 0, error_message="boom") + assert title == "Failed: GitHub" + assert message == "Sync failed: boom" + assert status == "failed" + assert meta["sync_stage"] == "failed" + + def test_partial_success_with_error_note(self): + """An error after partial progress still completes, with an appended note.""" + title, message, status, _ = msg.completion("GitHub", 2, error_message="flaky") + assert title == "Ready: GitHub" + assert message == "Now searchable! 2 files synced. Note: flaky" + assert status == "completed" + + def test_warning_is_treated_as_complete(self): + """A warning-level error completes the run rather than failing it.""" + title, message, status, _ = msg.completion( + "GitHub", 0, error_message="partial", is_warning=True + ) + assert title == "Ready: GitHub" + assert message == "Sync complete. partial" + assert status == "completed" + + def test_unsupported_files_note_singular_and_plural(self): + """Unsupported-file counts are described with correct singular/plural wording.""" + _, one, _, _ = msg.completion("GitHub", 2, unsupported_count=1) + _, many, _, _ = msg.completion("GitHub", 2, unsupported_count=3) + assert "1 file was not supported." in one + assert "3 files were not supported." in many + + def test_zero_indexed_with_unsupported_reports_complete(self): + """Nothing synced but some unsupported files still reports completion.""" + _, message, status, _ = msg.completion("GitHub", 0, unsupported_count=2) + assert message == "Sync complete. 2 files were not supported." + assert status == "completed" diff --git a/surfsense_backend/tests/unit/notifications/service/messages/test_document_processing.py b/surfsense_backend/tests/unit/notifications/service/messages/test_document_processing.py new file mode 100644 index 000000000..2f0a6a9d3 --- /dev/null +++ b/surfsense_backend/tests/unit/notifications/service/messages/test_document_processing.py @@ -0,0 +1,63 @@ +"""Unit tests for document-processing presentation logic.""" + +from __future__ import annotations + +import pytest + +from app.notifications.service.messages import document_processing as msg + +pytestmark = pytest.mark.unit + + +def test_operation_id_encodes_type_and_space(): + """The operation id embeds the document type and search space id.""" + op = msg.operation_id("FILE", "report.pdf", 9) + assert op.startswith("doc_FILE_9_") + + +@pytest.mark.parametrize( + ("stage", "expected"), + [ + ("parsing", "Reading your file"), + ("chunking", "Preparing for search"), + ("embedding", "Preparing for search"), + ("storing", "Finalizing"), + ("unknown", "Processing"), + ], +) +def test_progress_stage_messages(stage, expected): + """Each processing stage maps to its message; unknown stages get a fallback.""" + message, meta = msg.progress(stage) + assert message == expected + assert meta["processing_stage"] == stage + + +def test_progress_records_chunks_count(): + """A provided chunk count is stored in metadata for debugging.""" + _, meta = msg.progress("chunking", chunks_count=12) + assert meta["chunks_count"] == 12 + + +def test_progress_message_override(): + """An explicit stage message overrides the stage mapping.""" + message, _ = msg.progress("parsing", stage_message="Scanning") + assert message == "Scanning" + + +def test_completion_success(): + """A successful run reports ready/completed and records the document id.""" + title, message, status, meta = msg.completion("report.pdf", document_id=5) + assert title == "Ready: report.pdf" + assert message == "Now searchable!" + assert status == "completed" + assert meta["document_id"] == 5 + assert meta["processing_stage"] == "completed" + + +def test_completion_failure(): + """An error reports failed status with the error surfaced in the message.""" + title, message, status, meta = msg.completion("report.pdf", error_message="bad") + assert title == "Failed: report.pdf" + assert message == "Processing failed: bad" + assert status == "failed" + assert meta["processing_stage"] == "failed" diff --git a/surfsense_backend/tests/unit/notifications/service/messages/test_page_limit.py b/surfsense_backend/tests/unit/notifications/service/messages/test_page_limit.py new file mode 100644 index 000000000..9b2ac9638 --- /dev/null +++ b/surfsense_backend/tests/unit/notifications/service/messages/test_page_limit.py @@ -0,0 +1,30 @@ +"""Unit tests for page-limit presentation logic.""" + +from __future__ import annotations + +import pytest + +from app.notifications.service.messages import page_limit as msg + +pytestmark = pytest.mark.unit + + +def test_operation_id_encodes_search_space(): + """The operation id embeds the search space id.""" + assert msg.operation_id("doc.pdf", 9).startswith("page_limit_9_") + + +def test_summary_title_and_message(): + """The summary states the document and the used/limit page counts.""" + title, message = msg.summary("short.pdf", pages_used=95, pages_limit=100, pages_to_add=10) + assert title == "Page limit exceeded: short.pdf" + assert message == ( + "This document has ~10 page(s) but you've used 95/100 pages. " + "Upgrade to process more documents." + ) + + +def test_summary_truncates_long_name(): + """A long document name is truncated in the title.""" + title, _ = msg.summary("a" * 50, pages_used=1, pages_limit=2, pages_to_add=1) + assert title == f"Page limit exceeded: {'a' * 40}..." diff --git a/surfsense_backend/tests/unit/notifications/service/messages/test_text.py b/surfsense_backend/tests/unit/notifications/service/messages/test_text.py new file mode 100644 index 000000000..bf3611607 --- /dev/null +++ b/surfsense_backend/tests/unit/notifications/service/messages/test_text.py @@ -0,0 +1,24 @@ +"""Unit tests for shared notification text helpers.""" + +from __future__ import annotations + +import pytest + +from app.notifications.service.messages.text import truncate + +pytestmark = pytest.mark.unit + + +def test_truncate_leaves_short_text_unchanged(): + """Text under the limit is returned verbatim, with no ellipsis.""" + assert truncate("hello", 100) == "hello" + + +def test_truncate_keeps_text_at_exact_limit(): + """Text exactly at the limit is not truncated.""" + assert truncate("a" * 40, 40) == "a" * 40 + + +def test_truncate_appends_ellipsis_when_over_limit(): + """Text past the limit is cut to the limit and gains an ellipsis.""" + assert truncate("a" * 41, 40) == "a" * 40 + "..." diff --git a/surfsense_backend/tests/unit/notifications/service/test_metadata.py b/surfsense_backend/tests/unit/notifications/service/test_metadata.py new file mode 100644 index 000000000..56f1dc583 --- /dev/null +++ b/surfsense_backend/tests/unit/notifications/service/test_metadata.py @@ -0,0 +1,63 @@ +"""Unit tests for pure notification metadata transitions.""" + +from __future__ import annotations + +import pytest + +from app.notifications.service.metadata import apply_update, start_metadata + +pytestmark = pytest.mark.unit + + +class TestStartMetadata: + def test_seeds_operation_and_progress_fields(self): + """A new notification is seeded with operation id, in-progress status, and start time.""" + meta = start_metadata("op-1") + assert meta["operation_id"] == "op-1" + assert meta["status"] == "in_progress" + assert "started_at" in meta + + def test_preserves_initial_fields(self): + """Caller-provided initial metadata is carried through.""" + meta = start_metadata("op-1", {"connector_id": 7}) + assert meta["connector_id"] == 7 + + def test_does_not_mutate_caller_dict(self): + """Seeding returns a new dict without mutating the caller's input.""" + initial = {"connector_id": 7} + start_metadata("op-1", initial) + assert initial == {"connector_id": 7} + + +class TestApplyUpdate: + def test_completed_stamps_completed_at(self): + """A completed status records a completion timestamp.""" + meta = apply_update({"status": "in_progress"}, status="completed") + assert meta["status"] == "completed" + assert "completed_at" in meta + + def test_failed_stamps_completed_at(self): + """A failed status also records a completion timestamp.""" + meta = apply_update({}, status="failed") + assert "completed_at" in meta + + def test_in_progress_does_not_stamp_completed_at(self): + """A non-terminal status leaves the completion timestamp unset.""" + meta = apply_update({}, status="in_progress") + assert "completed_at" not in meta + + def test_merges_metadata_updates(self): + """Metadata updates are merged into the existing metadata.""" + meta = apply_update({"a": 1}, metadata_updates={"b": 2}) + assert meta == {"a": 1, "b": 2} + + def test_updates_override_existing_keys(self): + """Updates take precedence over existing keys on conflict.""" + meta = apply_update({"a": 1}, metadata_updates={"a": 9}) + assert meta["a"] == 9 + + def test_does_not_mutate_caller_dict(self): + """Applying updates returns a new dict without mutating the caller's input.""" + current = {"a": 1} + apply_update(current, status="completed", metadata_updates={"b": 2}) + assert current == {"a": 1} diff --git a/surfsense_backend/tests/unit/routes/test_revert_turn_route.py b/surfsense_backend/tests/unit/routes/test_revert_turn_route.py index 1e1cbffb3..35d409a40 100644 --- a/surfsense_backend/tests/unit/routes/test_revert_turn_route.py +++ b/surfsense_backend/tests/unit/routes/test_revert_turn_route.py @@ -18,7 +18,7 @@ from unittest.mock import AsyncMock, patch import pytest -from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags from app.routes import agent_revert_route from app.services.revert_service import RevertOutcome diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py index 9d5fdb190..571e7d15b 100644 --- a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py +++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py @@ -20,6 +20,7 @@ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock, patch import pytest +from langchain.tools import ToolRuntime pytestmark = pytest.mark.unit @@ -90,7 +91,9 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty(): async def test_generate_image_tool_global_sets_api_base_when_config_empty(): """Same defense at the agent tool entry point — both surfaces share the same OpenRouter config payloads.""" - from app.agents.new_chat.tools import generate_image as gi_module + from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools import ( + generate_image as gi_module, + ) cfg = { "id": -20_001, @@ -150,7 +153,19 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty(): tool = gi_module.create_generate_image_tool( search_space_id=1, db_session=MagicMock() ) - await tool.ainvoke({"prompt": "a cat", "n": 1}) + # The live tool takes an injected ToolRuntime and returns a Command; + # drive the raw coroutine with a minimal runtime (the tool only reads + # ``tool_call_id``). We assert on what was forwarded to litellm, not + # on the return value. + runtime = ToolRuntime( + state={}, + context=None, + config={}, + stream_writer=None, + tool_call_id="call-1", + store=None, + ) + await tool.coroutine(prompt="a cat", n=1, runtime=runtime) assert captured.get("api_base") == "https://openrouter.ai/api/v1" assert captured["model"] == "openrouter/openai/gpt-image-1" diff --git a/surfsense_backend/tests/unit/services/test_supports_image_input.py b/surfsense_backend/tests/unit/services/test_supports_image_input.py index 71fdee1c7..fabb3587a 100644 --- a/surfsense_backend/tests/unit/services/test_supports_image_input.py +++ b/surfsense_backend/tests/unit/services/test_supports_image_input.py @@ -227,7 +227,7 @@ global_llm_configs: def test_agent_config_from_yaml_explicit_overrides_resolver(): - from app.agents.new_chat.llm_config import AgentConfig + from app.agents.chat.runtime.llm_config import AgentConfig cfg_text_only = AgentConfig.from_yaml_config( { @@ -256,7 +256,7 @@ def test_agent_config_from_yaml_explicit_overrides_resolver(): def test_agent_config_from_yaml_unannotated_uses_resolver(): """Without an explicit YAML key, AgentConfig defers to the catalog resolver — for ``gpt-4o`` LiteLLM's map says supports_vision=True.""" - from app.agents.new_chat.llm_config import AgentConfig + from app.agents.chat.runtime.llm_config import AgentConfig cfg = AgentConfig.from_yaml_config( { @@ -275,7 +275,7 @@ def test_agent_config_auto_mode_supports_image_input(): so users can keep their selection on Auto with a vision-capable deployment somewhere in the pool. The router's own `allowed_fails` handles non-vision deployments via fallback.""" - from app.agents.new_chat.llm_config import AgentConfig + from app.agents.chat.runtime.llm_config import AgentConfig auto = AgentConfig.from_auto_mode() assert auto.supports_image_input is True diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py index b8ba9d80c..5e3aa6eda 100644 --- a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py +++ b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py @@ -61,7 +61,7 @@ async def test_get_vision_llm_global_openrouter_sets_api_base(): return_value=cfg, ), patch( - "app.agents.new_chat.llm_config.SanitizedChatLiteLLM", + "app.agents.chat.runtime.llm_config.SanitizedChatLiteLLM", new=FakeSanitized, ), ): diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_interrupt_inspector_all.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_interrupt_inspector_all.py index 15ab89b73..4457f4768 100644 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_interrupt_inspector_all.py +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_interrupt_inspector_all.py @@ -18,7 +18,7 @@ from langgraph.graph import END, START, StateGraph from langgraph.types import Send, interrupt from typing_extensions import TypedDict -from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( +from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.task_tool import ( build_task_tool_with_parent_config, ) from app.tasks.chat.streaming.helpers.interrupt_inspector import ( diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_parallel_refactor_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_parallel_refactor_parity.py deleted file mode 100644 index fa6d8b9e2..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_parallel_refactor_parity.py +++ /dev/null @@ -1,561 +0,0 @@ -"""Parity gate for the parallel refactor of ``stream_new_chat.py``. - -The new tree under ``app.tasks.chat.streaming.flows`` is built side-by-side with -the legacy monolithic ``app.tasks.chat.stream_new_chat`` so we can cut over -atomically. This file pins externally-observable behaviour at module -boundaries so a divergence between the two trees fails loudly *before* the -cutover. - -What we verify: - - 1. **Signature parity** — ``stream_new_chat`` / ``stream_resume_chat`` from - the new tree have the same call signature as the originals. - 2. **Helper extraction parity** — the SRP modules in ``flows/`` produce the - same outputs as the inline code in the legacy file for representative - inputs (initial thinking step, image-capability gate, runtime context, - SSE frame sequences, token-usage frame shape, persistence guards). - 3. **Wrapper delegation** — wrappers like ``load_llm_bundle`` / - ``can_recover_provider_rate_limit`` exist and are addressable. - -Delete this file along with ``stream_new_chat.py`` once the cutover is done -(see the parent refactor plan). -""" - -from __future__ import annotations - -import asyncio -import inspect -from typing import Any -from unittest.mock import AsyncMock, patch - -import pytest - -from app.agents.new_chat.context import SurfSenseContextSchema -from app.services.new_streaming_service import VercelStreamingService -from app.tasks.chat.stream_new_chat import ( - stream_new_chat as old_stream_new_chat, - stream_resume_chat as old_stream_resume_chat, -) -from app.tasks.chat.streaming.flows import ( - stream_new_chat as new_stream_new_chat, - stream_resume_chat as new_stream_resume_chat, -) -from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import ( - build_initial_thinking_step, -) -from app.tasks.chat.streaming.flows.new_chat.llm_capability import ( - check_image_input_capability, -) -from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import ( - await_persist_task, - spawn_persist_assistant_shell_task, - spawn_persist_user_task, - spawn_set_ai_responding_bg, -) -from app.tasks.chat.streaming.flows.new_chat.runtime_context import ( - build_new_chat_runtime_context, -) -from app.tasks.chat.streaming.flows.resume_chat.runtime_context import ( - build_resume_chat_runtime_context, -) -from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame -from app.tasks.chat.streaming.flows.shared.first_frames import ( - iter_final_frames, - iter_initial_frames, -) -from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle -from app.tasks.chat.streaming.flows.shared.premium_quota import ( - PremiumReservation, - needs_premium_quota, -) -from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import ( - can_recover_provider_rate_limit, -) - -pytestmark = pytest.mark.unit - - -# --------------------------------------------------------------------- signature - - -def _normalize_annotation(ann: Any) -> str: - """Compare-friendly form for an annotation. - - The legacy ``stream_new_chat.py`` does NOT use ``from __future__ import - annotations``, so its annotations are evaluated at import time and come - back as type objects / typing generics. The new tree DOES use it, so its - annotations are PEP-563 strings. - - Both reprs describe the same types — strip the module prefixes / typing - namespace + the ```` wrapper so we compare the canonical - declared form. - """ - if ann is inspect.Signature.empty: - return "" - raw = ann if isinstance(ann, str) else repr(ann) - cleaned = ( - raw.replace("typing.", "") - .replace("collections.abc.", "") - .replace("app.db.", "") - .replace("app.agents.new_chat.filesystem_selection.", "") - .replace("app.agents.new_chat.context.", "") - ) - # Unwrap ```` → ``int`` (legacy-side type objects). - if cleaned.startswith(""): - cleaned = cleaned[len("")] - return cleaned - - -def _normalize_sig(sig: inspect.Signature) -> list[tuple[str, Any, str]]: - return [ - (p.name, p.default, _normalize_annotation(p.annotation)) - for p in sig.parameters.values() - ] - - -def test_stream_new_chat_signature_matches_legacy() -> None: - old = inspect.signature(old_stream_new_chat) - new = inspect.signature(new_stream_new_chat) - assert _normalize_sig(new) == _normalize_sig(old) - assert _normalize_annotation(new.return_annotation) == _normalize_annotation( - old.return_annotation - ) - - -def test_stream_resume_chat_signature_matches_legacy() -> None: - old = inspect.signature(old_stream_resume_chat) - new = inspect.signature(new_stream_resume_chat) - assert _normalize_sig(new) == _normalize_sig(old) - assert _normalize_annotation(new.return_annotation) == _normalize_annotation( - old.return_annotation - ) - - -def test_orchestrators_are_async_generator_functions() -> None: - assert inspect.isasyncgenfunction(new_stream_new_chat) - assert inspect.isasyncgenfunction(new_stream_resume_chat) - - -# ------------------------------------------------------------ initial thinking - - -@pytest.mark.parametrize( - "user_query, image_urls, expected_title, expected_action", - [ - ("hello world", None, "Understanding your request", "Processing"), - ( - "", - ["data:image/png;base64,AAA"], - "Understanding your request", - "Processing", - ), - ("", None, "Understanding your request", "Processing"), - ], -) -def test_initial_thinking_step_branches( - user_query: str, - image_urls: list[str] | None, - expected_title: str, - expected_action: str, -) -> None: - step = build_initial_thinking_step( - user_query=user_query, - user_image_data_urls=image_urls, - ) - assert step.step_id == "thinking-1" - assert step.title == expected_title - assert len(step.items) == 1 - assert step.items[0].startswith(f"{expected_action}: ") - - -def test_initial_thinking_step_truncates_long_query() -> None: - long_query = "x" * 200 - step = build_initial_thinking_step( - user_query=long_query, - user_image_data_urls=None, - ) - # 80-char truncation + ellipsis, sandwiched after "Processing: ". - assert "..." in step.items[0] - item = step.items[0] - payload = item[len("Processing: ") :] - assert payload.startswith("x" * 80) and payload.endswith("...") - - -# ------------------------------------------------------------ capability gate - - -def test_image_capability_passes_without_images() -> None: - assert ( - check_image_input_capability(user_image_data_urls=None, agent_config=None) - is None - ) - - -def test_image_capability_passes_when_capability_unknown() -> None: - """Unknown / unmapped models are not blocked — only models LiteLLM has - *explicitly* marked text-only trip the gate.""" - - class _AgentConfig: - provider = "openrouter" - model_name = "unknown-mystery-model" - custom_provider = None - config_name = "Unknown" - litellm_params: dict[str, Any] = {} - - with patch( - "app.services.provider_capabilities.is_known_text_only_chat_model", - return_value=False, - ): - assert ( - check_image_input_capability( - user_image_data_urls=["data:image/png;base64,AAA"], - agent_config=_AgentConfig(), # type: ignore[arg-type] - ) - is None - ) - - -def test_image_capability_blocks_known_text_only_models() -> None: - class _AgentConfig: - provider = "openai" - model_name = "gpt-3.5-turbo" - custom_provider = None - config_name = "GPT-3.5" - litellm_params: dict[str, Any] = {"base_model": "gpt-3.5-turbo"} - - with patch( - "app.services.provider_capabilities.is_known_text_only_chat_model", - return_value=True, - ): - result = check_image_input_capability( - user_image_data_urls=["data:image/png;base64,AAA"], - agent_config=_AgentConfig(), # type: ignore[arg-type] - ) - assert result is not None - message, error_code = result - assert error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" - assert "GPT-3.5" in message - - -# ---------------------------------------------------------------- runtime ctx - - -def test_new_chat_runtime_context_prefers_accepted_folder_ids() -> None: - ctx = build_new_chat_runtime_context( - search_space_id=7, - mentioned_document_ids=[1, 2], - accepted_folder_ids=[10], - mentioned_folder_ids=[20, 30], - mentioned_connector_ids=None, - mentioned_connectors=None, - request_id="req", - turn_id="t1", - ) - assert isinstance(ctx, SurfSenseContextSchema) - assert ctx.search_space_id == 7 - assert list(ctx.mentioned_document_ids) == [1, 2] - assert list(ctx.mentioned_folder_ids) == [10] - assert ctx.request_id == "req" - assert ctx.turn_id == "t1" - - -def test_new_chat_runtime_context_falls_back_to_mentioned_folder_ids() -> None: - ctx = build_new_chat_runtime_context( - search_space_id=7, - mentioned_document_ids=None, - accepted_folder_ids=[], - mentioned_folder_ids=[20, 30], - mentioned_connector_ids=None, - mentioned_connectors=None, - request_id=None, - turn_id="t2", - ) - assert list(ctx.mentioned_folder_ids) == [20, 30] - - -def test_resume_chat_runtime_context_empty_mention_lists() -> None: - ctx = build_resume_chat_runtime_context( - search_space_id=42, request_id="req-r", turn_id="t-r" - ) - assert ctx.search_space_id == 42 - assert ctx.request_id == "req-r" - assert ctx.turn_id == "t-r" - - -# ---------------------------------------------------------------- SSE frames - - -def test_iter_initial_frames_emits_canonical_sequence() -> None: - svc = VercelStreamingService() - frames = list(iter_initial_frames(svc, turn_id="42:1700000000000")) - # Exactly 4 frames: message_start, start_step, turn-info (turn_id), turn-status (busy). - assert len(frames) == 4 - assert "42:1700000000000" in frames[2] - assert '"status":"busy"' in frames[3] or '"status": "busy"' in frames[3] - - -def test_iter_final_frames_emits_idle_then_finish_done() -> None: - svc = VercelStreamingService() - frames = list(iter_final_frames(svc)) - assert len(frames) == 4 - assert '"status":"idle"' in frames[0] or '"status": "idle"' in frames[0] - - -# ----------------------------------------------------------- token usage frame - - -class _FakeAccumulator: - """Minimal stand-in covering only the fields ``iter_token_usage_frame`` reads.""" - - def __init__(self, summary: Any = None) -> None: - self._summary = summary - self.calls = [1, 2, 3] - self.grand_total = 100 - self.total_cost_micros = 50_000 - self.total_prompt_tokens = 60 - self.total_completion_tokens = 40 - - def per_message_summary(self) -> Any: - return self._summary - - def serialized_calls(self) -> list[Any]: - return list(self.calls) - - -def test_token_usage_frame_skipped_when_no_summary() -> None: - svc = VercelStreamingService() - frames = list( - iter_token_usage_frame( - svc, - accumulator=_FakeAccumulator(summary=None), # type: ignore[arg-type] - log_label="parity-empty", - ) - ) - assert frames == [] - - -def test_token_usage_frame_emitted_when_summary_present() -> None: - svc = VercelStreamingService() - frames = list( - iter_token_usage_frame( - svc, - accumulator=_FakeAccumulator(summary=[{"m": "x", "t": 100}]), # type: ignore[arg-type] - log_label="parity-populated", - ) - ) - assert len(frames) == 1 - # Field shape on the wire is fixed by the FE; assert each surfaces. - payload = frames[0] - for key in ( - '"prompt_tokens":60', - '"completion_tokens":40', - '"total_tokens":100', - '"cost_micros":50000', - ): - assert key in payload.replace(" ", "") - - -# ------------------------------------------------------------------ llm_bundle - - -def test_load_llm_bundle_routes_negative_id_to_yaml_loader() -> None: - async def _run() -> tuple[Any, Any, str | None]: - with ( - patch( - "app.tasks.chat.streaming.flows.shared.llm_bundle.load_global_llm_config_by_id", - return_value=None, - ), - ): - return await load_llm_bundle( - session=AsyncMock(), # type: ignore[arg-type] - config_id=-1, - search_space_id=7, - ) - - llm, agent_config, error = asyncio.run(_run()) - assert llm is None - assert agent_config is None - assert error is not None and "id -1" in error - - -def test_load_llm_bundle_routes_nonnegative_id_to_db_loader() -> None: - async def _run() -> tuple[Any, Any, str | None]: - with ( - patch( - "app.tasks.chat.streaming.flows.shared.llm_bundle.load_agent_config", - new=AsyncMock(return_value=None), - ), - ): - return await load_llm_bundle( - session=AsyncMock(), # type: ignore[arg-type] - config_id=12, - search_space_id=7, - ) - - llm, agent_config, error = asyncio.run(_run()) - assert llm is None - assert agent_config is None - assert error is not None and "id 12" in error - - -# ----------------------------------------------------------------- premium quota - - -def test_needs_premium_quota_requires_user_and_premium_flag() -> None: - class _AgentConfig: - is_premium = True - - class _NonPremium: - is_premium = False - - assert needs_premium_quota(_AgentConfig(), "user-1") is True # type: ignore[arg-type] - assert needs_premium_quota(_AgentConfig(), None) is False # type: ignore[arg-type] - assert needs_premium_quota(_NonPremium(), "user-1") is False # type: ignore[arg-type] - assert needs_premium_quota(None, "user-1") is False - - -def test_premium_reservation_dataclass_shape() -> None: - # Sanity: the dataclass exists and carries the fields the orchestrator uses. - r = PremiumReservation(request_id="abc", reserved_micros=100, allowed=True) - assert r.request_id == "abc" - assert r.reserved_micros == 100 - assert r.allowed is True - - -# ----------------------------------------------------------- rate-limit guard - - -@pytest.mark.parametrize( - "first_event_seen, recovered, requested_id, current_id, expected", - [ - (False, False, 0, -1, True), - # Already recovered: no second pass. - (False, True, 0, -1, False), - # User explicitly picked a config: don't silently switch. - (False, False, 5, -1, False), - # Already on a database-backed (positive) id. - (False, False, 0, 7, False), - # User has already seen output: silent rebuild not possible. - (True, False, 0, -1, False), - ], -) -def test_can_recover_provider_rate_limit_truth_table( - first_event_seen: bool, - recovered: bool, - requested_id: int, - current_id: int, - expected: bool, -) -> None: - # Use a known rate-limit-shaped exception so the helper's last condition - # is satisfied; the guard only short-circuits to False when one of the - # *other* preconditions fails. - exc = Exception('{"error":{"type":"rate_limit_error","message":"slow"}}') - assert ( - can_recover_provider_rate_limit( - exc, - first_event_seen=first_event_seen, - runtime_rate_limit_recovered=recovered, - requested_llm_config_id=requested_id, - current_llm_config_id=current_id, - ) - is expected - ) - - -def test_can_recover_provider_rate_limit_rejects_non_rate_limit_exception() -> None: - assert ( - can_recover_provider_rate_limit( - ValueError("not a rate limit"), - first_event_seen=False, - runtime_rate_limit_recovered=False, - requested_llm_config_id=0, - current_llm_config_id=-1, - ) - is False - ) - - -# --------------------------------------------------------- persistence spawn - - -def test_spawn_set_ai_responding_bg_noop_without_user_id() -> None: - async def _run() -> set[asyncio.Task]: - background: set[asyncio.Task] = set() - spawn_set_ai_responding_bg(chat_id=1, user_id=None, background_tasks=background) - return background - - bg = asyncio.run(_run()) - assert bg == set() - - -def test_spawn_persist_user_task_registers_and_self_unregisters() -> None: - async def _run() -> tuple[int, int]: - background: set[asyncio.Task] = set() - with patch( - "app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_user_turn", - new=AsyncMock(return_value=99), - ): - task = spawn_persist_user_task( - chat_id=1, - user_id="u", - turn_id="t", - user_query="hi", - user_image_data_urls=None, - mentioned_documents=None, - background_tasks=background, - ) - size_before_await = len(background) - result = await asyncio.shield(task) - # Give the done-callback one event-loop tick to run. - await asyncio.sleep(0) - return size_before_await, result # type: ignore[return-value] - - size_before, result = asyncio.run(_run()) - assert size_before == 1 - assert result == 99 - - -def test_spawn_persist_assistant_shell_task_registers() -> None: - async def _run() -> int | None: - background: set[asyncio.Task] = set() - with patch( - "app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_assistant_shell", - new=AsyncMock(return_value=42), - ): - task = spawn_persist_assistant_shell_task( - chat_id=1, - user_id="u", - turn_id="t", - background_tasks=background, - ) - return await asyncio.shield(task) - - assert asyncio.run(_run()) == 42 - - -def test_await_persist_task_returns_none_on_failure() -> None: - async def _run() -> int | None: - async def _boom() -> int: - raise RuntimeError("DB down") - - task = asyncio.create_task(_boom()) - return await await_persist_task( - task, - chat_id=1, - turn_id="t", - log_label="parity-failure", - ) - - assert asyncio.run(_run()) is None - - -def test_await_persist_task_returns_none_for_none_input() -> None: - async def _run() -> int | None: - return await await_persist_task( - None, - chat_id=1, - turn_id="t", - log_label="parity-none", - ) - - assert asyncio.run(_run()) is None diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py deleted file mode 100644 index 8fde773e3..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py +++ /dev/null @@ -1,240 +0,0 @@ -"""Pin Stage 1 extractions as faithful copies of the old helpers. - -Extractions under ``app.tasks.chat.streaming`` are compared to -``app.tasks.chat.stream_new_chat`` helpers. -For each Stage 1 extraction we assert the new function returns the same -output as the old one for a representative input set. The moment the -two diverge - intentionally or otherwise - this file fails loudly so -the divergence is reviewed rather than shipped silently. -""" - -from __future__ import annotations - -import logging -from dataclasses import dataclass, field -from typing import Any - -import pytest - -from app.agents.new_chat.errors import BusyError -from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel -from app.tasks.chat.stream_new_chat import ( - _classify_stream_exception as old_classify, - _emit_stream_terminal_error as old_emit_terminal_error, - _extract_chunk_parts as old_extract_chunk_parts, - _extract_resolved_file_path as old_extract_resolved_file_path, - _tool_output_has_error as old_tool_output_has_error, - _tool_output_to_text as old_tool_output_to_text, -) -from app.tasks.chat.streaming.errors.classifier import ( - classify_stream_exception as new_classify, -) -from app.tasks.chat.streaming.errors.emitter import ( - emit_stream_terminal_error as new_emit_terminal_error, -) -from app.tasks.chat.streaming.helpers.chunk_parts import ( - extract_chunk_parts as new_extract_chunk_parts, -) -from app.tasks.chat.streaming.helpers.tool_output import ( - extract_resolved_file_path as new_extract_resolved_file_path, - tool_output_has_error as new_tool_output_has_error, - tool_output_to_text as new_tool_output_to_text, -) - -pytestmark = pytest.mark.unit - - -# ---------------------------------------------------------------- chunk parts - - -@dataclass -class _Chunk: - content: Any = "" - additional_kwargs: dict[str, Any] = field(default_factory=dict) - tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) - - -_CHUNK_CASES: list[Any] = [ - None, - _Chunk(content=""), - _Chunk(content="hello"), - _Chunk(content=42), # invalid type, defensively coerced to empty - _Chunk( - content=[ - {"type": "text", "text": "Hello "}, - {"type": "text", "text": "world"}, - ] - ), - _Chunk( - content=[ - {"type": "reasoning", "reasoning": "hmm "}, - {"type": "reasoning", "text": "still"}, - {"type": "text", "text": "answer"}, - ] - ), - _Chunk( - content=[ - {"type": "tool_call_chunk", "id": "c1", "name": "x", "args": "{"}, - {"type": "tool_use", "id": "c2", "name": "y"}, - {"type": "image_url", "url": "ignored"}, - ] - ), - _Chunk( - content="visible", - additional_kwargs={"reasoning_content": "private"}, - ), - _Chunk( - tool_call_chunks=[ - {"id": None, "name": None, "args": '{"a":1}', "index": 0}, - {"id": "c", "name": "n", "args": "}", "index": 0}, - ] - ), - _Chunk( - content=[{"type": "tool_call_chunk", "id": "from-block", "name": "x"}], - tool_call_chunks=[{"id": "from-attr", "name": "y"}], - ), -] - - -@pytest.mark.parametrize("chunk", _CHUNK_CASES) -def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None: - assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk) - - -# ----------------------------------------------------------- error classifier - - -def _classify_cases() -> list[Exception]: - """Inputs that the FE depends on being mapped to specific error codes.""" - return [ - Exception("totally generic error"), - Exception('{"error":{"type":"rate_limit_error","message":"slow down"}}'), - Exception( - 'OpenrouterException - {"error":{"message":"Provider returned error",' - '"code":429}}' - ), - BusyError(request_id="thread-busy-parity"), - Exception("Thread is busy with another request"), - ] - - -@pytest.mark.parametrize("exc", _classify_cases()) -def test_classify_stream_exception_matches_old_implementation( - exc: Exception, -) -> None: - new = new_classify(exc, flow_label="parity-test") - old = old_classify(exc, flow_label="parity-test") - # Strip the wall-clock retry timestamp before comparing — both - # implementations call ``time.time()`` independently and the call - # order is enough to differ by 1 ms in practice. Every other field - # in the tuple must match exactly. - new_extra = dict(new[5]) if isinstance(new[5], dict) else new[5] - old_extra = dict(old[5]) if isinstance(old[5], dict) else old[5] - if isinstance(new_extra, dict) and isinstance(old_extra, dict): - new_extra.pop("retry_after_at", None) - old_extra.pop("retry_after_at", None) - assert new[:5] == old[:5] - assert new_extra == old_extra - - -def test_classify_turn_cancelling_branch_parity() -> None: - """The TURN_CANCELLING branch reads cancel state for the busy thread id; - both implementations must agree on retry-window semantics, not just the - plain THREAD_BUSY code.""" - thread_id = "parity-cancelling-thread" - reset_cancel(thread_id) - request_cancel(thread_id) - exc = BusyError(request_id=thread_id) - new = new_classify(exc, flow_label="parity-test") - old = old_classify(exc, flow_label="parity-test") - assert new[0] == old[0] == "thread_busy" - assert new[1] == old[1] == "TURN_CANCELLING" - assert isinstance(new[5], dict) and isinstance(old[5], dict) - assert new[5]["retry_after_ms"] == old[5]["retry_after_ms"] - - -# ------------------------------------------------------------ terminal emitter - - -class _FakeStreamingService: - """Duck-types ``format_error`` for both old and new emitters.""" - - def __init__(self) -> None: - self.calls: list[dict[str, Any]] = [] - - def format_error( - self, message: str, *, error_code: str, extra: dict[str, Any] | None = None - ) -> str: - self.calls.append( - {"message": message, "error_code": error_code, "extra": extra} - ) - return f'data: {{"type":"error","errorText":"{message}"}}\n\n' - - -def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None: - """The new emitter must produce the same SSE frame and log the same - structured payload as the old one for the same arguments.""" - args: dict[str, Any] = { - "flow": "new", - "request_id": "req-parity", - "thread_id": 7, - "search_space_id": 9, - "user_id": "user-parity", - "message": "boom", - "error_kind": "server_error", - "error_code": "SERVER_ERROR", - "severity": "error", - "is_expected": False, - "extra": {"foo": "bar"}, - } - - new_svc = _FakeStreamingService() - old_svc = _FakeStreamingService() - - with caplog.at_level(logging.ERROR): - new_frame = new_emit_terminal_error(streaming_service=new_svc, **args) - old_frame = old_emit_terminal_error(streaming_service=old_svc, **args) - - assert new_frame == old_frame - assert new_svc.calls == old_svc.calls - chat_error_records = [ - r for r in caplog.records if "[chat_stream_error]" in r.message - ] - # One log line per emit call (two emits -> two records). - assert len(chat_error_records) == 2 - - -# ---------------------------------------------------------------- tool output - - -def test_tool_output_helpers_match_old_implementation() -> None: - samples: list[Any] = [ - {"result": "ok"}, - {"error": "bad"}, - {"result": "Error: x"}, - "Error: plain", - "fine", - {"nested": {"a": 1}}, - ] - for s in samples: - assert new_tool_output_to_text(s) == old_tool_output_to_text(s) - assert new_tool_output_has_error(s) == old_tool_output_has_error(s) - - assert new_extract_resolved_file_path( - tool_name="write_file", - tool_output={"path": " /tmp/x "}, - tool_input=None, - ) == old_extract_resolved_file_path( - tool_name="write_file", - tool_output={"path": " /tmp/x "}, - tool_input=None, - ) - assert new_extract_resolved_file_path( - tool_name="write_file", - tool_output={}, - tool_input={"file_path": " /fallback "}, - ) == old_extract_resolved_file_path( - tool_name="write_file", - tool_output={}, - tool_input={"file_path": " /fallback "}, - ) diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py deleted file mode 100644 index 3ee1ab622..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py +++ /dev/null @@ -1,241 +0,0 @@ -"""Parity tests for Stage 2 extractions (tool matching, thinking step, custom events).""" - -from __future__ import annotations - -from typing import Any -from unittest.mock import MagicMock - -import pytest - -from app.tasks.chat.stream_new_chat import _legacy_match_lc_id as old_legacy_match -from app.tasks.chat.streaming.handlers.custom_events import ( - handle_action_log, - handle_action_log_updated, - handle_document_created, - handle_report_progress, -) -from app.tasks.chat.streaming.helpers.tool_call_matching import ( - match_buffered_langchain_tool_call_id as new_legacy_match, -) -from app.tasks.chat.streaming.relay.state import AgentEventRelayState -from app.tasks.chat.streaming.relay.thinking_step_completion import ( - complete_active_thinking_step, -) -from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame - -pytestmark = pytest.mark.unit - - -def _copy_chunk_buffer(raw: list[dict[str, Any]]) -> list[dict[str, Any]]: - return [dict(x) for x in raw] - - -def test_legacy_tool_call_match_matches_old_implementation() -> None: - cases: list[tuple[list[dict[str, Any]], str, str, dict[str, str]]] = [ - ( - [ - {"name": "write_file", "id": "lc-a"}, - {"name": "other", "id": "lc-b"}, - ], - "write_file", - "run-1", - {}, - ), - ( - [{"name": "x", "id": None}, {"name": "y", "id": "lc-fallback"}], - "write_file", - "run-2", - {}, - ), - ([{"name": "no_id"}], "write_file", "run-3", {}), - ] - for chunks_template, tool_name, run_id, lc_map_seed in cases: - old_chunks = _copy_chunk_buffer(chunks_template) - new_chunks = _copy_chunk_buffer(chunks_template) - old_map = dict(lc_map_seed) - new_map = dict(lc_map_seed) - old_out = old_legacy_match(old_chunks, tool_name, run_id, old_map) - new_out = new_legacy_match(new_chunks, tool_name, run_id, new_map) - assert new_out == old_out - assert new_chunks == old_chunks - assert new_map == old_map - - -def test_emit_thinking_step_frame_invokes_builder_before_service() -> None: - order: list[str] = [] - builder = MagicMock() - - def on_ts(*args: Any, **kwargs: Any) -> None: - order.append("builder") - - builder.on_thinking_step.side_effect = on_ts - - svc = MagicMock() - - def fmt(**kwargs: Any) -> str: - order.append("service") - return "frame" - - svc.format_thinking_step.side_effect = fmt - - out = emit_thinking_step_frame( - streaming_service=svc, - content_builder=builder, - step_id="thinking-1", - title="Working", - status="in_progress", - items=["a"], - ) - assert out == "frame" - assert order == ["builder", "service"] - builder.on_thinking_step.assert_called_once() - svc.format_thinking_step.assert_called_once() - - -def test_emit_thinking_step_frame_skips_builder_when_none() -> None: - svc = MagicMock(return_value="x") - svc.format_thinking_step.return_value = "frame" - assert ( - emit_thinking_step_frame( - streaming_service=svc, - content_builder=None, - step_id="s", - title="t", - ) - == "frame" - ) - svc.format_thinking_step.assert_called_once() - - -def test_complete_active_thinking_step_mirrors_closure_semantics() -> None: - svc = MagicMock() - svc.format_thinking_step.return_value = "done-frame" - completed: set[str] = set() - relay_state = AgentEventRelayState.for_invocation() - - frame, new_id = complete_active_thinking_step( - state=relay_state, - streaming_service=svc, - content_builder=None, - last_active_step_id="thinking-1", - last_active_step_title="T", - last_active_step_items=["x"], - completed_step_ids=completed, - ) - assert frame == "done-frame" - assert new_id is None - assert "thinking-1" in completed - - frame2, id2 = complete_active_thinking_step( - state=relay_state, - streaming_service=svc, - content_builder=None, - last_active_step_id="thinking-1", - last_active_step_title="T", - last_active_step_items=[], - completed_step_ids=completed, - ) - assert frame2 is None - assert id2 == "thinking-1" - - -def test_agent_event_relay_state_factory_matches_counter_rule() -> None: - s0 = AgentEventRelayState.for_invocation() - assert s0.thinking_step_counter == 0 - assert s0.last_active_step_id is None - - s1 = AgentEventRelayState.for_invocation( - initial_step_id="thinking-resume-1", - initial_step_title="Inherited", - initial_step_items=["Topic: X"], - ) - assert s1.thinking_step_counter == 1 - assert s1.last_active_step_id == "thinking-resume-1" - assert s1.next_thinking_step_id("thinking") == "thinking-2" - - -@pytest.mark.parametrize( - ("phase", "message", "start_items", "expected_tail"), - [ - ( - "revising_section", - "progress line", - ["Topic: Foo", "Modifying bar", "stale..."], - ["Topic: Foo", "Modifying bar", "progress line"], - ), - ( - "other", - "phase msg", - ["Topic: Foo", "old line"], - ["Topic: Foo", "phase msg"], - ), - ], -) -def test_report_progress_items_match_reference( - phase: str, - message: str, - start_items: list[str], - expected_tail: list[str], -) -> None: - svc = MagicMock() - svc.format_thinking_step.return_value = "sse" - - items = list(start_items) - frame, new_items = handle_report_progress( - {"message": message, "phase": phase}, - last_active_step_id="step-1", - last_active_step_title="Report", - last_active_step_items=items, - streaming_service=svc, - content_builder=None, - ) - assert frame == "sse" - assert new_items == expected_tail - kwargs = svc.format_thinking_step.call_args.kwargs - assert kwargs["items"] == expected_tail - - -def test_report_progress_noop_when_missing_message_or_step() -> None: - svc = MagicMock() - items = ["Topic: A"] - f1, i1 = handle_report_progress( - {"message": "", "phase": "x"}, - last_active_step_id="s", - last_active_step_title="t", - last_active_step_items=items, - streaming_service=svc, - content_builder=None, - ) - assert f1 is None and i1 is items - - f2, i2 = handle_report_progress( - {"message": "m", "phase": "x"}, - last_active_step_id=None, - last_active_step_title="t", - last_active_step_items=items, - streaming_service=svc, - content_builder=None, - ) - assert f2 is None and i2 is items - - -def test_document_action_handlers_match_format_data_guards() -> None: - svc = MagicMock() - svc.format_data.return_value = "data-frame" - - assert handle_document_created({}, streaming_service=svc) is None - assert handle_document_created({"id": 0}, streaming_service=svc) is None - handle_document_created({"id": 42, "title": "x"}, streaming_service=svc) - svc.format_data.assert_called_with( - "documents-updated", {"action": "created", "document": {"id": 42, "title": "x"}} - ) - - svc.reset_mock() - assert handle_action_log({"id": None}, streaming_service=svc) is None - handle_action_log({"id": 1}, streaming_service=svc) - svc.format_data.assert_called_once_with("action-log", {"id": 1}) - - svc.reset_mock() - assert handle_action_log_updated({"id": None}, streaming_service=svc) is None - handle_action_log_updated({"id": 2}, streaming_service=svc) - svc.format_data.assert_called_once_with("action-log-updated", {"id": 2}) diff --git a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py index 1263a5fe1..0f154f1dc 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py @@ -1,4 +1,4 @@ -"""Unit tests for ``stream_new_chat._extract_chunk_parts``. +"""Unit tests for ``streaming.helpers.chunk_parts.extract_chunk_parts``. Earlier versions only handled ``isinstance(chunk.content, str)`` and silently dropped every other shape (Anthropic typed-block lists, @@ -14,7 +14,9 @@ from typing import Any import pytest -from app.tasks.chat.stream_new_chat import _extract_chunk_parts +from app.tasks.chat.streaming.helpers.chunk_parts import ( + extract_chunk_parts as _extract_chunk_parts, +) @dataclass diff --git a/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py b/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py index 50f7b8070..7f8285e98 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py @@ -7,7 +7,7 @@ constructs a fresh :class:`AgentEventRelayState` with ``thinking_step_counter=0``), React renders sibling timeline rows with the same key — the warning the user reported in production. -The contract this module pins: each ``_stream_agent_events`` invocation must +The contract this module pins: each ``stream_agent_events`` invocation must receive a ``step_prefix`` that is unique within the thread (we salt with the per-turn ``turn_id``), so the resulting step IDs across consecutive turns are always disjoint. @@ -23,10 +23,12 @@ from typing import Any import pytest from app.services.new_streaming_service import VercelStreamingService -from app.tasks.chat.stream_new_chat import ( - StreamResult, - _resume_step_prefix, - _stream_agent_events, +from app.tasks.chat.streaming.agent.event_loop import ( + stream_agent_events as _stream_agent_events, +) +from app.tasks.chat.streaming.shared.stream_result import StreamResult +from app.tasks.chat.streaming.shared.utils import ( + resume_step_prefix as _resume_step_prefix, ) pytestmark = pytest.mark.unit diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py index ada32d168..f3f28eb1c 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py @@ -1,6 +1,6 @@ """Unit tests for live tool-call argument streaming. -Pins the wire format that ``_stream_agent_events`` emits: +Pins the wire format that ``stream_agent_events`` emits: ``tool-input-start`` → ``tool-input-delta``... → ``tool-input-available`` → ``tool-output-available``, keyed consistently with LangChain ``tool_call.id`` when the model streams indexed chunks. @@ -20,11 +20,13 @@ from typing import Any import pytest from app.services.new_streaming_service import VercelStreamingService -from app.tasks.chat.stream_new_chat import ( - StreamResult, - _legacy_match_lc_id, - _stream_agent_events, +from app.tasks.chat.streaming.agent.event_loop import ( + stream_agent_events as _stream_agent_events, ) +from app.tasks.chat.streaming.helpers.tool_call_matching import ( + match_buffered_langchain_tool_call_id as _legacy_match_lc_id, +) +from app.tasks.chat.streaming.shared.stream_result import StreamResult pytestmark = pytest.mark.unit diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py deleted file mode 100644 index 19b06201f..000000000 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ /dev/null @@ -1,474 +0,0 @@ -import inspect -import json -import logging -import re -from pathlib import Path - -import pytest - -import app.tasks.chat.stream_new_chat as stream_new_chat_module -from app.agents.new_chat.errors import BusyError -from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel -from app.tasks.chat.stream_new_chat import ( - StreamResult, - _classify_stream_exception, - _contract_enforcement_active, - _evaluate_file_contract_outcome, - _extract_resolved_file_path, - _log_chat_stream_error, - _tool_output_has_error, -) - -pytestmark = pytest.mark.unit - - -def test_tool_output_error_detection(): - assert _tool_output_has_error("Error: failed to write file") - assert _tool_output_has_error({"error": "boom"}) - assert _tool_output_has_error({"result": "Error: disk is full"}) - assert not _tool_output_has_error({"result": "Updated file /notes.md"}) - - -def test_extract_resolved_file_path_prefers_structured_path(): - assert ( - _extract_resolved_file_path( - tool_name="write_file", - tool_output={"status": "completed", "path": "/docs/note.md"}, - tool_input=None, - ) - == "/docs/note.md" - ) - - -def test_extract_resolved_file_path_falls_back_to_tool_input(): - assert ( - _extract_resolved_file_path( - tool_name="edit_file", - tool_output={"status": "completed", "result": "updated"}, - tool_input={"file_path": "/docs/edited.md"}, - ) - == "/docs/edited.md" - ) - - -def test_extract_resolved_file_path_does_not_parse_result_text(): - assert ( - _extract_resolved_file_path( - tool_name="write_file", - tool_output={"result": "Updated file /docs/from-text.md"}, - tool_input=None, - ) - is None - ) - - -def test_file_write_contract_outcome_reasons(): - result = StreamResult(intent_detected="file_write") - passed, reason = _evaluate_file_contract_outcome(result) - assert not passed - assert reason == "no_write_attempt" - - result.write_attempted = True - passed, reason = _evaluate_file_contract_outcome(result) - assert not passed - assert reason == "write_failed" - - result.write_succeeded = True - passed, reason = _evaluate_file_contract_outcome(result) - assert not passed - assert reason == "verification_failed" - - result.verification_succeeded = True - passed, reason = _evaluate_file_contract_outcome(result) - assert passed - assert reason == "" - - -def test_contract_enforcement_local_only(): - result = StreamResult(filesystem_mode="desktop_local_folder") - assert _contract_enforcement_active(result) - - result.filesystem_mode = "cloud" - assert not _contract_enforcement_active(result) - - -def _extract_chat_stream_payload(record_message: str) -> dict: - prefix = "[chat_stream_error] " - assert record_message.startswith(prefix) - return json.loads(record_message[len(prefix) :]) - - -def test_unified_chat_stream_error_log_schema(caplog): - with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"): - _log_chat_stream_error( - flow="new", - error_kind="server_error", - error_code="SERVER_ERROR", - severity="warn", - is_expected=False, - request_id="req-123", - thread_id=101, - search_space_id=202, - user_id="user-1", - message="Error during chat: boom", - ) - - record = next(r for r in caplog.records if "[chat_stream_error]" in r.message) - payload = _extract_chat_stream_payload(record.message) - - required_keys = { - "event", - "flow", - "error_kind", - "error_code", - "severity", - "is_expected", - "request_id", - "thread_id", - "search_space_id", - "user_id", - "message", - } - assert required_keys.issubset(payload.keys()) - assert payload["event"] == "chat_stream_error" - assert payload["flow"] == "new" - assert payload["error_code"] == "SERVER_ERROR" - - -def test_premium_quota_uses_unified_chat_stream_log_shape(caplog): - with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"): - _log_chat_stream_error( - flow="resume", - error_kind="premium_quota_exhausted", - error_code="PREMIUM_QUOTA_EXHAUSTED", - severity="info", - is_expected=True, - request_id="req-premium", - thread_id=303, - search_space_id=404, - user_id="user-2", - message="Buy more tokens to continue with this model, or switch to a free model", - extra={"auto_fallback": False}, - ) - - record = next(r for r in caplog.records if "[chat_stream_error]" in r.message) - payload = _extract_chat_stream_payload(record.message) - assert payload["event"] == "chat_stream_error" - assert payload["error_kind"] == "premium_quota_exhausted" - assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED" - assert payload["flow"] == "resume" - assert payload["is_expected"] is True - assert payload["auto_fallback"] is False - - -def test_stream_error_emission_keeps_machine_error_codes(): - source = inspect.getsource(stream_new_chat_module) - format_error_calls = re.findall(r"format_error\(", source) - emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source)) - - # All stream paths should route through one shared terminal error emitter. - assert len(format_error_calls) == 1 - assert { - "PREMIUM_QUOTA_EXHAUSTED", - "SERVER_ERROR", - }.issubset(emitted_error_codes) - assert 'flow: Literal["new", "regenerate"] = "new"' in source - assert "_emit_stream_terminal_error" in source - assert "flow=flow" in source - assert 'flow="resume"' in source - - -def test_stream_exception_classifies_rate_limited(): - exc = Exception( - '{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}' - ) - kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( - exc, flow_label="chat" - ) - assert kind == "rate_limited" - assert code == "RATE_LIMITED" - assert severity == "warn" - assert is_expected is True - assert "temporarily rate-limited" in user_message - assert extra is None - - -def test_stream_exception_classifies_openrouter_429_payload(): - exc = Exception( - 'OpenrouterException - {"error":{"message":"Provider returned error","code":429,' - '"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}' - ) - kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( - exc, flow_label="chat" - ) - assert kind == "rate_limited" - assert code == "RATE_LIMITED" - assert severity == "warn" - assert is_expected is True - assert "temporarily rate-limited" in user_message - assert extra is None - - -def test_stream_exception_classifies_thread_busy(): - exc = BusyError(request_id="thread-123") - kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( - exc, flow_label="chat" - ) - assert kind == "thread_busy" - assert code == "THREAD_BUSY" - assert severity == "warn" - assert is_expected is True - assert "still finishing for this thread" in user_message - assert extra is None - - -def test_stream_exception_classifies_thread_busy_from_message(): - exc = Exception("Thread is busy with another request") - kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( - exc, flow_label="chat" - ) - assert kind == "thread_busy" - assert code == "THREAD_BUSY" - assert severity == "warn" - assert is_expected is True - assert "still finishing for this thread" in user_message - assert extra is None - - -def test_stream_exception_classifies_turn_cancelling_when_cancel_requested(): - thread_id = "thread-cancelling-1" - reset_cancel(thread_id) - request_cancel(thread_id) - exc = BusyError(request_id=thread_id) - kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( - exc, flow_label="chat" - ) - assert kind == "thread_busy" - assert code == "TURN_CANCELLING" - assert severity == "info" - assert is_expected is True - assert "stopping" in user_message - assert isinstance(extra, dict) - assert "retry_after_ms" in extra - - -def test_premium_classification_is_error_code_driven(): - classifier_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/lib/chat/chat-error-classifier.ts" - ) - source = classifier_path.read_text(encoding="utf-8") - - assert "PREMIUM_KEYWORDS" not in source - assert "RATE_LIMIT_KEYWORDS" not in source - assert "normalized.includes(" not in source - assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source - - -def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook(): - page_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" - ) - source = page_path.read_text(encoding="utf-8") - - assert "onPreAcceptFailure?: () => Promise;" in source - assert "if (!accepted) {" in source - assert "await onPreAcceptFailure?.();" in source - assert "await onAcceptedStreamError?.();" in source - assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source - assert "setMessageDocumentsMap((prev) => {" in source - - -def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): - user_message_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/components/assistant-ui/user-message.tsx" - ) - source = user_message_path.read_text(encoding="utf-8") - - assert "Not sent. Edit and retry." not in source - assert "failed_pre_accept" not in source - - -def test_network_send_failures_use_unified_retry_toast_message(): - classifier_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/lib/chat/chat-error-classifier.ts" - ) - classifier_source = classifier_path.read_text(encoding="utf-8") - request_errors_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/lib/chat/chat-request-errors.ts" - ) - request_errors_source = request_errors_path.read_text(encoding="utf-8") - - assert '"send_failed_pre_accept"' in classifier_source - assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source - assert 'errorCode === "TURN_CANCELLING"' in classifier_source - assert "if (withCode.code) return withCode.code;" in classifier_source - assert 'userMessage: "Message not sent. Please retry."' in classifier_source - assert 'userMessage: "Connection issue. Please try again."' in classifier_source - assert "const passthroughCodes = new Set([" in request_errors_source - assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source - assert '"THREAD_BUSY"' in request_errors_source - assert '"TURN_CANCELLING"' in request_errors_source - assert '"AUTH_EXPIRED"' in request_errors_source - assert '"UNAUTHORIZED"' in request_errors_source - assert '"RATE_LIMITED"' in request_errors_source - assert '"NETWORK_ERROR"' in request_errors_source - assert '"STREAM_PARSE_ERROR"' in request_errors_source - assert '"TOOL_EXECUTION_ERROR"' in request_errors_source - assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source - assert '"SERVER_ERROR"' in request_errors_source - assert "passthroughCodes.has(existingCode)" in request_errors_source - assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source - assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source - assert "Failed to start chat. Please try again." not in classifier_source - - -def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows(): - page_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" - ) - source = page_path.read_text(encoding="utf-8") - - # Each flow tracks accepted boundary and passes it into shared terminal handling. - # The acceptance boundary is still meaningful post-refactor: it gates - # local-state cleanup (onPreAcceptFailure path) and lets the shared - # terminal handler distinguish pre-accept aborts from in-stream errors. - assert "let newAccepted = false;" in source - assert "let resumeAccepted = false;" in source - assert "let regenerateAccepted = false;" in source - assert "accepted: newAccepted," in source - assert "accepted: resumeAccepted," in source - assert "accepted: regenerateAccepted," in source - - # NOTE: The FE-side persistence guards previously asserted here - # ("if (!resumeAccepted) return;", "if (!regenerateAccepted) return;", - # "if (newAccepted && !userPersisted) {") have been intentionally - # removed by the SSE-based message-id handshake refactor. Persistence - # is now server-authoritative: persist_user_turn / persist_assistant_shell - # run inside stream_new_chat / stream_resume_chat unconditionally and - # the FE consumes data-user-message-id / data-assistant-message-id - # SSE events to learn the canonical primary keys. There is therefore - # no FE call-site to guard, and the shared terminal handler relies - # purely on the `accepted` field above (forwarded to onAbort / - # onAcceptedStreamError) to drive UI cleanup. See - # tests/integration/chat/test_message_id_sse.py for the new - # cross-tier ID coherence guarantees. - - # The TURN_CANCELLING / THREAD_BUSY retry plumbing is independent - # of the persistence refactor and must still exist on every - # start-stream fetch. - assert "const fetchWithTurnCancellingRetry = useCallback(" in source - assert "computeFallbackTurnCancellingRetryDelay" in source - assert 'withMeta.errorCode === "TURN_CANCELLING"' in source - assert 'withMeta.errorCode === "THREAD_BUSY"' in source - assert "await fetchWithTurnCancellingRetry(() =>" in source - - -def test_cancel_active_turn_route_contract_exists(): - routes_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_backend/app/routes/new_chat_routes.py" - ) - source = routes_path.read_text(encoding="utf-8") - - assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source - assert "response_model=CancelActiveTurnResponse" in source - assert 'status="cancelling",' in source - assert 'error_code="TURN_CANCELLING",' in source - assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source - assert "retry_after_at=" in source - assert 'status="idle",' in source - assert 'error_code="NO_ACTIVE_TURN",' in source - - -def test_turn_status_route_contract_exists(): - routes_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_backend/app/routes/new_chat_routes.py" - ) - source = routes_path.read_text(encoding="utf-8") - - assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source - assert "response_model=TurnStatusResponse" in source - assert "_build_turn_status_payload(thread_id)" in source - assert "Permission.CHATS_READ.value" in source - assert "_raise_if_thread_busy_for_start(" in source - - -def test_turn_cancelling_retry_policy_contract_exists(): - routes_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_backend/app/routes/new_chat_routes.py" - ) - source = routes_path.read_text(encoding="utf-8") - - assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source - assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source - assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source - assert "def _compute_turn_cancelling_retry_delay(" in source - assert "retry-after-ms" in source - assert '"Retry-After"' in source - assert '"errorCode": "TURN_CANCELLING"' in source - - -def test_turn_status_sse_contract_exists(): - stream_source = ( - Path(__file__).resolve().parents[3] - / "surfsense_backend/app/tasks/chat/stream_new_chat.py" - ).read_text(encoding="utf-8") - state_source = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/lib/chat/streaming-state.ts" - ).read_text(encoding="utf-8") - pipeline_source = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/lib/chat/stream-pipeline.ts" - ).read_text(encoding="utf-8") - - assert '"turn-status"' in stream_source - assert '"status": "busy"' in stream_source - assert '"status": "idle"' in stream_source - assert 'type: "data-turn-status"' in state_source - assert 'case "data-turn-status":' in pipeline_source - assert "end_turn(str(chat_id))" in stream_source - - -def test_chat_deepagent_forwards_resolved_model_name_to_both_builders(): - """Regression guard: both system-prompt builders in chat_deepagent.py - must receive ``model_name=_resolve_prompt_model_name(...)`` so the - provider-variant dispatch can render the right ```` - block. Without this the prompt silently falls back to the empty - ``"default"`` variant — the original bug being fixed. - - This test mirrors :func:`test_stream_error_emission_keeps_machine_error_codes` - in style: it inspects module source text + a regex to enforce the - call-site shape, not just the wrapper layer (the wrappers already - forward ``model_name`` correctly, so testing them would not catch - the actual missed plumbing). - """ - import app.agents.new_chat.chat_deepagent as chat_deepagent_module - - source = inspect.getsource(chat_deepagent_module) - - # Helper itself must be defined. - assert "def _resolve_prompt_model_name(" in source - - # Both builder calls must forward the resolved model name. Match - # across newlines + whitespace because the kwargs are split over - # multiple lines. - pattern = re.compile( - r"build_(?:surfsense|configurable)_system_prompt\([^)]*" - r"model_name=_resolve_prompt_model_name\(", - re.DOTALL, - ) - matches = pattern.findall(source) - assert len(matches) == 2, ( - "Expected both system-prompt builder call sites to forward " - "`model_name=_resolve_prompt_model_name(...)`, found " - f"{len(matches)}" - ) diff --git a/surfsense_web/tests/auth.setup.ts b/surfsense_web/tests/auth.setup.ts index a33a81b3c..7c1e37a39 100644 --- a/surfsense_web/tests/auth.setup.ts +++ b/surfsense_web/tests/auth.setup.ts @@ -1,5 +1,6 @@ import path from "node:path"; import { expect, test as setup } from "@playwright/test"; +import { announcements } from "../lib/announcements/announcements-data"; import { acquireTestToken } from "./helpers/api/auth"; /** @@ -7,21 +8,58 @@ import { acquireTestToken } from "./helpers/api/auth"; * e2e user (rate-limit-free /__e2e__/auth/token first, /auth/jwt/login * fallback) and persists it via localStorage so every test in the * chromium project starts already authenticated. + * + * Also pre-seeds the localStorage flags that gate the two new-user UI + * overlays so they never intercept clicks in journeys: + * - `surfsense_announcements_state` — the blocking AnnouncementSpotlight + * dialog (e.g. "Introducing AI Automations") plus its toasts. + * - `surfsense-tour-` — the OnboardingTour spotlight for new users. */ const authFile = path.join(__dirname, "..", "playwright", ".auth", "user.json"); const STORAGE_KEY = "surfsense_bearer_token"; +const ANNOUNCEMENTS_KEY = "surfsense_announcements_state"; + +/** Decode the user id (`sub`) from a JWT without verifying the signature. */ +function decodeUserId(token: string): string | null { + try { + const payload = token.split(".")[1]; + if (!payload) return null; + const json = Buffer.from(payload, "base64").toString("utf8"); + const obj = JSON.parse(json) as { sub?: string }; + return obj.sub ?? null; + } catch { + return null; + } +} setup("authenticate", async ({ page, request }) => { const access_token = await acquireTestToken(request); expect(access_token, "Failed to acquire e2e bearer token").toBeTruthy(); + const userId = decodeUserId(access_token); + // Mark every known announcement read + toasted so spotlight/toast + // announcements never overlay the dashboard during journeys. Sourced + // from the real data file so future announcements are covered too. + const announcementIds = announcements.map((a) => a.id); + const announcementState = { readIds: announcementIds, toastedIds: announcementIds }; + await page.addInitScript( - ({ key, token }) => { + ({ key, token, announcementsKey, state, uid }) => { localStorage.setItem(key, token); + localStorage.setItem(announcementsKey, JSON.stringify(state)); + if (uid) { + localStorage.setItem(`surfsense-tour-${uid}`, "true"); + } }, - { key: STORAGE_KEY, token: access_token } + { + key: STORAGE_KEY, + token: access_token, + announcementsKey: ANNOUNCEMENTS_KEY, + state: announcementState, + uid: userId, + } ); // Use a public page so the init script can write localStorage without