mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-08 20:25:19 +02:00
refactor(agents): extract subagent-invocation contract to subagents/shared
The knowledge_base subagent imported subagent_invoke_config + EXCLUDED_STATE_KEYS from main_agent's checkpointed_subagent_middleware -- a subagent reaching into main-agent internals. Both symbols (plus the recursion-limit constant they need) are a subagent-invocation contract shared by the orchestrator's task middleware and any nested-invoking subagent. Move them to subagents/shared/invocation.py; config.py keeps the HITL resume side-channel and constants.py keeps the main-agent tuning knobs. All consumers (task_tool, kb tool, tests) repointed.
This commit is contained in:
parent
490bb3c5c5
commit
88fe213176
7 changed files with 90 additions and 64 deletions
|
|
@ -1,7 +1,9 @@
|
||||||
"""RunnableConfig wiring for nested subagent invocations.
|
"""HITL resume side-channel for nested subagent invocations.
|
||||||
|
|
||||||
Forwards the parent's ``runtime.config`` (thread_id, …) into the subagent and
|
Exposes the configurable side-channel ``stream_resume_chat`` uses to ferry
|
||||||
exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads.
|
resume payloads into a mid-flight subagent. The ``RunnableConfig`` builder and
|
||||||
|
state-key filter shared with subagents live in
|
||||||
|
``app.agents.chat.multi_agent_chat.subagents.shared.invocation``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -11,8 +13,6 @@ from typing import Any
|
||||||
|
|
||||||
from langchain.tools import ToolRuntime
|
from langchain.tools import ToolRuntime
|
||||||
|
|
||||||
from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# langgraph stores the parent task's scratchpad under this configurable key;
|
# langgraph stores the parent task's scratchpad under this configurable key;
|
||||||
|
|
@ -20,39 +20,6 @@ logger = logging.getLogger(__name__)
|
||||||
_LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
|
_LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
|
||||||
|
|
||||||
|
|
||||||
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
|
||||||
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``.
|
|
||||||
|
|
||||||
Each parallel subagent invocation lands in its own checkpoint slot keyed
|
|
||||||
by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``.
|
|
||||||
The same call across the resume cycle keeps reading from the same snapshot
|
|
||||||
(``tool_call_id`` is stable per LLM-emitted call).
|
|
||||||
|
|
||||||
We namespace via ``thread_id`` rather than ``checkpoint_ns`` because
|
|
||||||
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
|
|
||||||
subgraph path and raises ``ValueError("Subgraph X not found")``.
|
|
||||||
"""
|
|
||||||
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
|
|
||||||
current_limit = merged.get("recursion_limit")
|
|
||||||
try:
|
|
||||||
current_int = int(current_limit) if current_limit is not None else 0
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
current_int = 0
|
|
||||||
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
|
|
||||||
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
|
|
||||||
|
|
||||||
configurable: dict[str, Any] = dict(merged.get("configurable") or {})
|
|
||||||
parent_thread_id = configurable.get("thread_id")
|
|
||||||
per_call_suffix = f"task:{runtime.tool_call_id}"
|
|
||||||
configurable["thread_id"] = (
|
|
||||||
f"{parent_thread_id}::{per_call_suffix}"
|
|
||||||
if parent_thread_id
|
|
||||||
else per_call_suffix
|
|
||||||
)
|
|
||||||
merged["configurable"] = configurable
|
|
||||||
return merged
|
|
||||||
|
|
||||||
|
|
||||||
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
|
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
|
||||||
"""Pop the resume payload for *this* call's ``tool_call_id``.
|
"""Pop the resume payload for *this* call's ``tool_call_id``.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,14 @@
|
||||||
"""Constants shared by the checkpointed subagent middleware."""
|
"""Tuning constants for the checkpointed subagent middleware.
|
||||||
|
|
||||||
|
``EXCLUDED_STATE_KEYS`` and ``DEFAULT_SUBAGENT_RECURSION_LIMIT`` are part of the
|
||||||
|
subagent-invocation contract shared with subagents and now live in
|
||||||
|
``app.agents.chat.multi_agent_chat.subagents.shared.invocation``.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Mirror of deepagents.middleware.subagents._EXCLUDED_STATE_KEYS.
|
|
||||||
EXCLUDED_STATE_KEYS = frozenset(
|
|
||||||
{
|
|
||||||
"messages",
|
|
||||||
"todos",
|
|
||||||
"structured_response",
|
|
||||||
"skills_metadata",
|
|
||||||
"memory_contents",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Match the parent graph's budget; the LangGraph default of 25 trips on
|
|
||||||
# multi-step subagent runs.
|
|
||||||
DEFAULT_SUBAGENT_RECURSION_LIMIT = 10_000
|
|
||||||
|
|
||||||
|
|
||||||
def _read_timeout_env(name: str, default: float) -> float:
|
def _read_timeout_env(name: str, default: float) -> float:
|
||||||
"""Parse ``name`` from the environment; fall back to ``default`` on bad values.
|
"""Parse ``name`` from the environment; fall back to ``default`` on bad values.
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,10 @@ from langchain_core.tools import StructuredTool
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
from langgraph.types import Command, Interrupt
|
from langgraph.types import Command, Interrupt
|
||||||
|
|
||||||
|
from app.agents.chat.multi_agent_chat.subagents.shared.invocation import (
|
||||||
|
EXCLUDED_STATE_KEYS,
|
||||||
|
subagent_invoke_config,
|
||||||
|
)
|
||||||
from app.agents.chat.multi_agent_chat.subagents.shared.spec import (
|
from app.agents.chat.multi_agent_chat.subagents.shared.spec import (
|
||||||
SURF_CONTEXT_HINT_PROVIDER_KEY,
|
SURF_CONTEXT_HINT_PROVIDER_KEY,
|
||||||
ContextHintProvider,
|
ContextHintProvider,
|
||||||
|
|
@ -34,13 +38,11 @@ from .config import (
|
||||||
consume_surfsense_resume,
|
consume_surfsense_resume,
|
||||||
drain_parent_null_resume,
|
drain_parent_null_resume,
|
||||||
has_surfsense_resume,
|
has_surfsense_resume,
|
||||||
subagent_invoke_config,
|
|
||||||
)
|
)
|
||||||
from .constants import (
|
from .constants import (
|
||||||
DEFAULT_SUBAGENT_BATCH_CONCURRENCY,
|
DEFAULT_SUBAGENT_BATCH_CONCURRENCY,
|
||||||
DEFAULT_SUBAGENT_BILLABLE_THRESHOLD,
|
DEFAULT_SUBAGENT_BILLABLE_THRESHOLD,
|
||||||
DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS,
|
DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS,
|
||||||
EXCLUDED_STATE_KEYS,
|
|
||||||
MAX_SUBAGENT_BATCH_SIZE,
|
MAX_SUBAGENT_BATCH_SIZE,
|
||||||
)
|
)
|
||||||
from .propagation import wrap_with_tool_call_id
|
from .propagation import wrap_with_tool_call_id
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,9 @@ from langchain_core.runnables import Runnable
|
||||||
from langchain_core.tools import StructuredTool
|
from langchain_core.tools import StructuredTool
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.config import (
|
from app.agents.chat.multi_agent_chat.subagents.shared.invocation import (
|
||||||
subagent_invoke_config,
|
|
||||||
)
|
|
||||||
from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.constants import (
|
|
||||||
EXCLUDED_STATE_KEYS,
|
EXCLUDED_STATE_KEYS,
|
||||||
|
subagent_invoke_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .prompts import load_readonly_description
|
from .prompts import load_readonly_description
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""Subagent-invocation contract shared by the orchestrator and nested subagents.
|
||||||
|
|
||||||
|
Both the main-agent ``task`` middleware (``checkpointed_subagent_middleware``)
|
||||||
|
and subagents that themselves invoke another subagent (e.g.
|
||||||
|
``ask_knowledge_base``) need the same two things when spawning a child run:
|
||||||
|
|
||||||
|
- a ``RunnableConfig`` that raises the recursion limit and isolates the child's
|
||||||
|
``thread_id`` so each invocation lands in its own checkpoint slot
|
||||||
|
(``subagent_invoke_config``), and
|
||||||
|
- the set of parent state keys that must *not* be forwarded into / merged back
|
||||||
|
from the child (``EXCLUDED_STATE_KEYS``).
|
||||||
|
|
||||||
|
Keeping this here (rather than inside the main-agent middleware) lets subagents
|
||||||
|
reuse the contract without importing main-agent internals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.tools import ToolRuntime
|
||||||
|
|
||||||
|
# Mirror of deepagents.middleware.subagents._EXCLUDED_STATE_KEYS.
|
||||||
|
EXCLUDED_STATE_KEYS = frozenset(
|
||||||
|
{
|
||||||
|
"messages",
|
||||||
|
"todos",
|
||||||
|
"structured_response",
|
||||||
|
"skills_metadata",
|
||||||
|
"memory_contents",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Match the parent graph's budget; the LangGraph default of 25 trips on
|
||||||
|
# multi-step subagent runs.
|
||||||
|
DEFAULT_SUBAGENT_RECURSION_LIMIT = 10_000
|
||||||
|
|
||||||
|
|
||||||
|
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
||||||
|
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``.
|
||||||
|
|
||||||
|
Each parallel subagent invocation lands in its own checkpoint slot keyed
|
||||||
|
by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``.
|
||||||
|
The same call across the resume cycle keeps reading from the same snapshot
|
||||||
|
(``tool_call_id`` is stable per LLM-emitted call).
|
||||||
|
|
||||||
|
We namespace via ``thread_id`` rather than ``checkpoint_ns`` because
|
||||||
|
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
|
||||||
|
subgraph path and raises ``ValueError("Subgraph X not found")``.
|
||||||
|
"""
|
||||||
|
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
|
||||||
|
current_limit = merged.get("recursion_limit")
|
||||||
|
try:
|
||||||
|
current_int = int(current_limit) if current_limit is not None else 0
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
current_int = 0
|
||||||
|
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
|
||||||
|
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||||
|
|
||||||
|
configurable: dict[str, Any] = dict(merged.get("configurable") or {})
|
||||||
|
parent_thread_id = configurable.get("thread_id")
|
||||||
|
per_call_suffix = f"task:{runtime.tool_call_id}"
|
||||||
|
configurable["thread_id"] = (
|
||||||
|
f"{parent_thread_id}::{per_call_suffix}"
|
||||||
|
if parent_thread_id
|
||||||
|
else per_call_suffix
|
||||||
|
)
|
||||||
|
merged["configurable"] = configurable
|
||||||
|
return merged
|
||||||
|
|
@ -14,9 +14,6 @@ from langgraph.graph import END, START, StateGraph
|
||||||
from langgraph.types import Command, interrupt
|
from langgraph.types import Command, interrupt
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.config import (
|
|
||||||
subagent_invoke_config,
|
|
||||||
)
|
|
||||||
from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.resume_routing import (
|
from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.resume_routing import (
|
||||||
collect_pending_tool_calls,
|
collect_pending_tool_calls,
|
||||||
slice_decisions_by_tool_call,
|
slice_decisions_by_tool_call,
|
||||||
|
|
@ -24,6 +21,9 @@ from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagen
|
||||||
from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.task_tool import (
|
from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.task_tool import (
|
||||||
build_task_tool_with_parent_config,
|
build_task_tool_with_parent_config,
|
||||||
)
|
)
|
||||||
|
from app.agents.chat.multi_agent_chat.subagents.shared.invocation import (
|
||||||
|
subagent_invoke_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _SubagentState(TypedDict, total=False):
|
class _SubagentState(TypedDict, total=False):
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from langchain.tools import ToolRuntime
|
from langchain.tools import ToolRuntime
|
||||||
|
|
||||||
from app.agents.chat.multi_agent_chat.main_agent.middleware.checkpointed_subagent_middleware.config import (
|
from app.agents.chat.multi_agent_chat.subagents.shared.invocation import (
|
||||||
subagent_invoke_config,
|
subagent_invoke_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue