mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): introduce chat/ category; dissolve top-level agents/shared
Recursive shared-folder rule: a shared/ must be shared by ALL siblings at its
level. The kernel (context, compaction, retry_after, web_search) was shared by
only 2 of the agents -- anonymous_chat + multi_agent_chat -- never by podcaster
or video_presentation. Those 2 are the "chat" category, so their shared code
belongs in that category's shared/, not the top-level one.
app/agents/anonymous_chat/ -> app/agents/chat/anonymous_chat/
app/agents/multi_agent_chat/ -> app/agents/chat/multi_agent_chat/
app/agents/shared/ -> app/agents/chat/shared/ (anon<->mac kernel)
Top-level app/agents/shared/ is gone: nothing was shared across all three
categories (chat / podcaster / video_presentation).
~289 import sites rewritten (app.agents.{anonymous_chat,multi_agent_chat,shared}
-> app.agents.chat.*); all moves are git renames (history preserved).
app/agents/ now: chat/, podcaster/, video_presentation/, runtime/.
This commit is contained in:
parent
d59bb2b5aa
commit
24b62a63b4
570 changed files with 712 additions and 613 deletions
9
surfsense_backend/app/agents/chat/shared/__init__.py
Normal file
9
surfsense_backend/app/agents/chat/shared/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""Cross-package agent contracts.
|
||||
|
||||
Symbols here are intentionally framework-light (no LangGraph / deepagents
|
||||
internals) so they can be imported from both ``app.agents.new_chat`` and
|
||||
``app.agents.chat.multi_agent_chat`` without creating a circular dependency
|
||||
between the two packages. See ``receipt.py`` for the rationale.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
71
surfsense_backend/app/agents/chat/shared/context.py
Normal file
71
surfsense_backend/app/agents/chat/shared/context.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""
|
||||
Context schema definitions for SurfSense agents.
|
||||
|
||||
This module defines the per-invocation context object passed to the SurfSense
|
||||
deep agent via ``agent.astream_events(..., context=ctx)`` (LangGraph >= 0.6).
|
||||
|
||||
The agent's compiled graph is the same across invocations (and cached by
|
||||
``agent_cache``), so anything that varies per turn — the user mentions a
|
||||
specific document, the front-end issues a unique ``request_id``, etc. —
|
||||
MUST live on this context object instead of being captured into a
|
||||
middleware ``__init__`` closure. Middlewares read fields back via
|
||||
``runtime.context.<field>``; tools read them via ``runtime.context``.
|
||||
|
||||
This object is read inside both ``KnowledgePriorityMiddleware`` (for
|
||||
``mentioned_document_ids``) and any future middleware that needs
|
||||
per-request state without invalidating the compiled-agent cache.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class FileOperationContractState(TypedDict):
|
||||
intent: str
|
||||
confidence: float
|
||||
suggested_path: str
|
||||
timestamp: str
|
||||
turn_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SurfSenseContextSchema:
|
||||
"""
|
||||
Per-invocation context for the SurfSense deep agent.
|
||||
|
||||
Defaults are chosen so the dataclass can be safely default-constructed
|
||||
(LangGraph's ``Runtime.context`` itself defaults to ``None`` if no
|
||||
context is supplied — see ``langgraph.runtime.Runtime``). All fields
|
||||
are optional; consumers must None-check before reading.
|
||||
|
||||
Phase 1.5 fields:
|
||||
search_space_id: Search space the request is scoped to.
|
||||
mentioned_document_ids: KB documents the user @-mentioned this turn.
|
||||
Read by ``KnowledgePriorityMiddleware`` to seed its priority
|
||||
list. Stays out of the compiled-agent cache key — that's the
|
||||
whole point of putting it here.
|
||||
mentioned_folder_ids: KB folders the user @-mentioned this turn
|
||||
(cloud filesystem mode). Surfaced as ``[USER-MENTIONED]``
|
||||
entries in ``<priority_documents>`` so the agent prioritises
|
||||
walking those folders with ``ls`` / ``find_documents``.
|
||||
file_operation_contract: One-shot file operation contract for the
|
||||
upcoming turn (reserved; not currently populated).
|
||||
turn_id / request_id: Correlation IDs surfaced by the streaming
|
||||
task; populated for telemetry.
|
||||
|
||||
Phase 2 will extend with: thread_id, user_id, visibility,
|
||||
filesystem_mode, anon_session_id, available_connectors,
|
||||
available_document_types, created_by_id (everything currently captured
|
||||
by middleware ``__init__`` closures).
|
||||
"""
|
||||
|
||||
search_space_id: int | None = None
|
||||
mentioned_document_ids: list[int] = field(default_factory=list)
|
||||
mentioned_folder_ids: list[int] = field(default_factory=list)
|
||||
mentioned_connector_ids: list[int] = field(default_factory=list)
|
||||
mentioned_connectors: list[dict[str, object]] = field(default_factory=list)
|
||||
file_operation_contract: FileOperationContractState | None = None
|
||||
turn_id: str | None = None
|
||||
request_id: str | None = None
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
"""Shared middleware components for the SurfSense chat agents."""
|
||||
|
||||
from app.agents.chat.shared.middleware.compaction import (
|
||||
SurfSenseCompactionMiddleware,
|
||||
create_surfsense_compaction_middleware,
|
||||
)
|
||||
from app.agents.chat.shared.middleware.retry_after import RetryAfterMiddleware
|
||||
|
||||
__all__ = [
|
||||
"RetryAfterMiddleware",
|
||||
"SurfSenseCompactionMiddleware",
|
||||
"create_surfsense_compaction_middleware",
|
||||
]
|
||||
|
|
@ -0,0 +1,255 @@
|
|||
"""
|
||||
SurfSense compaction middleware.
|
||||
|
||||
Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware`
|
||||
to add SurfSense-specific behavior:
|
||||
|
||||
1. **Structured summary template** (OpenCode-style ``## Goal / Constraints /
|
||||
Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``)
|
||||
— see :data:`SURFSENSE_SUMMARY_PROMPT` below. The base
|
||||
``SummarizationMiddleware`` only ships a freeform "summarize this"
|
||||
prompt; the structured template is ported from OpenCode's
|
||||
``compaction.ts``.
|
||||
2. **Protect SurfSense-specific SystemMessages** so injected hints
|
||||
(``<priority_documents>``, ``<workspace_tree>``, ``<file_operation_contract>``,
|
||||
``<user_memory>``, ``<team_memory>``, ``<user_name>``, ``<memory_warning>``)
|
||||
are *not* summarized away and are kept verbatim in the post-summary
|
||||
message list. Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
|
||||
(some message types are part of the agent's contract and must survive
|
||||
compaction unchanged).
|
||||
3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string``
|
||||
(Azure OpenAI / LiteLLM defense — when a provider streams an AIMessage
|
||||
containing only tool_calls and no text, ``content`` can be ``None`` and
|
||||
``get_buffer_string`` crashes iterating over ``None``). SurfSense-specific.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from deepagents.middleware.summarization import (
|
||||
SummarizationMiddleware,
|
||||
compute_summarization_defaults,
|
||||
)
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deepagents.backends.protocol import BACKEND_TYPES
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AnyMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Structured summary template ported from OpenCode's
|
||||
# ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a
|
||||
# module-level constant so unit tests can assert on its sections.
|
||||
SURFSENSE_SUMMARY_PROMPT = """<role>
|
||||
SurfSense Conversation Compaction Assistant
|
||||
</role>
|
||||
|
||||
<primary_objective>
|
||||
Extract the most important context from the conversation history below into a structured summary that will replace the older messages.
|
||||
</primary_objective>
|
||||
|
||||
<instructions>
|
||||
You are running because the conversation has grown beyond the model's input window. The conversation history below will be summarized and replaced with your output. Use the structured template that follows; keep each section concise but comprehensive enough that the agent can resume work without losing context. Each section is a checklist — populate it with relevant content or write "None" if there is nothing to report.
|
||||
|
||||
## Goal
|
||||
What is the user's primary goal or request? State it in one or two sentences.
|
||||
|
||||
## Constraints
|
||||
What boundaries must the agent respect (citations rules, visibility scope, allowed tools, user-imposed style, deadlines, deny-listed topics)?
|
||||
|
||||
## Progress
|
||||
What has the agent already accomplished? List each completed step succinctly. Do not reproduce tool output; just record the conclusion.
|
||||
|
||||
## Key Decisions
|
||||
What choices were made and why? Include rejected alternatives and the reasoning behind selecting the current path.
|
||||
|
||||
## Next Steps
|
||||
What specific tasks remain to achieve the goal? Order them by dependency.
|
||||
|
||||
## Critical Context
|
||||
What facts, IDs, document titles, query keywords, error messages, or partial answers must persist into the next turn? Include verbatim quotes only when the exact wording matters (e.g. a precise filter clause or a literal name).
|
||||
|
||||
## Relevant Files
|
||||
What documents or paths in the SurfSense knowledge base are in play? Use ``/documents/...`` paths exactly as they appeared in the workspace tree.
|
||||
</instructions>
|
||||
|
||||
<messages>
|
||||
Messages to summarize:
|
||||
{messages}
|
||||
</messages>
|
||||
|
||||
Respond ONLY with the structured summary. Do not include any text before or after.
|
||||
"""
|
||||
|
||||
# SystemMessage prefixes that must NOT be summarized away. They are
|
||||
# re-injected on every turn by the corresponding middleware, but the
|
||||
# compaction step happens *before* re-injection in some paths, so we
|
||||
# must preserve them verbatim across the cutoff.
|
||||
PROTECTED_SYSTEM_PREFIXES: tuple[str, ...] = (
|
||||
"<priority_documents>", # KnowledgePriorityMiddleware
|
||||
"<workspace_tree>", # KnowledgeTreeMiddleware
|
||||
"<file_operation_contract>", # reserved file-operation contract prefix
|
||||
"<user_memory>", # MemoryInjectionMiddleware
|
||||
"<team_memory>", # MemoryInjectionMiddleware
|
||||
"<user_name>", # MemoryInjectionMiddleware
|
||||
"<memory_warning>", # MemoryInjectionMiddleware
|
||||
)
|
||||
|
||||
|
||||
def _is_protected_system_message(msg: AnyMessage) -> bool:
|
||||
"""Return True if ``msg`` is a SystemMessage we must not summarize."""
|
||||
if not isinstance(msg, SystemMessage):
|
||||
return False
|
||||
content = msg.content
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
stripped = content.lstrip()
|
||||
return any(stripped.startswith(prefix) for prefix in PROTECTED_SYSTEM_PREFIXES)
|
||||
|
||||
|
||||
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
|
||||
"""Return ``msg`` with ``content=None`` coerced to ``""``.
|
||||
|
||||
Folds in the historical defense from ``safe_summarization.py`` —
|
||||
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``,
|
||||
so a ``None`` content (Azure OpenAI / LiteLLM streaming a tool-only
|
||||
AIMessage) explodes. We return a copy with empty string content so
|
||||
downstream consumers see an empty body without mutating the original.
|
||||
"""
|
||||
if getattr(msg, "content", "not-missing") is not None:
|
||||
return msg
|
||||
try:
|
||||
return msg.model_copy(update={"content": ""})
|
||||
except AttributeError:
|
||||
import copy
|
||||
|
||||
new_msg = copy.copy(msg)
|
||||
try:
|
||||
new_msg.content = ""
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not sanitize content=None on message of type %s",
|
||||
type(msg).__name__,
|
||||
)
|
||||
return msg
|
||||
return new_msg
|
||||
|
||||
|
||||
class SurfSenseCompactionMiddleware(SummarizationMiddleware):
|
||||
"""SummarizationMiddleware tuned for SurfSense.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Overrides :meth:`_partition_messages` so protected SystemMessages
|
||||
survive into the ``preserved_messages`` half regardless of cutoff.
|
||||
- Overrides :meth:`_filter_summary_messages` so the buffer-string path
|
||||
never iterates ``None`` content.
|
||||
- Inherits everything else (auto-trigger, backend offload,
|
||||
``_summarization_event`` plumbing, ``ContextOverflowError`` fallback).
|
||||
"""
|
||||
|
||||
def _partition_messages( # type: ignore[override]
|
||||
self,
|
||||
conversation_messages: list[AnyMessage],
|
||||
cutoff_index: int,
|
||||
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
||||
"""Split messages but always preserve SurfSense protected SystemMessages.
|
||||
|
||||
Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
|
||||
(``opencode/packages/opencode/src/session/compaction.ts``): some
|
||||
message types are always kept verbatim because they are part of the
|
||||
agent's working contract, not transient output.
|
||||
|
||||
Also opens a ``compaction.run`` OTel span (no-op when OTel is off)
|
||||
so dashboards can count compaction events and message-volume
|
||||
without having to instrument upstream callers.
|
||||
"""
|
||||
# Opening a span here is appropriate because partitioning is the
|
||||
# first call SummarizationMiddleware makes when it has decided to
|
||||
# summarize; we record the volume and then close as a normal span.
|
||||
with ot.compaction_span(
|
||||
reason="auto",
|
||||
messages_in=len(conversation_messages),
|
||||
extra={"compaction.cutoff_index": int(cutoff_index)},
|
||||
):
|
||||
ot_metrics.record_compaction_run(reason="auto")
|
||||
messages_to_summarize, preserved_messages = super()._partition_messages(
|
||||
conversation_messages, cutoff_index
|
||||
)
|
||||
|
||||
protected: list[AnyMessage] = []
|
||||
kept_for_summary: list[AnyMessage] = []
|
||||
for msg in messages_to_summarize:
|
||||
if _is_protected_system_message(msg):
|
||||
protected.append(msg)
|
||||
else:
|
||||
kept_for_summary.append(msg)
|
||||
|
||||
# Place protected blocks at the *front* of preserved_messages so
|
||||
# they keep their original ordering relative to the summary
|
||||
# HumanMessage that precedes the rest of the preserved tail.
|
||||
return kept_for_summary, [*protected, *preserved_messages]
|
||||
|
||||
def _filter_summary_messages( # type: ignore[override]
|
||||
self, messages: list[AnyMessage]
|
||||
) -> list[AnyMessage]:
|
||||
"""Filter previous summaries AND sanitize ``content=None``.
|
||||
|
||||
Folds the ``safe_summarization.py`` defense in: when the buffer
|
||||
builder iterates ``m.text`` over ``None`` it explodes; sanitizing
|
||||
here covers both the sync and async offload paths.
|
||||
"""
|
||||
filtered = super()._filter_summary_messages(messages)
|
||||
return [_sanitize_message_content(m) for m in filtered]
|
||||
|
||||
|
||||
def create_surfsense_compaction_middleware(
|
||||
model: BaseChatModel,
|
||||
backend: BACKEND_TYPES,
|
||||
*,
|
||||
summary_prompt: str | None = None,
|
||||
history_path_prefix: str = "/conversation_history",
|
||||
**overrides: Any,
|
||||
) -> SurfSenseCompactionMiddleware:
|
||||
"""Build a :class:`SurfSenseCompactionMiddleware` with sensible defaults.
|
||||
|
||||
Pulls profile-aware ``trigger`` / ``keep`` / ``truncate_args_settings``
|
||||
via :func:`deepagents.middleware.summarization.compute_summarization_defaults`
|
||||
so callers get the same behavior as ``create_summarization_middleware``
|
||||
plus our overrides.
|
||||
|
||||
Args:
|
||||
model: Chat model to call for summary generation.
|
||||
backend: Backend instance or factory for offloading conversation history.
|
||||
summary_prompt: Optional override; defaults to :data:`SURFSENSE_SUMMARY_PROMPT`.
|
||||
history_path_prefix: Path prefix for offloaded conversation history.
|
||||
**overrides: Forwarded to :class:`SurfSenseCompactionMiddleware`.
|
||||
"""
|
||||
defaults = compute_summarization_defaults(model)
|
||||
return SurfSenseCompactionMiddleware(
|
||||
model=model,
|
||||
backend=backend,
|
||||
trigger=overrides.pop("trigger", defaults["trigger"]),
|
||||
keep=overrides.pop("keep", defaults["keep"]),
|
||||
trim_tokens_to_summarize=overrides.pop("trim_tokens_to_summarize", None),
|
||||
truncate_args_settings=overrides.pop(
|
||||
"truncate_args_settings", defaults["truncate_args_settings"]
|
||||
),
|
||||
summary_prompt=summary_prompt or SURFSENSE_SUMMARY_PROMPT,
|
||||
history_path_prefix=history_path_prefix,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PROTECTED_SYSTEM_PREFIXES",
|
||||
"SURFSENSE_SUMMARY_PROMPT",
|
||||
"SurfSenseCompactionMiddleware",
|
||||
"create_surfsense_compaction_middleware",
|
||||
]
|
||||
|
|
@ -0,0 +1,277 @@
|
|||
"""
|
||||
RetryAfterMiddleware — Header-aware retry with custom backoff and SSE eventing.
|
||||
|
||||
LangChain's :class:`ModelRetryMiddleware` retries on exceptions but ignores
|
||||
the ``Retry-After`` HTTP header — it just runs its own exponential backoff.
|
||||
That wastes time when a provider has explicitly told us how long to wait.
|
||||
This middleware honors the header (mirroring OpenCode's
|
||||
``packages/opencode/src/session/llm.ts`` retry pathway) and emits an SSE
|
||||
event so the UI can show "rate-limited, retrying in Ns".
|
||||
|
||||
We can't subclass ``ModelRetryMiddleware`` cleanly because its loop calls a
|
||||
module-level ``calculate_delay`` inline (no overridable
|
||||
``_calculate_delay`` hook), so this is a standalone implementation.
|
||||
|
||||
Behaviour:
|
||||
- Extracts ``Retry-After`` / ``retry-after-ms`` from
|
||||
``litellm.exceptions.RateLimitError.response.headers`` (or any exception
|
||||
exposing a similar shape).
|
||||
- Sleeps ``max(exponential_backoff, header_delay)`` between retries.
|
||||
- Returns ``False`` from ``retry_on`` for ``ContextWindowExceededError`` /
|
||||
``ContextOverflowError`` so :class:`SurfSenseCompactionMiddleware` (or
|
||||
the LangChain summarization fallback path) handles those instead.
|
||||
- Emits ``surfsense.retrying`` via ``adispatch_custom_event`` on each retry
|
||||
so ``stream_new_chat`` can forward it to clients as an SSE event.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Names of exception classes for which a retry would not help — context
|
||||
# overflow needs compaction, auth needs human intervention, etc. Detected
|
||||
# by class-name substring so we don't have to import LiteLLM/Anthropic
|
||||
# here (which would tie this module to optional deps).
|
||||
_NON_RETRYABLE_NAME_HINTS: tuple[str, ...] = (
|
||||
"ContextWindowExceeded",
|
||||
"ContextOverflow",
|
||||
"AuthenticationError",
|
||||
"InvalidRequestError",
|
||||
"PermissionDenied",
|
||||
"InvalidApiKey",
|
||||
"ContextLimit",
|
||||
)
|
||||
|
||||
|
||||
def _is_non_retryable(exc: BaseException) -> bool:
|
||||
name = type(exc).__name__
|
||||
return any(hint in name for hint in _NON_RETRYABLE_NAME_HINTS)
|
||||
|
||||
|
||||
def _extract_retry_after_seconds(exc: BaseException) -> float | None:
|
||||
"""Return seconds-to-wait suggested by the provider, if any.
|
||||
|
||||
Looks at ``exc.response.headers`` or ``exc.headers`` for the standard
|
||||
HTTP ``Retry-After`` header (in seconds) or its millisecond cousin
|
||||
``retry-after-ms`` (sometimes used by Anthropic / OpenAI). Falls back
|
||||
to a regex on the exception message for shapes like
|
||||
``"Please retry after 30s"``.
|
||||
"""
|
||||
headers: dict[str, Any] | None = None
|
||||
response = getattr(exc, "response", None)
|
||||
if response is not None:
|
||||
headers = getattr(response, "headers", None)
|
||||
if headers is None:
|
||||
headers = getattr(exc, "headers", None)
|
||||
|
||||
if isinstance(headers, dict):
|
||||
# Normalize keys to lowercase for case-insensitive matching
|
||||
norm = {str(k).lower(): v for k, v in headers.items()}
|
||||
ms = norm.get("retry-after-ms")
|
||||
if ms is not None:
|
||||
try:
|
||||
return float(ms) / 1000.0
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
seconds = norm.get("retry-after")
|
||||
if seconds is not None:
|
||||
try:
|
||||
return float(seconds)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# Last resort: scan the message for "retry after Xs" or "X seconds"
|
||||
msg = str(exc)
|
||||
match = re.search(r"retry\s+after\s+([0-9]+(?:\.[0-9]+)?)", msg, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
return float(match.group(1))
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _exponential_delay(
|
||||
attempt: int,
|
||||
*,
|
||||
initial_delay: float,
|
||||
backoff_factor: float,
|
||||
max_delay: float,
|
||||
jitter: bool,
|
||||
) -> float:
|
||||
"""Compute an exponential-backoff delay with optional ±25% jitter."""
|
||||
delay = (
|
||||
initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
|
||||
)
|
||||
delay = min(delay, max_delay)
|
||||
if jitter and delay > 0:
|
||||
delay *= 1 + random.uniform(-0.25, 0.25)
|
||||
return max(delay, 0.0)
|
||||
|
||||
|
||||
class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Retry middleware that honors provider-issued Retry-After hints.
|
||||
|
||||
Drop-in replacement for :class:`langchain.agents.middleware.ModelRetryMiddleware`
|
||||
when working with LiteLLM/Anthropic/OpenAI providers that surface
|
||||
rate-limit hints in headers. Always emits ``surfsense.retrying`` SSE
|
||||
events so the UI can show a friendly "rate limited, retrying in Xs"
|
||||
indicator.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum retries after the initial attempt (default 3).
|
||||
initial_delay: Initial backoff delay in seconds.
|
||||
backoff_factor: Exponential growth factor for backoff.
|
||||
max_delay: Cap on per-attempt delay in seconds.
|
||||
jitter: Whether to add ±25% jitter.
|
||||
retry_on: Optional callable that returns True for retryable
|
||||
exceptions. The default retries everything except known
|
||||
non-retryable classes (context overflow, auth, etc.).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_retries: int = 3,
|
||||
initial_delay: float = 1.0,
|
||||
backoff_factor: float = 2.0,
|
||||
max_delay: float = 60.0,
|
||||
jitter: bool = True,
|
||||
retry_on: Callable[[BaseException], bool] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.max_retries = max_retries
|
||||
self.initial_delay = initial_delay
|
||||
self.backoff_factor = backoff_factor
|
||||
self.max_delay = max_delay
|
||||
self.jitter = jitter
|
||||
self._retry_on: Callable[[BaseException], bool] = retry_on or (
|
||||
lambda exc: not _is_non_retryable(exc)
|
||||
)
|
||||
|
||||
def _should_retry(self, exc: BaseException) -> bool:
|
||||
try:
|
||||
return bool(self._retry_on(exc))
|
||||
except Exception:
|
||||
logger.exception("retry_on callable raised; defaulting to False")
|
||||
return False
|
||||
|
||||
def _delay_for_attempt(self, attempt: int, exc: BaseException) -> float:
|
||||
backoff = _exponential_delay(
|
||||
attempt,
|
||||
initial_delay=self.initial_delay,
|
||||
backoff_factor=self.backoff_factor,
|
||||
max_delay=self.max_delay,
|
||||
jitter=self.jitter,
|
||||
)
|
||||
header = _extract_retry_after_seconds(exc) or 0.0
|
||||
return max(backoff, header)
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as exc:
|
||||
if not self._should_retry(exc) or attempt >= self.max_retries:
|
||||
raise
|
||||
delay = self._delay_for_attempt(attempt, exc)
|
||||
ot.add_event(
|
||||
"model.retry.scheduled",
|
||||
{
|
||||
"retry.attempt": attempt + 1,
|
||||
"retry.max": self.max_retries,
|
||||
"retry.delay_ms": int(delay * 1000),
|
||||
"retry.reason": ot_metrics.categorize_exception(exc),
|
||||
},
|
||||
)
|
||||
try:
|
||||
dispatch_custom_event(
|
||||
"surfsense.retrying",
|
||||
{
|
||||
"attempt": attempt + 1,
|
||||
"max_retries": self.max_retries,
|
||||
"delay_ms": int(delay * 1000),
|
||||
"reason": type(exc).__name__,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"dispatch_custom_event failed; suppressed", exc_info=True
|
||||
)
|
||||
if delay > 0:
|
||||
time.sleep(delay)
|
||||
# Unreachable
|
||||
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as exc:
|
||||
if not self._should_retry(exc) or attempt >= self.max_retries:
|
||||
raise
|
||||
delay = self._delay_for_attempt(attempt, exc)
|
||||
ot.add_event(
|
||||
"model.retry.scheduled",
|
||||
{
|
||||
"retry.attempt": attempt + 1,
|
||||
"retry.max": self.max_retries,
|
||||
"retry.delay_ms": int(delay * 1000),
|
||||
"retry.reason": ot_metrics.categorize_exception(exc),
|
||||
},
|
||||
)
|
||||
try:
|
||||
await adispatch_custom_event(
|
||||
"surfsense.retrying",
|
||||
{
|
||||
"attempt": attempt + 1,
|
||||
"max_retries": self.max_retries,
|
||||
"delay_ms": int(delay * 1000),
|
||||
"reason": type(exc).__name__,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"adispatch_custom_event failed; suppressed", exc_info=True
|
||||
)
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RetryAfterMiddleware",
|
||||
"_extract_retry_after_seconds",
|
||||
"_is_non_retryable",
|
||||
]
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""Cross-agent shared tools.
|
||||
|
||||
Only genuinely cross-agent tool code lives here (currently web_search, imported
|
||||
directly from its module).
|
||||
"""
|
||||
247
surfsense_backend/app/agents/chat/shared/tools/web_search.py
Normal file
247
surfsense_backend/app/agents/chat/shared/tools/web_search.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
"""
|
||||
Web search tool for the SurfSense agent.
|
||||
|
||||
Provides a unified tool for real-time web searches that dispatches to all
|
||||
configured search engines: the platform SearXNG instance (always available)
|
||||
plus any user-configured live-search connectors (Tavily, Linkup, Baidu).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.db import shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_LIVE_SEARCH_CONNECTORS: set[str] = {
|
||||
"TAVILY_API",
|
||||
"LINKUP_API",
|
||||
"BAIDU_SEARCH_API",
|
||||
}
|
||||
|
||||
_LIVE_CONNECTOR_SPECS: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
|
||||
"TAVILY_API": ("search_tavily", False, True, {}),
|
||||
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
|
||||
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
|
||||
}
|
||||
|
||||
_CONNECTOR_LABELS: dict[str, str] = {
|
||||
"TAVILY_API": "Tavily",
|
||||
"LINKUP_API": "Linkup",
|
||||
"BAIDU_SEARCH_API": "Baidu",
|
||||
}
|
||||
|
||||
|
||||
class WebSearchInput(BaseModel):
|
||||
"""Input schema for the web_search tool."""
|
||||
|
||||
query: str = Field(
|
||||
description="The search query to look up on the web. Use specific, descriptive terms.",
|
||||
)
|
||||
top_k: int = Field(
|
||||
default=10,
|
||||
description="Number of results to retrieve (default: 10, max: 50).",
|
||||
)
|
||||
|
||||
|
||||
def _format_web_results(
|
||||
documents: list[dict[str, Any]],
|
||||
*,
|
||||
max_chars: int = 50_000,
|
||||
) -> str:
|
||||
"""Format web search results into XML suitable for the LLM context."""
|
||||
if not documents:
|
||||
return "No web search results found."
|
||||
|
||||
parts: list[str] = []
|
||||
total_chars = 0
|
||||
|
||||
for doc in documents:
|
||||
doc_info = doc.get("document") or {}
|
||||
metadata = doc_info.get("metadata") or {}
|
||||
title = doc_info.get("title") or "Web Result"
|
||||
url = metadata.get("url") or ""
|
||||
content = (doc.get("content") or "").strip()
|
||||
source = metadata.get("document_type") or doc.get("source") or "WEB_SEARCH"
|
||||
if not content:
|
||||
continue
|
||||
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||
doc_xml = "\n".join(
|
||||
[
|
||||
"<document>",
|
||||
"<document_metadata>",
|
||||
f" <document_type>{source}</document_type>",
|
||||
f" <title><![CDATA[{title}]]></title>",
|
||||
f" <url><![CDATA[{url}]]></url>",
|
||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
||||
"</document_metadata>",
|
||||
"<document_content>",
|
||||
f" <chunk id='{url}'><![CDATA[{content}]]></chunk>",
|
||||
"</document_content>",
|
||||
"</document>",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
if total_chars + len(doc_xml) > max_chars:
|
||||
parts.append("<!-- Output truncated to fit context window -->")
|
||||
break
|
||||
|
||||
parts.append(doc_xml)
|
||||
total_chars += len(doc_xml)
|
||||
|
||||
return "\n".join(parts).strip() or "No web search results found."
|
||||
|
||||
|
||||
async def _search_live_connector(
|
||||
connector: str,
|
||||
query: str,
|
||||
search_space_id: int,
|
||||
top_k: int,
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Dispatch a single live-search connector (Tavily / Linkup / Baidu)."""
|
||||
perf = get_perf_logger()
|
||||
spec = _LIVE_CONNECTOR_SPECS.get(connector)
|
||||
if spec is None:
|
||||
return []
|
||||
|
||||
method_name, _includes_date_range, includes_top_k, extra_kwargs = spec
|
||||
kwargs: dict[str, Any] = {
|
||||
"user_query": query,
|
||||
"search_space_id": search_space_id,
|
||||
**extra_kwargs,
|
||||
}
|
||||
if includes_top_k:
|
||||
kwargs["top_k"] = top_k
|
||||
|
||||
try:
|
||||
t0 = time.perf_counter()
|
||||
async with semaphore, shielded_async_session() as session:
|
||||
svc = ConnectorService(session, search_space_id)
|
||||
_, chunks = await getattr(svc, method_name)(**kwargs)
|
||||
perf.info(
|
||||
"[web_search] connector=%s results=%d in %.3fs",
|
||||
connector,
|
||||
len(chunks),
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
return chunks
|
||||
except Exception as e:
|
||||
perf.warning("[web_search] connector=%s FAILED: %s", connector, e)
|
||||
return []
|
||||
|
||||
|
||||
def create_web_search_tool(
|
||||
search_space_id: int | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
) -> StructuredTool:
|
||||
"""Factory for the ``web_search`` tool.
|
||||
|
||||
Dispatches in parallel to the platform SearXNG instance and any
|
||||
user-configured live-search connectors (Tavily, Linkup, Baidu).
|
||||
"""
|
||||
active_live_connectors: list[str] = []
|
||||
if available_connectors:
|
||||
active_live_connectors = [
|
||||
c for c in available_connectors if c in _LIVE_SEARCH_CONNECTORS
|
||||
]
|
||||
|
||||
engine_names = ["SearXNG (platform default)"]
|
||||
engine_names.extend(_CONNECTOR_LABELS.get(c, c) for c in active_live_connectors)
|
||||
engines_summary = ", ".join(engine_names)
|
||||
|
||||
description = (
|
||||
"Search the web for real-time information. "
|
||||
"Use this for current events, news, prices, weather, public facts, or any "
|
||||
"question that requires up-to-date information from the internet.\n\n"
|
||||
f"Active search engines: {engines_summary}.\n"
|
||||
"All configured engines are queried in parallel and results are merged."
|
||||
)
|
||||
|
||||
_search_space_id = search_space_id
|
||||
_active_live = active_live_connectors
|
||||
|
||||
async def _web_search_impl(query: str, top_k: int = 10) -> str:
|
||||
from app.services import web_search_service
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
clamped_top_k = min(max(1, top_k), 50)
|
||||
|
||||
semaphore = asyncio.Semaphore(4)
|
||||
tasks: list[asyncio.Task[list[dict[str, Any]]]] = []
|
||||
|
||||
if web_search_service.is_available():
|
||||
|
||||
async def _searxng() -> list[dict[str, Any]]:
|
||||
async with semaphore:
|
||||
_result_obj, docs = await web_search_service.search(
|
||||
query=query,
|
||||
top_k=clamped_top_k,
|
||||
)
|
||||
return docs
|
||||
|
||||
tasks.append(asyncio.ensure_future(_searxng()))
|
||||
|
||||
if _search_space_id is not None:
|
||||
for connector in _active_live:
|
||||
tasks.append(
|
||||
asyncio.ensure_future(
|
||||
_search_live_connector(
|
||||
connector=connector,
|
||||
query=query,
|
||||
search_space_id=_search_space_id,
|
||||
top_k=clamped_top_k,
|
||||
semaphore=semaphore,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not tasks:
|
||||
return "Web search is not available — no search engines are configured."
|
||||
|
||||
results_lists = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
all_documents: list[dict[str, Any]] = []
|
||||
for result in results_lists:
|
||||
if isinstance(result, BaseException):
|
||||
perf.warning("[web_search] a search engine failed: %s", result)
|
||||
continue
|
||||
all_documents.extend(result)
|
||||
|
||||
seen_urls: set[str] = set()
|
||||
deduplicated: list[dict[str, Any]] = []
|
||||
for doc in all_documents:
|
||||
url = ((doc.get("document") or {}).get("metadata") or {}).get("url", "")
|
||||
if url and url in seen_urls:
|
||||
continue
|
||||
if url:
|
||||
seen_urls.add(url)
|
||||
deduplicated.append(doc)
|
||||
|
||||
formatted = _format_web_results(deduplicated)
|
||||
|
||||
perf.info(
|
||||
"[web_search] query=%r engines=%d results=%d deduped=%d chars=%d in %.3fs",
|
||||
query[:60],
|
||||
len(tasks),
|
||||
len(all_documents),
|
||||
len(deduplicated),
|
||||
len(formatted),
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
return formatted
|
||||
|
||||
return StructuredTool(
|
||||
name="web_search",
|
||||
description=description,
|
||||
coroutine=_web_search_impl,
|
||||
args_schema=WebSearchInput,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue