mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-07 23:02:39 +02:00
Merge upstream/dev into feature/multi-agent
This commit is contained in:
commit
5119915f4f
278 changed files with 34669 additions and 8970 deletions
|
|
@ -24,6 +24,9 @@ from app.agents.new_chat.middleware.file_intent import (
|
|||
from app.agents.new_chat.middleware.filesystem import (
|
||||
SurfSenseFilesystemMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.flatten_system import (
|
||||
FlattenSystemMessageMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.kb_persistence import (
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
commit_staged_filesystem_state,
|
||||
|
|
@ -61,6 +64,7 @@ __all__ = [
|
|||
"DedupHITLToolCallsMiddleware",
|
||||
"DoomLoopMiddleware",
|
||||
"FileIntentMiddleware",
|
||||
"FlattenSystemMessageMiddleware",
|
||||
"KnowledgeBasePersistenceMiddleware",
|
||||
"KnowledgeBaseSearchMiddleware",
|
||||
"KnowledgePriorityMiddleware",
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -58,6 +59,11 @@ class _ThreadLockManager:
|
|||
weakref.WeakValueDictionary()
|
||||
)
|
||||
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||
self._cancel_attempt_count: dict[str, int] = {}
|
||||
# Monotonic per-thread epoch used to prevent stale middleware
|
||||
# teardown from releasing a newer turn's lock.
|
||||
self._turn_epoch: dict[str, int] = {}
|
||||
|
||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||
lock = self._locks.get(thread_id)
|
||||
|
|
@ -76,14 +82,57 @@ class _ThreadLockManager:
|
|||
def request_cancel(self, thread_id: str) -> bool:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is None:
|
||||
return False
|
||||
event = asyncio.Event()
|
||||
self._cancel_events[thread_id] = event
|
||||
event.set()
|
||||
now_ms = int(time.time() * 1000)
|
||||
self._cancel_requested_at_ms[thread_id] = now_ms
|
||||
self._cancel_attempt_count[thread_id] = (
|
||||
self._cancel_attempt_count.get(thread_id, 0) + 1
|
||||
)
|
||||
return True
|
||||
|
||||
def is_cancel_requested(self, thread_id: str) -> bool:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
return bool(event and event.is_set())
|
||||
|
||||
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
|
||||
if not self.is_cancel_requested(thread_id):
|
||||
return None
|
||||
attempts = self._cancel_attempt_count.get(thread_id, 1)
|
||||
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
|
||||
return attempts, requested_at_ms
|
||||
|
||||
def reset(self, thread_id: str) -> None:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is not None:
|
||||
event.clear()
|
||||
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||
self._cancel_attempt_count.pop(thread_id, None)
|
||||
|
||||
def bump_turn_epoch(self, thread_id: str) -> int:
|
||||
epoch = self._turn_epoch.get(thread_id, 0) + 1
|
||||
self._turn_epoch[thread_id] = epoch
|
||||
return epoch
|
||||
|
||||
def current_turn_epoch(self, thread_id: str) -> int:
|
||||
return self._turn_epoch.get(thread_id, 0)
|
||||
|
||||
def end_turn(self, thread_id: str) -> None:
|
||||
"""Best-effort terminal cleanup for a thread turn.
|
||||
|
||||
This is intentionally idempotent and safe to call from outer stream
|
||||
finally-blocks where middleware teardown might be skipped due to abort
|
||||
or disconnect edge-cases.
|
||||
"""
|
||||
# Invalidate any in-flight middleware holder first. This guarantees a
|
||||
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
|
||||
# retry that already acquired the lock for the same thread.
|
||||
self.bump_turn_epoch(thread_id)
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is not None and lock.locked():
|
||||
lock.release()
|
||||
self.reset(thread_id)
|
||||
|
||||
def release(self, thread_id: str) -> bool:
|
||||
"""Force-release the per-thread lock; safety-net for turns that end before ``__end__``.
|
||||
|
|
@ -115,18 +164,28 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
|
|||
|
||||
|
||||
def request_cancel(thread_id: str) -> bool:
|
||||
"""Trip the cancel event for ``thread_id``. Returns True if found."""
|
||||
"""Trip the cancel event for ``thread_id``. Always returns True."""
|
||||
return manager.request_cancel(thread_id)
|
||||
|
||||
|
||||
def is_cancel_requested(thread_id: str) -> bool:
|
||||
"""Return whether ``thread_id`` currently has a pending cancel signal."""
|
||||
return manager.is_cancel_requested(thread_id)
|
||||
|
||||
|
||||
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
|
||||
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
|
||||
return manager.cancel_state(thread_id)
|
||||
|
||||
|
||||
def reset_cancel(thread_id: str) -> None:
|
||||
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
||||
manager.reset(thread_id)
|
||||
|
||||
|
||||
def release_lock(thread_id: str) -> bool:
|
||||
"""Force-release the per-thread busy lock; safe to call when not held."""
|
||||
return manager.release(thread_id)
|
||||
def end_turn(thread_id: str) -> None:
|
||||
"""Force end-of-turn cleanup for lock + cancel state."""
|
||||
manager.end_turn(thread_id)
|
||||
|
||||
|
||||
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
|
|
@ -151,10 +210,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
super().__init__()
|
||||
self._require_thread_id = require_thread_id
|
||||
self.tools = []
|
||||
# Per-call locks owned by this middleware. We track them as
|
||||
# an instance attribute so ``aafter_agent`` knows which lock
|
||||
# to release.
|
||||
self._held_locks: dict[str, asyncio.Lock] = {}
|
||||
# Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
|
||||
# only releases when its epoch still matches the manager's current
|
||||
# epoch for the thread, preventing stale unlock races.
|
||||
self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
|
||||
|
|
@ -205,7 +264,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
if lock.locked():
|
||||
raise BusyError(request_id=thread_id)
|
||||
await lock.acquire()
|
||||
self._held_locks[thread_id] = lock
|
||||
epoch = manager.bump_turn_epoch(thread_id)
|
||||
self._held_locks[thread_id] = (lock, epoch)
|
||||
# Reset the cancel event so this turn starts fresh
|
||||
reset_cancel(thread_id)
|
||||
return None
|
||||
|
|
@ -219,8 +279,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
return None
|
||||
lock = self._held_locks.pop(thread_id, None)
|
||||
if lock is not None and lock.locked():
|
||||
held = self._held_locks.pop(thread_id, None)
|
||||
if held is None:
|
||||
return None
|
||||
lock, held_epoch = held
|
||||
if held_epoch != manager.current_turn_epoch(thread_id):
|
||||
# Stale teardown from an older attempt (e.g. runtime-recovery path
|
||||
# already advanced epoch). Do not touch current lock/cancel state.
|
||||
return None
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
# Always clear cancel event between turns so a stale signal
|
||||
# doesn't leak into the next request.
|
||||
|
|
@ -251,9 +318,11 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
|
||||
__all__ = [
|
||||
"BusyMutexMiddleware",
|
||||
"end_turn",
|
||||
"get_cancel_event",
|
||||
"get_cancel_state",
|
||||
"is_cancel_requested",
|
||||
"manager",
|
||||
"release_lock",
|
||||
"request_cancel",
|
||||
"reset_cancel",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,233 @@
|
|||
r"""Coalesce multi-block system messages into a single text block.
|
||||
|
||||
Several middlewares in our deepagent stack each call
|
||||
``append_to_system_message`` on the way down to the model
|
||||
(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``,
|
||||
``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the
|
||||
request reaches the LLM, the system message has 5+ separate text blocks.
|
||||
|
||||
Anthropic enforces a hard cap of **4 ``cache_control`` blocks per
|
||||
request**, and we configure 2 injection points
|
||||
(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting
|
||||
the prepended ``request.system_message``, this middleware is the
|
||||
defensive partner: it guarantees that "the system block" is *one*
|
||||
content block, so LiteLLM's ``AnthropicCacheControlHook`` and any
|
||||
OpenRouter→Anthropic transformer can never multiply our budget into
|
||||
several breakpoints by spreading ``cache_control`` across multiple
|
||||
text blocks of a multi-block system content.
|
||||
|
||||
Without flattening we used to see::
|
||||
|
||||
OpenrouterException - {"error":{"message":"Provider returned error",
|
||||
"code":400,"metadata":{"raw":"...A maximum of 4 blocks with
|
||||
cache_control may be provided. Found 5."}}}
|
||||
|
||||
(Same error class documented in
|
||||
https://github.com/BerriAI/litellm/issues/15696 and
|
||||
https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix
|
||||
in PR #15395 covers the litellm transformer but does not protect us
|
||||
when the OpenRouter SaaS itself does the redistribution.)
|
||||
|
||||
A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching
|
||||
the first injection point from ``role: system`` to ``index: 0``)
|
||||
neutralises the *primary* cause of the same 400 — multiple
|
||||
``SystemMessage``\ s injected by ``before_agent`` middlewares
|
||||
(priority/tree/memory/file-intent/anonymous-doc) accumulating across
|
||||
turns, each tagged with ``cache_control`` by the ``role: system``
|
||||
matcher. This middleware remains useful as defence-in-depth against
|
||||
the multi-block redistribution path.
|
||||
|
||||
Placement: innermost on the system-message-mutation chain, after every
|
||||
appender (``todo``/``filesystem``/``skills``/``subagents``) and after
|
||||
summarization, but before ``noop``/``retry``/``fallback`` so each retry
|
||||
attempt sees a flattened payload. See ``chat_deepagent.py``.
|
||||
|
||||
Idempotent: a string-content system message is left untouched. A list
|
||||
that contains anything other than plain text blocks (e.g. an image) is
|
||||
also left untouched — those are rare on system messages and we'd lose
|
||||
the non-text payload by joining.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _flatten_text_blocks(content: list[Any]) -> str | None:
|
||||
"""Return joined text if every block is a plain ``{"type": "text"}``.
|
||||
|
||||
Returns ``None`` when the list contains anything that isn't a text
|
||||
block we can safely concatenate (image, audio, file, non-standard
|
||||
blocks, dicts with extra non-cache_control fields). The caller
|
||||
leaves the original content untouched in that case rather than
|
||||
silently dropping payload.
|
||||
|
||||
``cache_control`` on individual blocks is intentionally discarded —
|
||||
the whole point of flattening is to let LiteLLM's
|
||||
``cache_control_injection_points`` re-place a single breakpoint on
|
||||
the resulting one-block system content.
|
||||
"""
|
||||
chunks: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
chunks.append(block)
|
||||
continue
|
||||
if not isinstance(block, dict):
|
||||
return None
|
||||
if block.get("type") != "text":
|
||||
return None
|
||||
text = block.get("text")
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
chunks.append(text)
|
||||
return "\n\n".join(chunks)
|
||||
|
||||
|
||||
def _flattened_request(
|
||||
request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT] | None:
|
||||
"""Return a request with system_message flattened, or ``None`` for no-op."""
|
||||
sys_msg = request.system_message
|
||||
if sys_msg is None:
|
||||
return None
|
||||
content = sys_msg.content
|
||||
if not isinstance(content, list) or len(content) <= 1:
|
||||
return None
|
||||
|
||||
flattened = _flatten_text_blocks(content)
|
||||
if flattened is None:
|
||||
return None
|
||||
|
||||
new_sys = SystemMessage(
|
||||
content=flattened,
|
||||
additional_kwargs=dict(sys_msg.additional_kwargs),
|
||||
response_metadata=dict(sys_msg.response_metadata),
|
||||
)
|
||||
if sys_msg.id is not None:
|
||||
new_sys.id = sys_msg.id
|
||||
return request.override(system_message=new_sys)
|
||||
|
||||
|
||||
def _diagnostic_summary(request: ModelRequest[Any]) -> str:
|
||||
"""One-line dump of cache_control-relevant request shape.
|
||||
|
||||
Temporary diagnostic to prove where the ``Found N`` cache_control
|
||||
breakpoints are coming from when Anthropic 400s. Removed once the
|
||||
root cause is confirmed and a fix is in place.
|
||||
"""
|
||||
sys_msg = request.system_message
|
||||
if sys_msg is None:
|
||||
sys_shape = "none"
|
||||
elif isinstance(sys_msg.content, str):
|
||||
sys_shape = f"str(len={len(sys_msg.content)})"
|
||||
elif isinstance(sys_msg.content, list):
|
||||
sys_shape = f"list(blocks={len(sys_msg.content)})"
|
||||
else:
|
||||
sys_shape = f"other({type(sys_msg.content).__name__})"
|
||||
|
||||
role_hist: list[str] = []
|
||||
multi_block_msgs = 0
|
||||
msgs_with_cc = 0
|
||||
sys_msgs_in_history = 0
|
||||
for m in request.messages:
|
||||
mtype = getattr(m, "type", type(m).__name__)
|
||||
role_hist.append(mtype)
|
||||
if isinstance(m, SystemMessage):
|
||||
sys_msgs_in_history += 1
|
||||
c = getattr(m, "content", None)
|
||||
if isinstance(c, list):
|
||||
multi_block_msgs += 1
|
||||
for blk in c:
|
||||
if isinstance(blk, dict) and "cache_control" in blk:
|
||||
msgs_with_cc += 1
|
||||
break
|
||||
if "cache_control" in getattr(m, "additional_kwargs", {}) or {}:
|
||||
msgs_with_cc += 1
|
||||
|
||||
tools = request.tools or []
|
||||
tools_with_cc = 0
|
||||
for t in tools:
|
||||
if isinstance(t, dict) and (
|
||||
"cache_control" in t or "cache_control" in t.get("function", {})
|
||||
):
|
||||
tools_with_cc += 1
|
||||
|
||||
return (
|
||||
f"sys={sys_shape} msgs={len(request.messages)} "
|
||||
f"sys_msgs_in_history={sys_msgs_in_history} "
|
||||
f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} "
|
||||
f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} "
|
||||
f"roles={role_hist[-8:]}"
|
||||
)
|
||||
|
||||
|
||||
class FlattenSystemMessageMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""Collapse a multi-text-block system message to a single string.
|
||||
|
||||
Sits innermost on the system-message-mutation chain so it observes
|
||||
every middleware's contribution. Has no other side effect — the
|
||||
body of every block is preserved, just joined with ``"\\n\\n"``.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tools = []
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> Any:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
||||
flattened = _flattened_request(request)
|
||||
if flattened is not None:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"[flatten_system] collapsed %d system blocks to one",
|
||||
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
||||
)
|
||||
return handler(flattened)
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> Any:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
||||
flattened = _flattened_request(request)
|
||||
if flattened is not None:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"[flatten_system] collapsed %d system blocks to one",
|
||||
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
||||
)
|
||||
return await handler(flattened)
|
||||
return await handler(request)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FlattenSystemMessageMiddleware",
|
||||
"_flatten_text_blocks",
|
||||
"_flattened_request",
|
||||
]
|
||||
|
|
@ -732,7 +732,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
|
||||
|
|
@ -755,7 +754,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
if anon_doc:
|
||||
return self._anon_priority(state, anon_doc)
|
||||
|
||||
return await self._authenticated_priority(state, messages, user_text)
|
||||
return await self._authenticated_priority(state, messages, user_text, runtime)
|
||||
|
||||
def _anon_priority(
|
||||
self,
|
||||
|
|
@ -787,6 +786,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
state: AgentState,
|
||||
messages: Sequence[BaseMessage],
|
||||
user_text: str,
|
||||
runtime: Runtime[Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
t0 = asyncio.get_event_loop().time()
|
||||
(
|
||||
|
|
@ -799,13 +799,45 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
user_text=user_text,
|
||||
)
|
||||
|
||||
# Per-turn ``mentioned_document_ids`` flow:
|
||||
# 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
|
||||
# streaming task supplies a fresh :class:`SurfSenseContextSchema`
|
||||
# on every ``astream_events`` call, so this list is naturally
|
||||
# scoped to the current turn. Allows cross-turn graph reuse via
|
||||
# ``agent_cache``.
|
||||
# 2. Legacy fallback (cache disabled / context not propagated): the
|
||||
# constructor-injected ``self.mentioned_document_ids`` list. We
|
||||
# drain it after the first read so a cached graph (no Phase 1.5
|
||||
# wiring) doesn't keep replaying the same mentions on every
|
||||
# turn.
|
||||
#
|
||||
# CRITICAL: distinguish "context absent" (legacy caller, no field at
|
||||
# all) from "context provided but empty" (turn with no mentions).
|
||||
# ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
|
||||
# Python, so a naive ``if ctx_mentions:`` would fall through to the
|
||||
# legacy closure on every no-mention follow-up turn — replaying the
|
||||
# mentions baked in by turn 1's cache-miss build. Always drain the
|
||||
# closure once the runtime path has fired so a cached middleware
|
||||
# instance can never resurrect stale state.
|
||||
mention_ids: list[int] = []
|
||||
ctx = getattr(runtime, "context", None) if runtime is not None else None
|
||||
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
|
||||
if ctx_mentions is not None:
|
||||
# Runtime path is authoritative — even an empty list means
|
||||
# "this turn has no mentions", NOT "look at the closure".
|
||||
mention_ids = list(ctx_mentions)
|
||||
if self.mentioned_document_ids:
|
||||
self.mentioned_document_ids = []
|
||||
elif self.mentioned_document_ids:
|
||||
mention_ids = list(self.mentioned_document_ids)
|
||||
self.mentioned_document_ids = []
|
||||
|
||||
mentioned_results: list[dict[str, Any]] = []
|
||||
if self.mentioned_document_ids:
|
||||
if mention_ids:
|
||||
mentioned_results = await fetch_mentioned_documents(
|
||||
document_ids=self.mentioned_document_ids,
|
||||
document_ids=mention_ids,
|
||||
search_space_id=self.search_space_id,
|
||||
)
|
||||
self.mentioned_document_ids = []
|
||||
|
||||
if is_recency:
|
||||
doc_types = _resolve_search_types(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue