mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-07 23:02:39 +02:00
fix(multi-agent): cache compiled agent graph keyed on per-request inputs
This commit is contained in:
parent
c8ed70a26c
commit
07a84d1a41
2 changed files with 122 additions and 4 deletions
|
|
@ -0,0 +1,117 @@
|
||||||
|
"""Compiled agent graph caching for the multi-agent path."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
|
||||||
|
from app.agents.new_chat.agent_cache import (
|
||||||
|
flags_signature,
|
||||||
|
get_cache,
|
||||||
|
stable_hash,
|
||||||
|
system_prompt_hash,
|
||||||
|
tools_signature,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
|
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||||
|
|
||||||
|
|
||||||
|
def mcp_signature(mcp_tools_by_agent: dict[str, ToolsPermissions]) -> str:
|
||||||
|
"""Hash the per-agent MCP tool surface so a change rotates the cache key."""
|
||||||
|
rows = []
|
||||||
|
for agent_name in sorted(mcp_tools_by_agent.keys()):
|
||||||
|
perms = mcp_tools_by_agent[agent_name]
|
||||||
|
allow_names = sorted(item.get("name", "") for item in perms.get("allow", []))
|
||||||
|
ask_names = sorted(item.get("name", "") for item in perms.get("ask", []))
|
||||||
|
rows.append((agent_name, allow_names, ask_names))
|
||||||
|
return stable_hash(rows)
|
||||||
|
|
||||||
|
|
||||||
|
async def build_agent_with_cache(
|
||||||
|
*,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
final_system_prompt: str,
|
||||||
|
backend_resolver: Any,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
thread_id: int | None,
|
||||||
|
visibility: ChatVisibility,
|
||||||
|
anon_session_id: str | None,
|
||||||
|
available_connectors: list[str],
|
||||||
|
available_document_types: list[str],
|
||||||
|
mentioned_document_ids: list[int] | None,
|
||||||
|
max_input_tokens: int | None,
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
checkpointer: Checkpointer,
|
||||||
|
subagent_dependencies: dict[str, Any],
|
||||||
|
mcp_tools_by_agent: dict[str, ToolsPermissions],
|
||||||
|
disabled_tools: list[str] | None,
|
||||||
|
config_id: str | None,
|
||||||
|
) -> Any:
|
||||||
|
"""Compile the multi-agent graph, serving from cache when key components are stable."""
|
||||||
|
|
||||||
|
async def _build() -> Any:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
build_compiled_agent_graph_sync,
|
||||||
|
llm=llm,
|
||||||
|
tools=tools,
|
||||||
|
final_system_prompt=final_system_prompt,
|
||||||
|
backend_resolver=backend_resolver,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
visibility=visibility,
|
||||||
|
anon_session_id=anon_session_id,
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
available_document_types=available_document_types,
|
||||||
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
|
max_input_tokens=max_input_tokens,
|
||||||
|
flags=flags,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
subagent_dependencies=subagent_dependencies,
|
||||||
|
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||||
|
disabled_tools=disabled_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (flags.enable_agent_cache and not flags.disable_new_agent_stack):
|
||||||
|
return await _build()
|
||||||
|
|
||||||
|
# Every per-request value any middleware closes over at __init__ must be in
|
||||||
|
# the key, otherwise a hit will leak state across threads. Bump the schema
|
||||||
|
# version when the component list changes shape.
|
||||||
|
cache_key = stable_hash(
|
||||||
|
"multi-agent-v1",
|
||||||
|
config_id,
|
||||||
|
thread_id,
|
||||||
|
user_id,
|
||||||
|
search_space_id,
|
||||||
|
visibility,
|
||||||
|
filesystem_mode,
|
||||||
|
anon_session_id,
|
||||||
|
tools_signature(
|
||||||
|
tools,
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
available_document_types=available_document_types,
|
||||||
|
),
|
||||||
|
mcp_signature(mcp_tools_by_agent),
|
||||||
|
flags_signature(flags),
|
||||||
|
system_prompt_hash(final_system_prompt),
|
||||||
|
max_input_tokens,
|
||||||
|
sorted(disabled_tools) if disabled_tools else None,
|
||||||
|
)
|
||||||
|
return await get_cache().get_or_build(cache_key, builder=_build)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["build_agent_with_cache", "mcp_signature"]
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
@ -33,12 +32,12 @@ from app.db import ChatVisibility
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
|
||||||
from ..system_prompt import build_main_agent_system_prompt
|
from ..system_prompt import build_main_agent_system_prompt
|
||||||
from ..tools import (
|
from ..tools import (
|
||||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
|
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
|
||||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
|
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
|
||||||
)
|
)
|
||||||
|
from .agent_cache import build_agent_with_cache
|
||||||
|
|
||||||
_perf_log = get_perf_logger()
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
@ -210,9 +209,10 @@ async def create_surfsense_deep_agent(
|
||||||
|
|
||||||
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
||||||
|
|
||||||
|
config_id = agent_config.config_id if agent_config is not None else None
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
agent = await asyncio.to_thread(
|
agent = await build_agent_with_cache(
|
||||||
build_compiled_agent_graph_sync,
|
|
||||||
llm=llm,
|
llm=llm,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
final_system_prompt=final_system_prompt,
|
final_system_prompt=final_system_prompt,
|
||||||
|
|
@ -232,6 +232,7 @@ async def create_surfsense_deep_agent(
|
||||||
subagent_dependencies=dependencies,
|
subagent_dependencies=dependencies,
|
||||||
mcp_tools_by_agent=mcp_tools_by_agent,
|
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||||
disabled_tools=disabled_tools,
|
disabled_tools=disabled_tools,
|
||||||
|
config_id=config_id,
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue