refactor(agents): colocate 8 main-agent-only middleware as per-concept folders

Each main-agent-only middleware now lives in its own folder under
main_agent/middleware/<concept>/ with builder.py (flag-gated construction)
+ middleware.py (the impl), re-exported via __init__.py. This kills the
cross-folder hop into agents/shared/middleware and keeps each middleware's
two responsibilities (build vs behavior) as colocated siblings.

Moved (impl from shared/middleware, builder from main_agent/middleware):
action_log, anonymous_document, context_editing, doom_loop, knowledge_tree,
noop_injection, otel_span, tool_call_repair.

Impls moved verbatim (git rename, no body edits) so behavior is unchanged.
Builders now import from the local .middleware sibling. stack.py import
paths updated for the 3 renamed folders; shared middleware barrel trimmed;
tests repointed (imports + patch targets).
This commit is contained in:
CREDO23 2026-06-05 11:42:58 +02:00
parent fbd5ccc35a
commit 9493519c61
33 changed files with 149 additions and 83 deletions

View file

@ -1,23 +1,10 @@
"""Middleware components for the SurfSense new chat agent."""
"""Shared middleware components for the SurfSense chat agents."""
from app.agents.shared.middleware.action_log import (
ActionLogMiddleware,
ToolDefinition,
)
from app.agents.shared.middleware.anonymous_document import (
AnonymousDocumentMiddleware,
)
from app.agents.shared.middleware.busy_mutex import BusyMutexMiddleware
from app.agents.shared.middleware.compaction import (
SurfSenseCompactionMiddleware,
create_surfsense_compaction_middleware,
)
from app.agents.shared.middleware.context_editing import (
ClearToolUsesEdit,
SpillingContextEditingMiddleware,
SpillToBackendEdit,
)
from app.agents.shared.middleware.doom_loop import DoomLoopMiddleware
from app.agents.shared.middleware.kb_persistence import (
KnowledgeBasePersistenceMiddleware,
commit_staged_filesystem_state,
@ -25,39 +12,20 @@ from app.agents.shared.middleware.kb_persistence import (
from app.agents.shared.middleware.knowledge_search import (
KnowledgePriorityMiddleware,
)
from app.agents.shared.middleware.knowledge_tree import (
KnowledgeTreeMiddleware,
)
from app.agents.shared.middleware.memory_injection import (
MemoryInjectionMiddleware,
)
from app.agents.shared.middleware.noop_injection import NoopInjectionMiddleware
from app.agents.shared.middleware.otel_span import OtelSpanMiddleware
from app.agents.shared.middleware.permission import PermissionMiddleware
from app.agents.shared.middleware.retry_after import RetryAfterMiddleware
from app.agents.shared.middleware.tool_call_repair import (
ToolCallNameRepairMiddleware,
)
__all__ = [
"ActionLogMiddleware",
"AnonymousDocumentMiddleware",
"BusyMutexMiddleware",
"ClearToolUsesEdit",
"DoomLoopMiddleware",
"KnowledgeBasePersistenceMiddleware",
"KnowledgePriorityMiddleware",
"KnowledgeTreeMiddleware",
"MemoryInjectionMiddleware",
"NoopInjectionMiddleware",
"OtelSpanMiddleware",
"PermissionMiddleware",
"RetryAfterMiddleware",
"SpillToBackendEdit",
"SpillingContextEditingMiddleware",
"SurfSenseCompactionMiddleware",
"ToolCallNameRepairMiddleware",
"ToolDefinition",
"commit_staged_filesystem_state",
"create_surfsense_compaction_middleware",
]

View file

@ -1,388 +0,0 @@
"""Append-only action-log middleware for the SurfSense agent.
Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes
a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt
into reversibility by declaring a ``reverse`` callable on their
:class:`ToolDefinition`; the rendered descriptor is persisted in
``reverse_descriptor`` for use by
``/api/threads/{thread_id}/revert/{action_id}``.
Design points:
* **Defensive.** Logging never blocks the agent. We catch every exception
on the DB write path and emit a warning; the tool's ``ToolMessage``
result is always returned untouched.
* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) +
``result_id`` + ``reverse_descriptor`` are stored. Tool output text
remains in the LangGraph checkpoint / spilled tool-output files.
* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)``
with the parsed JSON result when the tool's content is a JSON object;
otherwise the raw text is passed. Exceptions in the reverse callable
are swallowed and logged a failed descriptor render simply means the
action is NOT marked reversible.
"""
from __future__ import annotations
import json
import logging
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import AgentMiddleware
from langchain_core.callbacks import adispatch_custom_event
from langchain_core.messages import ToolMessage
from app.agents.shared.feature_flags import get_flags
if TYPE_CHECKING: # pragma: no cover - type-only
from langchain.agents.middleware.types import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__)
@dataclass
class ToolDefinition:
"""Reversibility descriptor consumed by :class:`ActionLogMiddleware`.
Only ``name`` and ``reverse`` are read by the middleware; the remaining
fields let callers and tests describe a tool declaratively. A tool is
marked reversible in the action log when ``reverse`` is set and renders a
descriptor without raising.
Attributes:
name: Unique identifier for the tool.
description: Human-readable description of what the tool does.
factory: Optional callable that builds the tool (unused by the
middleware; retained for declarative call sites/tests).
reverse: Optional callable that, given the tool's ``(args, result)``,
returns a ``ReverseDescriptor`` describing the inverse invocation.
"""
name: str
description: str = ""
factory: Callable[[dict[str, Any]], Any] | None = None
reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None
# Cap for the persisted ``args`` JSON to avoid bloating the action log with
# accidentally-huge inputs. Values are truncated and a flag is set in the
# stored payload so consumers can detect truncation.
_MAX_ARGS_PERSIST_BYTES = 32 * 1024 # 32KB
class ActionLogMiddleware(AgentMiddleware):
"""Persist a row in :class:`AgentActionLog` after every tool call.
Should be placed near the OUTERMOST end of the tool-call wrapping stack
so that it sees the *final* :class:`ToolMessage` after all retries,
permission checks, and dedup logic have run. In practice that means
placing it just inside :class:`PermissionMiddleware` and outside
:class:`DedupHITLToolCallsMiddleware`.
The middleware is fully a no-op when:
* the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set
(checked via :func:`get_flags`),
* the per-feature flag ``enable_action_log`` is off, or
* persistence raises (defensive: tool-call dispatch always succeeds).
Args:
thread_id: The current chat thread's primary-key id. Required to
persist a row; if ``None`` the middleware silently no-ops.
search_space_id: Search-space id for cascade-on-delete safety.
user_id: UUID string of the user driving this turn (nullable in
anonymous mode).
tool_definitions: Optional mapping of tool name -> :class:`ToolDefinition`
so the middleware can look up the tool's ``reverse`` callable.
When omitted, no actions are marked reversible.
"""
tools = ()
def __init__(
self,
*,
thread_id: int | None,
search_space_id: int,
user_id: str | None,
tool_definitions: dict[str, ToolDefinition] | None = None,
) -> None:
super().__init__()
self._thread_id = thread_id
self._search_space_id = search_space_id
self._user_id = user_id
self._tool_definitions = dict(tool_definitions or {})
def _enabled(self) -> bool:
flags = get_flags()
if flags.disable_new_agent_stack:
return False
return bool(flags.enable_action_log) and self._thread_id is not None
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
if not self._enabled():
return await handler(request)
result: ToolMessage | Command[Any]
error_payload: dict[str, Any] | None = None
try:
result = await handler(request)
except Exception as exc:
# Persist the failure too so revert/audit can see it, then
# re-raise so downstream middleware (RetryAfter, etc.) handles it.
error_payload = {"type": type(exc).__name__, "message": str(exc)}
await self._record(
request=request,
result=None,
error_payload=error_payload,
)
raise
await self._record(request=request, result=result, error_payload=None)
return result
async def _record(
self,
*,
request: ToolCallRequest,
result: ToolMessage | Command[Any] | None,
error_payload: dict[str, Any] | None,
) -> None:
"""Persist one ``agent_action_log`` row. Defensive: never raises."""
try:
from app.db import AgentActionLog, shielded_async_session
tool_name = _resolve_tool_name(request)
args_payload = _resolve_args_payload(request)
result_id = _resolve_result_id(result)
reverse_descriptor, reversible = self._render_reverse(
tool_name=tool_name,
args=_resolve_args_dict(request),
result=result,
)
tool_call_id = _resolve_tool_call_id(request)
chat_turn_id = _resolve_chat_turn_id(request)
row = AgentActionLog(
thread_id=self._thread_id,
user_id=self._user_id,
search_space_id=self._search_space_id,
# ``turn_id`` is the deprecated alias of ``tool_call_id``
# kept for one release for safe rollback. New consumers
# should read ``tool_call_id`` directly.
turn_id=tool_call_id,
tool_call_id=tool_call_id,
chat_turn_id=chat_turn_id,
message_id=_resolve_message_id(request),
tool_name=tool_name,
args=args_payload,
result_id=result_id,
reversible=reversible,
reverse_descriptor=reverse_descriptor,
error=error_payload,
)
async with shielded_async_session() as session:
session.add(row)
await session.commit()
row_id = int(row.id) if row.id is not None else None
row_created_at = row.created_at
except Exception:
logger.warning(
"ActionLogMiddleware failed to persist action log row",
exc_info=True,
)
return
# Surface a side-channel SSE event so the chat tool card can
# render a Revert button immediately after the row is durable.
# ``stream_new_chat`` translates this into a
# ``data-action-log`` SSE event. We DO NOT include the
# ``reverse_descriptor`` payload here; only a presence flag.
try:
await adispatch_custom_event(
"action_log",
{
"id": row_id,
"lc_tool_call_id": tool_call_id,
"chat_turn_id": chat_turn_id,
"tool_name": tool_name,
"reversible": bool(reversible),
"reverse_descriptor_present": reverse_descriptor is not None,
"created_at": row_created_at.isoformat()
if row_created_at
else None,
"error": error_payload is not None,
},
)
except Exception:
logger.debug(
"ActionLogMiddleware failed to dispatch action_log event",
exc_info=True,
)
def _render_reverse(
self,
*,
tool_name: str,
args: dict[str, Any] | None,
result: ToolMessage | Command[Any] | None,
) -> tuple[dict[str, Any] | None, bool]:
"""Run the tool's ``reverse`` callable and return its descriptor.
Returns a tuple of ``(descriptor_or_None, reversible_bool)``. When
the tool has no ``reverse`` callable, or when the callable raises,
the action is marked non-reversible.
"""
if not result or not isinstance(result, ToolMessage):
return None, False
if args is None:
return None, False
tool_def = self._tool_definitions.get(tool_name)
if tool_def is None or tool_def.reverse is None:
return None, False
try:
parsed_result = _parse_tool_result_content(result)
descriptor = tool_def.reverse(args, parsed_result)
except Exception:
logger.warning(
"Reverse descriptor render failed for tool %s",
tool_name,
exc_info=True,
)
return None, False
if not isinstance(descriptor, dict):
return None, False
return descriptor, True
# ---------------------------------------------------------------------------
# Resolution helpers — defensive against tool_call request shape variation.
# ---------------------------------------------------------------------------
def _resolve_tool_name(request: Any) -> str:
try:
tool = getattr(request, "tool", None)
if tool is not None:
name = getattr(tool, "name", None)
if isinstance(name, str) and name:
return name
call = getattr(request, "tool_call", None) or {}
if isinstance(call, dict):
name = call.get("name")
if isinstance(name, str) and name:
return name
except Exception: # pragma: no cover - defensive
pass
return "unknown"
def _resolve_args_dict(request: Any) -> dict[str, Any] | None:
try:
call = getattr(request, "tool_call", None)
if not isinstance(call, dict):
return None
args = call.get("args")
if isinstance(args, dict):
return args
return None
except Exception: # pragma: no cover - defensive
return None
def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
"""Return a JSON-serializable args dict, truncated if too big."""
args = _resolve_args_dict(request)
if args is None:
return None
try:
encoded = json.dumps(args, default=str)
except Exception:
return {"_repr": repr(args)[:_MAX_ARGS_PERSIST_BYTES]}
if len(encoded) <= _MAX_ARGS_PERSIST_BYTES:
return args
return {
"_truncated": True,
"_size": len(encoded),
"_preview": encoded[:_MAX_ARGS_PERSIST_BYTES],
}
def _resolve_tool_call_id(request: Any) -> str | None:
"""Return the LangChain ``tool_call.id`` for this request, if any."""
try:
call = getattr(request, "tool_call", None) or {}
if isinstance(call, dict):
tid = call.get("id")
if isinstance(tid, str):
return tid
except Exception: # pragma: no cover
pass
return None
# Deprecated alias kept for one release. Old callers and tests treated
# ``turn_id`` as if it carried the LangChain tool_call id; the new column
# lives under ``tool_call_id``. Both resolve to the same value today.
_resolve_turn_id = _resolve_tool_call_id
def _resolve_chat_turn_id(request: Any) -> str | None:
"""Return ``configurable.turn_id`` for this request, if accessible.
``ToolRuntime.config`` is exposed by LangGraph (see
``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id
lives at ``runtime.config["configurable"]["turn_id"]``.
"""
try:
runtime = getattr(request, "runtime", None)
if runtime is None:
return None
config = getattr(runtime, "config", None)
if not isinstance(config, dict):
return None
configurable = config.get("configurable")
if not isinstance(configurable, dict):
return None
value = configurable.get("turn_id")
if isinstance(value, str) and value:
return value
except Exception: # pragma: no cover - defensive
pass
return None
def _resolve_message_id(request: Any) -> str | None:
"""Tool-call IDs serve as best-available message correlator at this layer."""
return _resolve_tool_call_id(request)
def _resolve_result_id(result: Any) -> str | None:
if isinstance(result, ToolMessage):
msg_id = getattr(result, "id", None)
if isinstance(msg_id, str):
return msg_id
return None
def _parse_tool_result_content(result: ToolMessage) -> Any:
content = result.content
if isinstance(content, str):
try:
return json.loads(content)
except (json.JSONDecodeError, ValueError):
return content
return content
__all__ = ["ActionLogMiddleware"]

View file

@ -1,93 +0,0 @@
"""Lightweight middleware that loads the anonymous-session document into state.
Anonymous chats receive a single uploaded document via Redis (no DB row,
read-only). This middleware loads it once on the first turn into
``state['kb_anon_doc']`` so:
* :class:`KnowledgeTreeMiddleware` can render the synthetic ``/documents``
view without touching the DB.
* :class:`KnowledgePriorityMiddleware` skips hybrid search and emits a
degenerate priority list.
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / ``_load_file_data``)
recognises the synthetic path.
The middleware is a no-op when ``anon_session_id`` is not provided or when
the document is already cached in state.
"""
from __future__ import annotations
import json
import logging
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langgraph.runtime import Runtime
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
SurfSenseFilesystemState,
)
from app.agents.shared.path_resolver import DOCUMENTS_ROOT, safe_filename
logger = logging.getLogger(__name__)
class AnonymousDocumentMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Load the anonymous user's uploaded document from Redis into state."""
tools = ()
state_schema = SurfSenseFilesystemState
def __init__(self, *, anon_session_id: str | None) -> None:
self.anon_session_id = anon_session_id
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
if not self.anon_session_id:
return None
if state.get("kb_anon_doc"):
return None
anon_doc = await self._load_anon_document()
if anon_doc is None:
return None
return {"kb_anon_doc": anon_doc}
async def _load_anon_document(self) -> dict[str, Any] | None:
"""Read ``anon:doc:<session_id>`` from Redis."""
try:
import redis.asyncio as aioredis # local import to keep cold paths cheap
from app.config import config
redis_client = aioredis.from_url(
config.REDIS_APP_URL, decode_responses=True
)
try:
redis_key = f"anon:doc:{self.anon_session_id}"
data = await redis_client.get(redis_key)
if not data:
return None
payload = json.loads(data)
finally:
await redis_client.aclose()
except Exception as exc:
logger.warning("Failed to load anonymous document from Redis: %s", exc)
return None
title = str(payload.get("filename") or "uploaded_document")
content = str(payload.get("content") or "")
path = f"{DOCUMENTS_ROOT}/{safe_filename(title)}"
return {
"path": path,
"title": title,
"content": content,
"chunks": [{"chunk_id": -1, "content": content}] if content else [],
}
__all__ = ["AnonymousDocumentMiddleware"]

View file

@ -1,350 +0,0 @@
"""
SpillToBackendEdit + SpillingContextEditingMiddleware.
LangChain's :class:`ClearToolUsesEdit` discards old ``ToolMessage.content``
when the context-editing budget triggers, replacing the body with a fixed
placeholder. That's lossy: anything the agent might want to revisit is
gone. The spill-to-disk pattern (originally from OpenCode's
``opencode/packages/opencode/src/tool/truncate.ts``) keeps the prune
behavior but writes the full original payload to the runtime backend
under ``/tool_outputs/{thread_id}/{message_id}.txt`` first. The
placeholder is then upgraded to point at the spill path so the agent
(or a subagent) can read it back on demand.
Why this is a middleware subclass instead of a plain ``ContextEdit``:
``ContextEdit.apply`` is sync, but writing to the backend is async. We
capture the spill payloads inside ``apply`` and flush them via
``await backend.aupload_files(...)`` from ``awrap_model_call`` *before*
delegating to the handler, so the explore subagent can always read what
the placeholder advertises.
"""
from __future__ import annotations
import logging
import threading
from collections.abc import Awaitable, Callable, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware.context_editing import (
ClearToolUsesEdit,
ContextEdit,
ContextEditingMiddleware,
TokenCounter,
)
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
ToolMessage,
)
from langchain_core.messages.utils import count_tokens_approximately
from langgraph.config import get_config
if TYPE_CHECKING:
from deepagents.backends.protocol import BackendProtocol
from langchain.agents.middleware.types import (
ModelRequest,
ModelResponse,
)
logger = logging.getLogger(__name__)
DEFAULT_SPILL_PREFIX = "/tool_outputs"
def _build_spill_placeholder(spill_path: str) -> str:
"""Build the user-facing placeholder text shown to the model."""
return (
f"[cleared — full output at {spill_path}; ask the explore subagent to read it]"
)
def _get_thread_id_or_session() -> str:
"""Best-effort thread_id discovery for the spill path.
Falls back to a process-stable string if no LangGraph config is
available (e.g. unit tests). The exact value doesn't matter as long
as it's stable within one stream so the placeholder paths line up
with the actual upload path.
"""
try:
config = get_config()
thread_id = config.get("configurable", {}).get("thread_id")
if thread_id is not None:
return str(thread_id)
except RuntimeError:
pass
return "no_thread"
@dataclass(slots=True)
class SpillToBackendEdit(ContextEdit):
"""Capture-and-replace context edit that spills full tool output to the backend.
Behaves like :class:`ClearToolUsesEdit` (same trigger / keep / exclude
semantics) **and** records the original ``ToolMessage.content`` in
:attr:`pending_spills` so the wrapping middleware can flush them
before the model call.
Args:
trigger: Token threshold above which the edit fires.
clear_at_least: Minimum number of tokens to reclaim (best effort).
keep: Number of most-recent ``ToolMessage`` instances to leave
untouched.
exclude_tools: Names of tools whose output is NOT spilled.
clear_tool_inputs: Also clear the originating ``AIMessage.tool_calls``
args when their pair is cleared.
path_prefix: Path under the backend where spills are written.
Default ``"/tool_outputs"``.
"""
trigger: int = 100_000
clear_at_least: int = 0
keep: int = 3
clear_tool_inputs: bool = False
exclude_tools: Sequence[str] = ()
path_prefix: str = DEFAULT_SPILL_PREFIX
pending_spills: list[tuple[str, bytes]] = field(default_factory=list)
_lock: threading.Lock = field(default_factory=threading.Lock)
def drain_pending(self) -> list[tuple[str, bytes]]:
"""Return and clear the pending-spill list atomically."""
with self._lock:
out = list(self.pending_spills)
self.pending_spills.clear()
return out
def apply(
self,
messages: list[AnyMessage],
*,
count_tokens: TokenCounter,
) -> None:
"""Mirror ``ClearToolUsesEdit.apply`` but capture originals first."""
tokens = count_tokens(messages)
if tokens <= self.trigger:
return
candidates = [
(idx, msg)
for idx, msg in enumerate(messages)
if isinstance(msg, ToolMessage)
]
if self.keep >= len(candidates):
return
if self.keep:
candidates = candidates[: -self.keep]
thread_id = _get_thread_id_or_session()
excluded_tools = set(self.exclude_tools)
for idx, tool_message in candidates:
if tool_message.response_metadata.get("context_editing", {}).get("cleared"):
continue
ai_message = next(
(m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)),
None,
)
if ai_message is None:
continue
tool_call = next(
(
call
for call in ai_message.tool_calls
if call.get("id") == tool_message.tool_call_id
),
None,
)
if tool_call is None:
continue
tool_name = tool_message.name or tool_call["name"]
if tool_name in excluded_tools:
continue
message_id = tool_message.id or tool_message.tool_call_id or "unknown"
spill_path = f"{self.path_prefix}/{thread_id}/{message_id}.txt"
original = tool_message.content
payload = self._encode_payload(original)
with self._lock:
self.pending_spills.append((spill_path, payload))
messages[idx] = tool_message.model_copy(
update={
"artifact": None,
"content": _build_spill_placeholder(spill_path),
"response_metadata": {
**tool_message.response_metadata,
"context_editing": {
"cleared": True,
"strategy": "spill_to_backend",
"spill_path": spill_path,
},
},
}
)
if self.clear_tool_inputs:
ai_idx = messages.index(ai_message)
messages[ai_idx] = self._clear_input_args(
ai_message, tool_message.tool_call_id or ""
)
if self.clear_at_least > 0:
new_token_count = count_tokens(messages)
cleared_tokens = max(0, tokens - new_token_count)
if cleared_tokens >= self.clear_at_least:
break
@staticmethod
def _encode_payload(content: Any) -> bytes:
"""Serialize ``ToolMessage.content`` to bytes for upload."""
if isinstance(content, bytes):
return content
if isinstance(content, str):
return content.encode("utf-8")
try:
import json
return json.dumps(content, default=str).encode("utf-8")
except Exception:
return str(content).encode("utf-8")
@staticmethod
def _clear_input_args(message: AIMessage, tool_call_id: str) -> AIMessage:
updated_tool_calls: list[dict[str, Any]] = []
cleared_any = False
for tool_call in message.tool_calls:
updated = dict(tool_call)
if updated.get("id") == tool_call_id:
updated["args"] = {}
cleared_any = True
updated_tool_calls.append(updated)
metadata = dict(getattr(message, "response_metadata", {}))
if cleared_any:
ctx = dict(metadata.get("context_editing", {}))
ids = set(ctx.get("cleared_tool_inputs", []))
ids.add(tool_call_id)
ctx["cleared_tool_inputs"] = sorted(ids)
metadata["context_editing"] = ctx
return message.model_copy(
update={
"tool_calls": updated_tool_calls,
"response_metadata": metadata,
}
)
BackendResolver = "Callable[[Any], BackendProtocol] | BackendProtocol"
class SpillingContextEditingMiddleware(ContextEditingMiddleware):
""":class:`ContextEditingMiddleware` that flushes :class:`SpillToBackendEdit` writes.
Runs the configured edits as the parent does, then flushes any
pending spills via the supplied backend resolver before delegating
to the model handler. Spill failures are logged but never abort the
model call the placeholder text is already in the message, so the
worst case is the agent gets a placeholder it cannot follow up on.
"""
def __init__(
self,
*,
edits: Sequence[ContextEdit],
backend_resolver: BackendResolver | None = None,
token_count_method: str = "approximate",
) -> None:
super().__init__(edits=list(edits), token_count_method=token_count_method) # type: ignore[arg-type]
self._backend_resolver = backend_resolver
def _resolve_backend(self, request: ModelRequest) -> BackendProtocol | None:
if self._backend_resolver is None:
return None
if callable(self._backend_resolver):
try:
from langchain.tools import ToolRuntime
tool_runtime = ToolRuntime(
state=getattr(request, "state", {}),
context=getattr(request.runtime, "context", None),
stream_writer=getattr(request.runtime, "stream_writer", None),
store=getattr(request.runtime, "store", None),
config=getattr(request.runtime, "config", None) or {},
tool_call_id=None,
)
return self._backend_resolver(tool_runtime)
except Exception:
logger.exception("Failed to resolve spill backend")
return None
return self._backend_resolver # type: ignore[return-value]
def _collect_pending(self) -> list[tuple[str, bytes]]:
out: list[tuple[str, bytes]] = []
for edit in self.edits:
if isinstance(edit, SpillToBackendEdit):
out.extend(edit.drain_pending())
return out
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> Any:
if not request.messages:
return await handler(request)
if self.token_count_method == "approximate":
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return count_tokens_approximately(messages)
else:
system_msg = [request.system_message] if request.system_message else []
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return request.model.get_num_tokens_from_messages(
system_msg + list(messages), request.tools
)
edited_messages = deepcopy(list(request.messages))
for edit in self.edits:
edit.apply(edited_messages, count_tokens=count_tokens)
pending = self._collect_pending()
if pending:
backend = self._resolve_backend(request)
if backend is not None:
try:
await backend.aupload_files(pending)
except Exception:
logger.exception(
"Spill-to-backend upload failed (%d files); placeholders "
"remain in messages but content is unrecoverable",
len(pending),
)
else:
logger.warning(
"SpillToBackendEdit produced %d pending spills but no backend "
"resolver was configured; content is unrecoverable",
len(pending),
)
return await handler(request.override(messages=edited_messages))
__all__ = [
"DEFAULT_SPILL_PREFIX",
"ClearToolUsesEdit",
"SpillToBackendEdit",
"SpillingContextEditingMiddleware",
"_build_spill_placeholder",
]

View file

@ -1,238 +0,0 @@
"""
DoomLoopMiddleware pattern-based detector for repeated identical tool calls.
LangChain has :class:`ToolCallLimitMiddleware` which caps the *total* number
of tool calls per turn but it can't tell apart "10 distinct, useful
calls" from "the same call 10 times in a row". This middleware fills that
gap with a sliding-window check on tool-call signatures, ported from
OpenCode's ``packages/opencode/src/session/processor.ts``.
When the same tool with the same arguments is called N times in a row,
the agent has likely entered an infinite loop. We surface this to the
user as an interrupt with ``permission="doom_loop"`` so the UI can
render an "Are you stuck? Continue / cancel?" affordance.
This ships **OFF by default** until the frontend explicitly handles
``context.permission == "doom_loop"`` interrupts.
Wire format: uses SurfSense's existing ``interrupt()`` payload shape
(see ``app/agents/shared/tools/hitl.py``):
{
"type": "permission_ask",
"action": {"tool": <name>, "params": <args>},
"context": {"permission": "doom_loop", "recent_signatures": [...]},
}
so the frontend that already handles HITL prompts can render this with
no changes beyond a string check.
"""
from __future__ import annotations
import hashlib
import json
import logging
from collections import deque
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ResponseT,
)
from langchain_core.messages import AIMessage
from langgraph.config import get_config
from langgraph.runtime import Runtime
from langgraph.types import interrupt
from app.observability import metrics as ot_metrics, otel as ot
logger = logging.getLogger(__name__)
def _signature(name: str, args: Any) -> str:
"""Hash a tool call ``(name, args)`` to a short signature."""
try:
canonical = json.dumps(args, sort_keys=True, default=str)
except (TypeError, ValueError):
canonical = repr(args)
digest = hashlib.sha1(f"{name}::{canonical}".encode()).hexdigest()
return digest[:16]
class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Detect repeated identical tool calls and prompt the user.
Tracks a sliding window of the most-recent ``threshold`` tool-call
signatures across the live request. When all entries match, raise
a SurfSense-style HITL interrupt with ``permission="doom_loop"``.
Args:
threshold: How many consecutive identical signatures count as a
doom loop. Default 3 (matches OpenCode's processor.ts).
"""
def __init__(self, *, threshold: int = 3) -> None:
super().__init__()
if threshold < 2:
raise ValueError("DoomLoopMiddleware threshold must be >= 2")
self._threshold = threshold
self.tools = []
# Per-thread sliding windows. We can't put this in graph state
# without state-schema gymnastics; for one process-lifetime it's
# fine to keep an in-memory map keyed by thread_id.
self._windows: dict[str, deque[str]] = {}
@staticmethod
def _thread_id_from_runtime(runtime: Runtime[ContextT]) -> str:
"""Resolve the thread id for sliding-window keying.
Prefer LangGraph's ``get_config()`` (the only way to read
``RunnableConfig`` inside a node :class:`Runtime` does NOT carry
a ``config`` attribute). Fall back to ``runtime.config`` for unit
tests that synthesize a config-bearing stub. Default
``"no_thread"`` is intentionally only used when both lookups fail
it would collapse all threads into one window so we keep the
debug log loud.
"""
def _from_dict(cfg: Any) -> str | None:
if not isinstance(cfg, dict):
return None
tid = (cfg.get("configurable") or {}).get("thread_id")
return str(tid) if tid is not None else None
try:
tid = _from_dict(get_config())
except Exception:
tid = None
if tid is not None:
return tid
tid = _from_dict(getattr(runtime, "config", None))
if tid is not None:
return tid
logger.debug(
"DoomLoopMiddleware: no thread_id resolved from RunnableConfig; "
"falling back to shared 'no_thread' window."
)
return "no_thread"
def _window(self, thread_id: str) -> deque[str]:
win = self._windows.get(thread_id)
if win is None:
win = deque(maxlen=self._threshold)
self._windows[thread_id] = win
return win
def _detect(
self, message: AIMessage, runtime: Runtime[ContextT]
) -> tuple[bool, list[str], dict[str, Any] | None]:
if not message.tool_calls:
return False, [], None
thread_id = self._thread_id_from_runtime(runtime)
window = self._window(thread_id)
triggered_call: dict[str, Any] | None = None
for call in message.tool_calls:
name = (
call.get("name")
if isinstance(call, dict)
else getattr(call, "name", None)
)
args = (
call.get("args")
if isinstance(call, dict)
else getattr(call, "args", {})
)
if not isinstance(name, str):
continue
sig = _signature(name, args)
window.append(sig)
if len(window) >= self._threshold and len(set(window)) == 1:
triggered_call = {"name": name, "params": args or {}}
break
if triggered_call is None:
return False, list(window), None
return True, list(window), triggered_call
def after_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
messages = state.get("messages") or []
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
triggered, signatures, action = self._detect(last, runtime)
if not triggered:
return None
logger.warning(
"Doom loop detected: tool %s called %d times in a row (sig=%s)",
action["name"] if action else "<unknown>",
self._threshold,
signatures[-1] if signatures else "<empty>",
)
# Open an interrupt.raised span with permission=doom_loop attribute
# so dashboards can break out doom-loop interrupts from regular
# permission asks via the ``interrupt.permission`` attribute.
with ot.interrupt_span(
interrupt_type="permission_ask",
extra={
"interrupt.permission": "doom_loop",
"interrupt.threshold": self._threshold,
"interrupt.tool": (action or {}).get("tool", "<unknown>"),
},
):
ot_metrics.record_interrupt(interrupt_type="permission_ask")
decision = interrupt(
{
"type": "permission_ask",
"action": action or {"tool": "<unknown>", "params": {}},
"context": {
"permission": "doom_loop",
"recent_signatures": signatures,
"threshold": self._threshold,
},
}
)
# Reset window so the next decision (continue/cancel) starts fresh.
thread_id = self._thread_id_from_runtime(runtime)
self._windows.pop(thread_id, None)
# Decision shape mirrors ``tools/hitl.py``: {"decision_type": "..."}
# If the user cancelled, jump to end. Otherwise return ``None`` so the
# tool call proceeds. The frontend's exact reply names may differ —
# we tolerate any shape that contains a string with "reject"/"cancel".
if isinstance(decision, dict):
kind = str(
decision.get("decision_type") or decision.get("type") or ""
).lower()
if "reject" in kind or "cancel" in kind:
return {"jump_to": "end"}
return None
async def aafter_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
return self.after_model(state, runtime)
__all__ = [
"DoomLoopMiddleware",
"_signature",
]

View file

@ -1,334 +0,0 @@
"""Workspace-tree middleware for the SurfSense agent.
Renders the full ``Folder``+``Document`` tree under ``/documents/`` once per
turn (cloud only), caches it by ``(search_space_id, tree_version)``, and
injects the result as a ``<workspace_tree>`` system message immediately
before the latest human turn.
The render is bounded by two truncation layers:
1. **Entry cap** at most ``MAX_TREE_ENTRIES`` lines. The remainder is
replaced with a "use ls" hint.
2. **Token cap** at most ``MAX_TREE_TOKENS`` tokens (using the LLM's
token-count profile when available). If the entry-truncated tree still
exceeds the token cap we fall back to a root-only summary.
Anonymous mode renders only ``state['kb_anon_doc']`` (no DB calls).
This middleware also performs a one-time initialization of ``state['cwd']``
to ``"/documents"`` so subsequent middlewares and tools always see a valid
cwd in cloud mode.
"""
from __future__ import annotations
import asyncio
import logging
import time
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import SystemMessage
from langgraph.runtime import Runtime
from sqlalchemy import select
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
SurfSenseFilesystemState,
)
from app.agents.shared.filesystem_selection import FilesystemMode
from app.agents.shared.path_resolver import (
DOCUMENTS_ROOT,
PathIndex,
build_path_index,
doc_to_virtual_path,
)
from app.db import Document, shielded_async_session
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
try:
from litellm import token_counter
except Exception: # pragma: no cover - optional dep
token_counter = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
MAX_TREE_ENTRIES = 500
MAX_TREE_TOKENS = 4000
def _approx_tokens(text: str) -> int:
"""Cheap fallback token estimate (1 token ~= 4 chars)."""
return max(1, (len(text) + 3) // 4)
def _count_tokens(text: str, *, llm: BaseChatModel | None) -> int:
if llm is None:
return _approx_tokens(text)
count_fn = getattr(llm, "_count_tokens", None)
if callable(count_fn):
try:
return int(count_fn([{"role": "user", "content": text}]))
except Exception:
pass
profile = getattr(llm, "profile", None)
model_names: list[str] = []
if isinstance(profile, dict):
tcms = profile.get("token_count_models")
if isinstance(tcms, list):
model_names.extend(name for name in tcms if isinstance(name, str) and name)
tcm = profile.get("token_count_model")
if isinstance(tcm, str) and tcm and tcm not in model_names:
model_names.append(tcm)
model_name = model_names[0] if model_names else getattr(llm, "model", None)
if not isinstance(model_name, str) or not model_name or token_counter is None:
return _approx_tokens(text)
try:
return int(
token_counter(
messages=[{"role": "user", "content": text}],
model=model_name,
)
)
except Exception:
return _approx_tokens(text)
class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Inject the workspace folder/document tree into the agent's context."""
tools = ()
state_schema = SurfSenseFilesystemState
def __init__(
self,
*,
search_space_id: int,
filesystem_mode: FilesystemMode,
llm: BaseChatModel | None = None,
max_entries: int = MAX_TREE_ENTRIES,
max_tokens: int = MAX_TREE_TOKENS,
inject_system_message: bool = True, # For backwards compatibility
) -> None:
self.search_space_id = search_space_id
self.filesystem_mode = filesystem_mode
self.llm = llm
self.max_entries = max_entries
self.max_tokens = max_tokens
self.inject_system_message = inject_system_message
self._cache: dict[tuple[int, int, bool], str] = {}
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
if self.filesystem_mode != FilesystemMode.CLOUD:
return None
start = time.perf_counter()
update: dict[str, Any] = {}
if not state.get("cwd"):
update["cwd"] = DOCUMENTS_ROOT
anon_doc = state.get("kb_anon_doc")
if anon_doc:
tree_msg = self._render_anon_tree(anon_doc)
cache_outcome = "anon"
else:
version = int(state.get("tree_version") or 0)
cache_key = (self.search_space_id, version, False)
cache_outcome = "hit" if cache_key in self._cache else "miss"
tree_msg = await self._render_kb_tree(state)
update["workspace_tree_text"] = tree_msg
if self.inject_system_message:
messages = list(state.get("messages") or [])
insert_at = max(len(messages) - 1, 0)
messages.insert(insert_at, SystemMessage(content=tree_msg))
update["messages"] = messages
_perf_log.info(
"[knowledge_tree] cache=%s chars=%d elapsed=%.3fs space=%d",
cache_outcome,
len(tree_msg),
time.perf_counter() - start,
self.search_space_id,
)
return update
def before_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
try:
loop = asyncio.get_running_loop()
if loop.is_running():
return None
except RuntimeError:
pass
return asyncio.run(self.abefore_agent(state, runtime))
# ------------------------------------------------------------------ render
def _render_anon_tree(self, anon_doc: dict[str, Any]) -> str:
path = str(anon_doc.get("path") or "")
title = str(anon_doc.get("title") or "uploaded_document")
return (
"<workspace_tree>\n"
"Anonymous session — only one read-only document is available.\n"
f"{DOCUMENTS_ROOT}/\n"
f" {path}{title}\n"
"</workspace_tree>"
)
async def _render_kb_tree(self, state: AgentState) -> str:
version = int(state.get("tree_version") or 0)
cache_key = (self.search_space_id, version, False)
cached = self._cache.get(cache_key)
if cached is not None:
return cached
try:
async with shielded_async_session() as session:
index = await build_path_index(session, self.search_space_id)
doc_rows = await session.execute(
select(Document.id, Document.title, Document.folder_id).where(
Document.search_space_id == self.search_space_id
)
)
docs = list(doc_rows.all())
except Exception as exc: # pragma: no cover - defensive
logger.warning("knowledge_tree: DB error %s", exc)
return "<workspace_tree>\n(unavailable)\n</workspace_tree>"
rendered = self._format_tree(index, docs)
self._cache[cache_key] = rendered
return rendered
def _format_tree(self, index: PathIndex, docs: list[Any]) -> str:
folder_paths = sorted(set(index.folder_paths.values()))
doc_paths = sorted(
doc_to_virtual_path(
doc_id=row.id,
title=str(row.title or "untitled"),
folder_id=row.folder_id,
index=index,
)
for row in docs
)
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
# Pre-compute which folders have at least one descendant (folder or doc).
# A folder is "empty" iff no path in `all_paths` is strictly under it.
# Used to emit an explicit "(empty)" marker so the LLM doesn't have to
# infer emptiness from indentation alone.
non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths)
lines: list[str] = []
for path in all_paths:
depth = (
0
if path == DOCUMENTS_ROOT
else len([p for p in path[len(DOCUMENTS_ROOT) :].split("/") if p])
)
indent = " " * depth
is_dir = path == DOCUMENTS_ROOT or path in folder_paths
display = (
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
)
if is_dir:
if path != DOCUMENTS_ROOT and path not in non_empty_folders:
lines.append(f"{indent}{display}/ (empty)")
else:
lines.append(f"{indent}{display}/")
else:
lines.append(f"{indent}{display}")
if len(lines) >= self.max_entries:
remaining = len(all_paths) - len(lines)
if remaining > 0:
lines.append(
f"... {remaining} more entries — use "
"ls('/documents/<folder>', offset, limit) to expand"
)
break
body = "\n".join(lines)
rendered = f"<workspace_tree>\n{body}\n</workspace_tree>"
token_count = _count_tokens(rendered, llm=self.llm)
if token_count <= self.max_tokens:
return rendered
return self._format_root_summary(folder_paths, doc_paths)
@staticmethod
def _compute_non_empty_folders(
folder_paths: list[str], doc_paths: list[str]
) -> set[str]:
"""Return the set of folder paths that contain at least one descendant.
A folder is "non-empty" if any document path or any other folder path
is strictly under it. Documents propagate emptiness up to every
ancestor folder, while a sub-folder only marks its direct ancestors
non-empty (so a chain of empty folders all read ``(empty)``).
"""
non_empty: set[str] = set()
folder_set = set(folder_paths)
for doc_path in doc_paths:
parent = doc_path.rsplit("/", 1)[0]
while parent and parent != DOCUMENTS_ROOT:
if parent in folder_set:
non_empty.add(parent)
parent = parent.rsplit("/", 1)[0]
for child in folder_paths:
parent = child.rsplit("/", 1)[0]
while parent and parent != DOCUMENTS_ROOT and parent in folder_set:
non_empty.add(parent)
parent = parent.rsplit("/", 1)[0]
return non_empty
def _format_root_summary(
self, folder_paths: list[str], doc_paths: list[str]
) -> str:
top_level: dict[str, int] = {}
loose_docs = 0
for path in doc_paths:
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
if "/" in rel:
top = rel.split("/", 1)[0]
top_level[top] = top_level.get(top, 0) + 1
else:
loose_docs += 1
for path in folder_paths:
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
if not rel:
continue
top = rel.split("/", 1)[0]
top_level.setdefault(top, 0)
lines = [DOCUMENTS_ROOT + "/"]
for name in sorted(top_level):
count = top_level[name]
lines.append(f" {name}/ ({count} document{'s' if count != 1 else ''})")
if loose_docs:
lines.append(
f" ({loose_docs} loose document{'s' if loose_docs != 1 else ''})"
)
lines.append(
"Tree is large; use list_tree('/documents/<folder>') to drill in "
"or ls('/documents/<folder>', offset, limit) for paginated listings."
)
return "<workspace_tree>\n" + "\n".join(lines) + "\n</workspace_tree>"
__all__ = ["KnowledgeTreeMiddleware"]

View file

@ -1,141 +0,0 @@
"""
``_noop`` provider-compatibility tool + injection middleware.
Some providers (LiteLLM, Bedrock, Copilot) 400 when a model call has
empty ``tools`` but the message history includes prior ``tool_calls``
they treat that shape as malformed even though it's perfectly valid
LangChain. SurfSense hits this on the compaction summarize call (no
tools, history full of tool calls).
Ported from OpenCode's ``packages/opencode/src/session/llm.ts:209-228``,
which discovered and codified the workaround: inject a no-op tool *only*
on those provider shapes so the request validates without ever being
called.
Operation: a :class:`NoopInjectionMiddleware` ``wrap_model_call`` checks
if the request has zero tools but the last AI message in history includes
``tool_calls``. If yes, it injects the ``_noop`` tool only never
globally mirroring OpenCode's gating exactly. The :func:`noop_tool`
returns empty content when called (which it should never be in
practice).
"""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
from langchain_core.messages import AIMessage
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
NOOP_TOOL_NAME = "_noop"
NOOP_TOOL_DESCRIPTION = "Do not call this tool. It exists only for API compatibility."
@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION)
def noop_tool() -> str:
"""Return empty content. Never expected to be called."""
return ""
# Provider markers that benefit from ``_noop`` injection. These match
# OpenCode's gating list (``llm.ts:209-228``). We also accept any string
# containing one of these substrings so e.g. ``litellm`` matches
# ``ChatLiteLLM``.
_NOOP_NEEDED_PROVIDERS: tuple[str, ...] = (
"litellm",
"bedrock",
"copilot",
)
def _provider_needs_noop(model: Any) -> bool:
"""Heuristic: does this model's provider need the _noop injection?"""
try:
ls_params = model._get_ls_params()
provider = str(ls_params.get("ls_provider", "")).lower()
except Exception:
provider = ""
if not provider:
cls_name = type(model).__name__.lower()
provider = cls_name
return any(needle in provider for needle in _NOOP_NEEDED_PROVIDERS)
def _last_ai_has_tool_calls(messages: list[Any]) -> bool:
for msg in reversed(messages):
if isinstance(msg, AIMessage):
return bool(msg.tool_calls)
return False
class NoopInjectionMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Inject the ``_noop`` tool only when the provider would otherwise 400.
The check fires per model call, not at agent build time, because the
summarization path generates a no-tool subcall at runtime. The
extra tool is appended to ``request.tools`` as an instance the
actual ``langchain_core.tools.BaseTool`` is bound on every call site
that creates the agent.
"""
def __init__(self, *, noop_tool_instance: Any | None = None) -> None:
super().__init__()
self._noop_tool = noop_tool_instance or noop_tool
self.tools = []
def _should_inject(self, request: ModelRequest[ContextT]) -> bool:
if request.tools:
return False
if not _last_ai_has_tool_calls(request.messages):
return False
return _provider_needs_noop(request.model)
def _augmented(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
return request.override(tools=[self._noop_tool])
def wrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> Any:
if self._should_inject(request):
logger.debug("Injecting _noop tool for provider compatibility")
return handler(self._augmented(request))
return handler(request)
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> Any:
if self._should_inject(request):
logger.debug("Injecting _noop tool for provider compatibility")
return await handler(self._augmented(request))
return await handler(request)
__all__ = [
"NOOP_TOOL_DESCRIPTION",
"NOOP_TOOL_NAME",
"NoopInjectionMiddleware",
"_provider_needs_noop",
"noop_tool",
]

View file

@ -1,285 +0,0 @@
"""
OpenTelemetry span middleware for the SurfSense ``new_chat`` agent.
Wraps both ``model.call`` (LLM invocations) and ``tool.call`` (tool
executions) with OTel spans, attaching low-cardinality span names and
high-cardinality identifiers as attributes.
This middleware is intentionally a thin adapter over
:mod:`app.observability.otel`; when OTel is not configured all spans
collapse to no-ops and the wrapper adds <1µs overhead per call. When
OTel **is** configured (``OTEL_EXPORTER_OTLP_ENDPOINT`` set), every
model and tool call gets a span with the standard attributes our
dashboards expect.
"""
from __future__ import annotations
import logging
import time
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage, ToolMessage
from app.observability import metrics as ot_metrics, otel as ot
if TYPE_CHECKING: # pragma: no cover — type-only
from langchain.agents.middleware.types import (
ModelRequest,
ModelResponse,
ToolCallRequest,
)
from langgraph.types import Command
logger = logging.getLogger(__name__)
class OtelSpanMiddleware(AgentMiddleware):
"""Emit ``model.call`` and ``tool.call`` OTel spans for every invocation.
Should be placed near the **outer** end of the middleware list so
that the spans encompass retry/fallback wrapper effects (i.e. ``N``
model.call spans for ``N`` retry attempts) but inside any concurrency/
auth gate. Empirically this means **between** ``BusyMutex`` and
``RetryAfter``.
"""
def __init__(self, *, instrumentation_name: str = "surfsense.new_chat") -> None:
super().__init__()
self._instrumentation_name = instrumentation_name
# ------------------------------------------------------------------
# Model call spans
# ------------------------------------------------------------------
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]],
) -> ModelResponse | AIMessage | Any:
if not ot.is_enabled():
return await handler(request)
model_id, provider = _resolve_model_attrs(request)
t0 = time.perf_counter()
with ot.model_call_span(model_id=model_id, provider=provider) as sp:
_annotate_model_request(sp, model_id=model_id, provider=provider)
try:
result = await handler(request)
except Exception:
ot_metrics.record_model_call_duration(
(time.perf_counter() - t0) * 1000,
model=model_id,
provider=provider,
)
# span context manager records + re-raises
raise
else:
input_tokens, output_tokens = _annotate_model_response(
sp,
result,
model_id=model_id,
provider=provider,
)
ot_metrics.record_model_call_duration(
(time.perf_counter() - t0) * 1000,
model=model_id,
provider=provider,
)
ot_metrics.record_model_token_usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
model=model_id,
provider=provider,
)
return result
# ------------------------------------------------------------------
# Tool call spans
# ------------------------------------------------------------------
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
if not ot.is_enabled():
return await handler(request)
tool_name = _resolve_tool_name(request)
input_size = _resolve_input_size(request)
t0 = time.perf_counter()
with ot.tool_call_span(tool_name, input_size=input_size) as sp:
try:
result = await handler(request)
except Exception:
ot_metrics.record_tool_call_duration(
(time.perf_counter() - t0) * 1000,
tool_name=tool_name,
)
ot_metrics.record_tool_call_error(tool_name=tool_name)
raise
errored = _annotate_tool_result(sp, result)
ot_metrics.record_tool_call_duration(
(time.perf_counter() - t0) * 1000,
tool_name=tool_name,
)
if errored:
ot_metrics.record_tool_call_error(tool_name=tool_name)
return result
# ---------------------------------------------------------------------------
# Attribute helpers (kept defensive; we never want OTel bookkeeping to break
# a real model/tool call).
# ---------------------------------------------------------------------------
def _resolve_model_attrs(request: Any) -> tuple[str | None, str | None]:
"""Extract ``model.id`` and ``model.provider`` from a ``ModelRequest``."""
model_id: str | None = None
provider: str | None = None
try:
model = getattr(request, "model", None)
if model is None:
return None, None
# langchain BaseChatModel exposes a few different identifiers
for attr in ("model_name", "model", "model_id"):
value = getattr(model, attr, None)
if value:
model_id = str(value)
break
# provider sometimes lives on ``_llm_type`` (legacy) or ``provider``
for attr in ("provider", "_llm_type"):
value = getattr(model, attr, None)
if value:
provider = str(value)
break
except Exception: # pragma: no cover — defensive
pass
return model_id, provider
def _resolve_tool_name(request: Any) -> str:
try:
tool = getattr(request, "tool", None)
if tool is not None:
name = getattr(tool, "name", None)
if isinstance(name, str) and name:
return name
# Fall back to the tool_call dict
call = getattr(request, "tool_call", None) or {}
name = call.get("name") if isinstance(call, dict) else None
if isinstance(name, str) and name:
return name
except Exception: # pragma: no cover — defensive
pass
return "unknown"
def _resolve_input_size(request: Any) -> int | None:
try:
call = getattr(request, "tool_call", None)
if not isinstance(call, dict) or not call:
return None
args = call.get("args")
if args is None:
return None
return len(repr(args))
except Exception: # pragma: no cover — defensive
return None
def _annotate_model_request(
span: Any, *, model_id: str | None, provider: str | None
) -> None:
try:
span.set_attribute("gen_ai.operation.name", "chat")
if model_id:
span.set_attribute("gen_ai.request.model", model_id)
if provider:
span.set_attribute("gen_ai.provider.name", provider)
except Exception: # pragma: no cover — defensive
pass
def _annotate_model_response(
span: Any,
result: Any,
*,
model_id: str | None = None,
provider: str | None = None,
) -> tuple[int | None, int | None]:
"""Best-effort: attach prompt/completion token counts when available."""
input_tokens: int | None = None
output_tokens: int | None = None
try:
# ModelResponse may be a dataclass with .result containing AIMessage
msg: Any
if isinstance(result, AIMessage):
msg = result
else:
inner = getattr(result, "result", None)
msg = inner[-1] if isinstance(inner, list) and inner else inner
if msg is None:
return None, None
if provider:
span.set_attribute("gen_ai.provider.name", provider)
if model_id:
span.set_attribute("gen_ai.request.model", model_id)
response_model = getattr(msg, "response_metadata", {}) or {}
if isinstance(response_model, dict):
response_model = (
response_model.get("model_name")
or response_model.get("model")
or response_model.get("model_id")
)
if not response_model:
response_model = model_id
if response_model:
span.set_attribute("gen_ai.response.model", str(response_model))
span.set_attribute("gen_ai.operation.name", "chat")
usage = getattr(msg, "usage_metadata", None) or {}
if isinstance(usage, dict):
if (n := usage.get("input_tokens")) is not None:
input_tokens = int(n)
span.set_attribute("gen_ai.usage.input_tokens", input_tokens)
if (n := usage.get("output_tokens")) is not None:
output_tokens = int(n)
span.set_attribute("gen_ai.usage.output_tokens", output_tokens)
if (n := usage.get("total_tokens")) is not None:
span.set_attribute("gen_ai.usage.total_tokens", int(n))
tool_calls = getattr(msg, "tool_calls", None) or []
span.set_attribute("model.tool_calls", len(tool_calls))
except Exception: # pragma: no cover — defensive
pass
return input_tokens, output_tokens
def _annotate_tool_result(span: Any, result: Any) -> bool:
errored = False
try:
if isinstance(result, ToolMessage):
content = (
result.content
if isinstance(result.content, str)
else repr(result.content)
)
span.set_attribute("tool.output.size", len(content))
status = getattr(result, "status", None)
if isinstance(status, str):
span.set_attribute("tool.status", status)
errored = status.lower() == "error"
kwargs = getattr(result, "additional_kwargs", None) or {}
if isinstance(kwargs, dict) and kwargs.get("error"):
span.set_attribute("tool.error", True)
errored = True
except Exception: # pragma: no cover — defensive
pass
return errored
__all__ = ["OtelSpanMiddleware"]

View file

@ -1,197 +0,0 @@
"""
ToolCallNameRepairMiddleware two-stage tool-name repair.
Operation:
1. **Stage 1 lowercase repair:** if a tool call's ``name`` is not in
the registry but ``name.lower()`` is, rewrite in place. Catches
models that emit ``Search`` instead of ``search``.
2. **Stage 2 invalid fallback:** if still unmatched, rewrite the call
to ``invalid`` with ``args={"tool": original_name, "error": <error>}``
so the registered :func:`invalid_tool` returns the error to the model
for self-correction.
Ported from OpenCode's ``packages/opencode/src/session/llm.ts:339-358``
+ ``packages/opencode/src/tool/invalid.ts``. LangChain has no equivalent:
:class:`deepagents.middleware.PatchToolCallsMiddleware` patches
*dangling* tool calls (no matching ToolMessage) but does nothing about
wrong names, and the model framework's default behavior on an unknown
name is to crash the turn rather than route to a self-correction
fallback.
"""
from __future__ import annotations
import difflib
import logging
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ResponseT,
)
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
def _coerce_existing_tool_call(call: Any) -> dict[str, Any]:
"""Normalize a tool call entry to a mutable dict."""
if isinstance(call, dict):
return dict(call)
return {
"name": getattr(call, "name", None),
"args": getattr(call, "args", {}),
"id": getattr(call, "id", None),
"type": "tool_call",
}
class ToolCallNameRepairMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Two-stage tool-name repair on the most recent ``AIMessage``.
Args:
registered_tool_names: Set of canonically-registered tool names.
``invalid`` should be in this set so the fallback dispatches.
fuzzy_match_threshold: Optional ``difflib`` ratio (0-1) for the
fuzzy-match step that runs *between* lowercase and invalid.
Set to ``None`` to disable fuzzy matching (default in
OpenCode; we mirror that to avoid silent rewrites).
"""
def __init__(
self,
*,
registered_tool_names: set[str],
fuzzy_match_threshold: float | None = 0.85,
) -> None:
super().__init__()
self._registered = set(registered_tool_names)
self._registered_lower = {name.lower(): name for name in self._registered}
self._fuzzy_threshold = fuzzy_match_threshold
self.tools = []
def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]:
"""Allow runtime overrides to expand the set (e.g. dynamic MCP tools)."""
ctx_tools = getattr(runtime.context, "registered_tool_names", None)
if isinstance(ctx_tools, set | frozenset):
return self._registered | set(ctx_tools)
if isinstance(ctx_tools, list | tuple):
return self._registered | set(ctx_tools)
return self._registered
def _repair_one(
self,
call: dict[str, Any],
registered: set[str],
) -> dict[str, Any]:
name = call.get("name")
if not isinstance(name, str):
return call
if name in registered:
return call
# Stage 1 — lowercase
lowered = name.lower()
if lowered in registered:
call["name"] = lowered
metadata = dict(call.get("response_metadata") or {})
metadata.setdefault("repair", "lowercase")
call["response_metadata"] = metadata
return call
# Optional fuzzy step (off by default — see class docstring)
if self._fuzzy_threshold is not None:
close = difflib.get_close_matches(
name, registered, n=1, cutoff=self._fuzzy_threshold
)
if close:
call["name"] = close[0]
metadata = dict(call.get("response_metadata") or {})
metadata.setdefault("repair", f"fuzzy:{name}->{close[0]}")
call["response_metadata"] = metadata
return call
# Stage 2 — invalid fallback
# Local import keeps the middleware module import-light and avoids any
# tools <-> middleware import-order coupling at module scope.
from app.agents.multi_agent_chat.main_agent.tools.invalid_tool import (
INVALID_TOOL_NAME,
)
if INVALID_TOOL_NAME in registered:
original_args = call.get("args") or {}
error_msg = (
f"Tool name '{name}' is not registered. "
f"Original arguments were: {original_args!r}."
)
call["name"] = INVALID_TOOL_NAME
call["args"] = {"tool": name, "error": error_msg}
metadata = dict(call.get("response_metadata") or {})
metadata.setdefault("repair", f"invalid_fallback:{name}")
call["response_metadata"] = metadata
else:
logger.warning(
"Could not repair unknown tool call %r; 'invalid' tool not registered",
name,
)
return call
def _maybe_repair(
self,
message: AIMessage,
registered: set[str],
) -> AIMessage | None:
if not message.tool_calls:
return None
new_calls: list[dict[str, Any]] = []
any_changed = False
for raw in message.tool_calls:
call = _coerce_existing_tool_call(raw)
before = (call.get("name"), call.get("args"))
repaired = self._repair_one(call, registered)
after = (repaired.get("name"), repaired.get("args"))
if before != after:
any_changed = True
new_calls.append(repaired)
if not any_changed:
return None
return message.model_copy(update={"tool_calls": new_calls})
def after_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
messages = state.get("messages") or []
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
registered = self._registered_for_runtime(runtime)
repaired = self._maybe_repair(last, registered)
if repaired is None:
return None
return {"messages": [repaired]}
async def aafter_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
return self.after_model(state, runtime)
__all__ = [
"ToolCallNameRepairMiddleware",
]