mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): move MAC middleware impls out of shared kernel
knowledge_search, memory_injection and scoped_model_fallback no longer belong in the cross-agent kernel (app/agents/shared/middleware): they are consumed only inside multi_agent_chat. Relocate each impl next to the builder that uses it: - knowledge_search.py -> multi_agent_chat/shared/middleware/ (genuinely shared: its _render_priority_message feeds kb_context_projection, used by both the main agent and the KB subagent) - memory_injection.py -> multi_agent_chat/shared/middleware/ (beside its memory.py builder) - scoped_model_fallback.py -> multi_agent_chat/shared/middleware/resilience/ (beside fallback.py/bundle.py) Impls moved verbatim (git rename). Builders/consumers now import the local sibling; main_agent knowledge_priority imports the new shared path; shared middleware barrel trimmed. Tests: repoint imports; convert the knowledge_search monkeypatch targets from brittle dotted-string form to object-based patching (monkeypatch.setattr on the imported module), which is robust to import ordering. No behavior change.
This commit is contained in:
parent
9493519c61
commit
8ae190a11d
12 changed files with 35 additions and 39 deletions
|
|
@ -9,20 +9,12 @@ from app.agents.shared.middleware.kb_persistence import (
|
|||
KnowledgeBasePersistenceMiddleware,
|
||||
commit_staged_filesystem_state,
|
||||
)
|
||||
from app.agents.shared.middleware.knowledge_search import (
|
||||
KnowledgePriorityMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.memory_injection import (
|
||||
MemoryInjectionMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.permission import PermissionMiddleware
|
||||
from app.agents.shared.middleware.retry_after import RetryAfterMiddleware
|
||||
|
||||
__all__ = [
|
||||
"BusyMutexMiddleware",
|
||||
"KnowledgeBasePersistenceMiddleware",
|
||||
"KnowledgePriorityMiddleware",
|
||||
"MemoryInjectionMiddleware",
|
||||
"PermissionMiddleware",
|
||||
"RetryAfterMiddleware",
|
||||
"SurfSenseCompactionMiddleware",
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,160 +0,0 @@
|
|||
"""Memory injection middleware for the SurfSense agent.
|
||||
|
||||
Injects memory markdown into the system prompt on every turn:
|
||||
- Private threads: only personal memory (<user_memory>)
|
||||
- Shared threads: only team memory (<team_memory>)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import ChatVisibility, SearchSpace, User, shielded_async_session
|
||||
from app.services.memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Injects memory markdown into the conversation on every turn."""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
user_id: str | UUID | None,
|
||||
search_space_id: int,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
) -> None:
|
||||
self.user_id = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
self.search_space_id = search_space_id
|
||||
self.visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
if not isinstance(last_message, HumanMessage):
|
||||
return None
|
||||
|
||||
start = time.perf_counter()
|
||||
db_elapsed = 0.0
|
||||
memory_blocks: list[str] = []
|
||||
scope = "team" if self.visibility == ChatVisibility.SEARCH_SPACE else "user"
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
db_start = time.perf_counter()
|
||||
if self.visibility == ChatVisibility.SEARCH_SPACE:
|
||||
team_memory = await self._load_team_memory(session)
|
||||
if team_memory:
|
||||
chars = len(team_memory)
|
||||
memory_blocks.append(
|
||||
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
||||
f"{team_memory}\n"
|
||||
f"</team_memory>"
|
||||
)
|
||||
if chars > MEMORY_SOFT_LIMIT:
|
||||
memory_blocks.append(
|
||||
f"<memory_warning>Team memory is at "
|
||||
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
||||
f"the hard limit. On your next update_memory call, consolidate "
|
||||
f"by merging duplicates, removing outdated entries, and "
|
||||
f"shortening descriptions before adding anything new."
|
||||
f"</memory_warning>"
|
||||
)
|
||||
elif self.user_id is not None:
|
||||
user_memory, display_name = await self._load_user_memory(session)
|
||||
if display_name and display_name.strip():
|
||||
first_name = display_name.strip().split()[0]
|
||||
memory_blocks.append(f"<user_name>{first_name}</user_name>")
|
||||
if user_memory:
|
||||
chars = len(user_memory)
|
||||
memory_blocks.append(
|
||||
f'<user_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
||||
f"{user_memory}\n"
|
||||
f"</user_memory>"
|
||||
)
|
||||
if chars > MEMORY_SOFT_LIMIT:
|
||||
memory_blocks.append(
|
||||
f"<memory_warning>Your personal memory is at "
|
||||
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
||||
f"the hard limit. On your next update_memory call, consolidate "
|
||||
f"by merging duplicates, removing outdated entries, and "
|
||||
f"shortening descriptions before adding anything new."
|
||||
f"</memory_warning>"
|
||||
)
|
||||
|
||||
db_elapsed = time.perf_counter() - db_start
|
||||
|
||||
if not memory_blocks:
|
||||
_perf_log.info(
|
||||
"[memory_injection] scope=%s injected=0 db=%.3fs total=%.3fs",
|
||||
scope,
|
||||
db_elapsed,
|
||||
time.perf_counter() - start,
|
||||
)
|
||||
return None
|
||||
|
||||
memory_text = "\n\n".join(memory_blocks)
|
||||
memory_msg = SystemMessage(content=memory_text)
|
||||
|
||||
new_messages = list(messages)
|
||||
insert_idx = 1 if len(new_messages) > 1 else 0
|
||||
new_messages.insert(insert_idx, memory_msg)
|
||||
|
||||
_perf_log.info(
|
||||
"[memory_injection] scope=%s injected=1 chars=%d db=%.3fs total=%.3fs",
|
||||
scope,
|
||||
len(memory_text),
|
||||
db_elapsed,
|
||||
time.perf_counter() - start,
|
||||
)
|
||||
return {"messages": new_messages}
|
||||
|
||||
async def _load_user_memory(
|
||||
self, session: AsyncSession
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Return (memory_content, display_name)."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(User.memory_md, User.display_name).where(User.id == self.user_id)
|
||||
)
|
||||
row = result.one_or_none()
|
||||
if row is None:
|
||||
return None, None
|
||||
return row.memory_md or None, row.display_name
|
||||
except Exception:
|
||||
logger.exception("Failed to load user memory")
|
||||
return None, None
|
||||
|
||||
async def _load_team_memory(self, session: AsyncSession) -> str | None:
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSpace.shared_memory_md).where(
|
||||
SearchSpace.id == self.search_space_id
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
return row if row else None
|
||||
except Exception:
|
||||
logger.exception("Failed to load team memory")
|
||||
return None
|
||||
|
|
@ -1,111 +0,0 @@
|
|||
"""Fallback only on provider/network errors; let programming bugs raise."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
# Matched by class name across the MRO so we don't have to import every
|
||||
# provider SDK (openai/anthropic/google/...). Extend as new providers ship.
|
||||
_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"RateLimitError",
|
||||
"APIStatusError",
|
||||
"InternalServerError",
|
||||
"ServiceUnavailableError",
|
||||
"BadGatewayError",
|
||||
"GatewayTimeoutError",
|
||||
"APIConnectionError",
|
||||
"APITimeoutError",
|
||||
"ConnectError",
|
||||
"ConnectTimeout",
|
||||
"ReadTimeout",
|
||||
"RemoteProtocolError",
|
||||
"TimeoutError",
|
||||
"TimeoutException",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_fallback_eligible(exc: BaseException) -> bool:
|
||||
return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__)
|
||||
|
||||
|
||||
class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
|
||||
"""Re-raise non-provider exceptions instead of walking the fallback chain."""
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[Any],
|
||||
handler: Callable[[ModelRequest[Any]], ModelResponse[Any]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
last_exception: Exception
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
|
||||
for attempt, fallback_model in enumerate(self.models, start=1):
|
||||
ot.add_event(
|
||||
"model.fallback",
|
||||
{
|
||||
"fallback.attempt": attempt,
|
||||
"fallback.from": attempt - 1,
|
||||
"fallback.to": attempt,
|
||||
"fallback.reason": ot_metrics.categorize_exception(last_exception),
|
||||
},
|
||||
)
|
||||
try:
|
||||
return handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[Any],
|
||||
handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
last_exception: Exception
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
|
||||
for attempt, fallback_model in enumerate(self.models, start=1):
|
||||
ot.add_event(
|
||||
"model.fallback",
|
||||
{
|
||||
"fallback.attempt": attempt,
|
||||
"fallback.from": attempt - 1,
|
||||
"fallback.to": attempt,
|
||||
"fallback.reason": ot_metrics.categorize_exception(last_exception),
|
||||
},
|
||||
)
|
||||
try:
|
||||
return await handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
Loading…
Add table
Add a link
Reference in a new issue