refactor(agents): introduce chat/ category; dissolve top-level agents/shared

Recursive shared-folder rule: a shared/ must be shared by ALL siblings at its
level. The kernel (context, compaction, retry_after, web_search) was shared by
only 2 of the agents -- anonymous_chat + multi_agent_chat -- never by podcaster
or video_presentation. Those 2 are the "chat" category, so their shared code
belongs in that category's shared/, not the top-level one.

  app/agents/anonymous_chat/   -> app/agents/chat/anonymous_chat/
  app/agents/multi_agent_chat/ -> app/agents/chat/multi_agent_chat/
  app/agents/shared/           -> app/agents/chat/shared/   (anon<->mac kernel)

Top-level app/agents/shared/ is gone: nothing was shared across all three
categories (chat / podcaster / video_presentation).

~289 import sites rewritten (app.agents.{anonymous_chat,multi_agent_chat,shared}
-> app.agents.chat.*); all moves are git renames (history preserved).
app/agents/ now: chat/, podcaster/, video_presentation/, runtime/.
This commit is contained in:
CREDO23 2026-06-05 12:54:02 +02:00
parent d59bb2b5aa
commit 24b62a63b4
570 changed files with 712 additions and 613 deletions

View file

@ -0,0 +1,9 @@
"""Cross-package agent contracts.
Symbols here are intentionally framework-light (no LangGraph / deepagents
internals) so they can be imported from both ``app.agents.new_chat`` and
``app.agents.chat.multi_agent_chat`` without creating a circular dependency
between the two packages. See ``receipt.py`` for the rationale.
"""
from __future__ import annotations

View file

@ -0,0 +1,71 @@
"""
Context schema definitions for SurfSense agents.
This module defines the per-invocation context object passed to the SurfSense
deep agent via ``agent.astream_events(..., context=ctx)`` (LangGraph >= 0.6).
The agent's compiled graph is the same across invocations (and cached by
``agent_cache``), so anything that varies per turn the user mentions a
specific document, the front-end issues a unique ``request_id``, etc.
MUST live on this context object instead of being captured into a
middleware ``__init__`` closure. Middlewares read fields back via
``runtime.context.<field>``; tools read them via ``runtime.context``.
This object is read inside both ``KnowledgePriorityMiddleware`` (for
``mentioned_document_ids``) and any future middleware that needs
per-request state without invalidating the compiled-agent cache.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TypedDict
class FileOperationContractState(TypedDict):
intent: str
confidence: float
suggested_path: str
timestamp: str
turn_id: str
@dataclass
class SurfSenseContextSchema:
"""
Per-invocation context for the SurfSense deep agent.
Defaults are chosen so the dataclass can be safely default-constructed
(LangGraph's ``Runtime.context`` itself defaults to ``None`` if no
context is supplied see ``langgraph.runtime.Runtime``). All fields
are optional; consumers must None-check before reading.
Phase 1.5 fields:
search_space_id: Search space the request is scoped to.
mentioned_document_ids: KB documents the user @-mentioned this turn.
Read by ``KnowledgePriorityMiddleware`` to seed its priority
list. Stays out of the compiled-agent cache key that's the
whole point of putting it here.
mentioned_folder_ids: KB folders the user @-mentioned this turn
(cloud filesystem mode). Surfaced as ``[USER-MENTIONED]``
entries in ``<priority_documents>`` so the agent prioritises
walking those folders with ``ls`` / ``find_documents``.
file_operation_contract: One-shot file operation contract for the
upcoming turn (reserved; not currently populated).
turn_id / request_id: Correlation IDs surfaced by the streaming
task; populated for telemetry.
Phase 2 will extend with: thread_id, user_id, visibility,
filesystem_mode, anon_session_id, available_connectors,
available_document_types, created_by_id (everything currently captured
by middleware ``__init__`` closures).
"""
search_space_id: int | None = None
mentioned_document_ids: list[int] = field(default_factory=list)
mentioned_folder_ids: list[int] = field(default_factory=list)
mentioned_connector_ids: list[int] = field(default_factory=list)
mentioned_connectors: list[dict[str, object]] = field(default_factory=list)
file_operation_contract: FileOperationContractState | None = None
turn_id: str | None = None
request_id: str | None = None

View file

@ -0,0 +1,13 @@
"""Shared middleware components for the SurfSense chat agents."""
from app.agents.chat.shared.middleware.compaction import (
SurfSenseCompactionMiddleware,
create_surfsense_compaction_middleware,
)
from app.agents.chat.shared.middleware.retry_after import RetryAfterMiddleware
__all__ = [
"RetryAfterMiddleware",
"SurfSenseCompactionMiddleware",
"create_surfsense_compaction_middleware",
]

View file

@ -0,0 +1,255 @@
"""
SurfSense compaction middleware.
Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware`
to add SurfSense-specific behavior:
1. **Structured summary template** (OpenCode-style ``## Goal / Constraints /
Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``)
see :data:`SURFSENSE_SUMMARY_PROMPT` below. The base
``SummarizationMiddleware`` only ships a freeform "summarize this"
prompt; the structured template is ported from OpenCode's
``compaction.ts``.
2. **Protect SurfSense-specific SystemMessages** so injected hints
(``<priority_documents>``, ``<workspace_tree>``, ``<file_operation_contract>``,
``<user_memory>``, ``<team_memory>``, ``<user_name>``, ``<memory_warning>``)
are *not* summarized away and are kept verbatim in the post-summary
message list. Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
(some message types are part of the agent's contract and must survive
compaction unchanged).
3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string``
(Azure OpenAI / LiteLLM defense when a provider streams an AIMessage
containing only tool_calls and no text, ``content`` can be ``None`` and
``get_buffer_string`` crashes iterating over ``None``). SurfSense-specific.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from deepagents.middleware.summarization import (
SummarizationMiddleware,
compute_summarization_defaults,
)
from langchain_core.messages import SystemMessage
from app.observability import metrics as ot_metrics, otel as ot
if TYPE_CHECKING:
from deepagents.backends.protocol import BACKEND_TYPES
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AnyMessage
logger = logging.getLogger(__name__)
# Structured summary template ported from OpenCode's
# ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a
# module-level constant so unit tests can assert on its sections.
SURFSENSE_SUMMARY_PROMPT = """<role>
SurfSense Conversation Compaction Assistant
</role>
<primary_objective>
Extract the most important context from the conversation history below into a structured summary that will replace the older messages.
</primary_objective>
<instructions>
You are running because the conversation has grown beyond the model's input window. The conversation history below will be summarized and replaced with your output. Use the structured template that follows; keep each section concise but comprehensive enough that the agent can resume work without losing context. Each section is a checklist — populate it with relevant content or write "None" if there is nothing to report.
## Goal
What is the user's primary goal or request? State it in one or two sentences.
## Constraints
What boundaries must the agent respect (citations rules, visibility scope, allowed tools, user-imposed style, deadlines, deny-listed topics)?
## Progress
What has the agent already accomplished? List each completed step succinctly. Do not reproduce tool output; just record the conclusion.
## Key Decisions
What choices were made and why? Include rejected alternatives and the reasoning behind selecting the current path.
## Next Steps
What specific tasks remain to achieve the goal? Order them by dependency.
## Critical Context
What facts, IDs, document titles, query keywords, error messages, or partial answers must persist into the next turn? Include verbatim quotes only when the exact wording matters (e.g. a precise filter clause or a literal name).
## Relevant Files
What documents or paths in the SurfSense knowledge base are in play? Use ``/documents/...`` paths exactly as they appeared in the workspace tree.
</instructions>
<messages>
Messages to summarize:
{messages}
</messages>
Respond ONLY with the structured summary. Do not include any text before or after.
"""
# SystemMessage prefixes that must NOT be summarized away. They are
# re-injected on every turn by the corresponding middleware, but the
# compaction step happens *before* re-injection in some paths, so we
# must preserve them verbatim across the cutoff.
PROTECTED_SYSTEM_PREFIXES: tuple[str, ...] = (
"<priority_documents>", # KnowledgePriorityMiddleware
"<workspace_tree>", # KnowledgeTreeMiddleware
"<file_operation_contract>", # reserved file-operation contract prefix
"<user_memory>", # MemoryInjectionMiddleware
"<team_memory>", # MemoryInjectionMiddleware
"<user_name>", # MemoryInjectionMiddleware
"<memory_warning>", # MemoryInjectionMiddleware
)
def _is_protected_system_message(msg: AnyMessage) -> bool:
"""Return True if ``msg`` is a SystemMessage we must not summarize."""
if not isinstance(msg, SystemMessage):
return False
content = msg.content
if not isinstance(content, str):
return False
stripped = content.lstrip()
return any(stripped.startswith(prefix) for prefix in PROTECTED_SYSTEM_PREFIXES)
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
"""Return ``msg`` with ``content=None`` coerced to ``""``.
Folds in the historical defense from ``safe_summarization.py``
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``,
so a ``None`` content (Azure OpenAI / LiteLLM streaming a tool-only
AIMessage) explodes. We return a copy with empty string content so
downstream consumers see an empty body without mutating the original.
"""
if getattr(msg, "content", "not-missing") is not None:
return msg
try:
return msg.model_copy(update={"content": ""})
except AttributeError:
import copy
new_msg = copy.copy(msg)
try:
new_msg.content = ""
except Exception:
logger.debug(
"Could not sanitize content=None on message of type %s",
type(msg).__name__,
)
return msg
return new_msg
class SurfSenseCompactionMiddleware(SummarizationMiddleware):
"""SummarizationMiddleware tuned for SurfSense.
Notes
-----
- Overrides :meth:`_partition_messages` so protected SystemMessages
survive into the ``preserved_messages`` half regardless of cutoff.
- Overrides :meth:`_filter_summary_messages` so the buffer-string path
never iterates ``None`` content.
- Inherits everything else (auto-trigger, backend offload,
``_summarization_event`` plumbing, ``ContextOverflowError`` fallback).
"""
def _partition_messages( # type: ignore[override]
self,
conversation_messages: list[AnyMessage],
cutoff_index: int,
) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Split messages but always preserve SurfSense protected SystemMessages.
Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
(``opencode/packages/opencode/src/session/compaction.ts``): some
message types are always kept verbatim because they are part of the
agent's working contract, not transient output.
Also opens a ``compaction.run`` OTel span (no-op when OTel is off)
so dashboards can count compaction events and message-volume
without having to instrument upstream callers.
"""
# Opening a span here is appropriate because partitioning is the
# first call SummarizationMiddleware makes when it has decided to
# summarize; we record the volume and then close as a normal span.
with ot.compaction_span(
reason="auto",
messages_in=len(conversation_messages),
extra={"compaction.cutoff_index": int(cutoff_index)},
):
ot_metrics.record_compaction_run(reason="auto")
messages_to_summarize, preserved_messages = super()._partition_messages(
conversation_messages, cutoff_index
)
protected: list[AnyMessage] = []
kept_for_summary: list[AnyMessage] = []
for msg in messages_to_summarize:
if _is_protected_system_message(msg):
protected.append(msg)
else:
kept_for_summary.append(msg)
# Place protected blocks at the *front* of preserved_messages so
# they keep their original ordering relative to the summary
# HumanMessage that precedes the rest of the preserved tail.
return kept_for_summary, [*protected, *preserved_messages]
def _filter_summary_messages( # type: ignore[override]
self, messages: list[AnyMessage]
) -> list[AnyMessage]:
"""Filter previous summaries AND sanitize ``content=None``.
Folds the ``safe_summarization.py`` defense in: when the buffer
builder iterates ``m.text`` over ``None`` it explodes; sanitizing
here covers both the sync and async offload paths.
"""
filtered = super()._filter_summary_messages(messages)
return [_sanitize_message_content(m) for m in filtered]
def create_surfsense_compaction_middleware(
model: BaseChatModel,
backend: BACKEND_TYPES,
*,
summary_prompt: str | None = None,
history_path_prefix: str = "/conversation_history",
**overrides: Any,
) -> SurfSenseCompactionMiddleware:
"""Build a :class:`SurfSenseCompactionMiddleware` with sensible defaults.
Pulls profile-aware ``trigger`` / ``keep`` / ``truncate_args_settings``
via :func:`deepagents.middleware.summarization.compute_summarization_defaults`
so callers get the same behavior as ``create_summarization_middleware``
plus our overrides.
Args:
model: Chat model to call for summary generation.
backend: Backend instance or factory for offloading conversation history.
summary_prompt: Optional override; defaults to :data:`SURFSENSE_SUMMARY_PROMPT`.
history_path_prefix: Path prefix for offloaded conversation history.
**overrides: Forwarded to :class:`SurfSenseCompactionMiddleware`.
"""
defaults = compute_summarization_defaults(model)
return SurfSenseCompactionMiddleware(
model=model,
backend=backend,
trigger=overrides.pop("trigger", defaults["trigger"]),
keep=overrides.pop("keep", defaults["keep"]),
trim_tokens_to_summarize=overrides.pop("trim_tokens_to_summarize", None),
truncate_args_settings=overrides.pop(
"truncate_args_settings", defaults["truncate_args_settings"]
),
summary_prompt=summary_prompt or SURFSENSE_SUMMARY_PROMPT,
history_path_prefix=history_path_prefix,
**overrides,
)
__all__ = [
"PROTECTED_SYSTEM_PREFIXES",
"SURFSENSE_SUMMARY_PROMPT",
"SurfSenseCompactionMiddleware",
"create_surfsense_compaction_middleware",
]

View file

@ -0,0 +1,277 @@
"""
RetryAfterMiddleware Header-aware retry with custom backoff and SSE eventing.
LangChain's :class:`ModelRetryMiddleware` retries on exceptions but ignores
the ``Retry-After`` HTTP header it just runs its own exponential backoff.
That wastes time when a provider has explicitly told us how long to wait.
This middleware honors the header (mirroring OpenCode's
``packages/opencode/src/session/llm.ts`` retry pathway) and emits an SSE
event so the UI can show "rate-limited, retrying in Ns".
We can't subclass ``ModelRetryMiddleware`` cleanly because its loop calls a
module-level ``calculate_delay`` inline (no overridable
``_calculate_delay`` hook), so this is a standalone implementation.
Behaviour:
- Extracts ``Retry-After`` / ``retry-after-ms`` from
``litellm.exceptions.RateLimitError.response.headers`` (or any exception
exposing a similar shape).
- Sleeps ``max(exponential_backoff, header_delay)`` between retries.
- Returns ``False`` from ``retry_on`` for ``ContextWindowExceededError`` /
``ContextOverflowError`` so :class:`SurfSenseCompactionMiddleware` (or
the LangChain summarization fallback path) handles those instead.
- Emits ``surfsense.retrying`` via ``adispatch_custom_event`` on each retry
so ``stream_new_chat`` can forward it to clients as an SSE event.
"""
from __future__ import annotations
import asyncio
import logging
import random
import re
import time
from collections.abc import Awaitable, Callable
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
from langchain_core.messages import AIMessage
from app.observability import metrics as ot_metrics, otel as ot
logger = logging.getLogger(__name__)
# Names of exception classes for which a retry would not help — context
# overflow needs compaction, auth needs human intervention, etc. Detected
# by class-name substring so we don't have to import LiteLLM/Anthropic
# here (which would tie this module to optional deps).
_NON_RETRYABLE_NAME_HINTS: tuple[str, ...] = (
"ContextWindowExceeded",
"ContextOverflow",
"AuthenticationError",
"InvalidRequestError",
"PermissionDenied",
"InvalidApiKey",
"ContextLimit",
)
def _is_non_retryable(exc: BaseException) -> bool:
name = type(exc).__name__
return any(hint in name for hint in _NON_RETRYABLE_NAME_HINTS)
def _extract_retry_after_seconds(exc: BaseException) -> float | None:
"""Return seconds-to-wait suggested by the provider, if any.
Looks at ``exc.response.headers`` or ``exc.headers`` for the standard
HTTP ``Retry-After`` header (in seconds) or its millisecond cousin
``retry-after-ms`` (sometimes used by Anthropic / OpenAI). Falls back
to a regex on the exception message for shapes like
``"Please retry after 30s"``.
"""
headers: dict[str, Any] | None = None
response = getattr(exc, "response", None)
if response is not None:
headers = getattr(response, "headers", None)
if headers is None:
headers = getattr(exc, "headers", None)
if isinstance(headers, dict):
# Normalize keys to lowercase for case-insensitive matching
norm = {str(k).lower(): v for k, v in headers.items()}
ms = norm.get("retry-after-ms")
if ms is not None:
try:
return float(ms) / 1000.0
except (TypeError, ValueError):
pass
seconds = norm.get("retry-after")
if seconds is not None:
try:
return float(seconds)
except (TypeError, ValueError):
pass
# Last resort: scan the message for "retry after Xs" or "X seconds"
msg = str(exc)
match = re.search(r"retry\s+after\s+([0-9]+(?:\.[0-9]+)?)", msg, re.IGNORECASE)
if match:
try:
return float(match.group(1))
except ValueError:
return None
return None
def _exponential_delay(
attempt: int,
*,
initial_delay: float,
backoff_factor: float,
max_delay: float,
jitter: bool,
) -> float:
"""Compute an exponential-backoff delay with optional ±25% jitter."""
delay = (
initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
)
delay = min(delay, max_delay)
if jitter and delay > 0:
delay *= 1 + random.uniform(-0.25, 0.25)
return max(delay, 0.0)
class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Retry middleware that honors provider-issued Retry-After hints.
Drop-in replacement for :class:`langchain.agents.middleware.ModelRetryMiddleware`
when working with LiteLLM/Anthropic/OpenAI providers that surface
rate-limit hints in headers. Always emits ``surfsense.retrying`` SSE
events so the UI can show a friendly "rate limited, retrying in Xs"
indicator.
Args:
max_retries: Maximum retries after the initial attempt (default 3).
initial_delay: Initial backoff delay in seconds.
backoff_factor: Exponential growth factor for backoff.
max_delay: Cap on per-attempt delay in seconds.
jitter: Whether to add ±25% jitter.
retry_on: Optional callable that returns True for retryable
exceptions. The default retries everything except known
non-retryable classes (context overflow, auth, etc.).
"""
def __init__(
self,
*,
max_retries: int = 3,
initial_delay: float = 1.0,
backoff_factor: float = 2.0,
max_delay: float = 60.0,
jitter: bool = True,
retry_on: Callable[[BaseException], bool] | None = None,
) -> None:
super().__init__()
self.max_retries = max_retries
self.initial_delay = initial_delay
self.backoff_factor = backoff_factor
self.max_delay = max_delay
self.jitter = jitter
self._retry_on: Callable[[BaseException], bool] = retry_on or (
lambda exc: not _is_non_retryable(exc)
)
def _should_retry(self, exc: BaseException) -> bool:
try:
return bool(self._retry_on(exc))
except Exception:
logger.exception("retry_on callable raised; defaulting to False")
return False
def _delay_for_attempt(self, attempt: int, exc: BaseException) -> float:
backoff = _exponential_delay(
attempt,
initial_delay=self.initial_delay,
backoff_factor=self.backoff_factor,
max_delay=self.max_delay,
jitter=self.jitter,
)
header = _extract_retry_after_seconds(exc) or 0.0
return max(backoff, header)
def wrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage:
for attempt in range(self.max_retries + 1):
try:
return handler(request)
except Exception as exc:
if not self._should_retry(exc) or attempt >= self.max_retries:
raise
delay = self._delay_for_attempt(attempt, exc)
ot.add_event(
"model.retry.scheduled",
{
"retry.attempt": attempt + 1,
"retry.max": self.max_retries,
"retry.delay_ms": int(delay * 1000),
"retry.reason": ot_metrics.categorize_exception(exc),
},
)
try:
dispatch_custom_event(
"surfsense.retrying",
{
"attempt": attempt + 1,
"max_retries": self.max_retries,
"delay_ms": int(delay * 1000),
"reason": type(exc).__name__,
},
)
except Exception:
logger.debug(
"dispatch_custom_event failed; suppressed", exc_info=True
)
if delay > 0:
time.sleep(delay)
# Unreachable
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> ModelResponse[ResponseT] | AIMessage:
for attempt in range(self.max_retries + 1):
try:
return await handler(request)
except Exception as exc:
if not self._should_retry(exc) or attempt >= self.max_retries:
raise
delay = self._delay_for_attempt(attempt, exc)
ot.add_event(
"model.retry.scheduled",
{
"retry.attempt": attempt + 1,
"retry.max": self.max_retries,
"retry.delay_ms": int(delay * 1000),
"retry.reason": ot_metrics.categorize_exception(exc),
},
)
try:
await adispatch_custom_event(
"surfsense.retrying",
{
"attempt": attempt + 1,
"max_retries": self.max_retries,
"delay_ms": int(delay * 1000),
"reason": type(exc).__name__,
},
)
except Exception:
logger.debug(
"adispatch_custom_event failed; suppressed", exc_info=True
)
if delay > 0:
await asyncio.sleep(delay)
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
__all__ = [
"RetryAfterMiddleware",
"_extract_retry_after_seconds",
"_is_non_retryable",
]

View file

@ -0,0 +1,5 @@
"""Cross-agent shared tools.
Only genuinely cross-agent tool code lives here (currently web_search, imported
directly from its module).
"""

View file

@ -0,0 +1,247 @@
"""
Web search tool for the SurfSense agent.
Provides a unified tool for real-time web searches that dispatches to all
configured search engines: the platform SearXNG instance (always available)
plus any user-configured live-search connectors (Tavily, Linkup, Baidu).
"""
import asyncio
import json
import time
from typing import Any
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from app.db import shielded_async_session
from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger
_LIVE_SEARCH_CONNECTORS: set[str] = {
"TAVILY_API",
"LINKUP_API",
"BAIDU_SEARCH_API",
}
_LIVE_CONNECTOR_SPECS: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
"TAVILY_API": ("search_tavily", False, True, {}),
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
}
_CONNECTOR_LABELS: dict[str, str] = {
"TAVILY_API": "Tavily",
"LINKUP_API": "Linkup",
"BAIDU_SEARCH_API": "Baidu",
}
class WebSearchInput(BaseModel):
"""Input schema for the web_search tool."""
query: str = Field(
description="The search query to look up on the web. Use specific, descriptive terms.",
)
top_k: int = Field(
default=10,
description="Number of results to retrieve (default: 10, max: 50).",
)
def _format_web_results(
documents: list[dict[str, Any]],
*,
max_chars: int = 50_000,
) -> str:
"""Format web search results into XML suitable for the LLM context."""
if not documents:
return "No web search results found."
parts: list[str] = []
total_chars = 0
for doc in documents:
doc_info = doc.get("document") or {}
metadata = doc_info.get("metadata") or {}
title = doc_info.get("title") or "Web Result"
url = metadata.get("url") or ""
content = (doc.get("content") or "").strip()
source = metadata.get("document_type") or doc.get("source") or "WEB_SEARCH"
if not content:
continue
metadata_json = json.dumps(metadata, ensure_ascii=False)
doc_xml = "\n".join(
[
"<document>",
"<document_metadata>",
f" <document_type>{source}</document_type>",
f" <title><![CDATA[{title}]]></title>",
f" <url><![CDATA[{url}]]></url>",
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
"</document_metadata>",
"<document_content>",
f" <chunk id='{url}'><![CDATA[{content}]]></chunk>",
"</document_content>",
"</document>",
"",
]
)
if total_chars + len(doc_xml) > max_chars:
parts.append("<!-- Output truncated to fit context window -->")
break
parts.append(doc_xml)
total_chars += len(doc_xml)
return "\n".join(parts).strip() or "No web search results found."
async def _search_live_connector(
connector: str,
query: str,
search_space_id: int,
top_k: int,
semaphore: asyncio.Semaphore,
) -> list[dict[str, Any]]:
"""Dispatch a single live-search connector (Tavily / Linkup / Baidu)."""
perf = get_perf_logger()
spec = _LIVE_CONNECTOR_SPECS.get(connector)
if spec is None:
return []
method_name, _includes_date_range, includes_top_k, extra_kwargs = spec
kwargs: dict[str, Any] = {
"user_query": query,
"search_space_id": search_space_id,
**extra_kwargs,
}
if includes_top_k:
kwargs["top_k"] = top_k
try:
t0 = time.perf_counter()
async with semaphore, shielded_async_session() as session:
svc = ConnectorService(session, search_space_id)
_, chunks = await getattr(svc, method_name)(**kwargs)
perf.info(
"[web_search] connector=%s results=%d in %.3fs",
connector,
len(chunks),
time.perf_counter() - t0,
)
return chunks
except Exception as e:
perf.warning("[web_search] connector=%s FAILED: %s", connector, e)
return []
def create_web_search_tool(
search_space_id: int | None = None,
available_connectors: list[str] | None = None,
) -> StructuredTool:
"""Factory for the ``web_search`` tool.
Dispatches in parallel to the platform SearXNG instance and any
user-configured live-search connectors (Tavily, Linkup, Baidu).
"""
active_live_connectors: list[str] = []
if available_connectors:
active_live_connectors = [
c for c in available_connectors if c in _LIVE_SEARCH_CONNECTORS
]
engine_names = ["SearXNG (platform default)"]
engine_names.extend(_CONNECTOR_LABELS.get(c, c) for c in active_live_connectors)
engines_summary = ", ".join(engine_names)
description = (
"Search the web for real-time information. "
"Use this for current events, news, prices, weather, public facts, or any "
"question that requires up-to-date information from the internet.\n\n"
f"Active search engines: {engines_summary}.\n"
"All configured engines are queried in parallel and results are merged."
)
_search_space_id = search_space_id
_active_live = active_live_connectors
async def _web_search_impl(query: str, top_k: int = 10) -> str:
from app.services import web_search_service
perf = get_perf_logger()
t0 = time.perf_counter()
clamped_top_k = min(max(1, top_k), 50)
semaphore = asyncio.Semaphore(4)
tasks: list[asyncio.Task[list[dict[str, Any]]]] = []
if web_search_service.is_available():
async def _searxng() -> list[dict[str, Any]]:
async with semaphore:
_result_obj, docs = await web_search_service.search(
query=query,
top_k=clamped_top_k,
)
return docs
tasks.append(asyncio.ensure_future(_searxng()))
if _search_space_id is not None:
for connector in _active_live:
tasks.append(
asyncio.ensure_future(
_search_live_connector(
connector=connector,
query=query,
search_space_id=_search_space_id,
top_k=clamped_top_k,
semaphore=semaphore,
)
)
)
if not tasks:
return "Web search is not available — no search engines are configured."
results_lists = await asyncio.gather(*tasks, return_exceptions=True)
all_documents: list[dict[str, Any]] = []
for result in results_lists:
if isinstance(result, BaseException):
perf.warning("[web_search] a search engine failed: %s", result)
continue
all_documents.extend(result)
seen_urls: set[str] = set()
deduplicated: list[dict[str, Any]] = []
for doc in all_documents:
url = ((doc.get("document") or {}).get("metadata") or {}).get("url", "")
if url and url in seen_urls:
continue
if url:
seen_urls.add(url)
deduplicated.append(doc)
formatted = _format_web_results(deduplicated)
perf.info(
"[web_search] query=%r engines=%d results=%d deduped=%d chars=%d in %.3fs",
query[:60],
len(tasks),
len(all_documents),
len(deduplicated),
len(formatted),
time.perf_counter() - t0,
)
return formatted
return StructuredTool(
name="web_search",
description=description,
coroutine=_web_search_impl,
args_schema=WebSearchInput,
)