mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 14:22:47 +02:00
refactor(multi-agent): add main-agent safety and llm-shaping middleware factories
This commit is contained in:
parent
390dc9307f
commit
b0ee44b2f1
7 changed files with 215 additions and 0 deletions
|
|
@ -0,0 +1,50 @@
|
|||
"""Spill + clear-tool-uses passes to keep payloads under budget."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.main_agent.context_prune.prune_tool_names import (
|
||||
safe_exclude_tools,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import (
|
||||
ClearToolUsesEdit,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
)
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_context_editing_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
max_input_tokens: int | None,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
) -> SpillingContextEditingMiddleware | None:
|
||||
if not enabled(flags, "enable_context_editing") or not max_input_tokens:
|
||||
return None
|
||||
spill_edit = SpillToBackendEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
)
|
||||
clear_edit = ClearToolUsesEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
placeholder="[cleared - older tool output trimmed for context]",
|
||||
)
|
||||
return SpillingContextEditingMiddleware(
|
||||
edits=[spill_edit, clear_edit],
|
||||
backend_resolver=backend_resolver,
|
||||
)
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
"""Drop duplicate HITL tool calls before execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware
|
||||
|
||||
|
||||
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
|
||||
return DedupHITLToolCallsMiddleware(agent_tools=list(tools))
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Stop N identical tool calls in a row via interrupt."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import DoomLoopMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_doom_loop_mw(flags: AgentFeatureFlags) -> DoomLoopMiddleware | None:
|
||||
return DoomLoopMiddleware(threshold=3) if enabled(flags, "enable_doom_loop") else None
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Provider-compat: append a `_noop` tool when tools=[] but history has tool calls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import NoopInjectionMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_noop_injection_mw(flags: AgentFeatureFlags) -> NoopInjectionMiddleware | None:
|
||||
return NoopInjectionMiddleware() if enabled(flags, "enable_compaction_v2") else None
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""Repair miscased / unknown tool names to the registered set or invalid_tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import ToolCallNameRepairMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
# deepagents-built-in tool names the repair pass treats as known.
|
||||
_DEEPAGENT_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"write_todos",
|
||||
"ls",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"glob",
|
||||
"grep",
|
||||
"execute",
|
||||
"task",
|
||||
"mkdir",
|
||||
"cd",
|
||||
"pwd",
|
||||
"move_file",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"execute_code",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def build_repair_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
tools: Sequence[BaseTool],
|
||||
) -> ToolCallNameRepairMiddleware | None:
|
||||
if not enabled(flags, "enable_tool_call_repair"):
|
||||
return None
|
||||
registered_names: set[str] = {t.name for t in tools}
|
||||
registered_names |= _DEEPAGENT_BUILTIN_TOOL_NAMES
|
||||
return ToolCallNameRepairMiddleware(
|
||||
registered_tool_names=registered_names,
|
||||
fuzzy_match_threshold=None,
|
||||
)
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""LLM-based tool subset selection (only when >30 tools)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_selector_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
tools: Sequence[BaseTool],
|
||||
) -> LLMToolSelectorMiddleware | None:
|
||||
if not enabled(flags, "enable_llm_tool_selector") or len(tools) <= 30:
|
||||
return None
|
||||
try:
|
||||
return LLMToolSelectorMiddleware(
|
||||
model="openai:gpt-4o-mini",
|
||||
max_tools=12,
|
||||
always_include=[
|
||||
name
|
||||
for name in (
|
||||
"update_memory",
|
||||
"get_connected_accounts",
|
||||
"scrape_webpage",
|
||||
)
|
||||
if name in {t.name for t in tools}
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
|
||||
return None
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""Skill discovery + injection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from deepagents.middleware.skills import SkillsMiddleware
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import (
|
||||
build_skills_backend_factory,
|
||||
default_skills_sources,
|
||||
)
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_skills_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
) -> SkillsMiddleware | None:
|
||||
if not enabled(flags, "enable_skills"):
|
||||
return None
|
||||
try:
|
||||
skills_factory = build_skills_backend_factory(
|
||||
search_space_id=search_space_id
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
)
|
||||
return SkillsMiddleware(
|
||||
backend=skills_factory,
|
||||
sources=default_skills_sources(),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
|
||||
return None
|
||||
Loading…
Add table
Add a link
Reference in a new issue