mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): move middleware package to app/agents/shared (slice 5c)
Relocate the entire new_chat/middleware/ package to the shared kernel as one cohesive unit (it is live shared infrastructure: the multi-agent stack wraps nearly every middleware via multi_agent_chat/middleware/main_agent/*, and anonymous_agent consumes it too). Flip 69 live importers across both the package-path and submodule-path forms. Shims left for the frozen single-agent stack: a package __init__ re-export plus submodule shims for permission, skills_backends, and scoped_model_fallback (the three imported via submodule path by chat_deepagent/subagents). Cycle break: importing shared.middleware previously reached back into new_chat.tools at module load, which dragged in new_chat.__init__ -> chat_deepagent -> the middleware shim -> half-initialized shared.middleware. Made action_log's ToolDefinition import TYPE_CHECKING-only and tool_call_repair's INVALID_TOOL_NAME import function-local. These tools-package back-edges fully resolve in slice 6. Asset note: skills_backends._default_builtin_root now walks to app/agents/new_chat/skills/builtin (the skills/ tree migrates in slice 7).
This commit is contained in:
parent
6f488d9564
commit
227983a104
98 changed files with 1131 additions and 999 deletions
|
|
@ -0,0 +1,350 @@
|
|||
"""
|
||||
SpillToBackendEdit + SpillingContextEditingMiddleware.
|
||||
|
||||
LangChain's :class:`ClearToolUsesEdit` discards old ``ToolMessage.content``
|
||||
when the context-editing budget triggers, replacing the body with a fixed
|
||||
placeholder. That's lossy: anything the agent might want to revisit is
|
||||
gone. The spill-to-disk pattern (originally from OpenCode's
|
||||
``opencode/packages/opencode/src/tool/truncate.ts``) keeps the prune
|
||||
behavior but writes the full original payload to the runtime backend
|
||||
under ``/tool_outputs/{thread_id}/{message_id}.txt`` first. The
|
||||
placeholder is then upgraded to point at the spill path so the agent
|
||||
(or a subagent) can read it back on demand.
|
||||
|
||||
Why this is a middleware subclass instead of a plain ``ContextEdit``:
|
||||
``ContextEdit.apply`` is sync, but writing to the backend is async. We
|
||||
capture the spill payloads inside ``apply`` and flush them via
|
||||
``await backend.aupload_files(...)`` from ``awrap_model_call`` *before*
|
||||
delegating to the handler, so the explore subagent can always read what
|
||||
the placeholder advertises.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware.context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
ContextEdit,
|
||||
ContextEditingMiddleware,
|
||||
TokenCounter,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.utils import count_tokens_approximately
|
||||
from langgraph.config import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deepagents.backends.protocol import BackendProtocol
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SPILL_PREFIX = "/tool_outputs"
|
||||
|
||||
|
||||
def _build_spill_placeholder(spill_path: str) -> str:
|
||||
"""Build the user-facing placeholder text shown to the model."""
|
||||
return (
|
||||
f"[cleared — full output at {spill_path}; ask the explore subagent to read it]"
|
||||
)
|
||||
|
||||
|
||||
def _get_thread_id_or_session() -> str:
|
||||
"""Best-effort thread_id discovery for the spill path.
|
||||
|
||||
Falls back to a process-stable string if no LangGraph config is
|
||||
available (e.g. unit tests). The exact value doesn't matter as long
|
||||
as it's stable within one stream so the placeholder paths line up
|
||||
with the actual upload path.
|
||||
"""
|
||||
try:
|
||||
config = get_config()
|
||||
thread_id = config.get("configurable", {}).get("thread_id")
|
||||
if thread_id is not None:
|
||||
return str(thread_id)
|
||||
except RuntimeError:
|
||||
pass
|
||||
return "no_thread"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SpillToBackendEdit(ContextEdit):
|
||||
"""Capture-and-replace context edit that spills full tool output to the backend.
|
||||
|
||||
Behaves like :class:`ClearToolUsesEdit` (same trigger / keep / exclude
|
||||
semantics) **and** records the original ``ToolMessage.content`` in
|
||||
:attr:`pending_spills` so the wrapping middleware can flush them
|
||||
before the model call.
|
||||
|
||||
Args:
|
||||
trigger: Token threshold above which the edit fires.
|
||||
clear_at_least: Minimum number of tokens to reclaim (best effort).
|
||||
keep: Number of most-recent ``ToolMessage`` instances to leave
|
||||
untouched.
|
||||
exclude_tools: Names of tools whose output is NOT spilled.
|
||||
clear_tool_inputs: Also clear the originating ``AIMessage.tool_calls``
|
||||
args when their pair is cleared.
|
||||
path_prefix: Path under the backend where spills are written.
|
||||
Default ``"/tool_outputs"``.
|
||||
"""
|
||||
|
||||
trigger: int = 100_000
|
||||
clear_at_least: int = 0
|
||||
keep: int = 3
|
||||
clear_tool_inputs: bool = False
|
||||
exclude_tools: Sequence[str] = ()
|
||||
path_prefix: str = DEFAULT_SPILL_PREFIX
|
||||
|
||||
pending_spills: list[tuple[str, bytes]] = field(default_factory=list)
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
def drain_pending(self) -> list[tuple[str, bytes]]:
|
||||
"""Return and clear the pending-spill list atomically."""
|
||||
with self._lock:
|
||||
out = list(self.pending_spills)
|
||||
self.pending_spills.clear()
|
||||
return out
|
||||
|
||||
def apply(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
*,
|
||||
count_tokens: TokenCounter,
|
||||
) -> None:
|
||||
"""Mirror ``ClearToolUsesEdit.apply`` but capture originals first."""
|
||||
tokens = count_tokens(messages)
|
||||
if tokens <= self.trigger:
|
||||
return
|
||||
|
||||
candidates = [
|
||||
(idx, msg)
|
||||
for idx, msg in enumerate(messages)
|
||||
if isinstance(msg, ToolMessage)
|
||||
]
|
||||
if self.keep >= len(candidates):
|
||||
return
|
||||
if self.keep:
|
||||
candidates = candidates[: -self.keep]
|
||||
|
||||
thread_id = _get_thread_id_or_session()
|
||||
excluded_tools = set(self.exclude_tools)
|
||||
|
||||
for idx, tool_message in candidates:
|
||||
if tool_message.response_metadata.get("context_editing", {}).get("cleared"):
|
||||
continue
|
||||
|
||||
ai_message = next(
|
||||
(m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)),
|
||||
None,
|
||||
)
|
||||
if ai_message is None:
|
||||
continue
|
||||
|
||||
tool_call = next(
|
||||
(
|
||||
call
|
||||
for call in ai_message.tool_calls
|
||||
if call.get("id") == tool_message.tool_call_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if tool_call is None:
|
||||
continue
|
||||
|
||||
tool_name = tool_message.name or tool_call["name"]
|
||||
if tool_name in excluded_tools:
|
||||
continue
|
||||
|
||||
message_id = tool_message.id or tool_message.tool_call_id or "unknown"
|
||||
spill_path = f"{self.path_prefix}/{thread_id}/{message_id}.txt"
|
||||
|
||||
original = tool_message.content
|
||||
payload = self._encode_payload(original)
|
||||
with self._lock:
|
||||
self.pending_spills.append((spill_path, payload))
|
||||
|
||||
messages[idx] = tool_message.model_copy(
|
||||
update={
|
||||
"artifact": None,
|
||||
"content": _build_spill_placeholder(spill_path),
|
||||
"response_metadata": {
|
||||
**tool_message.response_metadata,
|
||||
"context_editing": {
|
||||
"cleared": True,
|
||||
"strategy": "spill_to_backend",
|
||||
"spill_path": spill_path,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if self.clear_tool_inputs:
|
||||
ai_idx = messages.index(ai_message)
|
||||
messages[ai_idx] = self._clear_input_args(
|
||||
ai_message, tool_message.tool_call_id or ""
|
||||
)
|
||||
|
||||
if self.clear_at_least > 0:
|
||||
new_token_count = count_tokens(messages)
|
||||
cleared_tokens = max(0, tokens - new_token_count)
|
||||
if cleared_tokens >= self.clear_at_least:
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def _encode_payload(content: Any) -> bytes:
|
||||
"""Serialize ``ToolMessage.content`` to bytes for upload."""
|
||||
if isinstance(content, bytes):
|
||||
return content
|
||||
if isinstance(content, str):
|
||||
return content.encode("utf-8")
|
||||
try:
|
||||
import json
|
||||
|
||||
return json.dumps(content, default=str).encode("utf-8")
|
||||
except Exception:
|
||||
return str(content).encode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def _clear_input_args(message: AIMessage, tool_call_id: str) -> AIMessage:
|
||||
updated_tool_calls: list[dict[str, Any]] = []
|
||||
cleared_any = False
|
||||
for tool_call in message.tool_calls:
|
||||
updated = dict(tool_call)
|
||||
if updated.get("id") == tool_call_id:
|
||||
updated["args"] = {}
|
||||
cleared_any = True
|
||||
updated_tool_calls.append(updated)
|
||||
|
||||
metadata = dict(getattr(message, "response_metadata", {}))
|
||||
if cleared_any:
|
||||
ctx = dict(metadata.get("context_editing", {}))
|
||||
ids = set(ctx.get("cleared_tool_inputs", []))
|
||||
ids.add(tool_call_id)
|
||||
ctx["cleared_tool_inputs"] = sorted(ids)
|
||||
metadata["context_editing"] = ctx
|
||||
return message.model_copy(
|
||||
update={
|
||||
"tool_calls": updated_tool_calls,
|
||||
"response_metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
BackendResolver = "Callable[[Any], BackendProtocol] | BackendProtocol"
|
||||
|
||||
|
||||
class SpillingContextEditingMiddleware(ContextEditingMiddleware):
|
||||
""":class:`ContextEditingMiddleware` that flushes :class:`SpillToBackendEdit` writes.
|
||||
|
||||
Runs the configured edits as the parent does, then flushes any
|
||||
pending spills via the supplied backend resolver before delegating
|
||||
to the model handler. Spill failures are logged but never abort the
|
||||
model call — the placeholder text is already in the message, so the
|
||||
worst case is the agent gets a placeholder it cannot follow up on.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
edits: Sequence[ContextEdit],
|
||||
backend_resolver: BackendResolver | None = None,
|
||||
token_count_method: str = "approximate",
|
||||
) -> None:
|
||||
super().__init__(edits=list(edits), token_count_method=token_count_method) # type: ignore[arg-type]
|
||||
self._backend_resolver = backend_resolver
|
||||
|
||||
def _resolve_backend(self, request: ModelRequest) -> BackendProtocol | None:
|
||||
if self._backend_resolver is None:
|
||||
return None
|
||||
if callable(self._backend_resolver):
|
||||
try:
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
tool_runtime = ToolRuntime(
|
||||
state=getattr(request, "state", {}),
|
||||
context=getattr(request.runtime, "context", None),
|
||||
stream_writer=getattr(request.runtime, "stream_writer", None),
|
||||
store=getattr(request.runtime, "store", None),
|
||||
config=getattr(request.runtime, "config", None) or {},
|
||||
tool_call_id=None,
|
||||
)
|
||||
return self._backend_resolver(tool_runtime)
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve spill backend")
|
||||
return None
|
||||
return self._backend_resolver # type: ignore[return-value]
|
||||
|
||||
def _collect_pending(self) -> list[tuple[str, bytes]]:
|
||||
out: list[tuple[str, bytes]] = []
|
||||
for edit in self.edits:
|
||||
if isinstance(edit, SpillToBackendEdit):
|
||||
out.extend(edit.drain_pending())
|
||||
return out
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> Any:
|
||||
if not request.messages:
|
||||
return await handler(request)
|
||||
|
||||
if self.token_count_method == "approximate":
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
|
||||
else:
|
||||
system_msg = [request.system_message] if request.system_message else []
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
edited_messages = deepcopy(list(request.messages))
|
||||
for edit in self.edits:
|
||||
edit.apply(edited_messages, count_tokens=count_tokens)
|
||||
|
||||
pending = self._collect_pending()
|
||||
if pending:
|
||||
backend = self._resolve_backend(request)
|
||||
if backend is not None:
|
||||
try:
|
||||
await backend.aupload_files(pending)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Spill-to-backend upload failed (%d files); placeholders "
|
||||
"remain in messages but content is unrecoverable",
|
||||
len(pending),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"SpillToBackendEdit produced %d pending spills but no backend "
|
||||
"resolver was configured; content is unrecoverable",
|
||||
len(pending),
|
||||
)
|
||||
|
||||
return await handler(request.override(messages=edited_messages))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_SPILL_PREFIX",
|
||||
"ClearToolUsesEdit",
|
||||
"SpillToBackendEdit",
|
||||
"SpillingContextEditingMiddleware",
|
||||
"_build_spill_placeholder",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue