mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-08 20:25:19 +02:00
refactor(agents): colocate 8 main-agent-only middleware as per-concept folders
Each main-agent-only middleware now lives in its own folder under main_agent/middleware/<concept>/ with builder.py (flag-gated construction) + middleware.py (the impl), re-exported via __init__.py. This kills the cross-folder hop into agents/shared/middleware and keeps each middleware's two responsibilities (build vs behavior) as colocated siblings. Moved (impl from shared/middleware, builder from main_agent/middleware): action_log, anonymous_document, context_editing, doom_loop, knowledge_tree, noop_injection, otel_span, tool_call_repair. Impls moved verbatim (git rename, no body edits) so behavior is unchanged. Builders now import from the local .middleware sibling. stack.py import paths updated for the 3 renamed folders; shared middleware barrel trimmed; tests repointed (imports + patch targets).
This commit is contained in:
parent
fbd5ccc35a
commit
9493519c61
33 changed files with 149 additions and 83 deletions
|
|
@ -1,23 +1,10 @@
|
|||
"""Middleware components for the SurfSense new chat agent."""
|
||||
"""Shared middleware components for the SurfSense chat agents."""
|
||||
|
||||
from app.agents.shared.middleware.action_log import (
|
||||
ActionLogMiddleware,
|
||||
ToolDefinition,
|
||||
)
|
||||
from app.agents.shared.middleware.anonymous_document import (
|
||||
AnonymousDocumentMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.busy_mutex import BusyMutexMiddleware
|
||||
from app.agents.shared.middleware.compaction import (
|
||||
SurfSenseCompactionMiddleware,
|
||||
create_surfsense_compaction_middleware,
|
||||
)
|
||||
from app.agents.shared.middleware.context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
)
|
||||
from app.agents.shared.middleware.doom_loop import DoomLoopMiddleware
|
||||
from app.agents.shared.middleware.kb_persistence import (
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
commit_staged_filesystem_state,
|
||||
|
|
@ -25,39 +12,20 @@ from app.agents.shared.middleware.kb_persistence import (
|
|||
from app.agents.shared.middleware.knowledge_search import (
|
||||
KnowledgePriorityMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.knowledge_tree import (
|
||||
KnowledgeTreeMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.memory_injection import (
|
||||
MemoryInjectionMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.noop_injection import NoopInjectionMiddleware
|
||||
from app.agents.shared.middleware.otel_span import OtelSpanMiddleware
|
||||
from app.agents.shared.middleware.permission import PermissionMiddleware
|
||||
from app.agents.shared.middleware.retry_after import RetryAfterMiddleware
|
||||
from app.agents.shared.middleware.tool_call_repair import (
|
||||
ToolCallNameRepairMiddleware,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionLogMiddleware",
|
||||
"AnonymousDocumentMiddleware",
|
||||
"BusyMutexMiddleware",
|
||||
"ClearToolUsesEdit",
|
||||
"DoomLoopMiddleware",
|
||||
"KnowledgeBasePersistenceMiddleware",
|
||||
"KnowledgePriorityMiddleware",
|
||||
"KnowledgeTreeMiddleware",
|
||||
"MemoryInjectionMiddleware",
|
||||
"NoopInjectionMiddleware",
|
||||
"OtelSpanMiddleware",
|
||||
"PermissionMiddleware",
|
||||
"RetryAfterMiddleware",
|
||||
"SpillToBackendEdit",
|
||||
"SpillingContextEditingMiddleware",
|
||||
"SurfSenseCompactionMiddleware",
|
||||
"ToolCallNameRepairMiddleware",
|
||||
"ToolDefinition",
|
||||
"commit_staged_filesystem_state",
|
||||
"create_surfsense_compaction_middleware",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,388 +0,0 @@
|
|||
"""Append-only action-log middleware for the SurfSense agent.
|
||||
|
||||
Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes
|
||||
a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt
|
||||
into reversibility by declaring a ``reverse`` callable on their
|
||||
:class:`ToolDefinition`; the rendered descriptor is persisted in
|
||||
``reverse_descriptor`` for use by
|
||||
``/api/threads/{thread_id}/revert/{action_id}``.
|
||||
|
||||
Design points:
|
||||
|
||||
* **Defensive.** Logging never blocks the agent. We catch every exception
|
||||
on the DB write path and emit a warning; the tool's ``ToolMessage``
|
||||
result is always returned untouched.
|
||||
* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) +
|
||||
``result_id`` + ``reverse_descriptor`` are stored. Tool output text
|
||||
remains in the LangGraph checkpoint / spilled tool-output files.
|
||||
* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)``
|
||||
with the parsed JSON result when the tool's content is a JSON object;
|
||||
otherwise the raw text is passed. Exceptions in the reverse callable
|
||||
are swallowed and logged — a failed descriptor render simply means the
|
||||
action is NOT marked reversible.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.callbacks import adispatch_custom_event
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from app.agents.shared.feature_flags import get_flags
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - type-only
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""Reversibility descriptor consumed by :class:`ActionLogMiddleware`.
|
||||
|
||||
Only ``name`` and ``reverse`` are read by the middleware; the remaining
|
||||
fields let callers and tests describe a tool declaratively. A tool is
|
||||
marked reversible in the action log when ``reverse`` is set and renders a
|
||||
descriptor without raising.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for the tool.
|
||||
description: Human-readable description of what the tool does.
|
||||
factory: Optional callable that builds the tool (unused by the
|
||||
middleware; retained for declarative call sites/tests).
|
||||
reverse: Optional callable that, given the tool's ``(args, result)``,
|
||||
returns a ``ReverseDescriptor`` describing the inverse invocation.
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
factory: Callable[[dict[str, Any]], Any] | None = None
|
||||
reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None
|
||||
|
||||
|
||||
# Cap for the persisted ``args`` JSON to avoid bloating the action log with
|
||||
# accidentally-huge inputs. Values are truncated and a flag is set in the
|
||||
# stored payload so consumers can detect truncation.
|
||||
_MAX_ARGS_PERSIST_BYTES = 32 * 1024 # 32KB
|
||||
|
||||
|
||||
class ActionLogMiddleware(AgentMiddleware):
|
||||
"""Persist a row in :class:`AgentActionLog` after every tool call.
|
||||
|
||||
Should be placed near the OUTERMOST end of the tool-call wrapping stack
|
||||
so that it sees the *final* :class:`ToolMessage` after all retries,
|
||||
permission checks, and dedup logic have run. In practice that means
|
||||
placing it just inside :class:`PermissionMiddleware` and outside
|
||||
:class:`DedupHITLToolCallsMiddleware`.
|
||||
|
||||
The middleware is fully a no-op when:
|
||||
|
||||
* the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set
|
||||
(checked via :func:`get_flags`),
|
||||
* the per-feature flag ``enable_action_log`` is off, or
|
||||
* persistence raises (defensive: tool-call dispatch always succeeds).
|
||||
|
||||
Args:
|
||||
thread_id: The current chat thread's primary-key id. Required to
|
||||
persist a row; if ``None`` the middleware silently no-ops.
|
||||
search_space_id: Search-space id for cascade-on-delete safety.
|
||||
user_id: UUID string of the user driving this turn (nullable in
|
||||
anonymous mode).
|
||||
tool_definitions: Optional mapping of tool name -> :class:`ToolDefinition`
|
||||
so the middleware can look up the tool's ``reverse`` callable.
|
||||
When omitted, no actions are marked reversible.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
thread_id: int | None,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
tool_definitions: dict[str, ToolDefinition] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._thread_id = thread_id
|
||||
self._search_space_id = search_space_id
|
||||
self._user_id = user_id
|
||||
self._tool_definitions = dict(tool_definitions or {})
|
||||
|
||||
def _enabled(self) -> bool:
|
||||
flags = get_flags()
|
||||
if flags.disable_new_agent_stack:
|
||||
return False
|
||||
return bool(flags.enable_action_log) and self._thread_id is not None
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
if not self._enabled():
|
||||
return await handler(request)
|
||||
|
||||
result: ToolMessage | Command[Any]
|
||||
error_payload: dict[str, Any] | None = None
|
||||
try:
|
||||
result = await handler(request)
|
||||
except Exception as exc:
|
||||
# Persist the failure too so revert/audit can see it, then
|
||||
# re-raise so downstream middleware (RetryAfter, etc.) handles it.
|
||||
error_payload = {"type": type(exc).__name__, "message": str(exc)}
|
||||
await self._record(
|
||||
request=request,
|
||||
result=None,
|
||||
error_payload=error_payload,
|
||||
)
|
||||
raise
|
||||
|
||||
await self._record(request=request, result=result, error_payload=None)
|
||||
return result
|
||||
|
||||
async def _record(
|
||||
self,
|
||||
*,
|
||||
request: ToolCallRequest,
|
||||
result: ToolMessage | Command[Any] | None,
|
||||
error_payload: dict[str, Any] | None,
|
||||
) -> None:
|
||||
"""Persist one ``agent_action_log`` row. Defensive: never raises."""
|
||||
try:
|
||||
from app.db import AgentActionLog, shielded_async_session
|
||||
|
||||
tool_name = _resolve_tool_name(request)
|
||||
args_payload = _resolve_args_payload(request)
|
||||
result_id = _resolve_result_id(result)
|
||||
reverse_descriptor, reversible = self._render_reverse(
|
||||
tool_name=tool_name,
|
||||
args=_resolve_args_dict(request),
|
||||
result=result,
|
||||
)
|
||||
|
||||
tool_call_id = _resolve_tool_call_id(request)
|
||||
chat_turn_id = _resolve_chat_turn_id(request)
|
||||
|
||||
row = AgentActionLog(
|
||||
thread_id=self._thread_id,
|
||||
user_id=self._user_id,
|
||||
search_space_id=self._search_space_id,
|
||||
# ``turn_id`` is the deprecated alias of ``tool_call_id``
|
||||
# kept for one release for safe rollback. New consumers
|
||||
# should read ``tool_call_id`` directly.
|
||||
turn_id=tool_call_id,
|
||||
tool_call_id=tool_call_id,
|
||||
chat_turn_id=chat_turn_id,
|
||||
message_id=_resolve_message_id(request),
|
||||
tool_name=tool_name,
|
||||
args=args_payload,
|
||||
result_id=result_id,
|
||||
reversible=reversible,
|
||||
reverse_descriptor=reverse_descriptor,
|
||||
error=error_payload,
|
||||
)
|
||||
async with shielded_async_session() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
row_id = int(row.id) if row.id is not None else None
|
||||
row_created_at = row.created_at
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"ActionLogMiddleware failed to persist action log row",
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Surface a side-channel SSE event so the chat tool card can
|
||||
# render a Revert button immediately after the row is durable.
|
||||
# ``stream_new_chat`` translates this into a
|
||||
# ``data-action-log`` SSE event. We DO NOT include the
|
||||
# ``reverse_descriptor`` payload here; only a presence flag.
|
||||
try:
|
||||
await adispatch_custom_event(
|
||||
"action_log",
|
||||
{
|
||||
"id": row_id,
|
||||
"lc_tool_call_id": tool_call_id,
|
||||
"chat_turn_id": chat_turn_id,
|
||||
"tool_name": tool_name,
|
||||
"reversible": bool(reversible),
|
||||
"reverse_descriptor_present": reverse_descriptor is not None,
|
||||
"created_at": row_created_at.isoformat()
|
||||
if row_created_at
|
||||
else None,
|
||||
"error": error_payload is not None,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ActionLogMiddleware failed to dispatch action_log event",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _render_reverse(
|
||||
self,
|
||||
*,
|
||||
tool_name: str,
|
||||
args: dict[str, Any] | None,
|
||||
result: ToolMessage | Command[Any] | None,
|
||||
) -> tuple[dict[str, Any] | None, bool]:
|
||||
"""Run the tool's ``reverse`` callable and return its descriptor.
|
||||
|
||||
Returns a tuple of ``(descriptor_or_None, reversible_bool)``. When
|
||||
the tool has no ``reverse`` callable, or when the callable raises,
|
||||
the action is marked non-reversible.
|
||||
"""
|
||||
if not result or not isinstance(result, ToolMessage):
|
||||
return None, False
|
||||
if args is None:
|
||||
return None, False
|
||||
tool_def = self._tool_definitions.get(tool_name)
|
||||
if tool_def is None or tool_def.reverse is None:
|
||||
return None, False
|
||||
try:
|
||||
parsed_result = _parse_tool_result_content(result)
|
||||
descriptor = tool_def.reverse(args, parsed_result)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Reverse descriptor render failed for tool %s",
|
||||
tool_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return None, False
|
||||
if not isinstance(descriptor, dict):
|
||||
return None, False
|
||||
return descriptor, True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Resolution helpers — defensive against tool_call request shape variation.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_tool_name(request: Any) -> str:
|
||||
try:
|
||||
tool = getattr(request, "tool", None)
|
||||
if tool is not None:
|
||||
name = getattr(tool, "name", None)
|
||||
if isinstance(name, str) and name:
|
||||
return name
|
||||
call = getattr(request, "tool_call", None) or {}
|
||||
if isinstance(call, dict):
|
||||
name = call.get("name")
|
||||
if isinstance(name, str) and name:
|
||||
return name
|
||||
except Exception: # pragma: no cover - defensive
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _resolve_args_dict(request: Any) -> dict[str, Any] | None:
|
||||
try:
|
||||
call = getattr(request, "tool_call", None)
|
||||
if not isinstance(call, dict):
|
||||
return None
|
||||
args = call.get("args")
|
||||
if isinstance(args, dict):
|
||||
return args
|
||||
return None
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
|
||||
"""Return a JSON-serializable args dict, truncated if too big."""
|
||||
args = _resolve_args_dict(request)
|
||||
if args is None:
|
||||
return None
|
||||
try:
|
||||
encoded = json.dumps(args, default=str)
|
||||
except Exception:
|
||||
return {"_repr": repr(args)[:_MAX_ARGS_PERSIST_BYTES]}
|
||||
if len(encoded) <= _MAX_ARGS_PERSIST_BYTES:
|
||||
return args
|
||||
return {
|
||||
"_truncated": True,
|
||||
"_size": len(encoded),
|
||||
"_preview": encoded[:_MAX_ARGS_PERSIST_BYTES],
|
||||
}
|
||||
|
||||
|
||||
def _resolve_tool_call_id(request: Any) -> str | None:
|
||||
"""Return the LangChain ``tool_call.id`` for this request, if any."""
|
||||
try:
|
||||
call = getattr(request, "tool_call", None) or {}
|
||||
if isinstance(call, dict):
|
||||
tid = call.get("id")
|
||||
if isinstance(tid, str):
|
||||
return tid
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
# Deprecated alias kept for one release. Old callers and tests treated
|
||||
# ``turn_id`` as if it carried the LangChain tool_call id; the new column
|
||||
# lives under ``tool_call_id``. Both resolve to the same value today.
|
||||
_resolve_turn_id = _resolve_tool_call_id
|
||||
|
||||
|
||||
def _resolve_chat_turn_id(request: Any) -> str | None:
|
||||
"""Return ``configurable.turn_id`` for this request, if accessible.
|
||||
|
||||
``ToolRuntime.config`` is exposed by LangGraph (see
|
||||
``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id
|
||||
lives at ``runtime.config["configurable"]["turn_id"]``.
|
||||
"""
|
||||
try:
|
||||
runtime = getattr(request, "runtime", None)
|
||||
if runtime is None:
|
||||
return None
|
||||
config = getattr(runtime, "config", None)
|
||||
if not isinstance(config, dict):
|
||||
return None
|
||||
configurable = config.get("configurable")
|
||||
if not isinstance(configurable, dict):
|
||||
return None
|
||||
value = configurable.get("turn_id")
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
except Exception: # pragma: no cover - defensive
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_message_id(request: Any) -> str | None:
|
||||
"""Tool-call IDs serve as best-available message correlator at this layer."""
|
||||
return _resolve_tool_call_id(request)
|
||||
|
||||
|
||||
def _resolve_result_id(result: Any) -> str | None:
|
||||
if isinstance(result, ToolMessage):
|
||||
msg_id = getattr(result, "id", None)
|
||||
if isinstance(msg_id, str):
|
||||
return msg_id
|
||||
return None
|
||||
|
||||
|
||||
def _parse_tool_result_content(result: ToolMessage) -> Any:
|
||||
content = result.content
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
return json.loads(content)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return content
|
||||
return content
|
||||
|
||||
|
||||
__all__ = ["ActionLogMiddleware"]
|
||||
|
|
@ -1,93 +0,0 @@
|
|||
"""Lightweight middleware that loads the anonymous-session document into state.
|
||||
|
||||
Anonymous chats receive a single uploaded document via Redis (no DB row,
|
||||
read-only). This middleware loads it once on the first turn into
|
||||
``state['kb_anon_doc']`` so:
|
||||
|
||||
* :class:`KnowledgeTreeMiddleware` can render the synthetic ``/documents``
|
||||
view without touching the DB.
|
||||
* :class:`KnowledgePriorityMiddleware` skips hybrid search and emits a
|
||||
degenerate priority list.
|
||||
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / ``_load_file_data``)
|
||||
recognises the synthetic path.
|
||||
|
||||
The middleware is a no-op when ``anon_session_id`` is not provided or when
|
||||
the document is already cached in state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.shared.path_resolver import DOCUMENTS_ROOT, safe_filename
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnonymousDocumentMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Load the anonymous user's uploaded document from Redis into state."""
|
||||
|
||||
tools = ()
|
||||
state_schema = SurfSenseFilesystemState
|
||||
|
||||
def __init__(self, *, anon_session_id: str | None) -> None:
|
||||
self.anon_session_id = anon_session_id
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
if not self.anon_session_id:
|
||||
return None
|
||||
if state.get("kb_anon_doc"):
|
||||
return None
|
||||
|
||||
anon_doc = await self._load_anon_document()
|
||||
if anon_doc is None:
|
||||
return None
|
||||
return {"kb_anon_doc": anon_doc}
|
||||
|
||||
async def _load_anon_document(self) -> dict[str, Any] | None:
|
||||
"""Read ``anon:doc:<session_id>`` from Redis."""
|
||||
try:
|
||||
import redis.asyncio as aioredis # local import to keep cold paths cheap
|
||||
|
||||
from app.config import config
|
||||
|
||||
redis_client = aioredis.from_url(
|
||||
config.REDIS_APP_URL, decode_responses=True
|
||||
)
|
||||
try:
|
||||
redis_key = f"anon:doc:{self.anon_session_id}"
|
||||
data = await redis_client.get(redis_key)
|
||||
if not data:
|
||||
return None
|
||||
payload = json.loads(data)
|
||||
finally:
|
||||
await redis_client.aclose()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load anonymous document from Redis: %s", exc)
|
||||
return None
|
||||
|
||||
title = str(payload.get("filename") or "uploaded_document")
|
||||
content = str(payload.get("content") or "")
|
||||
path = f"{DOCUMENTS_ROOT}/{safe_filename(title)}"
|
||||
return {
|
||||
"path": path,
|
||||
"title": title,
|
||||
"content": content,
|
||||
"chunks": [{"chunk_id": -1, "content": content}] if content else [],
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["AnonymousDocumentMiddleware"]
|
||||
|
|
@ -1,350 +0,0 @@
|
|||
"""
|
||||
SpillToBackendEdit + SpillingContextEditingMiddleware.
|
||||
|
||||
LangChain's :class:`ClearToolUsesEdit` discards old ``ToolMessage.content``
|
||||
when the context-editing budget triggers, replacing the body with a fixed
|
||||
placeholder. That's lossy: anything the agent might want to revisit is
|
||||
gone. The spill-to-disk pattern (originally from OpenCode's
|
||||
``opencode/packages/opencode/src/tool/truncate.ts``) keeps the prune
|
||||
behavior but writes the full original payload to the runtime backend
|
||||
under ``/tool_outputs/{thread_id}/{message_id}.txt`` first. The
|
||||
placeholder is then upgraded to point at the spill path so the agent
|
||||
(or a subagent) can read it back on demand.
|
||||
|
||||
Why this is a middleware subclass instead of a plain ``ContextEdit``:
|
||||
``ContextEdit.apply`` is sync, but writing to the backend is async. We
|
||||
capture the spill payloads inside ``apply`` and flush them via
|
||||
``await backend.aupload_files(...)`` from ``awrap_model_call`` *before*
|
||||
delegating to the handler, so the explore subagent can always read what
|
||||
the placeholder advertises.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware.context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
ContextEdit,
|
||||
ContextEditingMiddleware,
|
||||
TokenCounter,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.utils import count_tokens_approximately
|
||||
from langgraph.config import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deepagents.backends.protocol import BackendProtocol
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SPILL_PREFIX = "/tool_outputs"
|
||||
|
||||
|
||||
def _build_spill_placeholder(spill_path: str) -> str:
|
||||
"""Build the user-facing placeholder text shown to the model."""
|
||||
return (
|
||||
f"[cleared — full output at {spill_path}; ask the explore subagent to read it]"
|
||||
)
|
||||
|
||||
|
||||
def _get_thread_id_or_session() -> str:
|
||||
"""Best-effort thread_id discovery for the spill path.
|
||||
|
||||
Falls back to a process-stable string if no LangGraph config is
|
||||
available (e.g. unit tests). The exact value doesn't matter as long
|
||||
as it's stable within one stream so the placeholder paths line up
|
||||
with the actual upload path.
|
||||
"""
|
||||
try:
|
||||
config = get_config()
|
||||
thread_id = config.get("configurable", {}).get("thread_id")
|
||||
if thread_id is not None:
|
||||
return str(thread_id)
|
||||
except RuntimeError:
|
||||
pass
|
||||
return "no_thread"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SpillToBackendEdit(ContextEdit):
|
||||
"""Capture-and-replace context edit that spills full tool output to the backend.
|
||||
|
||||
Behaves like :class:`ClearToolUsesEdit` (same trigger / keep / exclude
|
||||
semantics) **and** records the original ``ToolMessage.content`` in
|
||||
:attr:`pending_spills` so the wrapping middleware can flush them
|
||||
before the model call.
|
||||
|
||||
Args:
|
||||
trigger: Token threshold above which the edit fires.
|
||||
clear_at_least: Minimum number of tokens to reclaim (best effort).
|
||||
keep: Number of most-recent ``ToolMessage`` instances to leave
|
||||
untouched.
|
||||
exclude_tools: Names of tools whose output is NOT spilled.
|
||||
clear_tool_inputs: Also clear the originating ``AIMessage.tool_calls``
|
||||
args when their pair is cleared.
|
||||
path_prefix: Path under the backend where spills are written.
|
||||
Default ``"/tool_outputs"``.
|
||||
"""
|
||||
|
||||
trigger: int = 100_000
|
||||
clear_at_least: int = 0
|
||||
keep: int = 3
|
||||
clear_tool_inputs: bool = False
|
||||
exclude_tools: Sequence[str] = ()
|
||||
path_prefix: str = DEFAULT_SPILL_PREFIX
|
||||
|
||||
pending_spills: list[tuple[str, bytes]] = field(default_factory=list)
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
def drain_pending(self) -> list[tuple[str, bytes]]:
|
||||
"""Return and clear the pending-spill list atomically."""
|
||||
with self._lock:
|
||||
out = list(self.pending_spills)
|
||||
self.pending_spills.clear()
|
||||
return out
|
||||
|
||||
def apply(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
*,
|
||||
count_tokens: TokenCounter,
|
||||
) -> None:
|
||||
"""Mirror ``ClearToolUsesEdit.apply`` but capture originals first."""
|
||||
tokens = count_tokens(messages)
|
||||
if tokens <= self.trigger:
|
||||
return
|
||||
|
||||
candidates = [
|
||||
(idx, msg)
|
||||
for idx, msg in enumerate(messages)
|
||||
if isinstance(msg, ToolMessage)
|
||||
]
|
||||
if self.keep >= len(candidates):
|
||||
return
|
||||
if self.keep:
|
||||
candidates = candidates[: -self.keep]
|
||||
|
||||
thread_id = _get_thread_id_or_session()
|
||||
excluded_tools = set(self.exclude_tools)
|
||||
|
||||
for idx, tool_message in candidates:
|
||||
if tool_message.response_metadata.get("context_editing", {}).get("cleared"):
|
||||
continue
|
||||
|
||||
ai_message = next(
|
||||
(m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)),
|
||||
None,
|
||||
)
|
||||
if ai_message is None:
|
||||
continue
|
||||
|
||||
tool_call = next(
|
||||
(
|
||||
call
|
||||
for call in ai_message.tool_calls
|
||||
if call.get("id") == tool_message.tool_call_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if tool_call is None:
|
||||
continue
|
||||
|
||||
tool_name = tool_message.name or tool_call["name"]
|
||||
if tool_name in excluded_tools:
|
||||
continue
|
||||
|
||||
message_id = tool_message.id or tool_message.tool_call_id or "unknown"
|
||||
spill_path = f"{self.path_prefix}/{thread_id}/{message_id}.txt"
|
||||
|
||||
original = tool_message.content
|
||||
payload = self._encode_payload(original)
|
||||
with self._lock:
|
||||
self.pending_spills.append((spill_path, payload))
|
||||
|
||||
messages[idx] = tool_message.model_copy(
|
||||
update={
|
||||
"artifact": None,
|
||||
"content": _build_spill_placeholder(spill_path),
|
||||
"response_metadata": {
|
||||
**tool_message.response_metadata,
|
||||
"context_editing": {
|
||||
"cleared": True,
|
||||
"strategy": "spill_to_backend",
|
||||
"spill_path": spill_path,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if self.clear_tool_inputs:
|
||||
ai_idx = messages.index(ai_message)
|
||||
messages[ai_idx] = self._clear_input_args(
|
||||
ai_message, tool_message.tool_call_id or ""
|
||||
)
|
||||
|
||||
if self.clear_at_least > 0:
|
||||
new_token_count = count_tokens(messages)
|
||||
cleared_tokens = max(0, tokens - new_token_count)
|
||||
if cleared_tokens >= self.clear_at_least:
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def _encode_payload(content: Any) -> bytes:
|
||||
"""Serialize ``ToolMessage.content`` to bytes for upload."""
|
||||
if isinstance(content, bytes):
|
||||
return content
|
||||
if isinstance(content, str):
|
||||
return content.encode("utf-8")
|
||||
try:
|
||||
import json
|
||||
|
||||
return json.dumps(content, default=str).encode("utf-8")
|
||||
except Exception:
|
||||
return str(content).encode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def _clear_input_args(message: AIMessage, tool_call_id: str) -> AIMessage:
|
||||
updated_tool_calls: list[dict[str, Any]] = []
|
||||
cleared_any = False
|
||||
for tool_call in message.tool_calls:
|
||||
updated = dict(tool_call)
|
||||
if updated.get("id") == tool_call_id:
|
||||
updated["args"] = {}
|
||||
cleared_any = True
|
||||
updated_tool_calls.append(updated)
|
||||
|
||||
metadata = dict(getattr(message, "response_metadata", {}))
|
||||
if cleared_any:
|
||||
ctx = dict(metadata.get("context_editing", {}))
|
||||
ids = set(ctx.get("cleared_tool_inputs", []))
|
||||
ids.add(tool_call_id)
|
||||
ctx["cleared_tool_inputs"] = sorted(ids)
|
||||
metadata["context_editing"] = ctx
|
||||
return message.model_copy(
|
||||
update={
|
||||
"tool_calls": updated_tool_calls,
|
||||
"response_metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
BackendResolver = "Callable[[Any], BackendProtocol] | BackendProtocol"
|
||||
|
||||
|
||||
class SpillingContextEditingMiddleware(ContextEditingMiddleware):
|
||||
""":class:`ContextEditingMiddleware` that flushes :class:`SpillToBackendEdit` writes.
|
||||
|
||||
Runs the configured edits as the parent does, then flushes any
|
||||
pending spills via the supplied backend resolver before delegating
|
||||
to the model handler. Spill failures are logged but never abort the
|
||||
model call — the placeholder text is already in the message, so the
|
||||
worst case is the agent gets a placeholder it cannot follow up on.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
edits: Sequence[ContextEdit],
|
||||
backend_resolver: BackendResolver | None = None,
|
||||
token_count_method: str = "approximate",
|
||||
) -> None:
|
||||
super().__init__(edits=list(edits), token_count_method=token_count_method) # type: ignore[arg-type]
|
||||
self._backend_resolver = backend_resolver
|
||||
|
||||
def _resolve_backend(self, request: ModelRequest) -> BackendProtocol | None:
|
||||
if self._backend_resolver is None:
|
||||
return None
|
||||
if callable(self._backend_resolver):
|
||||
try:
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
tool_runtime = ToolRuntime(
|
||||
state=getattr(request, "state", {}),
|
||||
context=getattr(request.runtime, "context", None),
|
||||
stream_writer=getattr(request.runtime, "stream_writer", None),
|
||||
store=getattr(request.runtime, "store", None),
|
||||
config=getattr(request.runtime, "config", None) or {},
|
||||
tool_call_id=None,
|
||||
)
|
||||
return self._backend_resolver(tool_runtime)
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve spill backend")
|
||||
return None
|
||||
return self._backend_resolver # type: ignore[return-value]
|
||||
|
||||
def _collect_pending(self) -> list[tuple[str, bytes]]:
|
||||
out: list[tuple[str, bytes]] = []
|
||||
for edit in self.edits:
|
||||
if isinstance(edit, SpillToBackendEdit):
|
||||
out.extend(edit.drain_pending())
|
||||
return out
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> Any:
|
||||
if not request.messages:
|
||||
return await handler(request)
|
||||
|
||||
if self.token_count_method == "approximate":
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
|
||||
else:
|
||||
system_msg = [request.system_message] if request.system_message else []
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
edited_messages = deepcopy(list(request.messages))
|
||||
for edit in self.edits:
|
||||
edit.apply(edited_messages, count_tokens=count_tokens)
|
||||
|
||||
pending = self._collect_pending()
|
||||
if pending:
|
||||
backend = self._resolve_backend(request)
|
||||
if backend is not None:
|
||||
try:
|
||||
await backend.aupload_files(pending)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Spill-to-backend upload failed (%d files); placeholders "
|
||||
"remain in messages but content is unrecoverable",
|
||||
len(pending),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"SpillToBackendEdit produced %d pending spills but no backend "
|
||||
"resolver was configured; content is unrecoverable",
|
||||
len(pending),
|
||||
)
|
||||
|
||||
return await handler(request.override(messages=edited_messages))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_SPILL_PREFIX",
|
||||
"ClearToolUsesEdit",
|
||||
"SpillToBackendEdit",
|
||||
"SpillingContextEditingMiddleware",
|
||||
"_build_spill_placeholder",
|
||||
]
|
||||
|
|
@ -1,238 +0,0 @@
|
|||
"""
|
||||
DoomLoopMiddleware — pattern-based detector for repeated identical tool calls.
|
||||
|
||||
LangChain has :class:`ToolCallLimitMiddleware` which caps the *total* number
|
||||
of tool calls per turn — but it can't tell apart "10 distinct, useful
|
||||
calls" from "the same call 10 times in a row". This middleware fills that
|
||||
gap with a sliding-window check on tool-call signatures, ported from
|
||||
OpenCode's ``packages/opencode/src/session/processor.ts``.
|
||||
|
||||
When the same tool with the same arguments is called N times in a row,
|
||||
the agent has likely entered an infinite loop. We surface this to the
|
||||
user as an interrupt with ``permission="doom_loop"`` so the UI can
|
||||
render an "Are you stuck? Continue / cancel?" affordance.
|
||||
|
||||
This ships **OFF by default** until the frontend explicitly handles
|
||||
``context.permission == "doom_loop"`` interrupts.
|
||||
|
||||
Wire format: uses SurfSense's existing ``interrupt()`` payload shape
|
||||
(see ``app/agents/shared/tools/hitl.py``):
|
||||
|
||||
{
|
||||
"type": "permission_ask",
|
||||
"action": {"tool": <name>, "params": <args>},
|
||||
"context": {"permission": "doom_loop", "recent_signatures": [...]},
|
||||
}
|
||||
|
||||
so the frontend that already handles HITL prompts can render this with
|
||||
no changes beyond a string check.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _signature(name: str, args: Any) -> str:
|
||||
"""Hash a tool call ``(name, args)`` to a short signature."""
|
||||
try:
|
||||
canonical = json.dumps(args, sort_keys=True, default=str)
|
||||
except (TypeError, ValueError):
|
||||
canonical = repr(args)
|
||||
digest = hashlib.sha1(f"{name}::{canonical}".encode()).hexdigest()
|
||||
return digest[:16]
|
||||
|
||||
|
||||
class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Detect repeated identical tool calls and prompt the user.
|
||||
|
||||
Tracks a sliding window of the most-recent ``threshold`` tool-call
|
||||
signatures across the live request. When all entries match, raise
|
||||
a SurfSense-style HITL interrupt with ``permission="doom_loop"``.
|
||||
|
||||
Args:
|
||||
threshold: How many consecutive identical signatures count as a
|
||||
doom loop. Default 3 (matches OpenCode's processor.ts).
|
||||
"""
|
||||
|
||||
def __init__(self, *, threshold: int = 3) -> None:
|
||||
super().__init__()
|
||||
if threshold < 2:
|
||||
raise ValueError("DoomLoopMiddleware threshold must be >= 2")
|
||||
self._threshold = threshold
|
||||
self.tools = []
|
||||
# Per-thread sliding windows. We can't put this in graph state
|
||||
# without state-schema gymnastics; for one process-lifetime it's
|
||||
# fine to keep an in-memory map keyed by thread_id.
|
||||
self._windows: dict[str, deque[str]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _thread_id_from_runtime(runtime: Runtime[ContextT]) -> str:
|
||||
"""Resolve the thread id for sliding-window keying.
|
||||
|
||||
Prefer LangGraph's ``get_config()`` (the only way to read
|
||||
``RunnableConfig`` inside a node — :class:`Runtime` does NOT carry
|
||||
a ``config`` attribute). Fall back to ``runtime.config`` for unit
|
||||
tests that synthesize a config-bearing stub. Default
|
||||
``"no_thread"`` is intentionally only used when both lookups fail
|
||||
— it would collapse all threads into one window so we keep the
|
||||
debug log loud.
|
||||
"""
|
||||
|
||||
def _from_dict(cfg: Any) -> str | None:
|
||||
if not isinstance(cfg, dict):
|
||||
return None
|
||||
tid = (cfg.get("configurable") or {}).get("thread_id")
|
||||
return str(tid) if tid is not None else None
|
||||
|
||||
try:
|
||||
tid = _from_dict(get_config())
|
||||
except Exception:
|
||||
tid = None
|
||||
if tid is not None:
|
||||
return tid
|
||||
|
||||
tid = _from_dict(getattr(runtime, "config", None))
|
||||
if tid is not None:
|
||||
return tid
|
||||
|
||||
logger.debug(
|
||||
"DoomLoopMiddleware: no thread_id resolved from RunnableConfig; "
|
||||
"falling back to shared 'no_thread' window."
|
||||
)
|
||||
return "no_thread"
|
||||
|
||||
def _window(self, thread_id: str) -> deque[str]:
|
||||
win = self._windows.get(thread_id)
|
||||
if win is None:
|
||||
win = deque(maxlen=self._threshold)
|
||||
self._windows[thread_id] = win
|
||||
return win
|
||||
|
||||
def _detect(
|
||||
self, message: AIMessage, runtime: Runtime[ContextT]
|
||||
) -> tuple[bool, list[str], dict[str, Any] | None]:
|
||||
if not message.tool_calls:
|
||||
return False, [], None
|
||||
|
||||
thread_id = self._thread_id_from_runtime(runtime)
|
||||
window = self._window(thread_id)
|
||||
|
||||
triggered_call: dict[str, Any] | None = None
|
||||
for call in message.tool_calls:
|
||||
name = (
|
||||
call.get("name")
|
||||
if isinstance(call, dict)
|
||||
else getattr(call, "name", None)
|
||||
)
|
||||
args = (
|
||||
call.get("args")
|
||||
if isinstance(call, dict)
|
||||
else getattr(call, "args", {})
|
||||
)
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
sig = _signature(name, args)
|
||||
window.append(sig)
|
||||
if len(window) >= self._threshold and len(set(window)) == 1:
|
||||
triggered_call = {"name": name, "params": args or {}}
|
||||
break
|
||||
|
||||
if triggered_call is None:
|
||||
return False, list(window), None
|
||||
return True, list(window), triggered_call
|
||||
|
||||
def after_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
return None
|
||||
|
||||
triggered, signatures, action = self._detect(last, runtime)
|
||||
if not triggered:
|
||||
return None
|
||||
|
||||
logger.warning(
|
||||
"Doom loop detected: tool %s called %d times in a row (sig=%s)",
|
||||
action["name"] if action else "<unknown>",
|
||||
self._threshold,
|
||||
signatures[-1] if signatures else "<empty>",
|
||||
)
|
||||
|
||||
# Open an interrupt.raised span with permission=doom_loop attribute
|
||||
# so dashboards can break out doom-loop interrupts from regular
|
||||
# permission asks via the ``interrupt.permission`` attribute.
|
||||
with ot.interrupt_span(
|
||||
interrupt_type="permission_ask",
|
||||
extra={
|
||||
"interrupt.permission": "doom_loop",
|
||||
"interrupt.threshold": self._threshold,
|
||||
"interrupt.tool": (action or {}).get("tool", "<unknown>"),
|
||||
},
|
||||
):
|
||||
ot_metrics.record_interrupt(interrupt_type="permission_ask")
|
||||
decision = interrupt(
|
||||
{
|
||||
"type": "permission_ask",
|
||||
"action": action or {"tool": "<unknown>", "params": {}},
|
||||
"context": {
|
||||
"permission": "doom_loop",
|
||||
"recent_signatures": signatures,
|
||||
"threshold": self._threshold,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Reset window so the next decision (continue/cancel) starts fresh.
|
||||
thread_id = self._thread_id_from_runtime(runtime)
|
||||
self._windows.pop(thread_id, None)
|
||||
|
||||
# Decision shape mirrors ``tools/hitl.py``: {"decision_type": "..."}
|
||||
# If the user cancelled, jump to end. Otherwise return ``None`` so the
|
||||
# tool call proceeds. The frontend's exact reply names may differ —
|
||||
# we tolerate any shape that contains a string with "reject"/"cancel".
|
||||
if isinstance(decision, dict):
|
||||
kind = str(
|
||||
decision.get("decision_type") or decision.get("type") or ""
|
||||
).lower()
|
||||
if "reject" in kind or "cancel" in kind:
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
|
||||
async def aafter_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DoomLoopMiddleware",
|
||||
"_signature",
|
||||
]
|
||||
|
|
@ -1,334 +0,0 @@
|
|||
"""Workspace-tree middleware for the SurfSense agent.
|
||||
|
||||
Renders the full ``Folder``+``Document`` tree under ``/documents/`` once per
|
||||
turn (cloud only), caches it by ``(search_space_id, tree_version)``, and
|
||||
injects the result as a ``<workspace_tree>`` system message immediately
|
||||
before the latest human turn.
|
||||
|
||||
The render is bounded by two truncation layers:
|
||||
|
||||
1. **Entry cap** — at most ``MAX_TREE_ENTRIES`` lines. The remainder is
|
||||
replaced with a "use ls" hint.
|
||||
2. **Token cap** — at most ``MAX_TREE_TOKENS`` tokens (using the LLM's
|
||||
token-count profile when available). If the entry-truncated tree still
|
||||
exceeds the token cap we fall back to a root-only summary.
|
||||
|
||||
Anonymous mode renders only ``state['kb_anon_doc']`` (no DB calls).
|
||||
|
||||
This middleware also performs a one-time initialization of ``state['cwd']``
|
||||
to ``"/documents"`` so subsequent middlewares and tools always see a valid
|
||||
cwd in cloud mode.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
PathIndex,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.db import Document, shielded_async_session
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
try:
|
||||
from litellm import token_counter
|
||||
except Exception: # pragma: no cover - optional dep
|
||||
token_counter = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_TREE_ENTRIES = 500
|
||||
MAX_TREE_TOKENS = 4000
|
||||
|
||||
|
||||
def _approx_tokens(text: str) -> int:
|
||||
"""Cheap fallback token estimate (1 token ~= 4 chars)."""
|
||||
return max(1, (len(text) + 3) // 4)
|
||||
|
||||
|
||||
def _count_tokens(text: str, *, llm: BaseChatModel | None) -> int:
|
||||
if llm is None:
|
||||
return _approx_tokens(text)
|
||||
count_fn = getattr(llm, "_count_tokens", None)
|
||||
if callable(count_fn):
|
||||
try:
|
||||
return int(count_fn([{"role": "user", "content": text}]))
|
||||
except Exception:
|
||||
pass
|
||||
profile = getattr(llm, "profile", None)
|
||||
model_names: list[str] = []
|
||||
if isinstance(profile, dict):
|
||||
tcms = profile.get("token_count_models")
|
||||
if isinstance(tcms, list):
|
||||
model_names.extend(name for name in tcms if isinstance(name, str) and name)
|
||||
tcm = profile.get("token_count_model")
|
||||
if isinstance(tcm, str) and tcm and tcm not in model_names:
|
||||
model_names.append(tcm)
|
||||
model_name = model_names[0] if model_names else getattr(llm, "model", None)
|
||||
if not isinstance(model_name, str) or not model_name or token_counter is None:
|
||||
return _approx_tokens(text)
|
||||
try:
|
||||
return int(
|
||||
token_counter(
|
||||
messages=[{"role": "user", "content": text}],
|
||||
model=model_name,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
return _approx_tokens(text)
|
||||
|
||||
|
||||
class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Inject the workspace folder/document tree into the agent's context."""
|
||||
|
||||
tools = ()
|
||||
state_schema = SurfSenseFilesystemState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
search_space_id: int,
|
||||
filesystem_mode: FilesystemMode,
|
||||
llm: BaseChatModel | None = None,
|
||||
max_entries: int = MAX_TREE_ENTRIES,
|
||||
max_tokens: int = MAX_TREE_TOKENS,
|
||||
inject_system_message: bool = True, # For backwards compatibility
|
||||
) -> None:
|
||||
self.search_space_id = search_space_id
|
||||
self.filesystem_mode = filesystem_mode
|
||||
self.llm = llm
|
||||
self.max_entries = max_entries
|
||||
self.max_tokens = max_tokens
|
||||
self.inject_system_message = inject_system_message
|
||||
self._cache: dict[tuple[int, int, bool], str] = {}
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
|
||||
start = time.perf_counter()
|
||||
update: dict[str, Any] = {}
|
||||
if not state.get("cwd"):
|
||||
update["cwd"] = DOCUMENTS_ROOT
|
||||
|
||||
anon_doc = state.get("kb_anon_doc")
|
||||
if anon_doc:
|
||||
tree_msg = self._render_anon_tree(anon_doc)
|
||||
cache_outcome = "anon"
|
||||
else:
|
||||
version = int(state.get("tree_version") or 0)
|
||||
cache_key = (self.search_space_id, version, False)
|
||||
cache_outcome = "hit" if cache_key in self._cache else "miss"
|
||||
tree_msg = await self._render_kb_tree(state)
|
||||
|
||||
update["workspace_tree_text"] = tree_msg
|
||||
|
||||
if self.inject_system_message:
|
||||
messages = list(state.get("messages") or [])
|
||||
insert_at = max(len(messages) - 1, 0)
|
||||
messages.insert(insert_at, SystemMessage(content=tree_msg))
|
||||
update["messages"] = messages
|
||||
|
||||
_perf_log.info(
|
||||
"[knowledge_tree] cache=%s chars=%d elapsed=%.3fs space=%d",
|
||||
cache_outcome,
|
||||
len(tree_msg),
|
||||
time.perf_counter() - start,
|
||||
self.search_space_id,
|
||||
)
|
||||
return update
|
||||
|
||||
def before_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_running():
|
||||
return None
|
||||
except RuntimeError:
|
||||
pass
|
||||
return asyncio.run(self.abefore_agent(state, runtime))
|
||||
|
||||
# ------------------------------------------------------------------ render
|
||||
|
||||
def _render_anon_tree(self, anon_doc: dict[str, Any]) -> str:
|
||||
path = str(anon_doc.get("path") or "")
|
||||
title = str(anon_doc.get("title") or "uploaded_document")
|
||||
return (
|
||||
"<workspace_tree>\n"
|
||||
"Anonymous session — only one read-only document is available.\n"
|
||||
f"{DOCUMENTS_ROOT}/\n"
|
||||
f" {path} — {title}\n"
|
||||
"</workspace_tree>"
|
||||
)
|
||||
|
||||
async def _render_kb_tree(self, state: AgentState) -> str:
|
||||
version = int(state.get("tree_version") or 0)
|
||||
cache_key = (self.search_space_id, version, False)
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
index = await build_path_index(session, self.search_space_id)
|
||||
doc_rows = await session.execute(
|
||||
select(Document.id, Document.title, Document.folder_id).where(
|
||||
Document.search_space_id == self.search_space_id
|
||||
)
|
||||
)
|
||||
docs = list(doc_rows.all())
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning("knowledge_tree: DB error %s", exc)
|
||||
return "<workspace_tree>\n(unavailable)\n</workspace_tree>"
|
||||
|
||||
rendered = self._format_tree(index, docs)
|
||||
self._cache[cache_key] = rendered
|
||||
return rendered
|
||||
|
||||
def _format_tree(self, index: PathIndex, docs: list[Any]) -> str:
|
||||
folder_paths = sorted(set(index.folder_paths.values()))
|
||||
doc_paths = sorted(
|
||||
doc_to_virtual_path(
|
||||
doc_id=row.id,
|
||||
title=str(row.title or "untitled"),
|
||||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
for row in docs
|
||||
)
|
||||
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
|
||||
|
||||
# Pre-compute which folders have at least one descendant (folder or doc).
|
||||
# A folder is "empty" iff no path in `all_paths` is strictly under it.
|
||||
# Used to emit an explicit "(empty)" marker so the LLM doesn't have to
|
||||
# infer emptiness from indentation alone.
|
||||
non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths)
|
||||
|
||||
lines: list[str] = []
|
||||
for path in all_paths:
|
||||
depth = (
|
||||
0
|
||||
if path == DOCUMENTS_ROOT
|
||||
else len([p for p in path[len(DOCUMENTS_ROOT) :].split("/") if p])
|
||||
)
|
||||
indent = " " * depth
|
||||
is_dir = path == DOCUMENTS_ROOT or path in folder_paths
|
||||
display = (
|
||||
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
|
||||
)
|
||||
if is_dir:
|
||||
if path != DOCUMENTS_ROOT and path not in non_empty_folders:
|
||||
lines.append(f"{indent}{display}/ (empty)")
|
||||
else:
|
||||
lines.append(f"{indent}{display}/")
|
||||
else:
|
||||
lines.append(f"{indent}{display}")
|
||||
if len(lines) >= self.max_entries:
|
||||
remaining = len(all_paths) - len(lines)
|
||||
if remaining > 0:
|
||||
lines.append(
|
||||
f"... {remaining} more entries — use "
|
||||
"ls('/documents/<folder>', offset, limit) to expand"
|
||||
)
|
||||
break
|
||||
|
||||
body = "\n".join(lines)
|
||||
rendered = f"<workspace_tree>\n{body}\n</workspace_tree>"
|
||||
|
||||
token_count = _count_tokens(rendered, llm=self.llm)
|
||||
if token_count <= self.max_tokens:
|
||||
return rendered
|
||||
|
||||
return self._format_root_summary(folder_paths, doc_paths)
|
||||
|
||||
@staticmethod
|
||||
def _compute_non_empty_folders(
|
||||
folder_paths: list[str], doc_paths: list[str]
|
||||
) -> set[str]:
|
||||
"""Return the set of folder paths that contain at least one descendant.
|
||||
|
||||
A folder is "non-empty" if any document path or any other folder path
|
||||
is strictly under it. Documents propagate emptiness up to every
|
||||
ancestor folder, while a sub-folder only marks its direct ancestors
|
||||
non-empty (so a chain of empty folders all read ``(empty)``).
|
||||
"""
|
||||
non_empty: set[str] = set()
|
||||
folder_set = set(folder_paths)
|
||||
|
||||
for doc_path in doc_paths:
|
||||
parent = doc_path.rsplit("/", 1)[0]
|
||||
while parent and parent != DOCUMENTS_ROOT:
|
||||
if parent in folder_set:
|
||||
non_empty.add(parent)
|
||||
parent = parent.rsplit("/", 1)[0]
|
||||
|
||||
for child in folder_paths:
|
||||
parent = child.rsplit("/", 1)[0]
|
||||
while parent and parent != DOCUMENTS_ROOT and parent in folder_set:
|
||||
non_empty.add(parent)
|
||||
parent = parent.rsplit("/", 1)[0]
|
||||
|
||||
return non_empty
|
||||
|
||||
def _format_root_summary(
|
||||
self, folder_paths: list[str], doc_paths: list[str]
|
||||
) -> str:
|
||||
top_level: dict[str, int] = {}
|
||||
loose_docs = 0
|
||||
for path in doc_paths:
|
||||
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||
if "/" in rel:
|
||||
top = rel.split("/", 1)[0]
|
||||
top_level[top] = top_level.get(top, 0) + 1
|
||||
else:
|
||||
loose_docs += 1
|
||||
for path in folder_paths:
|
||||
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||
if not rel:
|
||||
continue
|
||||
top = rel.split("/", 1)[0]
|
||||
top_level.setdefault(top, 0)
|
||||
|
||||
lines = [DOCUMENTS_ROOT + "/"]
|
||||
for name in sorted(top_level):
|
||||
count = top_level[name]
|
||||
lines.append(f" {name}/ ({count} document{'s' if count != 1 else ''})")
|
||||
if loose_docs:
|
||||
lines.append(
|
||||
f" ({loose_docs} loose document{'s' if loose_docs != 1 else ''})"
|
||||
)
|
||||
lines.append(
|
||||
"Tree is large; use list_tree('/documents/<folder>') to drill in "
|
||||
"or ls('/documents/<folder>', offset, limit) for paginated listings."
|
||||
)
|
||||
return "<workspace_tree>\n" + "\n".join(lines) + "\n</workspace_tree>"
|
||||
|
||||
|
||||
__all__ = ["KnowledgeTreeMiddleware"]
|
||||
|
|
@ -1,141 +0,0 @@
|
|||
"""
|
||||
``_noop`` provider-compatibility tool + injection middleware.
|
||||
|
||||
Some providers (LiteLLM, Bedrock, Copilot) 400 when a model call has
|
||||
empty ``tools`` but the message history includes prior ``tool_calls`` —
|
||||
they treat that shape as malformed even though it's perfectly valid
|
||||
LangChain. SurfSense hits this on the compaction summarize call (no
|
||||
tools, history full of tool calls).
|
||||
|
||||
Ported from OpenCode's ``packages/opencode/src/session/llm.ts:209-228``,
|
||||
which discovered and codified the workaround: inject a no-op tool *only*
|
||||
on those provider shapes so the request validates without ever being
|
||||
called.
|
||||
|
||||
Operation: a :class:`NoopInjectionMiddleware` ``wrap_model_call`` checks
|
||||
if the request has zero tools but the last AI message in history includes
|
||||
``tool_calls``. If yes, it injects the ``_noop`` tool only — never
|
||||
globally — mirroring OpenCode's gating exactly. The :func:`noop_tool`
|
||||
returns empty content when called (which it should never be in
|
||||
practice).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NOOP_TOOL_NAME = "_noop"
|
||||
NOOP_TOOL_DESCRIPTION = "Do not call this tool. It exists only for API compatibility."
|
||||
|
||||
|
||||
@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION)
|
||||
def noop_tool() -> str:
|
||||
"""Return empty content. Never expected to be called."""
|
||||
return ""
|
||||
|
||||
|
||||
# Provider markers that benefit from ``_noop`` injection. These match
|
||||
# OpenCode's gating list (``llm.ts:209-228``). We also accept any string
|
||||
# containing one of these substrings so e.g. ``litellm`` matches
|
||||
# ``ChatLiteLLM``.
|
||||
_NOOP_NEEDED_PROVIDERS: tuple[str, ...] = (
|
||||
"litellm",
|
||||
"bedrock",
|
||||
"copilot",
|
||||
)
|
||||
|
||||
|
||||
def _provider_needs_noop(model: Any) -> bool:
|
||||
"""Heuristic: does this model's provider need the _noop injection?"""
|
||||
try:
|
||||
ls_params = model._get_ls_params()
|
||||
provider = str(ls_params.get("ls_provider", "")).lower()
|
||||
except Exception:
|
||||
provider = ""
|
||||
|
||||
if not provider:
|
||||
cls_name = type(model).__name__.lower()
|
||||
provider = cls_name
|
||||
|
||||
return any(needle in provider for needle in _NOOP_NEEDED_PROVIDERS)
|
||||
|
||||
|
||||
def _last_ai_has_tool_calls(messages: list[Any]) -> bool:
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
return bool(msg.tool_calls)
|
||||
return False
|
||||
|
||||
|
||||
class NoopInjectionMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""Inject the ``_noop`` tool only when the provider would otherwise 400.
|
||||
|
||||
The check fires per model call, not at agent build time, because the
|
||||
summarization path generates a no-tool subcall at runtime. The
|
||||
extra tool is appended to ``request.tools`` as an instance — the
|
||||
actual ``langchain_core.tools.BaseTool`` is bound on every call site
|
||||
that creates the agent.
|
||||
"""
|
||||
|
||||
def __init__(self, *, noop_tool_instance: Any | None = None) -> None:
|
||||
super().__init__()
|
||||
self._noop_tool = noop_tool_instance or noop_tool
|
||||
self.tools = []
|
||||
|
||||
def _should_inject(self, request: ModelRequest[ContextT]) -> bool:
|
||||
if request.tools:
|
||||
return False
|
||||
if not _last_ai_has_tool_calls(request.messages):
|
||||
return False
|
||||
return _provider_needs_noop(request.model)
|
||||
|
||||
def _augmented(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
return request.override(tools=[self._noop_tool])
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> Any:
|
||||
if self._should_inject(request):
|
||||
logger.debug("Injecting _noop tool for provider compatibility")
|
||||
return handler(self._augmented(request))
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> Any:
|
||||
if self._should_inject(request):
|
||||
logger.debug("Injecting _noop tool for provider compatibility")
|
||||
return await handler(self._augmented(request))
|
||||
return await handler(request)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NOOP_TOOL_DESCRIPTION",
|
||||
"NOOP_TOOL_NAME",
|
||||
"NoopInjectionMiddleware",
|
||||
"_provider_needs_noop",
|
||||
"noop_tool",
|
||||
]
|
||||
|
|
@ -1,285 +0,0 @@
|
|||
"""
|
||||
OpenTelemetry span middleware for the SurfSense ``new_chat`` agent.
|
||||
|
||||
Wraps both ``model.call`` (LLM invocations) and ``tool.call`` (tool
|
||||
executions) with OTel spans, attaching low-cardinality span names and
|
||||
high-cardinality identifiers as attributes.
|
||||
|
||||
This middleware is intentionally a thin adapter over
|
||||
:mod:`app.observability.otel`; when OTel is not configured all spans
|
||||
collapse to no-ops and the wrapper adds <1µs overhead per call. When
|
||||
OTel **is** configured (``OTEL_EXPORTER_OTLP_ENDPOINT`` set), every
|
||||
model and tool call gets a span with the standard attributes our
|
||||
dashboards expect.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover — type-only
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ToolCallRequest,
|
||||
)
|
||||
from langgraph.types import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OtelSpanMiddleware(AgentMiddleware):
|
||||
"""Emit ``model.call`` and ``tool.call`` OTel spans for every invocation.
|
||||
|
||||
Should be placed near the **outer** end of the middleware list so
|
||||
that the spans encompass retry/fallback wrapper effects (i.e. ``N``
|
||||
model.call spans for ``N`` retry attempts) but inside any concurrency/
|
||||
auth gate. Empirically this means **between** ``BusyMutex`` and
|
||||
``RetryAfter``.
|
||||
"""
|
||||
|
||||
def __init__(self, *, instrumentation_name: str = "surfsense.new_chat") -> None:
|
||||
super().__init__()
|
||||
self._instrumentation_name = instrumentation_name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Model call spans
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]],
|
||||
) -> ModelResponse | AIMessage | Any:
|
||||
if not ot.is_enabled():
|
||||
return await handler(request)
|
||||
|
||||
model_id, provider = _resolve_model_attrs(request)
|
||||
t0 = time.perf_counter()
|
||||
with ot.model_call_span(model_id=model_id, provider=provider) as sp:
|
||||
_annotate_model_request(sp, model_id=model_id, provider=provider)
|
||||
try:
|
||||
result = await handler(request)
|
||||
except Exception:
|
||||
ot_metrics.record_model_call_duration(
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
model=model_id,
|
||||
provider=provider,
|
||||
)
|
||||
# span context manager records + re-raises
|
||||
raise
|
||||
else:
|
||||
input_tokens, output_tokens = _annotate_model_response(
|
||||
sp,
|
||||
result,
|
||||
model_id=model_id,
|
||||
provider=provider,
|
||||
)
|
||||
ot_metrics.record_model_call_duration(
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
model=model_id,
|
||||
provider=provider,
|
||||
)
|
||||
ot_metrics.record_model_token_usage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
model=model_id,
|
||||
provider=provider,
|
||||
)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool call spans
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
if not ot.is_enabled():
|
||||
return await handler(request)
|
||||
|
||||
tool_name = _resolve_tool_name(request)
|
||||
input_size = _resolve_input_size(request)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
with ot.tool_call_span(tool_name, input_size=input_size) as sp:
|
||||
try:
|
||||
result = await handler(request)
|
||||
except Exception:
|
||||
ot_metrics.record_tool_call_duration(
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
ot_metrics.record_tool_call_error(tool_name=tool_name)
|
||||
raise
|
||||
errored = _annotate_tool_result(sp, result)
|
||||
ot_metrics.record_tool_call_duration(
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
if errored:
|
||||
ot_metrics.record_tool_call_error(tool_name=tool_name)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attribute helpers (kept defensive; we never want OTel bookkeeping to break
|
||||
# a real model/tool call).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_model_attrs(request: Any) -> tuple[str | None, str | None]:
|
||||
"""Extract ``model.id`` and ``model.provider`` from a ``ModelRequest``."""
|
||||
model_id: str | None = None
|
||||
provider: str | None = None
|
||||
try:
|
||||
model = getattr(request, "model", None)
|
||||
if model is None:
|
||||
return None, None
|
||||
# langchain BaseChatModel exposes a few different identifiers
|
||||
for attr in ("model_name", "model", "model_id"):
|
||||
value = getattr(model, attr, None)
|
||||
if value:
|
||||
model_id = str(value)
|
||||
break
|
||||
# provider sometimes lives on ``_llm_type`` (legacy) or ``provider``
|
||||
for attr in ("provider", "_llm_type"):
|
||||
value = getattr(model, attr, None)
|
||||
if value:
|
||||
provider = str(value)
|
||||
break
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
return model_id, provider
|
||||
|
||||
|
||||
def _resolve_tool_name(request: Any) -> str:
|
||||
try:
|
||||
tool = getattr(request, "tool", None)
|
||||
if tool is not None:
|
||||
name = getattr(tool, "name", None)
|
||||
if isinstance(name, str) and name:
|
||||
return name
|
||||
# Fall back to the tool_call dict
|
||||
call = getattr(request, "tool_call", None) or {}
|
||||
name = call.get("name") if isinstance(call, dict) else None
|
||||
if isinstance(name, str) and name:
|
||||
return name
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _resolve_input_size(request: Any) -> int | None:
|
||||
try:
|
||||
call = getattr(request, "tool_call", None)
|
||||
if not isinstance(call, dict) or not call:
|
||||
return None
|
||||
args = call.get("args")
|
||||
if args is None:
|
||||
return None
|
||||
return len(repr(args))
|
||||
except Exception: # pragma: no cover — defensive
|
||||
return None
|
||||
|
||||
|
||||
def _annotate_model_request(
|
||||
span: Any, *, model_id: str | None, provider: str | None
|
||||
) -> None:
|
||||
try:
|
||||
span.set_attribute("gen_ai.operation.name", "chat")
|
||||
if model_id:
|
||||
span.set_attribute("gen_ai.request.model", model_id)
|
||||
if provider:
|
||||
span.set_attribute("gen_ai.provider.name", provider)
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
|
||||
|
||||
def _annotate_model_response(
|
||||
span: Any,
|
||||
result: Any,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
provider: str | None = None,
|
||||
) -> tuple[int | None, int | None]:
|
||||
"""Best-effort: attach prompt/completion token counts when available."""
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
try:
|
||||
# ModelResponse may be a dataclass with .result containing AIMessage
|
||||
msg: Any
|
||||
if isinstance(result, AIMessage):
|
||||
msg = result
|
||||
else:
|
||||
inner = getattr(result, "result", None)
|
||||
msg = inner[-1] if isinstance(inner, list) and inner else inner
|
||||
if msg is None:
|
||||
return None, None
|
||||
if provider:
|
||||
span.set_attribute("gen_ai.provider.name", provider)
|
||||
if model_id:
|
||||
span.set_attribute("gen_ai.request.model", model_id)
|
||||
response_model = getattr(msg, "response_metadata", {}) or {}
|
||||
if isinstance(response_model, dict):
|
||||
response_model = (
|
||||
response_model.get("model_name")
|
||||
or response_model.get("model")
|
||||
or response_model.get("model_id")
|
||||
)
|
||||
if not response_model:
|
||||
response_model = model_id
|
||||
if response_model:
|
||||
span.set_attribute("gen_ai.response.model", str(response_model))
|
||||
span.set_attribute("gen_ai.operation.name", "chat")
|
||||
usage = getattr(msg, "usage_metadata", None) or {}
|
||||
if isinstance(usage, dict):
|
||||
if (n := usage.get("input_tokens")) is not None:
|
||||
input_tokens = int(n)
|
||||
span.set_attribute("gen_ai.usage.input_tokens", input_tokens)
|
||||
if (n := usage.get("output_tokens")) is not None:
|
||||
output_tokens = int(n)
|
||||
span.set_attribute("gen_ai.usage.output_tokens", output_tokens)
|
||||
if (n := usage.get("total_tokens")) is not None:
|
||||
span.set_attribute("gen_ai.usage.total_tokens", int(n))
|
||||
tool_calls = getattr(msg, "tool_calls", None) or []
|
||||
span.set_attribute("model.tool_calls", len(tool_calls))
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
return input_tokens, output_tokens
|
||||
|
||||
|
||||
def _annotate_tool_result(span: Any, result: Any) -> bool:
|
||||
errored = False
|
||||
try:
|
||||
if isinstance(result, ToolMessage):
|
||||
content = (
|
||||
result.content
|
||||
if isinstance(result.content, str)
|
||||
else repr(result.content)
|
||||
)
|
||||
span.set_attribute("tool.output.size", len(content))
|
||||
status = getattr(result, "status", None)
|
||||
if isinstance(status, str):
|
||||
span.set_attribute("tool.status", status)
|
||||
errored = status.lower() == "error"
|
||||
kwargs = getattr(result, "additional_kwargs", None) or {}
|
||||
if isinstance(kwargs, dict) and kwargs.get("error"):
|
||||
span.set_attribute("tool.error", True)
|
||||
errored = True
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
return errored
|
||||
|
||||
|
||||
__all__ = ["OtelSpanMiddleware"]
|
||||
|
|
@ -1,197 +0,0 @@
|
|||
"""
|
||||
ToolCallNameRepairMiddleware — two-stage tool-name repair.
|
||||
|
||||
Operation:
|
||||
1. **Stage 1 — lowercase repair:** if a tool call's ``name`` is not in
|
||||
the registry but ``name.lower()`` is, rewrite in place. Catches
|
||||
models that emit ``Search`` instead of ``search``.
|
||||
2. **Stage 2 — invalid fallback:** if still unmatched, rewrite the call
|
||||
to ``invalid`` with ``args={"tool": original_name, "error": <error>}``
|
||||
so the registered :func:`invalid_tool` returns the error to the model
|
||||
for self-correction.
|
||||
|
||||
Ported from OpenCode's ``packages/opencode/src/session/llm.ts:339-358``
|
||||
+ ``packages/opencode/src/tool/invalid.ts``. LangChain has no equivalent:
|
||||
:class:`deepagents.middleware.PatchToolCallsMiddleware` patches
|
||||
*dangling* tool calls (no matching ToolMessage) but does nothing about
|
||||
wrong names, and the model framework's default behavior on an unknown
|
||||
name is to crash the turn rather than route to a self-correction
|
||||
fallback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _coerce_existing_tool_call(call: Any) -> dict[str, Any]:
|
||||
"""Normalize a tool call entry to a mutable dict."""
|
||||
if isinstance(call, dict):
|
||||
return dict(call)
|
||||
return {
|
||||
"name": getattr(call, "name", None),
|
||||
"args": getattr(call, "args", {}),
|
||||
"id": getattr(call, "id", None),
|
||||
"type": "tool_call",
|
||||
}
|
||||
|
||||
|
||||
class ToolCallNameRepairMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""Two-stage tool-name repair on the most recent ``AIMessage``.
|
||||
|
||||
Args:
|
||||
registered_tool_names: Set of canonically-registered tool names.
|
||||
``invalid`` should be in this set so the fallback dispatches.
|
||||
fuzzy_match_threshold: Optional ``difflib`` ratio (0-1) for the
|
||||
fuzzy-match step that runs *between* lowercase and invalid.
|
||||
Set to ``None`` to disable fuzzy matching (default in
|
||||
OpenCode; we mirror that to avoid silent rewrites).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
registered_tool_names: set[str],
|
||||
fuzzy_match_threshold: float | None = 0.85,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._registered = set(registered_tool_names)
|
||||
self._registered_lower = {name.lower(): name for name in self._registered}
|
||||
self._fuzzy_threshold = fuzzy_match_threshold
|
||||
self.tools = []
|
||||
|
||||
def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]:
|
||||
"""Allow runtime overrides to expand the set (e.g. dynamic MCP tools)."""
|
||||
ctx_tools = getattr(runtime.context, "registered_tool_names", None)
|
||||
if isinstance(ctx_tools, set | frozenset):
|
||||
return self._registered | set(ctx_tools)
|
||||
if isinstance(ctx_tools, list | tuple):
|
||||
return self._registered | set(ctx_tools)
|
||||
return self._registered
|
||||
|
||||
def _repair_one(
|
||||
self,
|
||||
call: dict[str, Any],
|
||||
registered: set[str],
|
||||
) -> dict[str, Any]:
|
||||
name = call.get("name")
|
||||
if not isinstance(name, str):
|
||||
return call
|
||||
|
||||
if name in registered:
|
||||
return call
|
||||
|
||||
# Stage 1 — lowercase
|
||||
lowered = name.lower()
|
||||
if lowered in registered:
|
||||
call["name"] = lowered
|
||||
metadata = dict(call.get("response_metadata") or {})
|
||||
metadata.setdefault("repair", "lowercase")
|
||||
call["response_metadata"] = metadata
|
||||
return call
|
||||
|
||||
# Optional fuzzy step (off by default — see class docstring)
|
||||
if self._fuzzy_threshold is not None:
|
||||
close = difflib.get_close_matches(
|
||||
name, registered, n=1, cutoff=self._fuzzy_threshold
|
||||
)
|
||||
if close:
|
||||
call["name"] = close[0]
|
||||
metadata = dict(call.get("response_metadata") or {})
|
||||
metadata.setdefault("repair", f"fuzzy:{name}->{close[0]}")
|
||||
call["response_metadata"] = metadata
|
||||
return call
|
||||
|
||||
# Stage 2 — invalid fallback
|
||||
# Local import keeps the middleware module import-light and avoids any
|
||||
# tools <-> middleware import-order coupling at module scope.
|
||||
from app.agents.multi_agent_chat.main_agent.tools.invalid_tool import (
|
||||
INVALID_TOOL_NAME,
|
||||
)
|
||||
|
||||
if INVALID_TOOL_NAME in registered:
|
||||
original_args = call.get("args") or {}
|
||||
error_msg = (
|
||||
f"Tool name '{name}' is not registered. "
|
||||
f"Original arguments were: {original_args!r}."
|
||||
)
|
||||
call["name"] = INVALID_TOOL_NAME
|
||||
call["args"] = {"tool": name, "error": error_msg}
|
||||
metadata = dict(call.get("response_metadata") or {})
|
||||
metadata.setdefault("repair", f"invalid_fallback:{name}")
|
||||
call["response_metadata"] = metadata
|
||||
else:
|
||||
logger.warning(
|
||||
"Could not repair unknown tool call %r; 'invalid' tool not registered",
|
||||
name,
|
||||
)
|
||||
return call
|
||||
|
||||
def _maybe_repair(
|
||||
self,
|
||||
message: AIMessage,
|
||||
registered: set[str],
|
||||
) -> AIMessage | None:
|
||||
if not message.tool_calls:
|
||||
return None
|
||||
|
||||
new_calls: list[dict[str, Any]] = []
|
||||
any_changed = False
|
||||
for raw in message.tool_calls:
|
||||
call = _coerce_existing_tool_call(raw)
|
||||
before = (call.get("name"), call.get("args"))
|
||||
repaired = self._repair_one(call, registered)
|
||||
after = (repaired.get("name"), repaired.get("args"))
|
||||
if before != after:
|
||||
any_changed = True
|
||||
new_calls.append(repaired)
|
||||
|
||||
if not any_changed:
|
||||
return None
|
||||
|
||||
return message.model_copy(update={"tool_calls": new_calls})
|
||||
|
||||
def after_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
return None
|
||||
|
||||
registered = self._registered_for_runtime(runtime)
|
||||
repaired = self._maybe_repair(last, registered)
|
||||
if repaired is None:
|
||||
return None
|
||||
return {"messages": [repaired]}
|
||||
|
||||
async def aafter_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ToolCallNameRepairMiddleware",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue