mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
Replace the connector-coupled BUILTIN_TOOLS registry with a pure-data catalog so shared/tools no longer imports any connector module, making the connector packages independently deletable. - add shared/tools/catalog.py (ToolMetadata + TOOL_CATALOG, 41 tools, no imports) - point GET /agent/tools (the only live consumer) at the catalog - relocate ToolDefinition into action_log middleware (its sole consumer); drop the inert tool_definitions wiring (no tool defines reverse) - delete shared/tools/registry.py: connector imports, dead factories, dead get_connector_gated_tools, and BUILTIN_TOOLS - drop stale dedup-propagation test (path removed in C1) + refresh docstrings import-all guardrail + agents unit suite green (987 passed).
388 lines
14 KiB
Python
388 lines
14 KiB
Python
"""Append-only action-log middleware for the SurfSense agent.
|
|
|
|
Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes
|
|
a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt
|
|
into reversibility by declaring a ``reverse`` callable on their
|
|
:class:`ToolDefinition`; the rendered descriptor is persisted in
|
|
``reverse_descriptor`` for use by
|
|
``/api/threads/{thread_id}/revert/{action_id}``.
|
|
|
|
Design points:
|
|
|
|
* **Defensive.** Logging never blocks the agent. We catch every exception
|
|
on the DB write path and emit a warning; the tool's ``ToolMessage``
|
|
result is always returned untouched.
|
|
* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) +
|
|
``result_id`` + ``reverse_descriptor`` are stored. Tool output text
|
|
remains in the LangGraph checkpoint / spilled tool-output files.
|
|
* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)``
|
|
with the parsed JSON result when the tool's content is a JSON object;
|
|
otherwise the raw text is passed. Exceptions in the reverse callable
|
|
are swallowed and logged — a failed descriptor render simply means the
|
|
action is NOT marked reversible.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from collections.abc import Awaitable, Callable
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langchain_core.callbacks import adispatch_custom_event
|
|
from langchain_core.messages import ToolMessage
|
|
|
|
from app.agents.shared.feature_flags import get_flags
|
|
|
|
if TYPE_CHECKING: # pragma: no cover - type-only
|
|
from langchain.agents.middleware.types import ToolCallRequest
|
|
from langgraph.types import Command
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ToolDefinition:
|
|
"""Reversibility descriptor consumed by :class:`ActionLogMiddleware`.
|
|
|
|
Only ``name`` and ``reverse`` are read by the middleware; the remaining
|
|
fields let callers and tests describe a tool declaratively. A tool is
|
|
marked reversible in the action log when ``reverse`` is set and renders a
|
|
descriptor without raising.
|
|
|
|
Attributes:
|
|
name: Unique identifier for the tool.
|
|
description: Human-readable description of what the tool does.
|
|
factory: Optional callable that builds the tool (unused by the
|
|
middleware; retained for declarative call sites/tests).
|
|
reverse: Optional callable that, given the tool's ``(args, result)``,
|
|
returns a ``ReverseDescriptor`` describing the inverse invocation.
|
|
|
|
"""
|
|
|
|
name: str
|
|
description: str = ""
|
|
factory: Callable[[dict[str, Any]], Any] | None = None
|
|
reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None
|
|
|
|
|
|
# Cap for the persisted ``args`` JSON to avoid bloating the action log with
|
|
# accidentally-huge inputs. Values are truncated and a flag is set in the
|
|
# stored payload so consumers can detect truncation.
|
|
_MAX_ARGS_PERSIST_BYTES = 32 * 1024 # 32KB
|
|
|
|
|
|
class ActionLogMiddleware(AgentMiddleware):
|
|
"""Persist a row in :class:`AgentActionLog` after every tool call.
|
|
|
|
Should be placed near the OUTERMOST end of the tool-call wrapping stack
|
|
so that it sees the *final* :class:`ToolMessage` after all retries,
|
|
permission checks, and dedup logic have run. In practice that means
|
|
placing it just inside :class:`PermissionMiddleware` and outside
|
|
:class:`DedupHITLToolCallsMiddleware`.
|
|
|
|
The middleware is fully a no-op when:
|
|
|
|
* the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set
|
|
(checked via :func:`get_flags`),
|
|
* the per-feature flag ``enable_action_log`` is off, or
|
|
* persistence raises (defensive: tool-call dispatch always succeeds).
|
|
|
|
Args:
|
|
thread_id: The current chat thread's primary-key id. Required to
|
|
persist a row; if ``None`` the middleware silently no-ops.
|
|
search_space_id: Search-space id for cascade-on-delete safety.
|
|
user_id: UUID string of the user driving this turn (nullable in
|
|
anonymous mode).
|
|
tool_definitions: Optional mapping of tool name -> :class:`ToolDefinition`
|
|
so the middleware can look up the tool's ``reverse`` callable.
|
|
When omitted, no actions are marked reversible.
|
|
"""
|
|
|
|
tools = ()
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
thread_id: int | None,
|
|
search_space_id: int,
|
|
user_id: str | None,
|
|
tool_definitions: dict[str, ToolDefinition] | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self._thread_id = thread_id
|
|
self._search_space_id = search_space_id
|
|
self._user_id = user_id
|
|
self._tool_definitions = dict(tool_definitions or {})
|
|
|
|
def _enabled(self) -> bool:
|
|
flags = get_flags()
|
|
if flags.disable_new_agent_stack:
|
|
return False
|
|
return bool(flags.enable_action_log) and self._thread_id is not None
|
|
|
|
async def awrap_tool_call(
|
|
self,
|
|
request: ToolCallRequest,
|
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
|
) -> ToolMessage | Command[Any]:
|
|
if not self._enabled():
|
|
return await handler(request)
|
|
|
|
result: ToolMessage | Command[Any]
|
|
error_payload: dict[str, Any] | None = None
|
|
try:
|
|
result = await handler(request)
|
|
except Exception as exc:
|
|
# Persist the failure too so revert/audit can see it, then
|
|
# re-raise so downstream middleware (RetryAfter, etc.) handles it.
|
|
error_payload = {"type": type(exc).__name__, "message": str(exc)}
|
|
await self._record(
|
|
request=request,
|
|
result=None,
|
|
error_payload=error_payload,
|
|
)
|
|
raise
|
|
|
|
await self._record(request=request, result=result, error_payload=None)
|
|
return result
|
|
|
|
async def _record(
|
|
self,
|
|
*,
|
|
request: ToolCallRequest,
|
|
result: ToolMessage | Command[Any] | None,
|
|
error_payload: dict[str, Any] | None,
|
|
) -> None:
|
|
"""Persist one ``agent_action_log`` row. Defensive: never raises."""
|
|
try:
|
|
from app.db import AgentActionLog, shielded_async_session
|
|
|
|
tool_name = _resolve_tool_name(request)
|
|
args_payload = _resolve_args_payload(request)
|
|
result_id = _resolve_result_id(result)
|
|
reverse_descriptor, reversible = self._render_reverse(
|
|
tool_name=tool_name,
|
|
args=_resolve_args_dict(request),
|
|
result=result,
|
|
)
|
|
|
|
tool_call_id = _resolve_tool_call_id(request)
|
|
chat_turn_id = _resolve_chat_turn_id(request)
|
|
|
|
row = AgentActionLog(
|
|
thread_id=self._thread_id,
|
|
user_id=self._user_id,
|
|
search_space_id=self._search_space_id,
|
|
# ``turn_id`` is the deprecated alias of ``tool_call_id``
|
|
# kept for one release for safe rollback. New consumers
|
|
# should read ``tool_call_id`` directly.
|
|
turn_id=tool_call_id,
|
|
tool_call_id=tool_call_id,
|
|
chat_turn_id=chat_turn_id,
|
|
message_id=_resolve_message_id(request),
|
|
tool_name=tool_name,
|
|
args=args_payload,
|
|
result_id=result_id,
|
|
reversible=reversible,
|
|
reverse_descriptor=reverse_descriptor,
|
|
error=error_payload,
|
|
)
|
|
async with shielded_async_session() as session:
|
|
session.add(row)
|
|
await session.commit()
|
|
row_id = int(row.id) if row.id is not None else None
|
|
row_created_at = row.created_at
|
|
except Exception:
|
|
logger.warning(
|
|
"ActionLogMiddleware failed to persist action log row",
|
|
exc_info=True,
|
|
)
|
|
return
|
|
|
|
# Surface a side-channel SSE event so the chat tool card can
|
|
# render a Revert button immediately after the row is durable.
|
|
# ``stream_new_chat`` translates this into a
|
|
# ``data-action-log`` SSE event. We DO NOT include the
|
|
# ``reverse_descriptor`` payload here; only a presence flag.
|
|
try:
|
|
await adispatch_custom_event(
|
|
"action_log",
|
|
{
|
|
"id": row_id,
|
|
"lc_tool_call_id": tool_call_id,
|
|
"chat_turn_id": chat_turn_id,
|
|
"tool_name": tool_name,
|
|
"reversible": bool(reversible),
|
|
"reverse_descriptor_present": reverse_descriptor is not None,
|
|
"created_at": row_created_at.isoformat()
|
|
if row_created_at
|
|
else None,
|
|
"error": error_payload is not None,
|
|
},
|
|
)
|
|
except Exception:
|
|
logger.debug(
|
|
"ActionLogMiddleware failed to dispatch action_log event",
|
|
exc_info=True,
|
|
)
|
|
|
|
def _render_reverse(
|
|
self,
|
|
*,
|
|
tool_name: str,
|
|
args: dict[str, Any] | None,
|
|
result: ToolMessage | Command[Any] | None,
|
|
) -> tuple[dict[str, Any] | None, bool]:
|
|
"""Run the tool's ``reverse`` callable and return its descriptor.
|
|
|
|
Returns a tuple of ``(descriptor_or_None, reversible_bool)``. When
|
|
the tool has no ``reverse`` callable, or when the callable raises,
|
|
the action is marked non-reversible.
|
|
"""
|
|
if not result or not isinstance(result, ToolMessage):
|
|
return None, False
|
|
if args is None:
|
|
return None, False
|
|
tool_def = self._tool_definitions.get(tool_name)
|
|
if tool_def is None or tool_def.reverse is None:
|
|
return None, False
|
|
try:
|
|
parsed_result = _parse_tool_result_content(result)
|
|
descriptor = tool_def.reverse(args, parsed_result)
|
|
except Exception:
|
|
logger.warning(
|
|
"Reverse descriptor render failed for tool %s",
|
|
tool_name,
|
|
exc_info=True,
|
|
)
|
|
return None, False
|
|
if not isinstance(descriptor, dict):
|
|
return None, False
|
|
return descriptor, True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Resolution helpers — defensive against tool_call request shape variation.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _resolve_tool_name(request: Any) -> str:
|
|
try:
|
|
tool = getattr(request, "tool", None)
|
|
if tool is not None:
|
|
name = getattr(tool, "name", None)
|
|
if isinstance(name, str) and name:
|
|
return name
|
|
call = getattr(request, "tool_call", None) or {}
|
|
if isinstance(call, dict):
|
|
name = call.get("name")
|
|
if isinstance(name, str) and name:
|
|
return name
|
|
except Exception: # pragma: no cover - defensive
|
|
pass
|
|
return "unknown"
|
|
|
|
|
|
def _resolve_args_dict(request: Any) -> dict[str, Any] | None:
|
|
try:
|
|
call = getattr(request, "tool_call", None)
|
|
if not isinstance(call, dict):
|
|
return None
|
|
args = call.get("args")
|
|
if isinstance(args, dict):
|
|
return args
|
|
return None
|
|
except Exception: # pragma: no cover - defensive
|
|
return None
|
|
|
|
|
|
def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
|
|
"""Return a JSON-serializable args dict, truncated if too big."""
|
|
args = _resolve_args_dict(request)
|
|
if args is None:
|
|
return None
|
|
try:
|
|
encoded = json.dumps(args, default=str)
|
|
except Exception:
|
|
return {"_repr": repr(args)[:_MAX_ARGS_PERSIST_BYTES]}
|
|
if len(encoded) <= _MAX_ARGS_PERSIST_BYTES:
|
|
return args
|
|
return {
|
|
"_truncated": True,
|
|
"_size": len(encoded),
|
|
"_preview": encoded[:_MAX_ARGS_PERSIST_BYTES],
|
|
}
|
|
|
|
|
|
def _resolve_tool_call_id(request: Any) -> str | None:
|
|
"""Return the LangChain ``tool_call.id`` for this request, if any."""
|
|
try:
|
|
call = getattr(request, "tool_call", None) or {}
|
|
if isinstance(call, dict):
|
|
tid = call.get("id")
|
|
if isinstance(tid, str):
|
|
return tid
|
|
except Exception: # pragma: no cover
|
|
pass
|
|
return None
|
|
|
|
|
|
# Deprecated alias kept for one release. Old callers and tests treated
|
|
# ``turn_id`` as if it carried the LangChain tool_call id; the new column
|
|
# lives under ``tool_call_id``. Both resolve to the same value today.
|
|
_resolve_turn_id = _resolve_tool_call_id
|
|
|
|
|
|
def _resolve_chat_turn_id(request: Any) -> str | None:
|
|
"""Return ``configurable.turn_id`` for this request, if accessible.
|
|
|
|
``ToolRuntime.config`` is exposed by LangGraph (see
|
|
``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id
|
|
lives at ``runtime.config["configurable"]["turn_id"]``.
|
|
"""
|
|
try:
|
|
runtime = getattr(request, "runtime", None)
|
|
if runtime is None:
|
|
return None
|
|
config = getattr(runtime, "config", None)
|
|
if not isinstance(config, dict):
|
|
return None
|
|
configurable = config.get("configurable")
|
|
if not isinstance(configurable, dict):
|
|
return None
|
|
value = configurable.get("turn_id")
|
|
if isinstance(value, str) and value:
|
|
return value
|
|
except Exception: # pragma: no cover - defensive
|
|
pass
|
|
return None
|
|
|
|
|
|
def _resolve_message_id(request: Any) -> str | None:
|
|
"""Tool-call IDs serve as best-available message correlator at this layer."""
|
|
return _resolve_tool_call_id(request)
|
|
|
|
|
|
def _resolve_result_id(result: Any) -> str | None:
|
|
if isinstance(result, ToolMessage):
|
|
msg_id = getattr(result, "id", None)
|
|
if isinstance(msg_id, str):
|
|
return msg_id
|
|
return None
|
|
|
|
|
|
def _parse_tool_result_content(result: ToolMessage) -> Any:
|
|
content = result.content
|
|
if isinstance(content, str):
|
|
try:
|
|
return json.loads(content)
|
|
except (json.JSONDecodeError, ValueError):
|
|
return content
|
|
return content
|
|
|
|
|
|
__all__ = ["ActionLogMiddleware"]
|