mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
docs(agents): tighten docstrings and comments across agent module
Recursive pass over the agents module to make docstrings and inline comments concise and intent-oriented: drop narration that just restates the code, condense verbose module/function docstrings, and keep only the non-obvious "why" notes. No functional code changed.
This commit is contained in:
parent
620c378254
commit
a3d05f6418
16 changed files with 319 additions and 1055 deletions
|
|
@ -1,25 +1,15 @@
|
||||||
"""Append-only action-log middleware for the SurfSense agent.
|
"""Append-only action-log middleware for the SurfSense agent.
|
||||||
|
|
||||||
Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes
|
Wraps every tool call and writes a row to :class:`~app.db.AgentActionLog`
|
||||||
a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt
|
after the tool returns. Tools opt into reversibility via a ``reverse``
|
||||||
into reversibility by declaring a ``reverse`` callable on their
|
callable on their :class:`ToolDefinition`; the rendered descriptor powers
|
||||||
:class:`ToolDefinition`; the rendered descriptor is persisted in
|
|
||||||
``reverse_descriptor`` for use by
|
|
||||||
``/api/threads/{thread_id}/revert/{action_id}``.
|
``/api/threads/{thread_id}/revert/{action_id}``.
|
||||||
|
|
||||||
Design points:
|
Logging is fully defensive — DB-write failures are swallowed so the tool's
|
||||||
|
result is always returned untouched. Only metadata (name, capped args,
|
||||||
* **Defensive.** Logging never blocks the agent. We catch every exception
|
result_id, reverse_descriptor) is stored; tool output stays in the
|
||||||
on the DB write path and emit a warning; the tool's ``ToolMessage``
|
checkpoint. Reversibility is best-effort: a reverse callable that raises
|
||||||
result is always returned untouched.
|
just leaves the action non-reversible.
|
||||||
* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) +
|
|
||||||
``result_id`` + ``reverse_descriptor`` are stored. Tool output text
|
|
||||||
remains in the LangGraph checkpoint / spilled tool-output files.
|
|
||||||
* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)``
|
|
||||||
with the parsed JSON result when the tool's content is a JSON object;
|
|
||||||
otherwise the raw text is passed. Exceptions in the reverse callable
|
|
||||||
are swallowed and logged — a failed descriptor render simply means the
|
|
||||||
action is NOT marked reversible.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -203,11 +193,9 @@ class ActionLogMiddleware(AgentMiddleware):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Surface a side-channel SSE event so the chat tool card can
|
# Side-channel event (relayed by ``stream_new_chat`` as a
|
||||||
# render a Revert button immediately after the row is durable.
|
# ``data-action-log`` SSE) so the tool card can show a Revert button
|
||||||
# ``stream_new_chat`` translates this into a
|
# once the row is durable. Carries a presence flag, not the descriptor.
|
||||||
# ``data-action-log`` SSE event. We DO NOT include the
|
|
||||||
# ``reverse_descriptor`` payload here; only a presence flag.
|
|
||||||
try:
|
try:
|
||||||
await adispatch_custom_event(
|
await adispatch_custom_event(
|
||||||
"action_log",
|
"action_log",
|
||||||
|
|
|
||||||
|
|
@ -1,32 +1,12 @@
|
||||||
"""
|
"""Per-thread asyncio lock + cooperative cancel token, keyed by ``thread_id``.
|
||||||
BusyMutexMiddleware — per-thread asyncio lock + cancel token.
|
|
||||||
|
|
||||||
LangChain has no built-in concept of "this thread is already running a
|
Refuses a second concurrent turn on the same thread (e.g. double-clicked
|
||||||
turn — refuse the second concurrent request". Without it, a user
|
"send") that would otherwise race on the same checkpoint and duplicate tool
|
||||||
double-clicking "send" or refreshing the page mid-stream can spawn two
|
calls. Also exposes a per-thread cancel event that long-running tools poll
|
||||||
turns racing on the same checkpoint, producing duplicated tool calls
|
via ``runtime.context.cancel_event.is_set()`` to abort cooperatively.
|
||||||
and mangled state.
|
|
||||||
|
|
||||||
Ported from OpenCode's ``Stream.scoped(AbortController)`` pattern: a
|
Process-local and in-memory; multi-worker deployments need a distributed lock
|
||||||
single-process, in-memory lock + cooperative cancellation token keyed by
|
(Redis / PostgreSQL advisory locks) as a follow-up.
|
||||||
``thread_id``. For multi-worker deployments a distributed lock backend
|
|
||||||
(Redis or PostgreSQL advisory locks) is a phase-2 follow-up.
|
|
||||||
|
|
||||||
What this provides:
|
|
||||||
- A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``;
|
|
||||||
acquiring the lock during ``before_agent`` blocks any concurrent
|
|
||||||
prompt on the same thread until release.
|
|
||||||
- A per-thread ``asyncio.Event`` (``cancel_event``) that long-running
|
|
||||||
tools can poll to abort cooperatively. The event is reset between
|
|
||||||
turns. Tools should check ``runtime.context.cancel_event.is_set()``
|
|
||||||
in tight inner loops.
|
|
||||||
- A typed :class:`~app.agents.chat.runtime.errors.BusyError` raised when a
|
|
||||||
second turn arrives while the lock is held.
|
|
||||||
|
|
||||||
Note: SurfSense's ``stream_new_chat`` is the call site that should
|
|
||||||
acquire/release. Wiring this as middleware means the contract is
|
|
||||||
explicit and the lock manager is shared with subagents that compile
|
|
||||||
their own ``create_agent`` runnables.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -152,9 +132,8 @@ class _ThreadLockManager:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — process-local but reused across all agent
|
# Process-local singleton shared across all agents/subagents built in this
|
||||||
# instances built in this process. Subagents created in nested
|
# process so per-thread locks stay coherent.
|
||||||
# ``create_agent`` calls also get this so locks are coherent.
|
|
||||||
manager = _ThreadLockManager()
|
manager = _ThreadLockManager()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -266,7 +245,6 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
||||||
await lock.acquire()
|
await lock.acquire()
|
||||||
epoch = manager.bump_turn_epoch(thread_id)
|
epoch = manager.bump_turn_epoch(thread_id)
|
||||||
self._held_locks[thread_id] = (lock, epoch)
|
self._held_locks[thread_id] = (lock, epoch)
|
||||||
# Reset the cancel event so this turn starts fresh
|
|
||||||
reset_cancel(thread_id)
|
reset_cancel(thread_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -289,17 +267,14 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
||||||
return None
|
return None
|
||||||
if lock.locked():
|
if lock.locked():
|
||||||
lock.release()
|
lock.release()
|
||||||
# Always clear cancel event between turns so a stale signal
|
# Clear cancel event so a stale signal doesn't leak into the next turn.
|
||||||
# doesn't leak into the next request.
|
|
||||||
reset_cancel(thread_id)
|
reset_cancel(thread_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Provide sync no-ops because the middleware base class allows them
|
|
||||||
def before_agent( # type: ignore[override]
|
def before_agent( # type: ignore[override]
|
||||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
# Sync path: no asyncio.Lock to acquire. Best we can do is reject
|
# Sync path can't await an asyncio.Lock; only reject if one is in flight.
|
||||||
# if anyone else is in flight.
|
|
||||||
thread_id = self._thread_id(runtime)
|
thread_id = self._thread_id(runtime)
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
if self._require_thread_id:
|
if self._require_thread_id:
|
||||||
|
|
|
||||||
|
|
@ -82,13 +82,10 @@ _T = TypeVar("_T")
|
||||||
async def _ainvoke_with_timeout[T](
|
async def _ainvoke_with_timeout[T](
|
||||||
coro: Awaitable[_T], *, subagent_type: str, started_at: float
|
coro: Awaitable[_T], *, subagent_type: str, started_at: float
|
||||||
) -> _T:
|
) -> _T:
|
||||||
"""Apply :data:`DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS` to ``coro``.
|
"""Apply the subagent invoke timeout to ``coro`` (non-positive disables it).
|
||||||
|
|
||||||
A non-positive timeout disables the cap (configurable via the
|
On expiry the task is cancelled and :class:`SubagentInvokeTimeoutError` is
|
||||||
``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` env var). On expiry the
|
raised for the caller to turn into a synthetic ToolMessage.
|
||||||
underlying task is cancelled and :class:`SubagentInvokeTimeoutError` is
|
|
||||||
raised — the caller wraps it into a synthetic ToolMessage so the
|
|
||||||
orchestrator can decide what to do.
|
|
||||||
"""
|
"""
|
||||||
timeout = DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS
|
timeout = DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS
|
||||||
if timeout <= 0:
|
if timeout <= 0:
|
||||||
|
|
@ -151,12 +148,9 @@ def build_task_tool_with_parent_config(
|
||||||
subagent_graphs: dict[str, Runnable] = {
|
subagent_graphs: dict[str, Runnable] = {
|
||||||
spec["name"]: spec["runnable"] for spec in subagents
|
spec["name"]: spec["runnable"] for spec in subagents
|
||||||
}
|
}
|
||||||
# Per-subagent context-hint providers (see ``SurfSenseSubagentSpec``).
|
# Sparse map of opt-in context-hint providers; each runs once per task()
|
||||||
# The mapping is sparse: only routes that opted in via ``pack_subagent``
|
# call to prepend a string to the subagent's first HumanMessage. Failures
|
||||||
# appear here, and the value is invoked once per ``task(...)`` call to
|
# are swallowed so a broken hint never blocks the task.
|
||||||
# generate a short string prepended to the subagent's first
|
|
||||||
# ``HumanMessage``. Failures are logged and swallowed — a broken hint
|
|
||||||
# provider must never prevent the underlying task from running.
|
|
||||||
subagent_hint_providers: dict[str, ContextHintProvider] = {
|
subagent_hint_providers: dict[str, ContextHintProvider] = {
|
||||||
spec["name"]: provider
|
spec["name"]: provider
|
||||||
for spec in subagents
|
for spec in subagents
|
||||||
|
|
@ -178,24 +172,18 @@ def build_task_tool_with_parent_config(
|
||||||
def _billable_call_update(
|
def _billable_call_update(
|
||||||
subagent_type: str, runtime: ToolRuntime
|
subagent_type: str, runtime: ToolRuntime
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Build the per-call ``billable_calls`` delta + an optional warning.
|
"""Build the per-call ``billable_calls`` delta plus an optional soft-cap warning.
|
||||||
|
|
||||||
The orchestrator's ``billable_calls`` map is summed by
|
Always emits ``{subagent_type: 1}`` (a reducer accumulates it); when this
|
||||||
:func:`_int_counter_merge_reducer`, so we always emit
|
call would cross the threshold, also adds a soft ``messages`` entry so the
|
||||||
``{subagent_type: 1}`` and let the reducer accumulate. If the
|
orchestrator self-limits on its next step.
|
||||||
cumulative count *after* this call would cross the configured
|
|
||||||
threshold, we also slip a soft ``messages`` entry into the update
|
|
||||||
so the orchestrator can read it on its next step and self-limit.
|
|
||||||
Returning a plain ``dict`` (vs. an extra :class:`Command`) keeps
|
|
||||||
the helper composable with the existing single/batch return paths.
|
|
||||||
"""
|
"""
|
||||||
delta: dict[str, Any] = {"billable_calls": {subagent_type: 1}}
|
delta: dict[str, Any] = {"billable_calls": {subagent_type: 1}}
|
||||||
threshold = DEFAULT_SUBAGENT_BILLABLE_THRESHOLD
|
threshold = DEFAULT_SUBAGENT_BILLABLE_THRESHOLD
|
||||||
if threshold <= 0:
|
if threshold <= 0:
|
||||||
return delta
|
return delta
|
||||||
prior = runtime.state.get("billable_calls") or {}
|
prior = runtime.state.get("billable_calls") or {}
|
||||||
# ``prior`` may be a plain dict or a reducer-managed mapping; only
|
# Count int values only so a malformed checkpoint can't crash us.
|
||||||
# int values are counted so a malformed checkpoint can't crash us.
|
|
||||||
prior_total = sum(v for v in prior.values() if isinstance(v, int))
|
prior_total = sum(v for v in prior.values() if isinstance(v, int))
|
||||||
new_total = prior_total + 1
|
new_total = prior_total + 1
|
||||||
if prior_total < threshold <= new_total:
|
if prior_total < threshold <= new_total:
|
||||||
|
|
@ -214,8 +202,7 @@ def build_task_tool_with_parent_config(
|
||||||
"""Merge the per-call billable counter (and warning) into ``cmd``."""
|
"""Merge the per-call billable counter (and warning) into ``cmd``."""
|
||||||
delta = _billable_call_update(subagent_type, runtime)
|
delta = _billable_call_update(subagent_type, runtime)
|
||||||
warn_text = delta.pop("_billable_warn_text", None)
|
warn_text = delta.pop("_billable_warn_text", None)
|
||||||
# ``cmd.update`` may be a dict or LangGraph ``UpdateDict``; defensively
|
# Copy so we don't mutate state shared with other tool returns.
|
||||||
# copy so we don't mutate state shared across other tool returns.
|
|
||||||
update = dict(getattr(cmd, "update", {}) or {})
|
update = dict(getattr(cmd, "update", {}) or {})
|
||||||
for key, value in delta.items():
|
for key, value in delta.items():
|
||||||
update[key] = value
|
update[key] = value
|
||||||
|
|
@ -228,14 +215,10 @@ def build_task_tool_with_parent_config(
|
||||||
return Command(update=update)
|
return Command(update=update)
|
||||||
|
|
||||||
def _safe_message_text(msg: Any) -> str:
|
def _safe_message_text(msg: Any) -> str:
|
||||||
"""Pull text out of a BaseMessage without trusting the ``.text`` property.
|
"""Pull text out of a BaseMessage without using the ``.text`` property.
|
||||||
|
|
||||||
``BaseMessage.text`` walks ``content_blocks`` and crashes with
|
``.text`` crashes when ``content`` is ``None`` (common for tool-call
|
||||||
``TypeError: 'NoneType' object is not iterable`` when ``content`` is
|
AIMessages), and ``getattr`` won't catch it, so read ``content`` directly.
|
||||||
``None`` (common for tool-call AIMessages whose payload is purely
|
|
||||||
structured). ``getattr(msg, "text", None)`` does not catch this
|
|
||||||
because Python evaluates the property body before falling back to
|
|
||||||
the default. Read ``content`` directly and coerce defensively.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
content = getattr(msg, "content", None)
|
content = getattr(msg, "content", None)
|
||||||
|
|
@ -258,23 +241,18 @@ def build_task_tool_with_parent_config(
|
||||||
return str(content)
|
return str(content)
|
||||||
|
|
||||||
def _build_tool_trace(messages: list[Any]) -> list[dict[str, Any]]:
|
def _build_tool_trace(messages: list[Any]) -> list[dict[str, Any]]:
|
||||||
"""Compress the subagent's message stream into a compact tool trace.
|
"""Compress the subagent's messages into a compact tool trace.
|
||||||
|
|
||||||
Each entry is ``{"tool": <name>, "status": "ok"|"error", "preview":
|
Entries (``{tool, status, preview}``) ride on the ToolMessage's
|
||||||
<≤120 chars>}`` so the orchestrator can show "this is what your
|
``additional_kwargs["surf_tool_trace"]`` for UI/observability; the LLM
|
||||||
specialist actually did" without dumping the full message stream
|
never sees them.
|
||||||
back through the prompt. The list is attached to the returned
|
|
||||||
ToolMessage's ``additional_kwargs`` (under ``"surf_tool_trace"``);
|
|
||||||
the LLM never sees it, but UI / observability code can pluck it
|
|
||||||
out of the checkpoint.
|
|
||||||
"""
|
"""
|
||||||
trace: list[dict[str, Any]] = []
|
trace: list[dict[str, Any]] = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
tool_name = getattr(msg, "name", None)
|
tool_name = getattr(msg, "name", None)
|
||||||
tool_call_id_attr = getattr(msg, "tool_call_id", None)
|
tool_call_id_attr = getattr(msg, "tool_call_id", None)
|
||||||
if not tool_name and not tool_call_id_attr:
|
if not tool_name and not tool_call_id_attr:
|
||||||
# Only ToolMessages have either field; skip AIMessage /
|
# Only ToolMessages carry either field.
|
||||||
# HumanMessage / SystemMessage frames.
|
|
||||||
continue
|
continue
|
||||||
status = getattr(msg, "status", None) or "ok"
|
status = getattr(msg, "status", None) or "ok"
|
||||||
preview = _safe_message_text(msg).strip().replace("\n", " ")
|
preview = _safe_message_text(msg).strip().replace("\n", " ")
|
||||||
|
|
@ -308,8 +286,7 @@ def build_task_tool_with_parent_config(
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
message_text = _safe_message_text(messages[-1]).rstrip()
|
message_text = _safe_message_text(messages[-1]).rstrip()
|
||||||
# Tool-trace is purely observability — wrap defensively so a single
|
# Trace is observability-only; never let a bad frame kill the turn.
|
||||||
# malformed frame never bubbles up and kills the whole user turn.
|
|
||||||
try:
|
try:
|
||||||
tool_trace = _build_tool_trace(messages)
|
tool_trace = _build_tool_trace(messages)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -320,10 +297,7 @@ def build_task_tool_with_parent_config(
|
||||||
tool_trace = []
|
tool_trace = []
|
||||||
tool_msg = ToolMessage(message_text, tool_call_id=tool_call_id)
|
tool_msg = ToolMessage(message_text, tool_call_id=tool_call_id)
|
||||||
if tool_trace:
|
if tool_trace:
|
||||||
# ``additional_kwargs`` is a free-form dict on BaseMessage; using
|
# surf_ prefix avoids collision with provider keys (e.g. cache_control).
|
||||||
# a ``surf_`` prefix avoids collision with provider-specific keys
|
|
||||||
# (e.g. Anthropic's ``cache_control``). The LLM doesn't see it;
|
|
||||||
# consumers (UI, observability) read it off the checkpoint.
|
|
||||||
tool_msg.additional_kwargs["surf_tool_trace"] = tool_trace
|
tool_msg.additional_kwargs["surf_tool_trace"] = tool_trace
|
||||||
return Command(
|
return Command(
|
||||||
update={
|
update={
|
||||||
|
|
@ -361,9 +335,7 @@ def build_task_tool_with_parent_config(
|
||||||
}
|
}
|
||||||
hint = _resolve_context_hint(subagent_type, description, runtime)
|
hint = _resolve_context_hint(subagent_type, description, runtime)
|
||||||
if hint:
|
if hint:
|
||||||
# Prepend as a tagged block so the subagent prompt can pattern-match
|
# Tagged block so the subagent prompt can pattern-match the section.
|
||||||
# on the section (and a future change can lift it into its own
|
|
||||||
# ``SystemMessage`` if needed).
|
|
||||||
payload = f"<context_hint>\n{hint}\n</context_hint>\n\n{description}"
|
payload = f"<context_hint>\n{hint}\n</context_hint>\n\n{description}"
|
||||||
else:
|
else:
|
||||||
payload = description
|
payload = description
|
||||||
|
|
@ -374,16 +346,12 @@ def build_task_tool_with_parent_config(
|
||||||
results: list[tuple[int, str, dict | str, dict | None]],
|
results: list[tuple[int, str, dict | str, dict | None]],
|
||||||
runtime: ToolRuntime,
|
runtime: ToolRuntime,
|
||||||
) -> Command:
|
) -> Command:
|
||||||
"""Combine per-child results into one Command with a combined ToolMessage.
|
"""Combine per-child results into one Command with an aggregate ToolMessage.
|
||||||
|
|
||||||
``results`` is a list of ``(task_index, subagent_type,
|
``results`` tuples are ``(task_index, subagent_type, payload_or_error,
|
||||||
payload_or_error_text, child_state_update)`` tuples — preserving the
|
child_state_update)``; output blocks are sorted by index so the LLM can
|
||||||
input order so the orchestrator can map each block back to the task
|
map them back to dispatch order, and each child contributes a
|
||||||
it dispatched. State updates are merged by reducer for keys outside
|
``billable_calls`` increment to match single-mode accounting.
|
||||||
:data:`EXCLUDED_STATE_KEYS`; everything else (``messages``, ``todos``,
|
|
||||||
etc.) is replaced by the synthesized aggregate ToolMessage. Every
|
|
||||||
child also contributes a ``billable_calls`` increment so cost
|
|
||||||
accounting matches single-mode dispatch.
|
|
||||||
"""
|
"""
|
||||||
results.sort(key=lambda r: r[0])
|
results.sort(key=lambda r: r[0])
|
||||||
merged_state: dict[str, Any] = {}
|
merged_state: dict[str, Any] = {}
|
||||||
|
|
@ -424,8 +392,8 @@ def build_task_tool_with_parent_config(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if state_update:
|
if state_update:
|
||||||
# Naive merge: later tasks win on scalar collisions; reducer-backed
|
# Later tasks win on scalar collisions; reducer-backed fields
|
||||||
# fields (``receipts``, ``files`` etc.) accumulate at apply time.
|
# accumulate at apply time.
|
||||||
merged_state.update(state_update)
|
merged_state.update(state_update)
|
||||||
aggregate = "\n\n".join(message_blocks)
|
aggregate = "\n\n".join(message_blocks)
|
||||||
aggregate_msg = ToolMessage(
|
aggregate_msg = ToolMessage(
|
||||||
|
|
@ -469,11 +437,9 @@ def build_task_tool_with_parent_config(
|
||||||
) -> tuple[int, str, dict | str, dict | None]:
|
) -> tuple[int, str, dict | str, dict | None]:
|
||||||
"""Run one child of a batched ``task`` call under the concurrency cap.
|
"""Run one child of a batched ``task`` call under the concurrency cap.
|
||||||
|
|
||||||
Errors are returned as plain text in slot 2 so a single child's
|
Errors are returned as text (slot 2) so one child's failure doesn't abort
|
||||||
failure does not abort the whole batch. ``GraphInterrupt`` from a
|
the batch. A child's ``GraphInterrupt`` is a hard failure for that child:
|
||||||
batched child is currently treated as a hard failure for that child
|
batched HITL is intentionally out of scope.
|
||||||
only — batched HITL is intentionally out of scope for the v1
|
|
||||||
rollout (see plan tier 2 item 4 risks).
|
|
||||||
"""
|
"""
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
if subagent_type not in subagent_graphs:
|
if subagent_type not in subagent_graphs:
|
||||||
|
|
@ -507,8 +473,7 @@ def build_task_tool_with_parent_config(
|
||||||
)
|
)
|
||||||
return (task_index, subagent_type, str(exc), None)
|
return (task_index, subagent_type, str(exc), None)
|
||||||
except GraphInterrupt:
|
except GraphInterrupt:
|
||||||
# Batched HITL is unsupported in v1 — surface as a failure
|
# Batched HITL unsupported; fail this child so the batch finishes.
|
||||||
# for this child so the rest of the batch still completes.
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Batch child %d (%s) raised GraphInterrupt; batched HITL "
|
"Batch child %d (%s) raised GraphInterrupt; batched HITL "
|
||||||
"is not supported. Re-dispatch this task as a single "
|
"is not supported. Re-dispatch this task as a single "
|
||||||
|
|
@ -545,14 +510,11 @@ def build_task_tool_with_parent_config(
|
||||||
return (task_index, subagent_type, result, child_state_update)
|
return (task_index, subagent_type, result, child_state_update)
|
||||||
|
|
||||||
def _coerce_batch_arg(tasks: Any) -> list[dict] | str:
|
def _coerce_batch_arg(tasks: Any) -> list[dict] | str:
|
||||||
"""Rescue common LLM-side malformations of the ``tasks`` argument.
|
"""Rescue common LLM malformations of the ``tasks`` argument.
|
||||||
|
|
||||||
Some providers serialise an array argument as a JSON-encoded string,
|
Recovers a JSON-encoded array string and a single dict (instead of a
|
||||||
and small models occasionally hand back a single ``{description,
|
1-element array), logging a WARN. Unrecoverable shapes return a string
|
||||||
subagent_type}`` dict instead of a one-element array. Both are
|
the caller surfaces as the tool error.
|
||||||
recovered here with a WARN log so the issue is visible in metrics
|
|
||||||
but the user's turn still completes; truly broken shapes return a
|
|
||||||
plain string that the caller surfaces as the tool error.
|
|
||||||
"""
|
"""
|
||||||
if isinstance(tasks, list):
|
if isinstance(tasks, list):
|
||||||
return tasks
|
return tasks
|
||||||
|
|
@ -587,13 +549,10 @@ def build_task_tool_with_parent_config(
|
||||||
async def _adispatch_batch(
|
async def _adispatch_batch(
|
||||||
tasks: list[dict], runtime: ToolRuntime
|
tasks: list[dict], runtime: ToolRuntime
|
||||||
) -> Command | str:
|
) -> Command | str:
|
||||||
"""Fan-out helper for the ``tasks`` array shape.
|
"""Fan out the ``tasks`` array (size- and concurrency-capped).
|
||||||
|
|
||||||
Bounded by :data:`MAX_SUBAGENT_BATCH_SIZE` and concurrency-capped
|
Returns one Command; the LLM sees one ``[task <index>]``-prefixed block
|
||||||
at :data:`DEFAULT_SUBAGENT_BATCH_CONCURRENCY`. Returns a single
|
per child, in input order.
|
||||||
:class:`Command` that the LLM sees as one ToolMessage per child,
|
|
||||||
prefixed with ``[task <index>]`` so it can map back to the input
|
|
||||||
order.
|
|
||||||
"""
|
"""
|
||||||
if not tasks:
|
if not tasks:
|
||||||
return "tasks: array is empty; nothing to dispatch."
|
return "tasks: array is empty; nothing to dispatch."
|
||||||
|
|
@ -703,17 +662,16 @@ def build_task_tool_with_parent_config(
|
||||||
if pending_value is not None:
|
if pending_value is not None:
|
||||||
resume_value = consume_surfsense_resume(runtime)
|
resume_value = consume_surfsense_resume(runtime)
|
||||||
if resume_value is None:
|
if resume_value is None:
|
||||||
# Bridge invariant: a queued resume must accompany any pending
|
# A pending interrupt must have a queued resume; otherwise replay
|
||||||
# subagent interrupt. Fall-through replay would silently re-prompt
|
# would silently re-prompt the user. Raise instead.
|
||||||
# the user; raise so the streaming layer surfaces a clear error.
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Subagent {subagent_type!r} has a pending interrupt but no "
|
f"Subagent {subagent_type!r} has a pending interrupt but no "
|
||||||
"surfsense_resume_value on config; resume bridge is broken."
|
"surfsense_resume_value on config; resume bridge is broken."
|
||||||
)
|
)
|
||||||
expected = hitlrequest_action_count(pending_value)
|
expected = hitlrequest_action_count(pending_value)
|
||||||
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
||||||
# Prevent the parent's resume payload from leaking into subagent
|
# Stop the parent's resume leaking into subagent interrupts via
|
||||||
# interrupts via langgraph's parent_scratchpad fallback.
|
# langgraph's parent_scratchpad fallback.
|
||||||
drain_parent_null_resume(runtime)
|
drain_parent_null_resume(runtime)
|
||||||
with ot.subagent_invoke_span(
|
with ot.subagent_invoke_span(
|
||||||
subagent_type=subagent_type, path=invoke_path
|
subagent_type=subagent_type, path=invoke_path
|
||||||
|
|
@ -829,10 +787,8 @@ def build_task_tool_with_parent_config(
|
||||||
] = None,
|
] = None,
|
||||||
) -> str | Command:
|
) -> str | Command:
|
||||||
atask_start = time.perf_counter()
|
atask_start = time.perf_counter()
|
||||||
# Kill switch: when ops flips the spawn-paused flag for this
|
# Ops kill switch: short-circuit every task() call for this workspace
|
||||||
# workspace, every ``task(...)`` invocation (single- or batch-mode)
|
# so the orchestrator stops hammering downstream APIs.
|
||||||
# short-circuits with a clear ToolMessage so the orchestrator can
|
|
||||||
# tell the user what happened and stop hammering downstream APIs.
|
|
||||||
if await is_spawn_paused(search_space_id):
|
if await is_spawn_paused(search_space_id):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[hitl_route] atask SPAWN_PAUSED: search_space_id=%s tool_call_id=%s",
|
"[hitl_route] atask SPAWN_PAUSED: search_space_id=%s tool_call_id=%s",
|
||||||
|
|
@ -923,8 +879,8 @@ def build_task_tool_with_parent_config(
|
||||||
)
|
)
|
||||||
expected = hitlrequest_action_count(pending_value)
|
expected = hitlrequest_action_count(pending_value)
|
||||||
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
||||||
# Prevent the parent's resume payload from leaking into subagent
|
# Stop the parent's resume leaking into subagent interrupts via
|
||||||
# interrupts via langgraph's parent_scratchpad fallback.
|
# langgraph's parent_scratchpad fallback.
|
||||||
drain_parent_null_resume(runtime)
|
drain_parent_null_resume(runtime)
|
||||||
with ot.subagent_invoke_span(
|
with ot.subagent_invoke_span(
|
||||||
subagent_type=subagent_type, path=invoke_path
|
subagent_type=subagent_type, path=invoke_path
|
||||||
|
|
|
||||||
|
|
@ -1,33 +1,19 @@
|
||||||
"""End-of-turn persistence for the cloud-mode SurfSense filesystem.
|
"""End-of-turn persistence for the cloud-mode SurfSense filesystem.
|
||||||
|
|
||||||
This middleware runs ``aafter_agent`` once per turn (cloud only). It commits
|
Runs ``aafter_agent`` once per turn (cloud only), committing staged folder
|
||||||
all staged folder creations, file moves, content writes/edits, file deletes
|
creates, moves, writes/edits, and ``rm``/``rmdir`` to Postgres in one ordered
|
||||||
(``rm``), and directory deletes (``rmdir``) to Postgres in a single ordered
|
pass. Order matters: moves resolve before writes (so write-then-move lands at
|
||||||
pass:
|
the final path), and file deletes run before directory deletes (so a same-turn
|
||||||
|
``rm /a/x.md`` + ``rmdir /a`` works).
|
||||||
|
|
||||||
1. Materialize ``staged_dirs`` into ``Folder`` rows.
|
When ``flags.enable_action_log`` is on, each destructive op also snapshots a
|
||||||
2. Apply ``pending_moves`` in order (chained moves resolved via
|
``DocumentRevision`` / ``FolderRevision`` for revert. For ``rm``/``rmdir`` the
|
||||||
``doc_id_by_path``).
|
snapshot and DELETE share a SAVEPOINT, so a failed snapshot aborts the delete
|
||||||
3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move
|
rather than making the data silently irreversible.
|
||||||
sequences commit at the final path. Paths queued for ``rm`` this turn
|
|
||||||
are dropped here so a write+rm sequence doesn't recreate the doc.
|
|
||||||
4. Commit content writes / edits for ``/documents/*`` paths, skipping
|
|
||||||
``temp_*`` basenames.
|
|
||||||
5. Apply ``pending_deletes`` (``rm``) — file deletes run BEFORE directory
|
|
||||||
deletes so a same-turn ``rm /a/x.md`` + ``rmdir /a`` sequence works.
|
|
||||||
6. Apply ``pending_dir_deletes`` (``rmdir``); re-verifies emptiness against
|
|
||||||
the post-step-5 DB state.
|
|
||||||
|
|
||||||
When ``flags.enable_action_log`` is on every destructive op also writes a
|
The commit body is a free function (``commit_staged_filesystem_state``) so the
|
||||||
``DocumentRevision`` / ``FolderRevision`` snapshot bound to the
|
stream-task fallback can run the identical routine when ``aafter_agent`` was
|
||||||
originating ``AgentActionLog`` row via ``tool_call_id``. ``rm``/``rmdir``
|
skipped (e.g. client disconnect).
|
||||||
share a single ``SAVEPOINT`` with their snapshot — if the snapshot fails
|
|
||||||
the DELETE rolls back and we surface the error rather than silently
|
|
||||||
making the data irreversible.
|
|
||||||
|
|
||||||
The commit body is exposed as a free function ``commit_staged_filesystem_state``
|
|
||||||
so the optional stream-task fallback (``stream_new_chat.py``) can call the
|
|
||||||
exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -216,11 +202,9 @@ async def _create_document(
|
||||||
virtual_path,
|
virtual_path,
|
||||||
search_space_id,
|
search_space_id,
|
||||||
)
|
)
|
||||||
# Filesystem-parity invariant: the only thing that *must* be unique is
|
# Pre-check the path-derived unique_identifier_hash so a duplicate path
|
||||||
# the path. Two notes can legitimately share content (e.g. ``cp a b``).
|
# surfaces as a clean ValueError instead of an INSERT IntegrityError that
|
||||||
# Guard against the path-derived ``unique_identifier_hash`` constraint
|
# poisons the session. Content is intentionally not unique (cp a b).
|
||||||
# so we surface a clean ValueError instead of letting the INSERT poison
|
|
||||||
# the session with an IntegrityError.
|
|
||||||
path_collision = await session.execute(
|
path_collision = await session.execute(
|
||||||
select(Document.id).where(
|
select(Document.id).where(
|
||||||
Document.search_space_id == search_space_id,
|
Document.search_space_id == search_space_id,
|
||||||
|
|
@ -232,13 +216,6 @@ async def _create_document(
|
||||||
f"a document already exists at path '{virtual_path}' "
|
f"a document already exists at path '{virtual_path}' "
|
||||||
"(unique_identifier_hash collision)"
|
"(unique_identifier_hash collision)"
|
||||||
)
|
)
|
||||||
# ``content_hash`` is intentionally NOT checked for uniqueness here.
|
|
||||||
# In a real filesystem two files at different paths can hold identical
|
|
||||||
# bytes, and the agent's ``write_file`` path needs that semantic to
|
|
||||||
# support copy/duplicate operations. The hash remains useful as a
|
|
||||||
# change-detection hint for connector indexers, which still consult it
|
|
||||||
# via :func:`check_duplicate_document` but do so with a non-unique
|
|
||||||
# lookup (``.first()``).
|
|
||||||
content_hash = generate_content_hash(content, search_space_id)
|
content_hash = generate_content_hash(content, search_space_id)
|
||||||
doc = Document(
|
doc = Document(
|
||||||
title=title,
|
title=title,
|
||||||
|
|
@ -435,15 +412,9 @@ async def _mark_action_reversible(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Flip ``agent_action_log.reversible = TRUE`` for ``action_id``.
|
"""Flip ``agent_action_log.reversible = TRUE`` for ``action_id``.
|
||||||
|
|
||||||
Best-effort: caller may invoke from inside a SAVEPOINT and treat
|
Pair with ``_dispatch_reversibility_update`` *after* the enclosing
|
||||||
failure as a soft demotion (snapshot persists, just no Revert button).
|
SAVEPOINT commits, so the UI never sees ``reversible=true`` for a row whose
|
||||||
|
update later rolls back.
|
||||||
Callers should also call ``_dispatch_reversibility_update`` (defined
|
|
||||||
below) AFTER the enclosing SAVEPOINT block exits successfully so the
|
|
||||||
chat tool card can light up its Revert button without
|
|
||||||
re-fetching ``GET /threads/.../actions``. Dispatching from inside the
|
|
||||||
SAVEPOINT would risk emitting "reversible=true" for rows whose
|
|
||||||
update gets rolled back if the surrounding destructive op fails.
|
|
||||||
"""
|
"""
|
||||||
if action_id is None:
|
if action_id is None:
|
||||||
return
|
return
|
||||||
|
|
@ -455,22 +426,11 @@ async def _mark_action_reversible(
|
||||||
|
|
||||||
|
|
||||||
async def _dispatch_reversibility_update(action_id: int | None) -> None:
|
async def _dispatch_reversibility_update(action_id: int | None) -> None:
|
||||||
"""Best-effort dispatch of an ``action_log_updated`` custom event.
|
"""Emit an ``action_log_updated`` SSE event so the Revert button lights up.
|
||||||
|
|
||||||
Surfaces the post-SAVEPOINT reversibility flip to the SSE layer so
|
Best-effort (failures swallowed; the REST actions endpoint is
|
||||||
the chat tool card can flip its Revert button live. Defensive:
|
authoritative). Inside :func:`commit_staged_filesystem_state` this is
|
||||||
failures are logged at debug level and swallowed; the
|
deferred until after the outer commit via ``deferred_dispatches``.
|
||||||
REST endpoint ``GET /threads/.../actions`` is still authoritative.
|
|
||||||
|
|
||||||
.. warning::
|
|
||||||
Inside :func:`commit_staged_filesystem_state` we DEFER all
|
|
||||||
dispatches until the outer ``session.commit()`` succeeds — see
|
|
||||||
the ``deferred_dispatches`` queue in that function. Dispatching
|
|
||||||
from inside a SAVEPOINT block while the outer transaction is
|
|
||||||
still pending would emit ``reversible=true`` for rows whose
|
|
||||||
snapshots get rolled back if the outer commit fails. Direct
|
|
||||||
callers (e.g. the optional stream-task fallback) that own the
|
|
||||||
full session lifetime can still call this helper inline.
|
|
||||||
"""
|
"""
|
||||||
if action_id is None:
|
if action_id is None:
|
||||||
return
|
return
|
||||||
|
|
@ -489,12 +449,9 @@ async def _dispatch_reversibility_update(action_id: int | None) -> None:
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Snapshot helpers
|
# Snapshot helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
#
|
# Best-effort variants (write/edit/move/mkdir) swallow failures. Strict
|
||||||
# Best-effort helpers swallow + log so a snapshot failure can never break
|
# variants (rm/rmdir) share the destructive op's SAVEPOINT so a snapshot
|
||||||
# the destructive op for non-destructive tools (write/edit/move/mkdir).
|
# failure aborts the delete instead of making it silently irreversible.
|
||||||
# Strict helpers run inside the SAME ``begin_nested()`` SAVEPOINT as the
|
|
||||||
# destructive DELETE — failure aborts the savepoint and leaves the doc /
|
|
||||||
# folder intact, so revertable ops never become irreversible silently.
|
|
||||||
|
|
||||||
|
|
||||||
def _doc_revision_payload(
|
def _doc_revision_payload(
|
||||||
|
|
@ -704,15 +661,9 @@ async def commit_staged_filesystem_state(
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Commit all staged filesystem changes; return the state delta for reducers.
|
"""Commit all staged filesystem changes; return the state delta for reducers.
|
||||||
|
|
||||||
Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent`
|
Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` and
|
||||||
and the optional stream-task fallback.
|
the stream-task fallback. See the module docstring for ordering and the
|
||||||
|
action-log snapshot/revert semantics.
|
||||||
When ``flags.enable_action_log`` is on every destructive op also writes
|
|
||||||
a ``DocumentRevision`` / ``FolderRevision`` snapshot bound to the
|
|
||||||
originating ``AgentActionLog`` row via ``tool_call_id``. Snapshot
|
|
||||||
durability is best-effort for non-destructive ops and STRICT for
|
|
||||||
``rm``/``rmdir`` (snapshot + DELETE share a SAVEPOINT — snapshot
|
|
||||||
failure aborts the delete).
|
|
||||||
"""
|
"""
|
||||||
if filesystem_mode != FilesystemMode.CLOUD:
|
if filesystem_mode != FilesystemMode.CLOUD:
|
||||||
return None
|
return None
|
||||||
|
|
@ -771,8 +722,7 @@ async def commit_staged_filesystem_state(
|
||||||
flags = get_flags()
|
flags = get_flags()
|
||||||
snapshot_enabled = flags.enable_action_log
|
snapshot_enabled = flags.enable_action_log
|
||||||
|
|
||||||
# De-duplicate pending deletes per-path while preserving the latest
|
# De-dup deletes per-path, keeping the latest tool_call_id (likeliest revert).
|
||||||
# tool_call_id (the one the user is most likely to revert via the UI).
|
|
||||||
file_delete_paths: dict[str, str] = {}
|
file_delete_paths: dict[str, str] = {}
|
||||||
for entry in pending_deletes:
|
for entry in pending_deletes:
|
||||||
if not isinstance(entry, dict):
|
if not isinstance(entry, dict):
|
||||||
|
|
@ -796,22 +746,14 @@ async def commit_staged_filesystem_state(
|
||||||
applied_moves: list[dict[str, Any]] = []
|
applied_moves: list[dict[str, Any]] = []
|
||||||
doc_id_path_tombstones: dict[str, int | None] = {}
|
doc_id_path_tombstones: dict[str, int | None] = {}
|
||||||
tree_changed = False
|
tree_changed = False
|
||||||
# Reversibility-flip dispatches are deferred until AFTER the outer
|
# Reversibility-flip dispatches are drained only after the outer commit
|
||||||
# ``session.commit()`` succeeds. Dispatching from inside the
|
# succeeds (and abandoned on rollback), so the UI never sees reversible=true
|
||||||
# SAVEPOINT chain while the outer transaction is still pending
|
# for a snapshot that didn't durably land.
|
||||||
# would emit ``reversible=true`` for rows whose snapshots get rolled
|
|
||||||
# back if the final commit raises. Snapshot helpers append on
|
|
||||||
# success; we drain this list after commit and silently abandon it
|
|
||||||
# on rollback so the UI stays consistent with durable state.
|
|
||||||
deferred_dispatches: list[int] = []
|
deferred_dispatches: list[int] = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with shielded_async_session() as session:
|
async with shielded_async_session() as session:
|
||||||
# ------------------------------------------------------------------
|
# Resolve all action-id bindings in one SELECT per turn, not per op.
|
||||||
# Resolve action-id bindings up front. One SELECT per turn for all
|
|
||||||
# tool_call_ids, NOT one per op — important because a turn that
|
|
||||||
# touches 50 paths would otherwise issue 50 lookups.
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
action_id_by_call: dict[str, int] = {}
|
action_id_by_call: dict[str, int] = {}
|
||||||
if snapshot_enabled and thread_id is not None:
|
if snapshot_enabled and thread_id is not None:
|
||||||
tool_call_ids: set[str] = set()
|
tool_call_ids: set[str] = set()
|
||||||
|
|
@ -844,10 +786,7 @@ async def commit_staged_filesystem_state(
|
||||||
next(iter(action_id_by_call), None) if action_id_by_call else None
|
next(iter(action_id_by_call), None) if action_id_by_call else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# 1. staged_dirs -> Folder rows (snapshot post-flush for the FK).
|
||||||
# 1. staged_dirs -> Folder rows. Snapshot post-flush so the new
|
|
||||||
# folder_id is available for the FK.
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
for folder_path in staged_dirs:
|
for folder_path in staged_dirs:
|
||||||
if not isinstance(folder_path, str):
|
if not isinstance(folder_path, str):
|
||||||
continue
|
continue
|
||||||
|
|
@ -868,7 +807,6 @@ async def commit_staged_filesystem_state(
|
||||||
tcid = staged_dir_tool_calls.get(folder_path)
|
tcid = staged_dir_tool_calls.get(folder_path)
|
||||||
action_id = _action_id_for(tcid)
|
action_id = _action_id_for(tcid)
|
||||||
if action_id is not None:
|
if action_id is not None:
|
||||||
# Re-read the folder for the snapshot.
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Folder).where(Folder.id == folder_id)
|
select(Folder).where(Folder.id == folder_id)
|
||||||
)
|
)
|
||||||
|
|
@ -883,16 +821,13 @@ async def commit_staged_filesystem_state(
|
||||||
deferred_dispatches=deferred_dispatches,
|
deferred_dispatches=deferred_dispatches,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# 2. pending_moves (snapshot pre-move for in-place restore on revert).
|
||||||
# 2. pending_moves. Snapshot pre-move (in-place restore on revert).
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
for move in pending_moves:
|
for move in pending_moves:
|
||||||
source = str(move.get("source") or "")
|
source = str(move.get("source") or "")
|
||||||
if snapshot_enabled and source:
|
if snapshot_enabled and source:
|
||||||
tcid = str(move.get("tool_call_id") or "")
|
tcid = str(move.get("tool_call_id") or "")
|
||||||
action_id = _action_id_for(tcid)
|
action_id = _action_id_for(tcid)
|
||||||
if action_id is not None:
|
if action_id is not None:
|
||||||
# Resolve the doc to snapshot BEFORE we mutate it.
|
|
||||||
doc_id_pre = doc_id_by_path.get(source)
|
doc_id_pre = doc_id_by_path.get(source)
|
||||||
document_pre: Document | None = None
|
document_pre: Document | None = None
|
||||||
if doc_id_pre is not None:
|
if doc_id_pre is not None:
|
||||||
|
|
@ -942,10 +877,8 @@ async def commit_staged_filesystem_state(
|
||||||
path = move_alias[path]
|
path = move_alias[path]
|
||||||
return path
|
return path
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# 3. dirty_paths -> writes/edits. Paths queued for rm this turn are
|
||||||
# 3. dirty_paths -> writes/edits. Skip any path queued for ``rm``
|
# skipped so a write+rm sequence doesn't recreate the doc.
|
||||||
# this turn so a write+rm sequence doesn't recreate the doc.
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
kb_dirty_seen: set[str] = set()
|
kb_dirty_seen: set[str] = set()
|
||||||
kb_dirty: list[str] = []
|
kb_dirty: list[str] = []
|
||||||
kb_dirty_origin: dict[str, str] = {}
|
kb_dirty_origin: dict[str, str] = {}
|
||||||
|
|
@ -974,9 +907,7 @@ async def commit_staged_filesystem_state(
|
||||||
continue
|
continue
|
||||||
content = "\n".join(file_data.get("content") or [])
|
content = "\n".join(file_data.get("content") or [])
|
||||||
doc_id = doc_id_by_path.get(path)
|
doc_id = doc_id_by_path.get(path)
|
||||||
# Path ↔ tool_call_id binding: the dirty_paths list dedupes via
|
# Look up tool_call_id by final path or its pre-rename origin.
|
||||||
# _add_unique_reducer, so we look up the latest tool_call_id by
|
|
||||||
# path (or by the un-renamed origin).
|
|
||||||
origin = kb_dirty_origin.get(path, path)
|
origin = kb_dirty_origin.get(path, path)
|
||||||
tcid = dirty_path_tool_calls.get(path) or dirty_path_tool_calls.get(
|
tcid = dirty_path_tool_calls.get(path) or dirty_path_tool_calls.get(
|
||||||
origin
|
origin
|
||||||
|
|
@ -984,12 +915,9 @@ async def commit_staged_filesystem_state(
|
||||||
action_id = _action_id_for(tcid)
|
action_id = _action_id_for(tcid)
|
||||||
|
|
||||||
if doc_id is None:
|
if doc_id is None:
|
||||||
# The in-memory ``doc_id_by_path`` is per-thread and starts
|
# doc_id_by_path is per-thread and empty in a new chat, so a
|
||||||
# empty in every new chat. If the agent writes to a path
|
# write to a path already in the DB must update in place, not
|
||||||
# that already exists in the DB (e.g. a previous chat's
|
# INSERT (which would hit the path-derived unique hash).
|
||||||
# ``notes.md``), we must NOT try to INSERT — it would hit
|
|
||||||
# ``unique_identifier_hash`` (path-derived). Look up the
|
|
||||||
# existing doc and update it in place instead.
|
|
||||||
existing = await virtual_path_to_doc(
|
existing = await virtual_path_to_doc(
|
||||||
session,
|
session,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
|
|
@ -1038,12 +966,9 @@ async def commit_staged_filesystem_state(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Fresh create. Wrap each create in a SAVEPOINT so a
|
# Fresh create, wrapped in a SAVEPOINT so a residual
|
||||||
# residual ``IntegrityError`` (e.g. a deployment that
|
# IntegrityError (e.g. pre-migration-133 content_hash UNIQUE)
|
||||||
# hasn't run migration 133 yet, where
|
# rolls back only this create, not the whole turn.
|
||||||
# ``documents.content_hash`` still carries its legacy
|
|
||||||
# global UNIQUE constraint) rolls back only this one
|
|
||||||
# create instead of poisoning the whole turn.
|
|
||||||
placeholder_revision_id: int | None = None
|
placeholder_revision_id: int | None = None
|
||||||
if snapshot_enabled and action_id is not None:
|
if snapshot_enabled and action_id is not None:
|
||||||
placeholder_revision_id = await _snapshot_document_pre_create(
|
placeholder_revision_id = await _snapshot_document_pre_create(
|
||||||
|
|
@ -1066,8 +991,7 @@ async def commit_staged_filesystem_state(
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"kb_persistence: skipping %s create: %s", path, exc
|
"kb_persistence: skipping %s create: %s", path, exc
|
||||||
)
|
)
|
||||||
# Roll back the placeholder revision since the create
|
# Create never happened; drop its placeholder revision.
|
||||||
# never happened.
|
|
||||||
if placeholder_revision_id is not None:
|
if placeholder_revision_id is not None:
|
||||||
await session.execute(
|
await session.execute(
|
||||||
delete(DocumentRevision).where(
|
delete(DocumentRevision).where(
|
||||||
|
|
@ -1114,19 +1038,14 @@ async def commit_staged_filesystem_state(
|
||||||
)
|
)
|
||||||
tree_changed = True
|
tree_changed = True
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# 4. pending_deletes -> rm. Strict: snapshot + DELETE share a
|
||||||
# 4. pending_deletes -> ``rm``. STRICT durability: snapshot + DELETE
|
# SAVEPOINT, so a failed snapshot rolls the delete back too.
|
||||||
# share a SAVEPOINT. If the snapshot insert fails, the DELETE
|
|
||||||
# rolls back too and we surface the error rather than silently
|
|
||||||
# making the data irreversible.
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
for raw_path, tcid in file_delete_paths.items():
|
for raw_path, tcid in file_delete_paths.items():
|
||||||
final = _final_path(raw_path)
|
final = _final_path(raw_path)
|
||||||
if not final.startswith(DOCUMENTS_ROOT + "/"):
|
if not final.startswith(DOCUMENTS_ROOT + "/"):
|
||||||
continue
|
continue
|
||||||
action_id = _action_id_for(tcid)
|
action_id = _action_id_for(tcid)
|
||||||
|
|
||||||
# Resolve the doc.
|
|
||||||
doc_id_for_delete = doc_id_by_path.get(final)
|
doc_id_for_delete = doc_id_by_path.get(final)
|
||||||
document_to_delete: Document | None = None
|
document_to_delete: Document | None = None
|
||||||
if doc_id_for_delete is not None:
|
if doc_id_for_delete is not None:
|
||||||
|
|
@ -1155,7 +1074,6 @@ async def commit_staged_filesystem_state(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with session.begin_nested():
|
async with session.begin_nested():
|
||||||
# Strict: snapshot first; failure aborts the delete.
|
|
||||||
if snapshot_enabled and action_id is not None:
|
if snapshot_enabled and action_id is not None:
|
||||||
chunks = await _load_chunks_for_snapshot(
|
chunks = await _load_chunks_for_snapshot(
|
||||||
session, doc_id=doc_pk
|
session, doc_id=doc_pk
|
||||||
|
|
@ -1184,10 +1102,7 @@ async def commit_staged_filesystem_state(
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# B1 — SAVEPOINT released. Defer the reversibility-flip
|
# Defer the reversibility flip until after the outer commit.
|
||||||
# dispatch until AFTER the outer commit succeeds so we
|
|
||||||
# never tell the UI a row is reversible if its snapshot
|
|
||||||
# gets rolled back.
|
|
||||||
if snapshot_enabled and action_id is not None:
|
if snapshot_enabled and action_id is not None:
|
||||||
deferred_dispatches.append(int(action_id))
|
deferred_dispatches.append(int(action_id))
|
||||||
|
|
||||||
|
|
@ -1206,11 +1121,8 @@ async def commit_staged_filesystem_state(
|
||||||
)
|
)
|
||||||
tree_changed = True
|
tree_changed = True
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# 5. pending_dir_deletes -> rmdir. Strict, and re-checks emptiness
|
||||||
# 5. pending_dir_deletes -> ``rmdir``. STRICT durability + final
|
# against post-step-4 DB state.
|
||||||
# emptiness check (after step 4's deletes have run, an "empty
|
|
||||||
# mid-turn" directory really IS empty in DB now).
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
for raw_path, tcid in dir_delete_paths.items():
|
for raw_path, tcid in dir_delete_paths.items():
|
||||||
final = _final_path(raw_path)
|
final = _final_path(raw_path)
|
||||||
if not final.startswith(DOCUMENTS_ROOT + "/"):
|
if not final.startswith(DOCUMENTS_ROOT + "/"):
|
||||||
|
|
@ -1231,7 +1143,6 @@ async def commit_staged_filesystem_state(
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Re-check emptiness against in-DB state.
|
|
||||||
docs_in_folder = await session.execute(
|
docs_in_folder = await session.execute(
|
||||||
select(Document.id)
|
select(Document.id)
|
||||||
.where(Document.folder_id == folder_id)
|
.where(Document.folder_id == folder_id)
|
||||||
|
|
@ -1296,10 +1207,7 @@ async def commit_staged_filesystem_state(
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# B1 — SAVEPOINT released. Defer the reversibility-flip
|
# Defer the reversibility flip until after the outer commit.
|
||||||
# dispatch until AFTER the outer commit succeeds so we
|
|
||||||
# never tell the UI a row is reversible if its snapshot
|
|
||||||
# gets rolled back.
|
|
||||||
if snapshot_enabled and action_id is not None:
|
if snapshot_enabled and action_id is not None:
|
||||||
deferred_dispatches.append(int(action_id))
|
deferred_dispatches.append(int(action_id))
|
||||||
|
|
||||||
|
|
@ -1319,18 +1227,13 @@ async def commit_staged_filesystem_state(
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"kb_persistence: commit failed (search_space=%s)", search_space_id
|
"kb_persistence: commit failed (search_space=%s)", search_space_id
|
||||||
)
|
)
|
||||||
# Outer commit raised — every SAVEPOINT-released change above
|
# Outer commit raised: everything above rolled back, so drop the
|
||||||
# (snapshots + reversibility flips) is now rolled back. Drop
|
# deferred dispatches.
|
||||||
# the deferred SSE dispatches so the UI stays consistent with
|
|
||||||
# durable state.
|
|
||||||
deferred_dispatches.clear()
|
deferred_dispatches.clear()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Outer commit succeeded; flush deferred reversibility-flip
|
# Commit succeeded; flush deferred reversibility flips (de-duped, since
|
||||||
# dispatches now so the chat tool card can light up its Revert
|
# write-then-rm in one turn appends an id per snapshot site).
|
||||||
# button without re-fetching ``GET /threads/.../actions``. De-dup
|
|
||||||
# to avoid emitting the same id twice (e.g. write-then-rm in the
|
|
||||||
# same turn dispatches once for each snapshot site).
|
|
||||||
if deferred_dispatches and dispatch_events:
|
if deferred_dispatches and dispatch_events:
|
||||||
for action_id in dict.fromkeys(deferred_dispatches):
|
for action_id in dict.fromkeys(deferred_dispatches):
|
||||||
try:
|
try:
|
||||||
|
|
@ -1376,9 +1279,8 @@ async def commit_staged_filesystem_state(
|
||||||
p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX)
|
p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Tombstone every committed-delete path so a stale ``state["files"]`` entry
|
# Tombstone committed-delete paths so a stale state["files"] entry can't
|
||||||
# (which als_info would otherwise interpret as content) cannot survive into
|
# survive into the next turn and make a now-empty folder look non-empty.
|
||||||
# the next turn and make a now-empty folder look non-empty.
|
|
||||||
deleted_file_paths = [
|
deleted_file_paths = [
|
||||||
str(payload.get("virtualPath") or "")
|
str(payload.get("virtualPath") or "")
|
||||||
for payload in committed_deletes
|
for payload in committed_deletes
|
||||||
|
|
@ -1399,11 +1301,8 @@ async def commit_staged_filesystem_state(
|
||||||
"dirty_path_tool_calls": {_CLEAR: True},
|
"dirty_path_tool_calls": {_CLEAR: True},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Emit one Receipt per committed mutation, folded into ``state['receipts']``
|
# One Receipt per committed mutation: ground truth (post-savepoint) for the
|
||||||
# via ``_list_append_reducer``. The receipts surface what actually committed
|
# orchestrator's <verification> teaching. KB writes have no public URL.
|
||||||
# (post-savepoint) rather than what the LLM intended; the orchestrator uses
|
|
||||||
# them as ground truth in the ``<verification>`` teaching. KB writes do not
|
|
||||||
# have public verifiable URLs, so ``verifiable_url`` stays unset.
|
|
||||||
receipts: list[Receipt] = []
|
receipts: list[Receipt] = []
|
||||||
|
|
||||||
def _kb_receipt(
|
def _kb_receipt(
|
||||||
|
|
@ -1444,8 +1343,6 @@ async def commit_staged_filesystem_state(
|
||||||
external_id=payload.get("id"),
|
external_id=payload.get("id"),
|
||||||
)
|
)
|
||||||
for payload in applied_moves:
|
for payload in applied_moves:
|
||||||
# ``applied_moves`` rows carry the destination ``virtualPath`` because
|
|
||||||
# the move has already landed in the DB by the time we reach this code.
|
|
||||||
path = str(payload.get("virtualPath") or "")
|
path = str(payload.get("virtualPath") or "")
|
||||||
_kb_receipt(
|
_kb_receipt(
|
||||||
type="file",
|
type="file",
|
||||||
|
|
@ -1485,9 +1382,7 @@ async def commit_staged_filesystem_state(
|
||||||
if tree_changed:
|
if tree_changed:
|
||||||
delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1
|
delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1
|
||||||
|
|
||||||
# Avoid 'unused' lint when turn_id_for_revision was only useful for
|
_ = turn_id_for_revision # diagnostic-only; silence unused lint
|
||||||
# diagnostic purposes inside the SAVEPOINT chain above.
|
|
||||||
_ = turn_id_for_revision
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"kb_persistence: commit (search_space=%s) creates=%d updates=%d "
|
"kb_persistence: commit (search_space=%s) creates=%d updates=%d "
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,6 @@ def extract_domain(url: str) -> str:
|
||||||
try:
|
try:
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
domain = parsed.netloc
|
domain = parsed.netloc
|
||||||
# Remove 'www.' prefix if present
|
|
||||||
if domain.startswith("www."):
|
if domain.startswith("www."):
|
||||||
domain = domain[4:]
|
domain = domain[4:]
|
||||||
return domain
|
return domain
|
||||||
|
|
@ -53,14 +52,13 @@ def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]:
|
||||||
if len(content) <= max_length:
|
if len(content) <= max_length:
|
||||||
return content, False
|
return content, False
|
||||||
|
|
||||||
# Try to truncate at a sentence boundary
|
# Prefer truncating at a sentence/paragraph boundary.
|
||||||
truncated = content[:max_length]
|
truncated = content[:max_length]
|
||||||
last_period = truncated.rfind(".")
|
last_period = truncated.rfind(".")
|
||||||
last_newline = truncated.rfind("\n\n")
|
last_newline = truncated.rfind("\n\n")
|
||||||
|
|
||||||
# Use the later of the two boundaries, or just truncate
|
|
||||||
boundary = max(last_period, last_newline)
|
boundary = max(last_period, last_newline)
|
||||||
if boundary > max_length * 0.8: # Only use boundary if it's not too far back
|
if boundary > max_length * 0.8: # only if the boundary isn't too far back
|
||||||
truncated = content[: boundary + 1]
|
truncated = content[: boundary + 1]
|
||||||
|
|
||||||
return truncated + "\n\n[Content truncated...]", True
|
return truncated + "\n\n[Content truncated...]", True
|
||||||
|
|
@ -111,8 +109,8 @@ async def _scrape_youtube_video(
|
||||||
http_client.proxies.update(residential_proxies)
|
http_client.proxies.update(residential_proxies)
|
||||||
ytt_api = YouTubeTranscriptApi(http_client=http_client)
|
ytt_api = YouTubeTranscriptApi(http_client=http_client)
|
||||||
|
|
||||||
# List all available transcripts and pick the first one
|
# Pick the first transcript (video's primary language) rather than
|
||||||
# (the video's primary language) instead of defaulting to English
|
# defaulting to English.
|
||||||
transcript_list = ytt_api.list(video_id)
|
transcript_list = ytt_api.list(video_id)
|
||||||
transcript = next(iter(transcript_list))
|
transcript = next(iter(transcript_list))
|
||||||
captions = transcript.fetch()
|
captions = transcript.fetch()
|
||||||
|
|
@ -134,10 +132,8 @@ async def _scrape_youtube_video(
|
||||||
logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}")
|
logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}")
|
||||||
transcript_text = f"No captions available for this video. Error: {e!s}"
|
transcript_text = f"No captions available for this video. Error: {e!s}"
|
||||||
|
|
||||||
# Build combined content
|
|
||||||
content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}"
|
content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}"
|
||||||
|
|
||||||
# Truncate if needed
|
|
||||||
content, was_truncated = truncate_content(content, max_length)
|
content, was_truncated = truncate_content(content, max_length)
|
||||||
word_count = len(content.split())
|
word_count = len(content.split())
|
||||||
|
|
||||||
|
|
@ -212,20 +208,16 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
|
||||||
scrape_id = generate_scrape_id(url)
|
scrape_id = generate_scrape_id(url)
|
||||||
domain = extract_domain(url)
|
domain = extract_domain(url)
|
||||||
|
|
||||||
# Validate and normalize URL
|
|
||||||
if not url.startswith(("http://", "https://")):
|
if not url.startswith(("http://", "https://")):
|
||||||
url = f"https://{url}"
|
url = f"https://{url}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check if this is a YouTube URL and use transcript API instead
|
# YouTube URLs use the transcript API instead of crawling.
|
||||||
video_id = get_youtube_video_id(url)
|
video_id = get_youtube_video_id(url)
|
||||||
if video_id:
|
if video_id:
|
||||||
return await _scrape_youtube_video(url, video_id, max_length)
|
return await _scrape_youtube_video(url, video_id, max_length)
|
||||||
|
|
||||||
# Create webcrawler connector
|
|
||||||
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
|
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
|
||||||
|
|
||||||
# Crawl the URL
|
|
||||||
result, error = await connector.crawl_url(url, formats=["markdown"])
|
result, error = await connector.crawl_url(url, formats=["markdown"])
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
|
|
@ -250,28 +242,21 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
|
||||||
"error": "No content returned from crawler",
|
"error": "No content returned from crawler",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Extract content and metadata
|
|
||||||
content = result.get("content", "")
|
content = result.get("content", "")
|
||||||
metadata = result.get("metadata", {})
|
metadata = result.get("metadata", {})
|
||||||
|
|
||||||
# Get title from metadata
|
|
||||||
title = metadata.get("title", "")
|
title = metadata.get("title", "")
|
||||||
if not title:
|
if not title:
|
||||||
title = domain or url.split("/")[-1] or "Webpage"
|
title = domain or url.split("/")[-1] or "Webpage"
|
||||||
|
|
||||||
# Get description from metadata
|
|
||||||
description = metadata.get("description", "")
|
description = metadata.get("description", "")
|
||||||
if not description and content:
|
if not description and content:
|
||||||
# Use first paragraph as description
|
|
||||||
first_para = content.split("\n\n")[0] if content else ""
|
first_para = content.split("\n\n")[0] if content else ""
|
||||||
description = (
|
description = (
|
||||||
first_para[:300] + "..." if len(first_para) > 300 else first_para
|
first_para[:300] + "..." if len(first_para) > 300 else first_para
|
||||||
)
|
)
|
||||||
|
|
||||||
# Truncate content if needed
|
|
||||||
content, was_truncated = truncate_content(content, max_length)
|
content, was_truncated = truncate_content(content, max_length)
|
||||||
|
|
||||||
# Calculate word count
|
|
||||||
word_count = len(content.split())
|
word_count = len(content.split())
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -1,37 +1,9 @@
|
||||||
"""
|
"""Feature flags for the SurfSense new_chat agent stack.
|
||||||
Feature flags for the SurfSense new_chat agent stack.
|
|
||||||
|
|
||||||
These flags gate the newer agent middleware (some ported from OpenCode,
|
Flags are resolved at agent build time. Most upgrades default ON so Docker
|
||||||
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
|
updates work without operators adding new env vars; risky integrations stay
|
||||||
SurfSense-native). Most shipped agent-stack upgrades default ON so Docker
|
OFF. The master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` forces every
|
||||||
image updates work even when older installs do not have newly introduced
|
flag below to False for a one-switch rollback to pre-port behavior.
|
||||||
environment variables. Risky/experimental integrations stay default OFF,
|
|
||||||
and the master kill-switch can still disable everything new.
|
|
||||||
|
|
||||||
All new middleware checks its flag at agent build time. If the master
|
|
||||||
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
|
|
||||||
middleware is disabled regardless of its individual flag. This gives
|
|
||||||
operators a single switch to revert to pre-port behavior.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
|
|
||||||
Defaults:
|
|
||||||
|
|
||||||
SURFSENSE_ENABLE_CONTEXT_EDITING=true
|
|
||||||
SURFSENSE_ENABLE_COMPACTION_V2=true
|
|
||||||
SURFSENSE_ENABLE_RETRY_AFTER=true
|
|
||||||
SURFSENSE_ENABLE_MODEL_FALLBACK=false
|
|
||||||
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
|
|
||||||
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
|
|
||||||
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
|
|
||||||
SURFSENSE_ENABLE_PERMISSION=true
|
|
||||||
SURFSENSE_ENABLE_DOOM_LOOP=true
|
|
||||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
|
||||||
|
|
||||||
Master kill-switch (overrides everything else):
|
|
||||||
|
|
||||||
SURFSENSE_DISABLE_NEW_AGENT_STACK=true
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -93,39 +65,14 @@ class AgentFeatureFlags:
|
||||||
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
|
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
|
||||||
enable_otel: bool = False
|
enable_otel: bool = False
|
||||||
|
|
||||||
# Performance — compiled-agent cache (Phase 1 + Phase 2).
|
# Performance — reuse a compiled agent graph when the cache key matches
|
||||||
# When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled
|
# (~4-5s -> <50µs per turn). Safe to default-on because mutation tools take
|
||||||
# graph if the cache key matches (LLM config + thread + tool surface +
|
# fresh short-lived sessions per call and per-turn context (mentions, etc.)
|
||||||
# flags + system prompt + filesystem mode). Cuts per-turn agent-build
|
# is read from runtime.context, not the constructor closure. Rollback via
|
||||||
# wall clock from ~4-5s to <50µs on cache hits.
|
# SURFSENSE_ENABLE_AGENT_CACHE=false.
|
||||||
#
|
|
||||||
# SAFETY (Phase 2 unblocked this default-on):
|
|
||||||
# All connector mutation tools (``tools/notion``, ``tools/gmail``,
|
|
||||||
# ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``,
|
|
||||||
# ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``,
|
|
||||||
# ``tools/teams``, ``tools/luma``, ``connected_accounts``,
|
|
||||||
# ``update_memory``) now acquire fresh
|
|
||||||
# short-lived ``AsyncSession`` instances per call via
|
|
||||||
# :data:`async_session_maker`. The factory still accepts ``db_session``
|
|
||||||
# for registry compatibility but ``del``'s it immediately — see any
|
|
||||||
# of those files' factory docstrings for the rationale. The ``llm``
|
|
||||||
# closure is per-(provider, model, config_id) which is already in
|
|
||||||
# the cache key, so the LLM is safe to share across cached hits of
|
|
||||||
# the same key. The KB priority middleware reads
|
|
||||||
# ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5),
|
|
||||||
# not its constructor closure, so the same compiled agent serves
|
|
||||||
# turns with different mention lists correctly.
|
|
||||||
#
|
|
||||||
# Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the
|
|
||||||
# environment if a regression surfaces. The path is exercised by
|
|
||||||
# the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite.
|
|
||||||
enable_agent_cache: bool = True
|
enable_agent_cache: bool = True
|
||||||
# Phase 1 (deferred — measure first): pre-build & share the
|
# Deferred: only helps on outer-cache MISSES, so off until data shows cold
|
||||||
# general-purpose subagent ``CompiledSubAgent`` across cold-cache
|
# misses are frequent enough to justify the extra global state.
|
||||||
# misses. Only helps when the outer cache MISSES (cache hits already
|
|
||||||
# reuse the entire SubAgentMiddleware-compiled graph). Off by default
|
|
||||||
# until we have data showing cold misses are frequent enough to
|
|
||||||
# justify the extra global state.
|
|
||||||
enable_agent_cache_share_gp_subagent: bool = False
|
enable_agent_cache_share_gp_subagent: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -594,14 +594,9 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
inject_system_message: bool = True, # For backwards compatibility
|
inject_system_message: bool = True, # For backwards compatibility
|
||||||
) -> None:
|
) -> None:
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
# The planner LLM handles short, structured internal tasks (query
|
# Cheap model for structured internal tasks (query rewrite, date
|
||||||
# rewriting, date extraction, recency classification). When an
|
# extraction, recency classification) when one is configured; falls back
|
||||||
# operator marks a global config ``is_planner: true`` we route
|
# to the chat LLM otherwise.
|
||||||
# those calls to a cheap/fast model (e.g. gpt-4o-mini, Haiku, Azure
|
|
||||||
# gpt-5.x-nano) instead of the user's chat LLM — those classification
|
|
||||||
# tasks don't need frontier-tier capability. Falls back to the chat
|
|
||||||
# LLM when no planner config is wired up so deployments without one
|
|
||||||
# keep working unchanged.
|
|
||||||
self.planner_llm = planner_llm or llm
|
self.planner_llm = planner_llm or llm
|
||||||
self.search_space_id = search_space_id
|
self.search_space_id = search_space_id
|
||||||
self.filesystem_mode = filesystem_mode
|
self.filesystem_mode = filesystem_mode
|
||||||
|
|
@ -610,26 +605,17 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.mentioned_document_ids = mentioned_document_ids or []
|
self.mentioned_document_ids = mentioned_document_ids or []
|
||||||
self.inject_system_message = inject_system_message
|
self.inject_system_message = inject_system_message
|
||||||
# Build the kb-planner private Runnable ONCE here so we don't pay
|
# Compiled lazily and memoized to avoid the per-turn create_agent cost.
|
||||||
# the ``create_agent`` compile cost (50-200ms) on every turn.
|
|
||||||
# Disabled by default behind ``enable_kb_planner_runnable``; when
|
|
||||||
# off the planner falls back to the legacy ``planner_llm.ainvoke``
|
|
||||||
# path.
|
|
||||||
self._planner: Runnable | None = None
|
self._planner: Runnable | None = None
|
||||||
self._planner_compile_failed = False
|
self._planner_compile_failed = False
|
||||||
|
|
||||||
def _build_kb_planner_runnable(self) -> Runnable | None:
|
def _build_kb_planner_runnable(self) -> Runnable | None:
|
||||||
"""Compile the kb-planner private :class:`Runnable` once.
|
"""Lazily compile and memoize the kb-planner Runnable.
|
||||||
|
|
||||||
Returns ``None`` when the feature flag is disabled, when the LLM is
|
Returns ``None`` (and the caller falls back to ``planner_llm.ainvoke``)
|
||||||
unavailable, or when ``create_agent`` raises (we fall back to the
|
when the flag is off, the LLM is missing, or ``create_agent`` raises.
|
||||||
legacy ``planner_llm.ainvoke`` path in that case). Compilation happens
|
Built without tools but with RetryAfterMiddleware so a transient
|
||||||
lazily on first call, then memoized via ``self._planner``.
|
rate-limit on the planner call doesn't fail the whole turn.
|
||||||
|
|
||||||
The compiled agent is constructed without tools — the planner's
|
|
||||||
contract is "answer with structured JSON" — but it inherits the
|
|
||||||
:class:`RetryAfterMiddleware` so transient rate-limit errors
|
|
||||||
from the planner LLM call don't fail the whole turn.
|
|
||||||
"""
|
"""
|
||||||
if self._planner is not None or self._planner_compile_failed:
|
if self._planner is not None or self._planner_compile_failed:
|
||||||
return self._planner
|
return self._planner
|
||||||
|
|
@ -677,10 +663,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
t0 = loop.time()
|
t0 = loop.time()
|
||||||
|
|
||||||
# Prefer the compiled-once planner Runnable when enabled; otherwise
|
# Both paths tag surfsense:internal so the planner's intermediate
|
||||||
# fall back to ``planner_llm.ainvoke``. The ``surfsense:internal``
|
# events stay suppressed from the UI.
|
||||||
# tag is preserved on both paths so ``_stream_agent_events`` still
|
|
||||||
# suppresses the planner's intermediate events from the UI.
|
|
||||||
planner = self._build_kb_planner_runnable()
|
planner = self._build_kb_planner_runnable()
|
||||||
try:
|
try:
|
||||||
if planner is not None:
|
if planner is not None:
|
||||||
|
|
@ -819,32 +803,16 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
user_text=user_text,
|
user_text=user_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Per-turn ``mentioned_document_ids`` flow:
|
# Prefer per-turn mentions from runtime.context (lets a cached graph
|
||||||
# 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
|
# serve different turns); fall back to the constructor closure, draining
|
||||||
# streaming task supplies a fresh :class:`SurfSenseContextSchema`
|
# it after one read so stale mentions can't replay.
|
||||||
# on every ``astream_events`` call, so this list is naturally
|
|
||||||
# scoped to the current turn. Allows cross-turn graph reuse via
|
|
||||||
# ``agent_cache``.
|
|
||||||
# 2. Legacy fallback (cache disabled / context not propagated): the
|
|
||||||
# constructor-injected ``self.mentioned_document_ids`` list. We
|
|
||||||
# drain it after the first read so a cached graph (no Phase 1.5
|
|
||||||
# wiring) doesn't keep replaying the same mentions on every
|
|
||||||
# turn.
|
|
||||||
#
|
#
|
||||||
# CRITICAL: distinguish "context absent" (legacy caller, no field at
|
# CRITICAL: test ``ctx_mentions is not None``, not truthiness — an empty
|
||||||
# all) from "context provided but empty" (turn with no mentions).
|
# list means "this turn has no mentions", not "use the closure".
|
||||||
# ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
|
|
||||||
# Python, so a naive ``if ctx_mentions:`` would fall through to the
|
|
||||||
# legacy closure on every no-mention follow-up turn — replaying the
|
|
||||||
# mentions baked in by turn 1's cache-miss build. Always drain the
|
|
||||||
# closure once the runtime path has fired so a cached middleware
|
|
||||||
# instance can never resurrect stale state.
|
|
||||||
mention_ids: list[int] = []
|
mention_ids: list[int] = []
|
||||||
ctx = getattr(runtime, "context", None) if runtime is not None else None
|
ctx = getattr(runtime, "context", None) if runtime is not None else None
|
||||||
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
|
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
|
||||||
if ctx_mentions is not None:
|
if ctx_mentions is not None:
|
||||||
# Runtime path is authoritative — even an empty list means
|
|
||||||
# "this turn has no mentions", NOT "look at the closure".
|
|
||||||
mention_ids = list(ctx_mentions)
|
mention_ids = list(ctx_mentions)
|
||||||
if self.mentioned_document_ids:
|
if self.mentioned_document_ids:
|
||||||
self.mentioned_document_ids = []
|
self.mentioned_document_ids = []
|
||||||
|
|
@ -852,12 +820,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
mention_ids = list(self.mentioned_document_ids)
|
mention_ids = list(self.mentioned_document_ids)
|
||||||
self.mentioned_document_ids = []
|
self.mentioned_document_ids = []
|
||||||
|
|
||||||
# Folder mentions live alongside doc mentions on the runtime
|
# Folder mentions aren't embedded, so they skip hybrid search and are
|
||||||
# context. They never feed hybrid search (folders aren't
|
# surfaced only as [USER-MENTIONED] entries. Cloud mode only.
|
||||||
# embedded) — they're surfaced purely as ``[USER-MENTIONED]``
|
|
||||||
# priority entries so the agent walks the folder with ``ls`` /
|
|
||||||
# ``find_documents`` instead of ignoring it. Cloud filesystem
|
|
||||||
# mode only.
|
|
||||||
folder_mention_ids: list[int] = []
|
folder_mention_ids: list[int] = []
|
||||||
if (
|
if (
|
||||||
ctx is not None
|
ctx is not None
|
||||||
|
|
@ -939,14 +903,10 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
async def _materialize_folder_priority(
|
async def _materialize_folder_priority(
|
||||||
self, folder_ids: list[int]
|
self, folder_ids: list[int]
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Resolve user-mentioned folder ids to ``<priority_documents>`` entries.
|
"""Resolve mentioned folder ids to canonical-path priority entries.
|
||||||
|
|
||||||
Each entry uses the canonical ``/documents/Folder/Sub/`` virtual
|
Flagged ``mentioned=True`` with ``score=None`` (folders aren't ranked;
|
||||||
path (matching ``KnowledgeTreeMiddleware`` and the agent's
|
the agent decides which children to read).
|
||||||
``ls`` adapter) and is flagged ``mentioned=True`` so the
|
|
||||||
rendered line carries ``[USER-MENTIONED]``. ``score`` is left
|
|
||||||
``None`` so the renderer prints ``n/a`` — folders aren't
|
|
||||||
ranked, the agent decides which children to read.
|
|
||||||
"""
|
"""
|
||||||
if not folder_ids:
|
if not folder_ids:
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -30,22 +30,11 @@ from langgraph.types import interrupt
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Tools that mirror the safety profile of ``write_file`` against the
|
# Low-stakes creation tools auto-approved by default: each creates one
|
||||||
# SurfSense KB: each call creates ONE artifact in the user's own workspace
|
# artifact in the user's own workspace with no external visibility (drafts
|
||||||
# with no external visibility (drafts aren't sent; new files aren't shared
|
# aren't sent; new files aren't shared). They still call ``request_approval``,
|
||||||
# unless the user shares them later). These are auto-approved by default
|
# which returns ``decision_type="auto_approved"`` without firing an interrupt.
|
||||||
# so the agent can compose drafts and seed scratch files without a popup
|
# Per-search-space ``agent_permission_rules`` can re-enable prompting.
|
||||||
# on every call.
|
|
||||||
#
|
|
||||||
# Members of this set still call ``request_approval`` exactly as before;
|
|
||||||
# the function returns immediately with ``decision_type="auto_approved"``
|
|
||||||
# and the original params untouched. This preserves the call-site shape
|
|
||||||
# (logging, metadata fetching, account fallbacks) so the only behavior
|
|
||||||
# change is "no interrupt fires".
|
|
||||||
#
|
|
||||||
# To re-enable prompting, the future per-search-space rules table
|
|
||||||
# (``agent_permission_rules``) takes precedence in the permission ruleset
|
|
||||||
# layering assembled by the agent factory.
|
|
||||||
DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
|
DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
|
||||||
{
|
{
|
||||||
"create_gmail_draft",
|
"create_gmail_draft",
|
||||||
|
|
@ -150,10 +139,6 @@ def request_approval(
|
||||||
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
|
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
|
||||||
|
|
||||||
if tool_name in DEFAULT_AUTO_APPROVED_TOOLS:
|
if tool_name in DEFAULT_AUTO_APPROVED_TOOLS:
|
||||||
# Default policy: low-stakes creation tools (drafts + new-file
|
|
||||||
# creates) skip HITL because they're as recoverable as a local
|
|
||||||
# ``write_file`` against the SurfSense KB. The user can still
|
|
||||||
# delete the artifact in <30s if it's wrong.
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL",
|
"Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL",
|
||||||
tool_name,
|
tool_name,
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ def create_generate_image_tool(
|
||||||
captured model), use this config id instead of reading the search space's
|
captured model), use this config id instead of reading the search space's
|
||||||
live ``image_generation_config_id``.
|
live ``image_generation_config_id``.
|
||||||
"""
|
"""
|
||||||
del db_session # use a fresh per-call session, see below
|
del db_session # tool uses a fresh per-call session instead
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def generate_image(
|
async def generate_image(
|
||||||
|
|
@ -140,17 +140,12 @@ def create_generate_image_tool(
|
||||||
or IMAGE_GEN_AUTO_MODE_ID
|
or IMAGE_GEN_AUTO_MODE_ID
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build generation kwargs
|
# size/quality/style are intentionally omitted: valid values
|
||||||
# NOTE: size, quality, and style are intentionally NOT passed.
|
# differ per model, so we let each model use its own defaults.
|
||||||
# Different models support different values for these params
|
|
||||||
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
|
|
||||||
# gpt-image-1 wants "high"/"medium"/"low"; size options also
|
|
||||||
# differ). Letting the model use its own defaults avoids errors.
|
|
||||||
gen_kwargs: dict[str, Any] = {}
|
gen_kwargs: dict[str, Any] = {}
|
||||||
if n is not None and n > 1:
|
if n is not None and n > 1:
|
||||||
gen_kwargs["n"] = n
|
gen_kwargs["n"] = n
|
||||||
|
|
||||||
# Call litellm based on config type
|
|
||||||
if is_image_gen_auto_mode(config_id):
|
if is_image_gen_auto_mode(config_id):
|
||||||
if not ImageGenRouterService.is_initialized():
|
if not ImageGenRouterService.is_initialized():
|
||||||
err = (
|
err = (
|
||||||
|
|
@ -224,17 +219,13 @@ def create_generate_image_tool(
|
||||||
prompt=prompt, model=model_string, **gen_kwargs
|
prompt=prompt, model=model_string, **gen_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse the response and store in DB
|
|
||||||
response_dict = (
|
response_dict = (
|
||||||
response.model_dump()
|
response.model_dump()
|
||||||
if hasattr(response, "model_dump")
|
if hasattr(response, "model_dump")
|
||||||
else dict(response)
|
else dict(response)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate a random access token for this image
|
|
||||||
access_token = generate_image_token()
|
access_token = generate_image_token()
|
||||||
|
|
||||||
# Save to image_generations table for history
|
|
||||||
db_image_gen = ImageGeneration(
|
db_image_gen = ImageGeneration(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||||
|
|
@ -249,7 +240,6 @@ def create_generate_image_tool(
|
||||||
await session.refresh(db_image_gen)
|
await session.refresh(db_image_gen)
|
||||||
db_image_gen_id = db_image_gen.id
|
db_image_gen_id = db_image_gen.id
|
||||||
|
|
||||||
# Extract image URLs from response
|
|
||||||
images = response_dict.get("data", [])
|
images = response_dict.get("data", [])
|
||||||
if not images:
|
if not images:
|
||||||
return _failed(
|
return _failed(
|
||||||
|
|
@ -260,11 +250,8 @@ def create_generate_image_tool(
|
||||||
first_image = images[0]
|
first_image = images[0]
|
||||||
revised_prompt = first_image.get("revised_prompt", prompt)
|
revised_prompt = first_image.get("revised_prompt", prompt)
|
||||||
|
|
||||||
# Resolve image URL:
|
# b64_json (e.g. gpt-image-1) is served via our backend endpoint so
|
||||||
# - If the API returned a URL, use it directly.
|
# megabytes of base64 don't bloat the LLM context.
|
||||||
# - If the API returned b64_json (e.g. gpt-image-1), serve the
|
|
||||||
# image through our backend endpoint to avoid bloating the
|
|
||||||
# LLM context with megabytes of base64 data.
|
|
||||||
if first_image.get("url"):
|
if first_image.get("url"):
|
||||||
image_url = first_image["url"]
|
image_url = first_image["url"]
|
||||||
elif first_image.get("b64_json"):
|
elif first_image.get("b64_json"):
|
||||||
|
|
|
||||||
|
|
@ -241,23 +241,12 @@ def _normalize_connectors(
|
||||||
connectors_to_search: list[str] | None,
|
connectors_to_search: list[str] | None,
|
||||||
available_connectors: list[str] | None = None,
|
available_connectors: list[str] | None = None,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
|
"""Normalize model-supplied connectors to canonical ConnectorService types.
|
||||||
|
|
||||||
|
Maps user-facing aliases (e.g. WEBCRAWLER_CONNECTOR), drops unknowns, and
|
||||||
|
constrains to ``available_connectors`` when given. Empty input defaults to
|
||||||
|
all available connectors (minus live-search ones).
|
||||||
"""
|
"""
|
||||||
Normalize connectors provided by the model.
|
|
||||||
|
|
||||||
- Accepts user-facing enums like WEBCRAWLER_CONNECTOR and maps them to canonical
|
|
||||||
ConnectorService types.
|
|
||||||
- Drops unknown values.
|
|
||||||
- If available_connectors is provided, only includes connectors from that list.
|
|
||||||
- If connectors_to_search is None/empty, defaults to available_connectors or all.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
connectors_to_search: List of connectors requested by the model
|
|
||||||
available_connectors: List of connectors actually available in the search space
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of normalized connector strings to search
|
|
||||||
"""
|
|
||||||
# Determine the set of valid connectors to consider
|
|
||||||
valid_set = (
|
valid_set = (
|
||||||
set(available_connectors) if available_connectors else set(_ALL_CONNECTORS)
|
set(available_connectors) if available_connectors else set(_ALL_CONNECTORS)
|
||||||
)
|
)
|
||||||
|
|
@ -276,18 +265,16 @@ def _normalize_connectors(
|
||||||
c = (raw or "").strip().upper()
|
c = (raw or "").strip().upper()
|
||||||
if not c:
|
if not c:
|
||||||
continue
|
continue
|
||||||
# Map user-facing aliases to canonical names
|
|
||||||
if c == "WEBCRAWLER_CONNECTOR":
|
if c == "WEBCRAWLER_CONNECTOR":
|
||||||
c = "CRAWLED_URL"
|
c = "CRAWLED_URL"
|
||||||
normalized.append(c)
|
normalized.append(c)
|
||||||
|
|
||||||
# de-dupe while preserving order + filter to valid connectors
|
# De-dupe (order-preserving), keeping only known + available connectors.
|
||||||
seen: set[str] = set()
|
seen: set[str] = set()
|
||||||
out: list[str] = []
|
out: list[str] = []
|
||||||
for c in normalized:
|
for c in normalized:
|
||||||
if c in seen:
|
if c in seen:
|
||||||
continue
|
continue
|
||||||
# Only include if it's a known connector AND available
|
|
||||||
if c not in _ALL_CONNECTORS:
|
if c not in _ALL_CONNECTORS:
|
||||||
continue
|
continue
|
||||||
if c not in valid_set:
|
if c not in valid_set:
|
||||||
|
|
@ -295,7 +282,7 @@ def _normalize_connectors(
|
||||||
seen.add(c)
|
seen.add(c)
|
||||||
out.append(c)
|
out.append(c)
|
||||||
|
|
||||||
# Fallback to all available if nothing matched
|
# Nothing matched: fall back to all available.
|
||||||
if not out:
|
if not out:
|
||||||
base = (
|
base = (
|
||||||
list(available_connectors)
|
list(available_connectors)
|
||||||
|
|
@ -377,39 +364,17 @@ def format_documents_for_context(
|
||||||
max_chunk_chars: int = _MAX_CHUNK_CHARS,
|
max_chunk_chars: int = _MAX_CHUNK_CHARS,
|
||||||
max_chunks_per_doc: int = 0,
|
max_chunks_per_doc: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""Format retrieved documents into an XML context string for the LLM.
|
||||||
Format retrieved documents into a readable context string for the LLM.
|
|
||||||
|
|
||||||
Documents are added in order (highest relevance first) until the character
|
Documents are emitted highest-relevance first until ``max_chars`` is hit.
|
||||||
budget is reached. Individual chunks are capped at ``max_chunk_chars`` and
|
``max_chunks_per_doc=0`` auto-computes a rank-adaptive cap so top results get
|
||||||
each document is limited to a dynamically computed chunk cap so a single
|
more chunks and no single large document monopolizes the budget.
|
||||||
large document cannot monopolize the output while still maximising the use
|
|
||||||
of available context space.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of document dictionaries from connector search
|
|
||||||
max_chars: Approximate character budget for the entire output.
|
|
||||||
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
|
|
||||||
max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means
|
|
||||||
auto-compute per document using a rank-adaptive formula so
|
|
||||||
higher-ranked documents receive more chunks.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string with document contents and metadata
|
|
||||||
"""
|
"""
|
||||||
if not documents:
|
if not documents:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Group chunks by document id (preferred) to produce the XML structure.
|
# Group chunks by document id, preserving chunk_id so [citation:123] works.
|
||||||
#
|
# ConnectorService returns document-grouped results ({document, chunks, source}).
|
||||||
# IMPORTANT: ConnectorService returns **document-grouped** results of the form:
|
|
||||||
# {
|
|
||||||
# "document": {...},
|
|
||||||
# "chunks": [{"chunk_id": 123, "content": "..."}, ...],
|
|
||||||
# "source": "NOTION_CONNECTOR" | "FILE" | ...
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# We must preserve chunk_id so citations like [citation:123] are possible.
|
|
||||||
grouped: dict[str, dict[str, Any]] = {}
|
grouped: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
|
|
@ -430,7 +395,7 @@ def format_documents_for_context(
|
||||||
or "UNKNOWN"
|
or "UNKNOWN"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Document identity (prefer document_id; otherwise fall back to type+title+url)
|
# Identity: prefer document_id, else type+title+url.
|
||||||
document_id_val = document_info.get("id")
|
document_id_val = document_info.get("id")
|
||||||
title = (
|
title = (
|
||||||
document_info.get("title") or metadata.get("title") or "Untitled Document"
|
document_info.get("title") or metadata.get("title") or "Untitled Document"
|
||||||
|
|
@ -460,7 +425,7 @@ def format_documents_for_context(
|
||||||
"chunks": [],
|
"chunks": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Prefer document-grouped chunks if available
|
# Prefer document-grouped chunks when present.
|
||||||
chunks_list = doc.get("chunks") if isinstance(doc, dict) else None
|
chunks_list = doc.get("chunks") if isinstance(doc, dict) else None
|
||||||
if isinstance(chunks_list, list) and chunks_list:
|
if isinstance(chunks_list, list) and chunks_list:
|
||||||
for ch in chunks_list:
|
for ch in chunks_list:
|
||||||
|
|
@ -492,7 +457,6 @@ def format_documents_for_context(
|
||||||
"BAIDU_SEARCH_API",
|
"BAIDU_SEARCH_API",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Render XML expected by citation instructions, respecting the char budget.
|
|
||||||
parts: list[str] = []
|
parts: list[str] = []
|
||||||
total_chars = 0
|
total_chars = 0
|
||||||
total_docs = len(grouped)
|
total_docs = len(grouped)
|
||||||
|
|
@ -594,30 +558,11 @@ async def search_knowledge_base_async(
|
||||||
available_document_types: list[str] | None = None,
|
available_document_types: list[str] | None = None,
|
||||||
max_input_tokens: int | None = None,
|
max_input_tokens: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""Search the knowledge base across connectors and return formatted results.
|
||||||
Search the user's knowledge base for relevant documents.
|
|
||||||
|
|
||||||
This is the async implementation that searches across multiple connectors.
|
``available_document_types`` lets local connectors with no indexed data be
|
||||||
|
skipped (no embedding / DB round-trip), and ``max_input_tokens`` sizes the
|
||||||
Args:
|
output to the model's context window.
|
||||||
query: The search query
|
|
||||||
search_space_id: The user's search space ID
|
|
||||||
db_session: Database session
|
|
||||||
connector_service: Initialized connector service
|
|
||||||
connectors_to_search: Optional list of connector types to search. If omitted, searches all.
|
|
||||||
top_k: Number of results per connector
|
|
||||||
start_date: Optional start datetime (UTC) for filtering documents
|
|
||||||
end_date: Optional end datetime (UTC) for filtering documents
|
|
||||||
available_connectors: Optional list of connectors actually available in the search space.
|
|
||||||
If provided, only these connectors will be searched.
|
|
||||||
available_document_types: Optional list of document types that actually have indexed
|
|
||||||
data. When provided, local connectors whose document type is
|
|
||||||
absent are skipped entirely (no embedding / DB round-trip).
|
|
||||||
max_input_tokens: Model context window size (tokens). Used to dynamically
|
|
||||||
size the output so it fits within the model's limits.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string with search results
|
|
||||||
"""
|
"""
|
||||||
perf = get_perf_logger()
|
perf = get_perf_logger()
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
|
|
|
||||||
|
|
@ -196,13 +196,8 @@ def _strip_wrapping_code_fences(text: str) -> str:
|
||||||
|
|
||||||
def _extract_metadata(content: str) -> dict[str, Any]:
|
def _extract_metadata(content: str) -> dict[str, Any]:
|
||||||
"""Extract metadata from generated Markdown content."""
|
"""Extract metadata from generated Markdown content."""
|
||||||
# Count section headings
|
|
||||||
headings = re.findall(r"^(#{1,6})\s+(.+)$", content, re.MULTILINE)
|
headings = re.findall(r"^(#{1,6})\s+(.+)$", content, re.MULTILINE)
|
||||||
|
|
||||||
# Word count
|
|
||||||
word_count = len(content.split())
|
word_count = len(content.split())
|
||||||
|
|
||||||
# Character count
|
|
||||||
char_count = len(content)
|
char_count = len(content)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -227,12 +222,11 @@ def _parse_sections(content: str) -> list[dict[str, str]]:
|
||||||
in_code_block = False
|
in_code_block = False
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
# Track code blocks to avoid matching headings inside them
|
# Track fences so headings inside code blocks aren't treated as splits.
|
||||||
stripped = line.strip()
|
stripped = line.strip()
|
||||||
if stripped.startswith("```"):
|
if stripped.startswith("```"):
|
||||||
in_code_block = not in_code_block
|
in_code_block = not in_code_block
|
||||||
|
|
||||||
# Only split on # or ## headings (not ### or deeper) and only outside code blocks
|
|
||||||
is_section_heading = (
|
is_section_heading = (
|
||||||
not in_code_block
|
not in_code_block
|
||||||
and re.match(r"^#{1,2}\s+", line)
|
and re.match(r"^#{1,2}\s+", line)
|
||||||
|
|
@ -240,7 +234,6 @@ def _parse_sections(content: str) -> list[dict[str, str]]:
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_section_heading:
|
if is_section_heading:
|
||||||
# Save previous section
|
|
||||||
if current_heading or current_body_lines:
|
if current_heading or current_body_lines:
|
||||||
sections.append(
|
sections.append(
|
||||||
{
|
{
|
||||||
|
|
@ -253,7 +246,6 @@ def _parse_sections(content: str) -> list[dict[str, str]]:
|
||||||
else:
|
else:
|
||||||
current_body_lines.append(line)
|
current_body_lines.append(line)
|
||||||
|
|
||||||
# Save last section
|
|
||||||
if current_heading or current_body_lines:
|
if current_heading or current_body_lines:
|
||||||
sections.append(
|
sections.append(
|
||||||
{
|
{
|
||||||
|
|
@ -292,7 +284,6 @@ async def _revise_with_sections(
|
||||||
Unchanged sections are kept byte-for-byte identical.
|
Unchanged sections are kept byte-for-byte identical.
|
||||||
Returns the revised content, or None to trigger full-document revision fallback.
|
Returns the revised content, or None to trigger full-document revision fallback.
|
||||||
"""
|
"""
|
||||||
# Parse report into sections
|
|
||||||
sections = _parse_sections(parent_content)
|
sections = _parse_sections(parent_content)
|
||||||
if len(sections) < 2:
|
if len(sections) < 2:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -300,7 +291,6 @@ async def _revise_with_sections(
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Build a sections listing for the LLM
|
|
||||||
sections_listing = ""
|
sections_listing = ""
|
||||||
for i, sec in enumerate(sections):
|
for i, sec in enumerate(sections):
|
||||||
heading = sec["heading"] or "(preamble — content before first heading)"
|
heading = sec["heading"] or "(preamble — content before first heading)"
|
||||||
|
|
@ -352,11 +342,9 @@ async def _revise_with_sections(
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Compute total operations for progress tracking
|
|
||||||
total_ops = len(modify_indices) + len(add_sections)
|
total_ops = len(modify_indices) + len(add_sections)
|
||||||
current_op = 0
|
current_op = 0
|
||||||
|
|
||||||
# Emit plan summary
|
|
||||||
parts = []
|
parts = []
|
||||||
if modify_indices:
|
if modify_indices:
|
||||||
parts.append(
|
parts.append(
|
||||||
|
|
@ -394,7 +382,6 @@ async def _revise_with_sections(
|
||||||
current_op += 1
|
current_op += 1
|
||||||
sec = sections[idx]
|
sec = sections[idx]
|
||||||
|
|
||||||
# Extract plain section name (strip markdown heading markers)
|
|
||||||
section_name = (
|
section_name = (
|
||||||
re.sub(r"^#+\s*", "", sec["heading"]).strip()
|
re.sub(r"^#+\s*", "", sec["heading"]).strip()
|
||||||
if sec["heading"]
|
if sec["heading"]
|
||||||
|
|
@ -412,7 +399,6 @@ async def _revise_with_sections(
|
||||||
f"{sec['heading']}\n\n{sec['body']}" if sec["heading"] else sec["body"]
|
f"{sec['heading']}\n\n{sec['body']}" if sec["heading"] else sec["body"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build context from surrounding sections
|
|
||||||
context_parts = []
|
context_parts = []
|
||||||
if idx > 0:
|
if idx > 0:
|
||||||
prev = sections[idx - 1]
|
prev = sections[idx - 1]
|
||||||
|
|
@ -442,7 +428,6 @@ async def _revise_with_sections(
|
||||||
revised_text = resp.content
|
revised_text = resp.content
|
||||||
if revised_text and isinstance(revised_text, str):
|
if revised_text and isinstance(revised_text, str):
|
||||||
revised_text = _strip_wrapping_code_fences(revised_text).strip()
|
revised_text = _strip_wrapping_code_fences(revised_text).strip()
|
||||||
# Parse the LLM output back into heading + body
|
|
||||||
revised_parsed = _parse_sections(revised_text)
|
revised_parsed = _parse_sections(revised_text)
|
||||||
if revised_parsed:
|
if revised_parsed:
|
||||||
revised_sections[idx] = revised_parsed[0]
|
revised_sections[idx] = revised_parsed[0]
|
||||||
|
|
@ -465,7 +450,6 @@ async def _revise_with_sections(
|
||||||
heading = add_info.get("heading", "## New Section")
|
heading = add_info.get("heading", "## New Section")
|
||||||
description = add_info.get("description", "")
|
description = add_info.get("description", "")
|
||||||
|
|
||||||
# Extract plain section name for progress display
|
|
||||||
plain_heading = re.sub(r"^#+\s*", "", heading).strip()
|
plain_heading = re.sub(r"^#+\s*", "", heading).strip()
|
||||||
dispatch_custom_event(
|
dispatch_custom_event(
|
||||||
"report_progress",
|
"report_progress",
|
||||||
|
|
@ -475,7 +459,6 @@ async def _revise_with_sections(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build context from the surrounding sections at the insertion point
|
|
||||||
ctx_parts = []
|
ctx_parts = []
|
||||||
if 0 <= after_idx < len(revised_sections):
|
if 0 <= after_idx < len(revised_sections):
|
||||||
before_sec = revised_sections[after_idx]
|
before_sec = revised_sections[after_idx]
|
||||||
|
|
@ -542,36 +525,13 @@ def create_generate_report_tool(
|
||||||
available_connectors: list[str] | None = None,
|
available_connectors: list[str] | None = None,
|
||||||
available_document_types: list[str] | None = None,
|
available_document_types: list[str] | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""Create the generate_report tool with injected dependencies.
|
||||||
Factory function to create the generate_report tool with injected dependencies.
|
|
||||||
|
|
||||||
The tool generates a Markdown report inline using the search space's
|
Uses short-lived DB sessions per operation so no connection is held during
|
||||||
document summary LLM, saves it to the database, and returns immediately.
|
the long LLM call. Generation: new reports are single-shot; revisions try
|
||||||
|
section-level first (unchanged sections preserved) and fall back to full-doc.
|
||||||
Uses short-lived database sessions for each DB operation so no connection
|
Source strategies: provided/conversation (use source_content), kb_search
|
||||||
is held during the long LLM API call.
|
(internal KB queries), auto (KB search only when source_content is thin).
|
||||||
|
|
||||||
Generation strategies:
|
|
||||||
- New reports: single-shot generation (1 LLM call)
|
|
||||||
- Revisions (targeted edits): section-level (unchanged sections preserved)
|
|
||||||
- Revisions (global changes): full-document revision fallback
|
|
||||||
|
|
||||||
Source strategies:
|
|
||||||
- "provided"/"conversation": use only the supplied source_content
|
|
||||||
- "kb_search": search the knowledge base internally using targeted queries
|
|
||||||
- "auto": use source_content if sufficient, otherwise fall back to KB search
|
|
||||||
|
|
||||||
Args:
|
|
||||||
search_space_id: The user's search space ID
|
|
||||||
thread_id: The chat thread ID for associating the report
|
|
||||||
connector_service: Optional connector service for internal KB search.
|
|
||||||
When provided, the tool can search the knowledge base internally
|
|
||||||
(used by the "kb_search" and "auto" source strategies).
|
|
||||||
available_connectors: Optional list of connector types available in the
|
|
||||||
search space (used to scope internal KB searches).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A configured tool function for generating reports
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -693,7 +653,7 @@ def create_generate_report_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Dict with status, report_id, title, word_count, and message.
|
Dict with status, report_id, title, word_count, and message.
|
||||||
"""
|
"""
|
||||||
# Initialize version tracking variables (used by _save_failed_report closure)
|
# Shared with the _save_failed_report closure.
|
||||||
parent_report_content: str | None = None
|
parent_report_content: str | None = None
|
||||||
report_group_id: int | None = None
|
report_group_id: int | None = None
|
||||||
|
|
||||||
|
|
@ -733,7 +693,7 @@ def create_generate_report_tool(
|
||||||
session.add(failed_report)
|
session.add(failed_report)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(failed_report)
|
await session.refresh(failed_report)
|
||||||
# If this is a new group (v1 failed), set group to self
|
# New group (v1 failed): point the group at itself.
|
||||||
if not failed_report.report_group_id:
|
if not failed_report.report_group_id:
|
||||||
failed_report.report_group_id = failed_report.id
|
failed_report.report_group_id = failed_report.id
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
@ -749,8 +709,8 @@ def create_generate_report_tool(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# ── Phase 1: READ (short-lived session) ──────────────────────
|
# ── Phase 1: READ (short-lived session) ──────────────────────
|
||||||
# Fetch parent report and LLM config, then close the session
|
# Fetch parent report + LLM config, then release the connection
|
||||||
# so no DB connection is held during the long LLM call.
|
# before the long LLM call.
|
||||||
async with shielded_async_session() as read_session:
|
async with shielded_async_session() as read_session:
|
||||||
if parent_report_id:
|
if parent_report_id:
|
||||||
parent_report = await read_session.get(Report, parent_report_id)
|
parent_report = await read_session.get(Report, parent_report_id)
|
||||||
|
|
@ -768,7 +728,6 @@ def create_generate_report_tool(
|
||||||
)
|
)
|
||||||
|
|
||||||
llm = await get_document_summary_llm(read_session, search_space_id)
|
llm = await get_document_summary_llm(read_session, search_space_id)
|
||||||
# read_session closed — connection returned to pool
|
|
||||||
|
|
||||||
if not llm:
|
if not llm:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
|
|
@ -785,7 +744,6 @@ def create_generate_report_tool(
|
||||||
error=error_msg,
|
error=error_msg,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build the user instructions string
|
|
||||||
user_instructions_section = ""
|
user_instructions_section = ""
|
||||||
if user_instructions:
|
if user_instructions:
|
||||||
user_instructions_section = (
|
user_instructions_section = (
|
||||||
|
|
@ -829,7 +787,7 @@ def create_generate_report_tool(
|
||||||
try:
|
try:
|
||||||
from .knowledge_base import search_knowledge_base_async
|
from .knowledge_base import search_knowledge_base_async
|
||||||
|
|
||||||
# Run all queries in parallel, each with its own session
|
# Each query gets its own short-lived session.
|
||||||
async def _run_single_query(q: str) -> str:
|
async def _run_single_query(q: str) -> str:
|
||||||
async with shielded_async_session() as kb_session:
|
async with shielded_async_session() as kb_session:
|
||||||
kb_connector_svc = ConnectorService(
|
kb_connector_svc = ConnectorService(
|
||||||
|
|
@ -849,7 +807,6 @@ def create_generate_report_tool(
|
||||||
*[_run_single_query(q) for q in search_queries[:5]]
|
*[_run_single_query(q) for q in search_queries[:5]]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge non-empty results into source_content
|
|
||||||
kb_text_parts = [r for r in kb_results if r and r.strip()]
|
kb_text_parts = [r for r in kb_results if r and r.strip()]
|
||||||
if kb_text_parts:
|
if kb_text_parts:
|
||||||
kb_combined = "\n\n---\n\n".join(kb_text_parts)
|
kb_combined = "\n\n---\n\n".join(kb_text_parts)
|
||||||
|
|
@ -903,9 +860,9 @@ def create_generate_report_tool(
|
||||||
"provided. Using source_content as-is."
|
"provided. Using source_content as-is."
|
||||||
)
|
)
|
||||||
|
|
||||||
capped_source = effective_source[:100000] # Cap source content
|
capped_source = effective_source[:100000]
|
||||||
|
|
||||||
# Length constraint — only when user explicitly asks for brevity
|
# Length constraint only when the user explicitly asked for brevity.
|
||||||
length_instruction = ""
|
length_instruction = ""
|
||||||
if report_style == "brief":
|
if report_style == "brief":
|
||||||
length_instruction = (
|
length_instruction = (
|
||||||
|
|
@ -920,11 +877,8 @@ def create_generate_report_tool(
|
||||||
report_content: str | None = None
|
report_content: str | None = None
|
||||||
|
|
||||||
if parent_report_content:
|
if parent_report_content:
|
||||||
# ─── REVISION MODE ───────────────────────────────────────
|
# Revision mode: section-level first (preserves untouched
|
||||||
# Strategy: Try section-level revision first (preserves
|
# sections), falling back to full-doc revision.
|
||||||
# unchanged sections byte-for-byte). Falls back to full-
|
|
||||||
# document revision if section identification fails or if
|
|
||||||
# all sections need changes.
|
|
||||||
dispatch_custom_event(
|
dispatch_custom_event(
|
||||||
"report_progress",
|
"report_progress",
|
||||||
{
|
{
|
||||||
|
|
@ -946,7 +900,6 @@ def create_generate_report_tool(
|
||||||
)
|
)
|
||||||
|
|
||||||
if report_content is None:
|
if report_content is None:
|
||||||
# Fallback: full-document revision
|
|
||||||
dispatch_custom_event(
|
dispatch_custom_event(
|
||||||
"report_progress",
|
"report_progress",
|
||||||
{"phase": "writing", "message": "Rewriting your full report"},
|
{"phase": "writing", "message": "Rewriting your full report"},
|
||||||
|
|
@ -969,9 +922,7 @@ def create_generate_report_tool(
|
||||||
report_content = response.content
|
report_content = response.content
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# ─── NEW REPORT MODE ─────────────────────────────────────
|
# New report: single-shot generation (one LLM call).
|
||||||
# Single-shot generation: one LLM call produces the full
|
|
||||||
# report. Fast, globally coherent, and cost-efficient.
|
|
||||||
dispatch_custom_event(
|
dispatch_custom_event(
|
||||||
"report_progress",
|
"report_progress",
|
||||||
{"phase": "writing", "message": "Writing your report"},
|
{"phase": "writing", "message": "Writing your report"},
|
||||||
|
|
@ -991,8 +942,6 @@ def create_generate_report_tool(
|
||||||
response = await llm.ainvoke([HumanMessage(content=prompt)])
|
response = await llm.ainvoke([HumanMessage(content=prompt)])
|
||||||
report_content = response.content
|
report_content = response.content
|
||||||
|
|
||||||
# ── Validate LLM output ──────────────────────────────────────
|
|
||||||
|
|
||||||
if not report_content or not isinstance(report_content, str):
|
if not report_content or not isinstance(report_content, str):
|
||||||
error_msg = "LLM returned empty or invalid content"
|
error_msg = "LLM returned empty or invalid content"
|
||||||
report_id = await _save_failed_report(error_msg)
|
report_id = await _save_failed_report(error_msg)
|
||||||
|
|
@ -1029,14 +978,12 @@ def create_generate_report_tool(
|
||||||
if report_content.rstrip().endswith("---"):
|
if report_content.rstrip().endswith("---"):
|
||||||
report_content = report_content.rstrip()[:-3].rstrip()
|
report_content = report_content.rstrip()[:-3].rstrip()
|
||||||
|
|
||||||
# Append exactly one standard disclaimer
|
# Append exactly one standard footer.
|
||||||
report_content += "\n\n---\n\n" + _REPORT_FOOTER
|
report_content += "\n\n---\n\n" + _REPORT_FOOTER
|
||||||
|
|
||||||
# Extract metadata (includes "status": "ready")
|
|
||||||
metadata = _extract_metadata(report_content)
|
metadata = _extract_metadata(report_content)
|
||||||
|
|
||||||
# ── Phase 3: WRITE (short-lived session) ─────────────────────
|
# ── Phase 3: WRITE (short-lived session) ─────────────────────
|
||||||
# Save the report to the database, then close the session.
|
|
||||||
async with shielded_async_session() as write_session:
|
async with shielded_async_session() as write_session:
|
||||||
report = Report(
|
report = Report(
|
||||||
title=topic,
|
title=topic,
|
||||||
|
|
@ -1051,14 +998,13 @@ def create_generate_report_tool(
|
||||||
await write_session.commit()
|
await write_session.commit()
|
||||||
await write_session.refresh(report)
|
await write_session.refresh(report)
|
||||||
|
|
||||||
# If this is a brand-new report (v1), set report_group_id = own id
|
# Brand-new report (v1): point the group at itself.
|
||||||
if not report.report_group_id:
|
if not report.report_group_id:
|
||||||
report.report_group_id = report.id
|
report.report_group_id = report.id
|
||||||
await write_session.commit()
|
await write_session.commit()
|
||||||
|
|
||||||
saved_report_id = report.id
|
saved_report_id = report.id
|
||||||
saved_group_id = report.report_group_id
|
saved_group_id = report.report_group_id
|
||||||
# write_session closed — connection returned to pool
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[generate_report] Created report {saved_report_id} "
|
f"[generate_report] Created report {saved_report_id} "
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ def extract_domain(url: str) -> str:
|
||||||
try:
|
try:
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
domain = parsed.netloc
|
domain = parsed.netloc
|
||||||
# Remove 'www.' prefix if present
|
|
||||||
if domain.startswith("www."):
|
if domain.startswith("www."):
|
||||||
domain = domain[4:]
|
domain = domain[4:]
|
||||||
return domain
|
return domain
|
||||||
|
|
@ -47,14 +46,13 @@ def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]:
|
||||||
if len(content) <= max_length:
|
if len(content) <= max_length:
|
||||||
return content, False
|
return content, False
|
||||||
|
|
||||||
# Try to truncate at a sentence boundary
|
# Prefer truncating at a sentence/paragraph boundary.
|
||||||
truncated = content[:max_length]
|
truncated = content[:max_length]
|
||||||
last_period = truncated.rfind(".")
|
last_period = truncated.rfind(".")
|
||||||
last_newline = truncated.rfind("\n\n")
|
last_newline = truncated.rfind("\n\n")
|
||||||
|
|
||||||
# Use the later of the two boundaries, or just truncate
|
|
||||||
boundary = max(last_period, last_newline)
|
boundary = max(last_period, last_newline)
|
||||||
if boundary > max_length * 0.8: # Only use boundary if it's not too far back
|
if boundary > max_length * 0.8: # only if the boundary isn't too far back
|
||||||
truncated = content[: boundary + 1]
|
truncated = content[: boundary + 1]
|
||||||
|
|
||||||
return truncated + "\n\n[Content truncated...]", True
|
return truncated + "\n\n[Content truncated...]", True
|
||||||
|
|
@ -105,8 +103,8 @@ async def _scrape_youtube_video(
|
||||||
http_client.proxies.update(residential_proxies)
|
http_client.proxies.update(residential_proxies)
|
||||||
ytt_api = YouTubeTranscriptApi(http_client=http_client)
|
ytt_api = YouTubeTranscriptApi(http_client=http_client)
|
||||||
|
|
||||||
# List all available transcripts and pick the first one
|
# Pick the first transcript (video's primary language) rather than
|
||||||
# (the video's primary language) instead of defaulting to English
|
# defaulting to English.
|
||||||
transcript_list = ytt_api.list(video_id)
|
transcript_list = ytt_api.list(video_id)
|
||||||
transcript = next(iter(transcript_list))
|
transcript = next(iter(transcript_list))
|
||||||
captions = transcript.fetch()
|
captions = transcript.fetch()
|
||||||
|
|
@ -128,10 +126,8 @@ async def _scrape_youtube_video(
|
||||||
logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}")
|
logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}")
|
||||||
transcript_text = f"No captions available for this video. Error: {e!s}"
|
transcript_text = f"No captions available for this video. Error: {e!s}"
|
||||||
|
|
||||||
# Build combined content
|
|
||||||
content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}"
|
content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}"
|
||||||
|
|
||||||
# Truncate if needed
|
|
||||||
content, was_truncated = truncate_content(content, max_length)
|
content, was_truncated = truncate_content(content, max_length)
|
||||||
word_count = len(content.split())
|
word_count = len(content.split())
|
||||||
|
|
||||||
|
|
@ -206,20 +202,16 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
|
||||||
scrape_id = generate_scrape_id(url)
|
scrape_id = generate_scrape_id(url)
|
||||||
domain = extract_domain(url)
|
domain = extract_domain(url)
|
||||||
|
|
||||||
# Validate and normalize URL
|
|
||||||
if not url.startswith(("http://", "https://")):
|
if not url.startswith(("http://", "https://")):
|
||||||
url = f"https://{url}"
|
url = f"https://{url}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check if this is a YouTube URL and use transcript API instead
|
# YouTube URLs use the transcript API instead of crawling.
|
||||||
video_id = get_youtube_video_id(url)
|
video_id = get_youtube_video_id(url)
|
||||||
if video_id:
|
if video_id:
|
||||||
return await _scrape_youtube_video(url, video_id, max_length)
|
return await _scrape_youtube_video(url, video_id, max_length)
|
||||||
|
|
||||||
# Create webcrawler connector
|
|
||||||
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
|
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
|
||||||
|
|
||||||
# Crawl the URL
|
|
||||||
result, error = await connector.crawl_url(url, formats=["markdown"])
|
result, error = await connector.crawl_url(url, formats=["markdown"])
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
|
|
@ -244,28 +236,21 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
|
||||||
"error": "No content returned from crawler",
|
"error": "No content returned from crawler",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Extract content and metadata
|
|
||||||
content = result.get("content", "")
|
content = result.get("content", "")
|
||||||
metadata = result.get("metadata", {})
|
metadata = result.get("metadata", {})
|
||||||
|
|
||||||
# Get title from metadata
|
|
||||||
title = metadata.get("title", "")
|
title = metadata.get("title", "")
|
||||||
if not title:
|
if not title:
|
||||||
title = domain or url.split("/")[-1] or "Webpage"
|
title = domain or url.split("/")[-1] or "Webpage"
|
||||||
|
|
||||||
# Get description from metadata
|
|
||||||
description = metadata.get("description", "")
|
description = metadata.get("description", "")
|
||||||
if not description and content:
|
if not description and content:
|
||||||
# Use first paragraph as description
|
|
||||||
first_para = content.split("\n\n")[0] if content else ""
|
first_para = content.split("\n\n")[0] if content else ""
|
||||||
description = (
|
description = (
|
||||||
first_para[:300] + "..." if len(first_para) > 300 else first_para
|
first_para[:300] + "..." if len(first_para) > 300 else first_para
|
||||||
)
|
)
|
||||||
|
|
||||||
# Truncate content if needed
|
|
||||||
content, was_truncated = truncate_content(content, max_length)
|
content, was_truncated = truncate_content(content, max_length)
|
||||||
|
|
||||||
# Calculate word count
|
|
||||||
word_count = len(content.split())
|
word_count = len(content.split())
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -92,15 +92,9 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
# Provider mapping for LiteLLM model string construction.
|
# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives
|
||||||
#
|
# in provider_capabilities so the YAML loader can resolve prefixes during
|
||||||
# Single source of truth lives in
|
# app.config init without importing the agent/tools tree.
|
||||||
# :mod:`app.services.provider_capabilities` so the YAML loader (which
|
|
||||||
# runs during ``app.config`` class-body init) can resolve provider
|
|
||||||
# prefixes without dragging the agent / tools tree into module load
|
|
||||||
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
|
|
||||||
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
|
|
||||||
# tests) keep working unchanged.
|
|
||||||
from app.services.provider_capabilities import ( # noqa: E402
|
from app.services.provider_capabilities import ( # noqa: E402
|
||||||
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
||||||
)
|
)
|
||||||
|
|
@ -157,25 +151,14 @@ class AgentConfig:
|
||||||
anonymous_enabled: bool = False
|
anonymous_enabled: bool = False
|
||||||
quota_reserve_tokens: int | None = None
|
quota_reserve_tokens: int | None = None
|
||||||
|
|
||||||
# Capability flag: best-effort True for the chat selector / catalog.
|
# Default-allow: only the streaming safety net (is_known_text_only_chat_model)
|
||||||
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
|
# actually blocks on False, so defaulting False would silently hide
|
||||||
# which prefers OpenRouter's ``architecture.input_modalities`` and
|
# vision-capable models. Resolved via derive_supports_image_input.
|
||||||
# otherwise consults LiteLLM's authoritative model map. Default True
|
|
||||||
# is the conservative-allow stance — the streaming-task safety net
|
|
||||||
# (``is_known_text_only_chat_model``) is the *only* place a False
|
|
||||||
# actually blocks a request. Setting this to False here without an
|
|
||||||
# authoritative source would silently hide vision-capable models
|
|
||||||
# (the regression we're fixing).
|
|
||||||
supports_image_input: bool = True
|
supports_image_input: bool = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_auto_mode(cls) -> "AgentConfig":
|
def from_auto_mode(cls) -> "AgentConfig":
|
||||||
"""
|
"""Build an AgentConfig for Auto mode (LiteLLM Router load balancing)."""
|
||||||
Create an AgentConfig for Auto mode (LiteLLM Router load balancing).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AgentConfig instance configured for Auto mode
|
|
||||||
"""
|
|
||||||
return cls(
|
return cls(
|
||||||
provider="AUTO",
|
provider="AUTO",
|
||||||
model_name="auto",
|
model_name="auto",
|
||||||
|
|
@ -193,27 +176,15 @@ class AgentConfig:
|
||||||
is_premium=False,
|
is_premium=False,
|
||||||
anonymous_enabled=False,
|
anonymous_enabled=False,
|
||||||
quota_reserve_tokens=None,
|
quota_reserve_tokens=None,
|
||||||
# Auto routes across the configured pool, which usually
|
# Auto fails over across the pool, so a non-vision deployment's 404
|
||||||
# contains at least one vision-capable deployment; the router
|
# is just an allowed_fails event rather than a hard block.
|
||||||
# will surface a 404 from a non-vision deployment as a normal
|
|
||||||
# ``allowed_fails`` event and fail over rather than blocking
|
|
||||||
# the request outright.
|
|
||||||
supports_image_input=True,
|
supports_image_input=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_new_llm_config(cls, config) -> "AgentConfig":
|
def from_new_llm_config(cls, config) -> "AgentConfig":
|
||||||
"""
|
"""Build an AgentConfig from a NewLLMConfig database model."""
|
||||||
Create an AgentConfig from a NewLLMConfig database model.
|
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
|
||||||
|
|
||||||
Args:
|
|
||||||
config: NewLLMConfig database model instance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AgentConfig instance
|
|
||||||
"""
|
|
||||||
# Lazy import to avoid pulling provider_capabilities (and its
|
|
||||||
# transitive litellm import) into module-init order.
|
|
||||||
from app.services.provider_capabilities import derive_supports_image_input
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
|
|
||||||
provider_value = (
|
provider_value = (
|
||||||
|
|
@ -245,10 +216,8 @@ class AgentConfig:
|
||||||
is_premium=False,
|
is_premium=False,
|
||||||
anonymous_enabled=False,
|
anonymous_enabled=False,
|
||||||
quota_reserve_tokens=None,
|
quota_reserve_tokens=None,
|
||||||
# BYOK rows have no operator-curated capability flag, so we
|
# BYOK rows have no curated flag; ask LiteLLM (default-allow on
|
||||||
# ask LiteLLM (default-allow on unknown). The streaming
|
# unknown). The streaming safety net still blocks explicit text-only.
|
||||||
# safety net still blocks if the model is *explicitly*
|
|
||||||
# marked text-only.
|
|
||||||
supports_image_input=derive_supports_image_input(
|
supports_image_input=derive_supports_image_input(
|
||||||
provider=provider_value,
|
provider=provider_value,
|
||||||
model_name=config.model_name,
|
model_name=config.model_name,
|
||||||
|
|
@ -259,25 +228,14 @@ class AgentConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
|
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
|
||||||
|
"""Build an AgentConfig from a YAML configuration dictionary.
|
||||||
|
|
||||||
|
Supports the same prompt fields as NewLLMConfig (system_instructions,
|
||||||
|
use_default_system_instructions, citations_enabled).
|
||||||
"""
|
"""
|
||||||
Create an AgentConfig from a YAML configuration dictionary.
|
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
|
||||||
|
|
||||||
YAML configs now support the same prompt configuration fields as NewLLMConfig:
|
|
||||||
- system_instructions: Custom system instructions (empty string uses defaults)
|
|
||||||
- use_default_system_instructions: Whether to use default instructions
|
|
||||||
- citations_enabled: Whether citations are enabled
|
|
||||||
|
|
||||||
Args:
|
|
||||||
yaml_config: Configuration dictionary from YAML file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AgentConfig instance
|
|
||||||
"""
|
|
||||||
# Lazy import to avoid pulling provider_capabilities (and its
|
|
||||||
# transitive litellm import) into module-init order.
|
|
||||||
from app.services.provider_capabilities import derive_supports_image_input
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
|
|
||||||
# Get system instructions from YAML, default to empty string
|
|
||||||
system_instructions = yaml_config.get("system_instructions", "")
|
system_instructions = yaml_config.get("system_instructions", "")
|
||||||
|
|
||||||
provider = yaml_config.get("provider", "").upper()
|
provider = yaml_config.get("provider", "").upper()
|
||||||
|
|
@ -290,13 +248,8 @@ class AgentConfig:
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Explicit YAML override wins; otherwise derive from LiteLLM /
|
# Explicit YAML override wins; otherwise re-derive (the hot-reload file
|
||||||
# OpenRouter modalities. The YAML loader already populates this
|
# fallback reaches this method without the loader having populated it).
|
||||||
# field, but this method is also called from
|
|
||||||
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
|
|
||||||
# so we re-derive here for safety. The bool() coercion preserves
|
|
||||||
# the loader's behaviour for explicit ``true`` / ``false``
|
|
||||||
# strings that PyYAML may surface.
|
|
||||||
if "supports_image_input" in yaml_config:
|
if "supports_image_input" in yaml_config:
|
||||||
supports_image_input = bool(yaml_config.get("supports_image_input"))
|
supports_image_input = bool(yaml_config.get("supports_image_input"))
|
||||||
else:
|
else:
|
||||||
|
|
@ -314,7 +267,6 @@ class AgentConfig:
|
||||||
api_base=yaml_config.get("api_base"),
|
api_base=yaml_config.get("api_base"),
|
||||||
custom_provider=custom_provider,
|
custom_provider=custom_provider,
|
||||||
litellm_params=yaml_config.get("litellm_params"),
|
litellm_params=yaml_config.get("litellm_params"),
|
||||||
# Prompt configuration from YAML (with defaults for backwards compatibility)
|
|
||||||
system_instructions=system_instructions if system_instructions else None,
|
system_instructions=system_instructions if system_instructions else None,
|
||||||
use_default_system_instructions=yaml_config.get(
|
use_default_system_instructions=yaml_config.get(
|
||||||
"use_default_system_instructions", True
|
"use_default_system_instructions", True
|
||||||
|
|
@ -332,20 +284,10 @@ class AgentConfig:
|
||||||
|
|
||||||
|
|
||||||
def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
|
def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
|
||||||
"""
|
"""Load a specific LLM config from global_llm_config.yaml."""
|
||||||
Load a specific LLM config from global_llm_config.yaml.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llm_config_id: The id of the config to load (default: -1)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LLM config dict or None if not found
|
|
||||||
"""
|
|
||||||
# Get the config file path
|
|
||||||
base_dir = Path(__file__).resolve().parent.parent.parent.parent
|
base_dir = Path(__file__).resolve().parent.parent.parent.parent
|
||||||
config_file = base_dir / "app" / "config" / "global_llm_config.yaml"
|
config_file = base_dir / "app" / "config" / "global_llm_config.yaml"
|
||||||
|
|
||||||
# Fallback to example file if main config doesn't exist
|
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
config_file = base_dir / "app" / "config" / "global_llm_config.example.yaml"
|
config_file = base_dir / "app" / "config" / "global_llm_config.example.yaml"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
|
|
@ -368,24 +310,17 @@ def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
|
||||||
|
|
||||||
|
|
||||||
def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
|
def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
|
||||||
"""
|
"""Load a global LLM config by ID, checking in-memory configs first.
|
||||||
Load a global LLM config by ID, checking in-memory configs first.
|
|
||||||
|
|
||||||
This handles both static YAML configs and dynamically injected configs
|
In-memory covers both static YAML and dynamically injected configs (e.g.
|
||||||
(e.g. OpenRouter integration models that only exist in memory).
|
OpenRouter integration models that only exist in memory).
|
||||||
|
|
||||||
Args:
|
|
||||||
llm_config_id: The negative ID of the global config to load
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LLM config dict or None if not found
|
|
||||||
"""
|
"""
|
||||||
from app.config import config as app_config
|
from app.config import config as app_config
|
||||||
|
|
||||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||||
if cfg.get("id") == llm_config_id:
|
if cfg.get("id") == llm_config_id:
|
||||||
return cfg
|
return cfg
|
||||||
# Fallback to YAML file read (covers edge cases like hot-reload)
|
# Fallback to YAML file read (covers hot-reload edge cases).
|
||||||
return load_llm_config_from_yaml(llm_config_id)
|
return load_llm_config_from_yaml(llm_config_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -393,17 +328,7 @@ async def load_new_llm_config_from_db(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
config_id: int,
|
config_id: int,
|
||||||
) -> "AgentConfig | None":
|
) -> "AgentConfig | None":
|
||||||
"""
|
"""Load a NewLLMConfig from the database by ID."""
|
||||||
Load a NewLLMConfig from the database by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: AsyncSession for database access
|
|
||||||
config_id: The ID of the NewLLMConfig to load
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AgentConfig instance or None if not found
|
|
||||||
"""
|
|
||||||
# Import here to avoid circular imports
|
|
||||||
from app.db import NewLLMConfig
|
from app.db import NewLLMConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -426,26 +351,13 @@ async def load_agent_llm_config_for_search_space(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
) -> "AgentConfig | None":
|
) -> "AgentConfig | None":
|
||||||
|
"""Load the agent LLM config for a search space via its agent_llm_id.
|
||||||
|
|
||||||
|
Positive id -> DB; negative -> YAML; None -> first global config (-1).
|
||||||
"""
|
"""
|
||||||
Load the agent LLM configuration for a search space.
|
|
||||||
|
|
||||||
This loads the LLM config based on the search space's agent_llm_id setting:
|
|
||||||
- Positive ID: Load from NewLLMConfig database table
|
|
||||||
- Negative ID: Load from YAML global configs
|
|
||||||
- None: Falls back to first global config (id=-1)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: AsyncSession for database access
|
|
||||||
search_space_id: The search space ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AgentConfig instance or None if not found
|
|
||||||
"""
|
|
||||||
# Import here to avoid circular imports
|
|
||||||
from app.db import SearchSpace
|
from app.db import SearchSpace
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get the search space to check its agent_llm_id preference
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||||
)
|
)
|
||||||
|
|
@ -455,12 +367,9 @@ async def load_agent_llm_config_for_search_space(
|
||||||
print(f"Error: SearchSpace with id {search_space_id} not found")
|
print(f"Error: SearchSpace with id {search_space_id} not found")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Use agent_llm_id from search space, fallback to -1 (first global config)
|
|
||||||
config_id = (
|
config_id = (
|
||||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load the config using the unified loader
|
|
||||||
return await load_agent_config(session, config_id, search_space_id)
|
return await load_agent_config(session, config_id, search_space_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading agent LLM config for search space {search_space_id}: {e}")
|
print(f"Error loading agent LLM config for search space {search_space_id}: {e}")
|
||||||
|
|
@ -472,23 +381,7 @@ async def load_agent_config(
|
||||||
config_id: int,
|
config_id: int,
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
) -> "AgentConfig | None":
|
) -> "AgentConfig | None":
|
||||||
"""
|
"""Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB."""
|
||||||
Load an agent configuration, supporting Auto mode, YAML, and database configs.
|
|
||||||
|
|
||||||
This is the main entry point for loading configurations:
|
|
||||||
- ID 0: Auto mode (uses LiteLLM Router for load balancing)
|
|
||||||
- Negative IDs: Load from YAML file (global configs)
|
|
||||||
- Positive IDs: Load from NewLLMConfig database table
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: AsyncSession for database access
|
|
||||||
config_id: The config ID (0 for Auto, negative for YAML, positive for database)
|
|
||||||
search_space_id: Optional search space ID for context
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AgentConfig instance or None if not found
|
|
||||||
"""
|
|
||||||
# Auto mode (ID 0) - use LiteLLM Router
|
|
||||||
if is_auto_mode(config_id):
|
if is_auto_mode(config_id):
|
||||||
if not LLMRouterService.is_initialized():
|
if not LLMRouterService.is_initialized():
|
||||||
print("Error: Auto mode requested but LLM Router not initialized")
|
print("Error: Auto mode requested but LLM Router not initialized")
|
||||||
|
|
@ -496,33 +389,22 @@ async def load_agent_config(
|
||||||
return AgentConfig.from_auto_mode()
|
return AgentConfig.from_auto_mode()
|
||||||
|
|
||||||
if config_id < 0:
|
if config_id < 0:
|
||||||
# Check in-memory configs first (includes static YAML + dynamic OpenRouter)
|
# In-memory covers static YAML + dynamic OpenRouter configs.
|
||||||
from app.config import config as app_config
|
from app.config import config as app_config
|
||||||
|
|
||||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||||
if cfg.get("id") == config_id:
|
if cfg.get("id") == config_id:
|
||||||
return AgentConfig.from_yaml_config(cfg)
|
return AgentConfig.from_yaml_config(cfg)
|
||||||
# Fallback to YAML file read for safety
|
|
||||||
yaml_config = load_llm_config_from_yaml(config_id)
|
yaml_config = load_llm_config_from_yaml(config_id)
|
||||||
if yaml_config:
|
if yaml_config:
|
||||||
return AgentConfig.from_yaml_config(yaml_config)
|
return AgentConfig.from_yaml_config(yaml_config)
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
# Load from database (NewLLMConfig)
|
|
||||||
return await load_new_llm_config_from_db(session, config_id)
|
return await load_new_llm_config_from_db(session, config_id)
|
||||||
|
|
||||||
|
|
||||||
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||||
"""
|
"""Create a ChatLiteLLM instance from a global LLM config dictionary."""
|
||||||
Create a ChatLiteLLM instance from a global LLM config dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llm_config: LLM configuration dictionary from YAML
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatLiteLLM instance or None on error
|
|
||||||
"""
|
|
||||||
# Build the model string
|
|
||||||
if llm_config.get("custom_provider"):
|
if llm_config.get("custom_provider"):
|
||||||
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
|
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
|
||||||
else:
|
else:
|
||||||
|
|
@ -530,27 +412,20 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||||
model_string = f"{provider_prefix}/{llm_config['model_name']}"
|
model_string = f"{provider_prefix}/{llm_config['model_name']}"
|
||||||
|
|
||||||
# Create ChatLiteLLM instance with streaming enabled
|
|
||||||
litellm_kwargs = {
|
litellm_kwargs = {
|
||||||
"model": model_string,
|
"model": model_string,
|
||||||
"api_key": llm_config.get("api_key"),
|
"api_key": llm_config.get("api_key"),
|
||||||
"streaming": True, # Enable streaming for real-time token streaming
|
"streaming": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add optional parameters
|
|
||||||
if llm_config.get("api_base"):
|
if llm_config.get("api_base"):
|
||||||
litellm_kwargs["api_base"] = llm_config["api_base"]
|
litellm_kwargs["api_base"] = llm_config["api_base"]
|
||||||
|
|
||||||
# Add any additional litellm parameters
|
|
||||||
if llm_config.get("litellm_params"):
|
if llm_config.get("litellm_params"):
|
||||||
litellm_kwargs.update(llm_config["litellm_params"])
|
litellm_kwargs.update(llm_config["litellm_params"])
|
||||||
|
|
||||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
_attach_model_profile(llm, model_string)
|
_attach_model_profile(llm, model_string)
|
||||||
# Configure LiteLLM-native prompt caching (cache_control_injection_points
|
# agent_config=None: the YAML path lacks structured provider intent, so set
|
||||||
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
|
# only the universal cache_control_injection_points.
|
||||||
# ``agent_config=None`` here — the YAML path doesn't have provider intent
|
|
||||||
# in a structured form, so we set only the universal injection points.
|
|
||||||
apply_litellm_prompt_caching(llm)
|
apply_litellm_prompt_caching(llm)
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
|
|
@ -558,19 +433,7 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||||
def create_chat_litellm_from_agent_config(
|
def create_chat_litellm_from_agent_config(
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||||
"""
|
"""Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config."""
|
||||||
Create a ChatLiteLLM or ChatLiteLLMRouter instance from an AgentConfig.
|
|
||||||
|
|
||||||
For Auto mode configs, returns a ChatLiteLLMRouter that uses LiteLLM Router
|
|
||||||
for automatic load balancing across available providers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_config: AgentConfig instance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatLiteLLM or ChatLiteLLMRouter instance, or None on error
|
|
||||||
"""
|
|
||||||
# Handle Auto mode - return ChatLiteLLMRouter
|
|
||||||
if agent_config.is_auto_mode:
|
if agent_config.is_auto_mode:
|
||||||
if not LLMRouterService.is_initialized():
|
if not LLMRouterService.is_initialized():
|
||||||
print("Error: Auto mode requested but LLM Router not initialized")
|
print("Error: Auto mode requested but LLM Router not initialized")
|
||||||
|
|
@ -578,19 +441,14 @@ def create_chat_litellm_from_agent_config(
|
||||||
try:
|
try:
|
||||||
router_llm = get_auto_mode_llm()
|
router_llm = get_auto_mode_llm()
|
||||||
if router_llm is not None:
|
if router_llm is not None:
|
||||||
# Universal cache_control_injection_points only — auto-mode
|
# Universal injection points only: auto-mode fans out across
|
||||||
# fans out across providers, so OpenAI-only kwargs (e.g.
|
# providers, so provider-specific kwargs have no known target.
|
||||||
# ``prompt_cache_key``) are left off here. ``drop_params``
|
|
||||||
# would strip them at the provider boundary anyway, but
|
|
||||||
# there's no point setting them when we don't know the
|
|
||||||
# destination.
|
|
||||||
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
||||||
return router_llm
|
return router_llm
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Build the model string
|
|
||||||
if agent_config.custom_provider:
|
if agent_config.custom_provider:
|
||||||
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
|
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
|
||||||
else:
|
else:
|
||||||
|
|
@ -599,26 +457,19 @@ def create_chat_litellm_from_agent_config(
|
||||||
)
|
)
|
||||||
model_string = f"{provider_prefix}/{agent_config.model_name}"
|
model_string = f"{provider_prefix}/{agent_config.model_name}"
|
||||||
|
|
||||||
# Create ChatLiteLLM instance with streaming enabled
|
|
||||||
litellm_kwargs = {
|
litellm_kwargs = {
|
||||||
"model": model_string,
|
"model": model_string,
|
||||||
"api_key": agent_config.api_key,
|
"api_key": agent_config.api_key,
|
||||||
"streaming": True, # Enable streaming for real-time token streaming
|
"streaming": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add optional parameters
|
|
||||||
if agent_config.api_base:
|
if agent_config.api_base:
|
||||||
litellm_kwargs["api_base"] = agent_config.api_base
|
litellm_kwargs["api_base"] = agent_config.api_base
|
||||||
|
|
||||||
# Add any additional litellm parameters
|
|
||||||
if agent_config.litellm_params:
|
if agent_config.litellm_params:
|
||||||
litellm_kwargs.update(agent_config.litellm_params)
|
litellm_kwargs.update(agent_config.litellm_params)
|
||||||
|
|
||||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
_attach_model_profile(llm, model_string)
|
_attach_model_profile(llm, model_string)
|
||||||
# Build-time prompt caching: sets ``cache_control_injection_points`` for
|
# Build-time caching only; the per-thread prompt_cache_key is layered on
|
||||||
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
|
# later in create_surfsense_deep_agent once thread_id is known.
|
||||||
# Per-thread ``prompt_cache_key`` is layered on later in
|
|
||||||
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
|
|
||||||
apply_litellm_prompt_caching(llm, agent_config=agent_config)
|
apply_litellm_prompt_caching(llm, agent_config=agent_config)
|
||||||
return llm
|
return llm
|
||||||
|
|
|
||||||
|
|
@ -1,63 +1,28 @@
|
||||||
r"""LiteLLM-native prompt caching configuration for SurfSense agents.
|
r"""LiteLLM-native prompt caching for SurfSense agents.
|
||||||
|
|
||||||
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
|
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (its
|
||||||
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
|
``isinstance(model, ChatAnthropic)`` gate never matched our LiteLLM stack)
|
||||||
gate always failed) with LiteLLM's universal caching mechanism.
|
with LiteLLM's universal ``cache_control_injection_points`` mechanism, which
|
||||||
|
covers the Anthropic/Bedrock/Vertex/Gemini/OpenRouter/etc. marker-based
|
||||||
|
providers and the auto-caching OpenAI family.
|
||||||
|
|
||||||
Coverage:
|
Two breakpoints per request:
|
||||||
|
|
||||||
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
|
- ``index: 0`` pins the head-of-request system prompt. We use ``index: 0``,
|
||||||
performs automatically when ``cache_control_injection_points`` is set):
|
NOT ``role: system``: ``before_agent`` injectors accumulate many
|
||||||
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
|
SystemMessages, and tagging all of them overflows Anthropic's 4-block cap
|
||||||
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
|
(upstream 400 via OpenRouter).
|
||||||
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
|
- ``index: -1`` pins the latest message so longest-prefix lookup compounds
|
||||||
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
|
multi-turn savings.
|
||||||
``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024
|
|
||||||
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
|
|
||||||
|
|
||||||
We inject **two** breakpoints per request:
|
OpenAI-family configs also get ``prompt_cache_key`` (per-thread routing hint)
|
||||||
|
and ``prompt_cache_retention="24h"``. Azure is excluded from the latter
|
||||||
|
because LiteLLM's Azure transformer drops it (see
|
||||||
|
``_PROMPT_CACHE_RETENTION_PROVIDERS``).
|
||||||
|
|
||||||
- ``index: 0`` — pins the SurfSense system prompt at the head of the
|
Safety net: ``litellm.drop_params=True`` (set in ``app.services.llm_service``)
|
||||||
request (provider variant, citation rules, tool catalog, KB tree,
|
strips any kwarg the destination provider rejects, so an auto-mode fallback
|
||||||
skills metadata). The langchain agent factory always prepends
|
can't 400 on these extras.
|
||||||
``request.system_message`` at index 0 (see ``factory.py``
|
|
||||||
``_execute_model_async``), so this targets exactly the main system
|
|
||||||
prompt regardless of how many other ``SystemMessage``\ s the
|
|
||||||
``before_agent`` injectors (priority, tree, memory, file-intent,
|
|
||||||
anonymous-doc) have inserted into ``state["messages"]``. Using
|
|
||||||
``role: system`` here would apply ``cache_control`` to **every**
|
|
||||||
system-role message and trip Anthropic's hard cap of 4 cache
|
|
||||||
breakpoints per request once the conversation accumulates enough
|
|
||||||
injected system messages — which surfaces as the upstream 400
|
|
||||||
``A maximum of 4 blocks with cache_control may be provided. Found N``
|
|
||||||
via OpenRouter→Anthropic.
|
|
||||||
- ``index: -1`` — pins the latest message so multi-turn savings compound:
|
|
||||||
Anthropic-family providers use longest-matching-prefix lookup, so turn
|
|
||||||
N+1 still reads turn N's cache up to the shared prefix.
|
|
||||||
|
|
||||||
For OpenAI-family configs we additionally pass:
|
|
||||||
|
|
||||||
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
|
|
||||||
raises hit rate by sending requests with a shared prefix to the same
|
|
||||||
backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and
|
|
||||||
``azure/`` (added to LiteLLM's Azure transformer in
|
|
||||||
https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified
|
|
||||||
against ``AzureOpenAIConfig.get_supported_openai_params`` in our
|
|
||||||
installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``,
|
|
||||||
``azure/gpt-5.4``, ``azure/gpt-5.4-mini``).
|
|
||||||
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
|
|
||||||
5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's
|
|
||||||
server-side support landed in Microsoft's docs on 2026-05-13 but
|
|
||||||
LiteLLM 1.83.14's Azure transformer still omits it from its supported
|
|
||||||
params list, so it gets silently dropped by ``litellm.drop_params``.
|
|
||||||
Azure's default in-memory retention (5-10 min, max 1 h) already
|
|
||||||
bridges intra-conversation turns; revisit when LiteLLM bumps Azure.
|
|
||||||
|
|
||||||
Safety net: ``litellm.drop_params=True`` is set globally in
|
|
||||||
``app.services.llm_service`` at module-load time. Any kwarg the destination
|
|
||||||
provider doesn't recognise is auto-stripped at the provider transformer
|
|
||||||
layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on
|
|
||||||
``prompt_cache_key`` etc.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -73,57 +38,29 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Two-breakpoint policy: head-of-request + latest message. See module
|
# Head-of-request + latest message (see module docstring for the index:0 vs
|
||||||
# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
|
# role:system rationale and Anthropic's 4-block cap).
|
||||||
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
|
|
||||||
#
|
|
||||||
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
|
|
||||||
# ``before_agent`` middlewares (priority, tree, memory, anonymous-doc)
|
|
||||||
# insert ``SystemMessage`` instances into ``state["messages"]`` that
|
|
||||||
# accumulate across turns. With ``role: system`` the LiteLLM hook would
|
|
||||||
# tag *every* one of them with ``cache_control`` and overflow Anthropic's
|
|
||||||
# 4-block limit. ``index: 0`` always targets the langchain-prepended
|
|
||||||
# ``request.system_message``, giving us exactly one stable cache breakpoint.
|
|
||||||
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||||
{"location": "message", "index": 0},
|
{"location": "message", "index": 0},
|
||||||
{"location": "message", "index": -1},
|
{"location": "message", "index": -1},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Providers (uppercase ``AgentConfig.provider`` values) that accept the
|
# Providers that accept the OpenAI ``prompt_cache_key`` routing hint. Strict
|
||||||
# OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs
|
# whitelist: many providers route through litellm's ``openai`` prefix without
|
||||||
# (2026-05-13) confirm automatic prompt caching applies to every GPT-4o
|
# the prompt-cache surface, so the prefix alone isn't enough to infer family.
|
||||||
# or newer Azure deployment at ≥1024 tokens with no configuration needed,
|
|
||||||
# and that ``prompt_cache_key`` is combined with the prefix hash to
|
|
||||||
# improve routing affinity and therefore cache hit rate. LiteLLM's Azure
|
|
||||||
# transformer ships ``prompt_cache_key`` in its supported params as of
|
|
||||||
# https://github.com/BerriAI/litellm/pull/20989.
|
|
||||||
#
|
|
||||||
# Strict whitelist — many other providers in ``PROVIDER_MAP`` route
|
|
||||||
# through litellm's ``openai`` prefix without implementing the OpenAI
|
|
||||||
# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer
|
|
||||||
# family from the litellm prefix alone.
|
|
||||||
_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset(
|
_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset(
|
||||||
{"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"}
|
{"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept
|
# Subset that also accepts ``prompt_cache_retention="24h"``. Azure is excluded
|
||||||
# ``prompt_cache_retention="24h"``. Azure is excluded: see module
|
# because LiteLLM's Azure transformer omits the param (drop_params strips it).
|
||||||
# docstring — LiteLLM 1.83.14's Azure transformer omits the param so
|
|
||||||
# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM
|
|
||||||
# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``.
|
|
||||||
_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset(
|
_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset(
|
||||||
{"OPENAI", "DEEPSEEK", "XAI"}
|
{"OPENAI", "DEEPSEEK", "XAI"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_router_llm(llm: BaseChatModel) -> bool:
|
def _is_router_llm(llm: BaseChatModel) -> bool:
|
||||||
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
|
"""Detect ``ChatLiteLLMRouter`` by class name to avoid an import cycle."""
|
||||||
|
|
||||||
Importing ``app.services.llm_router_service`` at module-load time would
|
|
||||||
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
|
|
||||||
Class-name comparison is sufficient since the class is defined in a
|
|
||||||
single place.
|
|
||||||
"""
|
|
||||||
return type(llm).__name__ == "ChatLiteLLMRouter"
|
return type(llm).__name__ == "ChatLiteLLMRouter"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -188,21 +125,10 @@ def apply_litellm_prompt_caching(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
|
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
|
||||||
|
|
||||||
Idempotent — values already present in ``llm.model_kwargs`` (e.g. from
|
Idempotent (existing ``model_kwargs`` values are preserved) and mutates
|
||||||
``agent_config.litellm_params`` overrides) are preserved. Mutates
|
``llm.model_kwargs`` in place. Without ``agent_config`` (or in auto-mode)
|
||||||
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
|
only the universal injection points are set; ``thread_id`` adds a per-thread
|
||||||
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
|
``prompt_cache_key`` for OpenAI-family providers to improve routing affinity.
|
||||||
in our custom ``ChatLiteLLMRouter``.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
|
|
||||||
agent_config: Optional ``AgentConfig`` driving provider-specific
|
|
||||||
behaviour. When omitted (or auto-mode), only the universal
|
|
||||||
``cache_control_injection_points`` are set.
|
|
||||||
thread_id: Optional thread id used to construct a per-thread
|
|
||||||
``prompt_cache_key`` for OpenAI-family providers. Caching still
|
|
||||||
works without it (server-side automatic), but the key improves
|
|
||||||
backend routing affinity and therefore hit rate.
|
|
||||||
"""
|
"""
|
||||||
model_kwargs = _get_or_init_model_kwargs(llm)
|
model_kwargs = _get_or_init_model_kwargs(llm)
|
||||||
if model_kwargs is None:
|
if model_kwargs is None:
|
||||||
|
|
@ -217,11 +143,8 @@ def apply_litellm_prompt_caching(
|
||||||
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
||||||
]
|
]
|
||||||
|
|
||||||
# OpenAI-style extras only when we statically know the destination
|
# OpenAI-style extras only when the destination is statically known. The
|
||||||
# accepts them. Auto-mode router fans out across mixed providers so
|
# auto-mode router fans out across mixed providers, so skip them there.
|
||||||
# we can't safely set destination-specific kwargs there (drop_params
|
|
||||||
# would strip them but it's wasteful to set them in the first
|
|
||||||
# place).
|
|
||||||
if _is_router_llm(llm):
|
if _is_router_llm(llm):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,13 @@
|
||||||
"""
|
"""SurfSense compaction middleware.
|
||||||
SurfSense compaction middleware.
|
|
||||||
|
|
||||||
Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware`
|
Extends ``SummarizationMiddleware`` with three SurfSense behaviors:
|
||||||
to add SurfSense-specific behavior:
|
|
||||||
|
|
||||||
1. **Structured summary template** (OpenCode-style ``## Goal / Constraints /
|
1. A structured summary template (:data:`SURFSENSE_SUMMARY_PROMPT`) instead of
|
||||||
Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``)
|
the base freeform prompt.
|
||||||
— see :data:`SURFSENSE_SUMMARY_PROMPT` below. The base
|
2. Protected SystemMessages (injected hints like ``<priority_documents>``) are
|
||||||
``SummarizationMiddleware`` only ships a freeform "summarize this"
|
kept verbatim instead of being summarized away.
|
||||||
prompt; the structured template is ported from OpenCode's
|
3. ``content=None`` is sanitized before ``get_buffer_string`` (some providers
|
||||||
``compaction.ts``.
|
stream tool-only AIMessages with ``None`` content, which would crash it).
|
||||||
2. **Protect SurfSense-specific SystemMessages** so injected hints
|
|
||||||
(``<priority_documents>``, ``<workspace_tree>``, ``<file_operation_contract>``,
|
|
||||||
``<user_memory>``, ``<team_memory>``, ``<user_name>``, ``<memory_warning>``)
|
|
||||||
are *not* summarized away and are kept verbatim in the post-summary
|
|
||||||
message list. Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
|
|
||||||
(some message types are part of the agent's contract and must survive
|
|
||||||
compaction unchanged).
|
|
||||||
3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string``
|
|
||||||
(Azure OpenAI / LiteLLM defense — when a provider streams an AIMessage
|
|
||||||
containing only tool_calls and no text, ``content`` can be ``None`` and
|
|
||||||
``get_buffer_string`` crashes iterating over ``None``). SurfSense-specific.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -43,9 +30,7 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Structured summary template ported from OpenCode's
|
# Module-level constant so unit tests can assert on its sections.
|
||||||
# ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a
|
|
||||||
# module-level constant so unit tests can assert on its sections.
|
|
||||||
SURFSENSE_SUMMARY_PROMPT = """<role>
|
SURFSENSE_SUMMARY_PROMPT = """<role>
|
||||||
SurfSense Conversation Compaction Assistant
|
SurfSense Conversation Compaction Assistant
|
||||||
</role>
|
</role>
|
||||||
|
|
@ -114,13 +99,10 @@ def _is_protected_system_message(msg: AnyMessage) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
|
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
|
||||||
"""Return ``msg`` with ``content=None`` coerced to ``""``.
|
"""Return a copy of ``msg`` with ``content=None`` coerced to ``""``.
|
||||||
|
|
||||||
Folds in the historical defense from ``safe_summarization.py`` —
|
``get_buffer_string`` reads ``m.text`` (iterating ``content``), so a
|
||||||
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``,
|
tool-only AIMessage with ``None`` content would crash it.
|
||||||
so a ``None`` content (Azure OpenAI / LiteLLM streaming a tool-only
|
|
||||||
AIMessage) explodes. We return a copy with empty string content so
|
|
||||||
downstream consumers see an empty body without mutating the original.
|
|
||||||
"""
|
"""
|
||||||
if getattr(msg, "content", "not-missing") is not None:
|
if getattr(msg, "content", "not-missing") is not None:
|
||||||
return msg
|
return msg
|
||||||
|
|
@ -159,20 +141,11 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware):
|
||||||
conversation_messages: list[AnyMessage],
|
conversation_messages: list[AnyMessage],
|
||||||
cutoff_index: int,
|
cutoff_index: int,
|
||||||
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
||||||
"""Split messages but always preserve SurfSense protected SystemMessages.
|
"""Split messages, always preserving protected SystemMessages.
|
||||||
|
|
||||||
Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
|
Also opens a ``compaction.run`` OTel span (no-op when OTel is off) here,
|
||||||
(``opencode/packages/opencode/src/session/compaction.ts``): some
|
since partitioning is the first call once summarization is decided.
|
||||||
message types are always kept verbatim because they are part of the
|
|
||||||
agent's working contract, not transient output.
|
|
||||||
|
|
||||||
Also opens a ``compaction.run`` OTel span (no-op when OTel is off)
|
|
||||||
so dashboards can count compaction events and message-volume
|
|
||||||
without having to instrument upstream callers.
|
|
||||||
"""
|
"""
|
||||||
# Opening a span here is appropriate because partitioning is the
|
|
||||||
# first call SummarizationMiddleware makes when it has decided to
|
|
||||||
# summarize; we record the volume and then close as a normal span.
|
|
||||||
with ot.compaction_span(
|
with ot.compaction_span(
|
||||||
reason="auto",
|
reason="auto",
|
||||||
messages_in=len(conversation_messages),
|
messages_in=len(conversation_messages),
|
||||||
|
|
@ -191,20 +164,15 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware):
|
||||||
else:
|
else:
|
||||||
kept_for_summary.append(msg)
|
kept_for_summary.append(msg)
|
||||||
|
|
||||||
# Place protected blocks at the *front* of preserved_messages so
|
# Protected blocks go at the front of preserved_messages to keep
|
||||||
# they keep their original ordering relative to the summary
|
# ordering relative to the summary HumanMessage.
|
||||||
# HumanMessage that precedes the rest of the preserved tail.
|
|
||||||
return kept_for_summary, [*protected, *preserved_messages]
|
return kept_for_summary, [*protected, *preserved_messages]
|
||||||
|
|
||||||
def _filter_summary_messages( # type: ignore[override]
|
def _filter_summary_messages( # type: ignore[override]
|
||||||
self, messages: list[AnyMessage]
|
self, messages: list[AnyMessage]
|
||||||
) -> list[AnyMessage]:
|
) -> list[AnyMessage]:
|
||||||
"""Filter previous summaries AND sanitize ``content=None``.
|
"""Filter previous summaries and sanitize ``content=None`` (covers the
|
||||||
|
sync and async offload paths)."""
|
||||||
Folds the ``safe_summarization.py`` defense in: when the buffer
|
|
||||||
builder iterates ``m.text`` over ``None`` it explodes; sanitizing
|
|
||||||
here covers both the sync and async offload paths.
|
|
||||||
"""
|
|
||||||
filtered = super()._filter_summary_messages(messages)
|
filtered = super()._filter_summary_messages(messages)
|
||||||
return [_sanitize_message_content(m) for m in filtered]
|
return [_sanitize_message_content(m) for m in filtered]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,14 +24,11 @@ from .utils import get_voice_for_provider
|
||||||
async def create_podcast_transcript(
|
async def create_podcast_transcript(
|
||||||
state: State, config: RunnableConfig
|
state: State, config: RunnableConfig
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Each node does work."""
|
"""Generate the podcast transcript from the source content."""
|
||||||
|
|
||||||
# Get configuration from runnable config
|
|
||||||
configuration = Configuration.from_runnable_config(config)
|
configuration = Configuration.from_runnable_config(config)
|
||||||
search_space_id = configuration.search_space_id
|
search_space_id = configuration.search_space_id
|
||||||
user_prompt = configuration.user_prompt
|
user_prompt = configuration.user_prompt
|
||||||
|
|
||||||
# Get search space's document summary LLM
|
|
||||||
llm = await get_agent_llm(state.db_session, search_space_id)
|
llm = await get_agent_llm(state.db_session, search_space_id)
|
||||||
if not llm:
|
if not llm:
|
||||||
error_message = (
|
error_message = (
|
||||||
|
|
@ -40,22 +37,16 @@ async def create_podcast_transcript(
|
||||||
print(error_message)
|
print(error_message)
|
||||||
raise RuntimeError(error_message)
|
raise RuntimeError(error_message)
|
||||||
|
|
||||||
# Get the prompt
|
|
||||||
prompt = get_podcast_generation_prompt(user_prompt)
|
prompt = get_podcast_generation_prompt(user_prompt)
|
||||||
|
|
||||||
# Create the messages
|
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(content=prompt),
|
SystemMessage(content=prompt),
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=f"<source_content>{state.source_content}</source_content>"
|
content=f"<source_content>{state.source_content}</source_content>"
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Generate the podcast transcript
|
|
||||||
llm_response = await llm.ainvoke(messages)
|
llm_response = await llm.ainvoke(messages)
|
||||||
|
|
||||||
# Reasoning models (e.g. Kimi K2.5) may return content as a list of
|
# Reasoning models may return content as blocks; normalise to a string.
|
||||||
# blocks including 'reasoning' entries. Normalise to a plain string.
|
|
||||||
content = strip_markdown_fences(extract_text_content(llm_response.content))
|
content = strip_markdown_fences(extract_text_content(llm_response.content))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -89,17 +80,13 @@ async def create_merged_podcast_audio(
|
||||||
state: State, config: RunnableConfig
|
state: State, config: RunnableConfig
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Generate audio for each transcript and merge them into a single podcast file."""
|
"""Generate audio for each transcript and merge them into a single podcast file."""
|
||||||
|
|
||||||
# configuration = Configuration.from_runnable_config(config)
|
|
||||||
|
|
||||||
starting_transcript = PodcastTranscriptEntry(
|
starting_transcript = PodcastTranscriptEntry(
|
||||||
speaker_id=1, dialog="Welcome to Surfsense Podcast."
|
speaker_id=1, dialog="Welcome to Surfsense Podcast."
|
||||||
)
|
)
|
||||||
|
|
||||||
transcript = state.podcast_transcript
|
transcript = state.podcast_transcript
|
||||||
|
|
||||||
# Merge the starting transcript with the podcast transcript
|
# transcript may be a PodcastTranscripts object or already a list.
|
||||||
# Check if transcript is a PodcastTranscripts object or already a list
|
|
||||||
if hasattr(transcript, "podcast_transcripts"):
|
if hasattr(transcript, "podcast_transcripts"):
|
||||||
transcript_entries = transcript.podcast_transcripts
|
transcript_entries = transcript.podcast_transcripts
|
||||||
else:
|
else:
|
||||||
|
|
@ -107,20 +94,16 @@ async def create_merged_podcast_audio(
|
||||||
|
|
||||||
merged_transcript = [starting_transcript, *transcript_entries]
|
merged_transcript = [starting_transcript, *transcript_entries]
|
||||||
|
|
||||||
# Create a temporary directory for audio files
|
|
||||||
temp_dir = Path("temp_audio")
|
temp_dir = Path("temp_audio")
|
||||||
temp_dir.mkdir(exist_ok=True)
|
temp_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
# Generate a unique session ID for this podcast
|
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
output_path = f"podcasts/{session_id}_podcast.mp3"
|
output_path = f"podcasts/{session_id}_podcast.mp3"
|
||||||
os.makedirs("podcasts", exist_ok=True)
|
os.makedirs("podcasts", exist_ok=True)
|
||||||
|
|
||||||
# Generate audio for each transcript segment
|
|
||||||
audio_files = []
|
audio_files = []
|
||||||
|
|
||||||
async def generate_speech_for_segment(segment, index):
|
async def generate_speech_for_segment(segment, index):
|
||||||
# Handle both dictionary and PodcastTranscriptEntry objects
|
|
||||||
if hasattr(segment, "speaker_id"):
|
if hasattr(segment, "speaker_id"):
|
||||||
speaker_id = segment.speaker_id
|
speaker_id = segment.speaker_id
|
||||||
dialog = segment.dialog
|
dialog = segment.dialog
|
||||||
|
|
@ -128,20 +111,15 @@ async def create_merged_podcast_audio(
|
||||||
speaker_id = segment.get("speaker_id", 0)
|
speaker_id = segment.get("speaker_id", 0)
|
||||||
dialog = segment.get("dialog", "")
|
dialog = segment.get("dialog", "")
|
||||||
|
|
||||||
# Select voice based on speaker_id
|
|
||||||
voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id)
|
voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id)
|
||||||
|
|
||||||
# Generate a unique filename for this segment
|
|
||||||
if app_config.TTS_SERVICE == "local/kokoro":
|
if app_config.TTS_SERVICE == "local/kokoro":
|
||||||
# Kokoro generates WAV files
|
|
||||||
filename = f"{temp_dir}/{session_id}_{index}.wav"
|
filename = f"{temp_dir}/{session_id}_{index}.wav"
|
||||||
else:
|
else:
|
||||||
# Other services generate MP3 files
|
|
||||||
filename = f"{temp_dir}/{session_id}_{index}.mp3"
|
filename = f"{temp_dir}/{session_id}_{index}.mp3"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if app_config.TTS_SERVICE == "local/kokoro":
|
if app_config.TTS_SERVICE == "local/kokoro":
|
||||||
# Use Kokoro TTS service
|
|
||||||
kokoro_service = await get_kokoro_tts_service(
|
kokoro_service = await get_kokoro_tts_service(
|
||||||
lang_code="a"
|
lang_code="a"
|
||||||
) # American English
|
) # American English
|
||||||
|
|
@ -170,7 +148,6 @@ async def create_merged_podcast_audio(
|
||||||
timeout=600,
|
timeout=600,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save the audio to a file - use proper streaming method
|
|
||||||
with open(filename, "wb") as f:
|
with open(filename, "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
|
|
||||||
|
|
@ -179,23 +156,17 @@ async def create_merged_podcast_audio(
|
||||||
print(f"Error generating speech for segment {index}: {e!s}")
|
print(f"Error generating speech for segment {index}: {e!s}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Generate all audio files concurrently
|
|
||||||
tasks = [
|
tasks = [
|
||||||
generate_speech_for_segment(segment, i)
|
generate_speech_for_segment(segment, i)
|
||||||
for i, segment in enumerate(merged_transcript)
|
for i, segment in enumerate(merged_transcript)
|
||||||
]
|
]
|
||||||
audio_files = await asyncio.gather(*tasks)
|
audio_files = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# Merge audio files using ffmpeg
|
|
||||||
try:
|
try:
|
||||||
# Create FFmpeg instance with the first input
|
|
||||||
ffmpeg = FFmpeg().option("y")
|
ffmpeg = FFmpeg().option("y")
|
||||||
|
|
||||||
# Add each audio file as input
|
|
||||||
for audio_file in audio_files:
|
for audio_file in audio_files:
|
||||||
ffmpeg = ffmpeg.input(audio_file)
|
ffmpeg = ffmpeg.input(audio_file)
|
||||||
|
|
||||||
# Configure the concatenation and output
|
|
||||||
filter_complex = []
|
filter_complex = []
|
||||||
for i in range(len(audio_files)):
|
for i in range(len(audio_files)):
|
||||||
filter_complex.append(f"[{i}:0]")
|
filter_complex.append(f"[{i}:0]")
|
||||||
|
|
@ -205,8 +176,6 @@ async def create_merged_podcast_audio(
|
||||||
)
|
)
|
||||||
ffmpeg = ffmpeg.option("filter_complex", filter_complex_str)
|
ffmpeg = ffmpeg.option("filter_complex", filter_complex_str)
|
||||||
ffmpeg = ffmpeg.output(output_path, map="[outa]")
|
ffmpeg = ffmpeg.output(output_path, map="[outa]")
|
||||||
|
|
||||||
# Execute FFmpeg
|
|
||||||
await ffmpeg.execute()
|
await ffmpeg.execute()
|
||||||
|
|
||||||
print(f"Successfully created podcast audio: {output_path}")
|
print(f"Successfully created podcast audio: {output_path}")
|
||||||
|
|
@ -215,7 +184,6 @@ async def create_merged_podcast_audio(
|
||||||
print(f"Error merging audio files: {e!s}")
|
print(f"Error merging audio files: {e!s}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
# Clean up temporary files
|
|
||||||
for audio_file in audio_files:
|
for audio_file in audio_files:
|
||||||
try:
|
try:
|
||||||
os.remove(audio_file)
|
os.remove(audio_file)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue