From a3d05f6418c268bece1591eb2c8171f45e3fa70c Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 5 Jun 2026 17:39:38 +0200 Subject: [PATCH] docs(agents): tighten docstrings and comments across agent module Recursive pass over the agents module to make docstrings and inline comments concise and intent-oriented: drop narration that just restates the code, condense verbose module/function docstrings, and keep only the non-obvious "why" notes. No functional code changed. --- .../middleware/action_log/middleware.py | 34 +-- .../middleware/busy_mutex/middleware.py | 47 +--- .../task_tool.py | 142 ++++------- .../middleware/kb_persistence/middleware.py | 225 +++++------------ .../main_agent/tools/scrape_webpage.py | 25 +- .../multi_agent_chat/shared/feature_flags.py | 77 +----- .../shared/middleware/knowledge_search.py | 82 ++----- .../multi_agent_chat/shared/tools/hitl.py | 25 +- .../deliverables/tools/generate_image.py | 23 +- .../deliverables/tools/knowledge_base.py | 93 ++----- .../builtins/deliverables/tools/report.py | 92 ++----- .../builtins/research/tools/scrape_webpage.py | 25 +- .../app/agents/chat/runtime/llm_config.py | 231 ++++-------------- .../app/agents/chat/runtime/prompt_caching.py | 145 +++-------- .../chat/shared/middleware/compaction.py | 70 ++---- .../app/agents/podcaster/nodes.py | 38 +-- 16 files changed, 319 insertions(+), 1055 deletions(-) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/action_log/middleware.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/action_log/middleware.py index c383ae12f..789705d0e 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/action_log/middleware.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/action_log/middleware.py @@ -1,25 +1,15 @@ """Append-only action-log middleware for the SurfSense agent. -Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes -a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt -into reversibility by declaring a ``reverse`` callable on their -:class:`ToolDefinition`; the rendered descriptor is persisted in -``reverse_descriptor`` for use by +Wraps every tool call and writes a row to :class:`~app.db.AgentActionLog` +after the tool returns. Tools opt into reversibility via a ``reverse`` +callable on their :class:`ToolDefinition`; the rendered descriptor powers ``/api/threads/{thread_id}/revert/{action_id}``. -Design points: - -* **Defensive.** Logging never blocks the agent. We catch every exception - on the DB write path and emit a warning; the tool's ``ToolMessage`` - result is always returned untouched. -* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) + - ``result_id`` + ``reverse_descriptor`` are stored. Tool output text - remains in the LangGraph checkpoint / spilled tool-output files. -* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)`` - with the parsed JSON result when the tool's content is a JSON object; - otherwise the raw text is passed. Exceptions in the reverse callable - are swallowed and logged — a failed descriptor render simply means the - action is NOT marked reversible. +Logging is fully defensive — DB-write failures are swallowed so the tool's +result is always returned untouched. Only metadata (name, capped args, +result_id, reverse_descriptor) is stored; tool output stays in the +checkpoint. Reversibility is best-effort: a reverse callable that raises +just leaves the action non-reversible. """ from __future__ import annotations @@ -203,11 +193,9 @@ class ActionLogMiddleware(AgentMiddleware): ) return - # Surface a side-channel SSE event so the chat tool card can - # render a Revert button immediately after the row is durable. - # ``stream_new_chat`` translates this into a - # ``data-action-log`` SSE event. We DO NOT include the - # ``reverse_descriptor`` payload here; only a presence flag. + # Side-channel event (relayed by ``stream_new_chat`` as a + # ``data-action-log`` SSE) so the tool card can show a Revert button + # once the row is durable. Carries a presence flag, not the descriptor. try: await adispatch_custom_event( "action_log", diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/busy_mutex/middleware.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/busy_mutex/middleware.py index f90e2d179..7a82196d9 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/busy_mutex/middleware.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/busy_mutex/middleware.py @@ -1,32 +1,12 @@ -""" -BusyMutexMiddleware — per-thread asyncio lock + cancel token. +"""Per-thread asyncio lock + cooperative cancel token, keyed by ``thread_id``. -LangChain has no built-in concept of "this thread is already running a -turn — refuse the second concurrent request". Without it, a user -double-clicking "send" or refreshing the page mid-stream can spawn two -turns racing on the same checkpoint, producing duplicated tool calls -and mangled state. +Refuses a second concurrent turn on the same thread (e.g. double-clicked +"send") that would otherwise race on the same checkpoint and duplicate tool +calls. Also exposes a per-thread cancel event that long-running tools poll +via ``runtime.context.cancel_event.is_set()`` to abort cooperatively. -Ported from OpenCode's ``Stream.scoped(AbortController)`` pattern: a -single-process, in-memory lock + cooperative cancellation token keyed by -``thread_id``. For multi-worker deployments a distributed lock backend -(Redis or PostgreSQL advisory locks) is a phase-2 follow-up. - -What this provides: -- A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``; - acquiring the lock during ``before_agent`` blocks any concurrent - prompt on the same thread until release. -- A per-thread ``asyncio.Event`` (``cancel_event``) that long-running - tools can poll to abort cooperatively. The event is reset between - turns. Tools should check ``runtime.context.cancel_event.is_set()`` - in tight inner loops. -- A typed :class:`~app.agents.chat.runtime.errors.BusyError` raised when a - second turn arrives while the lock is held. - -Note: SurfSense's ``stream_new_chat`` is the call site that should -acquire/release. Wiring this as middleware means the contract is -explicit and the lock manager is shared with subagents that compile -their own ``create_agent`` runnables. +Process-local and in-memory; multi-worker deployments need a distributed lock +(Redis / PostgreSQL advisory locks) as a follow-up. """ from __future__ import annotations @@ -152,9 +132,8 @@ class _ThreadLockManager: return True -# Module-level singleton — process-local but reused across all agent -# instances built in this process. Subagents created in nested -# ``create_agent`` calls also get this so locks are coherent. +# Process-local singleton shared across all agents/subagents built in this +# process so per-thread locks stay coherent. manager = _ThreadLockManager() @@ -266,7 +245,6 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo await lock.acquire() epoch = manager.bump_turn_epoch(thread_id) self._held_locks[thread_id] = (lock, epoch) - # Reset the cancel event so this turn starts fresh reset_cancel(thread_id) return None @@ -289,17 +267,14 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo return None if lock.locked(): lock.release() - # Always clear cancel event between turns so a stale signal - # doesn't leak into the next request. + # Clear cancel event so a stale signal doesn't leak into the next turn. reset_cancel(thread_id) return None - # Provide sync no-ops because the middleware base class allows them def before_agent( # type: ignore[override] self, state: AgentState[Any], runtime: Runtime[ContextT] ) -> dict[str, Any] | None: - # Sync path: no asyncio.Lock to acquire. Best we can do is reject - # if anyone else is in flight. + # Sync path can't await an asyncio.Lock; only reject if one is in flight. thread_id = self._thread_id(runtime) if thread_id is None: if self._require_thread_id: diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/task_tool.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/task_tool.py index fd303a60e..ab825501a 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/task_tool.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/checkpointed_subagent_middleware/task_tool.py @@ -82,13 +82,10 @@ _T = TypeVar("_T") async def _ainvoke_with_timeout[T]( coro: Awaitable[_T], *, subagent_type: str, started_at: float ) -> _T: - """Apply :data:`DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS` to ``coro``. + """Apply the subagent invoke timeout to ``coro`` (non-positive disables it). - A non-positive timeout disables the cap (configurable via the - ``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` env var). On expiry the - underlying task is cancelled and :class:`SubagentInvokeTimeoutError` is - raised — the caller wraps it into a synthetic ToolMessage so the - orchestrator can decide what to do. + On expiry the task is cancelled and :class:`SubagentInvokeTimeoutError` is + raised for the caller to turn into a synthetic ToolMessage. """ timeout = DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS if timeout <= 0: @@ -151,12 +148,9 @@ def build_task_tool_with_parent_config( subagent_graphs: dict[str, Runnable] = { spec["name"]: spec["runnable"] for spec in subagents } - # Per-subagent context-hint providers (see ``SurfSenseSubagentSpec``). - # The mapping is sparse: only routes that opted in via ``pack_subagent`` - # appear here, and the value is invoked once per ``task(...)`` call to - # generate a short string prepended to the subagent's first - # ``HumanMessage``. Failures are logged and swallowed — a broken hint - # provider must never prevent the underlying task from running. + # Sparse map of opt-in context-hint providers; each runs once per task() + # call to prepend a string to the subagent's first HumanMessage. Failures + # are swallowed so a broken hint never blocks the task. subagent_hint_providers: dict[str, ContextHintProvider] = { spec["name"]: provider for spec in subagents @@ -178,24 +172,18 @@ def build_task_tool_with_parent_config( def _billable_call_update( subagent_type: str, runtime: ToolRuntime ) -> dict[str, Any]: - """Build the per-call ``billable_calls`` delta + an optional warning. + """Build the per-call ``billable_calls`` delta plus an optional soft-cap warning. - The orchestrator's ``billable_calls`` map is summed by - :func:`_int_counter_merge_reducer`, so we always emit - ``{subagent_type: 1}`` and let the reducer accumulate. If the - cumulative count *after* this call would cross the configured - threshold, we also slip a soft ``messages`` entry into the update - so the orchestrator can read it on its next step and self-limit. - Returning a plain ``dict`` (vs. an extra :class:`Command`) keeps - the helper composable with the existing single/batch return paths. + Always emits ``{subagent_type: 1}`` (a reducer accumulates it); when this + call would cross the threshold, also adds a soft ``messages`` entry so the + orchestrator self-limits on its next step. """ delta: dict[str, Any] = {"billable_calls": {subagent_type: 1}} threshold = DEFAULT_SUBAGENT_BILLABLE_THRESHOLD if threshold <= 0: return delta prior = runtime.state.get("billable_calls") or {} - # ``prior`` may be a plain dict or a reducer-managed mapping; only - # int values are counted so a malformed checkpoint can't crash us. + # Count int values only so a malformed checkpoint can't crash us. prior_total = sum(v for v in prior.values() if isinstance(v, int)) new_total = prior_total + 1 if prior_total < threshold <= new_total: @@ -214,8 +202,7 @@ def build_task_tool_with_parent_config( """Merge the per-call billable counter (and warning) into ``cmd``.""" delta = _billable_call_update(subagent_type, runtime) warn_text = delta.pop("_billable_warn_text", None) - # ``cmd.update`` may be a dict or LangGraph ``UpdateDict``; defensively - # copy so we don't mutate state shared across other tool returns. + # Copy so we don't mutate state shared with other tool returns. update = dict(getattr(cmd, "update", {}) or {}) for key, value in delta.items(): update[key] = value @@ -228,14 +215,10 @@ def build_task_tool_with_parent_config( return Command(update=update) def _safe_message_text(msg: Any) -> str: - """Pull text out of a BaseMessage without trusting the ``.text`` property. + """Pull text out of a BaseMessage without using the ``.text`` property. - ``BaseMessage.text`` walks ``content_blocks`` and crashes with - ``TypeError: 'NoneType' object is not iterable`` when ``content`` is - ``None`` (common for tool-call AIMessages whose payload is purely - structured). ``getattr(msg, "text", None)`` does not catch this - because Python evaluates the property body before falling back to - the default. Read ``content`` directly and coerce defensively. + ``.text`` crashes when ``content`` is ``None`` (common for tool-call + AIMessages), and ``getattr`` won't catch it, so read ``content`` directly. """ try: content = getattr(msg, "content", None) @@ -258,23 +241,18 @@ def build_task_tool_with_parent_config( return str(content) def _build_tool_trace(messages: list[Any]) -> list[dict[str, Any]]: - """Compress the subagent's message stream into a compact tool trace. + """Compress the subagent's messages into a compact tool trace. - Each entry is ``{"tool": , "status": "ok"|"error", "preview": - <≤120 chars>}`` so the orchestrator can show "this is what your - specialist actually did" without dumping the full message stream - back through the prompt. The list is attached to the returned - ToolMessage's ``additional_kwargs`` (under ``"surf_tool_trace"``); - the LLM never sees it, but UI / observability code can pluck it - out of the checkpoint. + Entries (``{tool, status, preview}``) ride on the ToolMessage's + ``additional_kwargs["surf_tool_trace"]`` for UI/observability; the LLM + never sees them. """ trace: list[dict[str, Any]] = [] for msg in messages: tool_name = getattr(msg, "name", None) tool_call_id_attr = getattr(msg, "tool_call_id", None) if not tool_name and not tool_call_id_attr: - # Only ToolMessages have either field; skip AIMessage / - # HumanMessage / SystemMessage frames. + # Only ToolMessages carry either field. continue status = getattr(msg, "status", None) or "ok" preview = _safe_message_text(msg).strip().replace("\n", " ") @@ -308,8 +286,7 @@ def build_task_tool_with_parent_config( ) raise ValueError(msg) message_text = _safe_message_text(messages[-1]).rstrip() - # Tool-trace is purely observability — wrap defensively so a single - # malformed frame never bubbles up and kills the whole user turn. + # Trace is observability-only; never let a bad frame kill the turn. try: tool_trace = _build_tool_trace(messages) except Exception: @@ -320,10 +297,7 @@ def build_task_tool_with_parent_config( tool_trace = [] tool_msg = ToolMessage(message_text, tool_call_id=tool_call_id) if tool_trace: - # ``additional_kwargs`` is a free-form dict on BaseMessage; using - # a ``surf_`` prefix avoids collision with provider-specific keys - # (e.g. Anthropic's ``cache_control``). The LLM doesn't see it; - # consumers (UI, observability) read it off the checkpoint. + # surf_ prefix avoids collision with provider keys (e.g. cache_control). tool_msg.additional_kwargs["surf_tool_trace"] = tool_trace return Command( update={ @@ -361,9 +335,7 @@ def build_task_tool_with_parent_config( } hint = _resolve_context_hint(subagent_type, description, runtime) if hint: - # Prepend as a tagged block so the subagent prompt can pattern-match - # on the section (and a future change can lift it into its own - # ``SystemMessage`` if needed). + # Tagged block so the subagent prompt can pattern-match the section. payload = f"\n{hint}\n\n\n{description}" else: payload = description @@ -374,16 +346,12 @@ def build_task_tool_with_parent_config( results: list[tuple[int, str, dict | str, dict | None]], runtime: ToolRuntime, ) -> Command: - """Combine per-child results into one Command with a combined ToolMessage. + """Combine per-child results into one Command with an aggregate ToolMessage. - ``results`` is a list of ``(task_index, subagent_type, - payload_or_error_text, child_state_update)`` tuples — preserving the - input order so the orchestrator can map each block back to the task - it dispatched. State updates are merged by reducer for keys outside - :data:`EXCLUDED_STATE_KEYS`; everything else (``messages``, ``todos``, - etc.) is replaced by the synthesized aggregate ToolMessage. Every - child also contributes a ``billable_calls`` increment so cost - accounting matches single-mode dispatch. + ``results`` tuples are ``(task_index, subagent_type, payload_or_error, + child_state_update)``; output blocks are sorted by index so the LLM can + map them back to dispatch order, and each child contributes a + ``billable_calls`` increment to match single-mode accounting. """ results.sort(key=lambda r: r[0]) merged_state: dict[str, Any] = {} @@ -424,8 +392,8 @@ def build_task_tool_with_parent_config( } ) if state_update: - # Naive merge: later tasks win on scalar collisions; reducer-backed - # fields (``receipts``, ``files`` etc.) accumulate at apply time. + # Later tasks win on scalar collisions; reducer-backed fields + # accumulate at apply time. merged_state.update(state_update) aggregate = "\n\n".join(message_blocks) aggregate_msg = ToolMessage( @@ -469,11 +437,9 @@ def build_task_tool_with_parent_config( ) -> tuple[int, str, dict | str, dict | None]: """Run one child of a batched ``task`` call under the concurrency cap. - Errors are returned as plain text in slot 2 so a single child's - failure does not abort the whole batch. ``GraphInterrupt`` from a - batched child is currently treated as a hard failure for that child - only — batched HITL is intentionally out of scope for the v1 - rollout (see plan tier 2 item 4 risks). + Errors are returned as text (slot 2) so one child's failure doesn't abort + the batch. A child's ``GraphInterrupt`` is a hard failure for that child: + batched HITL is intentionally out of scope. """ async with semaphore: if subagent_type not in subagent_graphs: @@ -507,8 +473,7 @@ def build_task_tool_with_parent_config( ) return (task_index, subagent_type, str(exc), None) except GraphInterrupt: - # Batched HITL is unsupported in v1 — surface as a failure - # for this child so the rest of the batch still completes. + # Batched HITL unsupported; fail this child so the batch finishes. logger.warning( "Batch child %d (%s) raised GraphInterrupt; batched HITL " "is not supported. Re-dispatch this task as a single " @@ -545,14 +510,11 @@ def build_task_tool_with_parent_config( return (task_index, subagent_type, result, child_state_update) def _coerce_batch_arg(tasks: Any) -> list[dict] | str: - """Rescue common LLM-side malformations of the ``tasks`` argument. + """Rescue common LLM malformations of the ``tasks`` argument. - Some providers serialise an array argument as a JSON-encoded string, - and small models occasionally hand back a single ``{description, - subagent_type}`` dict instead of a one-element array. Both are - recovered here with a WARN log so the issue is visible in metrics - but the user's turn still completes; truly broken shapes return a - plain string that the caller surfaces as the tool error. + Recovers a JSON-encoded array string and a single dict (instead of a + 1-element array), logging a WARN. Unrecoverable shapes return a string + the caller surfaces as the tool error. """ if isinstance(tasks, list): return tasks @@ -587,13 +549,10 @@ def build_task_tool_with_parent_config( async def _adispatch_batch( tasks: list[dict], runtime: ToolRuntime ) -> Command | str: - """Fan-out helper for the ``tasks`` array shape. + """Fan out the ``tasks`` array (size- and concurrency-capped). - Bounded by :data:`MAX_SUBAGENT_BATCH_SIZE` and concurrency-capped - at :data:`DEFAULT_SUBAGENT_BATCH_CONCURRENCY`. Returns a single - :class:`Command` that the LLM sees as one ToolMessage per child, - prefixed with ``[task ]`` so it can map back to the input - order. + Returns one Command; the LLM sees one ``[task ]``-prefixed block + per child, in input order. """ if not tasks: return "tasks: array is empty; nothing to dispatch." @@ -703,17 +662,16 @@ def build_task_tool_with_parent_config( if pending_value is not None: resume_value = consume_surfsense_resume(runtime) if resume_value is None: - # Bridge invariant: a queued resume must accompany any pending - # subagent interrupt. Fall-through replay would silently re-prompt - # the user; raise so the streaming layer surfaces a clear error. + # A pending interrupt must have a queued resume; otherwise replay + # would silently re-prompt the user. Raise instead. raise RuntimeError( f"Subagent {subagent_type!r} has a pending interrupt but no " "surfsense_resume_value on config; resume bridge is broken." ) expected = hitlrequest_action_count(pending_value) resume_value = fan_out_decisions_to_match(resume_value, expected) - # Prevent the parent's resume payload from leaking into subagent - # interrupts via langgraph's parent_scratchpad fallback. + # Stop the parent's resume leaking into subagent interrupts via + # langgraph's parent_scratchpad fallback. drain_parent_null_resume(runtime) with ot.subagent_invoke_span( subagent_type=subagent_type, path=invoke_path @@ -829,10 +787,8 @@ def build_task_tool_with_parent_config( ] = None, ) -> str | Command: atask_start = time.perf_counter() - # Kill switch: when ops flips the spawn-paused flag for this - # workspace, every ``task(...)`` invocation (single- or batch-mode) - # short-circuits with a clear ToolMessage so the orchestrator can - # tell the user what happened and stop hammering downstream APIs. + # Ops kill switch: short-circuit every task() call for this workspace + # so the orchestrator stops hammering downstream APIs. if await is_spawn_paused(search_space_id): logger.warning( "[hitl_route] atask SPAWN_PAUSED: search_space_id=%s tool_call_id=%s", @@ -923,8 +879,8 @@ def build_task_tool_with_parent_config( ) expected = hitlrequest_action_count(pending_value) resume_value = fan_out_decisions_to_match(resume_value, expected) - # Prevent the parent's resume payload from leaking into subagent - # interrupts via langgraph's parent_scratchpad fallback. + # Stop the parent's resume leaking into subagent interrupts via + # langgraph's parent_scratchpad fallback. drain_parent_null_resume(runtime) with ot.subagent_invoke_span( subagent_type=subagent_type, path=invoke_path diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py index 9e2d9a8d5..747ddacd3 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py @@ -1,33 +1,19 @@ """End-of-turn persistence for the cloud-mode SurfSense filesystem. -This middleware runs ``aafter_agent`` once per turn (cloud only). It commits -all staged folder creations, file moves, content writes/edits, file deletes -(``rm``), and directory deletes (``rmdir``) to Postgres in a single ordered -pass: +Runs ``aafter_agent`` once per turn (cloud only), committing staged folder +creates, moves, writes/edits, and ``rm``/``rmdir`` to Postgres in one ordered +pass. Order matters: moves resolve before writes (so write-then-move lands at +the final path), and file deletes run before directory deletes (so a same-turn +``rm /a/x.md`` + ``rmdir /a`` works). -1. Materialize ``staged_dirs`` into ``Folder`` rows. -2. Apply ``pending_moves`` in order (chained moves resolved via - ``doc_id_by_path``). -3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move - sequences commit at the final path. Paths queued for ``rm`` this turn - are dropped here so a write+rm sequence doesn't recreate the doc. -4. Commit content writes / edits for ``/documents/*`` paths, skipping - ``temp_*`` basenames. -5. Apply ``pending_deletes`` (``rm``) — file deletes run BEFORE directory - deletes so a same-turn ``rm /a/x.md`` + ``rmdir /a`` sequence works. -6. Apply ``pending_dir_deletes`` (``rmdir``); re-verifies emptiness against - the post-step-5 DB state. +When ``flags.enable_action_log`` is on, each destructive op also snapshots a +``DocumentRevision`` / ``FolderRevision`` for revert. For ``rm``/``rmdir`` the +snapshot and DELETE share a SAVEPOINT, so a failed snapshot aborts the delete +rather than making the data silently irreversible. -When ``flags.enable_action_log`` is on every destructive op also writes a -``DocumentRevision`` / ``FolderRevision`` snapshot bound to the -originating ``AgentActionLog`` row via ``tool_call_id``. ``rm``/``rmdir`` -share a single ``SAVEPOINT`` with their snapshot — if the snapshot fails -the DELETE rolls back and we surface the error rather than silently -making the data irreversible. - -The commit body is exposed as a free function ``commit_staged_filesystem_state`` -so the optional stream-task fallback (``stream_new_chat.py``) can call the -exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect). +The commit body is a free function (``commit_staged_filesystem_state``) so the +stream-task fallback can run the identical routine when ``aafter_agent`` was +skipped (e.g. client disconnect). """ from __future__ import annotations @@ -216,11 +202,9 @@ async def _create_document( virtual_path, search_space_id, ) - # Filesystem-parity invariant: the only thing that *must* be unique is - # the path. Two notes can legitimately share content (e.g. ``cp a b``). - # Guard against the path-derived ``unique_identifier_hash`` constraint - # so we surface a clean ValueError instead of letting the INSERT poison - # the session with an IntegrityError. + # Pre-check the path-derived unique_identifier_hash so a duplicate path + # surfaces as a clean ValueError instead of an INSERT IntegrityError that + # poisons the session. Content is intentionally not unique (cp a b). path_collision = await session.execute( select(Document.id).where( Document.search_space_id == search_space_id, @@ -232,13 +216,6 @@ async def _create_document( f"a document already exists at path '{virtual_path}' " "(unique_identifier_hash collision)" ) - # ``content_hash`` is intentionally NOT checked for uniqueness here. - # In a real filesystem two files at different paths can hold identical - # bytes, and the agent's ``write_file`` path needs that semantic to - # support copy/duplicate operations. The hash remains useful as a - # change-detection hint for connector indexers, which still consult it - # via :func:`check_duplicate_document` but do so with a non-unique - # lookup (``.first()``). content_hash = generate_content_hash(content, search_space_id) doc = Document( title=title, @@ -435,15 +412,9 @@ async def _mark_action_reversible( ) -> None: """Flip ``agent_action_log.reversible = TRUE`` for ``action_id``. - Best-effort: caller may invoke from inside a SAVEPOINT and treat - failure as a soft demotion (snapshot persists, just no Revert button). - - Callers should also call ``_dispatch_reversibility_update`` (defined - below) AFTER the enclosing SAVEPOINT block exits successfully so the - chat tool card can light up its Revert button without - re-fetching ``GET /threads/.../actions``. Dispatching from inside the - SAVEPOINT would risk emitting "reversible=true" for rows whose - update gets rolled back if the surrounding destructive op fails. + Pair with ``_dispatch_reversibility_update`` *after* the enclosing + SAVEPOINT commits, so the UI never sees ``reversible=true`` for a row whose + update later rolls back. """ if action_id is None: return @@ -455,22 +426,11 @@ async def _mark_action_reversible( async def _dispatch_reversibility_update(action_id: int | None) -> None: - """Best-effort dispatch of an ``action_log_updated`` custom event. + """Emit an ``action_log_updated`` SSE event so the Revert button lights up. - Surfaces the post-SAVEPOINT reversibility flip to the SSE layer so - the chat tool card can flip its Revert button live. Defensive: - failures are logged at debug level and swallowed; the - REST endpoint ``GET /threads/.../actions`` is still authoritative. - - .. warning:: - Inside :func:`commit_staged_filesystem_state` we DEFER all - dispatches until the outer ``session.commit()`` succeeds — see - the ``deferred_dispatches`` queue in that function. Dispatching - from inside a SAVEPOINT block while the outer transaction is - still pending would emit ``reversible=true`` for rows whose - snapshots get rolled back if the outer commit fails. Direct - callers (e.g. the optional stream-task fallback) that own the - full session lifetime can still call this helper inline. + Best-effort (failures swallowed; the REST actions endpoint is + authoritative). Inside :func:`commit_staged_filesystem_state` this is + deferred until after the outer commit via ``deferred_dispatches``. """ if action_id is None: return @@ -489,12 +449,9 @@ async def _dispatch_reversibility_update(action_id: int | None) -> None: # --------------------------------------------------------------------------- # Snapshot helpers # --------------------------------------------------------------------------- -# -# Best-effort helpers swallow + log so a snapshot failure can never break -# the destructive op for non-destructive tools (write/edit/move/mkdir). -# Strict helpers run inside the SAME ``begin_nested()`` SAVEPOINT as the -# destructive DELETE — failure aborts the savepoint and leaves the doc / -# folder intact, so revertable ops never become irreversible silently. +# Best-effort variants (write/edit/move/mkdir) swallow failures. Strict +# variants (rm/rmdir) share the destructive op's SAVEPOINT so a snapshot +# failure aborts the delete instead of making it silently irreversible. def _doc_revision_payload( @@ -704,15 +661,9 @@ async def commit_staged_filesystem_state( ) -> dict[str, Any] | None: """Commit all staged filesystem changes; return the state delta for reducers. - Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` - and the optional stream-task fallback. - - When ``flags.enable_action_log`` is on every destructive op also writes - a ``DocumentRevision`` / ``FolderRevision`` snapshot bound to the - originating ``AgentActionLog`` row via ``tool_call_id``. Snapshot - durability is best-effort for non-destructive ops and STRICT for - ``rm``/``rmdir`` (snapshot + DELETE share a SAVEPOINT — snapshot - failure aborts the delete). + Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` and + the stream-task fallback. See the module docstring for ordering and the + action-log snapshot/revert semantics. """ if filesystem_mode != FilesystemMode.CLOUD: return None @@ -771,8 +722,7 @@ async def commit_staged_filesystem_state( flags = get_flags() snapshot_enabled = flags.enable_action_log - # De-duplicate pending deletes per-path while preserving the latest - # tool_call_id (the one the user is most likely to revert via the UI). + # De-dup deletes per-path, keeping the latest tool_call_id (likeliest revert). file_delete_paths: dict[str, str] = {} for entry in pending_deletes: if not isinstance(entry, dict): @@ -796,22 +746,14 @@ async def commit_staged_filesystem_state( applied_moves: list[dict[str, Any]] = [] doc_id_path_tombstones: dict[str, int | None] = {} tree_changed = False - # Reversibility-flip dispatches are deferred until AFTER the outer - # ``session.commit()`` succeeds. Dispatching from inside the - # SAVEPOINT chain while the outer transaction is still pending - # would emit ``reversible=true`` for rows whose snapshots get rolled - # back if the final commit raises. Snapshot helpers append on - # success; we drain this list after commit and silently abandon it - # on rollback so the UI stays consistent with durable state. + # Reversibility-flip dispatches are drained only after the outer commit + # succeeds (and abandoned on rollback), so the UI never sees reversible=true + # for a snapshot that didn't durably land. deferred_dispatches: list[int] = [] try: async with shielded_async_session() as session: - # ------------------------------------------------------------------ - # Resolve action-id bindings up front. One SELECT per turn for all - # tool_call_ids, NOT one per op — important because a turn that - # touches 50 paths would otherwise issue 50 lookups. - # ------------------------------------------------------------------ + # Resolve all action-id bindings in one SELECT per turn, not per op. action_id_by_call: dict[str, int] = {} if snapshot_enabled and thread_id is not None: tool_call_ids: set[str] = set() @@ -844,10 +786,7 @@ async def commit_staged_filesystem_state( next(iter(action_id_by_call), None) if action_id_by_call else None ) - # ------------------------------------------------------------------ - # 1. staged_dirs -> Folder rows. Snapshot post-flush so the new - # folder_id is available for the FK. - # ------------------------------------------------------------------ + # 1. staged_dirs -> Folder rows (snapshot post-flush for the FK). for folder_path in staged_dirs: if not isinstance(folder_path, str): continue @@ -868,7 +807,6 @@ async def commit_staged_filesystem_state( tcid = staged_dir_tool_calls.get(folder_path) action_id = _action_id_for(tcid) if action_id is not None: - # Re-read the folder for the snapshot. result = await session.execute( select(Folder).where(Folder.id == folder_id) ) @@ -883,16 +821,13 @@ async def commit_staged_filesystem_state( deferred_dispatches=deferred_dispatches, ) - # ------------------------------------------------------------------ - # 2. pending_moves. Snapshot pre-move (in-place restore on revert). - # ------------------------------------------------------------------ + # 2. pending_moves (snapshot pre-move for in-place restore on revert). for move in pending_moves: source = str(move.get("source") or "") if snapshot_enabled and source: tcid = str(move.get("tool_call_id") or "") action_id = _action_id_for(tcid) if action_id is not None: - # Resolve the doc to snapshot BEFORE we mutate it. doc_id_pre = doc_id_by_path.get(source) document_pre: Document | None = None if doc_id_pre is not None: @@ -942,10 +877,8 @@ async def commit_staged_filesystem_state( path = move_alias[path] return path - # ------------------------------------------------------------------ - # 3. dirty_paths -> writes/edits. Skip any path queued for ``rm`` - # this turn so a write+rm sequence doesn't recreate the doc. - # ------------------------------------------------------------------ + # 3. dirty_paths -> writes/edits. Paths queued for rm this turn are + # skipped so a write+rm sequence doesn't recreate the doc. kb_dirty_seen: set[str] = set() kb_dirty: list[str] = [] kb_dirty_origin: dict[str, str] = {} @@ -974,9 +907,7 @@ async def commit_staged_filesystem_state( continue content = "\n".join(file_data.get("content") or []) doc_id = doc_id_by_path.get(path) - # Path ↔ tool_call_id binding: the dirty_paths list dedupes via - # _add_unique_reducer, so we look up the latest tool_call_id by - # path (or by the un-renamed origin). + # Look up tool_call_id by final path or its pre-rename origin. origin = kb_dirty_origin.get(path, path) tcid = dirty_path_tool_calls.get(path) or dirty_path_tool_calls.get( origin @@ -984,12 +915,9 @@ async def commit_staged_filesystem_state( action_id = _action_id_for(tcid) if doc_id is None: - # The in-memory ``doc_id_by_path`` is per-thread and starts - # empty in every new chat. If the agent writes to a path - # that already exists in the DB (e.g. a previous chat's - # ``notes.md``), we must NOT try to INSERT — it would hit - # ``unique_identifier_hash`` (path-derived). Look up the - # existing doc and update it in place instead. + # doc_id_by_path is per-thread and empty in a new chat, so a + # write to a path already in the DB must update in place, not + # INSERT (which would hit the path-derived unique hash). existing = await virtual_path_to_doc( session, search_space_id=search_space_id, @@ -1038,12 +966,9 @@ async def commit_staged_filesystem_state( } ) else: - # Fresh create. Wrap each create in a SAVEPOINT so a - # residual ``IntegrityError`` (e.g. a deployment that - # hasn't run migration 133 yet, where - # ``documents.content_hash`` still carries its legacy - # global UNIQUE constraint) rolls back only this one - # create instead of poisoning the whole turn. + # Fresh create, wrapped in a SAVEPOINT so a residual + # IntegrityError (e.g. pre-migration-133 content_hash UNIQUE) + # rolls back only this create, not the whole turn. placeholder_revision_id: int | None = None if snapshot_enabled and action_id is not None: placeholder_revision_id = await _snapshot_document_pre_create( @@ -1066,8 +991,7 @@ async def commit_staged_filesystem_state( logger.warning( "kb_persistence: skipping %s create: %s", path, exc ) - # Roll back the placeholder revision since the create - # never happened. + # Create never happened; drop its placeholder revision. if placeholder_revision_id is not None: await session.execute( delete(DocumentRevision).where( @@ -1114,19 +1038,14 @@ async def commit_staged_filesystem_state( ) tree_changed = True - # ------------------------------------------------------------------ - # 4. pending_deletes -> ``rm``. STRICT durability: snapshot + DELETE - # share a SAVEPOINT. If the snapshot insert fails, the DELETE - # rolls back too and we surface the error rather than silently - # making the data irreversible. - # ------------------------------------------------------------------ + # 4. pending_deletes -> rm. Strict: snapshot + DELETE share a + # SAVEPOINT, so a failed snapshot rolls the delete back too. for raw_path, tcid in file_delete_paths.items(): final = _final_path(raw_path) if not final.startswith(DOCUMENTS_ROOT + "/"): continue action_id = _action_id_for(tcid) - # Resolve the doc. doc_id_for_delete = doc_id_by_path.get(final) document_to_delete: Document | None = None if doc_id_for_delete is not None: @@ -1155,7 +1074,6 @@ async def commit_staged_filesystem_state( try: async with session.begin_nested(): - # Strict: snapshot first; failure aborts the delete. if snapshot_enabled and action_id is not None: chunks = await _load_chunks_for_snapshot( session, doc_id=doc_pk @@ -1184,10 +1102,7 @@ async def commit_staged_filesystem_state( ) continue - # B1 — SAVEPOINT released. Defer the reversibility-flip - # dispatch until AFTER the outer commit succeeds so we - # never tell the UI a row is reversible if its snapshot - # gets rolled back. + # Defer the reversibility flip until after the outer commit. if snapshot_enabled and action_id is not None: deferred_dispatches.append(int(action_id)) @@ -1206,11 +1121,8 @@ async def commit_staged_filesystem_state( ) tree_changed = True - # ------------------------------------------------------------------ - # 5. pending_dir_deletes -> ``rmdir``. STRICT durability + final - # emptiness check (after step 4's deletes have run, an "empty - # mid-turn" directory really IS empty in DB now). - # ------------------------------------------------------------------ + # 5. pending_dir_deletes -> rmdir. Strict, and re-checks emptiness + # against post-step-4 DB state. for raw_path, tcid in dir_delete_paths.items(): final = _final_path(raw_path) if not final.startswith(DOCUMENTS_ROOT + "/"): @@ -1231,7 +1143,6 @@ async def commit_staged_filesystem_state( ) continue - # Re-check emptiness against in-DB state. docs_in_folder = await session.execute( select(Document.id) .where(Document.folder_id == folder_id) @@ -1296,10 +1207,7 @@ async def commit_staged_filesystem_state( ) continue - # B1 — SAVEPOINT released. Defer the reversibility-flip - # dispatch until AFTER the outer commit succeeds so we - # never tell the UI a row is reversible if its snapshot - # gets rolled back. + # Defer the reversibility flip until after the outer commit. if snapshot_enabled and action_id is not None: deferred_dispatches.append(int(action_id)) @@ -1319,18 +1227,13 @@ async def commit_staged_filesystem_state( logger.exception( "kb_persistence: commit failed (search_space=%s)", search_space_id ) - # Outer commit raised — every SAVEPOINT-released change above - # (snapshots + reversibility flips) is now rolled back. Drop - # the deferred SSE dispatches so the UI stays consistent with - # durable state. + # Outer commit raised: everything above rolled back, so drop the + # deferred dispatches. deferred_dispatches.clear() return None - # Outer commit succeeded; flush deferred reversibility-flip - # dispatches now so the chat tool card can light up its Revert - # button without re-fetching ``GET /threads/.../actions``. De-dup - # to avoid emitting the same id twice (e.g. write-then-rm in the - # same turn dispatches once for each snapshot site). + # Commit succeeded; flush deferred reversibility flips (de-duped, since + # write-then-rm in one turn appends an id per snapshot site). if deferred_dispatches and dispatch_events: for action_id in dict.fromkeys(deferred_dispatches): try: @@ -1376,9 +1279,8 @@ async def commit_staged_filesystem_state( p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX) ] - # Tombstone every committed-delete path so a stale ``state["files"]`` entry - # (which als_info would otherwise interpret as content) cannot survive into - # the next turn and make a now-empty folder look non-empty. + # Tombstone committed-delete paths so a stale state["files"] entry can't + # survive into the next turn and make a now-empty folder look non-empty. deleted_file_paths = [ str(payload.get("virtualPath") or "") for payload in committed_deletes @@ -1399,11 +1301,8 @@ async def commit_staged_filesystem_state( "dirty_path_tool_calls": {_CLEAR: True}, } - # Emit one Receipt per committed mutation, folded into ``state['receipts']`` - # via ``_list_append_reducer``. The receipts surface what actually committed - # (post-savepoint) rather than what the LLM intended; the orchestrator uses - # them as ground truth in the ```` teaching. KB writes do not - # have public verifiable URLs, so ``verifiable_url`` stays unset. + # One Receipt per committed mutation: ground truth (post-savepoint) for the + # orchestrator's teaching. KB writes have no public URL. receipts: list[Receipt] = [] def _kb_receipt( @@ -1444,8 +1343,6 @@ async def commit_staged_filesystem_state( external_id=payload.get("id"), ) for payload in applied_moves: - # ``applied_moves`` rows carry the destination ``virtualPath`` because - # the move has already landed in the DB by the time we reach this code. path = str(payload.get("virtualPath") or "") _kb_receipt( type="file", @@ -1485,9 +1382,7 @@ async def commit_staged_filesystem_state( if tree_changed: delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1 - # Avoid 'unused' lint when turn_id_for_revision was only useful for - # diagnostic purposes inside the SAVEPOINT chain above. - _ = turn_id_for_revision + _ = turn_id_for_revision # diagnostic-only; silence unused lint logger.info( "kb_persistence: commit (search_space=%s) creates=%d updates=%d " diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/scrape_webpage.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/scrape_webpage.py index 014126927..24a686da1 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/scrape_webpage.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/scrape_webpage.py @@ -29,7 +29,6 @@ def extract_domain(url: str) -> str: try: parsed = urlparse(url) domain = parsed.netloc - # Remove 'www.' prefix if present if domain.startswith("www."): domain = domain[4:] return domain @@ -53,14 +52,13 @@ def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]: if len(content) <= max_length: return content, False - # Try to truncate at a sentence boundary + # Prefer truncating at a sentence/paragraph boundary. truncated = content[:max_length] last_period = truncated.rfind(".") last_newline = truncated.rfind("\n\n") - # Use the later of the two boundaries, or just truncate boundary = max(last_period, last_newline) - if boundary > max_length * 0.8: # Only use boundary if it's not too far back + if boundary > max_length * 0.8: # only if the boundary isn't too far back truncated = content[: boundary + 1] return truncated + "\n\n[Content truncated...]", True @@ -111,8 +109,8 @@ async def _scrape_youtube_video( http_client.proxies.update(residential_proxies) ytt_api = YouTubeTranscriptApi(http_client=http_client) - # List all available transcripts and pick the first one - # (the video's primary language) instead of defaulting to English + # Pick the first transcript (video's primary language) rather than + # defaulting to English. transcript_list = ytt_api.list(video_id) transcript = next(iter(transcript_list)) captions = transcript.fetch() @@ -134,10 +132,8 @@ async def _scrape_youtube_video( logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}") transcript_text = f"No captions available for this video. Error: {e!s}" - # Build combined content content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}" - # Truncate if needed content, was_truncated = truncate_content(content, max_length) word_count = len(content.split()) @@ -212,20 +208,16 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None): scrape_id = generate_scrape_id(url) domain = extract_domain(url) - # Validate and normalize URL if not url.startswith(("http://", "https://")): url = f"https://{url}" try: - # Check if this is a YouTube URL and use transcript API instead + # YouTube URLs use the transcript API instead of crawling. video_id = get_youtube_video_id(url) if video_id: return await _scrape_youtube_video(url, video_id, max_length) - # Create webcrawler connector connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key) - - # Crawl the URL result, error = await connector.crawl_url(url, formats=["markdown"]) if error: @@ -250,28 +242,21 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None): "error": "No content returned from crawler", } - # Extract content and metadata content = result.get("content", "") metadata = result.get("metadata", {}) - # Get title from metadata title = metadata.get("title", "") if not title: title = domain or url.split("/")[-1] or "Webpage" - # Get description from metadata description = metadata.get("description", "") if not description and content: - # Use first paragraph as description first_para = content.split("\n\n")[0] if content else "" description = ( first_para[:300] + "..." if len(first_para) > 300 else first_para ) - # Truncate content if needed content, was_truncated = truncate_content(content, max_length) - - # Calculate word count word_count = len(content.split()) return { diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/feature_flags.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/feature_flags.py index 27188fac3..9564bd195 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/feature_flags.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/feature_flags.py @@ -1,37 +1,9 @@ -""" -Feature flags for the SurfSense new_chat agent stack. +"""Feature flags for the SurfSense new_chat agent stack. -These flags gate the newer agent middleware (some ported from OpenCode, -some sourced from ``langchain.agents.middleware`` / ``deepagents``, some -SurfSense-native). Most shipped agent-stack upgrades default ON so Docker -image updates work even when older installs do not have newly introduced -environment variables. Risky/experimental integrations stay default OFF, -and the master kill-switch can still disable everything new. - -All new middleware checks its flag at agent build time. If the master -kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new -middleware is disabled regardless of its individual flag. This gives -operators a single switch to revert to pre-port behavior. - -Examples --------- - -Defaults: - - SURFSENSE_ENABLE_CONTEXT_EDITING=true - SURFSENSE_ENABLE_COMPACTION_V2=true - SURFSENSE_ENABLE_RETRY_AFTER=true - SURFSENSE_ENABLE_MODEL_FALLBACK=false - SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true - SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true - SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true - SURFSENSE_ENABLE_PERMISSION=true - SURFSENSE_ENABLE_DOOM_LOOP=true - SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call - -Master kill-switch (overrides everything else): - - SURFSENSE_DISABLE_NEW_AGENT_STACK=true +Flags are resolved at agent build time. Most upgrades default ON so Docker +updates work without operators adding new env vars; risky integrations stay +OFF. The master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` forces every +flag below to False for a one-switch rollback to pre-port behavior. """ from __future__ import annotations @@ -93,39 +65,14 @@ class AgentFeatureFlags: # Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT) enable_otel: bool = False - # Performance — compiled-agent cache (Phase 1 + Phase 2). - # When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled - # graph if the cache key matches (LLM config + thread + tool surface + - # flags + system prompt + filesystem mode). Cuts per-turn agent-build - # wall clock from ~4-5s to <50µs on cache hits. - # - # SAFETY (Phase 2 unblocked this default-on): - # All connector mutation tools (``tools/notion``, ``tools/gmail``, - # ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``, - # ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``, - # ``tools/teams``, ``tools/luma``, ``connected_accounts``, - # ``update_memory``) now acquire fresh - # short-lived ``AsyncSession`` instances per call via - # :data:`async_session_maker`. The factory still accepts ``db_session`` - # for registry compatibility but ``del``'s it immediately — see any - # of those files' factory docstrings for the rationale. The ``llm`` - # closure is per-(provider, model, config_id) which is already in - # the cache key, so the LLM is safe to share across cached hits of - # the same key. The KB priority middleware reads - # ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5), - # not its constructor closure, so the same compiled agent serves - # turns with different mention lists correctly. - # - # Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the - # environment if a regression surfaces. The path is exercised by - # the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite. + # Performance — reuse a compiled agent graph when the cache key matches + # (~4-5s -> <50µs per turn). Safe to default-on because mutation tools take + # fresh short-lived sessions per call and per-turn context (mentions, etc.) + # is read from runtime.context, not the constructor closure. Rollback via + # SURFSENSE_ENABLE_AGENT_CACHE=false. enable_agent_cache: bool = True - # Phase 1 (deferred — measure first): pre-build & share the - # general-purpose subagent ``CompiledSubAgent`` across cold-cache - # misses. Only helps when the outer cache MISSES (cache hits already - # reuse the entire SubAgentMiddleware-compiled graph). Off by default - # until we have data showing cold misses are frequent enough to - # justify the extra global state. + # Deferred: only helps on outer-cache MISSES, so off until data shows cold + # misses are frequent enough to justify the extra global state. enable_agent_cache_share_gp_subagent: bool = False @classmethod diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py index cc716b00f..2714c6065 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py @@ -594,14 +594,9 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] inject_system_message: bool = True, # For backwards compatibility ) -> None: self.llm = llm - # The planner LLM handles short, structured internal tasks (query - # rewriting, date extraction, recency classification). When an - # operator marks a global config ``is_planner: true`` we route - # those calls to a cheap/fast model (e.g. gpt-4o-mini, Haiku, Azure - # gpt-5.x-nano) instead of the user's chat LLM — those classification - # tasks don't need frontier-tier capability. Falls back to the chat - # LLM when no planner config is wired up so deployments without one - # keep working unchanged. + # Cheap model for structured internal tasks (query rewrite, date + # extraction, recency classification) when one is configured; falls back + # to the chat LLM otherwise. self.planner_llm = planner_llm or llm self.search_space_id = search_space_id self.filesystem_mode = filesystem_mode @@ -610,26 +605,17 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] self.top_k = top_k self.mentioned_document_ids = mentioned_document_ids or [] self.inject_system_message = inject_system_message - # Build the kb-planner private Runnable ONCE here so we don't pay - # the ``create_agent`` compile cost (50-200ms) on every turn. - # Disabled by default behind ``enable_kb_planner_runnable``; when - # off the planner falls back to the legacy ``planner_llm.ainvoke`` - # path. + # Compiled lazily and memoized to avoid the per-turn create_agent cost. self._planner: Runnable | None = None self._planner_compile_failed = False def _build_kb_planner_runnable(self) -> Runnable | None: - """Compile the kb-planner private :class:`Runnable` once. + """Lazily compile and memoize the kb-planner Runnable. - Returns ``None`` when the feature flag is disabled, when the LLM is - unavailable, or when ``create_agent`` raises (we fall back to the - legacy ``planner_llm.ainvoke`` path in that case). Compilation happens - lazily on first call, then memoized via ``self._planner``. - - The compiled agent is constructed without tools — the planner's - contract is "answer with structured JSON" — but it inherits the - :class:`RetryAfterMiddleware` so transient rate-limit errors - from the planner LLM call don't fail the whole turn. + Returns ``None`` (and the caller falls back to ``planner_llm.ainvoke``) + when the flag is off, the LLM is missing, or ``create_agent`` raises. + Built without tools but with RetryAfterMiddleware so a transient + rate-limit on the planner call doesn't fail the whole turn. """ if self._planner is not None or self._planner_compile_failed: return self._planner @@ -677,10 +663,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] loop = asyncio.get_running_loop() t0 = loop.time() - # Prefer the compiled-once planner Runnable when enabled; otherwise - # fall back to ``planner_llm.ainvoke``. The ``surfsense:internal`` - # tag is preserved on both paths so ``_stream_agent_events`` still - # suppresses the planner's intermediate events from the UI. + # Both paths tag surfsense:internal so the planner's intermediate + # events stay suppressed from the UI. planner = self._build_kb_planner_runnable() try: if planner is not None: @@ -819,32 +803,16 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] user_text=user_text, ) - # Per-turn ``mentioned_document_ids`` flow: - # 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the - # streaming task supplies a fresh :class:`SurfSenseContextSchema` - # on every ``astream_events`` call, so this list is naturally - # scoped to the current turn. Allows cross-turn graph reuse via - # ``agent_cache``. - # 2. Legacy fallback (cache disabled / context not propagated): the - # constructor-injected ``self.mentioned_document_ids`` list. We - # drain it after the first read so a cached graph (no Phase 1.5 - # wiring) doesn't keep replaying the same mentions on every - # turn. + # Prefer per-turn mentions from runtime.context (lets a cached graph + # serve different turns); fall back to the constructor closure, draining + # it after one read so stale mentions can't replay. # - # CRITICAL: distinguish "context absent" (legacy caller, no field at - # all) from "context provided but empty" (turn with no mentions). - # ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in - # Python, so a naive ``if ctx_mentions:`` would fall through to the - # legacy closure on every no-mention follow-up turn — replaying the - # mentions baked in by turn 1's cache-miss build. Always drain the - # closure once the runtime path has fired so a cached middleware - # instance can never resurrect stale state. + # CRITICAL: test ``ctx_mentions is not None``, not truthiness — an empty + # list means "this turn has no mentions", not "use the closure". mention_ids: list[int] = [] ctx = getattr(runtime, "context", None) if runtime is not None else None ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None if ctx_mentions is not None: - # Runtime path is authoritative — even an empty list means - # "this turn has no mentions", NOT "look at the closure". mention_ids = list(ctx_mentions) if self.mentioned_document_ids: self.mentioned_document_ids = [] @@ -852,12 +820,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] mention_ids = list(self.mentioned_document_ids) self.mentioned_document_ids = [] - # Folder mentions live alongside doc mentions on the runtime - # context. They never feed hybrid search (folders aren't - # embedded) — they're surfaced purely as ``[USER-MENTIONED]`` - # priority entries so the agent walks the folder with ``ls`` / - # ``find_documents`` instead of ignoring it. Cloud filesystem - # mode only. + # Folder mentions aren't embedded, so they skip hybrid search and are + # surfaced only as [USER-MENTIONED] entries. Cloud mode only. folder_mention_ids: list[int] = [] if ( ctx is not None @@ -939,14 +903,10 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] async def _materialize_folder_priority( self, folder_ids: list[int] ) -> list[dict[str, Any]]: - """Resolve user-mentioned folder ids to ```` entries. + """Resolve mentioned folder ids to canonical-path priority entries. - Each entry uses the canonical ``/documents/Folder/Sub/`` virtual - path (matching ``KnowledgeTreeMiddleware`` and the agent's - ``ls`` adapter) and is flagged ``mentioned=True`` so the - rendered line carries ``[USER-MENTIONED]``. ``score`` is left - ``None`` so the renderer prints ``n/a`` — folders aren't - ranked, the agent decides which children to read. + Flagged ``mentioned=True`` with ``score=None`` (folders aren't ranked; + the agent decides which children to read). """ if not folder_ids: return [] diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/hitl.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/hitl.py index f5023737a..9b16e1a4c 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/hitl.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/tools/hitl.py @@ -30,22 +30,11 @@ from langgraph.types import interrupt logger = logging.getLogger(__name__) -# Tools that mirror the safety profile of ``write_file`` against the -# SurfSense KB: each call creates ONE artifact in the user's own workspace -# with no external visibility (drafts aren't sent; new files aren't shared -# unless the user shares them later). These are auto-approved by default -# so the agent can compose drafts and seed scratch files without a popup -# on every call. -# -# Members of this set still call ``request_approval`` exactly as before; -# the function returns immediately with ``decision_type="auto_approved"`` -# and the original params untouched. This preserves the call-site shape -# (logging, metadata fetching, account fallbacks) so the only behavior -# change is "no interrupt fires". -# -# To re-enable prompting, the future per-search-space rules table -# (``agent_permission_rules``) takes precedence in the permission ruleset -# layering assembled by the agent factory. +# Low-stakes creation tools auto-approved by default: each creates one +# artifact in the user's own workspace with no external visibility (drafts +# aren't sent; new files aren't shared). They still call ``request_approval``, +# which returns ``decision_type="auto_approved"`` without firing an interrupt. +# Per-search-space ``agent_permission_rules`` can re-enable prompting. DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset( { "create_gmail_draft", @@ -150,10 +139,6 @@ def request_approval( return HITLResult(rejected=False, decision_type="trusted", params=dict(params)) if tool_name in DEFAULT_AUTO_APPROVED_TOOLS: - # Default policy: low-stakes creation tools (drafts + new-file - # creates) skip HITL because they're as recoverable as a local - # ``write_file`` against the SurfSense KB. The user can still - # delete the artifact in <30s if it's wrong. logger.info( "Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL", tool_name, diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index 5ed5f2ad6..7bb4a7c24 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -75,7 +75,7 @@ def create_generate_image_tool( captured model), use this config id instead of reading the search space's live ``image_generation_config_id``. """ - del db_session # use a fresh per-call session, see below + del db_session # tool uses a fresh per-call session instead @tool async def generate_image( @@ -140,17 +140,12 @@ def create_generate_image_tool( or IMAGE_GEN_AUTO_MODE_ID ) - # Build generation kwargs - # NOTE: size, quality, and style are intentionally NOT passed. - # Different models support different values for these params - # (e.g. DALL-E 3 wants "hd"/"standard" for quality while - # gpt-image-1 wants "high"/"medium"/"low"; size options also - # differ). Letting the model use its own defaults avoids errors. + # size/quality/style are intentionally omitted: valid values + # differ per model, so we let each model use its own defaults. gen_kwargs: dict[str, Any] = {} if n is not None and n > 1: gen_kwargs["n"] = n - # Call litellm based on config type if is_image_gen_auto_mode(config_id): if not ImageGenRouterService.is_initialized(): err = ( @@ -224,17 +219,13 @@ def create_generate_image_tool( prompt=prompt, model=model_string, **gen_kwargs ) - # Parse the response and store in DB response_dict = ( response.model_dump() if hasattr(response, "model_dump") else dict(response) ) - # Generate a random access token for this image access_token = generate_image_token() - - # Save to image_generations table for history db_image_gen = ImageGeneration( prompt=prompt, model=getattr(response, "_hidden_params", {}).get("model"), @@ -249,7 +240,6 @@ def create_generate_image_tool( await session.refresh(db_image_gen) db_image_gen_id = db_image_gen.id - # Extract image URLs from response images = response_dict.get("data", []) if not images: return _failed( @@ -260,11 +250,8 @@ def create_generate_image_tool( first_image = images[0] revised_prompt = first_image.get("revised_prompt", prompt) - # Resolve image URL: - # - If the API returned a URL, use it directly. - # - If the API returned b64_json (e.g. gpt-image-1), serve the - # image through our backend endpoint to avoid bloating the - # LLM context with megabytes of base64 data. + # b64_json (e.g. gpt-image-1) is served via our backend endpoint so + # megabytes of base64 don't bloat the LLM context. if first_image.get("url"): image_url = first_image["url"] elif first_image.get("b64_json"): diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py index a7c994c3f..e99e0291a 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py @@ -241,23 +241,12 @@ def _normalize_connectors( connectors_to_search: list[str] | None, available_connectors: list[str] | None = None, ) -> list[str]: + """Normalize model-supplied connectors to canonical ConnectorService types. + + Maps user-facing aliases (e.g. WEBCRAWLER_CONNECTOR), drops unknowns, and + constrains to ``available_connectors`` when given. Empty input defaults to + all available connectors (minus live-search ones). """ - Normalize connectors provided by the model. - - - Accepts user-facing enums like WEBCRAWLER_CONNECTOR and maps them to canonical - ConnectorService types. - - Drops unknown values. - - If available_connectors is provided, only includes connectors from that list. - - If connectors_to_search is None/empty, defaults to available_connectors or all. - - Args: - connectors_to_search: List of connectors requested by the model - available_connectors: List of connectors actually available in the search space - - Returns: - List of normalized connector strings to search - """ - # Determine the set of valid connectors to consider valid_set = ( set(available_connectors) if available_connectors else set(_ALL_CONNECTORS) ) @@ -276,18 +265,16 @@ def _normalize_connectors( c = (raw or "").strip().upper() if not c: continue - # Map user-facing aliases to canonical names if c == "WEBCRAWLER_CONNECTOR": c = "CRAWLED_URL" normalized.append(c) - # de-dupe while preserving order + filter to valid connectors + # De-dupe (order-preserving), keeping only known + available connectors. seen: set[str] = set() out: list[str] = [] for c in normalized: if c in seen: continue - # Only include if it's a known connector AND available if c not in _ALL_CONNECTORS: continue if c not in valid_set: @@ -295,7 +282,7 @@ def _normalize_connectors( seen.add(c) out.append(c) - # Fallback to all available if nothing matched + # Nothing matched: fall back to all available. if not out: base = ( list(available_connectors) @@ -377,39 +364,17 @@ def format_documents_for_context( max_chunk_chars: int = _MAX_CHUNK_CHARS, max_chunks_per_doc: int = 0, ) -> str: - """ - Format retrieved documents into a readable context string for the LLM. + """Format retrieved documents into an XML context string for the LLM. - Documents are added in order (highest relevance first) until the character - budget is reached. Individual chunks are capped at ``max_chunk_chars`` and - each document is limited to a dynamically computed chunk cap so a single - large document cannot monopolize the output while still maximising the use - of available context space. - - Args: - documents: List of document dictionaries from connector search - max_chars: Approximate character budget for the entire output. - max_chunk_chars: Per-chunk character cap (content is tail-truncated). - max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means - auto-compute per document using a rank-adaptive formula so - higher-ranked documents receive more chunks. - - Returns: - Formatted string with document contents and metadata + Documents are emitted highest-relevance first until ``max_chars`` is hit. + ``max_chunks_per_doc=0`` auto-computes a rank-adaptive cap so top results get + more chunks and no single large document monopolizes the budget. """ if not documents: return "" - # Group chunks by document id (preferred) to produce the XML structure. - # - # IMPORTANT: ConnectorService returns **document-grouped** results of the form: - # { - # "document": {...}, - # "chunks": [{"chunk_id": 123, "content": "..."}, ...], - # "source": "NOTION_CONNECTOR" | "FILE" | ... - # } - # - # We must preserve chunk_id so citations like [citation:123] are possible. + # Group chunks by document id, preserving chunk_id so [citation:123] works. + # ConnectorService returns document-grouped results ({document, chunks, source}). grouped: dict[str, dict[str, Any]] = {} for doc in documents: @@ -430,7 +395,7 @@ def format_documents_for_context( or "UNKNOWN" ) - # Document identity (prefer document_id; otherwise fall back to type+title+url) + # Identity: prefer document_id, else type+title+url. document_id_val = document_info.get("id") title = ( document_info.get("title") or metadata.get("title") or "Untitled Document" @@ -460,7 +425,7 @@ def format_documents_for_context( "chunks": [], } - # Prefer document-grouped chunks if available + # Prefer document-grouped chunks when present. chunks_list = doc.get("chunks") if isinstance(doc, dict) else None if isinstance(chunks_list, list) and chunks_list: for ch in chunks_list: @@ -492,7 +457,6 @@ def format_documents_for_context( "BAIDU_SEARCH_API", } - # Render XML expected by citation instructions, respecting the char budget. parts: list[str] = [] total_chars = 0 total_docs = len(grouped) @@ -594,30 +558,11 @@ async def search_knowledge_base_async( available_document_types: list[str] | None = None, max_input_tokens: int | None = None, ) -> str: - """ - Search the user's knowledge base for relevant documents. + """Search the knowledge base across connectors and return formatted results. - This is the async implementation that searches across multiple connectors. - - Args: - query: The search query - search_space_id: The user's search space ID - db_session: Database session - connector_service: Initialized connector service - connectors_to_search: Optional list of connector types to search. If omitted, searches all. - top_k: Number of results per connector - start_date: Optional start datetime (UTC) for filtering documents - end_date: Optional end datetime (UTC) for filtering documents - available_connectors: Optional list of connectors actually available in the search space. - If provided, only these connectors will be searched. - available_document_types: Optional list of document types that actually have indexed - data. When provided, local connectors whose document type is - absent are skipped entirely (no embedding / DB round-trip). - max_input_tokens: Model context window size (tokens). Used to dynamically - size the output so it fits within the model's limits. - - Returns: - Formatted string with search results + ``available_document_types`` lets local connectors with no indexed data be + skipped (no embedding / DB round-trip), and ``max_input_tokens`` sizes the + output to the model's context window. """ perf = get_perf_logger() t0 = time.perf_counter() diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/report.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/report.py index d9a941021..24042d775 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/report.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/report.py @@ -196,13 +196,8 @@ def _strip_wrapping_code_fences(text: str) -> str: def _extract_metadata(content: str) -> dict[str, Any]: """Extract metadata from generated Markdown content.""" - # Count section headings headings = re.findall(r"^(#{1,6})\s+(.+)$", content, re.MULTILINE) - - # Word count word_count = len(content.split()) - - # Character count char_count = len(content) return { @@ -227,12 +222,11 @@ def _parse_sections(content: str) -> list[dict[str, str]]: in_code_block = False for line in lines: - # Track code blocks to avoid matching headings inside them + # Track fences so headings inside code blocks aren't treated as splits. stripped = line.strip() if stripped.startswith("```"): in_code_block = not in_code_block - # Only split on # or ## headings (not ### or deeper) and only outside code blocks is_section_heading = ( not in_code_block and re.match(r"^#{1,2}\s+", line) @@ -240,7 +234,6 @@ def _parse_sections(content: str) -> list[dict[str, str]]: ) if is_section_heading: - # Save previous section if current_heading or current_body_lines: sections.append( { @@ -253,7 +246,6 @@ def _parse_sections(content: str) -> list[dict[str, str]]: else: current_body_lines.append(line) - # Save last section if current_heading or current_body_lines: sections.append( { @@ -292,7 +284,6 @@ async def _revise_with_sections( Unchanged sections are kept byte-for-byte identical. Returns the revised content, or None to trigger full-document revision fallback. """ - # Parse report into sections sections = _parse_sections(parent_content) if len(sections) < 2: logger.info( @@ -300,7 +291,6 @@ async def _revise_with_sections( ) return None - # Build a sections listing for the LLM sections_listing = "" for i, sec in enumerate(sections): heading = sec["heading"] or "(preamble — content before first heading)" @@ -352,11 +342,9 @@ async def _revise_with_sections( ) return None - # Compute total operations for progress tracking total_ops = len(modify_indices) + len(add_sections) current_op = 0 - # Emit plan summary parts = [] if modify_indices: parts.append( @@ -394,7 +382,6 @@ async def _revise_with_sections( current_op += 1 sec = sections[idx] - # Extract plain section name (strip markdown heading markers) section_name = ( re.sub(r"^#+\s*", "", sec["heading"]).strip() if sec["heading"] @@ -412,7 +399,6 @@ async def _revise_with_sections( f"{sec['heading']}\n\n{sec['body']}" if sec["heading"] else sec["body"] ) - # Build context from surrounding sections context_parts = [] if idx > 0: prev = sections[idx - 1] @@ -442,7 +428,6 @@ async def _revise_with_sections( revised_text = resp.content if revised_text and isinstance(revised_text, str): revised_text = _strip_wrapping_code_fences(revised_text).strip() - # Parse the LLM output back into heading + body revised_parsed = _parse_sections(revised_text) if revised_parsed: revised_sections[idx] = revised_parsed[0] @@ -465,7 +450,6 @@ async def _revise_with_sections( heading = add_info.get("heading", "## New Section") description = add_info.get("description", "") - # Extract plain section name for progress display plain_heading = re.sub(r"^#+\s*", "", heading).strip() dispatch_custom_event( "report_progress", @@ -475,7 +459,6 @@ async def _revise_with_sections( }, ) - # Build context from the surrounding sections at the insertion point ctx_parts = [] if 0 <= after_idx < len(revised_sections): before_sec = revised_sections[after_idx] @@ -542,36 +525,13 @@ def create_generate_report_tool( available_connectors: list[str] | None = None, available_document_types: list[str] | None = None, ): - """ - Factory function to create the generate_report tool with injected dependencies. + """Create the generate_report tool with injected dependencies. - The tool generates a Markdown report inline using the search space's - document summary LLM, saves it to the database, and returns immediately. - - Uses short-lived database sessions for each DB operation so no connection - is held during the long LLM API call. - - Generation strategies: - - New reports: single-shot generation (1 LLM call) - - Revisions (targeted edits): section-level (unchanged sections preserved) - - Revisions (global changes): full-document revision fallback - - Source strategies: - - "provided"/"conversation": use only the supplied source_content - - "kb_search": search the knowledge base internally using targeted queries - - "auto": use source_content if sufficient, otherwise fall back to KB search - - Args: - search_space_id: The user's search space ID - thread_id: The chat thread ID for associating the report - connector_service: Optional connector service for internal KB search. - When provided, the tool can search the knowledge base internally - (used by the "kb_search" and "auto" source strategies). - available_connectors: Optional list of connector types available in the - search space (used to scope internal KB searches). - - Returns: - A configured tool function for generating reports + Uses short-lived DB sessions per operation so no connection is held during + the long LLM call. Generation: new reports are single-shot; revisions try + section-level first (unchanged sections preserved) and fall back to full-doc. + Source strategies: provided/conversation (use source_content), kb_search + (internal KB queries), auto (KB search only when source_content is thin). """ @tool @@ -693,7 +653,7 @@ def create_generate_report_tool( Returns: Dict with status, report_id, title, word_count, and message. """ - # Initialize version tracking variables (used by _save_failed_report closure) + # Shared with the _save_failed_report closure. parent_report_content: str | None = None report_group_id: int | None = None @@ -733,7 +693,7 @@ def create_generate_report_tool( session.add(failed_report) await session.commit() await session.refresh(failed_report) - # If this is a new group (v1 failed), set group to self + # New group (v1 failed): point the group at itself. if not failed_report.report_group_id: failed_report.report_group_id = failed_report.id await session.commit() @@ -749,8 +709,8 @@ def create_generate_report_tool( try: # ── Phase 1: READ (short-lived session) ────────────────────── - # Fetch parent report and LLM config, then close the session - # so no DB connection is held during the long LLM call. + # Fetch parent report + LLM config, then release the connection + # before the long LLM call. async with shielded_async_session() as read_session: if parent_report_id: parent_report = await read_session.get(Report, parent_report_id) @@ -768,7 +728,6 @@ def create_generate_report_tool( ) llm = await get_document_summary_llm(read_session, search_space_id) - # read_session closed — connection returned to pool if not llm: error_msg = ( @@ -785,7 +744,6 @@ def create_generate_report_tool( error=error_msg, ) - # Build the user instructions string user_instructions_section = "" if user_instructions: user_instructions_section = ( @@ -829,7 +787,7 @@ def create_generate_report_tool( try: from .knowledge_base import search_knowledge_base_async - # Run all queries in parallel, each with its own session + # Each query gets its own short-lived session. async def _run_single_query(q: str) -> str: async with shielded_async_session() as kb_session: kb_connector_svc = ConnectorService( @@ -849,7 +807,6 @@ def create_generate_report_tool( *[_run_single_query(q) for q in search_queries[:5]] ) - # Merge non-empty results into source_content kb_text_parts = [r for r in kb_results if r and r.strip()] if kb_text_parts: kb_combined = "\n\n---\n\n".join(kb_text_parts) @@ -903,9 +860,9 @@ def create_generate_report_tool( "provided. Using source_content as-is." ) - capped_source = effective_source[:100000] # Cap source content + capped_source = effective_source[:100000] - # Length constraint — only when user explicitly asks for brevity + # Length constraint only when the user explicitly asked for brevity. length_instruction = "" if report_style == "brief": length_instruction = ( @@ -920,11 +877,8 @@ def create_generate_report_tool( report_content: str | None = None if parent_report_content: - # ─── REVISION MODE ─────────────────────────────────────── - # Strategy: Try section-level revision first (preserves - # unchanged sections byte-for-byte). Falls back to full- - # document revision if section identification fails or if - # all sections need changes. + # Revision mode: section-level first (preserves untouched + # sections), falling back to full-doc revision. dispatch_custom_event( "report_progress", { @@ -946,7 +900,6 @@ def create_generate_report_tool( ) if report_content is None: - # Fallback: full-document revision dispatch_custom_event( "report_progress", {"phase": "writing", "message": "Rewriting your full report"}, @@ -969,9 +922,7 @@ def create_generate_report_tool( report_content = response.content else: - # ─── NEW REPORT MODE ───────────────────────────────────── - # Single-shot generation: one LLM call produces the full - # report. Fast, globally coherent, and cost-efficient. + # New report: single-shot generation (one LLM call). dispatch_custom_event( "report_progress", {"phase": "writing", "message": "Writing your report"}, @@ -991,8 +942,6 @@ def create_generate_report_tool( response = await llm.ainvoke([HumanMessage(content=prompt)]) report_content = response.content - # ── Validate LLM output ────────────────────────────────────── - if not report_content or not isinstance(report_content, str): error_msg = "LLM returned empty or invalid content" report_id = await _save_failed_report(error_msg) @@ -1029,14 +978,12 @@ def create_generate_report_tool( if report_content.rstrip().endswith("---"): report_content = report_content.rstrip()[:-3].rstrip() - # Append exactly one standard disclaimer + # Append exactly one standard footer. report_content += "\n\n---\n\n" + _REPORT_FOOTER - # Extract metadata (includes "status": "ready") metadata = _extract_metadata(report_content) # ── Phase 3: WRITE (short-lived session) ───────────────────── - # Save the report to the database, then close the session. async with shielded_async_session() as write_session: report = Report( title=topic, @@ -1051,14 +998,13 @@ def create_generate_report_tool( await write_session.commit() await write_session.refresh(report) - # If this is a brand-new report (v1), set report_group_id = own id + # Brand-new report (v1): point the group at itself. if not report.report_group_id: report.report_group_id = report.id await write_session.commit() saved_report_id = report.id saved_group_id = report.report_group_id - # write_session closed — connection returned to pool logger.info( f"[generate_report] Created report {saved_report_id} " diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/scrape_webpage.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/scrape_webpage.py index bb7c8e5a3..f4f109761 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/scrape_webpage.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/research/tools/scrape_webpage.py @@ -23,7 +23,6 @@ def extract_domain(url: str) -> str: try: parsed = urlparse(url) domain = parsed.netloc - # Remove 'www.' prefix if present if domain.startswith("www."): domain = domain[4:] return domain @@ -47,14 +46,13 @@ def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]: if len(content) <= max_length: return content, False - # Try to truncate at a sentence boundary + # Prefer truncating at a sentence/paragraph boundary. truncated = content[:max_length] last_period = truncated.rfind(".") last_newline = truncated.rfind("\n\n") - # Use the later of the two boundaries, or just truncate boundary = max(last_period, last_newline) - if boundary > max_length * 0.8: # Only use boundary if it's not too far back + if boundary > max_length * 0.8: # only if the boundary isn't too far back truncated = content[: boundary + 1] return truncated + "\n\n[Content truncated...]", True @@ -105,8 +103,8 @@ async def _scrape_youtube_video( http_client.proxies.update(residential_proxies) ytt_api = YouTubeTranscriptApi(http_client=http_client) - # List all available transcripts and pick the first one - # (the video's primary language) instead of defaulting to English + # Pick the first transcript (video's primary language) rather than + # defaulting to English. transcript_list = ytt_api.list(video_id) transcript = next(iter(transcript_list)) captions = transcript.fetch() @@ -128,10 +126,8 @@ async def _scrape_youtube_video( logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}") transcript_text = f"No captions available for this video. Error: {e!s}" - # Build combined content content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}" - # Truncate if needed content, was_truncated = truncate_content(content, max_length) word_count = len(content.split()) @@ -206,20 +202,16 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None): scrape_id = generate_scrape_id(url) domain = extract_domain(url) - # Validate and normalize URL if not url.startswith(("http://", "https://")): url = f"https://{url}" try: - # Check if this is a YouTube URL and use transcript API instead + # YouTube URLs use the transcript API instead of crawling. video_id = get_youtube_video_id(url) if video_id: return await _scrape_youtube_video(url, video_id, max_length) - # Create webcrawler connector connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key) - - # Crawl the URL result, error = await connector.crawl_url(url, formats=["markdown"]) if error: @@ -244,28 +236,21 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None): "error": "No content returned from crawler", } - # Extract content and metadata content = result.get("content", "") metadata = result.get("metadata", {}) - # Get title from metadata title = metadata.get("title", "") if not title: title = domain or url.split("/")[-1] or "Webpage" - # Get description from metadata description = metadata.get("description", "") if not description and content: - # Use first paragraph as description first_para = content.split("\n\n")[0] if content else "" description = ( first_para[:300] + "..." if len(first_para) > 300 else first_para ) - # Truncate content if needed content, was_truncated = truncate_content(content, max_length) - - # Calculate word count word_count = len(content.split()) return { diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index c5f929ec2..aad432edb 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -92,15 +92,9 @@ class SanitizedChatLiteLLM(ChatLiteLLM): yield chunk -# Provider mapping for LiteLLM model string construction. -# -# Single source of truth lives in -# :mod:`app.services.provider_capabilities` so the YAML loader (which -# runs during ``app.config`` class-body init) can resolve provider -# prefixes without dragging the agent / tools tree into module load -# order. Re-exported here under the historical ``PROVIDER_MAP`` name -# so existing callers (``llm_router_service``, ``image_gen_router_service``, -# tests) keep working unchanged. +# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives +# in provider_capabilities so the YAML loader can resolve prefixes during +# app.config init without importing the agent/tools tree. from app.services.provider_capabilities import ( # noqa: E402 _PROVIDER_PREFIX_MAP as PROVIDER_MAP, ) @@ -157,25 +151,14 @@ class AgentConfig: anonymous_enabled: bool = False quota_reserve_tokens: int | None = None - # Capability flag: best-effort True for the chat selector / catalog. - # Resolved via :func:`provider_capabilities.derive_supports_image_input` - # which prefers OpenRouter's ``architecture.input_modalities`` and - # otherwise consults LiteLLM's authoritative model map. Default True - # is the conservative-allow stance — the streaming-task safety net - # (``is_known_text_only_chat_model``) is the *only* place a False - # actually blocks a request. Setting this to False here without an - # authoritative source would silently hide vision-capable models - # (the regression we're fixing). + # Default-allow: only the streaming safety net (is_known_text_only_chat_model) + # actually blocks on False, so defaulting False would silently hide + # vision-capable models. Resolved via derive_supports_image_input. supports_image_input: bool = True @classmethod def from_auto_mode(cls) -> "AgentConfig": - """ - Create an AgentConfig for Auto mode (LiteLLM Router load balancing). - - Returns: - AgentConfig instance configured for Auto mode - """ + """Build an AgentConfig for Auto mode (LiteLLM Router load balancing).""" return cls( provider="AUTO", model_name="auto", @@ -193,27 +176,15 @@ class AgentConfig: is_premium=False, anonymous_enabled=False, quota_reserve_tokens=None, - # Auto routes across the configured pool, which usually - # contains at least one vision-capable deployment; the router - # will surface a 404 from a non-vision deployment as a normal - # ``allowed_fails`` event and fail over rather than blocking - # the request outright. + # Auto fails over across the pool, so a non-vision deployment's 404 + # is just an allowed_fails event rather than a hard block. supports_image_input=True, ) @classmethod def from_new_llm_config(cls, config) -> "AgentConfig": - """ - Create an AgentConfig from a NewLLMConfig database model. - - Args: - config: NewLLMConfig database model instance - - Returns: - AgentConfig instance - """ - # Lazy import to avoid pulling provider_capabilities (and its - # transitive litellm import) into module-init order. + """Build an AgentConfig from a NewLLMConfig database model.""" + # Lazy import: keeps provider_capabilities (and litellm) out of init order. from app.services.provider_capabilities import derive_supports_image_input provider_value = ( @@ -245,10 +216,8 @@ class AgentConfig: is_premium=False, anonymous_enabled=False, quota_reserve_tokens=None, - # BYOK rows have no operator-curated capability flag, so we - # ask LiteLLM (default-allow on unknown). The streaming - # safety net still blocks if the model is *explicitly* - # marked text-only. + # BYOK rows have no curated flag; ask LiteLLM (default-allow on + # unknown). The streaming safety net still blocks explicit text-only. supports_image_input=derive_supports_image_input( provider=provider_value, model_name=config.model_name, @@ -259,25 +228,14 @@ class AgentConfig: @classmethod def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig": + """Build an AgentConfig from a YAML configuration dictionary. + + Supports the same prompt fields as NewLLMConfig (system_instructions, + use_default_system_instructions, citations_enabled). """ - Create an AgentConfig from a YAML configuration dictionary. - - YAML configs now support the same prompt configuration fields as NewLLMConfig: - - system_instructions: Custom system instructions (empty string uses defaults) - - use_default_system_instructions: Whether to use default instructions - - citations_enabled: Whether citations are enabled - - Args: - yaml_config: Configuration dictionary from YAML file - - Returns: - AgentConfig instance - """ - # Lazy import to avoid pulling provider_capabilities (and its - # transitive litellm import) into module-init order. + # Lazy import: keeps provider_capabilities (and litellm) out of init order. from app.services.provider_capabilities import derive_supports_image_input - # Get system instructions from YAML, default to empty string system_instructions = yaml_config.get("system_instructions", "") provider = yaml_config.get("provider", "").upper() @@ -290,13 +248,8 @@ class AgentConfig: else None ) - # Explicit YAML override wins; otherwise derive from LiteLLM / - # OpenRouter modalities. The YAML loader already populates this - # field, but this method is also called from - # ``load_global_llm_config_by_id``'s file fallback (hot reload), - # so we re-derive here for safety. The bool() coercion preserves - # the loader's behaviour for explicit ``true`` / ``false`` - # strings that PyYAML may surface. + # Explicit YAML override wins; otherwise re-derive (the hot-reload file + # fallback reaches this method without the loader having populated it). if "supports_image_input" in yaml_config: supports_image_input = bool(yaml_config.get("supports_image_input")) else: @@ -314,7 +267,6 @@ class AgentConfig: api_base=yaml_config.get("api_base"), custom_provider=custom_provider, litellm_params=yaml_config.get("litellm_params"), - # Prompt configuration from YAML (with defaults for backwards compatibility) system_instructions=system_instructions if system_instructions else None, use_default_system_instructions=yaml_config.get( "use_default_system_instructions", True @@ -332,20 +284,10 @@ class AgentConfig: def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None: - """ - Load a specific LLM config from global_llm_config.yaml. - - Args: - llm_config_id: The id of the config to load (default: -1) - - Returns: - LLM config dict or None if not found - """ - # Get the config file path + """Load a specific LLM config from global_llm_config.yaml.""" base_dir = Path(__file__).resolve().parent.parent.parent.parent config_file = base_dir / "app" / "config" / "global_llm_config.yaml" - # Fallback to example file if main config doesn't exist if not config_file.exists(): config_file = base_dir / "app" / "config" / "global_llm_config.example.yaml" if not config_file.exists(): @@ -368,24 +310,17 @@ def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None: def load_global_llm_config_by_id(llm_config_id: int) -> dict | None: - """ - Load a global LLM config by ID, checking in-memory configs first. + """Load a global LLM config by ID, checking in-memory configs first. - This handles both static YAML configs and dynamically injected configs - (e.g. OpenRouter integration models that only exist in memory). - - Args: - llm_config_id: The negative ID of the global config to load - - Returns: - LLM config dict or None if not found + In-memory covers both static YAML and dynamically injected configs (e.g. + OpenRouter integration models that only exist in memory). """ from app.config import config as app_config for cfg in app_config.GLOBAL_LLM_CONFIGS: if cfg.get("id") == llm_config_id: return cfg - # Fallback to YAML file read (covers edge cases like hot-reload) + # Fallback to YAML file read (covers hot-reload edge cases). return load_llm_config_from_yaml(llm_config_id) @@ -393,17 +328,7 @@ async def load_new_llm_config_from_db( session: AsyncSession, config_id: int, ) -> "AgentConfig | None": - """ - Load a NewLLMConfig from the database by ID. - - Args: - session: AsyncSession for database access - config_id: The ID of the NewLLMConfig to load - - Returns: - AgentConfig instance or None if not found - """ - # Import here to avoid circular imports + """Load a NewLLMConfig from the database by ID.""" from app.db import NewLLMConfig try: @@ -426,26 +351,13 @@ async def load_agent_llm_config_for_search_space( session: AsyncSession, search_space_id: int, ) -> "AgentConfig | None": + """Load the agent LLM config for a search space via its agent_llm_id. + + Positive id -> DB; negative -> YAML; None -> first global config (-1). """ - Load the agent LLM configuration for a search space. - - This loads the LLM config based on the search space's agent_llm_id setting: - - Positive ID: Load from NewLLMConfig database table - - Negative ID: Load from YAML global configs - - None: Falls back to first global config (id=-1) - - Args: - session: AsyncSession for database access - search_space_id: The search space ID - - Returns: - AgentConfig instance or None if not found - """ - # Import here to avoid circular imports from app.db import SearchSpace try: - # Get the search space to check its agent_llm_id preference result = await session.execute( select(SearchSpace).filter(SearchSpace.id == search_space_id) ) @@ -455,12 +367,9 @@ async def load_agent_llm_config_for_search_space( print(f"Error: SearchSpace with id {search_space_id} not found") return None - # Use agent_llm_id from search space, fallback to -1 (first global config) config_id = ( search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 ) - - # Load the config using the unified loader return await load_agent_config(session, config_id, search_space_id) except Exception as e: print(f"Error loading agent LLM config for search space {search_space_id}: {e}") @@ -472,23 +381,7 @@ async def load_agent_config( config_id: int, search_space_id: int | None = None, ) -> "AgentConfig | None": - """ - Load an agent configuration, supporting Auto mode, YAML, and database configs. - - This is the main entry point for loading configurations: - - ID 0: Auto mode (uses LiteLLM Router for load balancing) - - Negative IDs: Load from YAML file (global configs) - - Positive IDs: Load from NewLLMConfig database table - - Args: - session: AsyncSession for database access - config_id: The config ID (0 for Auto, negative for YAML, positive for database) - search_space_id: Optional search space ID for context - - Returns: - AgentConfig instance or None if not found - """ - # Auto mode (ID 0) - use LiteLLM Router + """Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB.""" if is_auto_mode(config_id): if not LLMRouterService.is_initialized(): print("Error: Auto mode requested but LLM Router not initialized") @@ -496,33 +389,22 @@ async def load_agent_config( return AgentConfig.from_auto_mode() if config_id < 0: - # Check in-memory configs first (includes static YAML + dynamic OpenRouter) + # In-memory covers static YAML + dynamic OpenRouter configs. from app.config import config as app_config for cfg in app_config.GLOBAL_LLM_CONFIGS: if cfg.get("id") == config_id: return AgentConfig.from_yaml_config(cfg) - # Fallback to YAML file read for safety yaml_config = load_llm_config_from_yaml(config_id) if yaml_config: return AgentConfig.from_yaml_config(yaml_config) return None else: - # Load from database (NewLLMConfig) return await load_new_llm_config_from_db(session, config_id) def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: - """ - Create a ChatLiteLLM instance from a global LLM config dictionary. - - Args: - llm_config: LLM configuration dictionary from YAML - - Returns: - ChatLiteLLM instance or None on error - """ - # Build the model string + """Create a ChatLiteLLM instance from a global LLM config dictionary.""" if llm_config.get("custom_provider"): model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}" else: @@ -530,27 +412,20 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) model_string = f"{provider_prefix}/{llm_config['model_name']}" - # Create ChatLiteLLM instance with streaming enabled litellm_kwargs = { "model": model_string, "api_key": llm_config.get("api_key"), - "streaming": True, # Enable streaming for real-time token streaming + "streaming": True, } - - # Add optional parameters if llm_config.get("api_base"): litellm_kwargs["api_base"] = llm_config["api_base"] - - # Add any additional litellm parameters if llm_config.get("litellm_params"): litellm_kwargs.update(llm_config["litellm_params"]) llm = SanitizedChatLiteLLM(**litellm_kwargs) _attach_model_profile(llm, model_string) - # Configure LiteLLM-native prompt caching (cache_control_injection_points - # for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.). - # ``agent_config=None`` here — the YAML path doesn't have provider intent - # in a structured form, so we set only the universal injection points. + # agent_config=None: the YAML path lacks structured provider intent, so set + # only the universal cache_control_injection_points. apply_litellm_prompt_caching(llm) return llm @@ -558,19 +433,7 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: def create_chat_litellm_from_agent_config( agent_config: AgentConfig, ) -> ChatLiteLLM | ChatLiteLLMRouter | None: - """ - Create a ChatLiteLLM or ChatLiteLLMRouter instance from an AgentConfig. - - For Auto mode configs, returns a ChatLiteLLMRouter that uses LiteLLM Router - for automatic load balancing across available providers. - - Args: - agent_config: AgentConfig instance - - Returns: - ChatLiteLLM or ChatLiteLLMRouter instance, or None on error - """ - # Handle Auto mode - return ChatLiteLLMRouter + """Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config.""" if agent_config.is_auto_mode: if not LLMRouterService.is_initialized(): print("Error: Auto mode requested but LLM Router not initialized") @@ -578,19 +441,14 @@ def create_chat_litellm_from_agent_config( try: router_llm = get_auto_mode_llm() if router_llm is not None: - # Universal cache_control_injection_points only — auto-mode - # fans out across providers, so OpenAI-only kwargs (e.g. - # ``prompt_cache_key``) are left off here. ``drop_params`` - # would strip them at the provider boundary anyway, but - # there's no point setting them when we don't know the - # destination. + # Universal injection points only: auto-mode fans out across + # providers, so provider-specific kwargs have no known target. apply_litellm_prompt_caching(router_llm, agent_config=agent_config) return router_llm except Exception as e: print(f"Error creating ChatLiteLLMRouter: {e}") return None - # Build the model string if agent_config.custom_provider: model_string = f"{agent_config.custom_provider}/{agent_config.model_name}" else: @@ -599,26 +457,19 @@ def create_chat_litellm_from_agent_config( ) model_string = f"{provider_prefix}/{agent_config.model_name}" - # Create ChatLiteLLM instance with streaming enabled litellm_kwargs = { "model": model_string, "api_key": agent_config.api_key, - "streaming": True, # Enable streaming for real-time token streaming + "streaming": True, } - - # Add optional parameters if agent_config.api_base: litellm_kwargs["api_base"] = agent_config.api_base - - # Add any additional litellm parameters if agent_config.litellm_params: litellm_kwargs.update(agent_config.litellm_params) llm = SanitizedChatLiteLLM(**litellm_kwargs) _attach_model_profile(llm, model_string) - # Build-time prompt caching: sets ``cache_control_injection_points`` for - # all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``. - # Per-thread ``prompt_cache_key`` is layered on later in - # ``create_surfsense_deep_agent`` once ``thread_id`` is known. + # Build-time caching only; the per-thread prompt_cache_key is layered on + # later in create_surfsense_deep_agent once thread_id is known. apply_litellm_prompt_caching(llm, agent_config=agent_config) return llm diff --git a/surfsense_backend/app/agents/chat/runtime/prompt_caching.py b/surfsense_backend/app/agents/chat/runtime/prompt_caching.py index da0007b1e..5a5fd7418 100644 --- a/surfsense_backend/app/agents/chat/runtime/prompt_caching.py +++ b/surfsense_backend/app/agents/chat/runtime/prompt_caching.py @@ -1,63 +1,28 @@ -r"""LiteLLM-native prompt caching configuration for SurfSense agents. +r"""LiteLLM-native prompt caching for SurfSense agents. -Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never -activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)`` -gate always failed) with LiteLLM's universal caching mechanism. +Replaces the legacy ``AnthropicPromptCachingMiddleware`` (its +``isinstance(model, ChatAnthropic)`` gate never matched our LiteLLM stack) +with LiteLLM's universal ``cache_control_injection_points`` mechanism, which +covers the Anthropic/Bedrock/Vertex/Gemini/OpenRouter/etc. marker-based +providers and the auto-caching OpenAI family. -Coverage: +Two breakpoints per request: -- Marker-based providers (need ``cache_control`` injection, which LiteLLM - performs automatically when ``cache_control_injection_points`` is set): - ``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``, - ``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/`` - (Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM). -- Auto-cached (LiteLLM strips the marker silently): ``openai/``, - ``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024 - tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``. +- ``index: 0`` pins the head-of-request system prompt. We use ``index: 0``, + NOT ``role: system``: ``before_agent`` injectors accumulate many + SystemMessages, and tagging all of them overflows Anthropic's 4-block cap + (upstream 400 via OpenRouter). +- ``index: -1`` pins the latest message so longest-prefix lookup compounds + multi-turn savings. -We inject **two** breakpoints per request: +OpenAI-family configs also get ``prompt_cache_key`` (per-thread routing hint) +and ``prompt_cache_retention="24h"``. Azure is excluded from the latter +because LiteLLM's Azure transformer drops it (see +``_PROMPT_CACHE_RETENTION_PROVIDERS``). -- ``index: 0`` — pins the SurfSense system prompt at the head of the - request (provider variant, citation rules, tool catalog, KB tree, - skills metadata). The langchain agent factory always prepends - ``request.system_message`` at index 0 (see ``factory.py`` - ``_execute_model_async``), so this targets exactly the main system - prompt regardless of how many other ``SystemMessage``\ s the - ``before_agent`` injectors (priority, tree, memory, file-intent, - anonymous-doc) have inserted into ``state["messages"]``. Using - ``role: system`` here would apply ``cache_control`` to **every** - system-role message and trip Anthropic's hard cap of 4 cache - breakpoints per request once the conversation accumulates enough - injected system messages — which surfaces as the upstream 400 - ``A maximum of 4 blocks with cache_control may be provided. Found N`` - via OpenRouter→Anthropic. -- ``index: -1`` — pins the latest message so multi-turn savings compound: - Anthropic-family providers use longest-matching-prefix lookup, so turn - N+1 still reads turn N's cache up to the shared prefix. - -For OpenAI-family configs we additionally pass: - -- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that - raises hit rate by sending requests with a shared prefix to the same - backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and - ``azure/`` (added to LiteLLM's Azure transformer in - https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified - against ``AzureOpenAIConfig.get_supported_openai_params`` in our - installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``, - ``azure/gpt-5.4``, ``azure/gpt-5.4-mini``). -- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default - 5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's - server-side support landed in Microsoft's docs on 2026-05-13 but - LiteLLM 1.83.14's Azure transformer still omits it from its supported - params list, so it gets silently dropped by ``litellm.drop_params``. - Azure's default in-memory retention (5-10 min, max 1 h) already - bridges intra-conversation turns; revisit when LiteLLM bumps Azure. - -Safety net: ``litellm.drop_params=True`` is set globally in -``app.services.llm_service`` at module-load time. Any kwarg the destination -provider doesn't recognise is auto-stripped at the provider transformer -layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on -``prompt_cache_key`` etc. +Safety net: ``litellm.drop_params=True`` (set in ``app.services.llm_service``) +strips any kwarg the destination provider rejects, so an auto-mode fallback +can't 400 on these extras. """ from __future__ import annotations @@ -73,57 +38,29 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# Two-breakpoint policy: head-of-request + latest message. See module -# docstring for rationale. Anthropic caps requests at 4 ``cache_control`` -# blocks; we use 2 here, leaving headroom for Phase-2 tool caching. -# -# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's -# ``before_agent`` middlewares (priority, tree, memory, anonymous-doc) -# insert ``SystemMessage`` instances into ``state["messages"]`` that -# accumulate across turns. With ``role: system`` the LiteLLM hook would -# tag *every* one of them with ``cache_control`` and overflow Anthropic's -# 4-block limit. ``index: 0`` always targets the langchain-prepended -# ``request.system_message``, giving us exactly one stable cache breakpoint. +# Head-of-request + latest message (see module docstring for the index:0 vs +# role:system rationale and Anthropic's 4-block cap). _DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = ( {"location": "message", "index": 0}, {"location": "message", "index": -1}, ) -# Providers (uppercase ``AgentConfig.provider`` values) that accept the -# OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs -# (2026-05-13) confirm automatic prompt caching applies to every GPT-4o -# or newer Azure deployment at ≥1024 tokens with no configuration needed, -# and that ``prompt_cache_key`` is combined with the prefix hash to -# improve routing affinity and therefore cache hit rate. LiteLLM's Azure -# transformer ships ``prompt_cache_key`` in its supported params as of -# https://github.com/BerriAI/litellm/pull/20989. -# -# Strict whitelist — many other providers in ``PROVIDER_MAP`` route -# through litellm's ``openai`` prefix without implementing the OpenAI -# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer -# family from the litellm prefix alone. +# Providers that accept the OpenAI ``prompt_cache_key`` routing hint. Strict +# whitelist: many providers route through litellm's ``openai`` prefix without +# the prompt-cache surface, so the prefix alone isn't enough to infer family. _PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset( {"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"} ) -# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept -# ``prompt_cache_retention="24h"``. Azure is excluded: see module -# docstring — LiteLLM 1.83.14's Azure transformer omits the param so -# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM -# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``. +# Subset that also accepts ``prompt_cache_retention="24h"``. Azure is excluded +# because LiteLLM's Azure transformer omits the param (drop_params strips it). _PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset( {"OPENAI", "DEEPSEEK", "XAI"} ) def _is_router_llm(llm: BaseChatModel) -> bool: - """Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import. - - Importing ``app.services.llm_router_service`` at module-load time would - create a cycle via ``llm_config -> prompt_caching -> llm_router_service``. - Class-name comparison is sufficient since the class is defined in a - single place. - """ + """Detect ``ChatLiteLLMRouter`` by class name to avoid an import cycle.""" return type(llm).__name__ == "ChatLiteLLMRouter" @@ -188,21 +125,10 @@ def apply_litellm_prompt_caching( ) -> None: """Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter. - Idempotent — values already present in ``llm.model_kwargs`` (e.g. from - ``agent_config.litellm_params`` overrides) are preserved. Mutates - ``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion`` - via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge - in our custom ``ChatLiteLLMRouter``. - - Args: - llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance. - agent_config: Optional ``AgentConfig`` driving provider-specific - behaviour. When omitted (or auto-mode), only the universal - ``cache_control_injection_points`` are set. - thread_id: Optional thread id used to construct a per-thread - ``prompt_cache_key`` for OpenAI-family providers. Caching still - works without it (server-side automatic), but the key improves - backend routing affinity and therefore hit rate. + Idempotent (existing ``model_kwargs`` values are preserved) and mutates + ``llm.model_kwargs`` in place. Without ``agent_config`` (or in auto-mode) + only the universal injection points are set; ``thread_id`` adds a per-thread + ``prompt_cache_key`` for OpenAI-family providers to improve routing affinity. """ model_kwargs = _get_or_init_model_kwargs(llm) if model_kwargs is None: @@ -217,11 +143,8 @@ def apply_litellm_prompt_caching( dict(point) for point in _DEFAULT_INJECTION_POINTS ] - # OpenAI-style extras only when we statically know the destination - # accepts them. Auto-mode router fans out across mixed providers so - # we can't safely set destination-specific kwargs there (drop_params - # would strip them but it's wasteful to set them in the first - # place). + # OpenAI-style extras only when the destination is statically known. The + # auto-mode router fans out across mixed providers, so skip them there. if _is_router_llm(llm): return diff --git a/surfsense_backend/app/agents/chat/shared/middleware/compaction.py b/surfsense_backend/app/agents/chat/shared/middleware/compaction.py index 6a533be6b..f91af6a70 100644 --- a/surfsense_backend/app/agents/chat/shared/middleware/compaction.py +++ b/surfsense_backend/app/agents/chat/shared/middleware/compaction.py @@ -1,26 +1,13 @@ -""" -SurfSense compaction middleware. +"""SurfSense compaction middleware. -Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware` -to add SurfSense-specific behavior: +Extends ``SummarizationMiddleware`` with three SurfSense behaviors: -1. **Structured summary template** (OpenCode-style ``## Goal / Constraints / - Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``) - — see :data:`SURFSENSE_SUMMARY_PROMPT` below. The base - ``SummarizationMiddleware`` only ships a freeform "summarize this" - prompt; the structured template is ported from OpenCode's - ``compaction.ts``. -2. **Protect SurfSense-specific SystemMessages** so injected hints - (````, ````, ````, - ````, ````, ````, ````) - are *not* summarized away and are kept verbatim in the post-summary - message list. Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy - (some message types are part of the agent's contract and must survive - compaction unchanged). -3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string`` - (Azure OpenAI / LiteLLM defense — when a provider streams an AIMessage - containing only tool_calls and no text, ``content`` can be ``None`` and - ``get_buffer_string`` crashes iterating over ``None``). SurfSense-specific. +1. A structured summary template (:data:`SURFSENSE_SUMMARY_PROMPT`) instead of + the base freeform prompt. +2. Protected SystemMessages (injected hints like ````) are + kept verbatim instead of being summarized away. +3. ``content=None`` is sanitized before ``get_buffer_string`` (some providers + stream tool-only AIMessages with ``None`` content, which would crash it). """ from __future__ import annotations @@ -43,9 +30,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# Structured summary template ported from OpenCode's -# ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a -# module-level constant so unit tests can assert on its sections. +# Module-level constant so unit tests can assert on its sections. SURFSENSE_SUMMARY_PROMPT = """ SurfSense Conversation Compaction Assistant @@ -114,13 +99,10 @@ def _is_protected_system_message(msg: AnyMessage) -> bool: def _sanitize_message_content(msg: AnyMessage) -> AnyMessage: - """Return ``msg`` with ``content=None`` coerced to ``""``. + """Return a copy of ``msg`` with ``content=None`` coerced to ``""``. - Folds in the historical defense from ``safe_summarization.py`` — - ``get_buffer_string`` reads ``m.text`` which iterates ``self.content``, - so a ``None`` content (Azure OpenAI / LiteLLM streaming a tool-only - AIMessage) explodes. We return a copy with empty string content so - downstream consumers see an empty body without mutating the original. + ``get_buffer_string`` reads ``m.text`` (iterating ``content``), so a + tool-only AIMessage with ``None`` content would crash it. """ if getattr(msg, "content", "not-missing") is not None: return msg @@ -159,20 +141,11 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware): conversation_messages: list[AnyMessage], cutoff_index: int, ) -> tuple[list[AnyMessage], list[AnyMessage]]: - """Split messages but always preserve SurfSense protected SystemMessages. + """Split messages, always preserving protected SystemMessages. - Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy - (``opencode/packages/opencode/src/session/compaction.ts``): some - message types are always kept verbatim because they are part of the - agent's working contract, not transient output. - - Also opens a ``compaction.run`` OTel span (no-op when OTel is off) - so dashboards can count compaction events and message-volume - without having to instrument upstream callers. + Also opens a ``compaction.run`` OTel span (no-op when OTel is off) here, + since partitioning is the first call once summarization is decided. """ - # Opening a span here is appropriate because partitioning is the - # first call SummarizationMiddleware makes when it has decided to - # summarize; we record the volume and then close as a normal span. with ot.compaction_span( reason="auto", messages_in=len(conversation_messages), @@ -191,20 +164,15 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware): else: kept_for_summary.append(msg) - # Place protected blocks at the *front* of preserved_messages so - # they keep their original ordering relative to the summary - # HumanMessage that precedes the rest of the preserved tail. + # Protected blocks go at the front of preserved_messages to keep + # ordering relative to the summary HumanMessage. return kept_for_summary, [*protected, *preserved_messages] def _filter_summary_messages( # type: ignore[override] self, messages: list[AnyMessage] ) -> list[AnyMessage]: - """Filter previous summaries AND sanitize ``content=None``. - - Folds the ``safe_summarization.py`` defense in: when the buffer - builder iterates ``m.text`` over ``None`` it explodes; sanitizing - here covers both the sync and async offload paths. - """ + """Filter previous summaries and sanitize ``content=None`` (covers the + sync and async offload paths).""" filtered = super()._filter_summary_messages(messages) return [_sanitize_message_content(m) for m in filtered] diff --git a/surfsense_backend/app/agents/podcaster/nodes.py b/surfsense_backend/app/agents/podcaster/nodes.py index 517d900a3..b9fee57e7 100644 --- a/surfsense_backend/app/agents/podcaster/nodes.py +++ b/surfsense_backend/app/agents/podcaster/nodes.py @@ -24,14 +24,11 @@ from .utils import get_voice_for_provider async def create_podcast_transcript( state: State, config: RunnableConfig ) -> dict[str, Any]: - """Each node does work.""" - - # Get configuration from runnable config + """Generate the podcast transcript from the source content.""" configuration = Configuration.from_runnable_config(config) search_space_id = configuration.search_space_id user_prompt = configuration.user_prompt - # Get search space's document summary LLM llm = await get_agent_llm(state.db_session, search_space_id) if not llm: error_message = ( @@ -40,22 +37,16 @@ async def create_podcast_transcript( print(error_message) raise RuntimeError(error_message) - # Get the prompt prompt = get_podcast_generation_prompt(user_prompt) - - # Create the messages messages = [ SystemMessage(content=prompt), HumanMessage( content=f"{state.source_content}" ), ] - - # Generate the podcast transcript llm_response = await llm.ainvoke(messages) - # Reasoning models (e.g. Kimi K2.5) may return content as a list of - # blocks including 'reasoning' entries. Normalise to a plain string. + # Reasoning models may return content as blocks; normalise to a string. content = strip_markdown_fences(extract_text_content(llm_response.content)) try: @@ -89,17 +80,13 @@ async def create_merged_podcast_audio( state: State, config: RunnableConfig ) -> dict[str, Any]: """Generate audio for each transcript and merge them into a single podcast file.""" - - # configuration = Configuration.from_runnable_config(config) - starting_transcript = PodcastTranscriptEntry( speaker_id=1, dialog="Welcome to Surfsense Podcast." ) transcript = state.podcast_transcript - # Merge the starting transcript with the podcast transcript - # Check if transcript is a PodcastTranscripts object or already a list + # transcript may be a PodcastTranscripts object or already a list. if hasattr(transcript, "podcast_transcripts"): transcript_entries = transcript.podcast_transcripts else: @@ -107,20 +94,16 @@ async def create_merged_podcast_audio( merged_transcript = [starting_transcript, *transcript_entries] - # Create a temporary directory for audio files temp_dir = Path("temp_audio") temp_dir.mkdir(exist_ok=True) - # Generate a unique session ID for this podcast session_id = str(uuid.uuid4()) output_path = f"podcasts/{session_id}_podcast.mp3" os.makedirs("podcasts", exist_ok=True) - # Generate audio for each transcript segment audio_files = [] async def generate_speech_for_segment(segment, index): - # Handle both dictionary and PodcastTranscriptEntry objects if hasattr(segment, "speaker_id"): speaker_id = segment.speaker_id dialog = segment.dialog @@ -128,20 +111,15 @@ async def create_merged_podcast_audio( speaker_id = segment.get("speaker_id", 0) dialog = segment.get("dialog", "") - # Select voice based on speaker_id voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id) - # Generate a unique filename for this segment if app_config.TTS_SERVICE == "local/kokoro": - # Kokoro generates WAV files filename = f"{temp_dir}/{session_id}_{index}.wav" else: - # Other services generate MP3 files filename = f"{temp_dir}/{session_id}_{index}.mp3" try: if app_config.TTS_SERVICE == "local/kokoro": - # Use Kokoro TTS service kokoro_service = await get_kokoro_tts_service( lang_code="a" ) # American English @@ -170,7 +148,6 @@ async def create_merged_podcast_audio( timeout=600, ) - # Save the audio to a file - use proper streaming method with open(filename, "wb") as f: f.write(response.content) @@ -179,23 +156,17 @@ async def create_merged_podcast_audio( print(f"Error generating speech for segment {index}: {e!s}") raise - # Generate all audio files concurrently tasks = [ generate_speech_for_segment(segment, i) for i, segment in enumerate(merged_transcript) ] audio_files = await asyncio.gather(*tasks) - # Merge audio files using ffmpeg try: - # Create FFmpeg instance with the first input ffmpeg = FFmpeg().option("y") - - # Add each audio file as input for audio_file in audio_files: ffmpeg = ffmpeg.input(audio_file) - # Configure the concatenation and output filter_complex = [] for i in range(len(audio_files)): filter_complex.append(f"[{i}:0]") @@ -205,8 +176,6 @@ async def create_merged_podcast_audio( ) ffmpeg = ffmpeg.option("filter_complex", filter_complex_str) ffmpeg = ffmpeg.output(output_path, map="[outa]") - - # Execute FFmpeg await ffmpeg.execute() print(f"Successfully created podcast audio: {output_path}") @@ -215,7 +184,6 @@ async def create_merged_podcast_audio( print(f"Error merging audio files: {e!s}") raise finally: - # Clean up temporary files for audio_file in audio_files: try: os.remove(audio_file)