mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
feat(automations): wire agent_task to multi_agent_chat with auto-approve loop
This commit is contained in:
parent
7ec3468113
commit
ce45e11009
9 changed files with 285 additions and 31 deletions
1
surfsense_backend/app/automations/actions/__init__.py
Normal file
1
surfsense_backend/app/automations/actions/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Action implementations. One subpackage per built-in action type."""
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
"""``agent_task`` action: spin up multi_agent_chat for one rendered query."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .factory import build_handler
|
||||||
|
|
||||||
|
__all__ = ["build_handler"]
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
"""Synthesize HITL decisions for every pending interrupt (approve-all or reject-all)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def build_auto_decisions(
|
||||||
|
state: Any, decision: str
|
||||||
|
) -> tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any]]]:
|
||||||
|
"""Return ``(lg_resume_map, surfsense_resume_value)`` covering every pending interrupt.
|
||||||
|
|
||||||
|
``lg_resume_map`` is keyed by ``Interrupt.id`` for ``Command(resume=...)``;
|
||||||
|
``surfsense_resume_value`` is keyed by ``tool_call_id`` for the subagent
|
||||||
|
middleware bridge. Action count is read from ``value.action_requests`` when
|
||||||
|
present and falls back to ``1`` for wrapped scalar interrupts.
|
||||||
|
"""
|
||||||
|
lg_resume_map: dict[str, dict[str, Any]] = {}
|
||||||
|
routed: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
for interrupt_obj in getattr(state, "interrupts", ()) or ():
|
||||||
|
value = getattr(interrupt_obj, "value", None)
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
continue
|
||||||
|
interrupt_id = getattr(interrupt_obj, "id", None)
|
||||||
|
if not isinstance(interrupt_id, str):
|
||||||
|
continue
|
||||||
|
|
||||||
|
action_requests = value.get("action_requests")
|
||||||
|
count = len(action_requests) if isinstance(action_requests, list) else 1
|
||||||
|
decisions = [{"type": decision} for _ in range(count)]
|
||||||
|
|
||||||
|
lg_resume_map[interrupt_id] = {"decisions": decisions}
|
||||||
|
|
||||||
|
tool_call_id = value.get("tool_call_id")
|
||||||
|
if isinstance(tool_call_id, str):
|
||||||
|
routed[tool_call_id] = {"decisions": decisions}
|
||||||
|
|
||||||
|
return lg_resume_map, routed
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
"""Build the per-invocation dependencies the multi_agent_chat factory needs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
|
||||||
|
from app.tasks.chat.streaming.flows.shared.pre_stream_setup import (
|
||||||
|
get_chat_checkpointer,
|
||||||
|
setup_connector_and_firecrawl,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DependencyError(Exception):
|
||||||
|
"""An external dependency (LLM config, checkpointer, ...) refused to load."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class AgentDependencies:
|
||||||
|
"""Everything ``create_multi_agent_chat_deep_agent`` needs from the environment."""
|
||||||
|
|
||||||
|
llm: Any
|
||||||
|
agent_config: Any
|
||||||
|
connector_service: Any
|
||||||
|
firecrawl_api_key: str | None
|
||||||
|
checkpointer: Any
|
||||||
|
|
||||||
|
|
||||||
|
async def build_dependencies(
|
||||||
|
*,
|
||||||
|
session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> AgentDependencies:
|
||||||
|
"""Load the LLM bundle, connector service, and checkpointer for one invoke.
|
||||||
|
|
||||||
|
Uses the search space's default LLM config (``config_id=-1``). Per-step
|
||||||
|
model overrides land in a future iteration alongside the ``model`` param.
|
||||||
|
"""
|
||||||
|
llm, agent_config, err = await load_llm_bundle(
|
||||||
|
session, config_id=-1, search_space_id=search_space_id
|
||||||
|
)
|
||||||
|
if err is not None or llm is None:
|
||||||
|
raise DependencyError(err or "failed to load default LLM config")
|
||||||
|
|
||||||
|
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
|
||||||
|
session, search_space_id=search_space_id
|
||||||
|
)
|
||||||
|
checkpointer = await get_chat_checkpointer()
|
||||||
|
return AgentDependencies(
|
||||||
|
llm=llm,
|
||||||
|
agent_config=agent_config,
|
||||||
|
connector_service=connector_service,
|
||||||
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""Bind ``ActionContext`` to a callable that runs one ``agent_task`` step."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.automations.registries.actions.types import (
|
||||||
|
ActionContext,
|
||||||
|
ActionHandler,
|
||||||
|
)
|
||||||
|
from app.automations.schemas.actions import AgentTaskActionParams
|
||||||
|
|
||||||
|
from .invoke import run_agent_task
|
||||||
|
|
||||||
|
|
||||||
|
def build_handler(ctx: ActionContext) -> ActionHandler:
|
||||||
|
"""Return a handler closure that validates params and runs the agent task."""
|
||||||
|
|
||||||
|
async def handle(params: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
validated = AgentTaskActionParams.model_validate(params)
|
||||||
|
return await run_agent_task(
|
||||||
|
ctx=ctx,
|
||||||
|
query=validated.query,
|
||||||
|
auto_approve_all=validated.auto_approve_all,
|
||||||
|
)
|
||||||
|
|
||||||
|
return handle
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""Extract the agent's final assistant text from the terminal invoke result."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
|
||||||
|
def extract_final_assistant_message(result: Any) -> str | None:
|
||||||
|
"""Return the last ``AIMessage`` text content, or ``None`` if there isn't one.
|
||||||
|
|
||||||
|
Multi-part messages (content lists) are flattened by concatenating ``text``
|
||||||
|
parts in order. Non-string content (tool calls, images) is skipped.
|
||||||
|
"""
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
return None
|
||||||
|
messages = result.get("messages")
|
||||||
|
if not isinstance(messages, list):
|
||||||
|
return None
|
||||||
|
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if not isinstance(msg, AIMessage):
|
||||||
|
continue
|
||||||
|
return _content_to_text(msg.content)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _content_to_text(content: Any) -> str | None:
|
||||||
|
if isinstance(content, str):
|
||||||
|
text = content.strip()
|
||||||
|
return text or None
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, str):
|
||||||
|
parts.append(part)
|
||||||
|
elif isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
text = part.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
joined = "".join(parts).strip()
|
||||||
|
return joined or None
|
||||||
|
return None
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
"""Run one ``agent_task`` invocation: ainvoke + auto-decision resume loop."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
|
||||||
|
from app.automations.registries.actions.types import ActionContext
|
||||||
|
from app.db import ChatVisibility, async_session_maker
|
||||||
|
|
||||||
|
from .auto_decide import build_auto_decisions
|
||||||
|
from .dependencies import build_dependencies
|
||||||
|
from .finalize import extract_final_assistant_message
|
||||||
|
|
||||||
|
# Cap on HITL resume iterations. The agent should not need this many turns in one
|
||||||
|
# step; treat overshoot as a runaway and fail the step.
|
||||||
|
_MAX_RESUMES = 50
|
||||||
|
|
||||||
|
|
||||||
|
async def run_agent_task(
|
||||||
|
*,
|
||||||
|
ctx: ActionContext,
|
||||||
|
query: str,
|
||||||
|
auto_approve_all: bool,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Invoke multi_agent_chat for one rendered query and return its outcome.
|
||||||
|
|
||||||
|
Opens its own DB session so the executor's bookkeeping session isn't tied
|
||||||
|
up for the entire invocation. The LangGraph ``thread_id`` (a fresh UUID)
|
||||||
|
is returned as ``agent_session_id`` for later inspection.
|
||||||
|
"""
|
||||||
|
agent_session_id = str(uuid.uuid4())
|
||||||
|
user_id = str(ctx.creator_user_id) if ctx.creator_user_id else None
|
||||||
|
decision = "approve" if auto_approve_all else "reject"
|
||||||
|
|
||||||
|
async with async_session_maker() as agent_session:
|
||||||
|
deps = await build_dependencies(
|
||||||
|
session=agent_session,
|
||||||
|
search_space_id=ctx.search_space_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = await create_multi_agent_chat_deep_agent(
|
||||||
|
llm=deps.llm,
|
||||||
|
search_space_id=ctx.search_space_id,
|
||||||
|
db_session=agent_session,
|
||||||
|
connector_service=deps.connector_service,
|
||||||
|
checkpointer=deps.checkpointer,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=None,
|
||||||
|
agent_config=deps.agent_config,
|
||||||
|
firecrawl_api_key=deps.firecrawl_api_key,
|
||||||
|
thread_visibility=ChatVisibility.PRIVATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_id = f"automation:{ctx.run_id}:{ctx.step_id}"
|
||||||
|
turn_id = f"{request_id}:{int(time.time() * 1000)}"
|
||||||
|
input_state: dict[str, Any] = {
|
||||||
|
"messages": [HumanMessage(content=query)],
|
||||||
|
"search_space_id": ctx.search_space_id,
|
||||||
|
"request_id": request_id,
|
||||||
|
"turn_id": turn_id,
|
||||||
|
}
|
||||||
|
config: dict[str, Any] = {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": agent_session_id,
|
||||||
|
"request_id": request_id,
|
||||||
|
"turn_id": turn_id,
|
||||||
|
},
|
||||||
|
"recursion_limit": 10_000,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await agent.ainvoke(input_state, config=config)
|
||||||
|
|
||||||
|
resumes = 0
|
||||||
|
while True:
|
||||||
|
state = await agent.aget_state(config)
|
||||||
|
if not getattr(state, "interrupts", None):
|
||||||
|
break
|
||||||
|
if resumes >= _MAX_RESUMES:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"agent_task exceeded {_MAX_RESUMES} HITL resume iterations"
|
||||||
|
)
|
||||||
|
lg_resume_map, routed = build_auto_decisions(state, decision)
|
||||||
|
config["configurable"]["surfsense_resume_value"] = routed
|
||||||
|
result = await agent.ainvoke(Command(resume=lg_resume_map), config=config)
|
||||||
|
resumes += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"agent_session_id": agent_session_id,
|
||||||
|
"final_message": extract_final_assistant_message(result),
|
||||||
|
"resumes": resumes,
|
||||||
|
}
|
||||||
|
|
@ -2,31 +2,18 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from app.automations.actions.agent_task import build_handler
|
||||||
|
|
||||||
from app.automations.schemas.actions import AgentTaskActionParams
|
from app.automations.schemas.actions import AgentTaskActionParams
|
||||||
|
|
||||||
from .store import register_action
|
from .store import register_action
|
||||||
from .types import ActionContext, ActionDefinition, ActionHandler
|
from .types import ActionDefinition
|
||||||
|
|
||||||
|
|
||||||
def _build_handler(ctx: ActionContext) -> ActionHandler:
|
|
||||||
"""Bind run/session context to the agent_task handler. Real wiring lands in Phase 4b."""
|
|
||||||
del ctx # ignored by the stub; real handler will consume it
|
|
||||||
|
|
||||||
async def handle(params: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
AgentTaskActionParams.model_validate(params)
|
|
||||||
return {"status": "stubbed"}
|
|
||||||
|
|
||||||
return handle
|
|
||||||
|
|
||||||
|
|
||||||
AGENT_TASK_ACTION = ActionDefinition(
|
AGENT_TASK_ACTION = ActionDefinition(
|
||||||
type="agent_task",
|
type="agent_task",
|
||||||
name="Agent task",
|
name="Agent task",
|
||||||
description="Run an agent task with a scoped tool allowlist.",
|
description="Run a multi_agent_chat turn from an automation step.",
|
||||||
params_schema=AgentTaskActionParams.model_json_schema(),
|
params_schema=AgentTaskActionParams.model_json_schema(),
|
||||||
build_handler=_build_handler,
|
build_handler=build_handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
register_action(AGENT_TASK_ACTION)
|
register_action(AGENT_TASK_ACTION)
|
||||||
|
|
|
||||||
|
|
@ -2,26 +2,20 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
class AgentTaskActionParams(BaseModel):
|
class AgentTaskActionParams(BaseModel):
|
||||||
"""Run an agent task with a scoped tool allowlist."""
|
"""Run a multi_agent_chat turn from an automation step."""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
prompt: str = Field(..., min_length=1, description="Task prompt; rendered at execute time.")
|
query: str = Field(
|
||||||
tools: list[str] = Field(
|
...,
|
||||||
default_factory=list,
|
min_length=1,
|
||||||
description="Tool identifiers the agent may call. Empty = no tool access.",
|
description="User query for the agent; rendered at execute time.",
|
||||||
)
|
)
|
||||||
model: str | None = Field(
|
auto_approve_all: bool = Field(
|
||||||
default=None,
|
default=False,
|
||||||
description="Model identifier. Defaults to the search space's agent_llm_id.",
|
description="If true, every HITL approval is auto-approved; otherwise rejected.",
|
||||||
)
|
|
||||||
output_schema: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="JSON Schema (draft 2020-12) the agent must return. Recommended.",
|
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue