refactor(agents): move MAC middleware impls out of shared kernel

knowledge_search, memory_injection and scoped_model_fallback no longer
belong in the cross-agent kernel (app/agents/shared/middleware): they are
consumed only inside multi_agent_chat. Relocate each impl next to the
builder that uses it:

- knowledge_search.py -> multi_agent_chat/shared/middleware/ (genuinely
  shared: its _render_priority_message feeds kb_context_projection, used by
  both the main agent and the KB subagent)
- memory_injection.py -> multi_agent_chat/shared/middleware/ (beside its
  memory.py builder)
- scoped_model_fallback.py -> multi_agent_chat/shared/middleware/resilience/
  (beside fallback.py/bundle.py)

Impls moved verbatim (git rename). Builders/consumers now import the local
sibling; main_agent knowledge_priority imports the new shared path; shared
middleware barrel trimmed.

Tests: repoint imports; convert the knowledge_search monkeypatch targets
from brittle dotted-string form to object-based patching (monkeypatch.setattr
on the imported module), which is robust to import ordering. No behavior
change.
This commit is contained in:
CREDO23 2026-06-05 12:04:31 +02:00
parent 9493519c61
commit 8ae190a11d
12 changed files with 35 additions and 39 deletions

View file

@ -9,20 +9,12 @@ from app.agents.shared.middleware.kb_persistence import (
KnowledgeBasePersistenceMiddleware,
commit_staged_filesystem_state,
)
from app.agents.shared.middleware.knowledge_search import (
KnowledgePriorityMiddleware,
)
from app.agents.shared.middleware.memory_injection import (
MemoryInjectionMiddleware,
)
from app.agents.shared.middleware.permission import PermissionMiddleware
from app.agents.shared.middleware.retry_after import RetryAfterMiddleware
__all__ = [
"BusyMutexMiddleware",
"KnowledgeBasePersistenceMiddleware",
"KnowledgePriorityMiddleware",
"MemoryInjectionMiddleware",
"PermissionMiddleware",
"RetryAfterMiddleware",
"SurfSenseCompactionMiddleware",

View file

@ -1,160 +0,0 @@
"""Memory injection middleware for the SurfSense agent.
Injects memory markdown into the system prompt on every turn:
- Private threads: only personal memory (<user_memory>)
- Shared threads: only team memory (<team_memory>)
"""
from __future__ import annotations
import logging
import time
from typing import Any
from uuid import UUID
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.runtime import Runtime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import ChatVisibility, SearchSpace, User, shielded_async_session
from app.services.memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Injects memory markdown into the conversation on every turn."""
tools = ()
def __init__(
self,
*,
user_id: str | UUID | None,
search_space_id: int,
thread_visibility: ChatVisibility | None = None,
) -> None:
self.user_id = UUID(user_id) if isinstance(user_id, str) else user_id
self.search_space_id = search_space_id
self.visibility = thread_visibility or ChatVisibility.PRIVATE
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
messages = state.get("messages") or []
if not messages:
return None
last_message = messages[-1]
if not isinstance(last_message, HumanMessage):
return None
start = time.perf_counter()
db_elapsed = 0.0
memory_blocks: list[str] = []
scope = "team" if self.visibility == ChatVisibility.SEARCH_SPACE else "user"
async with shielded_async_session() as session:
db_start = time.perf_counter()
if self.visibility == ChatVisibility.SEARCH_SPACE:
team_memory = await self._load_team_memory(session)
if team_memory:
chars = len(team_memory)
memory_blocks.append(
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
f"{team_memory}\n"
f"</team_memory>"
)
if chars > MEMORY_SOFT_LIMIT:
memory_blocks.append(
f"<memory_warning>Team memory is at "
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
f"the hard limit. On your next update_memory call, consolidate "
f"by merging duplicates, removing outdated entries, and "
f"shortening descriptions before adding anything new."
f"</memory_warning>"
)
elif self.user_id is not None:
user_memory, display_name = await self._load_user_memory(session)
if display_name and display_name.strip():
first_name = display_name.strip().split()[0]
memory_blocks.append(f"<user_name>{first_name}</user_name>")
if user_memory:
chars = len(user_memory)
memory_blocks.append(
f'<user_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
f"{user_memory}\n"
f"</user_memory>"
)
if chars > MEMORY_SOFT_LIMIT:
memory_blocks.append(
f"<memory_warning>Your personal memory is at "
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
f"the hard limit. On your next update_memory call, consolidate "
f"by merging duplicates, removing outdated entries, and "
f"shortening descriptions before adding anything new."
f"</memory_warning>"
)
db_elapsed = time.perf_counter() - db_start
if not memory_blocks:
_perf_log.info(
"[memory_injection] scope=%s injected=0 db=%.3fs total=%.3fs",
scope,
db_elapsed,
time.perf_counter() - start,
)
return None
memory_text = "\n\n".join(memory_blocks)
memory_msg = SystemMessage(content=memory_text)
new_messages = list(messages)
insert_idx = 1 if len(new_messages) > 1 else 0
new_messages.insert(insert_idx, memory_msg)
_perf_log.info(
"[memory_injection] scope=%s injected=1 chars=%d db=%.3fs total=%.3fs",
scope,
len(memory_text),
db_elapsed,
time.perf_counter() - start,
)
return {"messages": new_messages}
async def _load_user_memory(
self, session: AsyncSession
) -> tuple[str | None, str | None]:
"""Return (memory_content, display_name)."""
try:
result = await session.execute(
select(User.memory_md, User.display_name).where(User.id == self.user_id)
)
row = result.one_or_none()
if row is None:
return None, None
return row.memory_md or None, row.display_name
except Exception:
logger.exception("Failed to load user memory")
return None, None
async def _load_team_memory(self, session: AsyncSession) -> str | None:
try:
result = await session.execute(
select(SearchSpace.shared_memory_md).where(
SearchSpace.id == self.search_space_id
)
)
row = result.scalar_one_or_none()
return row if row else None
except Exception:
logger.exception("Failed to load team memory")
return None

View file

@ -1,111 +0,0 @@
"""Fallback only on provider/network errors; let programming bugs raise."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import ModelFallbackMiddleware
from app.observability import metrics as ot_metrics, otel as ot
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.messages import AIMessage
# Matched by class name across the MRO so we don't have to import every
# provider SDK (openai/anthropic/google/...). Extend as new providers ship.
_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
{
"RateLimitError",
"APIStatusError",
"InternalServerError",
"ServiceUnavailableError",
"BadGatewayError",
"GatewayTimeoutError",
"APIConnectionError",
"APITimeoutError",
"ConnectError",
"ConnectTimeout",
"ReadTimeout",
"RemoteProtocolError",
"TimeoutError",
"TimeoutException",
}
)
def _is_fallback_eligible(exc: BaseException) -> bool:
return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__)
class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
"""Re-raise non-provider exceptions instead of walking the fallback chain."""
def wrap_model_call( # type: ignore[override]
self,
request: ModelRequest[Any],
handler: Callable[[ModelRequest[Any]], ModelResponse[Any]],
) -> ModelResponse[Any] | AIMessage:
last_exception: Exception
try:
return handler(request)
except Exception as e:
if not _is_fallback_eligible(e):
raise
last_exception = e
for attempt, fallback_model in enumerate(self.models, start=1):
ot.add_event(
"model.fallback",
{
"fallback.attempt": attempt,
"fallback.from": attempt - 1,
"fallback.to": attempt,
"fallback.reason": ot_metrics.categorize_exception(last_exception),
},
)
try:
return handler(request.override(model=fallback_model))
except Exception as e:
if not _is_fallback_eligible(e):
raise
last_exception = e
continue
raise last_exception
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[Any],
handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]],
) -> ModelResponse[Any] | AIMessage:
last_exception: Exception
try:
return await handler(request)
except Exception as e:
if not _is_fallback_eligible(e):
raise
last_exception = e
for attempt, fallback_model in enumerate(self.models, start=1):
ot.add_event(
"model.fallback",
{
"fallback.attempt": attempt,
"fallback.from": attempt - 1,
"fallback.to": attempt,
"fallback.reason": ot_metrics.categorize_exception(last_exception),
},
)
try:
return await handler(request.override(model=fallback_model))
except Exception as e:
if not _is_fallback_eligible(e):
raise
last_exception = e
continue
raise last_exception