mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
Merge pull request #1471 from CREDO23/improvement-code-organization
[Refactor] : Reorganize the agents and notifications modules
This commit is contained in:
commit
26a504f137
786 changed files with 6782 additions and 28394 deletions
|
|
@ -323,9 +323,6 @@ LANGSMITH_PROJECT=surfsense
|
|||
# =============================================================================
|
||||
# OPTIONAL: New-chat agent feature flags
|
||||
# =============================================================================
|
||||
# Multi-agent orchestrator switch for authenticated chat streaming.
|
||||
# MULTI_AGENT_CHAT_ENABLED=false
|
||||
|
||||
# Master kill-switch — when true, every flag below is forced OFF.
|
||||
# SURFSENSE_DISABLE_NEW_AGENT_STACK=false
|
||||
|
||||
|
|
|
|||
|
|
@ -1,557 +0,0 @@
|
|||
"""Vision autocomplete agent with scoped filesystem exploration.
|
||||
|
||||
Converts the stateless single-shot vision autocomplete into an agent that
|
||||
seeds a virtual filesystem from KB search results and lets the vision LLM
|
||||
explore documents via ``ls``, ``read_file``, ``glob``, ``grep``, etc.
|
||||
before generating the final completion.
|
||||
|
||||
Performance: KB search and agent graph compilation run in parallel so
|
||||
the only sequential latency is KB-search (or agent compile, whichever is
|
||||
slower) + the agent's LLM turns. There is no separate "query extraction"
|
||||
LLM call — the window title is used directly as the KB search query.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from deepagents.graph import BASE_AGENT_PROMPT
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from langchain.agents import create_agent
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.agents.new_chat.document_xml import build_document_xml
|
||||
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
||||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
search_knowledge_base,
|
||||
)
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.db import shielded_async_session
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
try:
|
||||
from deepagents.backends.utils import create_file_data
|
||||
except Exception: # pragma: no cover - defensive
|
||||
|
||||
def create_file_data(content: str) -> dict[str, Any]:
|
||||
return {"content": content.split("\n")}
|
||||
|
||||
|
||||
async def _build_autocomplete_filesystem(
|
||||
*,
|
||||
documents: Any,
|
||||
search_space_id: int,
|
||||
) -> tuple[dict[str, Any], dict[int, str]]:
|
||||
"""Build a ``state['files']``-shaped dict from KB search results.
|
||||
|
||||
This is the autocomplete-specific replacement for the previous
|
||||
``build_scoped_filesystem`` helper. It uses the canonical path resolver
|
||||
so paths line up with the rest of the system, including collision
|
||||
suffixes for duplicate titles.
|
||||
"""
|
||||
files: dict[str, Any] = {}
|
||||
doc_id_to_path: dict[int, str] = {}
|
||||
|
||||
if not documents:
|
||||
return files, doc_id_to_path
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
index = await build_path_index(session, search_space_id)
|
||||
|
||||
for document in documents:
|
||||
if not isinstance(document, dict):
|
||||
continue
|
||||
meta = document.get("document") or {}
|
||||
doc_id = meta.get("id")
|
||||
if not isinstance(doc_id, int):
|
||||
continue
|
||||
title = str(meta.get("title") or "untitled")
|
||||
folder_id = meta.get("folder_id")
|
||||
path = doc_to_virtual_path(
|
||||
doc_id=doc_id, title=title, folder_id=folder_id, index=index
|
||||
)
|
||||
chunk_ids = document.get("matched_chunk_ids") or []
|
||||
try:
|
||||
matched_set = {int(c) for c in chunk_ids}
|
||||
except (TypeError, ValueError):
|
||||
matched_set = set()
|
||||
xml = build_document_xml(document, matched_chunk_ids=matched_set)
|
||||
files[path] = create_file_data(xml)
|
||||
doc_id_to_path[doc_id] = path
|
||||
|
||||
if not files:
|
||||
# Ensure the synthetic /documents folder is visible even when empty.
|
||||
files.setdefault(f"{DOCUMENTS_ROOT}/.placeholder", create_file_data(""))
|
||||
|
||||
return files, doc_id_to_path
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KB_TOP_K = 10
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
AUTOCOMPLETE_SYSTEM_PROMPT = """You are a smart writing assistant that analyzes the user's screen to draft or complete text.
|
||||
|
||||
You will receive a screenshot of the user's screen. Your PRIMARY source of truth is the screenshot itself — the visual context determines what to write.
|
||||
|
||||
Your job:
|
||||
1. Analyze the ENTIRE screenshot to understand what the user is working on (email thread, chat conversation, document, code editor, form, etc.).
|
||||
2. Identify the text area where the user will type.
|
||||
3. Generate the text the user most likely wants to write based on the visual context.
|
||||
|
||||
You also have access to the user's knowledge base documents via filesystem tools. However:
|
||||
- ONLY consult the knowledge base if the screenshot clearly involves a topic where your KB documents are DIRECTLY relevant (e.g., the user is writing about a specific project/topic that matches a document title).
|
||||
- Do NOT explore documents just because they exist. Most autocomplete requests can be answered purely from the screenshot.
|
||||
- If you do read a document, only incorporate information that is 100% relevant to what the user is typing RIGHT NOW. Do not add extra details, background, or tangential information from the KB.
|
||||
- Keep your output SHORT — autocomplete should feel like a natural continuation, not an essay.
|
||||
|
||||
Key behavior:
|
||||
- If the text area is EMPTY, draft a concise response or message based on what you see on screen (e.g., reply to an email, respond to a chat message, continue a document).
|
||||
- If the text area already has text, continue it naturally — typically just a sentence or two.
|
||||
|
||||
Rules:
|
||||
- Be CONCISE. Prefer a single paragraph or a few sentences. Autocomplete is a quick assist, not a full draft.
|
||||
- Match the tone and formality of the surrounding context.
|
||||
- If the screen shows code, write code. If it shows a casual chat, be casual. If it shows a formal email, be formal.
|
||||
- Do NOT describe the screenshot or explain your reasoning.
|
||||
- Do NOT cite or reference documents explicitly — just let the knowledge inform your writing naturally.
|
||||
- If you cannot determine what to write, output an empty JSON array: []
|
||||
|
||||
## Output Format
|
||||
|
||||
You MUST provide exactly 3 different suggestion options. Each should be a distinct, plausible completion — vary the tone, detail level, or angle.
|
||||
|
||||
Return your suggestions as a JSON array of exactly 3 strings. Output ONLY the JSON array, nothing else — no markdown fences, no explanation, no commentary.
|
||||
|
||||
Example format:
|
||||
["First suggestion text here.", "Second suggestion — a different take.", "Third option with another approach."]
|
||||
|
||||
## Filesystem Tools `ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`
|
||||
|
||||
All file paths must start with a `/`.
|
||||
- ls: list files and directories at a given path.
|
||||
- read_file: read a file from the filesystem.
|
||||
- write_file: create a temporary file in the session (not persisted).
|
||||
- edit_file: edit a file in the session (not persisted for /documents/ files).
|
||||
- glob: find files matching a pattern (e.g., "**/*.xml").
|
||||
- grep: search for text within files.
|
||||
|
||||
## When to Use Filesystem Tools
|
||||
|
||||
BEFORE reaching for any tool, ask yourself: "Can I write a good completion purely from the screenshot?" If yes, just write it — do NOT explore the KB.
|
||||
|
||||
Only use tools when:
|
||||
- The user is clearly writing about a specific topic that likely has detailed information in their KB.
|
||||
- You need a specific fact, name, number, or reference that the screenshot doesn't provide.
|
||||
|
||||
When you do use tools, be surgical:
|
||||
- Check the `ls` output first. If no document title looks relevant, stop — do not read files just to see what's there.
|
||||
- If a title looks relevant, read only the `<chunk_index>` (first ~20 lines) and jump to matched chunks. Do not read entire documents.
|
||||
- Extract only the specific information you need and move on to generating the completion.
|
||||
|
||||
## Reading Documents Efficiently
|
||||
|
||||
Documents are formatted as XML. Each document contains:
|
||||
- `<document_metadata>` — title, type, URL, etc.
|
||||
- `<chunk_index>` — a table of every chunk with its **line range** and a
|
||||
`matched="true"` flag for chunks that matched the search query.
|
||||
- `<document_content>` — the actual chunks in original document order.
|
||||
|
||||
**Workflow**: read the first ~20 lines to see the `<chunk_index>`, identify
|
||||
chunks marked `matched="true"`, then use `read_file(path, offset=<start_line>,
|
||||
limit=<lines>)` to jump directly to those sections."""
|
||||
|
||||
APP_CONTEXT_BLOCK = """
|
||||
|
||||
The user is currently working in "{app_name}" (window: "{window_title}"). Use this to understand the type of application and adapt your tone and format accordingly."""
|
||||
|
||||
|
||||
def _build_autocomplete_system_prompt(app_name: str, window_title: str) -> str:
|
||||
prompt = AUTOCOMPLETE_SYSTEM_PROMPT
|
||||
if app_name:
|
||||
prompt += APP_CONTEXT_BLOCK.format(app_name=app_name, window_title=window_title)
|
||||
return prompt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-compute KB filesystem (runs in parallel with agent compilation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _KBResult:
|
||||
"""Container for pre-computed KB filesystem results."""
|
||||
|
||||
__slots__ = ("files", "ls_ai_msg", "ls_tool_msg")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
files: dict[str, Any] | None = None,
|
||||
ls_ai_msg: AIMessage | None = None,
|
||||
ls_tool_msg: ToolMessage | None = None,
|
||||
) -> None:
|
||||
self.files = files
|
||||
self.ls_ai_msg = ls_ai_msg
|
||||
self.ls_tool_msg = ls_tool_msg
|
||||
|
||||
@property
|
||||
def has_documents(self) -> bool:
|
||||
return bool(self.files)
|
||||
|
||||
|
||||
async def precompute_kb_filesystem(
|
||||
search_space_id: int,
|
||||
query: str,
|
||||
top_k: int = KB_TOP_K,
|
||||
) -> _KBResult:
|
||||
"""Search the KB and build the scoped filesystem outside the agent.
|
||||
|
||||
This is designed to be called via ``asyncio.gather`` alongside agent
|
||||
graph compilation so the two run concurrently.
|
||||
"""
|
||||
if not query:
|
||||
return _KBResult()
|
||||
|
||||
try:
|
||||
search_results = await search_knowledge_base(
|
||||
query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
if not search_results:
|
||||
return _KBResult()
|
||||
|
||||
new_files, _ = await _build_autocomplete_filesystem(
|
||||
documents=search_results,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
if not new_files:
|
||||
return _KBResult()
|
||||
|
||||
doc_paths = [
|
||||
p
|
||||
for p, v in new_files.items()
|
||||
if p.startswith("/documents/") and v is not None
|
||||
]
|
||||
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
|
||||
ai_msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}
|
||||
],
|
||||
)
|
||||
tool_msg = ToolMessage(
|
||||
content=str(doc_paths) if doc_paths else "No documents found.",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
return _KBResult(files=new_files, ls_ai_msg=ai_msg, ls_tool_msg=tool_msg)
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"KB pre-computation failed, proceeding without KB", exc_info=True
|
||||
)
|
||||
return _KBResult()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filesystem middleware — no save_document, no persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AutocompleteFilesystemMiddleware(SurfSenseFilesystemMiddleware):
|
||||
"""Filesystem middleware for autocomplete — read-only exploration only.
|
||||
|
||||
Passes ``search_space_id=None`` so the new persistence pipeline is
|
||||
bypassed; the autocomplete flow only reads, never commits to Postgres.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(search_space_id=None, created_by_id=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _compile_agent(
|
||||
llm: BaseChatModel,
|
||||
app_name: str,
|
||||
window_title: str,
|
||||
) -> Any:
|
||||
"""Compile the agent graph (CPU-bound, runs in a thread)."""
|
||||
system_prompt = _build_autocomplete_system_prompt(app_name, window_title)
|
||||
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
||||
|
||||
middleware = [
|
||||
AutocompleteFilesystemMiddleware(),
|
||||
PatchToolCallsMiddleware(),
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
|
||||
agent = await asyncio.to_thread(
|
||||
create_agent,
|
||||
llm,
|
||||
system_prompt=final_system_prompt,
|
||||
tools=[],
|
||||
middleware=middleware,
|
||||
)
|
||||
return agent.with_config({"recursion_limit": 200})
|
||||
|
||||
|
||||
async def create_autocomplete_agent(
|
||||
llm: BaseChatModel,
|
||||
*,
|
||||
search_space_id: int,
|
||||
kb_query: str,
|
||||
app_name: str = "",
|
||||
window_title: str = "",
|
||||
) -> tuple[Any, _KBResult]:
|
||||
"""Create the autocomplete agent and pre-compute KB in parallel.
|
||||
|
||||
Returns ``(agent, kb_result)`` so the caller can inject the pre-computed
|
||||
filesystem into the agent's initial state without any middleware delay.
|
||||
"""
|
||||
agent, kb = await asyncio.gather(
|
||||
_compile_agent(llm, app_name, window_title),
|
||||
precompute_kb_filesystem(search_space_id, kb_query),
|
||||
)
|
||||
return agent, kb
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON suggestion parsing (with fallback)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_suggestions(raw: str) -> list[str]:
|
||||
"""Extract a list of suggestion strings from the agent's output.
|
||||
|
||||
Tries, in order:
|
||||
1. Direct ``json.loads``
|
||||
2. Extract content between ```json ... ``` fences
|
||||
3. Find the first ``[`` … ``]`` span
|
||||
Falls back to wrapping the raw text as a single suggestion.
|
||||
"""
|
||||
text = raw.strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
for candidate in _json_candidates(text):
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
if isinstance(parsed, list) and all(isinstance(s, str) for s in parsed):
|
||||
return [s for s in parsed if s.strip()]
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
|
||||
return [text]
|
||||
|
||||
|
||||
def _json_candidates(text: str) -> list[str]:
|
||||
"""Yield candidate JSON strings from raw text."""
|
||||
candidates = [text]
|
||||
|
||||
fence = re.search(r"```(?:json)?\s*\n?(.*?)```", text, re.DOTALL)
|
||||
if fence:
|
||||
candidates.append(fence.group(1).strip())
|
||||
|
||||
bracket = re.search(r"\[.*]", text, re.DOTALL)
|
||||
if bracket:
|
||||
candidates.append(bracket.group(0))
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def stream_autocomplete_agent(
|
||||
agent: Any,
|
||||
input_data: dict[str, Any],
|
||||
streaming_service: VercelStreamingService,
|
||||
*,
|
||||
emit_message_start: bool = True,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream agent events as Vercel SSE, with thinking steps for tool calls.
|
||||
|
||||
When ``emit_message_start`` is False the caller has already sent the
|
||||
``message_start`` event (e.g. to show preparation steps before the agent
|
||||
runs).
|
||||
"""
|
||||
thread_id = uuid.uuid4().hex
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
text_buffer: list[str] = []
|
||||
active_tool_depth = 0
|
||||
thinking_step_counter = 0
|
||||
tool_step_ids: dict[str, str] = {}
|
||||
step_titles: dict[str, str] = {}
|
||||
completed_step_ids: set[str] = set()
|
||||
last_active_step_id: str | None = None
|
||||
|
||||
def next_thinking_step_id() -> str:
|
||||
nonlocal thinking_step_counter
|
||||
thinking_step_counter += 1
|
||||
return f"autocomplete-step-{thinking_step_counter}"
|
||||
|
||||
def complete_current_step() -> str | None:
|
||||
nonlocal last_active_step_id
|
||||
if last_active_step_id and last_active_step_id not in completed_step_ids:
|
||||
completed_step_ids.add(last_active_step_id)
|
||||
title = step_titles.get(last_active_step_id, "Done")
|
||||
event = streaming_service.format_thinking_step(
|
||||
step_id=last_active_step_id,
|
||||
title=title,
|
||||
status="complete",
|
||||
)
|
||||
last_active_step_id = None
|
||||
return event
|
||||
return None
|
||||
|
||||
if emit_message_start:
|
||||
yield streaming_service.format_message_start()
|
||||
|
||||
gen_step_id = next_thinking_step_id()
|
||||
last_active_step_id = gen_step_id
|
||||
step_titles[gen_step_id] = "Generating suggestions"
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=gen_step_id,
|
||||
title="Generating suggestions",
|
||||
status="in_progress",
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in agent.astream_events(
|
||||
input_data, config=config, version="v2"
|
||||
):
|
||||
event_type = event.get("event", "")
|
||||
if event_type == "on_chat_model_stream":
|
||||
if active_tool_depth > 0:
|
||||
continue
|
||||
if "surfsense:internal" in event.get("tags", []):
|
||||
continue
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
content = chunk.content
|
||||
if content and isinstance(content, str):
|
||||
text_buffer.append(content)
|
||||
|
||||
elif event_type == "on_chat_model_end":
|
||||
if active_tool_depth > 0:
|
||||
continue
|
||||
if "surfsense:internal" in event.get("tags", []):
|
||||
continue
|
||||
output = event.get("data", {}).get("output")
|
||||
if output and hasattr(output, "content"):
|
||||
if getattr(output, "tool_calls", None):
|
||||
continue
|
||||
content = output.content
|
||||
if content and isinstance(content, str) and not text_buffer:
|
||||
text_buffer.append(content)
|
||||
|
||||
elif event_type == "on_tool_start":
|
||||
active_tool_depth += 1
|
||||
tool_name = event.get("name", "unknown_tool")
|
||||
run_id = event.get("run_id", "")
|
||||
tool_input = event.get("data", {}).get("input", {})
|
||||
|
||||
step_event = complete_current_step()
|
||||
if step_event:
|
||||
yield step_event
|
||||
|
||||
tool_step_id = next_thinking_step_id()
|
||||
tool_step_ids[run_id] = tool_step_id
|
||||
last_active_step_id = tool_step_id
|
||||
|
||||
title, items = _describe_tool_call(tool_name, tool_input)
|
||||
step_titles[tool_step_id] = title
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title=title,
|
||||
status="in_progress",
|
||||
items=items,
|
||||
)
|
||||
|
||||
elif event_type == "on_tool_end":
|
||||
active_tool_depth = max(0, active_tool_depth - 1)
|
||||
run_id = event.get("run_id", "")
|
||||
step_id = tool_step_ids.pop(run_id, None)
|
||||
if step_id and step_id not in completed_step_ids:
|
||||
completed_step_ids.add(step_id)
|
||||
title = step_titles.get(step_id, "Done")
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=step_id,
|
||||
title=title,
|
||||
status="complete",
|
||||
)
|
||||
if last_active_step_id == step_id:
|
||||
last_active_step_id = None
|
||||
|
||||
step_event = complete_current_step()
|
||||
if step_event:
|
||||
yield step_event
|
||||
|
||||
raw_text = "".join(text_buffer)
|
||||
suggestions = _parse_suggestions(raw_text)
|
||||
|
||||
yield streaming_service.format_data("suggestions", {"options": suggestions})
|
||||
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Autocomplete agent streaming error: {e}", exc_info=True)
|
||||
yield streaming_service.format_error("Autocomplete failed. Please try again.")
|
||||
yield streaming_service.format_done()
|
||||
|
||||
|
||||
def _describe_tool_call(tool_name: str, tool_input: Any) -> tuple[str, list[str]]:
|
||||
"""Return a human-readable (title, items) for a tool call thinking step."""
|
||||
inp = tool_input if isinstance(tool_input, dict) else {}
|
||||
if tool_name == "ls":
|
||||
path = inp.get("path", "/")
|
||||
return "Listing files", [path]
|
||||
if tool_name == "read_file":
|
||||
fp = inp.get("file_path", "")
|
||||
display = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||
return "Reading file", [display]
|
||||
if tool_name == "write_file":
|
||||
fp = inp.get("file_path", "")
|
||||
display = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||
return "Writing file", [display]
|
||||
if tool_name == "edit_file":
|
||||
fp = inp.get("file_path", "")
|
||||
display = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||
return "Editing file", [display]
|
||||
if tool_name == "glob":
|
||||
pat = inp.get("pattern", "")
|
||||
base = inp.get("path", "/")
|
||||
return "Searching files", [f"{pat} in {base}"]
|
||||
if tool_name == "grep":
|
||||
pat = inp.get("pattern", "")
|
||||
path = inp.get("path", "")
|
||||
display_pat = pat[:60] + ("…" if len(pat) > 60 else "")
|
||||
return "Searching content", [
|
||||
f'"{display_pat}"' + (f" in {path}" if path else "")
|
||||
]
|
||||
return f"Using {tool_name}", []
|
||||
5
surfsense_backend/app/agents/chat/__init__.py
Normal file
5
surfsense_backend/app/agents/chat/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""Chat agents category.
|
||||
|
||||
Groups the conversational agents that share a kernel: ``anonymous_chat`` and
|
||||
``multi_agent_chat``. Code shared by *both* lives in ``chat/shared/``.
|
||||
"""
|
||||
14
surfsense_backend/app/agents/chat/anonymous_chat/__init__.py
Normal file
14
surfsense_backend/app/agents/chat/anonymous_chat/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
"""Anonymous / free-chat agent.
|
||||
|
||||
The no-login chat experience: a deliberately minimal agent that bypasses the
|
||||
full SurfSense deep-agent stack (filesystem, knowledge-base persistence,
|
||||
subagents, skills, memory) and answers with an optional ``web_search`` tool and
|
||||
an optional read-only uploaded document. See :mod:`.agent` for details.
|
||||
"""
|
||||
|
||||
from app.agents.chat.anonymous_chat.agent import (
|
||||
build_anonymous_system_prompt,
|
||||
create_anonymous_chat_agent,
|
||||
)
|
||||
|
||||
__all__ = ["build_anonymous_system_prompt", "create_anonymous_chat_agent"]
|
||||
|
|
@ -27,12 +27,12 @@ from langchain.agents.middleware import (
|
|||
from langchain_core.language_models import BaseChatModel
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.agents.new_chat.middleware import (
|
||||
from app.agents.chat.shared.context import SurfSenseContextSchema
|
||||
from app.agents.chat.shared.middleware import (
|
||||
RetryAfterMiddleware,
|
||||
create_surfsense_compaction_middleware,
|
||||
)
|
||||
from app.agents.new_chat.tools.web_search import create_web_search_tool
|
||||
from app.agents.chat.shared.tools.web_search import create_web_search_tool
|
||||
|
||||
# Cap how much of an uploaded document we inline into the system prompt. The
|
||||
# upload endpoint allows files up to several MB, but the doc is re-sent on
|
||||
|
|
@ -11,12 +11,12 @@ from langchain_core.language_models import BaseChatModel
|
|||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.stack import (
|
||||
from app.agents.chat.multi_agent_chat.main_agent.middleware.stack import (
|
||||
build_main_agent_deepagent_middleware,
|
||||
)
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.chat.shared.context import SurfSenseContextSchema
|
||||
from app.db import ChatVisibility
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
"""Action-log middleware: audit row per tool call (impl + builder)."""
|
||||
|
||||
from .builder import build_action_log_mw
|
||||
from .middleware import ActionLogMiddleware, ToolDefinition
|
||||
|
||||
__all__ = [
|
||||
"ActionLogMiddleware",
|
||||
"ToolDefinition",
|
||||
"build_action_log_mw",
|
||||
]
|
||||
|
|
@ -4,11 +4,10 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import ActionLogMiddleware
|
||||
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
|
||||
|
||||
from ..shared.flags import enabled
|
||||
from .middleware import ActionLogMiddleware
|
||||
|
||||
|
||||
def build_action_log_mw(
|
||||
|
|
@ -21,12 +20,13 @@ def build_action_log_mw(
|
|||
if not enabled(flags, "enable_action_log") or thread_id is None:
|
||||
return None
|
||||
try:
|
||||
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
|
||||
# No built-in tool declares a ``reverse`` callable yet, so the action
|
||||
# log runs without a tool_definitions map. Reversibility is opt-in per
|
||||
# tool via ``ToolDefinition.reverse`` and can be wired here when used.
|
||||
return ActionLogMiddleware(
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
tool_definitions=tool_defs_by_name,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
|
|
@ -1,25 +1,15 @@
|
|||
"""Append-only action-log middleware for the SurfSense agent.
|
||||
|
||||
Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes
|
||||
a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt
|
||||
into reversibility by declaring a ``reverse`` callable on their
|
||||
:class:`~app.agents.new_chat.tools.registry.ToolDefinition`; the rendered
|
||||
descriptor is persisted in ``reverse_descriptor`` for use by
|
||||
Wraps every tool call and writes a row to :class:`~app.db.AgentActionLog`
|
||||
after the tool returns. Tools opt into reversibility via a ``reverse``
|
||||
callable on their :class:`ToolDefinition`; the rendered descriptor powers
|
||||
``/api/threads/{thread_id}/revert/{action_id}``.
|
||||
|
||||
Design points:
|
||||
|
||||
* **Defensive.** Logging never blocks the agent. We catch every exception
|
||||
on the DB write path and emit a warning; the tool's ``ToolMessage``
|
||||
result is always returned untouched.
|
||||
* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) +
|
||||
``result_id`` + ``reverse_descriptor`` are stored. Tool output text
|
||||
remains in the LangGraph checkpoint / spilled tool-output files.
|
||||
* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)``
|
||||
with the parsed JSON result when the tool's content is a JSON object;
|
||||
otherwise the raw text is passed. Exceptions in the reverse callable
|
||||
are swallowed and logged — a failed descriptor render simply means the
|
||||
action is NOT marked reversible.
|
||||
Logging is fully defensive — DB-write failures are swallowed so the tool's
|
||||
result is always returned untouched. Only metadata (name, capped args,
|
||||
result_id, reverse_descriptor) is stored; tool output stays in the
|
||||
checkpoint. Reversibility is best-effort: a reverse callable that raises
|
||||
just leaves the action non-reversible.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -27,14 +17,14 @@ from __future__ import annotations
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.callbacks import adispatch_custom_event
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
from app.agents.new_chat.tools.registry import ToolDefinition
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - type-only
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
|
|
@ -44,6 +34,31 @@ if TYPE_CHECKING: # pragma: no cover - type-only
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""Reversibility descriptor consumed by :class:`ActionLogMiddleware`.
|
||||
|
||||
Only ``name`` and ``reverse`` are read by the middleware; the remaining
|
||||
fields let callers and tests describe a tool declaratively. A tool is
|
||||
marked reversible in the action log when ``reverse`` is set and renders a
|
||||
descriptor without raising.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for the tool.
|
||||
description: Human-readable description of what the tool does.
|
||||
factory: Optional callable that builds the tool (unused by the
|
||||
middleware; retained for declarative call sites/tests).
|
||||
reverse: Optional callable that, given the tool's ``(args, result)``,
|
||||
returns a ``ReverseDescriptor`` describing the inverse invocation.
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
factory: Callable[[dict[str, Any]], Any] | None = None
|
||||
reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None
|
||||
|
||||
|
||||
# Cap for the persisted ``args`` JSON to avoid bloating the action log with
|
||||
# accidentally-huge inputs. Values are truncated and a flag is set in the
|
||||
# stored payload so consumers can detect truncation.
|
||||
|
|
@ -178,11 +193,9 @@ class ActionLogMiddleware(AgentMiddleware):
|
|||
)
|
||||
return
|
||||
|
||||
# Surface a side-channel SSE event so the chat tool card can
|
||||
# render a Revert button immediately after the row is durable.
|
||||
# ``stream_new_chat`` translates this into a
|
||||
# ``data-action-log`` SSE event. We DO NOT include the
|
||||
# ``reverse_descriptor`` payload here; only a presence flag.
|
||||
# Side-channel event (relayed by ``stream_new_chat`` as a
|
||||
# ``data-action-log`` SSE) so the tool card can show a Revert button
|
||||
# once the row is durable. Carries a presence flag, not the descriptor.
|
||||
try:
|
||||
await adispatch_custom_event(
|
||||
"action_log",
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Anonymous-document middleware: Redis hydration, cloud only (impl + builder)."""
|
||||
|
||||
from .builder import build_anonymous_doc_mw
|
||||
from .middleware import AnonymousDocumentMiddleware
|
||||
|
||||
__all__ = [
|
||||
"AnonymousDocumentMiddleware",
|
||||
"build_anonymous_doc_mw",
|
||||
]
|
||||
|
|
@ -2,8 +2,9 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import AnonymousDocumentMiddleware
|
||||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
from .middleware import AnonymousDocumentMiddleware
|
||||
|
||||
|
||||
def build_anonymous_doc_mw(
|
||||
|
|
@ -24,8 +24,13 @@ from typing import Any
|
|||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, safe_filename
|
||||
from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.chat.runtime.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
safe_filename,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
"""Per-turn cooperative busy-lock middleware + cancel primitives (main-agent)."""
|
||||
|
||||
from .builder import build_busy_mutex_mw
|
||||
from .middleware import (
|
||||
BusyMutexMiddleware,
|
||||
end_turn,
|
||||
get_cancel_event,
|
||||
get_cancel_state,
|
||||
is_cancel_requested,
|
||||
manager,
|
||||
request_cancel,
|
||||
reset_cancel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BusyMutexMiddleware",
|
||||
"build_busy_mutex_mw",
|
||||
"end_turn",
|
||||
"get_cancel_event",
|
||||
"get_cancel_state",
|
||||
"is_cancel_requested",
|
||||
"manager",
|
||||
"request_cancel",
|
||||
"reset_cancel",
|
||||
]
|
||||
|
|
@ -2,10 +2,12 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import BusyMutexMiddleware
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
|
||||
|
||||
from ..shared.flags import enabled
|
||||
from .middleware import (
|
||||
BusyMutexMiddleware,
|
||||
)
|
||||
|
||||
|
||||
def build_busy_mutex_mw(flags: AgentFeatureFlags) -> BusyMutexMiddleware | None:
|
||||
|
|
@ -1,32 +1,12 @@
|
|||
"""
|
||||
BusyMutexMiddleware — per-thread asyncio lock + cancel token.
|
||||
"""Per-thread asyncio lock + cooperative cancel token, keyed by ``thread_id``.
|
||||
|
||||
LangChain has no built-in concept of "this thread is already running a
|
||||
turn — refuse the second concurrent request". Without it, a user
|
||||
double-clicking "send" or refreshing the page mid-stream can spawn two
|
||||
turns racing on the same checkpoint, producing duplicated tool calls
|
||||
and mangled state.
|
||||
Refuses a second concurrent turn on the same thread (e.g. double-clicked
|
||||
"send") that would otherwise race on the same checkpoint and duplicate tool
|
||||
calls. Also exposes a per-thread cancel event that long-running tools poll
|
||||
via ``runtime.context.cancel_event.is_set()`` to abort cooperatively.
|
||||
|
||||
Ported from OpenCode's ``Stream.scoped(AbortController)`` pattern: a
|
||||
single-process, in-memory lock + cooperative cancellation token keyed by
|
||||
``thread_id``. For multi-worker deployments a distributed lock backend
|
||||
(Redis or PostgreSQL advisory locks) is a phase-2 follow-up.
|
||||
|
||||
What this provides:
|
||||
- A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``;
|
||||
acquiring the lock during ``before_agent`` blocks any concurrent
|
||||
prompt on the same thread until release.
|
||||
- A per-thread ``asyncio.Event`` (``cancel_event``) that long-running
|
||||
tools can poll to abort cooperatively. The event is reset between
|
||||
turns. Tools should check ``runtime.context.cancel_event.is_set()``
|
||||
in tight inner loops.
|
||||
- A typed :class:`~app.agents.new_chat.errors.BusyError` raised when a
|
||||
second turn arrives while the lock is held.
|
||||
|
||||
Note: SurfSense's ``stream_new_chat`` is the call site that should
|
||||
acquire/release. Wiring this as middleware means the contract is
|
||||
explicit and the lock manager is shared with subagents that compile
|
||||
their own ``create_agent`` runnables.
|
||||
Process-local and in-memory; multi-worker deployments need a distributed lock
|
||||
(Redis / PostgreSQL advisory locks) as a follow-up.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -46,7 +26,7 @@ from langchain.agents.middleware.types import (
|
|||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.new_chat.errors import BusyError
|
||||
from app.agents.chat.runtime.errors import BusyError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -152,9 +132,8 @@ class _ThreadLockManager:
|
|||
return True
|
||||
|
||||
|
||||
# Module-level singleton — process-local but reused across all agent
|
||||
# instances built in this process. Subagents created in nested
|
||||
# ``create_agent`` calls also get this so locks are coherent.
|
||||
# Process-local singleton shared across all agents/subagents built in this
|
||||
# process so per-thread locks stay coherent.
|
||||
manager = _ThreadLockManager()
|
||||
|
||||
|
||||
|
|
@ -266,7 +245,6 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
await lock.acquire()
|
||||
epoch = manager.bump_turn_epoch(thread_id)
|
||||
self._held_locks[thread_id] = (lock, epoch)
|
||||
# Reset the cancel event so this turn starts fresh
|
||||
reset_cancel(thread_id)
|
||||
return None
|
||||
|
||||
|
|
@ -289,17 +267,14 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
return None
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
# Always clear cancel event between turns so a stale signal
|
||||
# doesn't leak into the next request.
|
||||
# Clear cancel event so a stale signal doesn't leak into the next turn.
|
||||
reset_cancel(thread_id)
|
||||
return None
|
||||
|
||||
# Provide sync no-ops because the middleware base class allows them
|
||||
def before_agent( # type: ignore[override]
|
||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
# Sync path: no asyncio.Lock to acquire. Best we can do is reject
|
||||
# if anyone else is in flight.
|
||||
# Sync path can't await an asyncio.Lock; only reject if one is in flight.
|
||||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
if self._require_thread_id:
|
||||
|
|
@ -1,7 +1,9 @@
|
|||
"""RunnableConfig wiring for nested subagent invocations.
|
||||
"""HITL resume side-channel for nested subagent invocations.
|
||||
|
||||
Forwards the parent's ``runtime.config`` (thread_id, …) into the subagent and
|
||||
exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads.
|
||||
Exposes the configurable side-channel ``stream_resume_chat`` uses to ferry
|
||||
resume payloads into a mid-flight subagent. The ``RunnableConfig`` builder and
|
||||
state-key filter shared with subagents live in
|
||||
``app.agents.chat.multi_agent_chat.subagents.shared.invocation``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -11,8 +13,6 @@ from typing import Any
|
|||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# langgraph stores the parent task's scratchpad under this configurable key;
|
||||
|
|
@ -20,39 +20,6 @@ logger = logging.getLogger(__name__)
|
|||
_LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
|
||||
|
||||
|
||||
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
||||
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``.
|
||||
|
||||
Each parallel subagent invocation lands in its own checkpoint slot keyed
|
||||
by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``.
|
||||
The same call across the resume cycle keeps reading from the same snapshot
|
||||
(``tool_call_id`` is stable per LLM-emitted call).
|
||||
|
||||
We namespace via ``thread_id`` rather than ``checkpoint_ns`` because
|
||||
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
|
||||
subgraph path and raises ``ValueError("Subgraph X not found")``.
|
||||
"""
|
||||
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
|
||||
current_limit = merged.get("recursion_limit")
|
||||
try:
|
||||
current_int = int(current_limit) if current_limit is not None else 0
|
||||
except (TypeError, ValueError):
|
||||
current_int = 0
|
||||
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
|
||||
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||
|
||||
configurable: dict[str, Any] = dict(merged.get("configurable") or {})
|
||||
parent_thread_id = configurable.get("thread_id")
|
||||
per_call_suffix = f"task:{runtime.tool_call_id}"
|
||||
configurable["thread_id"] = (
|
||||
f"{parent_thread_id}::{per_call_suffix}"
|
||||
if parent_thread_id
|
||||
else per_call_suffix
|
||||
)
|
||||
merged["configurable"] = configurable
|
||||
return merged
|
||||
|
||||
|
||||
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
|
||||
"""Pop the resume payload for *this* call's ``tool_call_id``.
|
||||
|
||||
|
|
@ -1,24 +1,14 @@
|
|||
"""Constants shared by the checkpointed subagent middleware."""
|
||||
"""Tuning constants for the checkpointed subagent middleware.
|
||||
|
||||
``EXCLUDED_STATE_KEYS`` and ``DEFAULT_SUBAGENT_RECURSION_LIMIT`` are part of the
|
||||
subagent-invocation contract shared with subagents and now live in
|
||||
``app.agents.chat.multi_agent_chat.subagents.shared.invocation``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
# Mirror of deepagents.middleware.subagents._EXCLUDED_STATE_KEYS.
|
||||
EXCLUDED_STATE_KEYS = frozenset(
|
||||
{
|
||||
"messages",
|
||||
"todos",
|
||||
"structured_response",
|
||||
"skills_metadata",
|
||||
"memory_contents",
|
||||
}
|
||||
)
|
||||
|
||||
# Match the parent graph's budget; the LangGraph default of 25 trips on
|
||||
# multi-step subagent runs.
|
||||
DEFAULT_SUBAGENT_RECURSION_LIMIT = 10_000
|
||||
|
||||
|
||||
def _read_timeout_env(name: str, default: float) -> float:
|
||||
"""Parse ``name`` from the environment; fall back to ``default`` on bad values.
|
||||
|
|
@ -16,7 +16,7 @@ from langchain.agents import create_agent
|
|||
from langchain.chat_models import init_chat_model
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.spec import (
|
||||
from app.agents.chat.multi_agent_chat.subagents.shared.spec import (
|
||||
SURF_CONTEXT_HINT_PROVIDER_KEY,
|
||||
)
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
|
@ -6,7 +6,7 @@ and the ``<tools>`` block render from the same source.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.multi_agent_chat.main_agent.system_prompt.builder.load_md import (
|
||||
from app.agents.chat.multi_agent_chat.main_agent.system_prompt.builder.load_md import (
|
||||
read_prompt_md,
|
||||
)
|
||||
|
||||
|
|
@ -23,7 +23,11 @@ from langchain_core.tools import StructuredTool
|
|||
from langgraph.errors import GraphInterrupt
|
||||
from langgraph.types import Command, Interrupt
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.spec import (
|
||||
from app.agents.chat.multi_agent_chat.subagents.shared.invocation import (
|
||||
EXCLUDED_STATE_KEYS,
|
||||
subagent_invoke_config,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.subagents.shared.spec import (
|
||||
SURF_CONTEXT_HINT_PROVIDER_KEY,
|
||||
ContextHintProvider,
|
||||
)
|
||||
|
|
@ -34,13 +38,11 @@ from .config import (
|
|||
consume_surfsense_resume,
|
||||
drain_parent_null_resume,
|
||||
has_surfsense_resume,
|
||||
subagent_invoke_config,
|
||||
)
|
||||
from .constants import (
|
||||
DEFAULT_SUBAGENT_BATCH_CONCURRENCY,
|
||||
DEFAULT_SUBAGENT_BILLABLE_THRESHOLD,
|
||||
DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS,
|
||||
EXCLUDED_STATE_KEYS,
|
||||
MAX_SUBAGENT_BATCH_SIZE,
|
||||
)
|
||||
from .propagation import wrap_with_tool_call_id
|
||||
|
|
@ -80,13 +82,10 @@ _T = TypeVar("_T")
|
|||
async def _ainvoke_with_timeout[T](
|
||||
coro: Awaitable[_T], *, subagent_type: str, started_at: float
|
||||
) -> _T:
|
||||
"""Apply :data:`DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS` to ``coro``.
|
||||
"""Apply the subagent invoke timeout to ``coro`` (non-positive disables it).
|
||||
|
||||
A non-positive timeout disables the cap (configurable via the
|
||||
``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` env var). On expiry the
|
||||
underlying task is cancelled and :class:`SubagentInvokeTimeoutError` is
|
||||
raised — the caller wraps it into a synthetic ToolMessage so the
|
||||
orchestrator can decide what to do.
|
||||
On expiry the task is cancelled and :class:`SubagentInvokeTimeoutError` is
|
||||
raised for the caller to turn into a synthetic ToolMessage.
|
||||
"""
|
||||
timeout = DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS
|
||||
if timeout <= 0:
|
||||
|
|
@ -149,12 +148,9 @@ def build_task_tool_with_parent_config(
|
|||
subagent_graphs: dict[str, Runnable] = {
|
||||
spec["name"]: spec["runnable"] for spec in subagents
|
||||
}
|
||||
# Per-subagent context-hint providers (see ``SurfSenseSubagentSpec``).
|
||||
# The mapping is sparse: only routes that opted in via ``pack_subagent``
|
||||
# appear here, and the value is invoked once per ``task(...)`` call to
|
||||
# generate a short string prepended to the subagent's first
|
||||
# ``HumanMessage``. Failures are logged and swallowed — a broken hint
|
||||
# provider must never prevent the underlying task from running.
|
||||
# Sparse map of opt-in context-hint providers; each runs once per task()
|
||||
# call to prepend a string to the subagent's first HumanMessage. Failures
|
||||
# are swallowed so a broken hint never blocks the task.
|
||||
subagent_hint_providers: dict[str, ContextHintProvider] = {
|
||||
spec["name"]: provider
|
||||
for spec in subagents
|
||||
|
|
@ -176,24 +172,18 @@ def build_task_tool_with_parent_config(
|
|||
def _billable_call_update(
|
||||
subagent_type: str, runtime: ToolRuntime
|
||||
) -> dict[str, Any]:
|
||||
"""Build the per-call ``billable_calls`` delta + an optional warning.
|
||||
"""Build the per-call ``billable_calls`` delta plus an optional soft-cap warning.
|
||||
|
||||
The orchestrator's ``billable_calls`` map is summed by
|
||||
:func:`_int_counter_merge_reducer`, so we always emit
|
||||
``{subagent_type: 1}`` and let the reducer accumulate. If the
|
||||
cumulative count *after* this call would cross the configured
|
||||
threshold, we also slip a soft ``messages`` entry into the update
|
||||
so the orchestrator can read it on its next step and self-limit.
|
||||
Returning a plain ``dict`` (vs. an extra :class:`Command`) keeps
|
||||
the helper composable with the existing single/batch return paths.
|
||||
Always emits ``{subagent_type: 1}`` (a reducer accumulates it); when this
|
||||
call would cross the threshold, also adds a soft ``messages`` entry so the
|
||||
orchestrator self-limits on its next step.
|
||||
"""
|
||||
delta: dict[str, Any] = {"billable_calls": {subagent_type: 1}}
|
||||
threshold = DEFAULT_SUBAGENT_BILLABLE_THRESHOLD
|
||||
if threshold <= 0:
|
||||
return delta
|
||||
prior = runtime.state.get("billable_calls") or {}
|
||||
# ``prior`` may be a plain dict or a reducer-managed mapping; only
|
||||
# int values are counted so a malformed checkpoint can't crash us.
|
||||
# Count int values only so a malformed checkpoint can't crash us.
|
||||
prior_total = sum(v for v in prior.values() if isinstance(v, int))
|
||||
new_total = prior_total + 1
|
||||
if prior_total < threshold <= new_total:
|
||||
|
|
@ -212,8 +202,7 @@ def build_task_tool_with_parent_config(
|
|||
"""Merge the per-call billable counter (and warning) into ``cmd``."""
|
||||
delta = _billable_call_update(subagent_type, runtime)
|
||||
warn_text = delta.pop("_billable_warn_text", None)
|
||||
# ``cmd.update`` may be a dict or LangGraph ``UpdateDict``; defensively
|
||||
# copy so we don't mutate state shared across other tool returns.
|
||||
# Copy so we don't mutate state shared with other tool returns.
|
||||
update = dict(getattr(cmd, "update", {}) or {})
|
||||
for key, value in delta.items():
|
||||
update[key] = value
|
||||
|
|
@ -226,14 +215,10 @@ def build_task_tool_with_parent_config(
|
|||
return Command(update=update)
|
||||
|
||||
def _safe_message_text(msg: Any) -> str:
|
||||
"""Pull text out of a BaseMessage without trusting the ``.text`` property.
|
||||
"""Pull text out of a BaseMessage without using the ``.text`` property.
|
||||
|
||||
``BaseMessage.text`` walks ``content_blocks`` and crashes with
|
||||
``TypeError: 'NoneType' object is not iterable`` when ``content`` is
|
||||
``None`` (common for tool-call AIMessages whose payload is purely
|
||||
structured). ``getattr(msg, "text", None)`` does not catch this
|
||||
because Python evaluates the property body before falling back to
|
||||
the default. Read ``content`` directly and coerce defensively.
|
||||
``.text`` crashes when ``content`` is ``None`` (common for tool-call
|
||||
AIMessages), and ``getattr`` won't catch it, so read ``content`` directly.
|
||||
"""
|
||||
try:
|
||||
content = getattr(msg, "content", None)
|
||||
|
|
@ -256,23 +241,18 @@ def build_task_tool_with_parent_config(
|
|||
return str(content)
|
||||
|
||||
def _build_tool_trace(messages: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Compress the subagent's message stream into a compact tool trace.
|
||||
"""Compress the subagent's messages into a compact tool trace.
|
||||
|
||||
Each entry is ``{"tool": <name>, "status": "ok"|"error", "preview":
|
||||
<≤120 chars>}`` so the orchestrator can show "this is what your
|
||||
specialist actually did" without dumping the full message stream
|
||||
back through the prompt. The list is attached to the returned
|
||||
ToolMessage's ``additional_kwargs`` (under ``"surf_tool_trace"``);
|
||||
the LLM never sees it, but UI / observability code can pluck it
|
||||
out of the checkpoint.
|
||||
Entries (``{tool, status, preview}``) ride on the ToolMessage's
|
||||
``additional_kwargs["surf_tool_trace"]`` for UI/observability; the LLM
|
||||
never sees them.
|
||||
"""
|
||||
trace: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
tool_name = getattr(msg, "name", None)
|
||||
tool_call_id_attr = getattr(msg, "tool_call_id", None)
|
||||
if not tool_name and not tool_call_id_attr:
|
||||
# Only ToolMessages have either field; skip AIMessage /
|
||||
# HumanMessage / SystemMessage frames.
|
||||
# Only ToolMessages carry either field.
|
||||
continue
|
||||
status = getattr(msg, "status", None) or "ok"
|
||||
preview = _safe_message_text(msg).strip().replace("\n", " ")
|
||||
|
|
@ -306,8 +286,7 @@ def build_task_tool_with_parent_config(
|
|||
)
|
||||
raise ValueError(msg)
|
||||
message_text = _safe_message_text(messages[-1]).rstrip()
|
||||
# Tool-trace is purely observability — wrap defensively so a single
|
||||
# malformed frame never bubbles up and kills the whole user turn.
|
||||
# Trace is observability-only; never let a bad frame kill the turn.
|
||||
try:
|
||||
tool_trace = _build_tool_trace(messages)
|
||||
except Exception:
|
||||
|
|
@ -318,10 +297,7 @@ def build_task_tool_with_parent_config(
|
|||
tool_trace = []
|
||||
tool_msg = ToolMessage(message_text, tool_call_id=tool_call_id)
|
||||
if tool_trace:
|
||||
# ``additional_kwargs`` is a free-form dict on BaseMessage; using
|
||||
# a ``surf_`` prefix avoids collision with provider-specific keys
|
||||
# (e.g. Anthropic's ``cache_control``). The LLM doesn't see it;
|
||||
# consumers (UI, observability) read it off the checkpoint.
|
||||
# surf_ prefix avoids collision with provider keys (e.g. cache_control).
|
||||
tool_msg.additional_kwargs["surf_tool_trace"] = tool_trace
|
||||
return Command(
|
||||
update={
|
||||
|
|
@ -359,9 +335,7 @@ def build_task_tool_with_parent_config(
|
|||
}
|
||||
hint = _resolve_context_hint(subagent_type, description, runtime)
|
||||
if hint:
|
||||
# Prepend as a tagged block so the subagent prompt can pattern-match
|
||||
# on the section (and a future change can lift it into its own
|
||||
# ``SystemMessage`` if needed).
|
||||
# Tagged block so the subagent prompt can pattern-match the section.
|
||||
payload = f"<context_hint>\n{hint}\n</context_hint>\n\n{description}"
|
||||
else:
|
||||
payload = description
|
||||
|
|
@ -372,16 +346,12 @@ def build_task_tool_with_parent_config(
|
|||
results: list[tuple[int, str, dict | str, dict | None]],
|
||||
runtime: ToolRuntime,
|
||||
) -> Command:
|
||||
"""Combine per-child results into one Command with a combined ToolMessage.
|
||||
"""Combine per-child results into one Command with an aggregate ToolMessage.
|
||||
|
||||
``results`` is a list of ``(task_index, subagent_type,
|
||||
payload_or_error_text, child_state_update)`` tuples — preserving the
|
||||
input order so the orchestrator can map each block back to the task
|
||||
it dispatched. State updates are merged by reducer for keys outside
|
||||
:data:`EXCLUDED_STATE_KEYS`; everything else (``messages``, ``todos``,
|
||||
etc.) is replaced by the synthesized aggregate ToolMessage. Every
|
||||
child also contributes a ``billable_calls`` increment so cost
|
||||
accounting matches single-mode dispatch.
|
||||
``results`` tuples are ``(task_index, subagent_type, payload_or_error,
|
||||
child_state_update)``; output blocks are sorted by index so the LLM can
|
||||
map them back to dispatch order, and each child contributes a
|
||||
``billable_calls`` increment to match single-mode accounting.
|
||||
"""
|
||||
results.sort(key=lambda r: r[0])
|
||||
merged_state: dict[str, Any] = {}
|
||||
|
|
@ -422,8 +392,8 @@ def build_task_tool_with_parent_config(
|
|||
}
|
||||
)
|
||||
if state_update:
|
||||
# Naive merge: later tasks win on scalar collisions; reducer-backed
|
||||
# fields (``receipts``, ``files`` etc.) accumulate at apply time.
|
||||
# Later tasks win on scalar collisions; reducer-backed fields
|
||||
# accumulate at apply time.
|
||||
merged_state.update(state_update)
|
||||
aggregate = "\n\n".join(message_blocks)
|
||||
aggregate_msg = ToolMessage(
|
||||
|
|
@ -467,11 +437,9 @@ def build_task_tool_with_parent_config(
|
|||
) -> tuple[int, str, dict | str, dict | None]:
|
||||
"""Run one child of a batched ``task`` call under the concurrency cap.
|
||||
|
||||
Errors are returned as plain text in slot 2 so a single child's
|
||||
failure does not abort the whole batch. ``GraphInterrupt`` from a
|
||||
batched child is currently treated as a hard failure for that child
|
||||
only — batched HITL is intentionally out of scope for the v1
|
||||
rollout (see plan tier 2 item 4 risks).
|
||||
Errors are returned as text (slot 2) so one child's failure doesn't abort
|
||||
the batch. A child's ``GraphInterrupt`` is a hard failure for that child:
|
||||
batched HITL is intentionally out of scope.
|
||||
"""
|
||||
async with semaphore:
|
||||
if subagent_type not in subagent_graphs:
|
||||
|
|
@ -505,8 +473,7 @@ def build_task_tool_with_parent_config(
|
|||
)
|
||||
return (task_index, subagent_type, str(exc), None)
|
||||
except GraphInterrupt:
|
||||
# Batched HITL is unsupported in v1 — surface as a failure
|
||||
# for this child so the rest of the batch still completes.
|
||||
# Batched HITL unsupported; fail this child so the batch finishes.
|
||||
logger.warning(
|
||||
"Batch child %d (%s) raised GraphInterrupt; batched HITL "
|
||||
"is not supported. Re-dispatch this task as a single "
|
||||
|
|
@ -543,14 +510,11 @@ def build_task_tool_with_parent_config(
|
|||
return (task_index, subagent_type, result, child_state_update)
|
||||
|
||||
def _coerce_batch_arg(tasks: Any) -> list[dict] | str:
|
||||
"""Rescue common LLM-side malformations of the ``tasks`` argument.
|
||||
"""Rescue common LLM malformations of the ``tasks`` argument.
|
||||
|
||||
Some providers serialise an array argument as a JSON-encoded string,
|
||||
and small models occasionally hand back a single ``{description,
|
||||
subagent_type}`` dict instead of a one-element array. Both are
|
||||
recovered here with a WARN log so the issue is visible in metrics
|
||||
but the user's turn still completes; truly broken shapes return a
|
||||
plain string that the caller surfaces as the tool error.
|
||||
Recovers a JSON-encoded array string and a single dict (instead of a
|
||||
1-element array), logging a WARN. Unrecoverable shapes return a string
|
||||
the caller surfaces as the tool error.
|
||||
"""
|
||||
if isinstance(tasks, list):
|
||||
return tasks
|
||||
|
|
@ -585,13 +549,10 @@ def build_task_tool_with_parent_config(
|
|||
async def _adispatch_batch(
|
||||
tasks: list[dict], runtime: ToolRuntime
|
||||
) -> Command | str:
|
||||
"""Fan-out helper for the ``tasks`` array shape.
|
||||
"""Fan out the ``tasks`` array (size- and concurrency-capped).
|
||||
|
||||
Bounded by :data:`MAX_SUBAGENT_BATCH_SIZE` and concurrency-capped
|
||||
at :data:`DEFAULT_SUBAGENT_BATCH_CONCURRENCY`. Returns a single
|
||||
:class:`Command` that the LLM sees as one ToolMessage per child,
|
||||
prefixed with ``[task <index>]`` so it can map back to the input
|
||||
order.
|
||||
Returns one Command; the LLM sees one ``[task <index>]``-prefixed block
|
||||
per child, in input order.
|
||||
"""
|
||||
if not tasks:
|
||||
return "tasks: array is empty; nothing to dispatch."
|
||||
|
|
@ -701,17 +662,16 @@ def build_task_tool_with_parent_config(
|
|||
if pending_value is not None:
|
||||
resume_value = consume_surfsense_resume(runtime)
|
||||
if resume_value is None:
|
||||
# Bridge invariant: a queued resume must accompany any pending
|
||||
# subagent interrupt. Fall-through replay would silently re-prompt
|
||||
# the user; raise so the streaming layer surfaces a clear error.
|
||||
# A pending interrupt must have a queued resume; otherwise replay
|
||||
# would silently re-prompt the user. Raise instead.
|
||||
raise RuntimeError(
|
||||
f"Subagent {subagent_type!r} has a pending interrupt but no "
|
||||
"surfsense_resume_value on config; resume bridge is broken."
|
||||
)
|
||||
expected = hitlrequest_action_count(pending_value)
|
||||
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
||||
# Prevent the parent's resume payload from leaking into subagent
|
||||
# interrupts via langgraph's parent_scratchpad fallback.
|
||||
# Stop the parent's resume leaking into subagent interrupts via
|
||||
# langgraph's parent_scratchpad fallback.
|
||||
drain_parent_null_resume(runtime)
|
||||
with ot.subagent_invoke_span(
|
||||
subagent_type=subagent_type, path=invoke_path
|
||||
|
|
@ -827,10 +787,8 @@ def build_task_tool_with_parent_config(
|
|||
] = None,
|
||||
) -> str | Command:
|
||||
atask_start = time.perf_counter()
|
||||
# Kill switch: when ops flips the spawn-paused flag for this
|
||||
# workspace, every ``task(...)`` invocation (single- or batch-mode)
|
||||
# short-circuits with a clear ToolMessage so the orchestrator can
|
||||
# tell the user what happened and stop hammering downstream APIs.
|
||||
# Ops kill switch: short-circuit every task() call for this workspace
|
||||
# so the orchestrator stops hammering downstream APIs.
|
||||
if await is_spawn_paused(search_space_id):
|
||||
logger.warning(
|
||||
"[hitl_route] atask SPAWN_PAUSED: search_space_id=%s tool_call_id=%s",
|
||||
|
|
@ -921,8 +879,8 @@ def build_task_tool_with_parent_config(
|
|||
)
|
||||
expected = hitlrequest_action_count(pending_value)
|
||||
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
||||
# Prevent the parent's resume payload from leaking into subagent
|
||||
# interrupts via langgraph's parent_scratchpad fallback.
|
||||
# Stop the parent's resume leaking into subagent interrupts via
|
||||
# langgraph's parent_scratchpad fallback.
|
||||
drain_parent_null_resume(runtime)
|
||||
with ot.subagent_invoke_span(
|
||||
subagent_type=subagent_type, path=invoke_path
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""Context-editing middleware: spill + clear-tool-uses passes (impl + builder)."""
|
||||
|
||||
from .builder import build_context_editing_mw
|
||||
from .middleware import (
|
||||
ClearToolUsesEdit,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ClearToolUsesEdit",
|
||||
"SpillToBackendEdit",
|
||||
"SpillingContextEditingMiddleware",
|
||||
"build_context_editing_mw",
|
||||
]
|
||||
|
|
@ -7,18 +7,18 @@ from typing import Any
|
|||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.main_agent.context_prune.prune_tool_names import (
|
||||
from app.agents.chat.multi_agent_chat.main_agent.context_prune.prune_tool_names import (
|
||||
safe_exclude_tools,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import (
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
|
||||
|
||||
from .middleware import (
|
||||
ClearToolUsesEdit,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
)
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_context_editing_mw(
|
||||
*,
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Middleware that deduplicates HITL tool calls within a single LLM response.
|
||||
"""Drop duplicate HITL tool calls before execution.
|
||||
|
||||
When the LLM emits multiple calls to the same HITL tool with the same
|
||||
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
|
||||
|
|
@ -9,72 +9,33 @@ the duplicate call is stripped from the AIMessage that gets checkpointed.
|
|||
That means it is also safe across LangGraph ``interrupt()`` boundaries:
|
||||
the removed call will never appear on graph resume.
|
||||
|
||||
Dedup-key resolution order:
|
||||
Dedup-key resolution order (read from each tool's own ``metadata``):
|
||||
|
||||
1. :class:`ToolDefinition.dedup_key` — callable provided by the registry
|
||||
entry. This is the canonical mechanism.
|
||||
2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg name;
|
||||
used by MCP / Composio tools whose schemas the registry doesn't see.
|
||||
1. ``tool.metadata["dedup_key"]`` — callable mapping the args dict to a
|
||||
stable signature string. This is the canonical mechanism.
|
||||
2. ``tool.metadata["hitl_dedup_key"]`` — string naming a primary arg;
|
||||
used by MCP / Composio tools that only expose a single key field.
|
||||
|
||||
A tool with no resolver from either path simply opts out of dedup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.dedup_tool_calls import (
|
||||
DedupResolver,
|
||||
wrap_dedup_key_by_arg_name,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Resolver type — given the tool ``args`` dict returns a stable
|
||||
# string used to dedupe consecutive calls. ``None`` means no dedup.
|
||||
DedupResolver = Callable[[dict[str, Any]], str]
|
||||
|
||||
|
||||
def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
|
||||
"""Adapt a string-arg name into a :data:`DedupResolver`.
|
||||
|
||||
Convenience helper used by registry entries that just want to dedupe
|
||||
on a single arg's lowercased value (the most common case for native
|
||||
HITL tools like ``send_gmail_email`` keyed on ``subject``).
|
||||
|
||||
Example::
|
||||
|
||||
ToolDefinition(
|
||||
name="send_gmail_email",
|
||||
...,
|
||||
dedup_key=wrap_dedup_key_by_arg_name("subject"),
|
||||
)
|
||||
"""
|
||||
|
||||
def _resolver(args: dict[str, Any]) -> str:
|
||||
return str(args.get(arg_name, "")).lower()
|
||||
|
||||
return _resolver
|
||||
|
||||
|
||||
def dedup_key_full_args(args: dict[str, Any]) -> str:
|
||||
"""Resolver that collapses calls only when **every** argument is identical.
|
||||
|
||||
Safe default for tools where no single field uniquely identifies a call
|
||||
(e.g. MCP tools whose first required field is a shared workspace id).
|
||||
"""
|
||||
|
||||
try:
|
||||
return json.dumps(args, sort_keys=True, default=str)
|
||||
except (TypeError, ValueError):
|
||||
return repr(sorted(args.items())) if isinstance(args, dict) else repr(args)
|
||||
|
||||
|
||||
# Backwards-compatible alias for code that imported the original
|
||||
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
|
||||
_wrap_string_key = wrap_dedup_key_by_arg_name
|
||||
|
||||
|
||||
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Remove duplicate HITL tool calls from a single LLM response.
|
||||
|
|
@ -84,9 +45,8 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
The dedup-resolver map is built from two sources, in priority order:
|
||||
|
||||
1. ``tool.metadata["dedup_key"]`` — callable provided by the registry's
|
||||
``ToolDefinition.dedup_key``. Receives the args dict and returns
|
||||
a string signature. This is the canonical mechanism.
|
||||
1. ``tool.metadata["dedup_key"]`` — callable that receives the args dict
|
||||
and returns a string signature. This is the canonical mechanism.
|
||||
2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg
|
||||
name; primarily used by MCP / Composio tools.
|
||||
"""
|
||||
|
|
@ -162,3 +122,7 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
|
||||
return {"messages": [updated_msg]}
|
||||
|
||||
|
||||
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
|
||||
return DedupHITLToolCallsMiddleware(agent_tools=list(tools))
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Doom-loop middleware: detect repeated identical tool calls (impl + builder)."""
|
||||
|
||||
from .builder import build_doom_loop_mw
|
||||
from .middleware import DoomLoopMiddleware
|
||||
|
||||
__all__ = [
|
||||
"DoomLoopMiddleware",
|
||||
"build_doom_loop_mw",
|
||||
]
|
||||
|
|
@ -2,10 +2,10 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import DoomLoopMiddleware
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
|
||||
|
||||
from ..shared.flags import enabled
|
||||
from .middleware import DoomLoopMiddleware
|
||||
|
||||
|
||||
def build_doom_loop_mw(flags: AgentFeatureFlags) -> DoomLoopMiddleware | None:
|
||||
|
|
@ -16,7 +16,7 @@ This ships **OFF by default** until the frontend explicitly handles
|
|||
``context.permission == "doom_loop"`` interrupts.
|
||||
|
||||
Wire format: uses SurfSense's existing ``interrupt()`` payload shape
|
||||
(see ``app/agents/new_chat/tools/hitl.py``):
|
||||
(see ``app/agents/shared/tools/hitl.py``):
|
||||
|
||||
{
|
||||
"type": "permission_ask",
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
"""End-of-turn KB persistence middleware (main-agent only)."""
|
||||
|
||||
from .builder import build_kb_persistence_mw
|
||||
from .middleware import (
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
commit_staged_filesystem_state,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"KnowledgeBasePersistenceMiddleware",
|
||||
"build_kb_persistence_mw",
|
||||
"commit_staged_filesystem_state",
|
||||
]
|
||||
|
|
@ -2,8 +2,11 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import KnowledgeBasePersistenceMiddleware
|
||||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
from .middleware import (
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
)
|
||||
|
||||
|
||||
def build_kb_persistence_mw(
|
||||
|
|
@ -1,33 +1,19 @@
|
|||
"""End-of-turn persistence for the cloud-mode SurfSense filesystem.
|
||||
|
||||
This middleware runs ``aafter_agent`` once per turn (cloud only). It commits
|
||||
all staged folder creations, file moves, content writes/edits, file deletes
|
||||
(``rm``), and directory deletes (``rmdir``) to Postgres in a single ordered
|
||||
pass:
|
||||
Runs ``aafter_agent`` once per turn (cloud only), committing staged folder
|
||||
creates, moves, writes/edits, and ``rm``/``rmdir`` to Postgres in one ordered
|
||||
pass. Order matters: moves resolve before writes (so write-then-move lands at
|
||||
the final path), and file deletes run before directory deletes (so a same-turn
|
||||
``rm /a/x.md`` + ``rmdir /a`` works).
|
||||
|
||||
1. Materialize ``staged_dirs`` into ``Folder`` rows.
|
||||
2. Apply ``pending_moves`` in order (chained moves resolved via
|
||||
``doc_id_by_path``).
|
||||
3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move
|
||||
sequences commit at the final path. Paths queued for ``rm`` this turn
|
||||
are dropped here so a write+rm sequence doesn't recreate the doc.
|
||||
4. Commit content writes / edits for ``/documents/*`` paths, skipping
|
||||
``temp_*`` basenames.
|
||||
5. Apply ``pending_deletes`` (``rm``) — file deletes run BEFORE directory
|
||||
deletes so a same-turn ``rm /a/x.md`` + ``rmdir /a`` sequence works.
|
||||
6. Apply ``pending_dir_deletes`` (``rmdir``); re-verifies emptiness against
|
||||
the post-step-5 DB state.
|
||||
When ``flags.enable_action_log`` is on, each destructive op also snapshots a
|
||||
``DocumentRevision`` / ``FolderRevision`` for revert. For ``rm``/``rmdir`` the
|
||||
snapshot and DELETE share a SAVEPOINT, so a failed snapshot aborts the delete
|
||||
rather than making the data silently irreversible.
|
||||
|
||||
When ``flags.enable_action_log`` is on every destructive op also writes a
|
||||
``DocumentRevision`` / ``FolderRevision`` snapshot bound to the
|
||||
originating ``AgentActionLog`` row via ``tool_call_id``. ``rm``/``rmdir``
|
||||
share a single ``SAVEPOINT`` with their snapshot — if the snapshot fails
|
||||
the DELETE rolls back and we surface the error rather than silently
|
||||
making the data irreversible.
|
||||
|
||||
The commit body is exposed as a free function ``commit_staged_filesystem_state``
|
||||
so the optional stream-task fallback (``stream_new_chat.py``) can call the
|
||||
exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect).
|
||||
The commit body is a free function (``commit_staged_filesystem_state``) so the
|
||||
stream-task fallback can run the identical routine when ``aafter_agent`` was
|
||||
skipped (e.g. client disconnect).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -45,17 +31,22 @@ from sqlalchemy import delete, select, update
|
|||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
|
||||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.receipt import (
|
||||
Receipt,
|
||||
make_receipt,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.state.reducers import _CLEAR
|
||||
from app.agents.chat.runtime.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
parse_documents_path,
|
||||
safe_folder_segment,
|
||||
virtual_path_to_doc,
|
||||
)
|
||||
from app.agents.new_chat.state_reducers import _CLEAR
|
||||
from app.agents.shared.receipt import Receipt, make_receipt
|
||||
from app.db import (
|
||||
AgentActionLog,
|
||||
Chunk,
|
||||
|
|
@ -211,11 +202,9 @@ async def _create_document(
|
|||
virtual_path,
|
||||
search_space_id,
|
||||
)
|
||||
# Filesystem-parity invariant: the only thing that *must* be unique is
|
||||
# the path. Two notes can legitimately share content (e.g. ``cp a b``).
|
||||
# Guard against the path-derived ``unique_identifier_hash`` constraint
|
||||
# so we surface a clean ValueError instead of letting the INSERT poison
|
||||
# the session with an IntegrityError.
|
||||
# Pre-check the path-derived unique_identifier_hash so a duplicate path
|
||||
# surfaces as a clean ValueError instead of an INSERT IntegrityError that
|
||||
# poisons the session. Content is intentionally not unique (cp a b).
|
||||
path_collision = await session.execute(
|
||||
select(Document.id).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
|
|
@ -227,13 +216,6 @@ async def _create_document(
|
|||
f"a document already exists at path '{virtual_path}' "
|
||||
"(unique_identifier_hash collision)"
|
||||
)
|
||||
# ``content_hash`` is intentionally NOT checked for uniqueness here.
|
||||
# In a real filesystem two files at different paths can hold identical
|
||||
# bytes, and the agent's ``write_file`` path needs that semantic to
|
||||
# support copy/duplicate operations. The hash remains useful as a
|
||||
# change-detection hint for connector indexers, which still consult it
|
||||
# via :func:`check_duplicate_document` but do so with a non-unique
|
||||
# lookup (``.first()``).
|
||||
content_hash = generate_content_hash(content, search_space_id)
|
||||
doc = Document(
|
||||
title=title,
|
||||
|
|
@ -430,15 +412,9 @@ async def _mark_action_reversible(
|
|||
) -> None:
|
||||
"""Flip ``agent_action_log.reversible = TRUE`` for ``action_id``.
|
||||
|
||||
Best-effort: caller may invoke from inside a SAVEPOINT and treat
|
||||
failure as a soft demotion (snapshot persists, just no Revert button).
|
||||
|
||||
Callers should also call ``_dispatch_reversibility_update`` (defined
|
||||
below) AFTER the enclosing SAVEPOINT block exits successfully so the
|
||||
chat tool card can light up its Revert button without
|
||||
re-fetching ``GET /threads/.../actions``. Dispatching from inside the
|
||||
SAVEPOINT would risk emitting "reversible=true" for rows whose
|
||||
update gets rolled back if the surrounding destructive op fails.
|
||||
Pair with ``_dispatch_reversibility_update`` *after* the enclosing
|
||||
SAVEPOINT commits, so the UI never sees ``reversible=true`` for a row whose
|
||||
update later rolls back.
|
||||
"""
|
||||
if action_id is None:
|
||||
return
|
||||
|
|
@ -450,22 +426,11 @@ async def _mark_action_reversible(
|
|||
|
||||
|
||||
async def _dispatch_reversibility_update(action_id: int | None) -> None:
|
||||
"""Best-effort dispatch of an ``action_log_updated`` custom event.
|
||||
"""Emit an ``action_log_updated`` SSE event so the Revert button lights up.
|
||||
|
||||
Surfaces the post-SAVEPOINT reversibility flip to the SSE layer so
|
||||
the chat tool card can flip its Revert button live. Defensive:
|
||||
failures are logged at debug level and swallowed; the
|
||||
REST endpoint ``GET /threads/.../actions`` is still authoritative.
|
||||
|
||||
.. warning::
|
||||
Inside :func:`commit_staged_filesystem_state` we DEFER all
|
||||
dispatches until the outer ``session.commit()`` succeeds — see
|
||||
the ``deferred_dispatches`` queue in that function. Dispatching
|
||||
from inside a SAVEPOINT block while the outer transaction is
|
||||
still pending would emit ``reversible=true`` for rows whose
|
||||
snapshots get rolled back if the outer commit fails. Direct
|
||||
callers (e.g. the optional stream-task fallback) that own the
|
||||
full session lifetime can still call this helper inline.
|
||||
Best-effort (failures swallowed; the REST actions endpoint is
|
||||
authoritative). Inside :func:`commit_staged_filesystem_state` this is
|
||||
deferred until after the outer commit via ``deferred_dispatches``.
|
||||
"""
|
||||
if action_id is None:
|
||||
return
|
||||
|
|
@ -484,12 +449,9 @@ async def _dispatch_reversibility_update(action_id: int | None) -> None:
|
|||
# ---------------------------------------------------------------------------
|
||||
# Snapshot helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Best-effort helpers swallow + log so a snapshot failure can never break
|
||||
# the destructive op for non-destructive tools (write/edit/move/mkdir).
|
||||
# Strict helpers run inside the SAME ``begin_nested()`` SAVEPOINT as the
|
||||
# destructive DELETE — failure aborts the savepoint and leaves the doc /
|
||||
# folder intact, so revertable ops never become irreversible silently.
|
||||
# Best-effort variants (write/edit/move/mkdir) swallow failures. Strict
|
||||
# variants (rm/rmdir) share the destructive op's SAVEPOINT so a snapshot
|
||||
# failure aborts the delete instead of making it silently irreversible.
|
||||
|
||||
|
||||
def _doc_revision_payload(
|
||||
|
|
@ -699,15 +661,9 @@ async def commit_staged_filesystem_state(
|
|||
) -> dict[str, Any] | None:
|
||||
"""Commit all staged filesystem changes; return the state delta for reducers.
|
||||
|
||||
Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent`
|
||||
and the optional stream-task fallback.
|
||||
|
||||
When ``flags.enable_action_log`` is on every destructive op also writes
|
||||
a ``DocumentRevision`` / ``FolderRevision`` snapshot bound to the
|
||||
originating ``AgentActionLog`` row via ``tool_call_id``. Snapshot
|
||||
durability is best-effort for non-destructive ops and STRICT for
|
||||
``rm``/``rmdir`` (snapshot + DELETE share a SAVEPOINT — snapshot
|
||||
failure aborts the delete).
|
||||
Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` and
|
||||
the stream-task fallback. See the module docstring for ordering and the
|
||||
action-log snapshot/revert semantics.
|
||||
"""
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
|
|
@ -766,8 +722,7 @@ async def commit_staged_filesystem_state(
|
|||
flags = get_flags()
|
||||
snapshot_enabled = flags.enable_action_log
|
||||
|
||||
# De-duplicate pending deletes per-path while preserving the latest
|
||||
# tool_call_id (the one the user is most likely to revert via the UI).
|
||||
# De-dup deletes per-path, keeping the latest tool_call_id (likeliest revert).
|
||||
file_delete_paths: dict[str, str] = {}
|
||||
for entry in pending_deletes:
|
||||
if not isinstance(entry, dict):
|
||||
|
|
@ -791,22 +746,14 @@ async def commit_staged_filesystem_state(
|
|||
applied_moves: list[dict[str, Any]] = []
|
||||
doc_id_path_tombstones: dict[str, int | None] = {}
|
||||
tree_changed = False
|
||||
# Reversibility-flip dispatches are deferred until AFTER the outer
|
||||
# ``session.commit()`` succeeds. Dispatching from inside the
|
||||
# SAVEPOINT chain while the outer transaction is still pending
|
||||
# would emit ``reversible=true`` for rows whose snapshots get rolled
|
||||
# back if the final commit raises. Snapshot helpers append on
|
||||
# success; we drain this list after commit and silently abandon it
|
||||
# on rollback so the UI stays consistent with durable state.
|
||||
# Reversibility-flip dispatches are drained only after the outer commit
|
||||
# succeeds (and abandoned on rollback), so the UI never sees reversible=true
|
||||
# for a snapshot that didn't durably land.
|
||||
deferred_dispatches: list[int] = []
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
# ------------------------------------------------------------------
|
||||
# Resolve action-id bindings up front. One SELECT per turn for all
|
||||
# tool_call_ids, NOT one per op — important because a turn that
|
||||
# touches 50 paths would otherwise issue 50 lookups.
|
||||
# ------------------------------------------------------------------
|
||||
# Resolve all action-id bindings in one SELECT per turn, not per op.
|
||||
action_id_by_call: dict[str, int] = {}
|
||||
if snapshot_enabled and thread_id is not None:
|
||||
tool_call_ids: set[str] = set()
|
||||
|
|
@ -839,10 +786,7 @@ async def commit_staged_filesystem_state(
|
|||
next(iter(action_id_by_call), None) if action_id_by_call else None
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. staged_dirs -> Folder rows. Snapshot post-flush so the new
|
||||
# folder_id is available for the FK.
|
||||
# ------------------------------------------------------------------
|
||||
# 1. staged_dirs -> Folder rows (snapshot post-flush for the FK).
|
||||
for folder_path in staged_dirs:
|
||||
if not isinstance(folder_path, str):
|
||||
continue
|
||||
|
|
@ -863,7 +807,6 @@ async def commit_staged_filesystem_state(
|
|||
tcid = staged_dir_tool_calls.get(folder_path)
|
||||
action_id = _action_id_for(tcid)
|
||||
if action_id is not None:
|
||||
# Re-read the folder for the snapshot.
|
||||
result = await session.execute(
|
||||
select(Folder).where(Folder.id == folder_id)
|
||||
)
|
||||
|
|
@ -878,16 +821,13 @@ async def commit_staged_filesystem_state(
|
|||
deferred_dispatches=deferred_dispatches,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. pending_moves. Snapshot pre-move (in-place restore on revert).
|
||||
# ------------------------------------------------------------------
|
||||
# 2. pending_moves (snapshot pre-move for in-place restore on revert).
|
||||
for move in pending_moves:
|
||||
source = str(move.get("source") or "")
|
||||
if snapshot_enabled and source:
|
||||
tcid = str(move.get("tool_call_id") or "")
|
||||
action_id = _action_id_for(tcid)
|
||||
if action_id is not None:
|
||||
# Resolve the doc to snapshot BEFORE we mutate it.
|
||||
doc_id_pre = doc_id_by_path.get(source)
|
||||
document_pre: Document | None = None
|
||||
if doc_id_pre is not None:
|
||||
|
|
@ -937,10 +877,8 @@ async def commit_staged_filesystem_state(
|
|||
path = move_alias[path]
|
||||
return path
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. dirty_paths -> writes/edits. Skip any path queued for ``rm``
|
||||
# this turn so a write+rm sequence doesn't recreate the doc.
|
||||
# ------------------------------------------------------------------
|
||||
# 3. dirty_paths -> writes/edits. Paths queued for rm this turn are
|
||||
# skipped so a write+rm sequence doesn't recreate the doc.
|
||||
kb_dirty_seen: set[str] = set()
|
||||
kb_dirty: list[str] = []
|
||||
kb_dirty_origin: dict[str, str] = {}
|
||||
|
|
@ -969,9 +907,7 @@ async def commit_staged_filesystem_state(
|
|||
continue
|
||||
content = "\n".join(file_data.get("content") or [])
|
||||
doc_id = doc_id_by_path.get(path)
|
||||
# Path ↔ tool_call_id binding: the dirty_paths list dedupes via
|
||||
# _add_unique_reducer, so we look up the latest tool_call_id by
|
||||
# path (or by the un-renamed origin).
|
||||
# Look up tool_call_id by final path or its pre-rename origin.
|
||||
origin = kb_dirty_origin.get(path, path)
|
||||
tcid = dirty_path_tool_calls.get(path) or dirty_path_tool_calls.get(
|
||||
origin
|
||||
|
|
@ -979,12 +915,9 @@ async def commit_staged_filesystem_state(
|
|||
action_id = _action_id_for(tcid)
|
||||
|
||||
if doc_id is None:
|
||||
# The in-memory ``doc_id_by_path`` is per-thread and starts
|
||||
# empty in every new chat. If the agent writes to a path
|
||||
# that already exists in the DB (e.g. a previous chat's
|
||||
# ``notes.md``), we must NOT try to INSERT — it would hit
|
||||
# ``unique_identifier_hash`` (path-derived). Look up the
|
||||
# existing doc and update it in place instead.
|
||||
# doc_id_by_path is per-thread and empty in a new chat, so a
|
||||
# write to a path already in the DB must update in place, not
|
||||
# INSERT (which would hit the path-derived unique hash).
|
||||
existing = await virtual_path_to_doc(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
|
|
@ -1033,12 +966,9 @@ async def commit_staged_filesystem_state(
|
|||
}
|
||||
)
|
||||
else:
|
||||
# Fresh create. Wrap each create in a SAVEPOINT so a
|
||||
# residual ``IntegrityError`` (e.g. a deployment that
|
||||
# hasn't run migration 133 yet, where
|
||||
# ``documents.content_hash`` still carries its legacy
|
||||
# global UNIQUE constraint) rolls back only this one
|
||||
# create instead of poisoning the whole turn.
|
||||
# Fresh create, wrapped in a SAVEPOINT so a residual
|
||||
# IntegrityError (e.g. pre-migration-133 content_hash UNIQUE)
|
||||
# rolls back only this create, not the whole turn.
|
||||
placeholder_revision_id: int | None = None
|
||||
if snapshot_enabled and action_id is not None:
|
||||
placeholder_revision_id = await _snapshot_document_pre_create(
|
||||
|
|
@ -1061,8 +991,7 @@ async def commit_staged_filesystem_state(
|
|||
logger.warning(
|
||||
"kb_persistence: skipping %s create: %s", path, exc
|
||||
)
|
||||
# Roll back the placeholder revision since the create
|
||||
# never happened.
|
||||
# Create never happened; drop its placeholder revision.
|
||||
if placeholder_revision_id is not None:
|
||||
await session.execute(
|
||||
delete(DocumentRevision).where(
|
||||
|
|
@ -1109,19 +1038,14 @@ async def commit_staged_filesystem_state(
|
|||
)
|
||||
tree_changed = True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. pending_deletes -> ``rm``. STRICT durability: snapshot + DELETE
|
||||
# share a SAVEPOINT. If the snapshot insert fails, the DELETE
|
||||
# rolls back too and we surface the error rather than silently
|
||||
# making the data irreversible.
|
||||
# ------------------------------------------------------------------
|
||||
# 4. pending_deletes -> rm. Strict: snapshot + DELETE share a
|
||||
# SAVEPOINT, so a failed snapshot rolls the delete back too.
|
||||
for raw_path, tcid in file_delete_paths.items():
|
||||
final = _final_path(raw_path)
|
||||
if not final.startswith(DOCUMENTS_ROOT + "/"):
|
||||
continue
|
||||
action_id = _action_id_for(tcid)
|
||||
|
||||
# Resolve the doc.
|
||||
doc_id_for_delete = doc_id_by_path.get(final)
|
||||
document_to_delete: Document | None = None
|
||||
if doc_id_for_delete is not None:
|
||||
|
|
@ -1150,7 +1074,6 @@ async def commit_staged_filesystem_state(
|
|||
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
# Strict: snapshot first; failure aborts the delete.
|
||||
if snapshot_enabled and action_id is not None:
|
||||
chunks = await _load_chunks_for_snapshot(
|
||||
session, doc_id=doc_pk
|
||||
|
|
@ -1179,10 +1102,7 @@ async def commit_staged_filesystem_state(
|
|||
)
|
||||
continue
|
||||
|
||||
# B1 — SAVEPOINT released. Defer the reversibility-flip
|
||||
# dispatch until AFTER the outer commit succeeds so we
|
||||
# never tell the UI a row is reversible if its snapshot
|
||||
# gets rolled back.
|
||||
# Defer the reversibility flip until after the outer commit.
|
||||
if snapshot_enabled and action_id is not None:
|
||||
deferred_dispatches.append(int(action_id))
|
||||
|
||||
|
|
@ -1201,11 +1121,8 @@ async def commit_staged_filesystem_state(
|
|||
)
|
||||
tree_changed = True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. pending_dir_deletes -> ``rmdir``. STRICT durability + final
|
||||
# emptiness check (after step 4's deletes have run, an "empty
|
||||
# mid-turn" directory really IS empty in DB now).
|
||||
# ------------------------------------------------------------------
|
||||
# 5. pending_dir_deletes -> rmdir. Strict, and re-checks emptiness
|
||||
# against post-step-4 DB state.
|
||||
for raw_path, tcid in dir_delete_paths.items():
|
||||
final = _final_path(raw_path)
|
||||
if not final.startswith(DOCUMENTS_ROOT + "/"):
|
||||
|
|
@ -1226,7 +1143,6 @@ async def commit_staged_filesystem_state(
|
|||
)
|
||||
continue
|
||||
|
||||
# Re-check emptiness against in-DB state.
|
||||
docs_in_folder = await session.execute(
|
||||
select(Document.id)
|
||||
.where(Document.folder_id == folder_id)
|
||||
|
|
@ -1291,10 +1207,7 @@ async def commit_staged_filesystem_state(
|
|||
)
|
||||
continue
|
||||
|
||||
# B1 — SAVEPOINT released. Defer the reversibility-flip
|
||||
# dispatch until AFTER the outer commit succeeds so we
|
||||
# never tell the UI a row is reversible if its snapshot
|
||||
# gets rolled back.
|
||||
# Defer the reversibility flip until after the outer commit.
|
||||
if snapshot_enabled and action_id is not None:
|
||||
deferred_dispatches.append(int(action_id))
|
||||
|
||||
|
|
@ -1314,18 +1227,13 @@ async def commit_staged_filesystem_state(
|
|||
logger.exception(
|
||||
"kb_persistence: commit failed (search_space=%s)", search_space_id
|
||||
)
|
||||
# Outer commit raised — every SAVEPOINT-released change above
|
||||
# (snapshots + reversibility flips) is now rolled back. Drop
|
||||
# the deferred SSE dispatches so the UI stays consistent with
|
||||
# durable state.
|
||||
# Outer commit raised: everything above rolled back, so drop the
|
||||
# deferred dispatches.
|
||||
deferred_dispatches.clear()
|
||||
return None
|
||||
|
||||
# Outer commit succeeded; flush deferred reversibility-flip
|
||||
# dispatches now so the chat tool card can light up its Revert
|
||||
# button without re-fetching ``GET /threads/.../actions``. De-dup
|
||||
# to avoid emitting the same id twice (e.g. write-then-rm in the
|
||||
# same turn dispatches once for each snapshot site).
|
||||
# Commit succeeded; flush deferred reversibility flips (de-duped, since
|
||||
# write-then-rm in one turn appends an id per snapshot site).
|
||||
if deferred_dispatches and dispatch_events:
|
||||
for action_id in dict.fromkeys(deferred_dispatches):
|
||||
try:
|
||||
|
|
@ -1371,9 +1279,8 @@ async def commit_staged_filesystem_state(
|
|||
p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX)
|
||||
]
|
||||
|
||||
# Tombstone every committed-delete path so a stale ``state["files"]`` entry
|
||||
# (which als_info would otherwise interpret as content) cannot survive into
|
||||
# the next turn and make a now-empty folder look non-empty.
|
||||
# Tombstone committed-delete paths so a stale state["files"] entry can't
|
||||
# survive into the next turn and make a now-empty folder look non-empty.
|
||||
deleted_file_paths = [
|
||||
str(payload.get("virtualPath") or "")
|
||||
for payload in committed_deletes
|
||||
|
|
@ -1394,11 +1301,8 @@ async def commit_staged_filesystem_state(
|
|||
"dirty_path_tool_calls": {_CLEAR: True},
|
||||
}
|
||||
|
||||
# Emit one Receipt per committed mutation, folded into ``state['receipts']``
|
||||
# via ``_list_append_reducer``. The receipts surface what actually committed
|
||||
# (post-savepoint) rather than what the LLM intended; the orchestrator uses
|
||||
# them as ground truth in the ``<verification>`` teaching. KB writes do not
|
||||
# have public verifiable URLs, so ``verifiable_url`` stays unset.
|
||||
# One Receipt per committed mutation: ground truth (post-savepoint) for the
|
||||
# orchestrator's <verification> teaching. KB writes have no public URL.
|
||||
receipts: list[Receipt] = []
|
||||
|
||||
def _kb_receipt(
|
||||
|
|
@ -1439,8 +1343,6 @@ async def commit_staged_filesystem_state(
|
|||
external_id=payload.get("id"),
|
||||
)
|
||||
for payload in applied_moves:
|
||||
# ``applied_moves`` rows carry the destination ``virtualPath`` because
|
||||
# the move has already landed in the DB by the time we reach this code.
|
||||
path = str(payload.get("virtualPath") or "")
|
||||
_kb_receipt(
|
||||
type="file",
|
||||
|
|
@ -1480,9 +1382,7 @@ async def commit_staged_filesystem_state(
|
|||
if tree_changed:
|
||||
delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1
|
||||
|
||||
# Avoid 'unused' lint when turn_id_for_revision was only useful for
|
||||
# diagnostic purposes inside the SAVEPOINT chain above.
|
||||
_ = turn_id_for_revision
|
||||
_ = turn_id_for_revision # diagnostic-only; silence unused lint
|
||||
|
||||
logger.info(
|
||||
"kb_persistence: commit (search_space=%s) creates=%d updates=%d "
|
||||
|
|
@ -4,8 +4,10 @@ from __future__ import annotations
|
|||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import KnowledgePriorityMiddleware
|
||||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
|
||||
KnowledgePriorityMiddleware,
|
||||
)
|
||||
from app.services.llm_service import get_planner_llm
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Knowledge-tree middleware: <workspace_tree> injection, cloud only (impl + builder)."""
|
||||
|
||||
from .builder import build_knowledge_tree_mw
|
||||
from .middleware import KnowledgeTreeMiddleware
|
||||
|
||||
__all__ = [
|
||||
"KnowledgeTreeMiddleware",
|
||||
"build_knowledge_tree_mw",
|
||||
]
|
||||
|
|
@ -4,8 +4,9 @@ from __future__ import annotations
|
|||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import KnowledgeTreeMiddleware
|
||||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
from .middleware import KnowledgeTreeMiddleware
|
||||
|
||||
|
||||
def build_knowledge_tree_mw(
|
||||
|
|
@ -33,9 +33,11 @@ from langchain_core.messages import SystemMessage
|
|||
from langgraph.runtime import Runtime
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.chat.runtime.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
PathIndex,
|
||||
build_path_index,
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""User/team memory injection middleware (main-agent only)."""
|
||||
|
||||
from .builder import build_memory_mw
|
||||
|
||||
__all__ = ["build_memory_mw"]
|
||||
|
|
@ -2,9 +2,10 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.middleware import MemoryInjectionMiddleware
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .middleware import MemoryInjectionMiddleware
|
||||
|
||||
|
||||
def build_memory_mw(
|
||||
*,
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Noop-injection middleware: provider-compat _noop tool (impl + builder)."""
|
||||
|
||||
from .builder import build_noop_injection_mw
|
||||
from .middleware import NoopInjectionMiddleware
|
||||
|
||||
__all__ = [
|
||||
"NoopInjectionMiddleware",
|
||||
"build_noop_injection_mw",
|
||||
]
|
||||
|
|
@ -2,10 +2,10 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import NoopInjectionMiddleware
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
|
||||
|
||||
from ..shared.flags import enabled
|
||||
from .middleware import NoopInjectionMiddleware
|
||||
|
||||
|
||||
def build_noop_injection_mw(flags: AgentFeatureFlags) -> NoopInjectionMiddleware | None:
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""OTel-span middleware: spans on model and tool calls (impl + builder)."""
|
||||
|
||||
from .builder import build_otel_mw
|
||||
from .middleware import OtelSpanMiddleware
|
||||
|
||||
__all__ = [
|
||||
"OtelSpanMiddleware",
|
||||
"build_otel_mw",
|
||||
]
|
||||
|
|
@ -2,10 +2,10 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import OtelSpanMiddleware
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
|
||||
|
||||
from ..shared.flags import enabled
|
||||
from .middleware import OtelSpanMiddleware
|
||||
|
||||
|
||||
def build_otel_mw(flags: AgentFeatureFlags) -> OtelSpanMiddleware | None:
|
||||
|
|
@ -7,15 +7,15 @@ from typing import Any
|
|||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..plugins.loader import (
|
||||
PluginContext,
|
||||
load_allowed_plugin_names_from_env,
|
||||
load_plugin_middlewares,
|
||||
)
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_plugin_middlewares(
|
||||
|
|
@ -6,14 +6,11 @@ import logging
|
|||
|
||||
from deepagents.middleware.skills import SkillsMiddleware
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import (
|
||||
build_skills_backend_factory,
|
||||
default_skills_sources,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
|
||||
|
||||
from ..shared.flags import enabled
|
||||
from ..skills.backends import build_skills_backend_factory, default_skills_sources
|
||||
|
||||
|
||||
def build_skills_mw(
|
||||
|
|
@ -20,50 +20,66 @@ from langchain_core.language_models import BaseChatModel
|
|||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents import (
|
||||
from app.agents.chat.multi_agent_chat.main_agent.middleware.memory import (
|
||||
build_memory_mw,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.anthropic_cache import (
|
||||
build_anthropic_cache_mw,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.compaction import (
|
||||
build_compaction_mw,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.kb_context_projection import (
|
||||
build_kb_context_projection_mw,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.patch_tool_calls import (
|
||||
build_patch_tool_calls_mw,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.resilience import (
|
||||
build_resilience_middlewares,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.todos import build_todos_mw
|
||||
from app.agents.chat.multi_agent_chat.shared.permissions import (
|
||||
build_permission_mw,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.subagents import (
|
||||
build_subagents,
|
||||
get_subagents_to_exclude,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.agent import (
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.knowledge_base.agent import (
|
||||
READONLY_NAME as KB_READONLY_NAME,
|
||||
build_readonly_subagent as build_kb_readonly_subagent,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.ask_knowledge_base_tool import (
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.knowledge_base.ask_knowledge_base_tool import (
|
||||
build_ask_knowledge_base_tool,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.chat.multi_agent_chat.subagents.middleware_stack import (
|
||||
build_subagent_middleware_stack,
|
||||
)
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .main_agent.action_log import build_action_log_mw
|
||||
from .main_agent.anonymous_doc import build_anonymous_doc_mw
|
||||
from .main_agent.busy_mutex import build_busy_mutex_mw
|
||||
from .main_agent.checkpointed_subagent_middleware import (
|
||||
from .action_log import build_action_log_mw
|
||||
from .anonymous_document import build_anonymous_doc_mw
|
||||
from .busy_mutex import build_busy_mutex_mw
|
||||
from .checkpointed_subagent_middleware import (
|
||||
SurfSenseCheckpointedSubAgentMiddleware,
|
||||
)
|
||||
from .main_agent.checkpointed_subagent_middleware.task_description import (
|
||||
from .checkpointed_subagent_middleware.task_description import (
|
||||
TASK_TOOL_DESCRIPTION,
|
||||
)
|
||||
from .main_agent.context_editing import build_context_editing_mw
|
||||
from .main_agent.dedup_hitl import build_dedup_hitl_mw
|
||||
from .main_agent.doom_loop import build_doom_loop_mw
|
||||
from .main_agent.kb_persistence import build_kb_persistence_mw
|
||||
from .main_agent.knowledge_priority import build_knowledge_priority_mw
|
||||
from .main_agent.knowledge_tree import build_knowledge_tree_mw
|
||||
from .main_agent.noop_injection import build_noop_injection_mw
|
||||
from .main_agent.otel import build_otel_mw
|
||||
from .main_agent.plugins import build_plugin_middlewares
|
||||
from .main_agent.repair import build_repair_mw
|
||||
from .main_agent.skills import build_skills_mw
|
||||
from .shared.anthropic_cache import build_anthropic_cache_mw
|
||||
from .shared.compaction import build_compaction_mw
|
||||
from .shared.kb_context_projection import build_kb_context_projection_mw
|
||||
from .shared.memory import build_memory_mw
|
||||
from .shared.patch_tool_calls import build_patch_tool_calls_mw
|
||||
from .shared.permissions import build_permission_mw
|
||||
from .shared.resilience import build_resilience_middlewares
|
||||
from .shared.todos import build_todos_mw
|
||||
from .subagent.middleware_stack import build_subagent_middleware_stack
|
||||
from .context_editing import build_context_editing_mw
|
||||
from .dedup_hitl import build_dedup_hitl_mw
|
||||
from .doom_loop import build_doom_loop_mw
|
||||
from .kb_persistence import build_kb_persistence_mw
|
||||
from .knowledge_priority import build_knowledge_priority_mw
|
||||
from .knowledge_tree import build_knowledge_tree_mw
|
||||
from .noop_injection import build_noop_injection_mw
|
||||
from .otel_span import build_otel_mw
|
||||
from .plugins import build_plugin_middlewares
|
||||
from .skills import build_skills_mw
|
||||
from .tool_call_repair import build_repair_mw
|
||||
|
||||
|
||||
def build_main_agent_deepagent_middleware(
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Tool-call-repair middleware: fix miscased/unknown tool names (impl + builder)."""
|
||||
|
||||
from .builder import build_repair_mw
|
||||
from .middleware import ToolCallNameRepairMiddleware
|
||||
|
||||
__all__ = [
|
||||
"ToolCallNameRepairMiddleware",
|
||||
"build_repair_mw",
|
||||
]
|
||||
|
|
@ -6,10 +6,10 @@ from collections.abc import Sequence
|
|||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import ToolCallNameRepairMiddleware
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
|
||||
|
||||
from ..shared.flags import enabled
|
||||
from .middleware import ToolCallNameRepairMiddleware
|
||||
|
||||
# deepagents-built-in tool names the repair pass treats as known.
|
||||
_DEEPAGENT_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
|
|
@ -34,8 +34,6 @@ from langchain.agents.middleware.types import (
|
|||
from langchain_core.messages import AIMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -120,6 +118,12 @@ class ToolCallNameRepairMiddleware(
|
|||
return call
|
||||
|
||||
# Stage 2 — invalid fallback
|
||||
# Local import keeps the middleware module import-light and avoids any
|
||||
# tools <-> middleware import-order coupling at module scope.
|
||||
from app.agents.chat.multi_agent_chat.main_agent.tools.invalid_tool import (
|
||||
INVALID_TOOL_NAME,
|
||||
)
|
||||
|
||||
if INVALID_TOOL_NAME in registered:
|
||||
original_args = call.get("args") or {}
|
||||
error_msg = (
|
||||
|
|
@ -17,7 +17,7 @@ Wire-up in ``pyproject.toml`` (illustrative; the in-repo plugin doesn't
|
|||
need this -- it's already on the import path)::
|
||||
|
||||
[project.entry-points."surfsense.plugins"]
|
||||
year_substituter = "app.agents.new_chat.plugins.year_substituter:make_middleware"
|
||||
year_substituter = "app.agents.chat.multi_agent_chat.main_agent.plugins.year_substituter:make_middleware"
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -34,7 +34,7 @@ if TYPE_CHECKING: # pragma: no cover - type-only
|
|||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from app.agents.new_chat.plugin_loader import PluginContext
|
||||
from .loader import PluginContext
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -10,18 +10,18 @@ from langchain_core.language_models import BaseChatModel
|
|||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.new_chat.agent_cache import (
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||
from .agent_cache_store import (
|
||||
flags_signature,
|
||||
get_cache,
|
||||
stable_hash,
|
||||
system_prompt_hash,
|
||||
tools_signature,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||
|
||||
|
||||
def mcp_signature(mcp_tools_by_agent: dict[str, list[BaseTool]]) -> str:
|
||||
|
|
@ -113,12 +113,11 @@ def tools_signature(
|
|||
MCP tools loaded for the user changes, gating rules flip, etc.).
|
||||
* The available connectors / document types for the search space
|
||||
change (new connector added, last connector removed, new document
|
||||
type indexed). Because :func:`get_connector_gated_tools` derives
|
||||
``modified_disabled_tools`` from ``available_connectors``, the
|
||||
tool surface is technically already covered — but we hash the
|
||||
connector list separately so an empty-list "no tools changed"
|
||||
situation still rotates the key when, say, the user re-adds a
|
||||
connector that gates a tool we were already not exposing.
|
||||
type indexed). Connector gating derives disabled tools from
|
||||
``available_connectors``, so the tool surface is technically already
|
||||
covered — but we hash the connector list separately so an empty-list
|
||||
"no tools changed" situation still rotates the key when, say, the user
|
||||
re-adds a connector that gates a tool we were already not exposing.
|
||||
|
||||
Stays stable across:
|
||||
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
"""Map configured connectors to the searchable document/connector types.
|
||||
|
||||
This is agent-agnostic infrastructure shared by every agent factory (single-
|
||||
and multi-agent). It translates the connectors a search space has enabled into
|
||||
the set of searchable type strings that pre-search middleware and ``web_search``
|
||||
understand, and always layers in the document types that exist independently of
|
||||
any connector (uploads, notes, extension captures, YouTube).
|
||||
|
||||
It lives in its own module — rather than inside a specific agent factory — so
|
||||
that retiring or moving any single agent never disturbs the others' access to
|
||||
this mapping.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
# Maps SearchSourceConnectorType enum values to the searchable document/connector types
|
||||
# used by pre-search middleware and web_search.
|
||||
# Live search connectors (TAVILY_API, LINKUP_API, BAIDU_SEARCH_API) are routed to
|
||||
# the web_search tool; all others are considered local/indexed data.
|
||||
_CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = {
|
||||
# Live search connectors (handled by web_search tool)
|
||||
"TAVILY_API": "TAVILY_API",
|
||||
"LINKUP_API": "LINKUP_API",
|
||||
"BAIDU_SEARCH_API": "BAIDU_SEARCH_API",
|
||||
# Local/indexed connectors (handled by KB pre-search middleware)
|
||||
"SLACK_CONNECTOR": "SLACK_CONNECTOR",
|
||||
"TEAMS_CONNECTOR": "TEAMS_CONNECTOR",
|
||||
"NOTION_CONNECTOR": "NOTION_CONNECTOR",
|
||||
"GITHUB_CONNECTOR": "GITHUB_CONNECTOR",
|
||||
"LINEAR_CONNECTOR": "LINEAR_CONNECTOR",
|
||||
"DISCORD_CONNECTOR": "DISCORD_CONNECTOR",
|
||||
"JIRA_CONNECTOR": "JIRA_CONNECTOR",
|
||||
"CONFLUENCE_CONNECTOR": "CONFLUENCE_CONNECTOR",
|
||||
"CLICKUP_CONNECTOR": "CLICKUP_CONNECTOR",
|
||||
"GOOGLE_CALENDAR_CONNECTOR": "GOOGLE_CALENDAR_CONNECTOR",
|
||||
"GOOGLE_GMAIL_CONNECTOR": "GOOGLE_GMAIL_CONNECTOR",
|
||||
"GOOGLE_DRIVE_CONNECTOR": "GOOGLE_DRIVE_FILE", # Connector type differs from document type
|
||||
"AIRTABLE_CONNECTOR": "AIRTABLE_CONNECTOR",
|
||||
"LUMA_CONNECTOR": "LUMA_CONNECTOR",
|
||||
"ELASTICSEARCH_CONNECTOR": "ELASTICSEARCH_CONNECTOR",
|
||||
"WEBCRAWLER_CONNECTOR": "CRAWLED_URL", # Maps to document type
|
||||
"BOOKSTACK_CONNECTOR": "BOOKSTACK_CONNECTOR",
|
||||
"CIRCLEBACK_CONNECTOR": "CIRCLEBACK", # Connector type differs from document type
|
||||
"OBSIDIAN_CONNECTOR": "OBSIDIAN_CONNECTOR",
|
||||
"DROPBOX_CONNECTOR": "DROPBOX_FILE", # Connector type differs from document type
|
||||
"ONEDRIVE_CONNECTOR": "ONEDRIVE_FILE", # Connector type differs from document type
|
||||
# Composio connectors (unified to native document types).
|
||||
# Reverse of NATIVE_TO_LEGACY_DOCTYPE in app.db.
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "GOOGLE_DRIVE_FILE",
|
||||
"COMPOSIO_GMAIL_CONNECTOR": "GOOGLE_GMAIL_CONNECTOR",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "GOOGLE_CALENDAR_CONNECTOR",
|
||||
}
|
||||
|
||||
# Document types that don't come from SearchSourceConnector but should always be searchable
|
||||
_ALWAYS_AVAILABLE_DOC_TYPES: list[str] = [
|
||||
"EXTENSION", # Browser extension data
|
||||
"FILE", # Uploaded files
|
||||
"NOTE", # User notes
|
||||
"YOUTUBE_VIDEO", # YouTube videos
|
||||
]
|
||||
|
||||
|
||||
def map_connectors_to_searchable_types(
|
||||
connector_types: list[Any],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Map SearchSourceConnectorType enums to searchable document/connector types.
|
||||
|
||||
This function:
|
||||
1. Converts connector type enums to their searchable counterparts
|
||||
2. Includes always-available document types (EXTENSION, FILE, NOTE, YOUTUBE_VIDEO)
|
||||
3. Deduplicates while preserving order
|
||||
|
||||
Args:
|
||||
connector_types: List of SearchSourceConnectorType enum values
|
||||
|
||||
Returns:
|
||||
List of searchable connector/document type strings
|
||||
"""
|
||||
result_set: set[str] = set()
|
||||
result_list: list[str] = []
|
||||
|
||||
# Add always-available document types first
|
||||
for doc_type in _ALWAYS_AVAILABLE_DOC_TYPES:
|
||||
if doc_type not in result_set:
|
||||
result_set.add(doc_type)
|
||||
result_list.append(doc_type)
|
||||
|
||||
# Map each connector type to its searchable equivalent
|
||||
for ct in connector_types:
|
||||
# Handle both enum and string types
|
||||
ct_str = ct.value if hasattr(ct, "value") else str(ct)
|
||||
searchable = _CONNECTOR_TYPE_TO_SEARCHABLE.get(ct_str)
|
||||
if searchable and searchable not in result_set:
|
||||
result_set.add(searchable)
|
||||
result_list.append(searchable)
|
||||
|
||||
return result_list
|
||||
|
|
@ -12,21 +12,28 @@ from langchain_core.tools import BaseTool
|
|||
from langgraph.types import Checkpointer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.multi_agent_chat.subagents import (
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import (
|
||||
AgentFeatureFlags,
|
||||
get_flags,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
|
||||
FilesystemMode,
|
||||
FilesystemSelection,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.resolver import (
|
||||
build_backend_resolver,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.subagents import (
|
||||
get_subagents_to_exclude,
|
||||
main_prompt_registry_subagent_lines,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.mcp_tools.index import (
|
||||
from app.agents.chat.multi_agent_chat.subagents.mcp_tools.index import (
|
||||
load_mcp_tools_by_connector,
|
||||
)
|
||||
from app.agents.new_chat.chat_deepagent import _map_connectors_to_searchable_types
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
||||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
|
||||
from app.agents.new_chat.tools.registry import build_tools_async
|
||||
from app.agents.chat.runtime.llm_config import AgentConfig
|
||||
from app.agents.chat.runtime.prompt_caching import (
|
||||
apply_litellm_prompt_caching,
|
||||
)
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.user_tool_allowlist import (
|
||||
|
|
@ -40,7 +47,10 @@ from ..tools import (
|
|||
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
|
||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
|
||||
)
|
||||
from ..tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
|
||||
from ..tools.registry import build_main_agent_tools
|
||||
from .agent_cache import build_agent_with_cache
|
||||
from .connector_searchable_types import map_connectors_to_searchable_types
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
|
@ -90,7 +100,7 @@ async def create_multi_agent_chat_deep_agent(
|
|||
connector_types = await connector_service.get_available_connectors(
|
||||
search_space_id
|
||||
)
|
||||
available_connectors = _map_connectors_to_searchable_types(connector_types)
|
||||
available_connectors = map_connectors_to_searchable_types(connector_types)
|
||||
|
||||
available_document_types = await connector_service.get_available_document_types(
|
||||
search_space_id
|
||||
|
|
@ -210,12 +220,14 @@ async def create_multi_agent_chat_deep_agent(
|
|||
main_agent_enabled_tools = list(MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
tools = await build_tools_async(
|
||||
# Main agent builds only its own small SurfSense toolset via the SRP
|
||||
# main-agent registry; connectors/MCP/deliverables are delegated to
|
||||
# subagents, so no MCP loading or connector construction happens here.
|
||||
tools = build_main_agent_tools(
|
||||
dependencies=dependencies,
|
||||
enabled_tools=main_agent_enabled_tools,
|
||||
disabled_tools=modified_disabled_tools,
|
||||
additional_tools=list(additional_tools) if additional_tools else None,
|
||||
include_mcp_tools=False,
|
||||
)
|
||||
|
||||
_flags: AgentFeatureFlags = get_flags()
|
||||
|
|
@ -16,7 +16,7 @@ prompt at agent build time, not edited at runtime.
|
|||
Two backends are provided:
|
||||
|
||||
* :class:`BuiltinSkillsBackend` — disk-backed read of bundled skills from
|
||||
``app/agents/new_chat/skills/builtin/``.
|
||||
``app/agents/shared/skills/builtin/``.
|
||||
* :class:`SearchSpaceSkillsBackend` — a thin read-only wrapper over
|
||||
:class:`KBPostgresBackend` that filters notes under the privileged folder
|
||||
``/documents/_skills/``.
|
||||
|
|
@ -47,7 +47,9 @@ from deepagents.backends.state import StateBackend
|
|||
if TYPE_CHECKING:
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import (
|
||||
KBPostgresBackend,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -59,9 +61,10 @@ _MAX_SKILL_FILE_SIZE = 10 * 1024 * 1024
|
|||
def _default_builtin_root() -> Path:
|
||||
"""Return the absolute path to the bundled builtin skills directory.
|
||||
|
||||
Located at ``app/agents/new_chat/skills/builtin/`` relative to this module.
|
||||
Located at ``builtin/`` next to this module (this module lives at
|
||||
``app/agents/multi_agent_chat/main_agent/skills/backends.py``).
|
||||
"""
|
||||
return (Path(__file__).resolve().parent.parent / "skills" / "builtin").resolve()
|
||||
return (Path(__file__).resolve().parent / "builtin").resolve()
|
||||
|
||||
|
||||
class BuiltinSkillsBackend(BackendProtocol):
|
||||
|
|
@ -121,6 +124,8 @@ class BuiltinSkillsBackend(BackendProtocol):
|
|||
else ("/" + str(target.relative_to(self.root)).replace("\\", "/"))
|
||||
)
|
||||
for child in sorted(target.iterdir()):
|
||||
if child.name == "__pycache__" or child.name.startswith("."):
|
||||
continue
|
||||
child_virtual = (
|
||||
target_virtual.rstrip("/") + "/" + child.name
|
||||
if target_virtual != "/"
|
||||
|
|
@ -305,7 +310,7 @@ def build_skills_backend_factory(
|
|||
# Imported lazily to avoid a hard dependency at module import time:
|
||||
# ``KBPostgresBackend`` pulls in DB models, which are unnecessary for
|
||||
# the unit-tested builtin path.
|
||||
from app.agents.new_chat.middleware.kb_postgres_backend import (
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import (
|
||||
KBPostgresBackend,
|
||||
)
|
||||
|
||||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
from importlib import resources
|
||||
|
||||
_PROMPTS_PACKAGE = "app.agents.multi_agent_chat.main_agent.system_prompt.prompts"
|
||||
_PROMPTS_PACKAGE = "app.agents.chat.multi_agent_chat.main_agent.system_prompt.prompts"
|
||||
|
||||
|
||||
def read_prompt_md(filename: str) -> str:
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue