docs(agents): tighten docstrings and comments across agent module

Recursive pass over the agents module to make docstrings and inline
comments concise and intent-oriented: drop narration that just restates
the code, condense verbose module/function docstrings, and keep only the
non-obvious "why" notes. No functional code changed.
This commit is contained in:
CREDO23 2026-06-05 17:39:38 +02:00
parent 620c378254
commit a3d05f6418
16 changed files with 319 additions and 1055 deletions

View file

@ -1,25 +1,15 @@
"""Append-only action-log middleware for the SurfSense agent.
Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes
a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt
into reversibility by declaring a ``reverse`` callable on their
:class:`ToolDefinition`; the rendered descriptor is persisted in
``reverse_descriptor`` for use by
Wraps every tool call and writes a row to :class:`~app.db.AgentActionLog`
after the tool returns. Tools opt into reversibility via a ``reverse``
callable on their :class:`ToolDefinition`; the rendered descriptor powers
``/api/threads/{thread_id}/revert/{action_id}``.
Design points:
* **Defensive.** Logging never blocks the agent. We catch every exception
on the DB write path and emit a warning; the tool's ``ToolMessage``
result is always returned untouched.
* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) +
``result_id`` + ``reverse_descriptor`` are stored. Tool output text
remains in the LangGraph checkpoint / spilled tool-output files.
* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)``
with the parsed JSON result when the tool's content is a JSON object;
otherwise the raw text is passed. Exceptions in the reverse callable
are swallowed and logged a failed descriptor render simply means the
action is NOT marked reversible.
Logging is fully defensive DB-write failures are swallowed so the tool's
result is always returned untouched. Only metadata (name, capped args,
result_id, reverse_descriptor) is stored; tool output stays in the
checkpoint. Reversibility is best-effort: a reverse callable that raises
just leaves the action non-reversible.
"""
from __future__ import annotations
@ -203,11 +193,9 @@ class ActionLogMiddleware(AgentMiddleware):
)
return
# Surface a side-channel SSE event so the chat tool card can
# render a Revert button immediately after the row is durable.
# ``stream_new_chat`` translates this into a
# ``data-action-log`` SSE event. We DO NOT include the
# ``reverse_descriptor`` payload here; only a presence flag.
# Side-channel event (relayed by ``stream_new_chat`` as a
# ``data-action-log`` SSE) so the tool card can show a Revert button
# once the row is durable. Carries a presence flag, not the descriptor.
try:
await adispatch_custom_event(
"action_log",

View file

@ -1,32 +1,12 @@
"""
BusyMutexMiddleware per-thread asyncio lock + cancel token.
"""Per-thread asyncio lock + cooperative cancel token, keyed by ``thread_id``.
LangChain has no built-in concept of "this thread is already running a
turn refuse the second concurrent request". Without it, a user
double-clicking "send" or refreshing the page mid-stream can spawn two
turns racing on the same checkpoint, producing duplicated tool calls
and mangled state.
Refuses a second concurrent turn on the same thread (e.g. double-clicked
"send") that would otherwise race on the same checkpoint and duplicate tool
calls. Also exposes a per-thread cancel event that long-running tools poll
via ``runtime.context.cancel_event.is_set()`` to abort cooperatively.
Ported from OpenCode's ``Stream.scoped(AbortController)`` pattern: a
single-process, in-memory lock + cooperative cancellation token keyed by
``thread_id``. For multi-worker deployments a distributed lock backend
(Redis or PostgreSQL advisory locks) is a phase-2 follow-up.
What this provides:
- A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``;
acquiring the lock during ``before_agent`` blocks any concurrent
prompt on the same thread until release.
- A per-thread ``asyncio.Event`` (``cancel_event``) that long-running
tools can poll to abort cooperatively. The event is reset between
turns. Tools should check ``runtime.context.cancel_event.is_set()``
in tight inner loops.
- A typed :class:`~app.agents.chat.runtime.errors.BusyError` raised when a
second turn arrives while the lock is held.
Note: SurfSense's ``stream_new_chat`` is the call site that should
acquire/release. Wiring this as middleware means the contract is
explicit and the lock manager is shared with subagents that compile
their own ``create_agent`` runnables.
Process-local and in-memory; multi-worker deployments need a distributed lock
(Redis / PostgreSQL advisory locks) as a follow-up.
"""
from __future__ import annotations
@ -152,9 +132,8 @@ class _ThreadLockManager:
return True
# Module-level singleton — process-local but reused across all agent
# instances built in this process. Subagents created in nested
# ``create_agent`` calls also get this so locks are coherent.
# Process-local singleton shared across all agents/subagents built in this
# process so per-thread locks stay coherent.
manager = _ThreadLockManager()
@ -266,7 +245,6 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
await lock.acquire()
epoch = manager.bump_turn_epoch(thread_id)
self._held_locks[thread_id] = (lock, epoch)
# Reset the cancel event so this turn starts fresh
reset_cancel(thread_id)
return None
@ -289,17 +267,14 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
return None
if lock.locked():
lock.release()
# Always clear cancel event between turns so a stale signal
# doesn't leak into the next request.
# Clear cancel event so a stale signal doesn't leak into the next turn.
reset_cancel(thread_id)
return None
# Provide sync no-ops because the middleware base class allows them
def before_agent( # type: ignore[override]
self, state: AgentState[Any], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
# Sync path: no asyncio.Lock to acquire. Best we can do is reject
# if anyone else is in flight.
# Sync path can't await an asyncio.Lock; only reject if one is in flight.
thread_id = self._thread_id(runtime)
if thread_id is None:
if self._require_thread_id:

View file

@ -82,13 +82,10 @@ _T = TypeVar("_T")
async def _ainvoke_with_timeout[T](
coro: Awaitable[_T], *, subagent_type: str, started_at: float
) -> _T:
"""Apply :data:`DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS` to ``coro``.
"""Apply the subagent invoke timeout to ``coro`` (non-positive disables it).
A non-positive timeout disables the cap (configurable via the
``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` env var). On expiry the
underlying task is cancelled and :class:`SubagentInvokeTimeoutError` is
raised the caller wraps it into a synthetic ToolMessage so the
orchestrator can decide what to do.
On expiry the task is cancelled and :class:`SubagentInvokeTimeoutError` is
raised for the caller to turn into a synthetic ToolMessage.
"""
timeout = DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS
if timeout <= 0:
@ -151,12 +148,9 @@ def build_task_tool_with_parent_config(
subagent_graphs: dict[str, Runnable] = {
spec["name"]: spec["runnable"] for spec in subagents
}
# Per-subagent context-hint providers (see ``SurfSenseSubagentSpec``).
# The mapping is sparse: only routes that opted in via ``pack_subagent``
# appear here, and the value is invoked once per ``task(...)`` call to
# generate a short string prepended to the subagent's first
# ``HumanMessage``. Failures are logged and swallowed — a broken hint
# provider must never prevent the underlying task from running.
# Sparse map of opt-in context-hint providers; each runs once per task()
# call to prepend a string to the subagent's first HumanMessage. Failures
# are swallowed so a broken hint never blocks the task.
subagent_hint_providers: dict[str, ContextHintProvider] = {
spec["name"]: provider
for spec in subagents
@ -178,24 +172,18 @@ def build_task_tool_with_parent_config(
def _billable_call_update(
subagent_type: str, runtime: ToolRuntime
) -> dict[str, Any]:
"""Build the per-call ``billable_calls`` delta + an optional warning.
"""Build the per-call ``billable_calls`` delta plus an optional soft-cap warning.
The orchestrator's ``billable_calls`` map is summed by
:func:`_int_counter_merge_reducer`, so we always emit
``{subagent_type: 1}`` and let the reducer accumulate. If the
cumulative count *after* this call would cross the configured
threshold, we also slip a soft ``messages`` entry into the update
so the orchestrator can read it on its next step and self-limit.
Returning a plain ``dict`` (vs. an extra :class:`Command`) keeps
the helper composable with the existing single/batch return paths.
Always emits ``{subagent_type: 1}`` (a reducer accumulates it); when this
call would cross the threshold, also adds a soft ``messages`` entry so the
orchestrator self-limits on its next step.
"""
delta: dict[str, Any] = {"billable_calls": {subagent_type: 1}}
threshold = DEFAULT_SUBAGENT_BILLABLE_THRESHOLD
if threshold <= 0:
return delta
prior = runtime.state.get("billable_calls") or {}
# ``prior`` may be a plain dict or a reducer-managed mapping; only
# int values are counted so a malformed checkpoint can't crash us.
# Count int values only so a malformed checkpoint can't crash us.
prior_total = sum(v for v in prior.values() if isinstance(v, int))
new_total = prior_total + 1
if prior_total < threshold <= new_total:
@ -214,8 +202,7 @@ def build_task_tool_with_parent_config(
"""Merge the per-call billable counter (and warning) into ``cmd``."""
delta = _billable_call_update(subagent_type, runtime)
warn_text = delta.pop("_billable_warn_text", None)
# ``cmd.update`` may be a dict or LangGraph ``UpdateDict``; defensively
# copy so we don't mutate state shared across other tool returns.
# Copy so we don't mutate state shared with other tool returns.
update = dict(getattr(cmd, "update", {}) or {})
for key, value in delta.items():
update[key] = value
@ -228,14 +215,10 @@ def build_task_tool_with_parent_config(
return Command(update=update)
def _safe_message_text(msg: Any) -> str:
"""Pull text out of a BaseMessage without trusting the ``.text`` property.
"""Pull text out of a BaseMessage without using the ``.text`` property.
``BaseMessage.text`` walks ``content_blocks`` and crashes with
``TypeError: 'NoneType' object is not iterable`` when ``content`` is
``None`` (common for tool-call AIMessages whose payload is purely
structured). ``getattr(msg, "text", None)`` does not catch this
because Python evaluates the property body before falling back to
the default. Read ``content`` directly and coerce defensively.
``.text`` crashes when ``content`` is ``None`` (common for tool-call
AIMessages), and ``getattr`` won't catch it, so read ``content`` directly.
"""
try:
content = getattr(msg, "content", None)
@ -258,23 +241,18 @@ def build_task_tool_with_parent_config(
return str(content)
def _build_tool_trace(messages: list[Any]) -> list[dict[str, Any]]:
"""Compress the subagent's message stream into a compact tool trace.
"""Compress the subagent's messages into a compact tool trace.
Each entry is ``{"tool": <name>, "status": "ok"|"error", "preview":
<120 chars>}`` so the orchestrator can show "this is what your
specialist actually did" without dumping the full message stream
back through the prompt. The list is attached to the returned
ToolMessage's ``additional_kwargs`` (under ``"surf_tool_trace"``);
the LLM never sees it, but UI / observability code can pluck it
out of the checkpoint.
Entries (``{tool, status, preview}``) ride on the ToolMessage's
``additional_kwargs["surf_tool_trace"]`` for UI/observability; the LLM
never sees them.
"""
trace: list[dict[str, Any]] = []
for msg in messages:
tool_name = getattr(msg, "name", None)
tool_call_id_attr = getattr(msg, "tool_call_id", None)
if not tool_name and not tool_call_id_attr:
# Only ToolMessages have either field; skip AIMessage /
# HumanMessage / SystemMessage frames.
# Only ToolMessages carry either field.
continue
status = getattr(msg, "status", None) or "ok"
preview = _safe_message_text(msg).strip().replace("\n", " ")
@ -308,8 +286,7 @@ def build_task_tool_with_parent_config(
)
raise ValueError(msg)
message_text = _safe_message_text(messages[-1]).rstrip()
# Tool-trace is purely observability — wrap defensively so a single
# malformed frame never bubbles up and kills the whole user turn.
# Trace is observability-only; never let a bad frame kill the turn.
try:
tool_trace = _build_tool_trace(messages)
except Exception:
@ -320,10 +297,7 @@ def build_task_tool_with_parent_config(
tool_trace = []
tool_msg = ToolMessage(message_text, tool_call_id=tool_call_id)
if tool_trace:
# ``additional_kwargs`` is a free-form dict on BaseMessage; using
# a ``surf_`` prefix avoids collision with provider-specific keys
# (e.g. Anthropic's ``cache_control``). The LLM doesn't see it;
# consumers (UI, observability) read it off the checkpoint.
# surf_ prefix avoids collision with provider keys (e.g. cache_control).
tool_msg.additional_kwargs["surf_tool_trace"] = tool_trace
return Command(
update={
@ -361,9 +335,7 @@ def build_task_tool_with_parent_config(
}
hint = _resolve_context_hint(subagent_type, description, runtime)
if hint:
# Prepend as a tagged block so the subagent prompt can pattern-match
# on the section (and a future change can lift it into its own
# ``SystemMessage`` if needed).
# Tagged block so the subagent prompt can pattern-match the section.
payload = f"<context_hint>\n{hint}\n</context_hint>\n\n{description}"
else:
payload = description
@ -374,16 +346,12 @@ def build_task_tool_with_parent_config(
results: list[tuple[int, str, dict | str, dict | None]],
runtime: ToolRuntime,
) -> Command:
"""Combine per-child results into one Command with a combined ToolMessage.
"""Combine per-child results into one Command with an aggregate ToolMessage.
``results`` is a list of ``(task_index, subagent_type,
payload_or_error_text, child_state_update)`` tuples preserving the
input order so the orchestrator can map each block back to the task
it dispatched. State updates are merged by reducer for keys outside
:data:`EXCLUDED_STATE_KEYS`; everything else (``messages``, ``todos``,
etc.) is replaced by the synthesized aggregate ToolMessage. Every
child also contributes a ``billable_calls`` increment so cost
accounting matches single-mode dispatch.
``results`` tuples are ``(task_index, subagent_type, payload_or_error,
child_state_update)``; output blocks are sorted by index so the LLM can
map them back to dispatch order, and each child contributes a
``billable_calls`` increment to match single-mode accounting.
"""
results.sort(key=lambda r: r[0])
merged_state: dict[str, Any] = {}
@ -424,8 +392,8 @@ def build_task_tool_with_parent_config(
}
)
if state_update:
# Naive merge: later tasks win on scalar collisions; reducer-backed
# fields (``receipts``, ``files`` etc.) accumulate at apply time.
# Later tasks win on scalar collisions; reducer-backed fields
# accumulate at apply time.
merged_state.update(state_update)
aggregate = "\n\n".join(message_blocks)
aggregate_msg = ToolMessage(
@ -469,11 +437,9 @@ def build_task_tool_with_parent_config(
) -> tuple[int, str, dict | str, dict | None]:
"""Run one child of a batched ``task`` call under the concurrency cap.
Errors are returned as plain text in slot 2 so a single child's
failure does not abort the whole batch. ``GraphInterrupt`` from a
batched child is currently treated as a hard failure for that child
only batched HITL is intentionally out of scope for the v1
rollout (see plan tier 2 item 4 risks).
Errors are returned as text (slot 2) so one child's failure doesn't abort
the batch. A child's ``GraphInterrupt`` is a hard failure for that child:
batched HITL is intentionally out of scope.
"""
async with semaphore:
if subagent_type not in subagent_graphs:
@ -507,8 +473,7 @@ def build_task_tool_with_parent_config(
)
return (task_index, subagent_type, str(exc), None)
except GraphInterrupt:
# Batched HITL is unsupported in v1 — surface as a failure
# for this child so the rest of the batch still completes.
# Batched HITL unsupported; fail this child so the batch finishes.
logger.warning(
"Batch child %d (%s) raised GraphInterrupt; batched HITL "
"is not supported. Re-dispatch this task as a single "
@ -545,14 +510,11 @@ def build_task_tool_with_parent_config(
return (task_index, subagent_type, result, child_state_update)
def _coerce_batch_arg(tasks: Any) -> list[dict] | str:
"""Rescue common LLM-side malformations of the ``tasks`` argument.
"""Rescue common LLM malformations of the ``tasks`` argument.
Some providers serialise an array argument as a JSON-encoded string,
and small models occasionally hand back a single ``{description,
subagent_type}`` dict instead of a one-element array. Both are
recovered here with a WARN log so the issue is visible in metrics
but the user's turn still completes; truly broken shapes return a
plain string that the caller surfaces as the tool error.
Recovers a JSON-encoded array string and a single dict (instead of a
1-element array), logging a WARN. Unrecoverable shapes return a string
the caller surfaces as the tool error.
"""
if isinstance(tasks, list):
return tasks
@ -587,13 +549,10 @@ def build_task_tool_with_parent_config(
async def _adispatch_batch(
tasks: list[dict], runtime: ToolRuntime
) -> Command | str:
"""Fan-out helper for the ``tasks`` array shape.
"""Fan out the ``tasks`` array (size- and concurrency-capped).
Bounded by :data:`MAX_SUBAGENT_BATCH_SIZE` and concurrency-capped
at :data:`DEFAULT_SUBAGENT_BATCH_CONCURRENCY`. Returns a single
:class:`Command` that the LLM sees as one ToolMessage per child,
prefixed with ``[task <index>]`` so it can map back to the input
order.
Returns one Command; the LLM sees one ``[task <index>]``-prefixed block
per child, in input order.
"""
if not tasks:
return "tasks: array is empty; nothing to dispatch."
@ -703,17 +662,16 @@ def build_task_tool_with_parent_config(
if pending_value is not None:
resume_value = consume_surfsense_resume(runtime)
if resume_value is None:
# Bridge invariant: a queued resume must accompany any pending
# subagent interrupt. Fall-through replay would silently re-prompt
# the user; raise so the streaming layer surfaces a clear error.
# A pending interrupt must have a queued resume; otherwise replay
# would silently re-prompt the user. Raise instead.
raise RuntimeError(
f"Subagent {subagent_type!r} has a pending interrupt but no "
"surfsense_resume_value on config; resume bridge is broken."
)
expected = hitlrequest_action_count(pending_value)
resume_value = fan_out_decisions_to_match(resume_value, expected)
# Prevent the parent's resume payload from leaking into subagent
# interrupts via langgraph's parent_scratchpad fallback.
# Stop the parent's resume leaking into subagent interrupts via
# langgraph's parent_scratchpad fallback.
drain_parent_null_resume(runtime)
with ot.subagent_invoke_span(
subagent_type=subagent_type, path=invoke_path
@ -829,10 +787,8 @@ def build_task_tool_with_parent_config(
] = None,
) -> str | Command:
atask_start = time.perf_counter()
# Kill switch: when ops flips the spawn-paused flag for this
# workspace, every ``task(...)`` invocation (single- or batch-mode)
# short-circuits with a clear ToolMessage so the orchestrator can
# tell the user what happened and stop hammering downstream APIs.
# Ops kill switch: short-circuit every task() call for this workspace
# so the orchestrator stops hammering downstream APIs.
if await is_spawn_paused(search_space_id):
logger.warning(
"[hitl_route] atask SPAWN_PAUSED: search_space_id=%s tool_call_id=%s",
@ -923,8 +879,8 @@ def build_task_tool_with_parent_config(
)
expected = hitlrequest_action_count(pending_value)
resume_value = fan_out_decisions_to_match(resume_value, expected)
# Prevent the parent's resume payload from leaking into subagent
# interrupts via langgraph's parent_scratchpad fallback.
# Stop the parent's resume leaking into subagent interrupts via
# langgraph's parent_scratchpad fallback.
drain_parent_null_resume(runtime)
with ot.subagent_invoke_span(
subagent_type=subagent_type, path=invoke_path

View file

@ -1,33 +1,19 @@
"""End-of-turn persistence for the cloud-mode SurfSense filesystem.
This middleware runs ``aafter_agent`` once per turn (cloud only). It commits
all staged folder creations, file moves, content writes/edits, file deletes
(``rm``), and directory deletes (``rmdir``) to Postgres in a single ordered
pass:
Runs ``aafter_agent`` once per turn (cloud only), committing staged folder
creates, moves, writes/edits, and ``rm``/``rmdir`` to Postgres in one ordered
pass. Order matters: moves resolve before writes (so write-then-move lands at
the final path), and file deletes run before directory deletes (so a same-turn
``rm /a/x.md`` + ``rmdir /a`` works).
1. Materialize ``staged_dirs`` into ``Folder`` rows.
2. Apply ``pending_moves`` in order (chained moves resolved via
``doc_id_by_path``).
3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move
sequences commit at the final path. Paths queued for ``rm`` this turn
are dropped here so a write+rm sequence doesn't recreate the doc.
4. Commit content writes / edits for ``/documents/*`` paths, skipping
``temp_*`` basenames.
5. Apply ``pending_deletes`` (``rm``) file deletes run BEFORE directory
deletes so a same-turn ``rm /a/x.md`` + ``rmdir /a`` sequence works.
6. Apply ``pending_dir_deletes`` (``rmdir``); re-verifies emptiness against
the post-step-5 DB state.
When ``flags.enable_action_log`` is on, each destructive op also snapshots a
``DocumentRevision`` / ``FolderRevision`` for revert. For ``rm``/``rmdir`` the
snapshot and DELETE share a SAVEPOINT, so a failed snapshot aborts the delete
rather than making the data silently irreversible.
When ``flags.enable_action_log`` is on every destructive op also writes a
``DocumentRevision`` / ``FolderRevision`` snapshot bound to the
originating ``AgentActionLog`` row via ``tool_call_id``. ``rm``/``rmdir``
share a single ``SAVEPOINT`` with their snapshot if the snapshot fails
the DELETE rolls back and we surface the error rather than silently
making the data irreversible.
The commit body is exposed as a free function ``commit_staged_filesystem_state``
so the optional stream-task fallback (``stream_new_chat.py``) can call the
exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect).
The commit body is a free function (``commit_staged_filesystem_state``) so the
stream-task fallback can run the identical routine when ``aafter_agent`` was
skipped (e.g. client disconnect).
"""
from __future__ import annotations
@ -216,11 +202,9 @@ async def _create_document(
virtual_path,
search_space_id,
)
# Filesystem-parity invariant: the only thing that *must* be unique is
# the path. Two notes can legitimately share content (e.g. ``cp a b``).
# Guard against the path-derived ``unique_identifier_hash`` constraint
# so we surface a clean ValueError instead of letting the INSERT poison
# the session with an IntegrityError.
# Pre-check the path-derived unique_identifier_hash so a duplicate path
# surfaces as a clean ValueError instead of an INSERT IntegrityError that
# poisons the session. Content is intentionally not unique (cp a b).
path_collision = await session.execute(
select(Document.id).where(
Document.search_space_id == search_space_id,
@ -232,13 +216,6 @@ async def _create_document(
f"a document already exists at path '{virtual_path}' "
"(unique_identifier_hash collision)"
)
# ``content_hash`` is intentionally NOT checked for uniqueness here.
# In a real filesystem two files at different paths can hold identical
# bytes, and the agent's ``write_file`` path needs that semantic to
# support copy/duplicate operations. The hash remains useful as a
# change-detection hint for connector indexers, which still consult it
# via :func:`check_duplicate_document` but do so with a non-unique
# lookup (``.first()``).
content_hash = generate_content_hash(content, search_space_id)
doc = Document(
title=title,
@ -435,15 +412,9 @@ async def _mark_action_reversible(
) -> None:
"""Flip ``agent_action_log.reversible = TRUE`` for ``action_id``.
Best-effort: caller may invoke from inside a SAVEPOINT and treat
failure as a soft demotion (snapshot persists, just no Revert button).
Callers should also call ``_dispatch_reversibility_update`` (defined
below) AFTER the enclosing SAVEPOINT block exits successfully so the
chat tool card can light up its Revert button without
re-fetching ``GET /threads/.../actions``. Dispatching from inside the
SAVEPOINT would risk emitting "reversible=true" for rows whose
update gets rolled back if the surrounding destructive op fails.
Pair with ``_dispatch_reversibility_update`` *after* the enclosing
SAVEPOINT commits, so the UI never sees ``reversible=true`` for a row whose
update later rolls back.
"""
if action_id is None:
return
@ -455,22 +426,11 @@ async def _mark_action_reversible(
async def _dispatch_reversibility_update(action_id: int | None) -> None:
"""Best-effort dispatch of an ``action_log_updated`` custom event.
"""Emit an ``action_log_updated`` SSE event so the Revert button lights up.
Surfaces the post-SAVEPOINT reversibility flip to the SSE layer so
the chat tool card can flip its Revert button live. Defensive:
failures are logged at debug level and swallowed; the
REST endpoint ``GET /threads/.../actions`` is still authoritative.
.. warning::
Inside :func:`commit_staged_filesystem_state` we DEFER all
dispatches until the outer ``session.commit()`` succeeds see
the ``deferred_dispatches`` queue in that function. Dispatching
from inside a SAVEPOINT block while the outer transaction is
still pending would emit ``reversible=true`` for rows whose
snapshots get rolled back if the outer commit fails. Direct
callers (e.g. the optional stream-task fallback) that own the
full session lifetime can still call this helper inline.
Best-effort (failures swallowed; the REST actions endpoint is
authoritative). Inside :func:`commit_staged_filesystem_state` this is
deferred until after the outer commit via ``deferred_dispatches``.
"""
if action_id is None:
return
@ -489,12 +449,9 @@ async def _dispatch_reversibility_update(action_id: int | None) -> None:
# ---------------------------------------------------------------------------
# Snapshot helpers
# ---------------------------------------------------------------------------
#
# Best-effort helpers swallow + log so a snapshot failure can never break
# the destructive op for non-destructive tools (write/edit/move/mkdir).
# Strict helpers run inside the SAME ``begin_nested()`` SAVEPOINT as the
# destructive DELETE — failure aborts the savepoint and leaves the doc /
# folder intact, so revertable ops never become irreversible silently.
# Best-effort variants (write/edit/move/mkdir) swallow failures. Strict
# variants (rm/rmdir) share the destructive op's SAVEPOINT so a snapshot
# failure aborts the delete instead of making it silently irreversible.
def _doc_revision_payload(
@ -704,15 +661,9 @@ async def commit_staged_filesystem_state(
) -> dict[str, Any] | None:
"""Commit all staged filesystem changes; return the state delta for reducers.
Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent`
and the optional stream-task fallback.
When ``flags.enable_action_log`` is on every destructive op also writes
a ``DocumentRevision`` / ``FolderRevision`` snapshot bound to the
originating ``AgentActionLog`` row via ``tool_call_id``. Snapshot
durability is best-effort for non-destructive ops and STRICT for
``rm``/``rmdir`` (snapshot + DELETE share a SAVEPOINT snapshot
failure aborts the delete).
Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` and
the stream-task fallback. See the module docstring for ordering and the
action-log snapshot/revert semantics.
"""
if filesystem_mode != FilesystemMode.CLOUD:
return None
@ -771,8 +722,7 @@ async def commit_staged_filesystem_state(
flags = get_flags()
snapshot_enabled = flags.enable_action_log
# De-duplicate pending deletes per-path while preserving the latest
# tool_call_id (the one the user is most likely to revert via the UI).
# De-dup deletes per-path, keeping the latest tool_call_id (likeliest revert).
file_delete_paths: dict[str, str] = {}
for entry in pending_deletes:
if not isinstance(entry, dict):
@ -796,22 +746,14 @@ async def commit_staged_filesystem_state(
applied_moves: list[dict[str, Any]] = []
doc_id_path_tombstones: dict[str, int | None] = {}
tree_changed = False
# Reversibility-flip dispatches are deferred until AFTER the outer
# ``session.commit()`` succeeds. Dispatching from inside the
# SAVEPOINT chain while the outer transaction is still pending
# would emit ``reversible=true`` for rows whose snapshots get rolled
# back if the final commit raises. Snapshot helpers append on
# success; we drain this list after commit and silently abandon it
# on rollback so the UI stays consistent with durable state.
# Reversibility-flip dispatches are drained only after the outer commit
# succeeds (and abandoned on rollback), so the UI never sees reversible=true
# for a snapshot that didn't durably land.
deferred_dispatches: list[int] = []
try:
async with shielded_async_session() as session:
# ------------------------------------------------------------------
# Resolve action-id bindings up front. One SELECT per turn for all
# tool_call_ids, NOT one per op — important because a turn that
# touches 50 paths would otherwise issue 50 lookups.
# ------------------------------------------------------------------
# Resolve all action-id bindings in one SELECT per turn, not per op.
action_id_by_call: dict[str, int] = {}
if snapshot_enabled and thread_id is not None:
tool_call_ids: set[str] = set()
@ -844,10 +786,7 @@ async def commit_staged_filesystem_state(
next(iter(action_id_by_call), None) if action_id_by_call else None
)
# ------------------------------------------------------------------
# 1. staged_dirs -> Folder rows. Snapshot post-flush so the new
# folder_id is available for the FK.
# ------------------------------------------------------------------
# 1. staged_dirs -> Folder rows (snapshot post-flush for the FK).
for folder_path in staged_dirs:
if not isinstance(folder_path, str):
continue
@ -868,7 +807,6 @@ async def commit_staged_filesystem_state(
tcid = staged_dir_tool_calls.get(folder_path)
action_id = _action_id_for(tcid)
if action_id is not None:
# Re-read the folder for the snapshot.
result = await session.execute(
select(Folder).where(Folder.id == folder_id)
)
@ -883,16 +821,13 @@ async def commit_staged_filesystem_state(
deferred_dispatches=deferred_dispatches,
)
# ------------------------------------------------------------------
# 2. pending_moves. Snapshot pre-move (in-place restore on revert).
# ------------------------------------------------------------------
# 2. pending_moves (snapshot pre-move for in-place restore on revert).
for move in pending_moves:
source = str(move.get("source") or "")
if snapshot_enabled and source:
tcid = str(move.get("tool_call_id") or "")
action_id = _action_id_for(tcid)
if action_id is not None:
# Resolve the doc to snapshot BEFORE we mutate it.
doc_id_pre = doc_id_by_path.get(source)
document_pre: Document | None = None
if doc_id_pre is not None:
@ -942,10 +877,8 @@ async def commit_staged_filesystem_state(
path = move_alias[path]
return path
# ------------------------------------------------------------------
# 3. dirty_paths -> writes/edits. Skip any path queued for ``rm``
# this turn so a write+rm sequence doesn't recreate the doc.
# ------------------------------------------------------------------
# 3. dirty_paths -> writes/edits. Paths queued for rm this turn are
# skipped so a write+rm sequence doesn't recreate the doc.
kb_dirty_seen: set[str] = set()
kb_dirty: list[str] = []
kb_dirty_origin: dict[str, str] = {}
@ -974,9 +907,7 @@ async def commit_staged_filesystem_state(
continue
content = "\n".join(file_data.get("content") or [])
doc_id = doc_id_by_path.get(path)
# Path ↔ tool_call_id binding: the dirty_paths list dedupes via
# _add_unique_reducer, so we look up the latest tool_call_id by
# path (or by the un-renamed origin).
# Look up tool_call_id by final path or its pre-rename origin.
origin = kb_dirty_origin.get(path, path)
tcid = dirty_path_tool_calls.get(path) or dirty_path_tool_calls.get(
origin
@ -984,12 +915,9 @@ async def commit_staged_filesystem_state(
action_id = _action_id_for(tcid)
if doc_id is None:
# The in-memory ``doc_id_by_path`` is per-thread and starts
# empty in every new chat. If the agent writes to a path
# that already exists in the DB (e.g. a previous chat's
# ``notes.md``), we must NOT try to INSERT — it would hit
# ``unique_identifier_hash`` (path-derived). Look up the
# existing doc and update it in place instead.
# doc_id_by_path is per-thread and empty in a new chat, so a
# write to a path already in the DB must update in place, not
# INSERT (which would hit the path-derived unique hash).
existing = await virtual_path_to_doc(
session,
search_space_id=search_space_id,
@ -1038,12 +966,9 @@ async def commit_staged_filesystem_state(
}
)
else:
# Fresh create. Wrap each create in a SAVEPOINT so a
# residual ``IntegrityError`` (e.g. a deployment that
# hasn't run migration 133 yet, where
# ``documents.content_hash`` still carries its legacy
# global UNIQUE constraint) rolls back only this one
# create instead of poisoning the whole turn.
# Fresh create, wrapped in a SAVEPOINT so a residual
# IntegrityError (e.g. pre-migration-133 content_hash UNIQUE)
# rolls back only this create, not the whole turn.
placeholder_revision_id: int | None = None
if snapshot_enabled and action_id is not None:
placeholder_revision_id = await _snapshot_document_pre_create(
@ -1066,8 +991,7 @@ async def commit_staged_filesystem_state(
logger.warning(
"kb_persistence: skipping %s create: %s", path, exc
)
# Roll back the placeholder revision since the create
# never happened.
# Create never happened; drop its placeholder revision.
if placeholder_revision_id is not None:
await session.execute(
delete(DocumentRevision).where(
@ -1114,19 +1038,14 @@ async def commit_staged_filesystem_state(
)
tree_changed = True
# ------------------------------------------------------------------
# 4. pending_deletes -> ``rm``. STRICT durability: snapshot + DELETE
# share a SAVEPOINT. If the snapshot insert fails, the DELETE
# rolls back too and we surface the error rather than silently
# making the data irreversible.
# ------------------------------------------------------------------
# 4. pending_deletes -> rm. Strict: snapshot + DELETE share a
# SAVEPOINT, so a failed snapshot rolls the delete back too.
for raw_path, tcid in file_delete_paths.items():
final = _final_path(raw_path)
if not final.startswith(DOCUMENTS_ROOT + "/"):
continue
action_id = _action_id_for(tcid)
# Resolve the doc.
doc_id_for_delete = doc_id_by_path.get(final)
document_to_delete: Document | None = None
if doc_id_for_delete is not None:
@ -1155,7 +1074,6 @@ async def commit_staged_filesystem_state(
try:
async with session.begin_nested():
# Strict: snapshot first; failure aborts the delete.
if snapshot_enabled and action_id is not None:
chunks = await _load_chunks_for_snapshot(
session, doc_id=doc_pk
@ -1184,10 +1102,7 @@ async def commit_staged_filesystem_state(
)
continue
# B1 — SAVEPOINT released. Defer the reversibility-flip
# dispatch until AFTER the outer commit succeeds so we
# never tell the UI a row is reversible if its snapshot
# gets rolled back.
# Defer the reversibility flip until after the outer commit.
if snapshot_enabled and action_id is not None:
deferred_dispatches.append(int(action_id))
@ -1206,11 +1121,8 @@ async def commit_staged_filesystem_state(
)
tree_changed = True
# ------------------------------------------------------------------
# 5. pending_dir_deletes -> ``rmdir``. STRICT durability + final
# emptiness check (after step 4's deletes have run, an "empty
# mid-turn" directory really IS empty in DB now).
# ------------------------------------------------------------------
# 5. pending_dir_deletes -> rmdir. Strict, and re-checks emptiness
# against post-step-4 DB state.
for raw_path, tcid in dir_delete_paths.items():
final = _final_path(raw_path)
if not final.startswith(DOCUMENTS_ROOT + "/"):
@ -1231,7 +1143,6 @@ async def commit_staged_filesystem_state(
)
continue
# Re-check emptiness against in-DB state.
docs_in_folder = await session.execute(
select(Document.id)
.where(Document.folder_id == folder_id)
@ -1296,10 +1207,7 @@ async def commit_staged_filesystem_state(
)
continue
# B1 — SAVEPOINT released. Defer the reversibility-flip
# dispatch until AFTER the outer commit succeeds so we
# never tell the UI a row is reversible if its snapshot
# gets rolled back.
# Defer the reversibility flip until after the outer commit.
if snapshot_enabled and action_id is not None:
deferred_dispatches.append(int(action_id))
@ -1319,18 +1227,13 @@ async def commit_staged_filesystem_state(
logger.exception(
"kb_persistence: commit failed (search_space=%s)", search_space_id
)
# Outer commit raised — every SAVEPOINT-released change above
# (snapshots + reversibility flips) is now rolled back. Drop
# the deferred SSE dispatches so the UI stays consistent with
# durable state.
# Outer commit raised: everything above rolled back, so drop the
# deferred dispatches.
deferred_dispatches.clear()
return None
# Outer commit succeeded; flush deferred reversibility-flip
# dispatches now so the chat tool card can light up its Revert
# button without re-fetching ``GET /threads/.../actions``. De-dup
# to avoid emitting the same id twice (e.g. write-then-rm in the
# same turn dispatches once for each snapshot site).
# Commit succeeded; flush deferred reversibility flips (de-duped, since
# write-then-rm in one turn appends an id per snapshot site).
if deferred_dispatches and dispatch_events:
for action_id in dict.fromkeys(deferred_dispatches):
try:
@ -1376,9 +1279,8 @@ async def commit_staged_filesystem_state(
p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX)
]
# Tombstone every committed-delete path so a stale ``state["files"]`` entry
# (which als_info would otherwise interpret as content) cannot survive into
# the next turn and make a now-empty folder look non-empty.
# Tombstone committed-delete paths so a stale state["files"] entry can't
# survive into the next turn and make a now-empty folder look non-empty.
deleted_file_paths = [
str(payload.get("virtualPath") or "")
for payload in committed_deletes
@ -1399,11 +1301,8 @@ async def commit_staged_filesystem_state(
"dirty_path_tool_calls": {_CLEAR: True},
}
# Emit one Receipt per committed mutation, folded into ``state['receipts']``
# via ``_list_append_reducer``. The receipts surface what actually committed
# (post-savepoint) rather than what the LLM intended; the orchestrator uses
# them as ground truth in the ``<verification>`` teaching. KB writes do not
# have public verifiable URLs, so ``verifiable_url`` stays unset.
# One Receipt per committed mutation: ground truth (post-savepoint) for the
# orchestrator's <verification> teaching. KB writes have no public URL.
receipts: list[Receipt] = []
def _kb_receipt(
@ -1444,8 +1343,6 @@ async def commit_staged_filesystem_state(
external_id=payload.get("id"),
)
for payload in applied_moves:
# ``applied_moves`` rows carry the destination ``virtualPath`` because
# the move has already landed in the DB by the time we reach this code.
path = str(payload.get("virtualPath") or "")
_kb_receipt(
type="file",
@ -1485,9 +1382,7 @@ async def commit_staged_filesystem_state(
if tree_changed:
delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1
# Avoid 'unused' lint when turn_id_for_revision was only useful for
# diagnostic purposes inside the SAVEPOINT chain above.
_ = turn_id_for_revision
_ = turn_id_for_revision # diagnostic-only; silence unused lint
logger.info(
"kb_persistence: commit (search_space=%s) creates=%d updates=%d "

View file

@ -29,7 +29,6 @@ def extract_domain(url: str) -> str:
try:
parsed = urlparse(url)
domain = parsed.netloc
# Remove 'www.' prefix if present
if domain.startswith("www."):
domain = domain[4:]
return domain
@ -53,14 +52,13 @@ def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]:
if len(content) <= max_length:
return content, False
# Try to truncate at a sentence boundary
# Prefer truncating at a sentence/paragraph boundary.
truncated = content[:max_length]
last_period = truncated.rfind(".")
last_newline = truncated.rfind("\n\n")
# Use the later of the two boundaries, or just truncate
boundary = max(last_period, last_newline)
if boundary > max_length * 0.8: # Only use boundary if it's not too far back
if boundary > max_length * 0.8: # only if the boundary isn't too far back
truncated = content[: boundary + 1]
return truncated + "\n\n[Content truncated...]", True
@ -111,8 +109,8 @@ async def _scrape_youtube_video(
http_client.proxies.update(residential_proxies)
ytt_api = YouTubeTranscriptApi(http_client=http_client)
# List all available transcripts and pick the first one
# (the video's primary language) instead of defaulting to English
# Pick the first transcript (video's primary language) rather than
# defaulting to English.
transcript_list = ytt_api.list(video_id)
transcript = next(iter(transcript_list))
captions = transcript.fetch()
@ -134,10 +132,8 @@ async def _scrape_youtube_video(
logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}")
transcript_text = f"No captions available for this video. Error: {e!s}"
# Build combined content
content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}"
# Truncate if needed
content, was_truncated = truncate_content(content, max_length)
word_count = len(content.split())
@ -212,20 +208,16 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
scrape_id = generate_scrape_id(url)
domain = extract_domain(url)
# Validate and normalize URL
if not url.startswith(("http://", "https://")):
url = f"https://{url}"
try:
# Check if this is a YouTube URL and use transcript API instead
# YouTube URLs use the transcript API instead of crawling.
video_id = get_youtube_video_id(url)
if video_id:
return await _scrape_youtube_video(url, video_id, max_length)
# Create webcrawler connector
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
# Crawl the URL
result, error = await connector.crawl_url(url, formats=["markdown"])
if error:
@ -250,28 +242,21 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
"error": "No content returned from crawler",
}
# Extract content and metadata
content = result.get("content", "")
metadata = result.get("metadata", {})
# Get title from metadata
title = metadata.get("title", "")
if not title:
title = domain or url.split("/")[-1] or "Webpage"
# Get description from metadata
description = metadata.get("description", "")
if not description and content:
# Use first paragraph as description
first_para = content.split("\n\n")[0] if content else ""
description = (
first_para[:300] + "..." if len(first_para) > 300 else first_para
)
# Truncate content if needed
content, was_truncated = truncate_content(content, max_length)
# Calculate word count
word_count = len(content.split())
return {

View file

@ -1,37 +1,9 @@
"""
Feature flags for the SurfSense new_chat agent stack.
"""Feature flags for the SurfSense new_chat agent stack.
These flags gate the newer agent middleware (some ported from OpenCode,
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
SurfSense-native). Most shipped agent-stack upgrades default ON so Docker
image updates work even when older installs do not have newly introduced
environment variables. Risky/experimental integrations stay default OFF,
and the master kill-switch can still disable everything new.
All new middleware checks its flag at agent build time. If the master
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
middleware is disabled regardless of its individual flag. This gives
operators a single switch to revert to pre-port behavior.
Examples
--------
Defaults:
SURFSENSE_ENABLE_CONTEXT_EDITING=true
SURFSENSE_ENABLE_COMPACTION_V2=true
SURFSENSE_ENABLE_RETRY_AFTER=true
SURFSENSE_ENABLE_MODEL_FALLBACK=false
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
SURFSENSE_ENABLE_PERMISSION=true
SURFSENSE_ENABLE_DOOM_LOOP=true
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
Master kill-switch (overrides everything else):
SURFSENSE_DISABLE_NEW_AGENT_STACK=true
Flags are resolved at agent build time. Most upgrades default ON so Docker
updates work without operators adding new env vars; risky integrations stay
OFF. The master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` forces every
flag below to False for a one-switch rollback to pre-port behavior.
"""
from __future__ import annotations
@ -93,39 +65,14 @@ class AgentFeatureFlags:
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
enable_otel: bool = False
# Performance — compiled-agent cache (Phase 1 + Phase 2).
# When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled
# graph if the cache key matches (LLM config + thread + tool surface +
# flags + system prompt + filesystem mode). Cuts per-turn agent-build
# wall clock from ~4-5s to <50µs on cache hits.
#
# SAFETY (Phase 2 unblocked this default-on):
# All connector mutation tools (``tools/notion``, ``tools/gmail``,
# ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``,
# ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``,
# ``tools/teams``, ``tools/luma``, ``connected_accounts``,
# ``update_memory``) now acquire fresh
# short-lived ``AsyncSession`` instances per call via
# :data:`async_session_maker`. The factory still accepts ``db_session``
# for registry compatibility but ``del``'s it immediately — see any
# of those files' factory docstrings for the rationale. The ``llm``
# closure is per-(provider, model, config_id) which is already in
# the cache key, so the LLM is safe to share across cached hits of
# the same key. The KB priority middleware reads
# ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5),
# not its constructor closure, so the same compiled agent serves
# turns with different mention lists correctly.
#
# Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the
# environment if a regression surfaces. The path is exercised by
# the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite.
# Performance — reuse a compiled agent graph when the cache key matches
# (~4-5s -> <50µs per turn). Safe to default-on because mutation tools take
# fresh short-lived sessions per call and per-turn context (mentions, etc.)
# is read from runtime.context, not the constructor closure. Rollback via
# SURFSENSE_ENABLE_AGENT_CACHE=false.
enable_agent_cache: bool = True
# Phase 1 (deferred — measure first): pre-build & share the
# general-purpose subagent ``CompiledSubAgent`` across cold-cache
# misses. Only helps when the outer cache MISSES (cache hits already
# reuse the entire SubAgentMiddleware-compiled graph). Off by default
# until we have data showing cold misses are frequent enough to
# justify the extra global state.
# Deferred: only helps on outer-cache MISSES, so off until data shows cold
# misses are frequent enough to justify the extra global state.
enable_agent_cache_share_gp_subagent: bool = False
@classmethod

View file

@ -594,14 +594,9 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
inject_system_message: bool = True, # For backwards compatibility
) -> None:
self.llm = llm
# The planner LLM handles short, structured internal tasks (query
# rewriting, date extraction, recency classification). When an
# operator marks a global config ``is_planner: true`` we route
# those calls to a cheap/fast model (e.g. gpt-4o-mini, Haiku, Azure
# gpt-5.x-nano) instead of the user's chat LLM — those classification
# tasks don't need frontier-tier capability. Falls back to the chat
# LLM when no planner config is wired up so deployments without one
# keep working unchanged.
# Cheap model for structured internal tasks (query rewrite, date
# extraction, recency classification) when one is configured; falls back
# to the chat LLM otherwise.
self.planner_llm = planner_llm or llm
self.search_space_id = search_space_id
self.filesystem_mode = filesystem_mode
@ -610,26 +605,17 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
self.top_k = top_k
self.mentioned_document_ids = mentioned_document_ids or []
self.inject_system_message = inject_system_message
# Build the kb-planner private Runnable ONCE here so we don't pay
# the ``create_agent`` compile cost (50-200ms) on every turn.
# Disabled by default behind ``enable_kb_planner_runnable``; when
# off the planner falls back to the legacy ``planner_llm.ainvoke``
# path.
# Compiled lazily and memoized to avoid the per-turn create_agent cost.
self._planner: Runnable | None = None
self._planner_compile_failed = False
def _build_kb_planner_runnable(self) -> Runnable | None:
"""Compile the kb-planner private :class:`Runnable` once.
"""Lazily compile and memoize the kb-planner Runnable.
Returns ``None`` when the feature flag is disabled, when the LLM is
unavailable, or when ``create_agent`` raises (we fall back to the
legacy ``planner_llm.ainvoke`` path in that case). Compilation happens
lazily on first call, then memoized via ``self._planner``.
The compiled agent is constructed without tools the planner's
contract is "answer with structured JSON" but it inherits the
:class:`RetryAfterMiddleware` so transient rate-limit errors
from the planner LLM call don't fail the whole turn.
Returns ``None`` (and the caller falls back to ``planner_llm.ainvoke``)
when the flag is off, the LLM is missing, or ``create_agent`` raises.
Built without tools but with RetryAfterMiddleware so a transient
rate-limit on the planner call doesn't fail the whole turn.
"""
if self._planner is not None or self._planner_compile_failed:
return self._planner
@ -677,10 +663,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
loop = asyncio.get_running_loop()
t0 = loop.time()
# Prefer the compiled-once planner Runnable when enabled; otherwise
# fall back to ``planner_llm.ainvoke``. The ``surfsense:internal``
# tag is preserved on both paths so ``_stream_agent_events`` still
# suppresses the planner's intermediate events from the UI.
# Both paths tag surfsense:internal so the planner's intermediate
# events stay suppressed from the UI.
planner = self._build_kb_planner_runnable()
try:
if planner is not None:
@ -819,32 +803,16 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
user_text=user_text,
)
# Per-turn ``mentioned_document_ids`` flow:
# 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
# streaming task supplies a fresh :class:`SurfSenseContextSchema`
# on every ``astream_events`` call, so this list is naturally
# scoped to the current turn. Allows cross-turn graph reuse via
# ``agent_cache``.
# 2. Legacy fallback (cache disabled / context not propagated): the
# constructor-injected ``self.mentioned_document_ids`` list. We
# drain it after the first read so a cached graph (no Phase 1.5
# wiring) doesn't keep replaying the same mentions on every
# turn.
# Prefer per-turn mentions from runtime.context (lets a cached graph
# serve different turns); fall back to the constructor closure, draining
# it after one read so stale mentions can't replay.
#
# CRITICAL: distinguish "context absent" (legacy caller, no field at
# all) from "context provided but empty" (turn with no mentions).
# ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
# Python, so a naive ``if ctx_mentions:`` would fall through to the
# legacy closure on every no-mention follow-up turn — replaying the
# mentions baked in by turn 1's cache-miss build. Always drain the
# closure once the runtime path has fired so a cached middleware
# instance can never resurrect stale state.
# CRITICAL: test ``ctx_mentions is not None``, not truthiness — an empty
# list means "this turn has no mentions", not "use the closure".
mention_ids: list[int] = []
ctx = getattr(runtime, "context", None) if runtime is not None else None
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
if ctx_mentions is not None:
# Runtime path is authoritative — even an empty list means
# "this turn has no mentions", NOT "look at the closure".
mention_ids = list(ctx_mentions)
if self.mentioned_document_ids:
self.mentioned_document_ids = []
@ -852,12 +820,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
mention_ids = list(self.mentioned_document_ids)
self.mentioned_document_ids = []
# Folder mentions live alongside doc mentions on the runtime
# context. They never feed hybrid search (folders aren't
# embedded) — they're surfaced purely as ``[USER-MENTIONED]``
# priority entries so the agent walks the folder with ``ls`` /
# ``find_documents`` instead of ignoring it. Cloud filesystem
# mode only.
# Folder mentions aren't embedded, so they skip hybrid search and are
# surfaced only as [USER-MENTIONED] entries. Cloud mode only.
folder_mention_ids: list[int] = []
if (
ctx is not None
@ -939,14 +903,10 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
async def _materialize_folder_priority(
self, folder_ids: list[int]
) -> list[dict[str, Any]]:
"""Resolve user-mentioned folder ids to ``<priority_documents>`` entries.
"""Resolve mentioned folder ids to canonical-path priority entries.
Each entry uses the canonical ``/documents/Folder/Sub/`` virtual
path (matching ``KnowledgeTreeMiddleware`` and the agent's
``ls`` adapter) and is flagged ``mentioned=True`` so the
rendered line carries ``[USER-MENTIONED]``. ``score`` is left
``None`` so the renderer prints ``n/a`` folders aren't
ranked, the agent decides which children to read.
Flagged ``mentioned=True`` with ``score=None`` (folders aren't ranked;
the agent decides which children to read).
"""
if not folder_ids:
return []

View file

@ -30,22 +30,11 @@ from langgraph.types import interrupt
logger = logging.getLogger(__name__)
# Tools that mirror the safety profile of ``write_file`` against the
# SurfSense KB: each call creates ONE artifact in the user's own workspace
# with no external visibility (drafts aren't sent; new files aren't shared
# unless the user shares them later). These are auto-approved by default
# so the agent can compose drafts and seed scratch files without a popup
# on every call.
#
# Members of this set still call ``request_approval`` exactly as before;
# the function returns immediately with ``decision_type="auto_approved"``
# and the original params untouched. This preserves the call-site shape
# (logging, metadata fetching, account fallbacks) so the only behavior
# change is "no interrupt fires".
#
# To re-enable prompting, the future per-search-space rules table
# (``agent_permission_rules``) takes precedence in the permission ruleset
# layering assembled by the agent factory.
# Low-stakes creation tools auto-approved by default: each creates one
# artifact in the user's own workspace with no external visibility (drafts
# aren't sent; new files aren't shared). They still call ``request_approval``,
# which returns ``decision_type="auto_approved"`` without firing an interrupt.
# Per-search-space ``agent_permission_rules`` can re-enable prompting.
DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
{
"create_gmail_draft",
@ -150,10 +139,6 @@ def request_approval(
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
if tool_name in DEFAULT_AUTO_APPROVED_TOOLS:
# Default policy: low-stakes creation tools (drafts + new-file
# creates) skip HITL because they're as recoverable as a local
# ``write_file`` against the SurfSense KB. The user can still
# delete the artifact in <30s if it's wrong.
logger.info(
"Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL",
tool_name,

View file

@ -75,7 +75,7 @@ def create_generate_image_tool(
captured model), use this config id instead of reading the search space's
live ``image_generation_config_id``.
"""
del db_session # use a fresh per-call session, see below
del db_session # tool uses a fresh per-call session instead
@tool
async def generate_image(
@ -140,17 +140,12 @@ def create_generate_image_tool(
or IMAGE_GEN_AUTO_MODE_ID
)
# Build generation kwargs
# NOTE: size, quality, and style are intentionally NOT passed.
# Different models support different values for these params
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
# gpt-image-1 wants "high"/"medium"/"low"; size options also
# differ). Letting the model use its own defaults avoids errors.
# size/quality/style are intentionally omitted: valid values
# differ per model, so we let each model use its own defaults.
gen_kwargs: dict[str, Any] = {}
if n is not None and n > 1:
gen_kwargs["n"] = n
# Call litellm based on config type
if is_image_gen_auto_mode(config_id):
if not ImageGenRouterService.is_initialized():
err = (
@ -224,17 +219,13 @@ def create_generate_image_tool(
prompt=prompt, model=model_string, **gen_kwargs
)
# Parse the response and store in DB
response_dict = (
response.model_dump()
if hasattr(response, "model_dump")
else dict(response)
)
# Generate a random access token for this image
access_token = generate_image_token()
# Save to image_generations table for history
db_image_gen = ImageGeneration(
prompt=prompt,
model=getattr(response, "_hidden_params", {}).get("model"),
@ -249,7 +240,6 @@ def create_generate_image_tool(
await session.refresh(db_image_gen)
db_image_gen_id = db_image_gen.id
# Extract image URLs from response
images = response_dict.get("data", [])
if not images:
return _failed(
@ -260,11 +250,8 @@ def create_generate_image_tool(
first_image = images[0]
revised_prompt = first_image.get("revised_prompt", prompt)
# Resolve image URL:
# - If the API returned a URL, use it directly.
# - If the API returned b64_json (e.g. gpt-image-1), serve the
# image through our backend endpoint to avoid bloating the
# LLM context with megabytes of base64 data.
# b64_json (e.g. gpt-image-1) is served via our backend endpoint so
# megabytes of base64 don't bloat the LLM context.
if first_image.get("url"):
image_url = first_image["url"]
elif first_image.get("b64_json"):

View file

@ -241,23 +241,12 @@ def _normalize_connectors(
connectors_to_search: list[str] | None,
available_connectors: list[str] | None = None,
) -> list[str]:
"""Normalize model-supplied connectors to canonical ConnectorService types.
Maps user-facing aliases (e.g. WEBCRAWLER_CONNECTOR), drops unknowns, and
constrains to ``available_connectors`` when given. Empty input defaults to
all available connectors (minus live-search ones).
"""
Normalize connectors provided by the model.
- Accepts user-facing enums like WEBCRAWLER_CONNECTOR and maps them to canonical
ConnectorService types.
- Drops unknown values.
- If available_connectors is provided, only includes connectors from that list.
- If connectors_to_search is None/empty, defaults to available_connectors or all.
Args:
connectors_to_search: List of connectors requested by the model
available_connectors: List of connectors actually available in the search space
Returns:
List of normalized connector strings to search
"""
# Determine the set of valid connectors to consider
valid_set = (
set(available_connectors) if available_connectors else set(_ALL_CONNECTORS)
)
@ -276,18 +265,16 @@ def _normalize_connectors(
c = (raw or "").strip().upper()
if not c:
continue
# Map user-facing aliases to canonical names
if c == "WEBCRAWLER_CONNECTOR":
c = "CRAWLED_URL"
normalized.append(c)
# de-dupe while preserving order + filter to valid connectors
# De-dupe (order-preserving), keeping only known + available connectors.
seen: set[str] = set()
out: list[str] = []
for c in normalized:
if c in seen:
continue
# Only include if it's a known connector AND available
if c not in _ALL_CONNECTORS:
continue
if c not in valid_set:
@ -295,7 +282,7 @@ def _normalize_connectors(
seen.add(c)
out.append(c)
# Fallback to all available if nothing matched
# Nothing matched: fall back to all available.
if not out:
base = (
list(available_connectors)
@ -377,39 +364,17 @@ def format_documents_for_context(
max_chunk_chars: int = _MAX_CHUNK_CHARS,
max_chunks_per_doc: int = 0,
) -> str:
"""
Format retrieved documents into a readable context string for the LLM.
"""Format retrieved documents into an XML context string for the LLM.
Documents are added in order (highest relevance first) until the character
budget is reached. Individual chunks are capped at ``max_chunk_chars`` and
each document is limited to a dynamically computed chunk cap so a single
large document cannot monopolize the output while still maximising the use
of available context space.
Args:
documents: List of document dictionaries from connector search
max_chars: Approximate character budget for the entire output.
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means
auto-compute per document using a rank-adaptive formula so
higher-ranked documents receive more chunks.
Returns:
Formatted string with document contents and metadata
Documents are emitted highest-relevance first until ``max_chars`` is hit.
``max_chunks_per_doc=0`` auto-computes a rank-adaptive cap so top results get
more chunks and no single large document monopolizes the budget.
"""
if not documents:
return ""
# Group chunks by document id (preferred) to produce the XML structure.
#
# IMPORTANT: ConnectorService returns **document-grouped** results of the form:
# {
# "document": {...},
# "chunks": [{"chunk_id": 123, "content": "..."}, ...],
# "source": "NOTION_CONNECTOR" | "FILE" | ...
# }
#
# We must preserve chunk_id so citations like [citation:123] are possible.
# Group chunks by document id, preserving chunk_id so [citation:123] works.
# ConnectorService returns document-grouped results ({document, chunks, source}).
grouped: dict[str, dict[str, Any]] = {}
for doc in documents:
@ -430,7 +395,7 @@ def format_documents_for_context(
or "UNKNOWN"
)
# Document identity (prefer document_id; otherwise fall back to type+title+url)
# Identity: prefer document_id, else type+title+url.
document_id_val = document_info.get("id")
title = (
document_info.get("title") or metadata.get("title") or "Untitled Document"
@ -460,7 +425,7 @@ def format_documents_for_context(
"chunks": [],
}
# Prefer document-grouped chunks if available
# Prefer document-grouped chunks when present.
chunks_list = doc.get("chunks") if isinstance(doc, dict) else None
if isinstance(chunks_list, list) and chunks_list:
for ch in chunks_list:
@ -492,7 +457,6 @@ def format_documents_for_context(
"BAIDU_SEARCH_API",
}
# Render XML expected by citation instructions, respecting the char budget.
parts: list[str] = []
total_chars = 0
total_docs = len(grouped)
@ -594,30 +558,11 @@ async def search_knowledge_base_async(
available_document_types: list[str] | None = None,
max_input_tokens: int | None = None,
) -> str:
"""
Search the user's knowledge base for relevant documents.
"""Search the knowledge base across connectors and return formatted results.
This is the async implementation that searches across multiple connectors.
Args:
query: The search query
search_space_id: The user's search space ID
db_session: Database session
connector_service: Initialized connector service
connectors_to_search: Optional list of connector types to search. If omitted, searches all.
top_k: Number of results per connector
start_date: Optional start datetime (UTC) for filtering documents
end_date: Optional end datetime (UTC) for filtering documents
available_connectors: Optional list of connectors actually available in the search space.
If provided, only these connectors will be searched.
available_document_types: Optional list of document types that actually have indexed
data. When provided, local connectors whose document type is
absent are skipped entirely (no embedding / DB round-trip).
max_input_tokens: Model context window size (tokens). Used to dynamically
size the output so it fits within the model's limits.
Returns:
Formatted string with search results
``available_document_types`` lets local connectors with no indexed data be
skipped (no embedding / DB round-trip), and ``max_input_tokens`` sizes the
output to the model's context window.
"""
perf = get_perf_logger()
t0 = time.perf_counter()

View file

@ -196,13 +196,8 @@ def _strip_wrapping_code_fences(text: str) -> str:
def _extract_metadata(content: str) -> dict[str, Any]:
"""Extract metadata from generated Markdown content."""
# Count section headings
headings = re.findall(r"^(#{1,6})\s+(.+)$", content, re.MULTILINE)
# Word count
word_count = len(content.split())
# Character count
char_count = len(content)
return {
@ -227,12 +222,11 @@ def _parse_sections(content: str) -> list[dict[str, str]]:
in_code_block = False
for line in lines:
# Track code blocks to avoid matching headings inside them
# Track fences so headings inside code blocks aren't treated as splits.
stripped = line.strip()
if stripped.startswith("```"):
in_code_block = not in_code_block
# Only split on # or ## headings (not ### or deeper) and only outside code blocks
is_section_heading = (
not in_code_block
and re.match(r"^#{1,2}\s+", line)
@ -240,7 +234,6 @@ def _parse_sections(content: str) -> list[dict[str, str]]:
)
if is_section_heading:
# Save previous section
if current_heading or current_body_lines:
sections.append(
{
@ -253,7 +246,6 @@ def _parse_sections(content: str) -> list[dict[str, str]]:
else:
current_body_lines.append(line)
# Save last section
if current_heading or current_body_lines:
sections.append(
{
@ -292,7 +284,6 @@ async def _revise_with_sections(
Unchanged sections are kept byte-for-byte identical.
Returns the revised content, or None to trigger full-document revision fallback.
"""
# Parse report into sections
sections = _parse_sections(parent_content)
if len(sections) < 2:
logger.info(
@ -300,7 +291,6 @@ async def _revise_with_sections(
)
return None
# Build a sections listing for the LLM
sections_listing = ""
for i, sec in enumerate(sections):
heading = sec["heading"] or "(preamble — content before first heading)"
@ -352,11 +342,9 @@ async def _revise_with_sections(
)
return None
# Compute total operations for progress tracking
total_ops = len(modify_indices) + len(add_sections)
current_op = 0
# Emit plan summary
parts = []
if modify_indices:
parts.append(
@ -394,7 +382,6 @@ async def _revise_with_sections(
current_op += 1
sec = sections[idx]
# Extract plain section name (strip markdown heading markers)
section_name = (
re.sub(r"^#+\s*", "", sec["heading"]).strip()
if sec["heading"]
@ -412,7 +399,6 @@ async def _revise_with_sections(
f"{sec['heading']}\n\n{sec['body']}" if sec["heading"] else sec["body"]
)
# Build context from surrounding sections
context_parts = []
if idx > 0:
prev = sections[idx - 1]
@ -442,7 +428,6 @@ async def _revise_with_sections(
revised_text = resp.content
if revised_text and isinstance(revised_text, str):
revised_text = _strip_wrapping_code_fences(revised_text).strip()
# Parse the LLM output back into heading + body
revised_parsed = _parse_sections(revised_text)
if revised_parsed:
revised_sections[idx] = revised_parsed[0]
@ -465,7 +450,6 @@ async def _revise_with_sections(
heading = add_info.get("heading", "## New Section")
description = add_info.get("description", "")
# Extract plain section name for progress display
plain_heading = re.sub(r"^#+\s*", "", heading).strip()
dispatch_custom_event(
"report_progress",
@ -475,7 +459,6 @@ async def _revise_with_sections(
},
)
# Build context from the surrounding sections at the insertion point
ctx_parts = []
if 0 <= after_idx < len(revised_sections):
before_sec = revised_sections[after_idx]
@ -542,36 +525,13 @@ def create_generate_report_tool(
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
):
"""
Factory function to create the generate_report tool with injected dependencies.
"""Create the generate_report tool with injected dependencies.
The tool generates a Markdown report inline using the search space's
document summary LLM, saves it to the database, and returns immediately.
Uses short-lived database sessions for each DB operation so no connection
is held during the long LLM API call.
Generation strategies:
- New reports: single-shot generation (1 LLM call)
- Revisions (targeted edits): section-level (unchanged sections preserved)
- Revisions (global changes): full-document revision fallback
Source strategies:
- "provided"/"conversation": use only the supplied source_content
- "kb_search": search the knowledge base internally using targeted queries
- "auto": use source_content if sufficient, otherwise fall back to KB search
Args:
search_space_id: The user's search space ID
thread_id: The chat thread ID for associating the report
connector_service: Optional connector service for internal KB search.
When provided, the tool can search the knowledge base internally
(used by the "kb_search" and "auto" source strategies).
available_connectors: Optional list of connector types available in the
search space (used to scope internal KB searches).
Returns:
A configured tool function for generating reports
Uses short-lived DB sessions per operation so no connection is held during
the long LLM call. Generation: new reports are single-shot; revisions try
section-level first (unchanged sections preserved) and fall back to full-doc.
Source strategies: provided/conversation (use source_content), kb_search
(internal KB queries), auto (KB search only when source_content is thin).
"""
@tool
@ -693,7 +653,7 @@ def create_generate_report_tool(
Returns:
Dict with status, report_id, title, word_count, and message.
"""
# Initialize version tracking variables (used by _save_failed_report closure)
# Shared with the _save_failed_report closure.
parent_report_content: str | None = None
report_group_id: int | None = None
@ -733,7 +693,7 @@ def create_generate_report_tool(
session.add(failed_report)
await session.commit()
await session.refresh(failed_report)
# If this is a new group (v1 failed), set group to self
# New group (v1 failed): point the group at itself.
if not failed_report.report_group_id:
failed_report.report_group_id = failed_report.id
await session.commit()
@ -749,8 +709,8 @@ def create_generate_report_tool(
try:
# ── Phase 1: READ (short-lived session) ──────────────────────
# Fetch parent report and LLM config, then close the session
# so no DB connection is held during the long LLM call.
# Fetch parent report + LLM config, then release the connection
# before the long LLM call.
async with shielded_async_session() as read_session:
if parent_report_id:
parent_report = await read_session.get(Report, parent_report_id)
@ -768,7 +728,6 @@ def create_generate_report_tool(
)
llm = await get_document_summary_llm(read_session, search_space_id)
# read_session closed — connection returned to pool
if not llm:
error_msg = (
@ -785,7 +744,6 @@ def create_generate_report_tool(
error=error_msg,
)
# Build the user instructions string
user_instructions_section = ""
if user_instructions:
user_instructions_section = (
@ -829,7 +787,7 @@ def create_generate_report_tool(
try:
from .knowledge_base import search_knowledge_base_async
# Run all queries in parallel, each with its own session
# Each query gets its own short-lived session.
async def _run_single_query(q: str) -> str:
async with shielded_async_session() as kb_session:
kb_connector_svc = ConnectorService(
@ -849,7 +807,6 @@ def create_generate_report_tool(
*[_run_single_query(q) for q in search_queries[:5]]
)
# Merge non-empty results into source_content
kb_text_parts = [r for r in kb_results if r and r.strip()]
if kb_text_parts:
kb_combined = "\n\n---\n\n".join(kb_text_parts)
@ -903,9 +860,9 @@ def create_generate_report_tool(
"provided. Using source_content as-is."
)
capped_source = effective_source[:100000] # Cap source content
capped_source = effective_source[:100000]
# Length constraint — only when user explicitly asks for brevity
# Length constraint only when the user explicitly asked for brevity.
length_instruction = ""
if report_style == "brief":
length_instruction = (
@ -920,11 +877,8 @@ def create_generate_report_tool(
report_content: str | None = None
if parent_report_content:
# ─── REVISION MODE ───────────────────────────────────────
# Strategy: Try section-level revision first (preserves
# unchanged sections byte-for-byte). Falls back to full-
# document revision if section identification fails or if
# all sections need changes.
# Revision mode: section-level first (preserves untouched
# sections), falling back to full-doc revision.
dispatch_custom_event(
"report_progress",
{
@ -946,7 +900,6 @@ def create_generate_report_tool(
)
if report_content is None:
# Fallback: full-document revision
dispatch_custom_event(
"report_progress",
{"phase": "writing", "message": "Rewriting your full report"},
@ -969,9 +922,7 @@ def create_generate_report_tool(
report_content = response.content
else:
# ─── NEW REPORT MODE ─────────────────────────────────────
# Single-shot generation: one LLM call produces the full
# report. Fast, globally coherent, and cost-efficient.
# New report: single-shot generation (one LLM call).
dispatch_custom_event(
"report_progress",
{"phase": "writing", "message": "Writing your report"},
@ -991,8 +942,6 @@ def create_generate_report_tool(
response = await llm.ainvoke([HumanMessage(content=prompt)])
report_content = response.content
# ── Validate LLM output ──────────────────────────────────────
if not report_content or not isinstance(report_content, str):
error_msg = "LLM returned empty or invalid content"
report_id = await _save_failed_report(error_msg)
@ -1029,14 +978,12 @@ def create_generate_report_tool(
if report_content.rstrip().endswith("---"):
report_content = report_content.rstrip()[:-3].rstrip()
# Append exactly one standard disclaimer
# Append exactly one standard footer.
report_content += "\n\n---\n\n" + _REPORT_FOOTER
# Extract metadata (includes "status": "ready")
metadata = _extract_metadata(report_content)
# ── Phase 3: WRITE (short-lived session) ─────────────────────
# Save the report to the database, then close the session.
async with shielded_async_session() as write_session:
report = Report(
title=topic,
@ -1051,14 +998,13 @@ def create_generate_report_tool(
await write_session.commit()
await write_session.refresh(report)
# If this is a brand-new report (v1), set report_group_id = own id
# Brand-new report (v1): point the group at itself.
if not report.report_group_id:
report.report_group_id = report.id
await write_session.commit()
saved_report_id = report.id
saved_group_id = report.report_group_id
# write_session closed — connection returned to pool
logger.info(
f"[generate_report] Created report {saved_report_id} "

View file

@ -23,7 +23,6 @@ def extract_domain(url: str) -> str:
try:
parsed = urlparse(url)
domain = parsed.netloc
# Remove 'www.' prefix if present
if domain.startswith("www."):
domain = domain[4:]
return domain
@ -47,14 +46,13 @@ def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]:
if len(content) <= max_length:
return content, False
# Try to truncate at a sentence boundary
# Prefer truncating at a sentence/paragraph boundary.
truncated = content[:max_length]
last_period = truncated.rfind(".")
last_newline = truncated.rfind("\n\n")
# Use the later of the two boundaries, or just truncate
boundary = max(last_period, last_newline)
if boundary > max_length * 0.8: # Only use boundary if it's not too far back
if boundary > max_length * 0.8: # only if the boundary isn't too far back
truncated = content[: boundary + 1]
return truncated + "\n\n[Content truncated...]", True
@ -105,8 +103,8 @@ async def _scrape_youtube_video(
http_client.proxies.update(residential_proxies)
ytt_api = YouTubeTranscriptApi(http_client=http_client)
# List all available transcripts and pick the first one
# (the video's primary language) instead of defaulting to English
# Pick the first transcript (video's primary language) rather than
# defaulting to English.
transcript_list = ytt_api.list(video_id)
transcript = next(iter(transcript_list))
captions = transcript.fetch()
@ -128,10 +126,8 @@ async def _scrape_youtube_video(
logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}")
transcript_text = f"No captions available for this video. Error: {e!s}"
# Build combined content
content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}"
# Truncate if needed
content, was_truncated = truncate_content(content, max_length)
word_count = len(content.split())
@ -206,20 +202,16 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
scrape_id = generate_scrape_id(url)
domain = extract_domain(url)
# Validate and normalize URL
if not url.startswith(("http://", "https://")):
url = f"https://{url}"
try:
# Check if this is a YouTube URL and use transcript API instead
# YouTube URLs use the transcript API instead of crawling.
video_id = get_youtube_video_id(url)
if video_id:
return await _scrape_youtube_video(url, video_id, max_length)
# Create webcrawler connector
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
# Crawl the URL
result, error = await connector.crawl_url(url, formats=["markdown"])
if error:
@ -244,28 +236,21 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
"error": "No content returned from crawler",
}
# Extract content and metadata
content = result.get("content", "")
metadata = result.get("metadata", {})
# Get title from metadata
title = metadata.get("title", "")
if not title:
title = domain or url.split("/")[-1] or "Webpage"
# Get description from metadata
description = metadata.get("description", "")
if not description and content:
# Use first paragraph as description
first_para = content.split("\n\n")[0] if content else ""
description = (
first_para[:300] + "..." if len(first_para) > 300 else first_para
)
# Truncate content if needed
content, was_truncated = truncate_content(content, max_length)
# Calculate word count
word_count = len(content.split())
return {

View file

@ -92,15 +92,9 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
yield chunk
# Provider mapping for LiteLLM model string construction.
#
# Single source of truth lives in
# :mod:`app.services.provider_capabilities` so the YAML loader (which
# runs during ``app.config`` class-body init) can resolve provider
# prefixes without dragging the agent / tools tree into module load
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
# tests) keep working unchanged.
# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives
# in provider_capabilities so the YAML loader can resolve prefixes during
# app.config init without importing the agent/tools tree.
from app.services.provider_capabilities import ( # noqa: E402
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
)
@ -157,25 +151,14 @@ class AgentConfig:
anonymous_enabled: bool = False
quota_reserve_tokens: int | None = None
# Capability flag: best-effort True for the chat selector / catalog.
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
# which prefers OpenRouter's ``architecture.input_modalities`` and
# otherwise consults LiteLLM's authoritative model map. Default True
# is the conservative-allow stance — the streaming-task safety net
# (``is_known_text_only_chat_model``) is the *only* place a False
# actually blocks a request. Setting this to False here without an
# authoritative source would silently hide vision-capable models
# (the regression we're fixing).
# Default-allow: only the streaming safety net (is_known_text_only_chat_model)
# actually blocks on False, so defaulting False would silently hide
# vision-capable models. Resolved via derive_supports_image_input.
supports_image_input: bool = True
@classmethod
def from_auto_mode(cls) -> "AgentConfig":
"""
Create an AgentConfig for Auto mode (LiteLLM Router load balancing).
Returns:
AgentConfig instance configured for Auto mode
"""
"""Build an AgentConfig for Auto mode (LiteLLM Router load balancing)."""
return cls(
provider="AUTO",
model_name="auto",
@ -193,27 +176,15 @@ class AgentConfig:
is_premium=False,
anonymous_enabled=False,
quota_reserve_tokens=None,
# Auto routes across the configured pool, which usually
# contains at least one vision-capable deployment; the router
# will surface a 404 from a non-vision deployment as a normal
# ``allowed_fails`` event and fail over rather than blocking
# the request outright.
# Auto fails over across the pool, so a non-vision deployment's 404
# is just an allowed_fails event rather than a hard block.
supports_image_input=True,
)
@classmethod
def from_new_llm_config(cls, config) -> "AgentConfig":
"""
Create an AgentConfig from a NewLLMConfig database model.
Args:
config: NewLLMConfig database model instance
Returns:
AgentConfig instance
"""
# Lazy import to avoid pulling provider_capabilities (and its
# transitive litellm import) into module-init order.
"""Build an AgentConfig from a NewLLMConfig database model."""
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
from app.services.provider_capabilities import derive_supports_image_input
provider_value = (
@ -245,10 +216,8 @@ class AgentConfig:
is_premium=False,
anonymous_enabled=False,
quota_reserve_tokens=None,
# BYOK rows have no operator-curated capability flag, so we
# ask LiteLLM (default-allow on unknown). The streaming
# safety net still blocks if the model is *explicitly*
# marked text-only.
# BYOK rows have no curated flag; ask LiteLLM (default-allow on
# unknown). The streaming safety net still blocks explicit text-only.
supports_image_input=derive_supports_image_input(
provider=provider_value,
model_name=config.model_name,
@ -259,25 +228,14 @@ class AgentConfig:
@classmethod
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
"""Build an AgentConfig from a YAML configuration dictionary.
Supports the same prompt fields as NewLLMConfig (system_instructions,
use_default_system_instructions, citations_enabled).
"""
Create an AgentConfig from a YAML configuration dictionary.
YAML configs now support the same prompt configuration fields as NewLLMConfig:
- system_instructions: Custom system instructions (empty string uses defaults)
- use_default_system_instructions: Whether to use default instructions
- citations_enabled: Whether citations are enabled
Args:
yaml_config: Configuration dictionary from YAML file
Returns:
AgentConfig instance
"""
# Lazy import to avoid pulling provider_capabilities (and its
# transitive litellm import) into module-init order.
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
from app.services.provider_capabilities import derive_supports_image_input
# Get system instructions from YAML, default to empty string
system_instructions = yaml_config.get("system_instructions", "")
provider = yaml_config.get("provider", "").upper()
@ -290,13 +248,8 @@ class AgentConfig:
else None
)
# Explicit YAML override wins; otherwise derive from LiteLLM /
# OpenRouter modalities. The YAML loader already populates this
# field, but this method is also called from
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
# so we re-derive here for safety. The bool() coercion preserves
# the loader's behaviour for explicit ``true`` / ``false``
# strings that PyYAML may surface.
# Explicit YAML override wins; otherwise re-derive (the hot-reload file
# fallback reaches this method without the loader having populated it).
if "supports_image_input" in yaml_config:
supports_image_input = bool(yaml_config.get("supports_image_input"))
else:
@ -314,7 +267,6 @@ class AgentConfig:
api_base=yaml_config.get("api_base"),
custom_provider=custom_provider,
litellm_params=yaml_config.get("litellm_params"),
# Prompt configuration from YAML (with defaults for backwards compatibility)
system_instructions=system_instructions if system_instructions else None,
use_default_system_instructions=yaml_config.get(
"use_default_system_instructions", True
@ -332,20 +284,10 @@ class AgentConfig:
def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
"""
Load a specific LLM config from global_llm_config.yaml.
Args:
llm_config_id: The id of the config to load (default: -1)
Returns:
LLM config dict or None if not found
"""
# Get the config file path
"""Load a specific LLM config from global_llm_config.yaml."""
base_dir = Path(__file__).resolve().parent.parent.parent.parent
config_file = base_dir / "app" / "config" / "global_llm_config.yaml"
# Fallback to example file if main config doesn't exist
if not config_file.exists():
config_file = base_dir / "app" / "config" / "global_llm_config.example.yaml"
if not config_file.exists():
@ -368,24 +310,17 @@ def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
"""
Load a global LLM config by ID, checking in-memory configs first.
"""Load a global LLM config by ID, checking in-memory configs first.
This handles both static YAML configs and dynamically injected configs
(e.g. OpenRouter integration models that only exist in memory).
Args:
llm_config_id: The negative ID of the global config to load
Returns:
LLM config dict or None if not found
In-memory covers both static YAML and dynamically injected configs (e.g.
OpenRouter integration models that only exist in memory).
"""
from app.config import config as app_config
for cfg in app_config.GLOBAL_LLM_CONFIGS:
if cfg.get("id") == llm_config_id:
return cfg
# Fallback to YAML file read (covers edge cases like hot-reload)
# Fallback to YAML file read (covers hot-reload edge cases).
return load_llm_config_from_yaml(llm_config_id)
@ -393,17 +328,7 @@ async def load_new_llm_config_from_db(
session: AsyncSession,
config_id: int,
) -> "AgentConfig | None":
"""
Load a NewLLMConfig from the database by ID.
Args:
session: AsyncSession for database access
config_id: The ID of the NewLLMConfig to load
Returns:
AgentConfig instance or None if not found
"""
# Import here to avoid circular imports
"""Load a NewLLMConfig from the database by ID."""
from app.db import NewLLMConfig
try:
@ -426,26 +351,13 @@ async def load_agent_llm_config_for_search_space(
session: AsyncSession,
search_space_id: int,
) -> "AgentConfig | None":
"""Load the agent LLM config for a search space via its agent_llm_id.
Positive id -> DB; negative -> YAML; None -> first global config (-1).
"""
Load the agent LLM configuration for a search space.
This loads the LLM config based on the search space's agent_llm_id setting:
- Positive ID: Load from NewLLMConfig database table
- Negative ID: Load from YAML global configs
- None: Falls back to first global config (id=-1)
Args:
session: AsyncSession for database access
search_space_id: The search space ID
Returns:
AgentConfig instance or None if not found
"""
# Import here to avoid circular imports
from app.db import SearchSpace
try:
# Get the search space to check its agent_llm_id preference
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
@ -455,12 +367,9 @@ async def load_agent_llm_config_for_search_space(
print(f"Error: SearchSpace with id {search_space_id} not found")
return None
# Use agent_llm_id from search space, fallback to -1 (first global config)
config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
)
# Load the config using the unified loader
return await load_agent_config(session, config_id, search_space_id)
except Exception as e:
print(f"Error loading agent LLM config for search space {search_space_id}: {e}")
@ -472,23 +381,7 @@ async def load_agent_config(
config_id: int,
search_space_id: int | None = None,
) -> "AgentConfig | None":
"""
Load an agent configuration, supporting Auto mode, YAML, and database configs.
This is the main entry point for loading configurations:
- ID 0: Auto mode (uses LiteLLM Router for load balancing)
- Negative IDs: Load from YAML file (global configs)
- Positive IDs: Load from NewLLMConfig database table
Args:
session: AsyncSession for database access
config_id: The config ID (0 for Auto, negative for YAML, positive for database)
search_space_id: Optional search space ID for context
Returns:
AgentConfig instance or None if not found
"""
# Auto mode (ID 0) - use LiteLLM Router
"""Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB."""
if is_auto_mode(config_id):
if not LLMRouterService.is_initialized():
print("Error: Auto mode requested but LLM Router not initialized")
@ -496,33 +389,22 @@ async def load_agent_config(
return AgentConfig.from_auto_mode()
if config_id < 0:
# Check in-memory configs first (includes static YAML + dynamic OpenRouter)
# In-memory covers static YAML + dynamic OpenRouter configs.
from app.config import config as app_config
for cfg in app_config.GLOBAL_LLM_CONFIGS:
if cfg.get("id") == config_id:
return AgentConfig.from_yaml_config(cfg)
# Fallback to YAML file read for safety
yaml_config = load_llm_config_from_yaml(config_id)
if yaml_config:
return AgentConfig.from_yaml_config(yaml_config)
return None
else:
# Load from database (NewLLMConfig)
return await load_new_llm_config_from_db(session, config_id)
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
"""
Create a ChatLiteLLM instance from a global LLM config dictionary.
Args:
llm_config: LLM configuration dictionary from YAML
Returns:
ChatLiteLLM instance or None on error
"""
# Build the model string
"""Create a ChatLiteLLM instance from a global LLM config dictionary."""
if llm_config.get("custom_provider"):
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
else:
@ -530,27 +412,20 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{llm_config['model_name']}"
# Create ChatLiteLLM instance with streaming enabled
litellm_kwargs = {
"model": model_string,
"api_key": llm_config.get("api_key"),
"streaming": True, # Enable streaming for real-time token streaming
"streaming": True,
}
# Add optional parameters
if llm_config.get("api_base"):
litellm_kwargs["api_base"] = llm_config["api_base"]
# Add any additional litellm parameters
if llm_config.get("litellm_params"):
litellm_kwargs.update(llm_config["litellm_params"])
llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string)
# Configure LiteLLM-native prompt caching (cache_control_injection_points
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
# ``agent_config=None`` here — the YAML path doesn't have provider intent
# in a structured form, so we set only the universal injection points.
# agent_config=None: the YAML path lacks structured provider intent, so set
# only the universal cache_control_injection_points.
apply_litellm_prompt_caching(llm)
return llm
@ -558,19 +433,7 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
def create_chat_litellm_from_agent_config(
agent_config: AgentConfig,
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
"""
Create a ChatLiteLLM or ChatLiteLLMRouter instance from an AgentConfig.
For Auto mode configs, returns a ChatLiteLLMRouter that uses LiteLLM Router
for automatic load balancing across available providers.
Args:
agent_config: AgentConfig instance
Returns:
ChatLiteLLM or ChatLiteLLMRouter instance, or None on error
"""
# Handle Auto mode - return ChatLiteLLMRouter
"""Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config."""
if agent_config.is_auto_mode:
if not LLMRouterService.is_initialized():
print("Error: Auto mode requested but LLM Router not initialized")
@ -578,19 +441,14 @@ def create_chat_litellm_from_agent_config(
try:
router_llm = get_auto_mode_llm()
if router_llm is not None:
# Universal cache_control_injection_points only — auto-mode
# fans out across providers, so OpenAI-only kwargs (e.g.
# ``prompt_cache_key``) are left off here. ``drop_params``
# would strip them at the provider boundary anyway, but
# there's no point setting them when we don't know the
# destination.
# Universal injection points only: auto-mode fans out across
# providers, so provider-specific kwargs have no known target.
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
return router_llm
except Exception as e:
print(f"Error creating ChatLiteLLMRouter: {e}")
return None
# Build the model string
if agent_config.custom_provider:
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
else:
@ -599,26 +457,19 @@ def create_chat_litellm_from_agent_config(
)
model_string = f"{provider_prefix}/{agent_config.model_name}"
# Create ChatLiteLLM instance with streaming enabled
litellm_kwargs = {
"model": model_string,
"api_key": agent_config.api_key,
"streaming": True, # Enable streaming for real-time token streaming
"streaming": True,
}
# Add optional parameters
if agent_config.api_base:
litellm_kwargs["api_base"] = agent_config.api_base
# Add any additional litellm parameters
if agent_config.litellm_params:
litellm_kwargs.update(agent_config.litellm_params)
llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string)
# Build-time prompt caching: sets ``cache_control_injection_points`` for
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
# Per-thread ``prompt_cache_key`` is layered on later in
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
# Build-time caching only; the per-thread prompt_cache_key is layered on
# later in create_surfsense_deep_agent once thread_id is known.
apply_litellm_prompt_caching(llm, agent_config=agent_config)
return llm

View file

@ -1,63 +1,28 @@
r"""LiteLLM-native prompt caching configuration for SurfSense agents.
r"""LiteLLM-native prompt caching for SurfSense agents.
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
activated for our LiteLLM-based stack its ``isinstance(model, ChatAnthropic)``
gate always failed) with LiteLLM's universal caching mechanism.
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (its
``isinstance(model, ChatAnthropic)`` gate never matched our LiteLLM stack)
with LiteLLM's universal ``cache_control_injection_points`` mechanism, which
covers the Anthropic/Bedrock/Vertex/Gemini/OpenRouter/etc. marker-based
providers and the auto-caching OpenAI family.
Coverage:
Two breakpoints per request:
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
performs automatically when ``cache_control_injection_points`` is set):
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
``deepseek/``, ``xai/`` these caches automatically for prompts 1024
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
- ``index: 0`` pins the head-of-request system prompt. We use ``index: 0``,
NOT ``role: system``: ``before_agent`` injectors accumulate many
SystemMessages, and tagging all of them overflows Anthropic's 4-block cap
(upstream 400 via OpenRouter).
- ``index: -1`` pins the latest message so longest-prefix lookup compounds
multi-turn savings.
We inject **two** breakpoints per request:
OpenAI-family configs also get ``prompt_cache_key`` (per-thread routing hint)
and ``prompt_cache_retention="24h"``. Azure is excluded from the latter
because LiteLLM's Azure transformer drops it (see
``_PROMPT_CACHE_RETENTION_PROVIDERS``).
- ``index: 0`` pins the SurfSense system prompt at the head of the
request (provider variant, citation rules, tool catalog, KB tree,
skills metadata). The langchain agent factory always prepends
``request.system_message`` at index 0 (see ``factory.py``
``_execute_model_async``), so this targets exactly the main system
prompt regardless of how many other ``SystemMessage``\ s the
``before_agent`` injectors (priority, tree, memory, file-intent,
anonymous-doc) have inserted into ``state["messages"]``. Using
``role: system`` here would apply ``cache_control`` to **every**
system-role message and trip Anthropic's hard cap of 4 cache
breakpoints per request once the conversation accumulates enough
injected system messages which surfaces as the upstream 400
``A maximum of 4 blocks with cache_control may be provided. Found N``
via OpenRouterAnthropic.
- ``index: -1`` pins the latest message so multi-turn savings compound:
Anthropic-family providers use longest-matching-prefix lookup, so turn
N+1 still reads turn N's cache up to the shared prefix.
For OpenAI-family configs we additionally pass:
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` routing hint that
raises hit rate by sending requests with a shared prefix to the same
backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and
``azure/`` (added to LiteLLM's Azure transformer in
https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified
against ``AzureOpenAIConfig.get_supported_openai_params`` in our
installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``,
``azure/gpt-5.4``, ``azure/gpt-5.4-mini``).
- ``prompt_cache_retention="24h"`` extends cache TTL beyond the default
5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's
server-side support landed in Microsoft's docs on 2026-05-13 but
LiteLLM 1.83.14's Azure transformer still omits it from its supported
params list, so it gets silently dropped by ``litellm.drop_params``.
Azure's default in-memory retention (5-10 min, max 1 h) already
bridges intra-conversation turns; revisit when LiteLLM bumps Azure.
Safety net: ``litellm.drop_params=True`` is set globally in
``app.services.llm_service`` at module-load time. Any kwarg the destination
provider doesn't recognise is auto-stripped at the provider transformer
layer, so an OpenAIBedrock auto-mode fallback can't 400 on
``prompt_cache_key`` etc.
Safety net: ``litellm.drop_params=True`` (set in ``app.services.llm_service``)
strips any kwarg the destination provider rejects, so an auto-mode fallback
can't 400 on these extras.
"""
from __future__ import annotations
@ -73,57 +38,29 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Two-breakpoint policy: head-of-request + latest message. See module
# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
#
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
# ``before_agent`` middlewares (priority, tree, memory, anonymous-doc)
# insert ``SystemMessage`` instances into ``state["messages"]`` that
# accumulate across turns. With ``role: system`` the LiteLLM hook would
# tag *every* one of them with ``cache_control`` and overflow Anthropic's
# 4-block limit. ``index: 0`` always targets the langchain-prepended
# ``request.system_message``, giving us exactly one stable cache breakpoint.
# Head-of-request + latest message (see module docstring for the index:0 vs
# role:system rationale and Anthropic's 4-block cap).
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
{"location": "message", "index": 0},
{"location": "message", "index": -1},
)
# Providers (uppercase ``AgentConfig.provider`` values) that accept the
# OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs
# (2026-05-13) confirm automatic prompt caching applies to every GPT-4o
# or newer Azure deployment at ≥1024 tokens with no configuration needed,
# and that ``prompt_cache_key`` is combined with the prefix hash to
# improve routing affinity and therefore cache hit rate. LiteLLM's Azure
# transformer ships ``prompt_cache_key`` in its supported params as of
# https://github.com/BerriAI/litellm/pull/20989.
#
# Strict whitelist — many other providers in ``PROVIDER_MAP`` route
# through litellm's ``openai`` prefix without implementing the OpenAI
# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer
# family from the litellm prefix alone.
# Providers that accept the OpenAI ``prompt_cache_key`` routing hint. Strict
# whitelist: many providers route through litellm's ``openai`` prefix without
# the prompt-cache surface, so the prefix alone isn't enough to infer family.
_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset(
{"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"}
)
# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept
# ``prompt_cache_retention="24h"``. Azure is excluded: see module
# docstring — LiteLLM 1.83.14's Azure transformer omits the param so
# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM
# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``.
# Subset that also accepts ``prompt_cache_retention="24h"``. Azure is excluded
# because LiteLLM's Azure transformer omits the param (drop_params strips it).
_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset(
{"OPENAI", "DEEPSEEK", "XAI"}
)
def _is_router_llm(llm: BaseChatModel) -> bool:
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
Importing ``app.services.llm_router_service`` at module-load time would
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
Class-name comparison is sufficient since the class is defined in a
single place.
"""
"""Detect ``ChatLiteLLMRouter`` by class name to avoid an import cycle."""
return type(llm).__name__ == "ChatLiteLLMRouter"
@ -188,21 +125,10 @@ def apply_litellm_prompt_caching(
) -> None:
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
Idempotent values already present in ``llm.model_kwargs`` (e.g. from
``agent_config.litellm_params`` overrides) are preserved. Mutates
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
in our custom ``ChatLiteLLMRouter``.
Args:
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
agent_config: Optional ``AgentConfig`` driving provider-specific
behaviour. When omitted (or auto-mode), only the universal
``cache_control_injection_points`` are set.
thread_id: Optional thread id used to construct a per-thread
``prompt_cache_key`` for OpenAI-family providers. Caching still
works without it (server-side automatic), but the key improves
backend routing affinity and therefore hit rate.
Idempotent (existing ``model_kwargs`` values are preserved) and mutates
``llm.model_kwargs`` in place. Without ``agent_config`` (or in auto-mode)
only the universal injection points are set; ``thread_id`` adds a per-thread
``prompt_cache_key`` for OpenAI-family providers to improve routing affinity.
"""
model_kwargs = _get_or_init_model_kwargs(llm)
if model_kwargs is None:
@ -217,11 +143,8 @@ def apply_litellm_prompt_caching(
dict(point) for point in _DEFAULT_INJECTION_POINTS
]
# OpenAI-style extras only when we statically know the destination
# accepts them. Auto-mode router fans out across mixed providers so
# we can't safely set destination-specific kwargs there (drop_params
# would strip them but it's wasteful to set them in the first
# place).
# OpenAI-style extras only when the destination is statically known. The
# auto-mode router fans out across mixed providers, so skip them there.
if _is_router_llm(llm):
return

View file

@ -1,26 +1,13 @@
"""
SurfSense compaction middleware.
"""SurfSense compaction middleware.
Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware`
to add SurfSense-specific behavior:
Extends ``SummarizationMiddleware`` with three SurfSense behaviors:
1. **Structured summary template** (OpenCode-style ``## Goal / Constraints /
Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``)
see :data:`SURFSENSE_SUMMARY_PROMPT` below. The base
``SummarizationMiddleware`` only ships a freeform "summarize this"
prompt; the structured template is ported from OpenCode's
``compaction.ts``.
2. **Protect SurfSense-specific SystemMessages** so injected hints
(``<priority_documents>``, ``<workspace_tree>``, ``<file_operation_contract>``,
``<user_memory>``, ``<team_memory>``, ``<user_name>``, ``<memory_warning>``)
are *not* summarized away and are kept verbatim in the post-summary
message list. Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
(some message types are part of the agent's contract and must survive
compaction unchanged).
3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string``
(Azure OpenAI / LiteLLM defense when a provider streams an AIMessage
containing only tool_calls and no text, ``content`` can be ``None`` and
``get_buffer_string`` crashes iterating over ``None``). SurfSense-specific.
1. A structured summary template (:data:`SURFSENSE_SUMMARY_PROMPT`) instead of
the base freeform prompt.
2. Protected SystemMessages (injected hints like ``<priority_documents>``) are
kept verbatim instead of being summarized away.
3. ``content=None`` is sanitized before ``get_buffer_string`` (some providers
stream tool-only AIMessages with ``None`` content, which would crash it).
"""
from __future__ import annotations
@ -43,9 +30,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Structured summary template ported from OpenCode's
# ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a
# module-level constant so unit tests can assert on its sections.
# Module-level constant so unit tests can assert on its sections.
SURFSENSE_SUMMARY_PROMPT = """<role>
SurfSense Conversation Compaction Assistant
</role>
@ -114,13 +99,10 @@ def _is_protected_system_message(msg: AnyMessage) -> bool:
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
"""Return ``msg`` with ``content=None`` coerced to ``""``.
"""Return a copy of ``msg`` with ``content=None`` coerced to ``""``.
Folds in the historical defense from ``safe_summarization.py``
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``,
so a ``None`` content (Azure OpenAI / LiteLLM streaming a tool-only
AIMessage) explodes. We return a copy with empty string content so
downstream consumers see an empty body without mutating the original.
``get_buffer_string`` reads ``m.text`` (iterating ``content``), so a
tool-only AIMessage with ``None`` content would crash it.
"""
if getattr(msg, "content", "not-missing") is not None:
return msg
@ -159,20 +141,11 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware):
conversation_messages: list[AnyMessage],
cutoff_index: int,
) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Split messages but always preserve SurfSense protected SystemMessages.
"""Split messages, always preserving protected SystemMessages.
Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
(``opencode/packages/opencode/src/session/compaction.ts``): some
message types are always kept verbatim because they are part of the
agent's working contract, not transient output.
Also opens a ``compaction.run`` OTel span (no-op when OTel is off)
so dashboards can count compaction events and message-volume
without having to instrument upstream callers.
Also opens a ``compaction.run`` OTel span (no-op when OTel is off) here,
since partitioning is the first call once summarization is decided.
"""
# Opening a span here is appropriate because partitioning is the
# first call SummarizationMiddleware makes when it has decided to
# summarize; we record the volume and then close as a normal span.
with ot.compaction_span(
reason="auto",
messages_in=len(conversation_messages),
@ -191,20 +164,15 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware):
else:
kept_for_summary.append(msg)
# Place protected blocks at the *front* of preserved_messages so
# they keep their original ordering relative to the summary
# HumanMessage that precedes the rest of the preserved tail.
# Protected blocks go at the front of preserved_messages to keep
# ordering relative to the summary HumanMessage.
return kept_for_summary, [*protected, *preserved_messages]
def _filter_summary_messages( # type: ignore[override]
self, messages: list[AnyMessage]
) -> list[AnyMessage]:
"""Filter previous summaries AND sanitize ``content=None``.
Folds the ``safe_summarization.py`` defense in: when the buffer
builder iterates ``m.text`` over ``None`` it explodes; sanitizing
here covers both the sync and async offload paths.
"""
"""Filter previous summaries and sanitize ``content=None`` (covers the
sync and async offload paths)."""
filtered = super()._filter_summary_messages(messages)
return [_sanitize_message_content(m) for m in filtered]

View file

@ -24,14 +24,11 @@ from .utils import get_voice_for_provider
async def create_podcast_transcript(
state: State, config: RunnableConfig
) -> dict[str, Any]:
"""Each node does work."""
# Get configuration from runnable config
"""Generate the podcast transcript from the source content."""
configuration = Configuration.from_runnable_config(config)
search_space_id = configuration.search_space_id
user_prompt = configuration.user_prompt
# Get search space's document summary LLM
llm = await get_agent_llm(state.db_session, search_space_id)
if not llm:
error_message = (
@ -40,22 +37,16 @@ async def create_podcast_transcript(
print(error_message)
raise RuntimeError(error_message)
# Get the prompt
prompt = get_podcast_generation_prompt(user_prompt)
# Create the messages
messages = [
SystemMessage(content=prompt),
HumanMessage(
content=f"<source_content>{state.source_content}</source_content>"
),
]
# Generate the podcast transcript
llm_response = await llm.ainvoke(messages)
# Reasoning models (e.g. Kimi K2.5) may return content as a list of
# blocks including 'reasoning' entries. Normalise to a plain string.
# Reasoning models may return content as blocks; normalise to a string.
content = strip_markdown_fences(extract_text_content(llm_response.content))
try:
@ -89,17 +80,13 @@ async def create_merged_podcast_audio(
state: State, config: RunnableConfig
) -> dict[str, Any]:
"""Generate audio for each transcript and merge them into a single podcast file."""
# configuration = Configuration.from_runnable_config(config)
starting_transcript = PodcastTranscriptEntry(
speaker_id=1, dialog="Welcome to Surfsense Podcast."
)
transcript = state.podcast_transcript
# Merge the starting transcript with the podcast transcript
# Check if transcript is a PodcastTranscripts object or already a list
# transcript may be a PodcastTranscripts object or already a list.
if hasattr(transcript, "podcast_transcripts"):
transcript_entries = transcript.podcast_transcripts
else:
@ -107,20 +94,16 @@ async def create_merged_podcast_audio(
merged_transcript = [starting_transcript, *transcript_entries]
# Create a temporary directory for audio files
temp_dir = Path("temp_audio")
temp_dir.mkdir(exist_ok=True)
# Generate a unique session ID for this podcast
session_id = str(uuid.uuid4())
output_path = f"podcasts/{session_id}_podcast.mp3"
os.makedirs("podcasts", exist_ok=True)
# Generate audio for each transcript segment
audio_files = []
async def generate_speech_for_segment(segment, index):
# Handle both dictionary and PodcastTranscriptEntry objects
if hasattr(segment, "speaker_id"):
speaker_id = segment.speaker_id
dialog = segment.dialog
@ -128,20 +111,15 @@ async def create_merged_podcast_audio(
speaker_id = segment.get("speaker_id", 0)
dialog = segment.get("dialog", "")
# Select voice based on speaker_id
voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id)
# Generate a unique filename for this segment
if app_config.TTS_SERVICE == "local/kokoro":
# Kokoro generates WAV files
filename = f"{temp_dir}/{session_id}_{index}.wav"
else:
# Other services generate MP3 files
filename = f"{temp_dir}/{session_id}_{index}.mp3"
try:
if app_config.TTS_SERVICE == "local/kokoro":
# Use Kokoro TTS service
kokoro_service = await get_kokoro_tts_service(
lang_code="a"
) # American English
@ -170,7 +148,6 @@ async def create_merged_podcast_audio(
timeout=600,
)
# Save the audio to a file - use proper streaming method
with open(filename, "wb") as f:
f.write(response.content)
@ -179,23 +156,17 @@ async def create_merged_podcast_audio(
print(f"Error generating speech for segment {index}: {e!s}")
raise
# Generate all audio files concurrently
tasks = [
generate_speech_for_segment(segment, i)
for i, segment in enumerate(merged_transcript)
]
audio_files = await asyncio.gather(*tasks)
# Merge audio files using ffmpeg
try:
# Create FFmpeg instance with the first input
ffmpeg = FFmpeg().option("y")
# Add each audio file as input
for audio_file in audio_files:
ffmpeg = ffmpeg.input(audio_file)
# Configure the concatenation and output
filter_complex = []
for i in range(len(audio_files)):
filter_complex.append(f"[{i}:0]")
@ -205,8 +176,6 @@ async def create_merged_podcast_audio(
)
ffmpeg = ffmpeg.option("filter_complex", filter_complex_str)
ffmpeg = ffmpeg.output(output_path, map="[outa]")
# Execute FFmpeg
await ffmpeg.execute()
print(f"Successfully created podcast audio: {output_path}")
@ -215,7 +184,6 @@ async def create_merged_podcast_audio(
print(f"Error merging audio files: {e!s}")
raise
finally:
# Clean up temporary files
for audio_file in audio_files:
try:
os.remove(audio_file)