mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 14:22:47 +02:00
fix(multi-agent): cache compiled agent graph keyed on per-request inputs
This commit is contained in:
parent
c8ed70a26c
commit
07a84d1a41
2 changed files with 122 additions and 4 deletions
|
|
@ -0,0 +1,117 @@
|
|||
"""Compiled agent graph caching for the multi-agent path."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
|
||||
from app.agents.new_chat.agent_cache import (
|
||||
flags_signature,
|
||||
get_cache,
|
||||
stable_hash,
|
||||
system_prompt_hash,
|
||||
tools_signature,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||
|
||||
|
||||
def mcp_signature(mcp_tools_by_agent: dict[str, ToolsPermissions]) -> str:
|
||||
"""Hash the per-agent MCP tool surface so a change rotates the cache key."""
|
||||
rows = []
|
||||
for agent_name in sorted(mcp_tools_by_agent.keys()):
|
||||
perms = mcp_tools_by_agent[agent_name]
|
||||
allow_names = sorted(item.get("name", "") for item in perms.get("allow", []))
|
||||
ask_names = sorted(item.get("name", "") for item in perms.get("ask", []))
|
||||
rows.append((agent_name, allow_names, ask_names))
|
||||
return stable_hash(rows)
|
||||
|
||||
|
||||
async def build_agent_with_cache(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
final_system_prompt: str,
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
visibility: ChatVisibility,
|
||||
anon_session_id: str | None,
|
||||
available_connectors: list[str],
|
||||
available_document_types: list[str],
|
||||
mentioned_document_ids: list[int] | None,
|
||||
max_input_tokens: int | None,
|
||||
flags: AgentFeatureFlags,
|
||||
checkpointer: Checkpointer,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions],
|
||||
disabled_tools: list[str] | None,
|
||||
config_id: str | None,
|
||||
) -> Any:
|
||||
"""Compile the multi-agent graph, serving from cache when key components are stable."""
|
||||
|
||||
async def _build() -> Any:
|
||||
return await asyncio.to_thread(
|
||||
build_compiled_agent_graph_sync,
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
final_system_prompt=final_system_prompt,
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
visibility=visibility,
|
||||
anon_session_id=anon_session_id,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
max_input_tokens=max_input_tokens,
|
||||
flags=flags,
|
||||
checkpointer=checkpointer,
|
||||
subagent_dependencies=subagent_dependencies,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
|
||||
if not (flags.enable_agent_cache and not flags.disable_new_agent_stack):
|
||||
return await _build()
|
||||
|
||||
# Every per-request value any middleware closes over at __init__ must be in
|
||||
# the key, otherwise a hit will leak state across threads. Bump the schema
|
||||
# version when the component list changes shape.
|
||||
cache_key = stable_hash(
|
||||
"multi-agent-v1",
|
||||
config_id,
|
||||
thread_id,
|
||||
user_id,
|
||||
search_space_id,
|
||||
visibility,
|
||||
filesystem_mode,
|
||||
anon_session_id,
|
||||
tools_signature(
|
||||
tools,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
),
|
||||
mcp_signature(mcp_tools_by_agent),
|
||||
flags_signature(flags),
|
||||
system_prompt_hash(final_system_prompt),
|
||||
max_input_tokens,
|
||||
sorted(disabled_tools) if disabled_tools else None,
|
||||
)
|
||||
return await get_cache().get_or_build(cache_key, builder=_build)
|
||||
|
||||
|
||||
__all__ = ["build_agent_with_cache", "mcp_signature"]
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
|
|
@ -33,12 +32,12 @@ from app.db import ChatVisibility
|
|||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||
from ..system_prompt import build_main_agent_system_prompt
|
||||
from ..tools import (
|
||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
|
||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
|
||||
)
|
||||
from .agent_cache import build_agent_with_cache
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
|
@ -210,9 +209,10 @@ async def create_surfsense_deep_agent(
|
|||
|
||||
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
||||
|
||||
config_id = agent_config.config_id if agent_config is not None else None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await asyncio.to_thread(
|
||||
build_compiled_agent_graph_sync,
|
||||
agent = await build_agent_with_cache(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
final_system_prompt=final_system_prompt,
|
||||
|
|
@ -232,6 +232,7 @@ async def create_surfsense_deep_agent(
|
|||
subagent_dependencies=dependencies,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||
disabled_tools=disabled_tools,
|
||||
config_id=config_id,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue