mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
refactor(automations): move agent_task to builtin and restructure dispatch
This commit is contained in:
parent
f356e304e8
commit
30fff9e52f
22 changed files with 142 additions and 133 deletions
|
|
@ -0,0 +1,15 @@
|
|||
"""``agent_task`` action: spin up multi_agent_chat for one rendered query.
|
||||
|
||||
Imports ``definition`` for its side-effect (self-registration on the actions
|
||||
registry) and re-exports ``build_handler`` for direct consumers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .factory import build_handler
|
||||
from .params import AgentTaskActionParams
|
||||
|
||||
__all__ = ["AgentTaskActionParams", "build_handler"]
|
||||
|
||||
# Side-effect: register on the actions store.
|
||||
from . import definition # noqa: F401
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""Synthesize HITL decisions for every pending interrupt (approve-all or reject-all)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def build_auto_decisions(
|
||||
state: Any, decision: str
|
||||
) -> tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any]]]:
|
||||
"""Return ``(lg_resume_map, surfsense_resume_value)`` covering every pending interrupt.
|
||||
|
||||
``lg_resume_map`` is keyed by ``Interrupt.id`` for ``Command(resume=...)``;
|
||||
``surfsense_resume_value`` is keyed by ``tool_call_id`` for the subagent
|
||||
middleware bridge. Action count is read from ``value.action_requests`` when
|
||||
present and falls back to ``1`` for wrapped scalar interrupts.
|
||||
"""
|
||||
lg_resume_map: dict[str, dict[str, Any]] = {}
|
||||
routed: dict[str, dict[str, Any]] = {}
|
||||
|
||||
for interrupt_obj in getattr(state, "interrupts", ()) or ():
|
||||
value = getattr(interrupt_obj, "value", None)
|
||||
if not isinstance(value, dict):
|
||||
continue
|
||||
interrupt_id = getattr(interrupt_obj, "id", None)
|
||||
if not isinstance(interrupt_id, str):
|
||||
continue
|
||||
|
||||
action_requests = value.get("action_requests")
|
||||
count = len(action_requests) if isinstance(action_requests, list) else 1
|
||||
decisions = [{"type": decision} for _ in range(count)]
|
||||
|
||||
lg_resume_map[interrupt_id] = {"decisions": decisions}
|
||||
|
||||
tool_call_id = value.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
routed[tool_call_id] = {"decisions": decisions}
|
||||
|
||||
return lg_resume_map, routed
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
"""``agent_task`` ``ActionDefinition`` registration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ...store import register_action
|
||||
from ...types import ActionDefinition
|
||||
from .factory import build_handler
|
||||
from .params import AgentTaskActionParams
|
||||
|
||||
AGENT_TASK_ACTION = ActionDefinition(
|
||||
type="agent_task",
|
||||
name="Agent task",
|
||||
description="Run a multi_agent_chat turn from an automation step.",
|
||||
params_model=AgentTaskActionParams,
|
||||
build_handler=build_handler,
|
||||
)
|
||||
|
||||
register_action(AGENT_TASK_ACTION)
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
"""Build the per-invocation dependencies the multi_agent_chat factory needs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
|
||||
from app.tasks.chat.streaming.flows.shared.pre_stream_setup import (
|
||||
setup_connector_and_firecrawl,
|
||||
)
|
||||
|
||||
|
||||
class DependencyError(Exception):
|
||||
"""An external dependency (LLM config, connector service, ...) refused to load."""
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AgentDependencies:
|
||||
"""Everything ``create_multi_agent_chat_deep_agent`` needs from the environment."""
|
||||
|
||||
llm: Any
|
||||
agent_config: Any
|
||||
connector_service: Any
|
||||
firecrawl_api_key: str | None
|
||||
checkpointer: Any
|
||||
|
||||
|
||||
async def build_dependencies(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> AgentDependencies:
|
||||
"""Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer.
|
||||
|
||||
Uses the search space's default LLM config (``config_id=-1``). Per-step
|
||||
model overrides land in a future iteration alongside the ``model`` param.
|
||||
"""
|
||||
llm, agent_config, err = await load_llm_bundle(
|
||||
session, config_id=-1, search_space_id=search_space_id
|
||||
)
|
||||
if err is not None or llm is None:
|
||||
raise DependencyError(err or "failed to load default LLM config")
|
||||
|
||||
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
|
||||
session, search_space_id=search_space_id
|
||||
)
|
||||
# Quick fix: use an in-memory checkpointer for automation runs.
|
||||
#
|
||||
# The shared Postgres checkpointer caches DB connections in a
|
||||
# module-level pool. Each cached connection is bound to the asyncio
|
||||
# loop that opened it. Celery throws away the loop after every task,
|
||||
# so the pool ends up full of connections pointing to a dead loop,
|
||||
# and the next Celery task (running on a fresh loop) can't use any
|
||||
# of them — it hangs 30s and fails with
|
||||
# `PoolTimeout: couldn't get a connection after 30.00 sec`.
|
||||
#
|
||||
# InMemorySaver has no cached connections, no loop binding — each
|
||||
# Celery task creates one and drops it on exit.
|
||||
#
|
||||
# TODO(checkpointer): proper fix is to dispose the checkpointer
|
||||
# pool around each Celery task in `run_async_celery_task`, the same
|
||||
# way `_dispose_shared_db_engine` already does for the SQLAlchemy
|
||||
# pool. Then this site can switch back to the shared checkpointer.
|
||||
checkpointer = InMemorySaver()
|
||||
return AgentDependencies(
|
||||
llm=llm,
|
||||
agent_config=agent_config,
|
||||
connector_service=connector_service,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
"""Bind ``ActionContext`` to a callable that runs one ``agent_task`` step."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ...types import ActionContext, ActionHandler
|
||||
from .invoke import run_agent_task
|
||||
from .params import AgentTaskActionParams
|
||||
|
||||
|
||||
def build_handler(ctx: ActionContext) -> ActionHandler:
|
||||
"""Return a handler closure that validates params and runs the agent task."""
|
||||
|
||||
async def handle(params: dict[str, Any]) -> dict[str, Any]:
|
||||
validated = AgentTaskActionParams.model_validate(params)
|
||||
return await run_agent_task(
|
||||
ctx=ctx,
|
||||
query=validated.query,
|
||||
auto_approve_all=validated.auto_approve_all,
|
||||
mentioned_document_ids=validated.mentioned_document_ids,
|
||||
mentioned_folder_ids=validated.mentioned_folder_ids,
|
||||
mentioned_connector_ids=validated.mentioned_connector_ids,
|
||||
mentioned_connectors=validated.mentioned_connectors,
|
||||
mentioned_documents=validated.mentioned_documents,
|
||||
)
|
||||
|
||||
return handle
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
"""Extract the agent's final assistant text from the terminal invoke result."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
def extract_final_assistant_message(result: Any) -> str | None:
|
||||
"""Return the last ``AIMessage`` text content, or ``None`` if there isn't one.
|
||||
|
||||
Multi-part messages (content lists) are flattened by concatenating ``text``
|
||||
parts in order. Non-string content (tool calls, images) is skipped.
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
messages = result.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
return None
|
||||
|
||||
for msg in reversed(messages):
|
||||
if not isinstance(msg, AIMessage):
|
||||
continue
|
||||
return _content_to_text(msg.content)
|
||||
return None
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str | None:
|
||||
if isinstance(content, str):
|
||||
text = content.strip()
|
||||
return text or None
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
parts.append(part)
|
||||
elif isinstance(part, dict) and part.get("type") == "text":
|
||||
text = part.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
joined = "".join(parts).strip()
|
||||
return joined or None
|
||||
return None
|
||||
|
|
@ -0,0 +1,223 @@
|
|||
"""Run one ``agent_task`` invocation: ainvoke + auto-decision resume loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import Command
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text
|
||||
from app.db import ChatVisibility, async_session_maker
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
from ...types import ActionContext
|
||||
from .auto_decide import build_auto_decisions
|
||||
from .dependencies import build_dependencies
|
||||
from .finalize import extract_final_assistant_message
|
||||
|
||||
# Cap on HITL resume iterations. The agent should not need this many turns in one
|
||||
# step; treat overshoot as a runaway and fail the step.
|
||||
_MAX_RESUMES = 50
|
||||
|
||||
|
||||
def _build_connector_block(connectors: list[dict[str, Any]]) -> str | None:
|
||||
"""Render the ``<mentioned_connectors>`` context block (same shape as chat).
|
||||
|
||||
Mirrors ``stream_new_chat`` so the agent gets the exact connector accounts
|
||||
the user picked. Returns ``None`` when nothing renders.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
for connector in connectors:
|
||||
connector_id = connector.get("id")
|
||||
connector_type = connector.get("connector_type") or connector.get(
|
||||
"document_type"
|
||||
)
|
||||
account_name = connector.get("account_name") or connector.get("title")
|
||||
if connector_id is None or connector_type is None:
|
||||
continue
|
||||
lines.append(
|
||||
f' - connector_id={connector_id}, connector_type="{connector_type}", '
|
||||
f'account_name="{account_name or ""}"'
|
||||
)
|
||||
if not lines:
|
||||
return None
|
||||
return (
|
||||
"<mentioned_connectors>\n"
|
||||
"The user selected these exact connector accounts with @. "
|
||||
"These entries are selection metadata, not retrieved connector content. "
|
||||
"When a connector-backed tool needs an account, use the matching "
|
||||
"connector_id from this list if the tool supports connector_id:\n"
|
||||
+ "\n".join(lines)
|
||||
+ "\n</mentioned_connectors>"
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_mention_context(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
query: str,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
mentioned_folder_ids: list[int] | None,
|
||||
mentioned_connector_ids: list[int] | None,
|
||||
mentioned_connectors: list[MentionedDocumentInfo] | None,
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None,
|
||||
) -> tuple[str, SurfSenseContextSchema | None]:
|
||||
"""Resolve @-mentions into a rewritten query + per-invocation context.
|
||||
|
||||
Automation always runs in cloud filesystem mode, so we mirror the chat
|
||||
``new_chat`` flow: substitute ``@title`` tokens with canonical
|
||||
``/documents/...`` paths, prepend a ``<mentioned_connectors>`` block, and
|
||||
build a ``SurfSenseContextSchema`` that ``KnowledgePriorityMiddleware``
|
||||
reads via ``runtime.context``. Returns ``(query, None)`` unchanged when
|
||||
there are no mentions.
|
||||
"""
|
||||
has_mentions = bool(
|
||||
mentioned_document_ids
|
||||
or mentioned_folder_ids
|
||||
or mentioned_connector_ids
|
||||
or mentioned_connectors
|
||||
or mentioned_documents
|
||||
)
|
||||
if not has_mentions:
|
||||
return query, None
|
||||
|
||||
resolved = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
mentioned_documents=mentioned_documents,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
mentioned_folder_ids=mentioned_folder_ids,
|
||||
)
|
||||
agent_query = substitute_in_text(query, resolved.token_to_path)
|
||||
|
||||
# ``SurfSenseContextSchema.mentioned_connectors`` is typed ``list[dict]`` and
|
||||
# the connector block reads dicts, so dump the pydantic chips once.
|
||||
connector_dicts = [c.model_dump() for c in (mentioned_connectors or [])]
|
||||
connector_block = _build_connector_block(connector_dicts)
|
||||
if connector_block:
|
||||
agent_query = f"{connector_block}\n\n<user_query>{agent_query}</user_query>"
|
||||
|
||||
runtime_context = SurfSenseContextSchema(
|
||||
search_space_id=search_space_id,
|
||||
mentioned_document_ids=list(
|
||||
resolved.mentioned_document_ids or (mentioned_document_ids or [])
|
||||
),
|
||||
mentioned_folder_ids=list(
|
||||
resolved.mentioned_folder_ids or (mentioned_folder_ids or [])
|
||||
),
|
||||
mentioned_connector_ids=list(mentioned_connector_ids or []),
|
||||
mentioned_connectors=connector_dicts,
|
||||
)
|
||||
return agent_query, runtime_context
|
||||
|
||||
|
||||
async def run_agent_task(
|
||||
*,
|
||||
ctx: ActionContext,
|
||||
query: str,
|
||||
auto_approve_all: bool,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
mentioned_folder_ids: list[int] | None = None,
|
||||
mentioned_connector_ids: list[int] | None = None,
|
||||
mentioned_connectors: list[MentionedDocumentInfo] | None = None,
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Invoke multi_agent_chat for one rendered query and return its outcome.
|
||||
|
||||
Opens its own DB session so the executor's bookkeeping session isn't tied
|
||||
up for the entire invocation. The LangGraph ``thread_id`` (a fresh UUID)
|
||||
is returned as ``agent_session_id`` for later inspection.
|
||||
|
||||
@-mentions (files / folders / connectors) chosen in the task input are
|
||||
resolved the same way the chat flow does and forwarded to the agent via the
|
||||
per-invocation ``context`` so they actually scope retrieval.
|
||||
"""
|
||||
agent_session_id = str(uuid.uuid4())
|
||||
user_id = str(ctx.creator_user_id) if ctx.creator_user_id else None
|
||||
decision = "approve" if auto_approve_all else "reject"
|
||||
|
||||
async with async_session_maker() as agent_session:
|
||||
deps = await build_dependencies(
|
||||
session=agent_session,
|
||||
search_space_id=ctx.search_space_id,
|
||||
)
|
||||
|
||||
agent = await create_multi_agent_chat_deep_agent(
|
||||
llm=deps.llm,
|
||||
search_space_id=ctx.search_space_id,
|
||||
db_session=agent_session,
|
||||
connector_service=deps.connector_service,
|
||||
checkpointer=deps.checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=None,
|
||||
agent_config=deps.agent_config,
|
||||
firecrawl_api_key=deps.firecrawl_api_key,
|
||||
thread_visibility=ChatVisibility.PRIVATE,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
)
|
||||
|
||||
agent_query, runtime_context = await _resolve_mention_context(
|
||||
agent_session,
|
||||
search_space_id=ctx.search_space_id,
|
||||
query=query,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
mentioned_folder_ids=mentioned_folder_ids,
|
||||
mentioned_connector_ids=mentioned_connector_ids,
|
||||
mentioned_connectors=mentioned_connectors,
|
||||
mentioned_documents=mentioned_documents,
|
||||
)
|
||||
|
||||
request_id = f"automation:{ctx.run_id}:{ctx.step_id}"
|
||||
turn_id = f"{request_id}:{int(time.time() * 1000)}"
|
||||
input_state: dict[str, Any] = {
|
||||
"messages": [HumanMessage(content=agent_query)],
|
||||
"search_space_id": ctx.search_space_id,
|
||||
"request_id": request_id,
|
||||
"turn_id": turn_id,
|
||||
}
|
||||
config: dict[str, Any] = {
|
||||
"configurable": {
|
||||
"thread_id": agent_session_id,
|
||||
"request_id": request_id,
|
||||
"turn_id": turn_id,
|
||||
},
|
||||
"recursion_limit": 10_000,
|
||||
}
|
||||
if runtime_context is not None:
|
||||
runtime_context.request_id = request_id
|
||||
runtime_context.turn_id = turn_id
|
||||
|
||||
# The compiled graph declares ``context_schema=SurfSenseContextSchema``;
|
||||
# mentions only reach ``KnowledgePriorityMiddleware`` via ``context=``.
|
||||
invoke_kwargs: dict[str, Any] = {"config": config}
|
||||
if runtime_context is not None:
|
||||
invoke_kwargs["context"] = runtime_context
|
||||
|
||||
result = await agent.ainvoke(input_state, **invoke_kwargs)
|
||||
|
||||
resumes = 0
|
||||
while True:
|
||||
state = await agent.aget_state(config)
|
||||
if not getattr(state, "interrupts", None):
|
||||
break
|
||||
if resumes >= _MAX_RESUMES:
|
||||
raise RuntimeError(
|
||||
f"agent_task exceeded {_MAX_RESUMES} HITL resume iterations"
|
||||
)
|
||||
lg_resume_map, routed = build_auto_decisions(state, decision)
|
||||
config["configurable"]["surfsense_resume_value"] = routed
|
||||
result = await agent.ainvoke(Command(resume=lg_resume_map), **invoke_kwargs)
|
||||
resumes += 1
|
||||
|
||||
return {
|
||||
"agent_session_id": agent_session_id,
|
||||
"final_message": extract_final_assistant_message(result),
|
||||
"resumes": resumes,
|
||||
}
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
"""``AgentTaskActionParams`` — params for the ``agent_task`` action type."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
|
||||
class AgentTaskActionParams(BaseModel):
|
||||
"""Run a multi_agent_chat turn from an automation step."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
query: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="User query for the agent; rendered at execute time.",
|
||||
)
|
||||
auto_approve_all: bool = Field(
|
||||
default=False,
|
||||
description="If true, every HITL approval is auto-approved; otherwise rejected.",
|
||||
)
|
||||
|
||||
# @-mention references chosen in the task input. Mirror the ``new_chat``
|
||||
# request fields (minus SurfSense product docs) so the run can scope
|
||||
# retrieval to the user's selected files / folders / connectors. All
|
||||
# optional and additive; a task with no mentions behaves as before.
|
||||
mentioned_document_ids: list[int] | None = Field(
|
||||
default=None,
|
||||
description="Knowledge-base document IDs the task references with @.",
|
||||
)
|
||||
mentioned_folder_ids: list[int] | None = Field(
|
||||
default=None,
|
||||
description="Knowledge-base folder IDs the task references with @.",
|
||||
)
|
||||
mentioned_connector_ids: list[int] | None = Field(
|
||||
default=None,
|
||||
description="Concrete connector account IDs the task references with @.",
|
||||
)
|
||||
mentioned_connectors: list[MentionedDocumentInfo] | None = Field(
|
||||
default=None,
|
||||
description="Display/context metadata for the @-mentioned connector accounts.",
|
||||
)
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Chip metadata (id, title, kind, ...) for every @-mention so the "
|
||||
"run can resolve titles to virtual paths and substitute them in "
|
||||
"the query."
|
||||
),
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue