Merge upstream/dev into feature/multi-agent

This commit is contained in:
CREDO23 2026-05-05 01:44:46 +02:00
commit 5119915f4f
278 changed files with 34669 additions and 8970 deletions

View file

@ -24,6 +24,9 @@ from app.agents.new_chat.middleware.file_intent import (
from app.agents.new_chat.middleware.filesystem import (
SurfSenseFilesystemMiddleware,
)
from app.agents.new_chat.middleware.flatten_system import (
FlattenSystemMessageMiddleware,
)
from app.agents.new_chat.middleware.kb_persistence import (
KnowledgeBasePersistenceMiddleware,
commit_staged_filesystem_state,
@ -61,6 +64,7 @@ __all__ = [
"DedupHITLToolCallsMiddleware",
"DoomLoopMiddleware",
"FileIntentMiddleware",
"FlattenSystemMessageMiddleware",
"KnowledgeBasePersistenceMiddleware",
"KnowledgeBaseSearchMiddleware",
"KnowledgePriorityMiddleware",

View file

@ -33,6 +33,7 @@ from __future__ import annotations
import asyncio
import logging
import time
import weakref
from typing import Any
@ -58,6 +59,11 @@ class _ThreadLockManager:
weakref.WeakValueDictionary()
)
self._cancel_events: dict[str, asyncio.Event] = {}
self._cancel_requested_at_ms: dict[str, int] = {}
self._cancel_attempt_count: dict[str, int] = {}
# Monotonic per-thread epoch used to prevent stale middleware
# teardown from releasing a newer turn's lock.
self._turn_epoch: dict[str, int] = {}
def lock_for(self, thread_id: str) -> asyncio.Lock:
lock = self._locks.get(thread_id)
@ -76,14 +82,57 @@ class _ThreadLockManager:
def request_cancel(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id)
if event is None:
return False
event = asyncio.Event()
self._cancel_events[thread_id] = event
event.set()
now_ms = int(time.time() * 1000)
self._cancel_requested_at_ms[thread_id] = now_ms
self._cancel_attempt_count[thread_id] = (
self._cancel_attempt_count.get(thread_id, 0) + 1
)
return True
def is_cancel_requested(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id)
return bool(event and event.is_set())
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
if not self.is_cancel_requested(thread_id):
return None
attempts = self._cancel_attempt_count.get(thread_id, 1)
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
return attempts, requested_at_ms
def reset(self, thread_id: str) -> None:
event = self._cancel_events.get(thread_id)
if event is not None:
event.clear()
self._cancel_requested_at_ms.pop(thread_id, None)
self._cancel_attempt_count.pop(thread_id, None)
def bump_turn_epoch(self, thread_id: str) -> int:
epoch = self._turn_epoch.get(thread_id, 0) + 1
self._turn_epoch[thread_id] = epoch
return epoch
def current_turn_epoch(self, thread_id: str) -> int:
return self._turn_epoch.get(thread_id, 0)
def end_turn(self, thread_id: str) -> None:
"""Best-effort terminal cleanup for a thread turn.
This is intentionally idempotent and safe to call from outer stream
finally-blocks where middleware teardown might be skipped due to abort
or disconnect edge-cases.
"""
# Invalidate any in-flight middleware holder first. This guarantees a
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
# retry that already acquired the lock for the same thread.
self.bump_turn_epoch(thread_id)
lock = self._locks.get(thread_id)
if lock is not None and lock.locked():
lock.release()
self.reset(thread_id)
def release(self, thread_id: str) -> bool:
"""Force-release the per-thread lock; safety-net for turns that end before ``__end__``.
@ -115,18 +164,28 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
def request_cancel(thread_id: str) -> bool:
"""Trip the cancel event for ``thread_id``. Returns True if found."""
"""Trip the cancel event for ``thread_id``. Always returns True."""
return manager.request_cancel(thread_id)
def is_cancel_requested(thread_id: str) -> bool:
"""Return whether ``thread_id`` currently has a pending cancel signal."""
return manager.is_cancel_requested(thread_id)
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
return manager.cancel_state(thread_id)
def reset_cancel(thread_id: str) -> None:
"""Reset the cancel event for ``thread_id`` (called between turns)."""
manager.reset(thread_id)
def release_lock(thread_id: str) -> bool:
"""Force-release the per-thread busy lock; safe to call when not held."""
return manager.release(thread_id)
def end_turn(thread_id: str) -> None:
"""Force end-of-turn cleanup for lock + cancel state."""
manager.end_turn(thread_id)
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
@ -151,10 +210,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
super().__init__()
self._require_thread_id = require_thread_id
self.tools = []
# Per-call locks owned by this middleware. We track them as
# an instance attribute so ``aafter_agent`` knows which lock
# to release.
self._held_locks: dict[str, asyncio.Lock] = {}
# Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
# only releases when its epoch still matches the manager's current
# epoch for the thread, preventing stale unlock races.
self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
@staticmethod
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
@ -205,7 +264,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
if lock.locked():
raise BusyError(request_id=thread_id)
await lock.acquire()
self._held_locks[thread_id] = lock
epoch = manager.bump_turn_epoch(thread_id)
self._held_locks[thread_id] = (lock, epoch)
# Reset the cancel event so this turn starts fresh
reset_cancel(thread_id)
return None
@ -219,8 +279,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
thread_id = self._thread_id(runtime)
if thread_id is None:
return None
lock = self._held_locks.pop(thread_id, None)
if lock is not None and lock.locked():
held = self._held_locks.pop(thread_id, None)
if held is None:
return None
lock, held_epoch = held
if held_epoch != manager.current_turn_epoch(thread_id):
# Stale teardown from an older attempt (e.g. runtime-recovery path
# already advanced epoch). Do not touch current lock/cancel state.
return None
if lock.locked():
lock.release()
# Always clear cancel event between turns so a stale signal
# doesn't leak into the next request.
@ -251,9 +318,11 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
__all__ = [
"BusyMutexMiddleware",
"end_turn",
"get_cancel_event",
"get_cancel_state",
"is_cancel_requested",
"manager",
"release_lock",
"request_cancel",
"reset_cancel",
]

View file

@ -0,0 +1,233 @@
r"""Coalesce multi-block system messages into a single text block.
Several middlewares in our deepagent stack each call
``append_to_system_message`` on the way down to the model
(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``,
``SkillsMiddleware``, ``SubAgentMiddleware`` ). By the time the
request reaches the LLM, the system message has 5+ separate text blocks.
Anthropic enforces a hard cap of **4 ``cache_control`` blocks per
request**, and we configure 2 injection points
(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting
the prepended ``request.system_message``, this middleware is the
defensive partner: it guarantees that "the system block" is *one*
content block, so LiteLLM's ``AnthropicCacheControlHook`` and any
OpenRouterAnthropic transformer can never multiply our budget into
several breakpoints by spreading ``cache_control`` across multiple
text blocks of a multi-block system content.
Without flattening we used to see::
OpenrouterException - {"error":{"message":"Provider returned error",
"code":400,"metadata":{"raw":"...A maximum of 4 blocks with
cache_control may be provided. Found 5."}}}
(Same error class documented in
https://github.com/BerriAI/litellm/issues/15696 and
https://github.com/BerriAI/litellm/issues/20485 the litellm-side fix
in PR #15395 covers the litellm transformer but does not protect us
when the OpenRouter SaaS itself does the redistribution.)
A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching
the first injection point from ``role: system`` to ``index: 0``)
neutralises the *primary* cause of the same 400 multiple
``SystemMessage``\ s injected by ``before_agent`` middlewares
(priority/tree/memory/file-intent/anonymous-doc) accumulating across
turns, each tagged with ``cache_control`` by the ``role: system``
matcher. This middleware remains useful as defence-in-depth against
the multi-block redistribution path.
Placement: innermost on the system-message-mutation chain, after every
appender (``todo``/``filesystem``/``skills``/``subagents``) and after
summarization, but before ``noop``/``retry``/``fallback`` so each retry
attempt sees a flattened payload. See ``chat_deepagent.py``.
Idempotent: a string-content system message is left untouched. A list
that contains anything other than plain text blocks (e.g. an image) is
also left untouched those are rare on system messages and we'd lose
the non-text payload by joining.
"""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
from langchain_core.messages import SystemMessage
logger = logging.getLogger(__name__)
def _flatten_text_blocks(content: list[Any]) -> str | None:
"""Return joined text if every block is a plain ``{"type": "text"}``.
Returns ``None`` when the list contains anything that isn't a text
block we can safely concatenate (image, audio, file, non-standard
blocks, dicts with extra non-cache_control fields). The caller
leaves the original content untouched in that case rather than
silently dropping payload.
``cache_control`` on individual blocks is intentionally discarded
the whole point of flattening is to let LiteLLM's
``cache_control_injection_points`` re-place a single breakpoint on
the resulting one-block system content.
"""
chunks: list[str] = []
for block in content:
if isinstance(block, str):
chunks.append(block)
continue
if not isinstance(block, dict):
return None
if block.get("type") != "text":
return None
text = block.get("text")
if not isinstance(text, str):
return None
chunks.append(text)
return "\n\n".join(chunks)
def _flattened_request(
request: ModelRequest[ContextT],
) -> ModelRequest[ContextT] | None:
"""Return a request with system_message flattened, or ``None`` for no-op."""
sys_msg = request.system_message
if sys_msg is None:
return None
content = sys_msg.content
if not isinstance(content, list) or len(content) <= 1:
return None
flattened = _flatten_text_blocks(content)
if flattened is None:
return None
new_sys = SystemMessage(
content=flattened,
additional_kwargs=dict(sys_msg.additional_kwargs),
response_metadata=dict(sys_msg.response_metadata),
)
if sys_msg.id is not None:
new_sys.id = sys_msg.id
return request.override(system_message=new_sys)
def _diagnostic_summary(request: ModelRequest[Any]) -> str:
"""One-line dump of cache_control-relevant request shape.
Temporary diagnostic to prove where the ``Found N`` cache_control
breakpoints are coming from when Anthropic 400s. Removed once the
root cause is confirmed and a fix is in place.
"""
sys_msg = request.system_message
if sys_msg is None:
sys_shape = "none"
elif isinstance(sys_msg.content, str):
sys_shape = f"str(len={len(sys_msg.content)})"
elif isinstance(sys_msg.content, list):
sys_shape = f"list(blocks={len(sys_msg.content)})"
else:
sys_shape = f"other({type(sys_msg.content).__name__})"
role_hist: list[str] = []
multi_block_msgs = 0
msgs_with_cc = 0
sys_msgs_in_history = 0
for m in request.messages:
mtype = getattr(m, "type", type(m).__name__)
role_hist.append(mtype)
if isinstance(m, SystemMessage):
sys_msgs_in_history += 1
c = getattr(m, "content", None)
if isinstance(c, list):
multi_block_msgs += 1
for blk in c:
if isinstance(blk, dict) and "cache_control" in blk:
msgs_with_cc += 1
break
if "cache_control" in getattr(m, "additional_kwargs", {}) or {}:
msgs_with_cc += 1
tools = request.tools or []
tools_with_cc = 0
for t in tools:
if isinstance(t, dict) and (
"cache_control" in t or "cache_control" in t.get("function", {})
):
tools_with_cc += 1
return (
f"sys={sys_shape} msgs={len(request.messages)} "
f"sys_msgs_in_history={sys_msgs_in_history} "
f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} "
f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} "
f"roles={role_hist[-8:]}"
)
class FlattenSystemMessageMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Collapse a multi-text-block system message to a single string.
Sits innermost on the system-message-mutation chain so it observes
every middleware's contribution. Has no other side effect — the
body of every block is preserved, just joined with ``"\\n\\n"``.
"""
def __init__(self) -> None:
super().__init__()
self.tools = []
def wrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> Any:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
flattened = _flattened_request(request)
if flattened is not None:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"[flatten_system] collapsed %d system blocks to one",
len(request.system_message.content), # type: ignore[arg-type, union-attr]
)
return handler(flattened)
return handler(request)
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> Any:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
flattened = _flattened_request(request)
if flattened is not None:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"[flatten_system] collapsed %d system blocks to one",
len(request.system_message.content), # type: ignore[arg-type, union-attr]
)
return await handler(flattened)
return await handler(request)
__all__ = [
"FlattenSystemMessageMiddleware",
"_flatten_text_blocks",
"_flattened_request",
]

View file

@ -732,7 +732,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
if self.filesystem_mode != FilesystemMode.CLOUD:
return None
@ -755,7 +754,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
if anon_doc:
return self._anon_priority(state, anon_doc)
return await self._authenticated_priority(state, messages, user_text)
return await self._authenticated_priority(state, messages, user_text, runtime)
def _anon_priority(
self,
@ -787,6 +786,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
state: AgentState,
messages: Sequence[BaseMessage],
user_text: str,
runtime: Runtime[Any] | None = None,
) -> dict[str, Any]:
t0 = asyncio.get_event_loop().time()
(
@ -799,13 +799,45 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
user_text=user_text,
)
# Per-turn ``mentioned_document_ids`` flow:
# 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
# streaming task supplies a fresh :class:`SurfSenseContextSchema`
# on every ``astream_events`` call, so this list is naturally
# scoped to the current turn. Allows cross-turn graph reuse via
# ``agent_cache``.
# 2. Legacy fallback (cache disabled / context not propagated): the
# constructor-injected ``self.mentioned_document_ids`` list. We
# drain it after the first read so a cached graph (no Phase 1.5
# wiring) doesn't keep replaying the same mentions on every
# turn.
#
# CRITICAL: distinguish "context absent" (legacy caller, no field at
# all) from "context provided but empty" (turn with no mentions).
# ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
# Python, so a naive ``if ctx_mentions:`` would fall through to the
# legacy closure on every no-mention follow-up turn — replaying the
# mentions baked in by turn 1's cache-miss build. Always drain the
# closure once the runtime path has fired so a cached middleware
# instance can never resurrect stale state.
mention_ids: list[int] = []
ctx = getattr(runtime, "context", None) if runtime is not None else None
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
if ctx_mentions is not None:
# Runtime path is authoritative — even an empty list means
# "this turn has no mentions", NOT "look at the closure".
mention_ids = list(ctx_mentions)
if self.mentioned_document_ids:
self.mentioned_document_ids = []
elif self.mentioned_document_ids:
mention_ids = list(self.mentioned_document_ids)
self.mentioned_document_ids = []
mentioned_results: list[dict[str, Any]] = []
if self.mentioned_document_ids:
if mention_ids:
mentioned_results = await fetch_mentioned_documents(
document_ids=self.mentioned_document_ids,
document_ids=mention_ids,
search_space_id=self.search_space_id,
)
self.mentioned_document_ids = []
if is_recency:
doc_types = _resolve_search_types(