docs(agents): tighten docstrings and comments across agent module

Recursive pass over the agents module to make docstrings and inline
comments concise and intent-oriented: drop narration that just restates
the code, condense verbose module/function docstrings, and keep only the
non-obvious "why" notes. No functional code changed.
This commit is contained in:
CREDO23 2026-06-05 17:39:38 +02:00
parent 620c378254
commit a3d05f6418
16 changed files with 319 additions and 1055 deletions

View file

@ -1,25 +1,15 @@
"""Append-only action-log middleware for the SurfSense agent. """Append-only action-log middleware for the SurfSense agent.
Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes Wraps every tool call and writes a row to :class:`~app.db.AgentActionLog`
a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt after the tool returns. Tools opt into reversibility via a ``reverse``
into reversibility by declaring a ``reverse`` callable on their callable on their :class:`ToolDefinition`; the rendered descriptor powers
:class:`ToolDefinition`; the rendered descriptor is persisted in
``reverse_descriptor`` for use by
``/api/threads/{thread_id}/revert/{action_id}``. ``/api/threads/{thread_id}/revert/{action_id}``.
Design points: Logging is fully defensive DB-write failures are swallowed so the tool's
result is always returned untouched. Only metadata (name, capped args,
* **Defensive.** Logging never blocks the agent. We catch every exception result_id, reverse_descriptor) is stored; tool output stays in the
on the DB write path and emit a warning; the tool's ``ToolMessage`` checkpoint. Reversibility is best-effort: a reverse callable that raises
result is always returned untouched. just leaves the action non-reversible.
* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) +
``result_id`` + ``reverse_descriptor`` are stored. Tool output text
remains in the LangGraph checkpoint / spilled tool-output files.
* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)``
with the parsed JSON result when the tool's content is a JSON object;
otherwise the raw text is passed. Exceptions in the reverse callable
are swallowed and logged a failed descriptor render simply means the
action is NOT marked reversible.
""" """
from __future__ import annotations from __future__ import annotations
@ -203,11 +193,9 @@ class ActionLogMiddleware(AgentMiddleware):
) )
return return
# Surface a side-channel SSE event so the chat tool card can # Side-channel event (relayed by ``stream_new_chat`` as a
# render a Revert button immediately after the row is durable. # ``data-action-log`` SSE) so the tool card can show a Revert button
# ``stream_new_chat`` translates this into a # once the row is durable. Carries a presence flag, not the descriptor.
# ``data-action-log`` SSE event. We DO NOT include the
# ``reverse_descriptor`` payload here; only a presence flag.
try: try:
await adispatch_custom_event( await adispatch_custom_event(
"action_log", "action_log",

View file

@ -1,32 +1,12 @@
""" """Per-thread asyncio lock + cooperative cancel token, keyed by ``thread_id``.
BusyMutexMiddleware per-thread asyncio lock + cancel token.
LangChain has no built-in concept of "this thread is already running a Refuses a second concurrent turn on the same thread (e.g. double-clicked
turn refuse the second concurrent request". Without it, a user "send") that would otherwise race on the same checkpoint and duplicate tool
double-clicking "send" or refreshing the page mid-stream can spawn two calls. Also exposes a per-thread cancel event that long-running tools poll
turns racing on the same checkpoint, producing duplicated tool calls via ``runtime.context.cancel_event.is_set()`` to abort cooperatively.
and mangled state.
Ported from OpenCode's ``Stream.scoped(AbortController)`` pattern: a Process-local and in-memory; multi-worker deployments need a distributed lock
single-process, in-memory lock + cooperative cancellation token keyed by (Redis / PostgreSQL advisory locks) as a follow-up.
``thread_id``. For multi-worker deployments a distributed lock backend
(Redis or PostgreSQL advisory locks) is a phase-2 follow-up.
What this provides:
- A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``;
acquiring the lock during ``before_agent`` blocks any concurrent
prompt on the same thread until release.
- A per-thread ``asyncio.Event`` (``cancel_event``) that long-running
tools can poll to abort cooperatively. The event is reset between
turns. Tools should check ``runtime.context.cancel_event.is_set()``
in tight inner loops.
- A typed :class:`~app.agents.chat.runtime.errors.BusyError` raised when a
second turn arrives while the lock is held.
Note: SurfSense's ``stream_new_chat`` is the call site that should
acquire/release. Wiring this as middleware means the contract is
explicit and the lock manager is shared with subagents that compile
their own ``create_agent`` runnables.
""" """
from __future__ import annotations from __future__ import annotations
@ -152,9 +132,8 @@ class _ThreadLockManager:
return True return True
# Module-level singleton — process-local but reused across all agent # Process-local singleton shared across all agents/subagents built in this
# instances built in this process. Subagents created in nested # process so per-thread locks stay coherent.
# ``create_agent`` calls also get this so locks are coherent.
manager = _ThreadLockManager() manager = _ThreadLockManager()
@ -266,7 +245,6 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
await lock.acquire() await lock.acquire()
epoch = manager.bump_turn_epoch(thread_id) epoch = manager.bump_turn_epoch(thread_id)
self._held_locks[thread_id] = (lock, epoch) self._held_locks[thread_id] = (lock, epoch)
# Reset the cancel event so this turn starts fresh
reset_cancel(thread_id) reset_cancel(thread_id)
return None return None
@ -289,17 +267,14 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
return None return None
if lock.locked(): if lock.locked():
lock.release() lock.release()
# Always clear cancel event between turns so a stale signal # Clear cancel event so a stale signal doesn't leak into the next turn.
# doesn't leak into the next request.
reset_cancel(thread_id) reset_cancel(thread_id)
return None return None
# Provide sync no-ops because the middleware base class allows them
def before_agent( # type: ignore[override] def before_agent( # type: ignore[override]
self, state: AgentState[Any], runtime: Runtime[ContextT] self, state: AgentState[Any], runtime: Runtime[ContextT]
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
# Sync path: no asyncio.Lock to acquire. Best we can do is reject # Sync path can't await an asyncio.Lock; only reject if one is in flight.
# if anyone else is in flight.
thread_id = self._thread_id(runtime) thread_id = self._thread_id(runtime)
if thread_id is None: if thread_id is None:
if self._require_thread_id: if self._require_thread_id:

View file

@ -82,13 +82,10 @@ _T = TypeVar("_T")
async def _ainvoke_with_timeout[T]( async def _ainvoke_with_timeout[T](
coro: Awaitable[_T], *, subagent_type: str, started_at: float coro: Awaitable[_T], *, subagent_type: str, started_at: float
) -> _T: ) -> _T:
"""Apply :data:`DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS` to ``coro``. """Apply the subagent invoke timeout to ``coro`` (non-positive disables it).
A non-positive timeout disables the cap (configurable via the On expiry the task is cancelled and :class:`SubagentInvokeTimeoutError` is
``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` env var). On expiry the raised for the caller to turn into a synthetic ToolMessage.
underlying task is cancelled and :class:`SubagentInvokeTimeoutError` is
raised the caller wraps it into a synthetic ToolMessage so the
orchestrator can decide what to do.
""" """
timeout = DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS timeout = DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS
if timeout <= 0: if timeout <= 0:
@ -151,12 +148,9 @@ def build_task_tool_with_parent_config(
subagent_graphs: dict[str, Runnable] = { subagent_graphs: dict[str, Runnable] = {
spec["name"]: spec["runnable"] for spec in subagents spec["name"]: spec["runnable"] for spec in subagents
} }
# Per-subagent context-hint providers (see ``SurfSenseSubagentSpec``). # Sparse map of opt-in context-hint providers; each runs once per task()
# The mapping is sparse: only routes that opted in via ``pack_subagent`` # call to prepend a string to the subagent's first HumanMessage. Failures
# appear here, and the value is invoked once per ``task(...)`` call to # are swallowed so a broken hint never blocks the task.
# generate a short string prepended to the subagent's first
# ``HumanMessage``. Failures are logged and swallowed — a broken hint
# provider must never prevent the underlying task from running.
subagent_hint_providers: dict[str, ContextHintProvider] = { subagent_hint_providers: dict[str, ContextHintProvider] = {
spec["name"]: provider spec["name"]: provider
for spec in subagents for spec in subagents
@ -178,24 +172,18 @@ def build_task_tool_with_parent_config(
def _billable_call_update( def _billable_call_update(
subagent_type: str, runtime: ToolRuntime subagent_type: str, runtime: ToolRuntime
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Build the per-call ``billable_calls`` delta + an optional warning. """Build the per-call ``billable_calls`` delta plus an optional soft-cap warning.
The orchestrator's ``billable_calls`` map is summed by Always emits ``{subagent_type: 1}`` (a reducer accumulates it); when this
:func:`_int_counter_merge_reducer`, so we always emit call would cross the threshold, also adds a soft ``messages`` entry so the
``{subagent_type: 1}`` and let the reducer accumulate. If the orchestrator self-limits on its next step.
cumulative count *after* this call would cross the configured
threshold, we also slip a soft ``messages`` entry into the update
so the orchestrator can read it on its next step and self-limit.
Returning a plain ``dict`` (vs. an extra :class:`Command`) keeps
the helper composable with the existing single/batch return paths.
""" """
delta: dict[str, Any] = {"billable_calls": {subagent_type: 1}} delta: dict[str, Any] = {"billable_calls": {subagent_type: 1}}
threshold = DEFAULT_SUBAGENT_BILLABLE_THRESHOLD threshold = DEFAULT_SUBAGENT_BILLABLE_THRESHOLD
if threshold <= 0: if threshold <= 0:
return delta return delta
prior = runtime.state.get("billable_calls") or {} prior = runtime.state.get("billable_calls") or {}
# ``prior`` may be a plain dict or a reducer-managed mapping; only # Count int values only so a malformed checkpoint can't crash us.
# int values are counted so a malformed checkpoint can't crash us.
prior_total = sum(v for v in prior.values() if isinstance(v, int)) prior_total = sum(v for v in prior.values() if isinstance(v, int))
new_total = prior_total + 1 new_total = prior_total + 1
if prior_total < threshold <= new_total: if prior_total < threshold <= new_total:
@ -214,8 +202,7 @@ def build_task_tool_with_parent_config(
"""Merge the per-call billable counter (and warning) into ``cmd``.""" """Merge the per-call billable counter (and warning) into ``cmd``."""
delta = _billable_call_update(subagent_type, runtime) delta = _billable_call_update(subagent_type, runtime)
warn_text = delta.pop("_billable_warn_text", None) warn_text = delta.pop("_billable_warn_text", None)
# ``cmd.update`` may be a dict or LangGraph ``UpdateDict``; defensively # Copy so we don't mutate state shared with other tool returns.
# copy so we don't mutate state shared across other tool returns.
update = dict(getattr(cmd, "update", {}) or {}) update = dict(getattr(cmd, "update", {}) or {})
for key, value in delta.items(): for key, value in delta.items():
update[key] = value update[key] = value
@ -228,14 +215,10 @@ def build_task_tool_with_parent_config(
return Command(update=update) return Command(update=update)
def _safe_message_text(msg: Any) -> str: def _safe_message_text(msg: Any) -> str:
"""Pull text out of a BaseMessage without trusting the ``.text`` property. """Pull text out of a BaseMessage without using the ``.text`` property.
``BaseMessage.text`` walks ``content_blocks`` and crashes with ``.text`` crashes when ``content`` is ``None`` (common for tool-call
``TypeError: 'NoneType' object is not iterable`` when ``content`` is AIMessages), and ``getattr`` won't catch it, so read ``content`` directly.
``None`` (common for tool-call AIMessages whose payload is purely
structured). ``getattr(msg, "text", None)`` does not catch this
because Python evaluates the property body before falling back to
the default. Read ``content`` directly and coerce defensively.
""" """
try: try:
content = getattr(msg, "content", None) content = getattr(msg, "content", None)
@ -258,23 +241,18 @@ def build_task_tool_with_parent_config(
return str(content) return str(content)
def _build_tool_trace(messages: list[Any]) -> list[dict[str, Any]]: def _build_tool_trace(messages: list[Any]) -> list[dict[str, Any]]:
"""Compress the subagent's message stream into a compact tool trace. """Compress the subagent's messages into a compact tool trace.
Each entry is ``{"tool": <name>, "status": "ok"|"error", "preview": Entries (``{tool, status, preview}``) ride on the ToolMessage's
<120 chars>}`` so the orchestrator can show "this is what your ``additional_kwargs["surf_tool_trace"]`` for UI/observability; the LLM
specialist actually did" without dumping the full message stream never sees them.
back through the prompt. The list is attached to the returned
ToolMessage's ``additional_kwargs`` (under ``"surf_tool_trace"``);
the LLM never sees it, but UI / observability code can pluck it
out of the checkpoint.
""" """
trace: list[dict[str, Any]] = [] trace: list[dict[str, Any]] = []
for msg in messages: for msg in messages:
tool_name = getattr(msg, "name", None) tool_name = getattr(msg, "name", None)
tool_call_id_attr = getattr(msg, "tool_call_id", None) tool_call_id_attr = getattr(msg, "tool_call_id", None)
if not tool_name and not tool_call_id_attr: if not tool_name and not tool_call_id_attr:
# Only ToolMessages have either field; skip AIMessage / # Only ToolMessages carry either field.
# HumanMessage / SystemMessage frames.
continue continue
status = getattr(msg, "status", None) or "ok" status = getattr(msg, "status", None) or "ok"
preview = _safe_message_text(msg).strip().replace("\n", " ") preview = _safe_message_text(msg).strip().replace("\n", " ")
@ -308,8 +286,7 @@ def build_task_tool_with_parent_config(
) )
raise ValueError(msg) raise ValueError(msg)
message_text = _safe_message_text(messages[-1]).rstrip() message_text = _safe_message_text(messages[-1]).rstrip()
# Tool-trace is purely observability — wrap defensively so a single # Trace is observability-only; never let a bad frame kill the turn.
# malformed frame never bubbles up and kills the whole user turn.
try: try:
tool_trace = _build_tool_trace(messages) tool_trace = _build_tool_trace(messages)
except Exception: except Exception:
@ -320,10 +297,7 @@ def build_task_tool_with_parent_config(
tool_trace = [] tool_trace = []
tool_msg = ToolMessage(message_text, tool_call_id=tool_call_id) tool_msg = ToolMessage(message_text, tool_call_id=tool_call_id)
if tool_trace: if tool_trace:
# ``additional_kwargs`` is a free-form dict on BaseMessage; using # surf_ prefix avoids collision with provider keys (e.g. cache_control).
# a ``surf_`` prefix avoids collision with provider-specific keys
# (e.g. Anthropic's ``cache_control``). The LLM doesn't see it;
# consumers (UI, observability) read it off the checkpoint.
tool_msg.additional_kwargs["surf_tool_trace"] = tool_trace tool_msg.additional_kwargs["surf_tool_trace"] = tool_trace
return Command( return Command(
update={ update={
@ -361,9 +335,7 @@ def build_task_tool_with_parent_config(
} }
hint = _resolve_context_hint(subagent_type, description, runtime) hint = _resolve_context_hint(subagent_type, description, runtime)
if hint: if hint:
# Prepend as a tagged block so the subagent prompt can pattern-match # Tagged block so the subagent prompt can pattern-match the section.
# on the section (and a future change can lift it into its own
# ``SystemMessage`` if needed).
payload = f"<context_hint>\n{hint}\n</context_hint>\n\n{description}" payload = f"<context_hint>\n{hint}\n</context_hint>\n\n{description}"
else: else:
payload = description payload = description
@ -374,16 +346,12 @@ def build_task_tool_with_parent_config(
results: list[tuple[int, str, dict | str, dict | None]], results: list[tuple[int, str, dict | str, dict | None]],
runtime: ToolRuntime, runtime: ToolRuntime,
) -> Command: ) -> Command:
"""Combine per-child results into one Command with a combined ToolMessage. """Combine per-child results into one Command with an aggregate ToolMessage.
``results`` is a list of ``(task_index, subagent_type, ``results`` tuples are ``(task_index, subagent_type, payload_or_error,
payload_or_error_text, child_state_update)`` tuples preserving the child_state_update)``; output blocks are sorted by index so the LLM can
input order so the orchestrator can map each block back to the task map them back to dispatch order, and each child contributes a
it dispatched. State updates are merged by reducer for keys outside ``billable_calls`` increment to match single-mode accounting.
:data:`EXCLUDED_STATE_KEYS`; everything else (``messages``, ``todos``,
etc.) is replaced by the synthesized aggregate ToolMessage. Every
child also contributes a ``billable_calls`` increment so cost
accounting matches single-mode dispatch.
""" """
results.sort(key=lambda r: r[0]) results.sort(key=lambda r: r[0])
merged_state: dict[str, Any] = {} merged_state: dict[str, Any] = {}
@ -424,8 +392,8 @@ def build_task_tool_with_parent_config(
} }
) )
if state_update: if state_update:
# Naive merge: later tasks win on scalar collisions; reducer-backed # Later tasks win on scalar collisions; reducer-backed fields
# fields (``receipts``, ``files`` etc.) accumulate at apply time. # accumulate at apply time.
merged_state.update(state_update) merged_state.update(state_update)
aggregate = "\n\n".join(message_blocks) aggregate = "\n\n".join(message_blocks)
aggregate_msg = ToolMessage( aggregate_msg = ToolMessage(
@ -469,11 +437,9 @@ def build_task_tool_with_parent_config(
) -> tuple[int, str, dict | str, dict | None]: ) -> tuple[int, str, dict | str, dict | None]:
"""Run one child of a batched ``task`` call under the concurrency cap. """Run one child of a batched ``task`` call under the concurrency cap.
Errors are returned as plain text in slot 2 so a single child's Errors are returned as text (slot 2) so one child's failure doesn't abort
failure does not abort the whole batch. ``GraphInterrupt`` from a the batch. A child's ``GraphInterrupt`` is a hard failure for that child:
batched child is currently treated as a hard failure for that child batched HITL is intentionally out of scope.
only batched HITL is intentionally out of scope for the v1
rollout (see plan tier 2 item 4 risks).
""" """
async with semaphore: async with semaphore:
if subagent_type not in subagent_graphs: if subagent_type not in subagent_graphs:
@ -507,8 +473,7 @@ def build_task_tool_with_parent_config(
) )
return (task_index, subagent_type, str(exc), None) return (task_index, subagent_type, str(exc), None)
except GraphInterrupt: except GraphInterrupt:
# Batched HITL is unsupported in v1 — surface as a failure # Batched HITL unsupported; fail this child so the batch finishes.
# for this child so the rest of the batch still completes.
logger.warning( logger.warning(
"Batch child %d (%s) raised GraphInterrupt; batched HITL " "Batch child %d (%s) raised GraphInterrupt; batched HITL "
"is not supported. Re-dispatch this task as a single " "is not supported. Re-dispatch this task as a single "
@ -545,14 +510,11 @@ def build_task_tool_with_parent_config(
return (task_index, subagent_type, result, child_state_update) return (task_index, subagent_type, result, child_state_update)
def _coerce_batch_arg(tasks: Any) -> list[dict] | str: def _coerce_batch_arg(tasks: Any) -> list[dict] | str:
"""Rescue common LLM-side malformations of the ``tasks`` argument. """Rescue common LLM malformations of the ``tasks`` argument.
Some providers serialise an array argument as a JSON-encoded string, Recovers a JSON-encoded array string and a single dict (instead of a
and small models occasionally hand back a single ``{description, 1-element array), logging a WARN. Unrecoverable shapes return a string
subagent_type}`` dict instead of a one-element array. Both are the caller surfaces as the tool error.
recovered here with a WARN log so the issue is visible in metrics
but the user's turn still completes; truly broken shapes return a
plain string that the caller surfaces as the tool error.
""" """
if isinstance(tasks, list): if isinstance(tasks, list):
return tasks return tasks
@ -587,13 +549,10 @@ def build_task_tool_with_parent_config(
async def _adispatch_batch( async def _adispatch_batch(
tasks: list[dict], runtime: ToolRuntime tasks: list[dict], runtime: ToolRuntime
) -> Command | str: ) -> Command | str:
"""Fan-out helper for the ``tasks`` array shape. """Fan out the ``tasks`` array (size- and concurrency-capped).
Bounded by :data:`MAX_SUBAGENT_BATCH_SIZE` and concurrency-capped Returns one Command; the LLM sees one ``[task <index>]``-prefixed block
at :data:`DEFAULT_SUBAGENT_BATCH_CONCURRENCY`. Returns a single per child, in input order.
:class:`Command` that the LLM sees as one ToolMessage per child,
prefixed with ``[task <index>]`` so it can map back to the input
order.
""" """
if not tasks: if not tasks:
return "tasks: array is empty; nothing to dispatch." return "tasks: array is empty; nothing to dispatch."
@ -703,17 +662,16 @@ def build_task_tool_with_parent_config(
if pending_value is not None: if pending_value is not None:
resume_value = consume_surfsense_resume(runtime) resume_value = consume_surfsense_resume(runtime)
if resume_value is None: if resume_value is None:
# Bridge invariant: a queued resume must accompany any pending # A pending interrupt must have a queued resume; otherwise replay
# subagent interrupt. Fall-through replay would silently re-prompt # would silently re-prompt the user. Raise instead.
# the user; raise so the streaming layer surfaces a clear error.
raise RuntimeError( raise RuntimeError(
f"Subagent {subagent_type!r} has a pending interrupt but no " f"Subagent {subagent_type!r} has a pending interrupt but no "
"surfsense_resume_value on config; resume bridge is broken." "surfsense_resume_value on config; resume bridge is broken."
) )
expected = hitlrequest_action_count(pending_value) expected = hitlrequest_action_count(pending_value)
resume_value = fan_out_decisions_to_match(resume_value, expected) resume_value = fan_out_decisions_to_match(resume_value, expected)
# Prevent the parent's resume payload from leaking into subagent # Stop the parent's resume leaking into subagent interrupts via
# interrupts via langgraph's parent_scratchpad fallback. # langgraph's parent_scratchpad fallback.
drain_parent_null_resume(runtime) drain_parent_null_resume(runtime)
with ot.subagent_invoke_span( with ot.subagent_invoke_span(
subagent_type=subagent_type, path=invoke_path subagent_type=subagent_type, path=invoke_path
@ -829,10 +787,8 @@ def build_task_tool_with_parent_config(
] = None, ] = None,
) -> str | Command: ) -> str | Command:
atask_start = time.perf_counter() atask_start = time.perf_counter()
# Kill switch: when ops flips the spawn-paused flag for this # Ops kill switch: short-circuit every task() call for this workspace
# workspace, every ``task(...)`` invocation (single- or batch-mode) # so the orchestrator stops hammering downstream APIs.
# short-circuits with a clear ToolMessage so the orchestrator can
# tell the user what happened and stop hammering downstream APIs.
if await is_spawn_paused(search_space_id): if await is_spawn_paused(search_space_id):
logger.warning( logger.warning(
"[hitl_route] atask SPAWN_PAUSED: search_space_id=%s tool_call_id=%s", "[hitl_route] atask SPAWN_PAUSED: search_space_id=%s tool_call_id=%s",
@ -923,8 +879,8 @@ def build_task_tool_with_parent_config(
) )
expected = hitlrequest_action_count(pending_value) expected = hitlrequest_action_count(pending_value)
resume_value = fan_out_decisions_to_match(resume_value, expected) resume_value = fan_out_decisions_to_match(resume_value, expected)
# Prevent the parent's resume payload from leaking into subagent # Stop the parent's resume leaking into subagent interrupts via
# interrupts via langgraph's parent_scratchpad fallback. # langgraph's parent_scratchpad fallback.
drain_parent_null_resume(runtime) drain_parent_null_resume(runtime)
with ot.subagent_invoke_span( with ot.subagent_invoke_span(
subagent_type=subagent_type, path=invoke_path subagent_type=subagent_type, path=invoke_path

View file

@ -1,33 +1,19 @@
"""End-of-turn persistence for the cloud-mode SurfSense filesystem. """End-of-turn persistence for the cloud-mode SurfSense filesystem.
This middleware runs ``aafter_agent`` once per turn (cloud only). It commits Runs ``aafter_agent`` once per turn (cloud only), committing staged folder
all staged folder creations, file moves, content writes/edits, file deletes creates, moves, writes/edits, and ``rm``/``rmdir`` to Postgres in one ordered
(``rm``), and directory deletes (``rmdir``) to Postgres in a single ordered pass. Order matters: moves resolve before writes (so write-then-move lands at
pass: the final path), and file deletes run before directory deletes (so a same-turn
``rm /a/x.md`` + ``rmdir /a`` works).
1. Materialize ``staged_dirs`` into ``Folder`` rows. When ``flags.enable_action_log`` is on, each destructive op also snapshots a
2. Apply ``pending_moves`` in order (chained moves resolved via ``DocumentRevision`` / ``FolderRevision`` for revert. For ``rm``/``rmdir`` the
``doc_id_by_path``). snapshot and DELETE share a SAVEPOINT, so a failed snapshot aborts the delete
3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move rather than making the data silently irreversible.
sequences commit at the final path. Paths queued for ``rm`` this turn
are dropped here so a write+rm sequence doesn't recreate the doc.
4. Commit content writes / edits for ``/documents/*`` paths, skipping
``temp_*`` basenames.
5. Apply ``pending_deletes`` (``rm``) file deletes run BEFORE directory
deletes so a same-turn ``rm /a/x.md`` + ``rmdir /a`` sequence works.
6. Apply ``pending_dir_deletes`` (``rmdir``); re-verifies emptiness against
the post-step-5 DB state.
When ``flags.enable_action_log`` is on every destructive op also writes a The commit body is a free function (``commit_staged_filesystem_state``) so the
``DocumentRevision`` / ``FolderRevision`` snapshot bound to the stream-task fallback can run the identical routine when ``aafter_agent`` was
originating ``AgentActionLog`` row via ``tool_call_id``. ``rm``/``rmdir`` skipped (e.g. client disconnect).
share a single ``SAVEPOINT`` with their snapshot if the snapshot fails
the DELETE rolls back and we surface the error rather than silently
making the data irreversible.
The commit body is exposed as a free function ``commit_staged_filesystem_state``
so the optional stream-task fallback (``stream_new_chat.py``) can call the
exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect).
""" """
from __future__ import annotations from __future__ import annotations
@ -216,11 +202,9 @@ async def _create_document(
virtual_path, virtual_path,
search_space_id, search_space_id,
) )
# Filesystem-parity invariant: the only thing that *must* be unique is # Pre-check the path-derived unique_identifier_hash so a duplicate path
# the path. Two notes can legitimately share content (e.g. ``cp a b``). # surfaces as a clean ValueError instead of an INSERT IntegrityError that
# Guard against the path-derived ``unique_identifier_hash`` constraint # poisons the session. Content is intentionally not unique (cp a b).
# so we surface a clean ValueError instead of letting the INSERT poison
# the session with an IntegrityError.
path_collision = await session.execute( path_collision = await session.execute(
select(Document.id).where( select(Document.id).where(
Document.search_space_id == search_space_id, Document.search_space_id == search_space_id,
@ -232,13 +216,6 @@ async def _create_document(
f"a document already exists at path '{virtual_path}' " f"a document already exists at path '{virtual_path}' "
"(unique_identifier_hash collision)" "(unique_identifier_hash collision)"
) )
# ``content_hash`` is intentionally NOT checked for uniqueness here.
# In a real filesystem two files at different paths can hold identical
# bytes, and the agent's ``write_file`` path needs that semantic to
# support copy/duplicate operations. The hash remains useful as a
# change-detection hint for connector indexers, which still consult it
# via :func:`check_duplicate_document` but do so with a non-unique
# lookup (``.first()``).
content_hash = generate_content_hash(content, search_space_id) content_hash = generate_content_hash(content, search_space_id)
doc = Document( doc = Document(
title=title, title=title,
@ -435,15 +412,9 @@ async def _mark_action_reversible(
) -> None: ) -> None:
"""Flip ``agent_action_log.reversible = TRUE`` for ``action_id``. """Flip ``agent_action_log.reversible = TRUE`` for ``action_id``.
Best-effort: caller may invoke from inside a SAVEPOINT and treat Pair with ``_dispatch_reversibility_update`` *after* the enclosing
failure as a soft demotion (snapshot persists, just no Revert button). SAVEPOINT commits, so the UI never sees ``reversible=true`` for a row whose
update later rolls back.
Callers should also call ``_dispatch_reversibility_update`` (defined
below) AFTER the enclosing SAVEPOINT block exits successfully so the
chat tool card can light up its Revert button without
re-fetching ``GET /threads/.../actions``. Dispatching from inside the
SAVEPOINT would risk emitting "reversible=true" for rows whose
update gets rolled back if the surrounding destructive op fails.
""" """
if action_id is None: if action_id is None:
return return
@ -455,22 +426,11 @@ async def _mark_action_reversible(
async def _dispatch_reversibility_update(action_id: int | None) -> None: async def _dispatch_reversibility_update(action_id: int | None) -> None:
"""Best-effort dispatch of an ``action_log_updated`` custom event. """Emit an ``action_log_updated`` SSE event so the Revert button lights up.
Surfaces the post-SAVEPOINT reversibility flip to the SSE layer so Best-effort (failures swallowed; the REST actions endpoint is
the chat tool card can flip its Revert button live. Defensive: authoritative). Inside :func:`commit_staged_filesystem_state` this is
failures are logged at debug level and swallowed; the deferred until after the outer commit via ``deferred_dispatches``.
REST endpoint ``GET /threads/.../actions`` is still authoritative.
.. warning::
Inside :func:`commit_staged_filesystem_state` we DEFER all
dispatches until the outer ``session.commit()`` succeeds see
the ``deferred_dispatches`` queue in that function. Dispatching
from inside a SAVEPOINT block while the outer transaction is
still pending would emit ``reversible=true`` for rows whose
snapshots get rolled back if the outer commit fails. Direct
callers (e.g. the optional stream-task fallback) that own the
full session lifetime can still call this helper inline.
""" """
if action_id is None: if action_id is None:
return return
@ -489,12 +449,9 @@ async def _dispatch_reversibility_update(action_id: int | None) -> None:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Snapshot helpers # Snapshot helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# # Best-effort variants (write/edit/move/mkdir) swallow failures. Strict
# Best-effort helpers swallow + log so a snapshot failure can never break # variants (rm/rmdir) share the destructive op's SAVEPOINT so a snapshot
# the destructive op for non-destructive tools (write/edit/move/mkdir). # failure aborts the delete instead of making it silently irreversible.
# Strict helpers run inside the SAME ``begin_nested()`` SAVEPOINT as the
# destructive DELETE — failure aborts the savepoint and leaves the doc /
# folder intact, so revertable ops never become irreversible silently.
def _doc_revision_payload( def _doc_revision_payload(
@ -704,15 +661,9 @@ async def commit_staged_filesystem_state(
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Commit all staged filesystem changes; return the state delta for reducers. """Commit all staged filesystem changes; return the state delta for reducers.
Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` and
and the optional stream-task fallback. the stream-task fallback. See the module docstring for ordering and the
action-log snapshot/revert semantics.
When ``flags.enable_action_log`` is on every destructive op also writes
a ``DocumentRevision`` / ``FolderRevision`` snapshot bound to the
originating ``AgentActionLog`` row via ``tool_call_id``. Snapshot
durability is best-effort for non-destructive ops and STRICT for
``rm``/``rmdir`` (snapshot + DELETE share a SAVEPOINT snapshot
failure aborts the delete).
""" """
if filesystem_mode != FilesystemMode.CLOUD: if filesystem_mode != FilesystemMode.CLOUD:
return None return None
@ -771,8 +722,7 @@ async def commit_staged_filesystem_state(
flags = get_flags() flags = get_flags()
snapshot_enabled = flags.enable_action_log snapshot_enabled = flags.enable_action_log
# De-duplicate pending deletes per-path while preserving the latest # De-dup deletes per-path, keeping the latest tool_call_id (likeliest revert).
# tool_call_id (the one the user is most likely to revert via the UI).
file_delete_paths: dict[str, str] = {} file_delete_paths: dict[str, str] = {}
for entry in pending_deletes: for entry in pending_deletes:
if not isinstance(entry, dict): if not isinstance(entry, dict):
@ -796,22 +746,14 @@ async def commit_staged_filesystem_state(
applied_moves: list[dict[str, Any]] = [] applied_moves: list[dict[str, Any]] = []
doc_id_path_tombstones: dict[str, int | None] = {} doc_id_path_tombstones: dict[str, int | None] = {}
tree_changed = False tree_changed = False
# Reversibility-flip dispatches are deferred until AFTER the outer # Reversibility-flip dispatches are drained only after the outer commit
# ``session.commit()`` succeeds. Dispatching from inside the # succeeds (and abandoned on rollback), so the UI never sees reversible=true
# SAVEPOINT chain while the outer transaction is still pending # for a snapshot that didn't durably land.
# would emit ``reversible=true`` for rows whose snapshots get rolled
# back if the final commit raises. Snapshot helpers append on
# success; we drain this list after commit and silently abandon it
# on rollback so the UI stays consistent with durable state.
deferred_dispatches: list[int] = [] deferred_dispatches: list[int] = []
try: try:
async with shielded_async_session() as session: async with shielded_async_session() as session:
# ------------------------------------------------------------------ # Resolve all action-id bindings in one SELECT per turn, not per op.
# Resolve action-id bindings up front. One SELECT per turn for all
# tool_call_ids, NOT one per op — important because a turn that
# touches 50 paths would otherwise issue 50 lookups.
# ------------------------------------------------------------------
action_id_by_call: dict[str, int] = {} action_id_by_call: dict[str, int] = {}
if snapshot_enabled and thread_id is not None: if snapshot_enabled and thread_id is not None:
tool_call_ids: set[str] = set() tool_call_ids: set[str] = set()
@ -844,10 +786,7 @@ async def commit_staged_filesystem_state(
next(iter(action_id_by_call), None) if action_id_by_call else None next(iter(action_id_by_call), None) if action_id_by_call else None
) )
# ------------------------------------------------------------------ # 1. staged_dirs -> Folder rows (snapshot post-flush for the FK).
# 1. staged_dirs -> Folder rows. Snapshot post-flush so the new
# folder_id is available for the FK.
# ------------------------------------------------------------------
for folder_path in staged_dirs: for folder_path in staged_dirs:
if not isinstance(folder_path, str): if not isinstance(folder_path, str):
continue continue
@ -868,7 +807,6 @@ async def commit_staged_filesystem_state(
tcid = staged_dir_tool_calls.get(folder_path) tcid = staged_dir_tool_calls.get(folder_path)
action_id = _action_id_for(tcid) action_id = _action_id_for(tcid)
if action_id is not None: if action_id is not None:
# Re-read the folder for the snapshot.
result = await session.execute( result = await session.execute(
select(Folder).where(Folder.id == folder_id) select(Folder).where(Folder.id == folder_id)
) )
@ -883,16 +821,13 @@ async def commit_staged_filesystem_state(
deferred_dispatches=deferred_dispatches, deferred_dispatches=deferred_dispatches,
) )
# ------------------------------------------------------------------ # 2. pending_moves (snapshot pre-move for in-place restore on revert).
# 2. pending_moves. Snapshot pre-move (in-place restore on revert).
# ------------------------------------------------------------------
for move in pending_moves: for move in pending_moves:
source = str(move.get("source") or "") source = str(move.get("source") or "")
if snapshot_enabled and source: if snapshot_enabled and source:
tcid = str(move.get("tool_call_id") or "") tcid = str(move.get("tool_call_id") or "")
action_id = _action_id_for(tcid) action_id = _action_id_for(tcid)
if action_id is not None: if action_id is not None:
# Resolve the doc to snapshot BEFORE we mutate it.
doc_id_pre = doc_id_by_path.get(source) doc_id_pre = doc_id_by_path.get(source)
document_pre: Document | None = None document_pre: Document | None = None
if doc_id_pre is not None: if doc_id_pre is not None:
@ -942,10 +877,8 @@ async def commit_staged_filesystem_state(
path = move_alias[path] path = move_alias[path]
return path return path
# ------------------------------------------------------------------ # 3. dirty_paths -> writes/edits. Paths queued for rm this turn are
# 3. dirty_paths -> writes/edits. Skip any path queued for ``rm`` # skipped so a write+rm sequence doesn't recreate the doc.
# this turn so a write+rm sequence doesn't recreate the doc.
# ------------------------------------------------------------------
kb_dirty_seen: set[str] = set() kb_dirty_seen: set[str] = set()
kb_dirty: list[str] = [] kb_dirty: list[str] = []
kb_dirty_origin: dict[str, str] = {} kb_dirty_origin: dict[str, str] = {}
@ -974,9 +907,7 @@ async def commit_staged_filesystem_state(
continue continue
content = "\n".join(file_data.get("content") or []) content = "\n".join(file_data.get("content") or [])
doc_id = doc_id_by_path.get(path) doc_id = doc_id_by_path.get(path)
# Path ↔ tool_call_id binding: the dirty_paths list dedupes via # Look up tool_call_id by final path or its pre-rename origin.
# _add_unique_reducer, so we look up the latest tool_call_id by
# path (or by the un-renamed origin).
origin = kb_dirty_origin.get(path, path) origin = kb_dirty_origin.get(path, path)
tcid = dirty_path_tool_calls.get(path) or dirty_path_tool_calls.get( tcid = dirty_path_tool_calls.get(path) or dirty_path_tool_calls.get(
origin origin
@ -984,12 +915,9 @@ async def commit_staged_filesystem_state(
action_id = _action_id_for(tcid) action_id = _action_id_for(tcid)
if doc_id is None: if doc_id is None:
# The in-memory ``doc_id_by_path`` is per-thread and starts # doc_id_by_path is per-thread and empty in a new chat, so a
# empty in every new chat. If the agent writes to a path # write to a path already in the DB must update in place, not
# that already exists in the DB (e.g. a previous chat's # INSERT (which would hit the path-derived unique hash).
# ``notes.md``), we must NOT try to INSERT — it would hit
# ``unique_identifier_hash`` (path-derived). Look up the
# existing doc and update it in place instead.
existing = await virtual_path_to_doc( existing = await virtual_path_to_doc(
session, session,
search_space_id=search_space_id, search_space_id=search_space_id,
@ -1038,12 +966,9 @@ async def commit_staged_filesystem_state(
} }
) )
else: else:
# Fresh create. Wrap each create in a SAVEPOINT so a # Fresh create, wrapped in a SAVEPOINT so a residual
# residual ``IntegrityError`` (e.g. a deployment that # IntegrityError (e.g. pre-migration-133 content_hash UNIQUE)
# hasn't run migration 133 yet, where # rolls back only this create, not the whole turn.
# ``documents.content_hash`` still carries its legacy
# global UNIQUE constraint) rolls back only this one
# create instead of poisoning the whole turn.
placeholder_revision_id: int | None = None placeholder_revision_id: int | None = None
if snapshot_enabled and action_id is not None: if snapshot_enabled and action_id is not None:
placeholder_revision_id = await _snapshot_document_pre_create( placeholder_revision_id = await _snapshot_document_pre_create(
@ -1066,8 +991,7 @@ async def commit_staged_filesystem_state(
logger.warning( logger.warning(
"kb_persistence: skipping %s create: %s", path, exc "kb_persistence: skipping %s create: %s", path, exc
) )
# Roll back the placeholder revision since the create # Create never happened; drop its placeholder revision.
# never happened.
if placeholder_revision_id is not None: if placeholder_revision_id is not None:
await session.execute( await session.execute(
delete(DocumentRevision).where( delete(DocumentRevision).where(
@ -1114,19 +1038,14 @@ async def commit_staged_filesystem_state(
) )
tree_changed = True tree_changed = True
# ------------------------------------------------------------------ # 4. pending_deletes -> rm. Strict: snapshot + DELETE share a
# 4. pending_deletes -> ``rm``. STRICT durability: snapshot + DELETE # SAVEPOINT, so a failed snapshot rolls the delete back too.
# share a SAVEPOINT. If the snapshot insert fails, the DELETE
# rolls back too and we surface the error rather than silently
# making the data irreversible.
# ------------------------------------------------------------------
for raw_path, tcid in file_delete_paths.items(): for raw_path, tcid in file_delete_paths.items():
final = _final_path(raw_path) final = _final_path(raw_path)
if not final.startswith(DOCUMENTS_ROOT + "/"): if not final.startswith(DOCUMENTS_ROOT + "/"):
continue continue
action_id = _action_id_for(tcid) action_id = _action_id_for(tcid)
# Resolve the doc.
doc_id_for_delete = doc_id_by_path.get(final) doc_id_for_delete = doc_id_by_path.get(final)
document_to_delete: Document | None = None document_to_delete: Document | None = None
if doc_id_for_delete is not None: if doc_id_for_delete is not None:
@ -1155,7 +1074,6 @@ async def commit_staged_filesystem_state(
try: try:
async with session.begin_nested(): async with session.begin_nested():
# Strict: snapshot first; failure aborts the delete.
if snapshot_enabled and action_id is not None: if snapshot_enabled and action_id is not None:
chunks = await _load_chunks_for_snapshot( chunks = await _load_chunks_for_snapshot(
session, doc_id=doc_pk session, doc_id=doc_pk
@ -1184,10 +1102,7 @@ async def commit_staged_filesystem_state(
) )
continue continue
# B1 — SAVEPOINT released. Defer the reversibility-flip # Defer the reversibility flip until after the outer commit.
# dispatch until AFTER the outer commit succeeds so we
# never tell the UI a row is reversible if its snapshot
# gets rolled back.
if snapshot_enabled and action_id is not None: if snapshot_enabled and action_id is not None:
deferred_dispatches.append(int(action_id)) deferred_dispatches.append(int(action_id))
@ -1206,11 +1121,8 @@ async def commit_staged_filesystem_state(
) )
tree_changed = True tree_changed = True
# ------------------------------------------------------------------ # 5. pending_dir_deletes -> rmdir. Strict, and re-checks emptiness
# 5. pending_dir_deletes -> ``rmdir``. STRICT durability + final # against post-step-4 DB state.
# emptiness check (after step 4's deletes have run, an "empty
# mid-turn" directory really IS empty in DB now).
# ------------------------------------------------------------------
for raw_path, tcid in dir_delete_paths.items(): for raw_path, tcid in dir_delete_paths.items():
final = _final_path(raw_path) final = _final_path(raw_path)
if not final.startswith(DOCUMENTS_ROOT + "/"): if not final.startswith(DOCUMENTS_ROOT + "/"):
@ -1231,7 +1143,6 @@ async def commit_staged_filesystem_state(
) )
continue continue
# Re-check emptiness against in-DB state.
docs_in_folder = await session.execute( docs_in_folder = await session.execute(
select(Document.id) select(Document.id)
.where(Document.folder_id == folder_id) .where(Document.folder_id == folder_id)
@ -1296,10 +1207,7 @@ async def commit_staged_filesystem_state(
) )
continue continue
# B1 — SAVEPOINT released. Defer the reversibility-flip # Defer the reversibility flip until after the outer commit.
# dispatch until AFTER the outer commit succeeds so we
# never tell the UI a row is reversible if its snapshot
# gets rolled back.
if snapshot_enabled and action_id is not None: if snapshot_enabled and action_id is not None:
deferred_dispatches.append(int(action_id)) deferred_dispatches.append(int(action_id))
@ -1319,18 +1227,13 @@ async def commit_staged_filesystem_state(
logger.exception( logger.exception(
"kb_persistence: commit failed (search_space=%s)", search_space_id "kb_persistence: commit failed (search_space=%s)", search_space_id
) )
# Outer commit raised — every SAVEPOINT-released change above # Outer commit raised: everything above rolled back, so drop the
# (snapshots + reversibility flips) is now rolled back. Drop # deferred dispatches.
# the deferred SSE dispatches so the UI stays consistent with
# durable state.
deferred_dispatches.clear() deferred_dispatches.clear()
return None return None
# Outer commit succeeded; flush deferred reversibility-flip # Commit succeeded; flush deferred reversibility flips (de-duped, since
# dispatches now so the chat tool card can light up its Revert # write-then-rm in one turn appends an id per snapshot site).
# button without re-fetching ``GET /threads/.../actions``. De-dup
# to avoid emitting the same id twice (e.g. write-then-rm in the
# same turn dispatches once for each snapshot site).
if deferred_dispatches and dispatch_events: if deferred_dispatches and dispatch_events:
for action_id in dict.fromkeys(deferred_dispatches): for action_id in dict.fromkeys(deferred_dispatches):
try: try:
@ -1376,9 +1279,8 @@ async def commit_staged_filesystem_state(
p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX) p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX)
] ]
# Tombstone every committed-delete path so a stale ``state["files"]`` entry # Tombstone committed-delete paths so a stale state["files"] entry can't
# (which als_info would otherwise interpret as content) cannot survive into # survive into the next turn and make a now-empty folder look non-empty.
# the next turn and make a now-empty folder look non-empty.
deleted_file_paths = [ deleted_file_paths = [
str(payload.get("virtualPath") or "") str(payload.get("virtualPath") or "")
for payload in committed_deletes for payload in committed_deletes
@ -1399,11 +1301,8 @@ async def commit_staged_filesystem_state(
"dirty_path_tool_calls": {_CLEAR: True}, "dirty_path_tool_calls": {_CLEAR: True},
} }
# Emit one Receipt per committed mutation, folded into ``state['receipts']`` # One Receipt per committed mutation: ground truth (post-savepoint) for the
# via ``_list_append_reducer``. The receipts surface what actually committed # orchestrator's <verification> teaching. KB writes have no public URL.
# (post-savepoint) rather than what the LLM intended; the orchestrator uses
# them as ground truth in the ``<verification>`` teaching. KB writes do not
# have public verifiable URLs, so ``verifiable_url`` stays unset.
receipts: list[Receipt] = [] receipts: list[Receipt] = []
def _kb_receipt( def _kb_receipt(
@ -1444,8 +1343,6 @@ async def commit_staged_filesystem_state(
external_id=payload.get("id"), external_id=payload.get("id"),
) )
for payload in applied_moves: for payload in applied_moves:
# ``applied_moves`` rows carry the destination ``virtualPath`` because
# the move has already landed in the DB by the time we reach this code.
path = str(payload.get("virtualPath") or "") path = str(payload.get("virtualPath") or "")
_kb_receipt( _kb_receipt(
type="file", type="file",
@ -1485,9 +1382,7 @@ async def commit_staged_filesystem_state(
if tree_changed: if tree_changed:
delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1 delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1
# Avoid 'unused' lint when turn_id_for_revision was only useful for _ = turn_id_for_revision # diagnostic-only; silence unused lint
# diagnostic purposes inside the SAVEPOINT chain above.
_ = turn_id_for_revision
logger.info( logger.info(
"kb_persistence: commit (search_space=%s) creates=%d updates=%d " "kb_persistence: commit (search_space=%s) creates=%d updates=%d "

View file

@ -29,7 +29,6 @@ def extract_domain(url: str) -> str:
try: try:
parsed = urlparse(url) parsed = urlparse(url)
domain = parsed.netloc domain = parsed.netloc
# Remove 'www.' prefix if present
if domain.startswith("www."): if domain.startswith("www."):
domain = domain[4:] domain = domain[4:]
return domain return domain
@ -53,14 +52,13 @@ def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]:
if len(content) <= max_length: if len(content) <= max_length:
return content, False return content, False
# Try to truncate at a sentence boundary # Prefer truncating at a sentence/paragraph boundary.
truncated = content[:max_length] truncated = content[:max_length]
last_period = truncated.rfind(".") last_period = truncated.rfind(".")
last_newline = truncated.rfind("\n\n") last_newline = truncated.rfind("\n\n")
# Use the later of the two boundaries, or just truncate
boundary = max(last_period, last_newline) boundary = max(last_period, last_newline)
if boundary > max_length * 0.8: # Only use boundary if it's not too far back if boundary > max_length * 0.8: # only if the boundary isn't too far back
truncated = content[: boundary + 1] truncated = content[: boundary + 1]
return truncated + "\n\n[Content truncated...]", True return truncated + "\n\n[Content truncated...]", True
@ -111,8 +109,8 @@ async def _scrape_youtube_video(
http_client.proxies.update(residential_proxies) http_client.proxies.update(residential_proxies)
ytt_api = YouTubeTranscriptApi(http_client=http_client) ytt_api = YouTubeTranscriptApi(http_client=http_client)
# List all available transcripts and pick the first one # Pick the first transcript (video's primary language) rather than
# (the video's primary language) instead of defaulting to English # defaulting to English.
transcript_list = ytt_api.list(video_id) transcript_list = ytt_api.list(video_id)
transcript = next(iter(transcript_list)) transcript = next(iter(transcript_list))
captions = transcript.fetch() captions = transcript.fetch()
@ -134,10 +132,8 @@ async def _scrape_youtube_video(
logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}") logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}")
transcript_text = f"No captions available for this video. Error: {e!s}" transcript_text = f"No captions available for this video. Error: {e!s}"
# Build combined content
content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}" content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}"
# Truncate if needed
content, was_truncated = truncate_content(content, max_length) content, was_truncated = truncate_content(content, max_length)
word_count = len(content.split()) word_count = len(content.split())
@ -212,20 +208,16 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
scrape_id = generate_scrape_id(url) scrape_id = generate_scrape_id(url)
domain = extract_domain(url) domain = extract_domain(url)
# Validate and normalize URL
if not url.startswith(("http://", "https://")): if not url.startswith(("http://", "https://")):
url = f"https://{url}" url = f"https://{url}"
try: try:
# Check if this is a YouTube URL and use transcript API instead # YouTube URLs use the transcript API instead of crawling.
video_id = get_youtube_video_id(url) video_id = get_youtube_video_id(url)
if video_id: if video_id:
return await _scrape_youtube_video(url, video_id, max_length) return await _scrape_youtube_video(url, video_id, max_length)
# Create webcrawler connector
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key) connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
# Crawl the URL
result, error = await connector.crawl_url(url, formats=["markdown"]) result, error = await connector.crawl_url(url, formats=["markdown"])
if error: if error:
@ -250,28 +242,21 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
"error": "No content returned from crawler", "error": "No content returned from crawler",
} }
# Extract content and metadata
content = result.get("content", "") content = result.get("content", "")
metadata = result.get("metadata", {}) metadata = result.get("metadata", {})
# Get title from metadata
title = metadata.get("title", "") title = metadata.get("title", "")
if not title: if not title:
title = domain or url.split("/")[-1] or "Webpage" title = domain or url.split("/")[-1] or "Webpage"
# Get description from metadata
description = metadata.get("description", "") description = metadata.get("description", "")
if not description and content: if not description and content:
# Use first paragraph as description
first_para = content.split("\n\n")[0] if content else "" first_para = content.split("\n\n")[0] if content else ""
description = ( description = (
first_para[:300] + "..." if len(first_para) > 300 else first_para first_para[:300] + "..." if len(first_para) > 300 else first_para
) )
# Truncate content if needed
content, was_truncated = truncate_content(content, max_length) content, was_truncated = truncate_content(content, max_length)
# Calculate word count
word_count = len(content.split()) word_count = len(content.split())
return { return {

View file

@ -1,37 +1,9 @@
""" """Feature flags for the SurfSense new_chat agent stack.
Feature flags for the SurfSense new_chat agent stack.
These flags gate the newer agent middleware (some ported from OpenCode, Flags are resolved at agent build time. Most upgrades default ON so Docker
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some updates work without operators adding new env vars; risky integrations stay
SurfSense-native). Most shipped agent-stack upgrades default ON so Docker OFF. The master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` forces every
image updates work even when older installs do not have newly introduced flag below to False for a one-switch rollback to pre-port behavior.
environment variables. Risky/experimental integrations stay default OFF,
and the master kill-switch can still disable everything new.
All new middleware checks its flag at agent build time. If the master
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
middleware is disabled regardless of its individual flag. This gives
operators a single switch to revert to pre-port behavior.
Examples
--------
Defaults:
SURFSENSE_ENABLE_CONTEXT_EDITING=true
SURFSENSE_ENABLE_COMPACTION_V2=true
SURFSENSE_ENABLE_RETRY_AFTER=true
SURFSENSE_ENABLE_MODEL_FALLBACK=false
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
SURFSENSE_ENABLE_PERMISSION=true
SURFSENSE_ENABLE_DOOM_LOOP=true
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
Master kill-switch (overrides everything else):
SURFSENSE_DISABLE_NEW_AGENT_STACK=true
""" """
from __future__ import annotations from __future__ import annotations
@ -93,39 +65,14 @@ class AgentFeatureFlags:
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT) # Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
enable_otel: bool = False enable_otel: bool = False
# Performance — compiled-agent cache (Phase 1 + Phase 2). # Performance — reuse a compiled agent graph when the cache key matches
# When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled # (~4-5s -> <50µs per turn). Safe to default-on because mutation tools take
# graph if the cache key matches (LLM config + thread + tool surface + # fresh short-lived sessions per call and per-turn context (mentions, etc.)
# flags + system prompt + filesystem mode). Cuts per-turn agent-build # is read from runtime.context, not the constructor closure. Rollback via
# wall clock from ~4-5s to <50µs on cache hits. # SURFSENSE_ENABLE_AGENT_CACHE=false.
#
# SAFETY (Phase 2 unblocked this default-on):
# All connector mutation tools (``tools/notion``, ``tools/gmail``,
# ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``,
# ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``,
# ``tools/teams``, ``tools/luma``, ``connected_accounts``,
# ``update_memory``) now acquire fresh
# short-lived ``AsyncSession`` instances per call via
# :data:`async_session_maker`. The factory still accepts ``db_session``
# for registry compatibility but ``del``'s it immediately — see any
# of those files' factory docstrings for the rationale. The ``llm``
# closure is per-(provider, model, config_id) which is already in
# the cache key, so the LLM is safe to share across cached hits of
# the same key. The KB priority middleware reads
# ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5),
# not its constructor closure, so the same compiled agent serves
# turns with different mention lists correctly.
#
# Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the
# environment if a regression surfaces. The path is exercised by
# the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite.
enable_agent_cache: bool = True enable_agent_cache: bool = True
# Phase 1 (deferred — measure first): pre-build & share the # Deferred: only helps on outer-cache MISSES, so off until data shows cold
# general-purpose subagent ``CompiledSubAgent`` across cold-cache # misses are frequent enough to justify the extra global state.
# misses. Only helps when the outer cache MISSES (cache hits already
# reuse the entire SubAgentMiddleware-compiled graph). Off by default
# until we have data showing cold misses are frequent enough to
# justify the extra global state.
enable_agent_cache_share_gp_subagent: bool = False enable_agent_cache_share_gp_subagent: bool = False
@classmethod @classmethod

View file

@ -594,14 +594,9 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
inject_system_message: bool = True, # For backwards compatibility inject_system_message: bool = True, # For backwards compatibility
) -> None: ) -> None:
self.llm = llm self.llm = llm
# The planner LLM handles short, structured internal tasks (query # Cheap model for structured internal tasks (query rewrite, date
# rewriting, date extraction, recency classification). When an # extraction, recency classification) when one is configured; falls back
# operator marks a global config ``is_planner: true`` we route # to the chat LLM otherwise.
# those calls to a cheap/fast model (e.g. gpt-4o-mini, Haiku, Azure
# gpt-5.x-nano) instead of the user's chat LLM — those classification
# tasks don't need frontier-tier capability. Falls back to the chat
# LLM when no planner config is wired up so deployments without one
# keep working unchanged.
self.planner_llm = planner_llm or llm self.planner_llm = planner_llm or llm
self.search_space_id = search_space_id self.search_space_id = search_space_id
self.filesystem_mode = filesystem_mode self.filesystem_mode = filesystem_mode
@ -610,26 +605,17 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
self.top_k = top_k self.top_k = top_k
self.mentioned_document_ids = mentioned_document_ids or [] self.mentioned_document_ids = mentioned_document_ids or []
self.inject_system_message = inject_system_message self.inject_system_message = inject_system_message
# Build the kb-planner private Runnable ONCE here so we don't pay # Compiled lazily and memoized to avoid the per-turn create_agent cost.
# the ``create_agent`` compile cost (50-200ms) on every turn.
# Disabled by default behind ``enable_kb_planner_runnable``; when
# off the planner falls back to the legacy ``planner_llm.ainvoke``
# path.
self._planner: Runnable | None = None self._planner: Runnable | None = None
self._planner_compile_failed = False self._planner_compile_failed = False
def _build_kb_planner_runnable(self) -> Runnable | None: def _build_kb_planner_runnable(self) -> Runnable | None:
"""Compile the kb-planner private :class:`Runnable` once. """Lazily compile and memoize the kb-planner Runnable.
Returns ``None`` when the feature flag is disabled, when the LLM is Returns ``None`` (and the caller falls back to ``planner_llm.ainvoke``)
unavailable, or when ``create_agent`` raises (we fall back to the when the flag is off, the LLM is missing, or ``create_agent`` raises.
legacy ``planner_llm.ainvoke`` path in that case). Compilation happens Built without tools but with RetryAfterMiddleware so a transient
lazily on first call, then memoized via ``self._planner``. rate-limit on the planner call doesn't fail the whole turn.
The compiled agent is constructed without tools the planner's
contract is "answer with structured JSON" but it inherits the
:class:`RetryAfterMiddleware` so transient rate-limit errors
from the planner LLM call don't fail the whole turn.
""" """
if self._planner is not None or self._planner_compile_failed: if self._planner is not None or self._planner_compile_failed:
return self._planner return self._planner
@ -677,10 +663,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
t0 = loop.time() t0 = loop.time()
# Prefer the compiled-once planner Runnable when enabled; otherwise # Both paths tag surfsense:internal so the planner's intermediate
# fall back to ``planner_llm.ainvoke``. The ``surfsense:internal`` # events stay suppressed from the UI.
# tag is preserved on both paths so ``_stream_agent_events`` still
# suppresses the planner's intermediate events from the UI.
planner = self._build_kb_planner_runnable() planner = self._build_kb_planner_runnable()
try: try:
if planner is not None: if planner is not None:
@ -819,32 +803,16 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
user_text=user_text, user_text=user_text,
) )
# Per-turn ``mentioned_document_ids`` flow: # Prefer per-turn mentions from runtime.context (lets a cached graph
# 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the # serve different turns); fall back to the constructor closure, draining
# streaming task supplies a fresh :class:`SurfSenseContextSchema` # it after one read so stale mentions can't replay.
# on every ``astream_events`` call, so this list is naturally
# scoped to the current turn. Allows cross-turn graph reuse via
# ``agent_cache``.
# 2. Legacy fallback (cache disabled / context not propagated): the
# constructor-injected ``self.mentioned_document_ids`` list. We
# drain it after the first read so a cached graph (no Phase 1.5
# wiring) doesn't keep replaying the same mentions on every
# turn.
# #
# CRITICAL: distinguish "context absent" (legacy caller, no field at # CRITICAL: test ``ctx_mentions is not None``, not truthiness — an empty
# all) from "context provided but empty" (turn with no mentions). # list means "this turn has no mentions", not "use the closure".
# ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
# Python, so a naive ``if ctx_mentions:`` would fall through to the
# legacy closure on every no-mention follow-up turn — replaying the
# mentions baked in by turn 1's cache-miss build. Always drain the
# closure once the runtime path has fired so a cached middleware
# instance can never resurrect stale state.
mention_ids: list[int] = [] mention_ids: list[int] = []
ctx = getattr(runtime, "context", None) if runtime is not None else None ctx = getattr(runtime, "context", None) if runtime is not None else None
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
if ctx_mentions is not None: if ctx_mentions is not None:
# Runtime path is authoritative — even an empty list means
# "this turn has no mentions", NOT "look at the closure".
mention_ids = list(ctx_mentions) mention_ids = list(ctx_mentions)
if self.mentioned_document_ids: if self.mentioned_document_ids:
self.mentioned_document_ids = [] self.mentioned_document_ids = []
@ -852,12 +820,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
mention_ids = list(self.mentioned_document_ids) mention_ids = list(self.mentioned_document_ids)
self.mentioned_document_ids = [] self.mentioned_document_ids = []
# Folder mentions live alongside doc mentions on the runtime # Folder mentions aren't embedded, so they skip hybrid search and are
# context. They never feed hybrid search (folders aren't # surfaced only as [USER-MENTIONED] entries. Cloud mode only.
# embedded) — they're surfaced purely as ``[USER-MENTIONED]``
# priority entries so the agent walks the folder with ``ls`` /
# ``find_documents`` instead of ignoring it. Cloud filesystem
# mode only.
folder_mention_ids: list[int] = [] folder_mention_ids: list[int] = []
if ( if (
ctx is not None ctx is not None
@ -939,14 +903,10 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
async def _materialize_folder_priority( async def _materialize_folder_priority(
self, folder_ids: list[int] self, folder_ids: list[int]
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Resolve user-mentioned folder ids to ``<priority_documents>`` entries. """Resolve mentioned folder ids to canonical-path priority entries.
Each entry uses the canonical ``/documents/Folder/Sub/`` virtual Flagged ``mentioned=True`` with ``score=None`` (folders aren't ranked;
path (matching ``KnowledgeTreeMiddleware`` and the agent's the agent decides which children to read).
``ls`` adapter) and is flagged ``mentioned=True`` so the
rendered line carries ``[USER-MENTIONED]``. ``score`` is left
``None`` so the renderer prints ``n/a`` folders aren't
ranked, the agent decides which children to read.
""" """
if not folder_ids: if not folder_ids:
return [] return []

View file

@ -30,22 +30,11 @@ from langgraph.types import interrupt
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Tools that mirror the safety profile of ``write_file`` against the # Low-stakes creation tools auto-approved by default: each creates one
# SurfSense KB: each call creates ONE artifact in the user's own workspace # artifact in the user's own workspace with no external visibility (drafts
# with no external visibility (drafts aren't sent; new files aren't shared # aren't sent; new files aren't shared). They still call ``request_approval``,
# unless the user shares them later). These are auto-approved by default # which returns ``decision_type="auto_approved"`` without firing an interrupt.
# so the agent can compose drafts and seed scratch files without a popup # Per-search-space ``agent_permission_rules`` can re-enable prompting.
# on every call.
#
# Members of this set still call ``request_approval`` exactly as before;
# the function returns immediately with ``decision_type="auto_approved"``
# and the original params untouched. This preserves the call-site shape
# (logging, metadata fetching, account fallbacks) so the only behavior
# change is "no interrupt fires".
#
# To re-enable prompting, the future per-search-space rules table
# (``agent_permission_rules``) takes precedence in the permission ruleset
# layering assembled by the agent factory.
DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset( DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
{ {
"create_gmail_draft", "create_gmail_draft",
@ -150,10 +139,6 @@ def request_approval(
return HITLResult(rejected=False, decision_type="trusted", params=dict(params)) return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
if tool_name in DEFAULT_AUTO_APPROVED_TOOLS: if tool_name in DEFAULT_AUTO_APPROVED_TOOLS:
# Default policy: low-stakes creation tools (drafts + new-file
# creates) skip HITL because they're as recoverable as a local
# ``write_file`` against the SurfSense KB. The user can still
# delete the artifact in <30s if it's wrong.
logger.info( logger.info(
"Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL", "Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL",
tool_name, tool_name,

View file

@ -75,7 +75,7 @@ def create_generate_image_tool(
captured model), use this config id instead of reading the search space's captured model), use this config id instead of reading the search space's
live ``image_generation_config_id``. live ``image_generation_config_id``.
""" """
del db_session # use a fresh per-call session, see below del db_session # tool uses a fresh per-call session instead
@tool @tool
async def generate_image( async def generate_image(
@ -140,17 +140,12 @@ def create_generate_image_tool(
or IMAGE_GEN_AUTO_MODE_ID or IMAGE_GEN_AUTO_MODE_ID
) )
# Build generation kwargs # size/quality/style are intentionally omitted: valid values
# NOTE: size, quality, and style are intentionally NOT passed. # differ per model, so we let each model use its own defaults.
# Different models support different values for these params
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
# gpt-image-1 wants "high"/"medium"/"low"; size options also
# differ). Letting the model use its own defaults avoids errors.
gen_kwargs: dict[str, Any] = {} gen_kwargs: dict[str, Any] = {}
if n is not None and n > 1: if n is not None and n > 1:
gen_kwargs["n"] = n gen_kwargs["n"] = n
# Call litellm based on config type
if is_image_gen_auto_mode(config_id): if is_image_gen_auto_mode(config_id):
if not ImageGenRouterService.is_initialized(): if not ImageGenRouterService.is_initialized():
err = ( err = (
@ -224,17 +219,13 @@ def create_generate_image_tool(
prompt=prompt, model=model_string, **gen_kwargs prompt=prompt, model=model_string, **gen_kwargs
) )
# Parse the response and store in DB
response_dict = ( response_dict = (
response.model_dump() response.model_dump()
if hasattr(response, "model_dump") if hasattr(response, "model_dump")
else dict(response) else dict(response)
) )
# Generate a random access token for this image
access_token = generate_image_token() access_token = generate_image_token()
# Save to image_generations table for history
db_image_gen = ImageGeneration( db_image_gen = ImageGeneration(
prompt=prompt, prompt=prompt,
model=getattr(response, "_hidden_params", {}).get("model"), model=getattr(response, "_hidden_params", {}).get("model"),
@ -249,7 +240,6 @@ def create_generate_image_tool(
await session.refresh(db_image_gen) await session.refresh(db_image_gen)
db_image_gen_id = db_image_gen.id db_image_gen_id = db_image_gen.id
# Extract image URLs from response
images = response_dict.get("data", []) images = response_dict.get("data", [])
if not images: if not images:
return _failed( return _failed(
@ -260,11 +250,8 @@ def create_generate_image_tool(
first_image = images[0] first_image = images[0]
revised_prompt = first_image.get("revised_prompt", prompt) revised_prompt = first_image.get("revised_prompt", prompt)
# Resolve image URL: # b64_json (e.g. gpt-image-1) is served via our backend endpoint so
# - If the API returned a URL, use it directly. # megabytes of base64 don't bloat the LLM context.
# - If the API returned b64_json (e.g. gpt-image-1), serve the
# image through our backend endpoint to avoid bloating the
# LLM context with megabytes of base64 data.
if first_image.get("url"): if first_image.get("url"):
image_url = first_image["url"] image_url = first_image["url"]
elif first_image.get("b64_json"): elif first_image.get("b64_json"):

View file

@ -241,23 +241,12 @@ def _normalize_connectors(
connectors_to_search: list[str] | None, connectors_to_search: list[str] | None,
available_connectors: list[str] | None = None, available_connectors: list[str] | None = None,
) -> list[str]: ) -> list[str]:
"""Normalize model-supplied connectors to canonical ConnectorService types.
Maps user-facing aliases (e.g. WEBCRAWLER_CONNECTOR), drops unknowns, and
constrains to ``available_connectors`` when given. Empty input defaults to
all available connectors (minus live-search ones).
""" """
Normalize connectors provided by the model.
- Accepts user-facing enums like WEBCRAWLER_CONNECTOR and maps them to canonical
ConnectorService types.
- Drops unknown values.
- If available_connectors is provided, only includes connectors from that list.
- If connectors_to_search is None/empty, defaults to available_connectors or all.
Args:
connectors_to_search: List of connectors requested by the model
available_connectors: List of connectors actually available in the search space
Returns:
List of normalized connector strings to search
"""
# Determine the set of valid connectors to consider
valid_set = ( valid_set = (
set(available_connectors) if available_connectors else set(_ALL_CONNECTORS) set(available_connectors) if available_connectors else set(_ALL_CONNECTORS)
) )
@ -276,18 +265,16 @@ def _normalize_connectors(
c = (raw or "").strip().upper() c = (raw or "").strip().upper()
if not c: if not c:
continue continue
# Map user-facing aliases to canonical names
if c == "WEBCRAWLER_CONNECTOR": if c == "WEBCRAWLER_CONNECTOR":
c = "CRAWLED_URL" c = "CRAWLED_URL"
normalized.append(c) normalized.append(c)
# de-dupe while preserving order + filter to valid connectors # De-dupe (order-preserving), keeping only known + available connectors.
seen: set[str] = set() seen: set[str] = set()
out: list[str] = [] out: list[str] = []
for c in normalized: for c in normalized:
if c in seen: if c in seen:
continue continue
# Only include if it's a known connector AND available
if c not in _ALL_CONNECTORS: if c not in _ALL_CONNECTORS:
continue continue
if c not in valid_set: if c not in valid_set:
@ -295,7 +282,7 @@ def _normalize_connectors(
seen.add(c) seen.add(c)
out.append(c) out.append(c)
# Fallback to all available if nothing matched # Nothing matched: fall back to all available.
if not out: if not out:
base = ( base = (
list(available_connectors) list(available_connectors)
@ -377,39 +364,17 @@ def format_documents_for_context(
max_chunk_chars: int = _MAX_CHUNK_CHARS, max_chunk_chars: int = _MAX_CHUNK_CHARS,
max_chunks_per_doc: int = 0, max_chunks_per_doc: int = 0,
) -> str: ) -> str:
""" """Format retrieved documents into an XML context string for the LLM.
Format retrieved documents into a readable context string for the LLM.
Documents are added in order (highest relevance first) until the character Documents are emitted highest-relevance first until ``max_chars`` is hit.
budget is reached. Individual chunks are capped at ``max_chunk_chars`` and ``max_chunks_per_doc=0`` auto-computes a rank-adaptive cap so top results get
each document is limited to a dynamically computed chunk cap so a single more chunks and no single large document monopolizes the budget.
large document cannot monopolize the output while still maximising the use
of available context space.
Args:
documents: List of document dictionaries from connector search
max_chars: Approximate character budget for the entire output.
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means
auto-compute per document using a rank-adaptive formula so
higher-ranked documents receive more chunks.
Returns:
Formatted string with document contents and metadata
""" """
if not documents: if not documents:
return "" return ""
# Group chunks by document id (preferred) to produce the XML structure. # Group chunks by document id, preserving chunk_id so [citation:123] works.
# # ConnectorService returns document-grouped results ({document, chunks, source}).
# IMPORTANT: ConnectorService returns **document-grouped** results of the form:
# {
# "document": {...},
# "chunks": [{"chunk_id": 123, "content": "..."}, ...],
# "source": "NOTION_CONNECTOR" | "FILE" | ...
# }
#
# We must preserve chunk_id so citations like [citation:123] are possible.
grouped: dict[str, dict[str, Any]] = {} grouped: dict[str, dict[str, Any]] = {}
for doc in documents: for doc in documents:
@ -430,7 +395,7 @@ def format_documents_for_context(
or "UNKNOWN" or "UNKNOWN"
) )
# Document identity (prefer document_id; otherwise fall back to type+title+url) # Identity: prefer document_id, else type+title+url.
document_id_val = document_info.get("id") document_id_val = document_info.get("id")
title = ( title = (
document_info.get("title") or metadata.get("title") or "Untitled Document" document_info.get("title") or metadata.get("title") or "Untitled Document"
@ -460,7 +425,7 @@ def format_documents_for_context(
"chunks": [], "chunks": [],
} }
# Prefer document-grouped chunks if available # Prefer document-grouped chunks when present.
chunks_list = doc.get("chunks") if isinstance(doc, dict) else None chunks_list = doc.get("chunks") if isinstance(doc, dict) else None
if isinstance(chunks_list, list) and chunks_list: if isinstance(chunks_list, list) and chunks_list:
for ch in chunks_list: for ch in chunks_list:
@ -492,7 +457,6 @@ def format_documents_for_context(
"BAIDU_SEARCH_API", "BAIDU_SEARCH_API",
} }
# Render XML expected by citation instructions, respecting the char budget.
parts: list[str] = [] parts: list[str] = []
total_chars = 0 total_chars = 0
total_docs = len(grouped) total_docs = len(grouped)
@ -594,30 +558,11 @@ async def search_knowledge_base_async(
available_document_types: list[str] | None = None, available_document_types: list[str] | None = None,
max_input_tokens: int | None = None, max_input_tokens: int | None = None,
) -> str: ) -> str:
""" """Search the knowledge base across connectors and return formatted results.
Search the user's knowledge base for relevant documents.
This is the async implementation that searches across multiple connectors. ``available_document_types`` lets local connectors with no indexed data be
skipped (no embedding / DB round-trip), and ``max_input_tokens`` sizes the
Args: output to the model's context window.
query: The search query
search_space_id: The user's search space ID
db_session: Database session
connector_service: Initialized connector service
connectors_to_search: Optional list of connector types to search. If omitted, searches all.
top_k: Number of results per connector
start_date: Optional start datetime (UTC) for filtering documents
end_date: Optional end datetime (UTC) for filtering documents
available_connectors: Optional list of connectors actually available in the search space.
If provided, only these connectors will be searched.
available_document_types: Optional list of document types that actually have indexed
data. When provided, local connectors whose document type is
absent are skipped entirely (no embedding / DB round-trip).
max_input_tokens: Model context window size (tokens). Used to dynamically
size the output so it fits within the model's limits.
Returns:
Formatted string with search results
""" """
perf = get_perf_logger() perf = get_perf_logger()
t0 = time.perf_counter() t0 = time.perf_counter()

View file

@ -196,13 +196,8 @@ def _strip_wrapping_code_fences(text: str) -> str:
def _extract_metadata(content: str) -> dict[str, Any]: def _extract_metadata(content: str) -> dict[str, Any]:
"""Extract metadata from generated Markdown content.""" """Extract metadata from generated Markdown content."""
# Count section headings
headings = re.findall(r"^(#{1,6})\s+(.+)$", content, re.MULTILINE) headings = re.findall(r"^(#{1,6})\s+(.+)$", content, re.MULTILINE)
# Word count
word_count = len(content.split()) word_count = len(content.split())
# Character count
char_count = len(content) char_count = len(content)
return { return {
@ -227,12 +222,11 @@ def _parse_sections(content: str) -> list[dict[str, str]]:
in_code_block = False in_code_block = False
for line in lines: for line in lines:
# Track code blocks to avoid matching headings inside them # Track fences so headings inside code blocks aren't treated as splits.
stripped = line.strip() stripped = line.strip()
if stripped.startswith("```"): if stripped.startswith("```"):
in_code_block = not in_code_block in_code_block = not in_code_block
# Only split on # or ## headings (not ### or deeper) and only outside code blocks
is_section_heading = ( is_section_heading = (
not in_code_block not in_code_block
and re.match(r"^#{1,2}\s+", line) and re.match(r"^#{1,2}\s+", line)
@ -240,7 +234,6 @@ def _parse_sections(content: str) -> list[dict[str, str]]:
) )
if is_section_heading: if is_section_heading:
# Save previous section
if current_heading or current_body_lines: if current_heading or current_body_lines:
sections.append( sections.append(
{ {
@ -253,7 +246,6 @@ def _parse_sections(content: str) -> list[dict[str, str]]:
else: else:
current_body_lines.append(line) current_body_lines.append(line)
# Save last section
if current_heading or current_body_lines: if current_heading or current_body_lines:
sections.append( sections.append(
{ {
@ -292,7 +284,6 @@ async def _revise_with_sections(
Unchanged sections are kept byte-for-byte identical. Unchanged sections are kept byte-for-byte identical.
Returns the revised content, or None to trigger full-document revision fallback. Returns the revised content, or None to trigger full-document revision fallback.
""" """
# Parse report into sections
sections = _parse_sections(parent_content) sections = _parse_sections(parent_content)
if len(sections) < 2: if len(sections) < 2:
logger.info( logger.info(
@ -300,7 +291,6 @@ async def _revise_with_sections(
) )
return None return None
# Build a sections listing for the LLM
sections_listing = "" sections_listing = ""
for i, sec in enumerate(sections): for i, sec in enumerate(sections):
heading = sec["heading"] or "(preamble — content before first heading)" heading = sec["heading"] or "(preamble — content before first heading)"
@ -352,11 +342,9 @@ async def _revise_with_sections(
) )
return None return None
# Compute total operations for progress tracking
total_ops = len(modify_indices) + len(add_sections) total_ops = len(modify_indices) + len(add_sections)
current_op = 0 current_op = 0
# Emit plan summary
parts = [] parts = []
if modify_indices: if modify_indices:
parts.append( parts.append(
@ -394,7 +382,6 @@ async def _revise_with_sections(
current_op += 1 current_op += 1
sec = sections[idx] sec = sections[idx]
# Extract plain section name (strip markdown heading markers)
section_name = ( section_name = (
re.sub(r"^#+\s*", "", sec["heading"]).strip() re.sub(r"^#+\s*", "", sec["heading"]).strip()
if sec["heading"] if sec["heading"]
@ -412,7 +399,6 @@ async def _revise_with_sections(
f"{sec['heading']}\n\n{sec['body']}" if sec["heading"] else sec["body"] f"{sec['heading']}\n\n{sec['body']}" if sec["heading"] else sec["body"]
) )
# Build context from surrounding sections
context_parts = [] context_parts = []
if idx > 0: if idx > 0:
prev = sections[idx - 1] prev = sections[idx - 1]
@ -442,7 +428,6 @@ async def _revise_with_sections(
revised_text = resp.content revised_text = resp.content
if revised_text and isinstance(revised_text, str): if revised_text and isinstance(revised_text, str):
revised_text = _strip_wrapping_code_fences(revised_text).strip() revised_text = _strip_wrapping_code_fences(revised_text).strip()
# Parse the LLM output back into heading + body
revised_parsed = _parse_sections(revised_text) revised_parsed = _parse_sections(revised_text)
if revised_parsed: if revised_parsed:
revised_sections[idx] = revised_parsed[0] revised_sections[idx] = revised_parsed[0]
@ -465,7 +450,6 @@ async def _revise_with_sections(
heading = add_info.get("heading", "## New Section") heading = add_info.get("heading", "## New Section")
description = add_info.get("description", "") description = add_info.get("description", "")
# Extract plain section name for progress display
plain_heading = re.sub(r"^#+\s*", "", heading).strip() plain_heading = re.sub(r"^#+\s*", "", heading).strip()
dispatch_custom_event( dispatch_custom_event(
"report_progress", "report_progress",
@ -475,7 +459,6 @@ async def _revise_with_sections(
}, },
) )
# Build context from the surrounding sections at the insertion point
ctx_parts = [] ctx_parts = []
if 0 <= after_idx < len(revised_sections): if 0 <= after_idx < len(revised_sections):
before_sec = revised_sections[after_idx] before_sec = revised_sections[after_idx]
@ -542,36 +525,13 @@ def create_generate_report_tool(
available_connectors: list[str] | None = None, available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None, available_document_types: list[str] | None = None,
): ):
""" """Create the generate_report tool with injected dependencies.
Factory function to create the generate_report tool with injected dependencies.
The tool generates a Markdown report inline using the search space's Uses short-lived DB sessions per operation so no connection is held during
document summary LLM, saves it to the database, and returns immediately. the long LLM call. Generation: new reports are single-shot; revisions try
section-level first (unchanged sections preserved) and fall back to full-doc.
Uses short-lived database sessions for each DB operation so no connection Source strategies: provided/conversation (use source_content), kb_search
is held during the long LLM API call. (internal KB queries), auto (KB search only when source_content is thin).
Generation strategies:
- New reports: single-shot generation (1 LLM call)
- Revisions (targeted edits): section-level (unchanged sections preserved)
- Revisions (global changes): full-document revision fallback
Source strategies:
- "provided"/"conversation": use only the supplied source_content
- "kb_search": search the knowledge base internally using targeted queries
- "auto": use source_content if sufficient, otherwise fall back to KB search
Args:
search_space_id: The user's search space ID
thread_id: The chat thread ID for associating the report
connector_service: Optional connector service for internal KB search.
When provided, the tool can search the knowledge base internally
(used by the "kb_search" and "auto" source strategies).
available_connectors: Optional list of connector types available in the
search space (used to scope internal KB searches).
Returns:
A configured tool function for generating reports
""" """
@tool @tool
@ -693,7 +653,7 @@ def create_generate_report_tool(
Returns: Returns:
Dict with status, report_id, title, word_count, and message. Dict with status, report_id, title, word_count, and message.
""" """
# Initialize version tracking variables (used by _save_failed_report closure) # Shared with the _save_failed_report closure.
parent_report_content: str | None = None parent_report_content: str | None = None
report_group_id: int | None = None report_group_id: int | None = None
@ -733,7 +693,7 @@ def create_generate_report_tool(
session.add(failed_report) session.add(failed_report)
await session.commit() await session.commit()
await session.refresh(failed_report) await session.refresh(failed_report)
# If this is a new group (v1 failed), set group to self # New group (v1 failed): point the group at itself.
if not failed_report.report_group_id: if not failed_report.report_group_id:
failed_report.report_group_id = failed_report.id failed_report.report_group_id = failed_report.id
await session.commit() await session.commit()
@ -749,8 +709,8 @@ def create_generate_report_tool(
try: try:
# ── Phase 1: READ (short-lived session) ────────────────────── # ── Phase 1: READ (short-lived session) ──────────────────────
# Fetch parent report and LLM config, then close the session # Fetch parent report + LLM config, then release the connection
# so no DB connection is held during the long LLM call. # before the long LLM call.
async with shielded_async_session() as read_session: async with shielded_async_session() as read_session:
if parent_report_id: if parent_report_id:
parent_report = await read_session.get(Report, parent_report_id) parent_report = await read_session.get(Report, parent_report_id)
@ -768,7 +728,6 @@ def create_generate_report_tool(
) )
llm = await get_document_summary_llm(read_session, search_space_id) llm = await get_document_summary_llm(read_session, search_space_id)
# read_session closed — connection returned to pool
if not llm: if not llm:
error_msg = ( error_msg = (
@ -785,7 +744,6 @@ def create_generate_report_tool(
error=error_msg, error=error_msg,
) )
# Build the user instructions string
user_instructions_section = "" user_instructions_section = ""
if user_instructions: if user_instructions:
user_instructions_section = ( user_instructions_section = (
@ -829,7 +787,7 @@ def create_generate_report_tool(
try: try:
from .knowledge_base import search_knowledge_base_async from .knowledge_base import search_knowledge_base_async
# Run all queries in parallel, each with its own session # Each query gets its own short-lived session.
async def _run_single_query(q: str) -> str: async def _run_single_query(q: str) -> str:
async with shielded_async_session() as kb_session: async with shielded_async_session() as kb_session:
kb_connector_svc = ConnectorService( kb_connector_svc = ConnectorService(
@ -849,7 +807,6 @@ def create_generate_report_tool(
*[_run_single_query(q) for q in search_queries[:5]] *[_run_single_query(q) for q in search_queries[:5]]
) )
# Merge non-empty results into source_content
kb_text_parts = [r for r in kb_results if r and r.strip()] kb_text_parts = [r for r in kb_results if r and r.strip()]
if kb_text_parts: if kb_text_parts:
kb_combined = "\n\n---\n\n".join(kb_text_parts) kb_combined = "\n\n---\n\n".join(kb_text_parts)
@ -903,9 +860,9 @@ def create_generate_report_tool(
"provided. Using source_content as-is." "provided. Using source_content as-is."
) )
capped_source = effective_source[:100000] # Cap source content capped_source = effective_source[:100000]
# Length constraint — only when user explicitly asks for brevity # Length constraint only when the user explicitly asked for brevity.
length_instruction = "" length_instruction = ""
if report_style == "brief": if report_style == "brief":
length_instruction = ( length_instruction = (
@ -920,11 +877,8 @@ def create_generate_report_tool(
report_content: str | None = None report_content: str | None = None
if parent_report_content: if parent_report_content:
# ─── REVISION MODE ─────────────────────────────────────── # Revision mode: section-level first (preserves untouched
# Strategy: Try section-level revision first (preserves # sections), falling back to full-doc revision.
# unchanged sections byte-for-byte). Falls back to full-
# document revision if section identification fails or if
# all sections need changes.
dispatch_custom_event( dispatch_custom_event(
"report_progress", "report_progress",
{ {
@ -946,7 +900,6 @@ def create_generate_report_tool(
) )
if report_content is None: if report_content is None:
# Fallback: full-document revision
dispatch_custom_event( dispatch_custom_event(
"report_progress", "report_progress",
{"phase": "writing", "message": "Rewriting your full report"}, {"phase": "writing", "message": "Rewriting your full report"},
@ -969,9 +922,7 @@ def create_generate_report_tool(
report_content = response.content report_content = response.content
else: else:
# ─── NEW REPORT MODE ───────────────────────────────────── # New report: single-shot generation (one LLM call).
# Single-shot generation: one LLM call produces the full
# report. Fast, globally coherent, and cost-efficient.
dispatch_custom_event( dispatch_custom_event(
"report_progress", "report_progress",
{"phase": "writing", "message": "Writing your report"}, {"phase": "writing", "message": "Writing your report"},
@ -991,8 +942,6 @@ def create_generate_report_tool(
response = await llm.ainvoke([HumanMessage(content=prompt)]) response = await llm.ainvoke([HumanMessage(content=prompt)])
report_content = response.content report_content = response.content
# ── Validate LLM output ──────────────────────────────────────
if not report_content or not isinstance(report_content, str): if not report_content or not isinstance(report_content, str):
error_msg = "LLM returned empty or invalid content" error_msg = "LLM returned empty or invalid content"
report_id = await _save_failed_report(error_msg) report_id = await _save_failed_report(error_msg)
@ -1029,14 +978,12 @@ def create_generate_report_tool(
if report_content.rstrip().endswith("---"): if report_content.rstrip().endswith("---"):
report_content = report_content.rstrip()[:-3].rstrip() report_content = report_content.rstrip()[:-3].rstrip()
# Append exactly one standard disclaimer # Append exactly one standard footer.
report_content += "\n\n---\n\n" + _REPORT_FOOTER report_content += "\n\n---\n\n" + _REPORT_FOOTER
# Extract metadata (includes "status": "ready")
metadata = _extract_metadata(report_content) metadata = _extract_metadata(report_content)
# ── Phase 3: WRITE (short-lived session) ───────────────────── # ── Phase 3: WRITE (short-lived session) ─────────────────────
# Save the report to the database, then close the session.
async with shielded_async_session() as write_session: async with shielded_async_session() as write_session:
report = Report( report = Report(
title=topic, title=topic,
@ -1051,14 +998,13 @@ def create_generate_report_tool(
await write_session.commit() await write_session.commit()
await write_session.refresh(report) await write_session.refresh(report)
# If this is a brand-new report (v1), set report_group_id = own id # Brand-new report (v1): point the group at itself.
if not report.report_group_id: if not report.report_group_id:
report.report_group_id = report.id report.report_group_id = report.id
await write_session.commit() await write_session.commit()
saved_report_id = report.id saved_report_id = report.id
saved_group_id = report.report_group_id saved_group_id = report.report_group_id
# write_session closed — connection returned to pool
logger.info( logger.info(
f"[generate_report] Created report {saved_report_id} " f"[generate_report] Created report {saved_report_id} "

View file

@ -23,7 +23,6 @@ def extract_domain(url: str) -> str:
try: try:
parsed = urlparse(url) parsed = urlparse(url)
domain = parsed.netloc domain = parsed.netloc
# Remove 'www.' prefix if present
if domain.startswith("www."): if domain.startswith("www."):
domain = domain[4:] domain = domain[4:]
return domain return domain
@ -47,14 +46,13 @@ def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]:
if len(content) <= max_length: if len(content) <= max_length:
return content, False return content, False
# Try to truncate at a sentence boundary # Prefer truncating at a sentence/paragraph boundary.
truncated = content[:max_length] truncated = content[:max_length]
last_period = truncated.rfind(".") last_period = truncated.rfind(".")
last_newline = truncated.rfind("\n\n") last_newline = truncated.rfind("\n\n")
# Use the later of the two boundaries, or just truncate
boundary = max(last_period, last_newline) boundary = max(last_period, last_newline)
if boundary > max_length * 0.8: # Only use boundary if it's not too far back if boundary > max_length * 0.8: # only if the boundary isn't too far back
truncated = content[: boundary + 1] truncated = content[: boundary + 1]
return truncated + "\n\n[Content truncated...]", True return truncated + "\n\n[Content truncated...]", True
@ -105,8 +103,8 @@ async def _scrape_youtube_video(
http_client.proxies.update(residential_proxies) http_client.proxies.update(residential_proxies)
ytt_api = YouTubeTranscriptApi(http_client=http_client) ytt_api = YouTubeTranscriptApi(http_client=http_client)
# List all available transcripts and pick the first one # Pick the first transcript (video's primary language) rather than
# (the video's primary language) instead of defaulting to English # defaulting to English.
transcript_list = ytt_api.list(video_id) transcript_list = ytt_api.list(video_id)
transcript = next(iter(transcript_list)) transcript = next(iter(transcript_list))
captions = transcript.fetch() captions = transcript.fetch()
@ -128,10 +126,8 @@ async def _scrape_youtube_video(
logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}") logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}")
transcript_text = f"No captions available for this video. Error: {e!s}" transcript_text = f"No captions available for this video. Error: {e!s}"
# Build combined content
content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}" content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}"
# Truncate if needed
content, was_truncated = truncate_content(content, max_length) content, was_truncated = truncate_content(content, max_length)
word_count = len(content.split()) word_count = len(content.split())
@ -206,20 +202,16 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
scrape_id = generate_scrape_id(url) scrape_id = generate_scrape_id(url)
domain = extract_domain(url) domain = extract_domain(url)
# Validate and normalize URL
if not url.startswith(("http://", "https://")): if not url.startswith(("http://", "https://")):
url = f"https://{url}" url = f"https://{url}"
try: try:
# Check if this is a YouTube URL and use transcript API instead # YouTube URLs use the transcript API instead of crawling.
video_id = get_youtube_video_id(url) video_id = get_youtube_video_id(url)
if video_id: if video_id:
return await _scrape_youtube_video(url, video_id, max_length) return await _scrape_youtube_video(url, video_id, max_length)
# Create webcrawler connector
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key) connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
# Crawl the URL
result, error = await connector.crawl_url(url, formats=["markdown"]) result, error = await connector.crawl_url(url, formats=["markdown"])
if error: if error:
@ -244,28 +236,21 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
"error": "No content returned from crawler", "error": "No content returned from crawler",
} }
# Extract content and metadata
content = result.get("content", "") content = result.get("content", "")
metadata = result.get("metadata", {}) metadata = result.get("metadata", {})
# Get title from metadata
title = metadata.get("title", "") title = metadata.get("title", "")
if not title: if not title:
title = domain or url.split("/")[-1] or "Webpage" title = domain or url.split("/")[-1] or "Webpage"
# Get description from metadata
description = metadata.get("description", "") description = metadata.get("description", "")
if not description and content: if not description and content:
# Use first paragraph as description
first_para = content.split("\n\n")[0] if content else "" first_para = content.split("\n\n")[0] if content else ""
description = ( description = (
first_para[:300] + "..." if len(first_para) > 300 else first_para first_para[:300] + "..." if len(first_para) > 300 else first_para
) )
# Truncate content if needed
content, was_truncated = truncate_content(content, max_length) content, was_truncated = truncate_content(content, max_length)
# Calculate word count
word_count = len(content.split()) word_count = len(content.split())
return { return {

View file

@ -92,15 +92,9 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
yield chunk yield chunk
# Provider mapping for LiteLLM model string construction. # Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives
# # in provider_capabilities so the YAML loader can resolve prefixes during
# Single source of truth lives in # app.config init without importing the agent/tools tree.
# :mod:`app.services.provider_capabilities` so the YAML loader (which
# runs during ``app.config`` class-body init) can resolve provider
# prefixes without dragging the agent / tools tree into module load
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
# tests) keep working unchanged.
from app.services.provider_capabilities import ( # noqa: E402 from app.services.provider_capabilities import ( # noqa: E402
_PROVIDER_PREFIX_MAP as PROVIDER_MAP, _PROVIDER_PREFIX_MAP as PROVIDER_MAP,
) )
@ -157,25 +151,14 @@ class AgentConfig:
anonymous_enabled: bool = False anonymous_enabled: bool = False
quota_reserve_tokens: int | None = None quota_reserve_tokens: int | None = None
# Capability flag: best-effort True for the chat selector / catalog. # Default-allow: only the streaming safety net (is_known_text_only_chat_model)
# Resolved via :func:`provider_capabilities.derive_supports_image_input` # actually blocks on False, so defaulting False would silently hide
# which prefers OpenRouter's ``architecture.input_modalities`` and # vision-capable models. Resolved via derive_supports_image_input.
# otherwise consults LiteLLM's authoritative model map. Default True
# is the conservative-allow stance — the streaming-task safety net
# (``is_known_text_only_chat_model``) is the *only* place a False
# actually blocks a request. Setting this to False here without an
# authoritative source would silently hide vision-capable models
# (the regression we're fixing).
supports_image_input: bool = True supports_image_input: bool = True
@classmethod @classmethod
def from_auto_mode(cls) -> "AgentConfig": def from_auto_mode(cls) -> "AgentConfig":
""" """Build an AgentConfig for Auto mode (LiteLLM Router load balancing)."""
Create an AgentConfig for Auto mode (LiteLLM Router load balancing).
Returns:
AgentConfig instance configured for Auto mode
"""
return cls( return cls(
provider="AUTO", provider="AUTO",
model_name="auto", model_name="auto",
@ -193,27 +176,15 @@ class AgentConfig:
is_premium=False, is_premium=False,
anonymous_enabled=False, anonymous_enabled=False,
quota_reserve_tokens=None, quota_reserve_tokens=None,
# Auto routes across the configured pool, which usually # Auto fails over across the pool, so a non-vision deployment's 404
# contains at least one vision-capable deployment; the router # is just an allowed_fails event rather than a hard block.
# will surface a 404 from a non-vision deployment as a normal
# ``allowed_fails`` event and fail over rather than blocking
# the request outright.
supports_image_input=True, supports_image_input=True,
) )
@classmethod @classmethod
def from_new_llm_config(cls, config) -> "AgentConfig": def from_new_llm_config(cls, config) -> "AgentConfig":
""" """Build an AgentConfig from a NewLLMConfig database model."""
Create an AgentConfig from a NewLLMConfig database model. # Lazy import: keeps provider_capabilities (and litellm) out of init order.
Args:
config: NewLLMConfig database model instance
Returns:
AgentConfig instance
"""
# Lazy import to avoid pulling provider_capabilities (and its
# transitive litellm import) into module-init order.
from app.services.provider_capabilities import derive_supports_image_input from app.services.provider_capabilities import derive_supports_image_input
provider_value = ( provider_value = (
@ -245,10 +216,8 @@ class AgentConfig:
is_premium=False, is_premium=False,
anonymous_enabled=False, anonymous_enabled=False,
quota_reserve_tokens=None, quota_reserve_tokens=None,
# BYOK rows have no operator-curated capability flag, so we # BYOK rows have no curated flag; ask LiteLLM (default-allow on
# ask LiteLLM (default-allow on unknown). The streaming # unknown). The streaming safety net still blocks explicit text-only.
# safety net still blocks if the model is *explicitly*
# marked text-only.
supports_image_input=derive_supports_image_input( supports_image_input=derive_supports_image_input(
provider=provider_value, provider=provider_value,
model_name=config.model_name, model_name=config.model_name,
@ -259,25 +228,14 @@ class AgentConfig:
@classmethod @classmethod
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig": def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
"""Build an AgentConfig from a YAML configuration dictionary.
Supports the same prompt fields as NewLLMConfig (system_instructions,
use_default_system_instructions, citations_enabled).
""" """
Create an AgentConfig from a YAML configuration dictionary. # Lazy import: keeps provider_capabilities (and litellm) out of init order.
YAML configs now support the same prompt configuration fields as NewLLMConfig:
- system_instructions: Custom system instructions (empty string uses defaults)
- use_default_system_instructions: Whether to use default instructions
- citations_enabled: Whether citations are enabled
Args:
yaml_config: Configuration dictionary from YAML file
Returns:
AgentConfig instance
"""
# Lazy import to avoid pulling provider_capabilities (and its
# transitive litellm import) into module-init order.
from app.services.provider_capabilities import derive_supports_image_input from app.services.provider_capabilities import derive_supports_image_input
# Get system instructions from YAML, default to empty string
system_instructions = yaml_config.get("system_instructions", "") system_instructions = yaml_config.get("system_instructions", "")
provider = yaml_config.get("provider", "").upper() provider = yaml_config.get("provider", "").upper()
@ -290,13 +248,8 @@ class AgentConfig:
else None else None
) )
# Explicit YAML override wins; otherwise derive from LiteLLM / # Explicit YAML override wins; otherwise re-derive (the hot-reload file
# OpenRouter modalities. The YAML loader already populates this # fallback reaches this method without the loader having populated it).
# field, but this method is also called from
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
# so we re-derive here for safety. The bool() coercion preserves
# the loader's behaviour for explicit ``true`` / ``false``
# strings that PyYAML may surface.
if "supports_image_input" in yaml_config: if "supports_image_input" in yaml_config:
supports_image_input = bool(yaml_config.get("supports_image_input")) supports_image_input = bool(yaml_config.get("supports_image_input"))
else: else:
@ -314,7 +267,6 @@ class AgentConfig:
api_base=yaml_config.get("api_base"), api_base=yaml_config.get("api_base"),
custom_provider=custom_provider, custom_provider=custom_provider,
litellm_params=yaml_config.get("litellm_params"), litellm_params=yaml_config.get("litellm_params"),
# Prompt configuration from YAML (with defaults for backwards compatibility)
system_instructions=system_instructions if system_instructions else None, system_instructions=system_instructions if system_instructions else None,
use_default_system_instructions=yaml_config.get( use_default_system_instructions=yaml_config.get(
"use_default_system_instructions", True "use_default_system_instructions", True
@ -332,20 +284,10 @@ class AgentConfig:
def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None: def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
""" """Load a specific LLM config from global_llm_config.yaml."""
Load a specific LLM config from global_llm_config.yaml.
Args:
llm_config_id: The id of the config to load (default: -1)
Returns:
LLM config dict or None if not found
"""
# Get the config file path
base_dir = Path(__file__).resolve().parent.parent.parent.parent base_dir = Path(__file__).resolve().parent.parent.parent.parent
config_file = base_dir / "app" / "config" / "global_llm_config.yaml" config_file = base_dir / "app" / "config" / "global_llm_config.yaml"
# Fallback to example file if main config doesn't exist
if not config_file.exists(): if not config_file.exists():
config_file = base_dir / "app" / "config" / "global_llm_config.example.yaml" config_file = base_dir / "app" / "config" / "global_llm_config.example.yaml"
if not config_file.exists(): if not config_file.exists():
@ -368,24 +310,17 @@ def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
def load_global_llm_config_by_id(llm_config_id: int) -> dict | None: def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
""" """Load a global LLM config by ID, checking in-memory configs first.
Load a global LLM config by ID, checking in-memory configs first.
This handles both static YAML configs and dynamically injected configs In-memory covers both static YAML and dynamically injected configs (e.g.
(e.g. OpenRouter integration models that only exist in memory). OpenRouter integration models that only exist in memory).
Args:
llm_config_id: The negative ID of the global config to load
Returns:
LLM config dict or None if not found
""" """
from app.config import config as app_config from app.config import config as app_config
for cfg in app_config.GLOBAL_LLM_CONFIGS: for cfg in app_config.GLOBAL_LLM_CONFIGS:
if cfg.get("id") == llm_config_id: if cfg.get("id") == llm_config_id:
return cfg return cfg
# Fallback to YAML file read (covers edge cases like hot-reload) # Fallback to YAML file read (covers hot-reload edge cases).
return load_llm_config_from_yaml(llm_config_id) return load_llm_config_from_yaml(llm_config_id)
@ -393,17 +328,7 @@ async def load_new_llm_config_from_db(
session: AsyncSession, session: AsyncSession,
config_id: int, config_id: int,
) -> "AgentConfig | None": ) -> "AgentConfig | None":
""" """Load a NewLLMConfig from the database by ID."""
Load a NewLLMConfig from the database by ID.
Args:
session: AsyncSession for database access
config_id: The ID of the NewLLMConfig to load
Returns:
AgentConfig instance or None if not found
"""
# Import here to avoid circular imports
from app.db import NewLLMConfig from app.db import NewLLMConfig
try: try:
@ -426,26 +351,13 @@ async def load_agent_llm_config_for_search_space(
session: AsyncSession, session: AsyncSession,
search_space_id: int, search_space_id: int,
) -> "AgentConfig | None": ) -> "AgentConfig | None":
"""Load the agent LLM config for a search space via its agent_llm_id.
Positive id -> DB; negative -> YAML; None -> first global config (-1).
""" """
Load the agent LLM configuration for a search space.
This loads the LLM config based on the search space's agent_llm_id setting:
- Positive ID: Load from NewLLMConfig database table
- Negative ID: Load from YAML global configs
- None: Falls back to first global config (id=-1)
Args:
session: AsyncSession for database access
search_space_id: The search space ID
Returns:
AgentConfig instance or None if not found
"""
# Import here to avoid circular imports
from app.db import SearchSpace from app.db import SearchSpace
try: try:
# Get the search space to check its agent_llm_id preference
result = await session.execute( result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id) select(SearchSpace).filter(SearchSpace.id == search_space_id)
) )
@ -455,12 +367,9 @@ async def load_agent_llm_config_for_search_space(
print(f"Error: SearchSpace with id {search_space_id} not found") print(f"Error: SearchSpace with id {search_space_id} not found")
return None return None
# Use agent_llm_id from search space, fallback to -1 (first global config)
config_id = ( config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
) )
# Load the config using the unified loader
return await load_agent_config(session, config_id, search_space_id) return await load_agent_config(session, config_id, search_space_id)
except Exception as e: except Exception as e:
print(f"Error loading agent LLM config for search space {search_space_id}: {e}") print(f"Error loading agent LLM config for search space {search_space_id}: {e}")
@ -472,23 +381,7 @@ async def load_agent_config(
config_id: int, config_id: int,
search_space_id: int | None = None, search_space_id: int | None = None,
) -> "AgentConfig | None": ) -> "AgentConfig | None":
""" """Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB."""
Load an agent configuration, supporting Auto mode, YAML, and database configs.
This is the main entry point for loading configurations:
- ID 0: Auto mode (uses LiteLLM Router for load balancing)
- Negative IDs: Load from YAML file (global configs)
- Positive IDs: Load from NewLLMConfig database table
Args:
session: AsyncSession for database access
config_id: The config ID (0 for Auto, negative for YAML, positive for database)
search_space_id: Optional search space ID for context
Returns:
AgentConfig instance or None if not found
"""
# Auto mode (ID 0) - use LiteLLM Router
if is_auto_mode(config_id): if is_auto_mode(config_id):
if not LLMRouterService.is_initialized(): if not LLMRouterService.is_initialized():
print("Error: Auto mode requested but LLM Router not initialized") print("Error: Auto mode requested but LLM Router not initialized")
@ -496,33 +389,22 @@ async def load_agent_config(
return AgentConfig.from_auto_mode() return AgentConfig.from_auto_mode()
if config_id < 0: if config_id < 0:
# Check in-memory configs first (includes static YAML + dynamic OpenRouter) # In-memory covers static YAML + dynamic OpenRouter configs.
from app.config import config as app_config from app.config import config as app_config
for cfg in app_config.GLOBAL_LLM_CONFIGS: for cfg in app_config.GLOBAL_LLM_CONFIGS:
if cfg.get("id") == config_id: if cfg.get("id") == config_id:
return AgentConfig.from_yaml_config(cfg) return AgentConfig.from_yaml_config(cfg)
# Fallback to YAML file read for safety
yaml_config = load_llm_config_from_yaml(config_id) yaml_config = load_llm_config_from_yaml(config_id)
if yaml_config: if yaml_config:
return AgentConfig.from_yaml_config(yaml_config) return AgentConfig.from_yaml_config(yaml_config)
return None return None
else: else:
# Load from database (NewLLMConfig)
return await load_new_llm_config_from_db(session, config_id) return await load_new_llm_config_from_db(session, config_id)
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
""" """Create a ChatLiteLLM instance from a global LLM config dictionary."""
Create a ChatLiteLLM instance from a global LLM config dictionary.
Args:
llm_config: LLM configuration dictionary from YAML
Returns:
ChatLiteLLM instance or None on error
"""
# Build the model string
if llm_config.get("custom_provider"): if llm_config.get("custom_provider"):
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}" model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
else: else:
@ -530,27 +412,20 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{llm_config['model_name']}" model_string = f"{provider_prefix}/{llm_config['model_name']}"
# Create ChatLiteLLM instance with streaming enabled
litellm_kwargs = { litellm_kwargs = {
"model": model_string, "model": model_string,
"api_key": llm_config.get("api_key"), "api_key": llm_config.get("api_key"),
"streaming": True, # Enable streaming for real-time token streaming "streaming": True,
} }
# Add optional parameters
if llm_config.get("api_base"): if llm_config.get("api_base"):
litellm_kwargs["api_base"] = llm_config["api_base"] litellm_kwargs["api_base"] = llm_config["api_base"]
# Add any additional litellm parameters
if llm_config.get("litellm_params"): if llm_config.get("litellm_params"):
litellm_kwargs.update(llm_config["litellm_params"]) litellm_kwargs.update(llm_config["litellm_params"])
llm = SanitizedChatLiteLLM(**litellm_kwargs) llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string) _attach_model_profile(llm, model_string)
# Configure LiteLLM-native prompt caching (cache_control_injection_points # agent_config=None: the YAML path lacks structured provider intent, so set
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.). # only the universal cache_control_injection_points.
# ``agent_config=None`` here — the YAML path doesn't have provider intent
# in a structured form, so we set only the universal injection points.
apply_litellm_prompt_caching(llm) apply_litellm_prompt_caching(llm)
return llm return llm
@ -558,19 +433,7 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
def create_chat_litellm_from_agent_config( def create_chat_litellm_from_agent_config(
agent_config: AgentConfig, agent_config: AgentConfig,
) -> ChatLiteLLM | ChatLiteLLMRouter | None: ) -> ChatLiteLLM | ChatLiteLLMRouter | None:
""" """Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config."""
Create a ChatLiteLLM or ChatLiteLLMRouter instance from an AgentConfig.
For Auto mode configs, returns a ChatLiteLLMRouter that uses LiteLLM Router
for automatic load balancing across available providers.
Args:
agent_config: AgentConfig instance
Returns:
ChatLiteLLM or ChatLiteLLMRouter instance, or None on error
"""
# Handle Auto mode - return ChatLiteLLMRouter
if agent_config.is_auto_mode: if agent_config.is_auto_mode:
if not LLMRouterService.is_initialized(): if not LLMRouterService.is_initialized():
print("Error: Auto mode requested but LLM Router not initialized") print("Error: Auto mode requested but LLM Router not initialized")
@ -578,19 +441,14 @@ def create_chat_litellm_from_agent_config(
try: try:
router_llm = get_auto_mode_llm() router_llm = get_auto_mode_llm()
if router_llm is not None: if router_llm is not None:
# Universal cache_control_injection_points only — auto-mode # Universal injection points only: auto-mode fans out across
# fans out across providers, so OpenAI-only kwargs (e.g. # providers, so provider-specific kwargs have no known target.
# ``prompt_cache_key``) are left off here. ``drop_params``
# would strip them at the provider boundary anyway, but
# there's no point setting them when we don't know the
# destination.
apply_litellm_prompt_caching(router_llm, agent_config=agent_config) apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
return router_llm return router_llm
except Exception as e: except Exception as e:
print(f"Error creating ChatLiteLLMRouter: {e}") print(f"Error creating ChatLiteLLMRouter: {e}")
return None return None
# Build the model string
if agent_config.custom_provider: if agent_config.custom_provider:
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}" model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
else: else:
@ -599,26 +457,19 @@ def create_chat_litellm_from_agent_config(
) )
model_string = f"{provider_prefix}/{agent_config.model_name}" model_string = f"{provider_prefix}/{agent_config.model_name}"
# Create ChatLiteLLM instance with streaming enabled
litellm_kwargs = { litellm_kwargs = {
"model": model_string, "model": model_string,
"api_key": agent_config.api_key, "api_key": agent_config.api_key,
"streaming": True, # Enable streaming for real-time token streaming "streaming": True,
} }
# Add optional parameters
if agent_config.api_base: if agent_config.api_base:
litellm_kwargs["api_base"] = agent_config.api_base litellm_kwargs["api_base"] = agent_config.api_base
# Add any additional litellm parameters
if agent_config.litellm_params: if agent_config.litellm_params:
litellm_kwargs.update(agent_config.litellm_params) litellm_kwargs.update(agent_config.litellm_params)
llm = SanitizedChatLiteLLM(**litellm_kwargs) llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string) _attach_model_profile(llm, model_string)
# Build-time prompt caching: sets ``cache_control_injection_points`` for # Build-time caching only; the per-thread prompt_cache_key is layered on
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``. # later in create_surfsense_deep_agent once thread_id is known.
# Per-thread ``prompt_cache_key`` is layered on later in
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
apply_litellm_prompt_caching(llm, agent_config=agent_config) apply_litellm_prompt_caching(llm, agent_config=agent_config)
return llm return llm

View file

@ -1,63 +1,28 @@
r"""LiteLLM-native prompt caching configuration for SurfSense agents. r"""LiteLLM-native prompt caching for SurfSense agents.
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never Replaces the legacy ``AnthropicPromptCachingMiddleware`` (its
activated for our LiteLLM-based stack its ``isinstance(model, ChatAnthropic)`` ``isinstance(model, ChatAnthropic)`` gate never matched our LiteLLM stack)
gate always failed) with LiteLLM's universal caching mechanism. with LiteLLM's universal ``cache_control_injection_points`` mechanism, which
covers the Anthropic/Bedrock/Vertex/Gemini/OpenRouter/etc. marker-based
providers and the auto-caching OpenAI family.
Coverage: Two breakpoints per request:
- Marker-based providers (need ``cache_control`` injection, which LiteLLM - ``index: 0`` pins the head-of-request system prompt. We use ``index: 0``,
performs automatically when ``cache_control_injection_points`` is set): NOT ``role: system``: ``before_agent`` injectors accumulate many
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``, SystemMessages, and tagging all of them overflows Anthropic's 4-block cap
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/`` (upstream 400 via OpenRouter).
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM). - ``index: -1`` pins the latest message so longest-prefix lookup compounds
- Auto-cached (LiteLLM strips the marker silently): ``openai/``, multi-turn savings.
``deepseek/``, ``xai/`` these caches automatically for prompts 1024
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
We inject **two** breakpoints per request: OpenAI-family configs also get ``prompt_cache_key`` (per-thread routing hint)
and ``prompt_cache_retention="24h"``. Azure is excluded from the latter
because LiteLLM's Azure transformer drops it (see
``_PROMPT_CACHE_RETENTION_PROVIDERS``).
- ``index: 0`` pins the SurfSense system prompt at the head of the Safety net: ``litellm.drop_params=True`` (set in ``app.services.llm_service``)
request (provider variant, citation rules, tool catalog, KB tree, strips any kwarg the destination provider rejects, so an auto-mode fallback
skills metadata). The langchain agent factory always prepends can't 400 on these extras.
``request.system_message`` at index 0 (see ``factory.py``
``_execute_model_async``), so this targets exactly the main system
prompt regardless of how many other ``SystemMessage``\ s the
``before_agent`` injectors (priority, tree, memory, file-intent,
anonymous-doc) have inserted into ``state["messages"]``. Using
``role: system`` here would apply ``cache_control`` to **every**
system-role message and trip Anthropic's hard cap of 4 cache
breakpoints per request once the conversation accumulates enough
injected system messages which surfaces as the upstream 400
``A maximum of 4 blocks with cache_control may be provided. Found N``
via OpenRouterAnthropic.
- ``index: -1`` pins the latest message so multi-turn savings compound:
Anthropic-family providers use longest-matching-prefix lookup, so turn
N+1 still reads turn N's cache up to the shared prefix.
For OpenAI-family configs we additionally pass:
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` routing hint that
raises hit rate by sending requests with a shared prefix to the same
backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and
``azure/`` (added to LiteLLM's Azure transformer in
https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified
against ``AzureOpenAIConfig.get_supported_openai_params`` in our
installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``,
``azure/gpt-5.4``, ``azure/gpt-5.4-mini``).
- ``prompt_cache_retention="24h"`` extends cache TTL beyond the default
5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's
server-side support landed in Microsoft's docs on 2026-05-13 but
LiteLLM 1.83.14's Azure transformer still omits it from its supported
params list, so it gets silently dropped by ``litellm.drop_params``.
Azure's default in-memory retention (5-10 min, max 1 h) already
bridges intra-conversation turns; revisit when LiteLLM bumps Azure.
Safety net: ``litellm.drop_params=True`` is set globally in
``app.services.llm_service`` at module-load time. Any kwarg the destination
provider doesn't recognise is auto-stripped at the provider transformer
layer, so an OpenAIBedrock auto-mode fallback can't 400 on
``prompt_cache_key`` etc.
""" """
from __future__ import annotations from __future__ import annotations
@ -73,57 +38,29 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Two-breakpoint policy: head-of-request + latest message. See module # Head-of-request + latest message (see module docstring for the index:0 vs
# docstring for rationale. Anthropic caps requests at 4 ``cache_control`` # role:system rationale and Anthropic's 4-block cap).
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
#
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
# ``before_agent`` middlewares (priority, tree, memory, anonymous-doc)
# insert ``SystemMessage`` instances into ``state["messages"]`` that
# accumulate across turns. With ``role: system`` the LiteLLM hook would
# tag *every* one of them with ``cache_control`` and overflow Anthropic's
# 4-block limit. ``index: 0`` always targets the langchain-prepended
# ``request.system_message``, giving us exactly one stable cache breakpoint.
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = ( _DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
{"location": "message", "index": 0}, {"location": "message", "index": 0},
{"location": "message", "index": -1}, {"location": "message", "index": -1},
) )
# Providers (uppercase ``AgentConfig.provider`` values) that accept the # Providers that accept the OpenAI ``prompt_cache_key`` routing hint. Strict
# OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs # whitelist: many providers route through litellm's ``openai`` prefix without
# (2026-05-13) confirm automatic prompt caching applies to every GPT-4o # the prompt-cache surface, so the prefix alone isn't enough to infer family.
# or newer Azure deployment at ≥1024 tokens with no configuration needed,
# and that ``prompt_cache_key`` is combined with the prefix hash to
# improve routing affinity and therefore cache hit rate. LiteLLM's Azure
# transformer ships ``prompt_cache_key`` in its supported params as of
# https://github.com/BerriAI/litellm/pull/20989.
#
# Strict whitelist — many other providers in ``PROVIDER_MAP`` route
# through litellm's ``openai`` prefix without implementing the OpenAI
# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer
# family from the litellm prefix alone.
_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset( _PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset(
{"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"} {"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"}
) )
# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept # Subset that also accepts ``prompt_cache_retention="24h"``. Azure is excluded
# ``prompt_cache_retention="24h"``. Azure is excluded: see module # because LiteLLM's Azure transformer omits the param (drop_params strips it).
# docstring — LiteLLM 1.83.14's Azure transformer omits the param so
# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM
# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``.
_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset( _PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset(
{"OPENAI", "DEEPSEEK", "XAI"} {"OPENAI", "DEEPSEEK", "XAI"}
) )
def _is_router_llm(llm: BaseChatModel) -> bool: def _is_router_llm(llm: BaseChatModel) -> bool:
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import. """Detect ``ChatLiteLLMRouter`` by class name to avoid an import cycle."""
Importing ``app.services.llm_router_service`` at module-load time would
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
Class-name comparison is sufficient since the class is defined in a
single place.
"""
return type(llm).__name__ == "ChatLiteLLMRouter" return type(llm).__name__ == "ChatLiteLLMRouter"
@ -188,21 +125,10 @@ def apply_litellm_prompt_caching(
) -> None: ) -> None:
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter. """Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
Idempotent values already present in ``llm.model_kwargs`` (e.g. from Idempotent (existing ``model_kwargs`` values are preserved) and mutates
``agent_config.litellm_params`` overrides) are preserved. Mutates ``llm.model_kwargs`` in place. Without ``agent_config`` (or in auto-mode)
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion`` only the universal injection points are set; ``thread_id`` adds a per-thread
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge ``prompt_cache_key`` for OpenAI-family providers to improve routing affinity.
in our custom ``ChatLiteLLMRouter``.
Args:
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
agent_config: Optional ``AgentConfig`` driving provider-specific
behaviour. When omitted (or auto-mode), only the universal
``cache_control_injection_points`` are set.
thread_id: Optional thread id used to construct a per-thread
``prompt_cache_key`` for OpenAI-family providers. Caching still
works without it (server-side automatic), but the key improves
backend routing affinity and therefore hit rate.
""" """
model_kwargs = _get_or_init_model_kwargs(llm) model_kwargs = _get_or_init_model_kwargs(llm)
if model_kwargs is None: if model_kwargs is None:
@ -217,11 +143,8 @@ def apply_litellm_prompt_caching(
dict(point) for point in _DEFAULT_INJECTION_POINTS dict(point) for point in _DEFAULT_INJECTION_POINTS
] ]
# OpenAI-style extras only when we statically know the destination # OpenAI-style extras only when the destination is statically known. The
# accepts them. Auto-mode router fans out across mixed providers so # auto-mode router fans out across mixed providers, so skip them there.
# we can't safely set destination-specific kwargs there (drop_params
# would strip them but it's wasteful to set them in the first
# place).
if _is_router_llm(llm): if _is_router_llm(llm):
return return

View file

@ -1,26 +1,13 @@
""" """SurfSense compaction middleware.
SurfSense compaction middleware.
Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware` Extends ``SummarizationMiddleware`` with three SurfSense behaviors:
to add SurfSense-specific behavior:
1. **Structured summary template** (OpenCode-style ``## Goal / Constraints / 1. A structured summary template (:data:`SURFSENSE_SUMMARY_PROMPT`) instead of
Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``) the base freeform prompt.
see :data:`SURFSENSE_SUMMARY_PROMPT` below. The base 2. Protected SystemMessages (injected hints like ``<priority_documents>``) are
``SummarizationMiddleware`` only ships a freeform "summarize this" kept verbatim instead of being summarized away.
prompt; the structured template is ported from OpenCode's 3. ``content=None`` is sanitized before ``get_buffer_string`` (some providers
``compaction.ts``. stream tool-only AIMessages with ``None`` content, which would crash it).
2. **Protect SurfSense-specific SystemMessages** so injected hints
(``<priority_documents>``, ``<workspace_tree>``, ``<file_operation_contract>``,
``<user_memory>``, ``<team_memory>``, ``<user_name>``, ``<memory_warning>``)
are *not* summarized away and are kept verbatim in the post-summary
message list. Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
(some message types are part of the agent's contract and must survive
compaction unchanged).
3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string``
(Azure OpenAI / LiteLLM defense when a provider streams an AIMessage
containing only tool_calls and no text, ``content`` can be ``None`` and
``get_buffer_string`` crashes iterating over ``None``). SurfSense-specific.
""" """
from __future__ import annotations from __future__ import annotations
@ -43,9 +30,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Structured summary template ported from OpenCode's # Module-level constant so unit tests can assert on its sections.
# ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a
# module-level constant so unit tests can assert on its sections.
SURFSENSE_SUMMARY_PROMPT = """<role> SURFSENSE_SUMMARY_PROMPT = """<role>
SurfSense Conversation Compaction Assistant SurfSense Conversation Compaction Assistant
</role> </role>
@ -114,13 +99,10 @@ def _is_protected_system_message(msg: AnyMessage) -> bool:
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage: def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
"""Return ``msg`` with ``content=None`` coerced to ``""``. """Return a copy of ``msg`` with ``content=None`` coerced to ``""``.
Folds in the historical defense from ``safe_summarization.py`` ``get_buffer_string`` reads ``m.text`` (iterating ``content``), so a
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``, tool-only AIMessage with ``None`` content would crash it.
so a ``None`` content (Azure OpenAI / LiteLLM streaming a tool-only
AIMessage) explodes. We return a copy with empty string content so
downstream consumers see an empty body without mutating the original.
""" """
if getattr(msg, "content", "not-missing") is not None: if getattr(msg, "content", "not-missing") is not None:
return msg return msg
@ -159,20 +141,11 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware):
conversation_messages: list[AnyMessage], conversation_messages: list[AnyMessage],
cutoff_index: int, cutoff_index: int,
) -> tuple[list[AnyMessage], list[AnyMessage]]: ) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Split messages but always preserve SurfSense protected SystemMessages. """Split messages, always preserving protected SystemMessages.
Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy Also opens a ``compaction.run`` OTel span (no-op when OTel is off) here,
(``opencode/packages/opencode/src/session/compaction.ts``): some since partitioning is the first call once summarization is decided.
message types are always kept verbatim because they are part of the
agent's working contract, not transient output.
Also opens a ``compaction.run`` OTel span (no-op when OTel is off)
so dashboards can count compaction events and message-volume
without having to instrument upstream callers.
""" """
# Opening a span here is appropriate because partitioning is the
# first call SummarizationMiddleware makes when it has decided to
# summarize; we record the volume and then close as a normal span.
with ot.compaction_span( with ot.compaction_span(
reason="auto", reason="auto",
messages_in=len(conversation_messages), messages_in=len(conversation_messages),
@ -191,20 +164,15 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware):
else: else:
kept_for_summary.append(msg) kept_for_summary.append(msg)
# Place protected blocks at the *front* of preserved_messages so # Protected blocks go at the front of preserved_messages to keep
# they keep their original ordering relative to the summary # ordering relative to the summary HumanMessage.
# HumanMessage that precedes the rest of the preserved tail.
return kept_for_summary, [*protected, *preserved_messages] return kept_for_summary, [*protected, *preserved_messages]
def _filter_summary_messages( # type: ignore[override] def _filter_summary_messages( # type: ignore[override]
self, messages: list[AnyMessage] self, messages: list[AnyMessage]
) -> list[AnyMessage]: ) -> list[AnyMessage]:
"""Filter previous summaries AND sanitize ``content=None``. """Filter previous summaries and sanitize ``content=None`` (covers the
sync and async offload paths)."""
Folds the ``safe_summarization.py`` defense in: when the buffer
builder iterates ``m.text`` over ``None`` it explodes; sanitizing
here covers both the sync and async offload paths.
"""
filtered = super()._filter_summary_messages(messages) filtered = super()._filter_summary_messages(messages)
return [_sanitize_message_content(m) for m in filtered] return [_sanitize_message_content(m) for m in filtered]

View file

@ -24,14 +24,11 @@ from .utils import get_voice_for_provider
async def create_podcast_transcript( async def create_podcast_transcript(
state: State, config: RunnableConfig state: State, config: RunnableConfig
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Each node does work.""" """Generate the podcast transcript from the source content."""
# Get configuration from runnable config
configuration = Configuration.from_runnable_config(config) configuration = Configuration.from_runnable_config(config)
search_space_id = configuration.search_space_id search_space_id = configuration.search_space_id
user_prompt = configuration.user_prompt user_prompt = configuration.user_prompt
# Get search space's document summary LLM
llm = await get_agent_llm(state.db_session, search_space_id) llm = await get_agent_llm(state.db_session, search_space_id)
if not llm: if not llm:
error_message = ( error_message = (
@ -40,22 +37,16 @@ async def create_podcast_transcript(
print(error_message) print(error_message)
raise RuntimeError(error_message) raise RuntimeError(error_message)
# Get the prompt
prompt = get_podcast_generation_prompt(user_prompt) prompt = get_podcast_generation_prompt(user_prompt)
# Create the messages
messages = [ messages = [
SystemMessage(content=prompt), SystemMessage(content=prompt),
HumanMessage( HumanMessage(
content=f"<source_content>{state.source_content}</source_content>" content=f"<source_content>{state.source_content}</source_content>"
), ),
] ]
# Generate the podcast transcript
llm_response = await llm.ainvoke(messages) llm_response = await llm.ainvoke(messages)
# Reasoning models (e.g. Kimi K2.5) may return content as a list of # Reasoning models may return content as blocks; normalise to a string.
# blocks including 'reasoning' entries. Normalise to a plain string.
content = strip_markdown_fences(extract_text_content(llm_response.content)) content = strip_markdown_fences(extract_text_content(llm_response.content))
try: try:
@ -89,17 +80,13 @@ async def create_merged_podcast_audio(
state: State, config: RunnableConfig state: State, config: RunnableConfig
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Generate audio for each transcript and merge them into a single podcast file.""" """Generate audio for each transcript and merge them into a single podcast file."""
# configuration = Configuration.from_runnable_config(config)
starting_transcript = PodcastTranscriptEntry( starting_transcript = PodcastTranscriptEntry(
speaker_id=1, dialog="Welcome to Surfsense Podcast." speaker_id=1, dialog="Welcome to Surfsense Podcast."
) )
transcript = state.podcast_transcript transcript = state.podcast_transcript
# Merge the starting transcript with the podcast transcript # transcript may be a PodcastTranscripts object or already a list.
# Check if transcript is a PodcastTranscripts object or already a list
if hasattr(transcript, "podcast_transcripts"): if hasattr(transcript, "podcast_transcripts"):
transcript_entries = transcript.podcast_transcripts transcript_entries = transcript.podcast_transcripts
else: else:
@ -107,20 +94,16 @@ async def create_merged_podcast_audio(
merged_transcript = [starting_transcript, *transcript_entries] merged_transcript = [starting_transcript, *transcript_entries]
# Create a temporary directory for audio files
temp_dir = Path("temp_audio") temp_dir = Path("temp_audio")
temp_dir.mkdir(exist_ok=True) temp_dir.mkdir(exist_ok=True)
# Generate a unique session ID for this podcast
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
output_path = f"podcasts/{session_id}_podcast.mp3" output_path = f"podcasts/{session_id}_podcast.mp3"
os.makedirs("podcasts", exist_ok=True) os.makedirs("podcasts", exist_ok=True)
# Generate audio for each transcript segment
audio_files = [] audio_files = []
async def generate_speech_for_segment(segment, index): async def generate_speech_for_segment(segment, index):
# Handle both dictionary and PodcastTranscriptEntry objects
if hasattr(segment, "speaker_id"): if hasattr(segment, "speaker_id"):
speaker_id = segment.speaker_id speaker_id = segment.speaker_id
dialog = segment.dialog dialog = segment.dialog
@ -128,20 +111,15 @@ async def create_merged_podcast_audio(
speaker_id = segment.get("speaker_id", 0) speaker_id = segment.get("speaker_id", 0)
dialog = segment.get("dialog", "") dialog = segment.get("dialog", "")
# Select voice based on speaker_id
voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id) voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id)
# Generate a unique filename for this segment
if app_config.TTS_SERVICE == "local/kokoro": if app_config.TTS_SERVICE == "local/kokoro":
# Kokoro generates WAV files
filename = f"{temp_dir}/{session_id}_{index}.wav" filename = f"{temp_dir}/{session_id}_{index}.wav"
else: else:
# Other services generate MP3 files
filename = f"{temp_dir}/{session_id}_{index}.mp3" filename = f"{temp_dir}/{session_id}_{index}.mp3"
try: try:
if app_config.TTS_SERVICE == "local/kokoro": if app_config.TTS_SERVICE == "local/kokoro":
# Use Kokoro TTS service
kokoro_service = await get_kokoro_tts_service( kokoro_service = await get_kokoro_tts_service(
lang_code="a" lang_code="a"
) # American English ) # American English
@ -170,7 +148,6 @@ async def create_merged_podcast_audio(
timeout=600, timeout=600,
) )
# Save the audio to a file - use proper streaming method
with open(filename, "wb") as f: with open(filename, "wb") as f:
f.write(response.content) f.write(response.content)
@ -179,23 +156,17 @@ async def create_merged_podcast_audio(
print(f"Error generating speech for segment {index}: {e!s}") print(f"Error generating speech for segment {index}: {e!s}")
raise raise
# Generate all audio files concurrently
tasks = [ tasks = [
generate_speech_for_segment(segment, i) generate_speech_for_segment(segment, i)
for i, segment in enumerate(merged_transcript) for i, segment in enumerate(merged_transcript)
] ]
audio_files = await asyncio.gather(*tasks) audio_files = await asyncio.gather(*tasks)
# Merge audio files using ffmpeg
try: try:
# Create FFmpeg instance with the first input
ffmpeg = FFmpeg().option("y") ffmpeg = FFmpeg().option("y")
# Add each audio file as input
for audio_file in audio_files: for audio_file in audio_files:
ffmpeg = ffmpeg.input(audio_file) ffmpeg = ffmpeg.input(audio_file)
# Configure the concatenation and output
filter_complex = [] filter_complex = []
for i in range(len(audio_files)): for i in range(len(audio_files)):
filter_complex.append(f"[{i}:0]") filter_complex.append(f"[{i}:0]")
@ -205,8 +176,6 @@ async def create_merged_podcast_audio(
) )
ffmpeg = ffmpeg.option("filter_complex", filter_complex_str) ffmpeg = ffmpeg.option("filter_complex", filter_complex_str)
ffmpeg = ffmpeg.output(output_path, map="[outa]") ffmpeg = ffmpeg.output(output_path, map="[outa]")
# Execute FFmpeg
await ffmpeg.execute() await ffmpeg.execute()
print(f"Successfully created podcast audio: {output_path}") print(f"Successfully created podcast audio: {output_path}")
@ -215,7 +184,6 @@ async def create_merged_podcast_audio(
print(f"Error merging audio files: {e!s}") print(f"Error merging audio files: {e!s}")
raise raise
finally: finally:
# Clean up temporary files
for audio_file in audio_files: for audio_file in audio_files:
try: try:
os.remove(audio_file) os.remove(audio_file)