Merge pull request #1422 from CREDO23/improvement-agent-speed

[Improvement] Agent: faster turns and lower LLM cost
This commit is contained in:
Rohan Verma 2026-05-20 14:57:19 -07:00 committed by GitHub
commit 5c4da79da4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
45 changed files with 1126 additions and 661 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import time
from typing import Any, cast from typing import Any, cast
from deepagents.backends.protocol import BackendFactory, BackendProtocol from deepagents.backends.protocol import BackendFactory, BackendProtocol
@ -15,8 +16,12 @@ from langchain.agents import create_agent
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from app.utils.perf import get_perf_logger
from .task_tool import build_task_tool_with_parent_config from .task_tool import build_task_tool_with_parent_config
_perf_log = get_perf_logger()
class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware): class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
"""``SubAgentMiddleware`` variant that compiles each subagent against the parent checkpointer.""" """``SubAgentMiddleware`` variant that compiles each subagent against the parent checkpointer."""
@ -54,8 +59,11 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]: def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]:
"""Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer.""" """Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer."""
specs: list[dict[str, Any]] = [] specs: list[dict[str, Any]] = []
loop_start = time.perf_counter()
timings: list[tuple[str, float, str]] = [] # (name, elapsed, source)
for spec in self._subagents: for spec in self._subagents:
spec_start = time.perf_counter()
if "runnable" in spec: if "runnable" in spec:
compiled = cast(CompiledSubAgent, spec) compiled = cast(CompiledSubAgent, spec)
specs.append( specs.append(
@ -65,6 +73,9 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
"runnable": compiled["runnable"], "runnable": compiled["runnable"],
} }
) )
timings.append(
(compiled["name"], time.perf_counter() - spec_start, "precompiled")
)
continue continue
if "model" not in spec: if "model" not in spec:
@ -79,20 +90,44 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
model = init_chat_model(model) model = init_chat_model(model)
middleware: list[Any] = list(spec.get("middleware", [])) middleware: list[Any] = list(spec.get("middleware", []))
tools_count = len(spec.get("tools") or [])
mw_count = len(middleware)
compile_start = time.perf_counter()
runnable = create_agent(
model,
system_prompt=spec["system_prompt"],
tools=spec["tools"],
middleware=middleware,
name=spec["name"],
checkpointer=self._surf_checkpointer,
)
compile_elapsed = time.perf_counter() - compile_start
specs.append( specs.append(
{ {
"name": spec["name"], "name": spec["name"],
"description": spec["description"], "description": spec["description"],
"runnable": create_agent( "runnable": runnable,
model,
system_prompt=spec["system_prompt"],
tools=spec["tools"],
middleware=middleware,
name=spec["name"],
checkpointer=self._surf_checkpointer,
),
} }
) )
timings.append(
(
spec["name"],
compile_elapsed,
f"compiled tools={tools_count} mw={mw_count}",
)
)
total_elapsed = time.perf_counter() - loop_start
per_subagent = ", ".join(
f"{name}={elapsed * 1000:.0f}ms[{source}]"
for name, elapsed, source in timings
)
_perf_log.info(
"[subagent_compile] total=%.3fs count=%d details=[%s]",
total_elapsed,
len(timings),
per_subagent,
)
return specs return specs

View file

@ -9,6 +9,7 @@ re-raises any new pending interrupt back to the parent.
from __future__ import annotations from __future__ import annotations
import logging import logging
import time
from typing import Annotated, Any, NoReturn from typing import Annotated, Any, NoReturn
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
@ -19,6 +20,8 @@ from langchain_core.tools import StructuredTool
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt
from langgraph.types import Command, Interrupt from langgraph.types import Command, Interrupt
from app.utils.perf import get_perf_logger
from .config import ( from .config import (
consume_surfsense_resume, consume_surfsense_resume,
drain_parent_null_resume, drain_parent_null_resume,
@ -35,6 +38,7 @@ from .resume import (
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
def _reraise_stamped_subagent_interrupt( def _reraise_stamped_subagent_interrupt(
@ -209,6 +213,7 @@ def build_task_tool_with_parent_config(
], ],
runtime: ToolRuntime, runtime: ToolRuntime,
) -> str | Command: ) -> str | Command:
atask_start = time.perf_counter()
logger.info( logger.info(
"[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s", "[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s",
subagent_type, subagent_type,
@ -230,8 +235,10 @@ def build_task_tool_with_parent_config(
# Resume bridge — see ``task`` above. # Resume bridge — see ``task`` above.
pending_id: str | None = None pending_id: str | None = None
pending_value: Any = None pending_value: Any = None
aget_state_elapsed = 0.0
aget_state = getattr(subagent, "aget_state", None) aget_state = getattr(subagent, "aget_state", None)
if callable(aget_state): if callable(aget_state):
aget_state_start = time.perf_counter()
try: try:
snapshot = await aget_state(sub_config) snapshot = await aget_state(sub_config)
pending_id, pending_value = get_first_pending_subagent_interrupt( pending_id, pending_value = get_first_pending_subagent_interrupt(
@ -248,32 +255,78 @@ def build_task_tool_with_parent_config(
"Subagent aget_state failed; falling back to fresh ainvoke", "Subagent aget_state failed; falling back to fresh ainvoke",
exc_info=True, exc_info=True,
) )
finally:
aget_state_elapsed = time.perf_counter() - aget_state_start
if pending_value is not None: invoke_path = "resume" if pending_value is not None else "fresh"
resume_value = consume_surfsense_resume(runtime) ainvoke_start = time.perf_counter()
if resume_value is None: ainvoke_outcome = "ok"
raise RuntimeError( try:
f"Subagent {subagent_type!r} has a pending interrupt but no " if pending_value is not None:
"surfsense_resume_value on config; resume bridge is broken." resume_value = consume_surfsense_resume(runtime)
) if resume_value is None:
expected = hitlrequest_action_count(pending_value) raise RuntimeError(
resume_value = fan_out_decisions_to_match(resume_value, expected) f"Subagent {subagent_type!r} has a pending interrupt but no "
# Prevent the parent's resume payload from leaking into subagent "surfsense_resume_value on config; resume bridge is broken."
# interrupts via langgraph's parent_scratchpad fallback. )
drain_parent_null_resume(runtime) expected = hitlrequest_action_count(pending_value)
try: resume_value = fan_out_decisions_to_match(resume_value, expected)
result = await subagent.ainvoke( # Prevent the parent's resume payload from leaking into subagent
build_resume_command(resume_value, pending_id), # interrupts via langgraph's parent_scratchpad fallback.
config=sub_config, drain_parent_null_resume(runtime)
) try:
except GraphInterrupt as gi: result = await subagent.ainvoke(
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) build_resume_command(resume_value, pending_id),
else: config=sub_config,
try: )
result = await subagent.ainvoke(subagent_state, config=sub_config) except GraphInterrupt as gi:
except GraphInterrupt as gi: ainvoke_outcome = "interrupted"
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) _perf_log.info(
return _return_command_with_state_update(result, runtime.tool_call_id) "[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s "
"aget_state=%.3fs ainvoke=%.3fs total=%.3fs",
subagent_type,
invoke_path,
ainvoke_outcome,
aget_state_elapsed,
time.perf_counter() - ainvoke_start,
time.perf_counter() - atask_start,
)
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
else:
try:
result = await subagent.ainvoke(subagent_state, config=sub_config)
except GraphInterrupt as gi:
ainvoke_outcome = "interrupted"
_perf_log.info(
"[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s "
"aget_state=%.3fs ainvoke=%.3fs total=%.3fs",
subagent_type,
invoke_path,
ainvoke_outcome,
aget_state_elapsed,
time.perf_counter() - ainvoke_start,
time.perf_counter() - atask_start,
)
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
ainvoke_elapsed = time.perf_counter() - ainvoke_start
except GraphInterrupt:
raise
merge_start = time.perf_counter()
cmd = _return_command_with_state_update(result, runtime.tool_call_id)
merge_elapsed = time.perf_counter() - merge_start
_perf_log.info(
"[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s "
"aget_state=%.3fs ainvoke=%.3fs merge=%.3fs total=%.3fs",
subagent_type,
invoke_path,
ainvoke_outcome,
aget_state_elapsed,
ainvoke_elapsed,
merge_elapsed,
time.perf_counter() - atask_start,
)
return cmd
return StructuredTool.from_function( return StructuredTool.from_function(
name="task", name="task",

View file

@ -6,6 +6,7 @@ from langchain_core.language_models import BaseChatModel
from app.agents.new_chat.filesystem_selection import FilesystemMode from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import KnowledgePriorityMiddleware from app.agents.new_chat.middleware import KnowledgePriorityMiddleware
from app.services.llm_service import get_planner_llm
def build_knowledge_priority_mw( def build_knowledge_priority_mw(
@ -19,6 +20,7 @@ def build_knowledge_priority_mw(
) -> KnowledgePriorityMiddleware: ) -> KnowledgePriorityMiddleware:
return KnowledgePriorityMiddleware( return KnowledgePriorityMiddleware(
llm=llm, llm=llm,
planner_llm=get_planner_llm(),
search_space_id=search_space_id, search_space_id=search_space_id,
filesystem_mode=filesystem_mode, filesystem_mode=filesystem_mode,
available_connectors=available_connectors, available_connectors=available_connectors,

View file

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import time
from typing import Any from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState from langchain.agents.middleware import AgentMiddleware, AgentState
@ -10,6 +11,9 @@ from langgraph.runtime import Runtime
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
from app.agents.new_chat.middleware.knowledge_search import _render_priority_message from app.agents.new_chat.middleware.knowledge_search import _render_priority_message
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
class KbContextProjectionMiddleware(AgentMiddleware): # type: ignore[type-arg] class KbContextProjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
@ -30,17 +34,34 @@ class KbContextProjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
runtime: Runtime[Any], runtime: Runtime[Any],
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
del runtime del runtime
start = time.perf_counter()
tree_text = state.get("workspace_tree_text") tree_text = state.get("workspace_tree_text")
priority = state.get("kb_priority") priority = state.get("kb_priority")
if not tree_text and not priority: if not tree_text and not priority:
_perf_log.info(
"[kb_context_projection] tree=0 priority=0 elapsed=%.3fs",
time.perf_counter() - start,
)
return None return None
messages = list(state.get("messages") or []) messages = list(state.get("messages") or [])
insert_at = max(len(messages) - 1, 0) insert_at = max(len(messages) - 1, 0)
tree_chars = 0
if tree_text: if tree_text:
tree_chars = len(tree_text)
messages.insert(insert_at, SystemMessage(content=tree_text)) messages.insert(insert_at, SystemMessage(content=tree_text))
priority_count = 0
if priority: if priority:
priority_count = (
len(priority) if hasattr(priority, "__len__") else 1
)
messages.insert(insert_at, _render_priority_message(priority)) messages.insert(insert_at, _render_priority_message(priority))
_perf_log.info(
"[kb_context_projection] tree_chars=%d priority_items=%d elapsed=%.3fs",
tree_chars,
priority_count,
time.perf_counter() - start,
)
return {"messages": messages} return {"messages": messages}

View file

@ -118,5 +118,6 @@ Rules:
- `status=success``next_step=null`, `missing_fields=null`. - `status=success``next_step=null`, `missing_fields=null`.
- `status=partial|blocked|error``next_step` must be non-null. - `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. - `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- `evidence.content_excerpt`: max ~500 characters. Surface a short excerpt or a one-sentence summary, not the full file body. The supervisor already sees the tool's raw output.
Infer before you call; map every tool outcome faithfully. Infer before you call; map every tool outcome faithfully.

View file

@ -118,5 +118,6 @@ Rules:
- `status=success``next_step=null`, `missing_fields=null`. - `status=success``next_step=null`, `missing_fields=null`.
- `status=partial|blocked|error``next_step` must be non-null. - `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. - `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- `evidence.content_excerpt`: max ~500 characters. Surface a short excerpt or a one-sentence summary, not the full file body. The supervisor already sees the tool's raw output.
Infer before you call; map every tool outcome faithfully. Infer before you call; map every tool outcome faithfully.

View file

@ -50,4 +50,6 @@ Rules:
- `status=success` -> `next_step=null`, `missing_fields=null`. - `status=success` -> `next_step=null`, `missing_fields=null`.
- `status=partial|blocked|error` -> `next_step` must be non-null. - `status=partial|blocked|error` -> `next_step` must be non-null.
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null. - `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
- `evidence.findings`: max 10 entries, each a single sentence stating one distinct fact. Do not paste raw paragraphs, scraped pages, or quote blocks.
- `evidence.sources`: max 10 URLs, one per finding when applicable. List each URL once.
</output_contract> </output_contract>

View file

@ -38,7 +38,7 @@ Supervisor: "List open tasks in the Project Tracker base."
2. List tables in that base → identify the Tasks table; capture its table ID. 2. List tables in that base → identify the Tasks table; capture its table ID.
3. Get table schema → identify the status field and the choice IDs that represent "open" states. 3. Get table schema → identify the status field and the choice IDs that represent "open" states.
4. List records with a typed filter on the status field for those choice IDs. 4. List records with a typed filter on the status field for those choice IDs.
5. Return `status=success` with the matched records in `evidence.items`. 5. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched records listed in `action_summary` (record id, primary-field value, and 1-2 most relevant fields; one line per record; up to 10 entries, then `"...and N more"`).
</example> </example>
<example> <example>
@ -97,7 +97,7 @@ Rules:
- `status=partial|blocked|error``next_step` must be non-null. - `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. - `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: base, table, field, choice, record, etc.). - For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: base, table, field, choice, record, etc.).
- For discovery-only queries (lists), populate `evidence.items` with the structured list. - For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (record id, primary-field value, and 1-2 most relevant fields; up to 10 entries, then `"...and N more"`).
</output_contract> </output_contract>
Discover before you mutate; never guess identifiers, choice IDs, or required fields. Discover before you mutate; never guess identifiers, choice IDs, or required fields.

View file

@ -29,7 +29,7 @@ You are a Google Calendar specialist for the user's connected calendar.
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. | | `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
| tool raises / unknown | `error` | `"Calendar tool failed unexpectedly. Ask the user to retry shortly."` | | tool raises / unknown | `error` | `"Calendar tool failed unexpectedly. Ask the user to retry shortly."` |
Surface the tool's `event_id`, `title` / `summary`, `start_at`, `end_at`, and `html_link` inside `evidence` when the tool returned them. For `search_calendar_events`, place the raw `events` array inside `evidence.items`. Never invent a field the tool did not return. Surface the tool's `event_id`, `title` / `summary`, `start_at`, `end_at`, and `html_link` inside `evidence` when the tool returned them. For `search_calendar_events`, set `evidence.items` to `{ "total": N }` and list the matched events in `action_summary` (title, date, start time; one line per event; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
## Examples ## Examples
@ -115,7 +115,7 @@ Rules:
- `status=success``next_step=null`, `missing_fields=null`. - `status=success``next_step=null`, `missing_fields=null`.
- `status=partial|blocked|error``next_step` must be non-null. - `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. - `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For `search_calendar_events` results, populate `evidence.items` with `{ "events": [...], "total": N }`. - For `search_calendar_events` results, set `evidence.items` to `{ "total": N }` and list the matched events in `action_summary` (title, date, start time; up to 10 entries, then `"...and N more"`).
- For ambiguous matches across `update_calendar_event` / `delete_calendar_event`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`, where `label` should include the event title and start time for human readability). - For ambiguous matches across `update_calendar_event` / `delete_calendar_event`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`, where `label` should include the event title and start time for human readability).
Infer before you call; map every tool outcome faithfully. Infer before you call; map every tool outcome faithfully.

View file

@ -36,7 +36,7 @@ Failure handling:
<example> <example>
Supervisor: "Find tasks about the homepage redesign." Supervisor: "Find tasks about the homepage redesign."
1. Workspace search for "homepage redesign" → matched tasks. 1. Workspace search for "homepage redesign" → matched tasks.
2. Return `status=success` with the matched tasks in `evidence.items`. 2. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched tasks listed in `action_summary` (task id, title, status, assignees; one line per task; up to 10 entries, then `"...and N more"`).
</example> </example>
<example> <example>
@ -98,7 +98,7 @@ Rules:
- `status=partial|blocked|error``next_step` must be non-null. - `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. - `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: task, list, member, status, custom-field choice, etc.). - For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: task, list, member, status, custom-field choice, etc.).
- For discovery-only queries (lists), populate `evidence.items` with the structured list. - For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (task id, title, status, assignees; up to 10 entries, then `"...and N more"`).
</output_contract> </output_contract>
Discover before you mutate; never guess identifiers, list statuses, or assignees. Discover before you mutate; never guess identifiers, list statuses, or assignees.

View file

@ -24,7 +24,7 @@ You are a Discord specialist for the user's connected Discord server.
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. | | `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
| tool raises / unknown | `error` | `"Discord tool failed unexpectedly. Ask the user to retry shortly."` | | tool raises / unknown | `error` | `"Discord tool failed unexpectedly. Ask the user to retry shortly."` |
Surface the tool's `message`, `channel_id`, `message_id`, and the listed channels/messages payload inside `evidence` when the tool returned them. Never invent a field the tool did not return. Surface the tool's `message`, `channel_id`, and `message_id` inside `evidence` when the tool returned them. For `list_discord_channels` and `read_discord_messages`, set `evidence.items` to `{ "total": N }` and list the matched entries in `action_summary` (channel name or sender + timestamp + short text snippet; one line per entry; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
## Examples ## Examples

View file

@ -33,7 +33,7 @@ You are a Gmail specialist for the user's connected Gmail mailbox.
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. | | `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
| tool raises / unknown | `error` | `"Gmail tool failed unexpectedly. Ask the user to retry shortly."` | | tool raises / unknown | `error` | `"Gmail tool failed unexpectedly. Ask the user to retry shortly."` |
Surface the tool's `message_id`, `thread_id`, `draft_id`, `subject`, and recipient fields inside `evidence` when the tool returned them. For `search_gmail`, place the raw `emails` array inside `evidence.items`. Never invent a field the tool did not return. Surface the tool's `message_id`, `thread_id`, `draft_id`, `subject`, and recipient fields inside `evidence` when the tool returned them. For `search_gmail`, set `evidence.items` to `{ "total": N }` and list the matched emails in `action_summary` (sender, subject, date; one line per email; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
## Examples ## Examples
@ -114,7 +114,7 @@ Rules:
- `status=success``next_step=null`, `missing_fields=null`. - `status=success``next_step=null`, `missing_fields=null`.
- `status=partial|blocked|error``next_step` must be non-null. - `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. - `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For `search_gmail` results, populate `evidence.items` with `{ "emails": [...], "total": N }`. - For `search_gmail` results, set `evidence.items` to `{ "total": N }` and list the matched emails in `action_summary` (sender, subject, date; up to 10 entries, then `"...and N more"`).
- For ambiguous matches across `update_gmail_draft` / `trash_gmail_email` / `read_gmail_email`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`). - For ambiguous matches across `update_gmail_draft` / `trash_gmail_email` / `read_gmail_email`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`).
Infer before you call; verify before you send; map every tool outcome faithfully. Infer before you call; verify before you send; map every tool outcome faithfully.

View file

@ -39,7 +39,7 @@ Failure handling:
<example> <example>
Supervisor: "Find issues assigned to me with status 'In Progress'." Supervisor: "Find issues assigned to me with status 'In Progress'."
1. JQL search with `assignee = currentUser() AND status = "In Progress"`. 1. JQL search with `assignee = currentUser() AND status = "In Progress"`.
2. Return `status=success` with the matched issues in `evidence.items`. 2. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched issues listed in `action_summary` (issue key, summary, status, assignee; one line per issue; up to 10 entries, then `"...and N more"`).
</example> </example>
<example> <example>
@ -116,7 +116,7 @@ Rules:
- `status=partial|blocked|error``next_step` must be non-null. - `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. - `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: site, project, issue, user, transition, etc.). - For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: site, project, issue, user, transition, etc.).
- For discovery-only queries (lists), populate `evidence.items` with the structured list. - For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (issue key, summary, status, assignee; up to 10 entries, then `"...and N more"`).
</output_contract> </output_contract>
Discover before you mutate; never guess identifiers, transitions, or required fields. Discover before you mutate; never guess identifiers, transitions, or required fields.

View file

@ -32,7 +32,7 @@ Failure handling:
<example> <example>
Supervisor: "Find issues assigned to me with priority Urgent." Supervisor: "Find issues assigned to me with priority Urgent."
1. Discovery: list issues with filters `{assignee: "me", priority: 1}`. 1. Discovery: list issues with filters `{assignee: "me", priority: 1}`.
2. Return `status=success` with the matched issues in `evidence.items`. 2. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched issues listed in `action_summary` (identifier, title, state, assignee; one line per issue; up to 10 entries, then `"...and N more"`).
</example> </example>
<example> <example>
@ -106,7 +106,7 @@ Rules:
- `status=partial|blocked|error``next_step` must be non-null. - `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. - `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: issue, user, project, state, etc.). - For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: issue, user, project, state, etc.).
- For discovery-only queries (lists), populate `evidence.items` with the structured list. - For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (identifier, title, state, assignee; up to 10 entries, then `"...and N more"`).
</output_contract> </output_contract>
Discover before you mutate; never guess identifiers. Discover before you mutate; never guess identifiers.

View file

@ -26,7 +26,7 @@ You are a Luma specialist for the user's connected Luma account.
| `error` | `error` | Relay the tool's `message` verbatim as `next_step` (this covers Luma Plus 403s and other API errors). | | `error` | `error` | Relay the tool's `message` verbatim as `next_step` (this covers Luma Plus 403s and other API errors). |
| tool raises / unknown | `error` | `"Luma tool failed unexpectedly. Ask the user to retry shortly."` | | tool raises / unknown | `error` | `"Luma tool failed unexpectedly. Ask the user to retry shortly."` |
Surface the tool's `message`, `event_id`, `name`, `start_at`, and `url` inside `evidence` when the tool returned them. Never invent a field the tool did not return. Surface the tool's `message`, `event_id`, `name`, `start_at`, and `url` inside `evidence` when the tool returned them. For `list_luma_events`, set `evidence.items` to `{ "total": N }` and list the matched events in `action_summary` (event name, start date/time, location if present; one line per event; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
## Examples ## Examples

View file

@ -37,7 +37,7 @@ Failure handling:
Supervisor: "Summarize the latest discussion in #marketing." Supervisor: "Summarize the latest discussion in #marketing."
1. Search channels for "marketing" → one strong match. Capture the channel ID. 1. Search channels for "marketing" → one strong match. Capture the channel ID.
2. Read that channel's recent message history. 2. Read that channel's recent message history.
3. Return `status=success` with the message list in `evidence.items`. 3. Return `status=success` with `evidence.items` set to `{ "total": N }` and the messages listed in `action_summary` (sender, timestamp, text snippet; one line per message; up to 10 entries, then `"...and N more"`).
</example> </example>
<example> <example>
@ -92,7 +92,7 @@ Rules:
- `status=partial|blocked|error``next_step` must be non-null. - `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. - `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: channel, user, message, thread). - For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: channel, user, message, thread).
- For discovery-only queries (lists), populate `evidence.items` with the structured list. - For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (channel/user, key identifier, timestamp, short snippet; up to 10 entries, then `"...and N more"`).
</output_contract> </output_contract>
Discover before you post; never guess channel, user, or thread targets. Discover before you post; never guess channel, user, or thread targets.

View file

@ -26,7 +26,7 @@ You are a Microsoft Teams specialist for the user's connected Teams account.
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. | | `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
| tool raises / unknown | `error` | `"Teams tool failed unexpectedly. Ask the user to retry shortly."` | | tool raises / unknown | `error` | `"Teams tool failed unexpectedly. Ask the user to retry shortly."` |
Surface the tool's `message`, `team_id`, `team_name`, `channel_id`, `channel_name`, and `message_id` inside `evidence` when the tool returned them. Never invent a field the tool did not return. Surface the tool's `message`, `team_id`, `team_name`, `channel_id`, `channel_name`, and `message_id` inside `evidence` when the tool returned them. For `list_teams_channels` and `read_teams_messages`, set `evidence.items` to `{ "total": N }` and list the matched entries in `action_summary` (team channel, or sender + timestamp + short text snippet; one line per entry; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
## Examples ## Examples

View file

@ -102,6 +102,7 @@ from app.agents.new_chat.tools.registry import (
) )
from app.db import ChatVisibility from app.db import ChatVisibility
from app.services.connector_service import ConnectorService from app.services.connector_service import ConnectorService
from app.services.llm_service import get_planner_llm
from app.utils.perf import get_perf_logger from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger() _perf_log = get_perf_logger()
@ -1077,6 +1078,7 @@ def _build_compiled_agent_blocking(
else None, else None,
KnowledgePriorityMiddleware( KnowledgePriorityMiddleware(
llm=llm, llm=llm,
planner_llm=get_planner_llm(),
search_space_id=search_space_id, search_space_id=search_space_id,
filesystem_mode=filesystem_mode, filesystem_mode=filesystem_mode,
available_connectors=available_connectors, available_connectors=available_connectors,

View file

@ -32,6 +32,7 @@ exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect).
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
@ -249,11 +250,11 @@ async def _create_document(
session.add(doc) session.add(doc)
await session.flush() await session.flush()
summary_embedding = embed_texts([content])[0] summary_embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
doc.embedding = summary_embedding doc.embedding = summary_embedding
chunks = chunk_text(content) chunks = chunk_text(content)
if chunks: if chunks:
chunk_embeddings = embed_texts(chunks) chunk_embeddings = await asyncio.to_thread(embed_texts, chunks)
session.add_all( session.add_all(
[ [
Chunk(document_id=doc.id, content=text, embedding=embedding) Chunk(document_id=doc.id, content=text, embedding=embedding)
@ -295,13 +296,13 @@ async def _update_document(
search_space_id, search_space_id,
) )
summary_embedding = embed_texts([content])[0] summary_embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
document.embedding = summary_embedding document.embedding = summary_embedding
await session.execute(delete(Chunk).where(Chunk.document_id == document.id)) await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
chunks = chunk_text(content) chunks = chunk_text(content)
if chunks: if chunks:
chunk_embeddings = embed_texts(chunks) chunk_embeddings = await asyncio.to_thread(embed_texts, chunks)
session.add_all( session.add_all(
[ [
Chunk(document_id=document.id, content=text, embedding=embedding) Chunk(document_id=document.id, content=text, embedding=embedding)

View file

@ -457,7 +457,7 @@ async def search_knowledge_base(
if not query: if not query:
return [] return []
[embedding] = embed_texts([query]) [embedding] = await asyncio.to_thread(embed_texts, [query])
doc_types = _resolve_search_types(available_connectors, available_document_types) doc_types = _resolve_search_types(available_connectors, available_document_types)
retriever_top_k = min(top_k * 3, 30) retriever_top_k = min(top_k * 3, 30)
@ -579,6 +579,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
self, self,
*, *,
llm: BaseChatModel | None = None, llm: BaseChatModel | None = None,
planner_llm: BaseChatModel | None = None,
search_space_id: int, search_space_id: int,
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
available_connectors: list[str] | None = None, available_connectors: list[str] | None = None,
@ -588,6 +589,15 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
inject_system_message: bool = True, # For backwards compatibility inject_system_message: bool = True, # For backwards compatibility
) -> None: ) -> None:
self.llm = llm self.llm = llm
# The planner LLM handles short, structured internal tasks (query
# rewriting, date extraction, recency classification). When an
# operator marks a global config ``is_planner: true`` we route
# those calls to a cheap/fast model (e.g. gpt-4o-mini, Haiku, Azure
# gpt-5.x-nano) instead of the user's chat LLM — those classification
# tasks don't need frontier-tier capability. Falls back to the chat
# LLM when no planner config is wired up so deployments without one
# keep working unchanged.
self.planner_llm = planner_llm or llm
self.search_space_id = search_space_id self.search_space_id = search_space_id
self.filesystem_mode = filesystem_mode self.filesystem_mode = filesystem_mode
self.available_connectors = available_connectors self.available_connectors = available_connectors
@ -598,7 +608,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
# Build the kb-planner private Runnable ONCE here so we don't pay # Build the kb-planner private Runnable ONCE here so we don't pay
# the ``create_agent`` compile cost (50-200ms) on every turn. # the ``create_agent`` compile cost (50-200ms) on every turn.
# Disabled by default behind ``enable_kb_planner_runnable``; when # Disabled by default behind ``enable_kb_planner_runnable``; when
# off the planner falls back to the legacy ``self.llm.ainvoke`` # off the planner falls back to the legacy ``planner_llm.ainvoke``
# path. # path.
self._planner: Runnable | None = None self._planner: Runnable | None = None
self._planner_compile_failed = False self._planner_compile_failed = False
@ -608,7 +618,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
Returns ``None`` when the feature flag is disabled, when the LLM is Returns ``None`` when the feature flag is disabled, when the LLM is
unavailable, or when ``create_agent`` raises (we fall back to the unavailable, or when ``create_agent`` raises (we fall back to the
legacy ``self.llm.ainvoke`` path in that case). Compilation happens legacy ``planner_llm.ainvoke`` path in that case). Compilation happens
lazily on first call, then memoized via ``self._planner``. lazily on first call, then memoized via ``self._planner``.
The compiled agent is constructed without tools the planner's The compiled agent is constructed without tools the planner's
@ -618,7 +628,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
""" """
if self._planner is not None or self._planner_compile_failed: if self._planner is not None or self._planner_compile_failed:
return self._planner return self._planner
if self.llm is None: if self.planner_llm is None:
return None return None
flags = get_flags() flags = get_flags()
if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack: if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
@ -628,13 +638,13 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
try: try:
self._planner = create_agent( self._planner = create_agent(
self.llm, self.planner_llm,
tools=[], tools=[],
middleware=[RetryAfterMiddleware(max_retries=2)], middleware=[RetryAfterMiddleware(max_retries=2)],
) )
except Exception as exc: # pragma: no cover - defensive except Exception as exc: # pragma: no cover - defensive
logger.warning( logger.warning(
"kb-planner Runnable compile failed; falling back to llm.ainvoke: %s", "kb-planner Runnable compile failed; falling back to planner_llm.ainvoke: %s",
exc, exc,
) )
self._planner_compile_failed = True self._planner_compile_failed = True
@ -647,12 +657,12 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
messages: Sequence[BaseMessage], messages: Sequence[BaseMessage],
user_text: str, user_text: str,
) -> tuple[str, datetime | None, datetime | None, bool]: ) -> tuple[str, datetime | None, datetime | None, bool]:
if self.llm is None: if self.planner_llm is None:
return user_text, None, None, False return user_text, None, None, False
recent_conversation = _render_recent_conversation( recent_conversation = _render_recent_conversation(
messages, messages,
llm=self.llm, llm=self.planner_llm,
user_text=user_text, user_text=user_text,
) )
prompt = _build_kb_planner_prompt( prompt = _build_kb_planner_prompt(
@ -663,8 +673,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
t0 = loop.time() t0 = loop.time()
# Prefer the compiled-once planner Runnable when enabled; otherwise # Prefer the compiled-once planner Runnable when enabled; otherwise
# fall back to ``self.llm.ainvoke``. The ``surfsense:internal`` tag # fall back to ``planner_llm.ainvoke``. The ``surfsense:internal``
# is preserved on both paths so ``_stream_agent_events`` still # tag is preserved on both paths so ``_stream_agent_events`` still
# suppresses the planner's intermediate events from the UI. # suppresses the planner's intermediate events from the UI.
planner = self._build_kb_planner_runnable() planner = self._build_kb_planner_runnable()
try: try:
@ -684,7 +694,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
else AIMessage(content="") else AIMessage(content="")
) )
else: else:
response = await self.llm.ainvoke( response = await self.planner_llm.ainvoke(
[HumanMessage(content=prompt)], [HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]}, config={"tags": ["surfsense:internal"]},
) )

View file

@ -24,6 +24,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import time
from typing import Any from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState from langchain.agents.middleware import AgentMiddleware, AgentState
@ -41,6 +42,9 @@ from app.agents.new_chat.path_resolver import (
doc_to_virtual_path, doc_to_virtual_path,
) )
from app.db import Document, shielded_async_session from app.db import Document, shielded_async_session
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
try: try:
from litellm import token_counter from litellm import token_counter
@ -124,6 +128,7 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
if self.filesystem_mode != FilesystemMode.CLOUD: if self.filesystem_mode != FilesystemMode.CLOUD:
return None return None
start = time.perf_counter()
update: dict[str, Any] = {} update: dict[str, Any] = {}
if not state.get("cwd"): if not state.get("cwd"):
update["cwd"] = DOCUMENTS_ROOT update["cwd"] = DOCUMENTS_ROOT
@ -131,7 +136,11 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
anon_doc = state.get("kb_anon_doc") anon_doc = state.get("kb_anon_doc")
if anon_doc: if anon_doc:
tree_msg = self._render_anon_tree(anon_doc) tree_msg = self._render_anon_tree(anon_doc)
cache_outcome = "anon"
else: else:
version = int(state.get("tree_version") or 0)
cache_key = (self.search_space_id, version, False)
cache_outcome = "hit" if cache_key in self._cache else "miss"
tree_msg = await self._render_kb_tree(state) tree_msg = await self._render_kb_tree(state)
update["workspace_tree_text"] = tree_msg update["workspace_tree_text"] = tree_msg
@ -141,6 +150,14 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
insert_at = max(len(messages) - 1, 0) insert_at = max(len(messages) - 1, 0)
messages.insert(insert_at, SystemMessage(content=tree_msg)) messages.insert(insert_at, SystemMessage(content=tree_msg))
update["messages"] = messages update["messages"] = messages
_perf_log.info(
"[knowledge_tree] cache=%s chars=%d elapsed=%.3fs space=%d",
cache_outcome,
len(tree_msg),
time.perf_counter() - start,
self.search_space_id,
)
return update return update
def before_agent( # type: ignore[override] def before_agent( # type: ignore[override]

View file

@ -8,6 +8,7 @@ Injects memory markdown into the system prompt on every turn:
from __future__ import annotations from __future__ import annotations
import logging import logging
import time
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID
@ -19,8 +20,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db import ChatVisibility, SearchSpace, User, shielded_async_session from app.db import ChatVisibility, SearchSpace, User, shielded_async_session
from app.services.memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT from app.services.memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg] class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
@ -53,9 +56,13 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
if not isinstance(last_message, HumanMessage): if not isinstance(last_message, HumanMessage):
return None return None
start = time.perf_counter()
db_elapsed = 0.0
memory_blocks: list[str] = [] memory_blocks: list[str] = []
scope = "team" if self.visibility == ChatVisibility.SEARCH_SPACE else "user"
async with shielded_async_session() as session: async with shielded_async_session() as session:
db_start = time.perf_counter()
if self.visibility == ChatVisibility.SEARCH_SPACE: if self.visibility == ChatVisibility.SEARCH_SPACE:
team_memory = await self._load_team_memory(session) team_memory = await self._load_team_memory(session)
if team_memory: if team_memory:
@ -96,7 +103,15 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
f"</memory_warning>" f"</memory_warning>"
) )
db_elapsed = time.perf_counter() - db_start
if not memory_blocks: if not memory_blocks:
_perf_log.info(
"[memory_injection] scope=%s injected=0 db=%.3fs total=%.3fs",
scope,
db_elapsed,
time.perf_counter() - start,
)
return None return None
memory_text = "\n\n".join(memory_blocks) memory_text = "\n\n".join(memory_blocks)
@ -106,6 +121,13 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
insert_idx = 1 if len(new_messages) > 1 else 0 insert_idx = 1 if len(new_messages) > 1 else 0
new_messages.insert(insert_idx, memory_msg) new_messages.insert(insert_idx, memory_msg)
_perf_log.info(
"[memory_injection] scope=%s injected=1 chars=%d db=%.3fs total=%.3fs",
scope,
len(memory_text),
db_elapsed,
time.perf_counter() - start,
)
return {"messages": new_messages} return {"messages": new_messages}
async def _load_user_memory( async def _load_user_memory(

View file

@ -39,9 +39,19 @@ For OpenAI-family configs we additionally pass:
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` routing hint that - ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` routing hint that
raises hit rate by sending requests with a shared prefix to the same raises hit rate by sending requests with a shared prefix to the same
backend. backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and
``azure/`` (added to LiteLLM's Azure transformer in
https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified
against ``AzureOpenAIConfig.get_supported_openai_params`` in our
installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``,
``azure/gpt-5.4``, ``azure/gpt-5.4-mini``).
- ``prompt_cache_retention="24h"`` extends cache TTL beyond the default - ``prompt_cache_retention="24h"`` extends cache TTL beyond the default
5-10 min in-memory cache. 5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's
server-side support landed in Microsoft's docs on 2026-05-13 but
LiteLLM 1.83.14's Azure transformer still omits it from its supported
params list, so it gets silently dropped by ``litellm.drop_params``.
Azure's default in-memory retention (5-10 min, max 1 h) already
bridges intra-conversation turns; revisit when LiteLLM bumps Azure.
Safety net: ``litellm.drop_params=True`` is set globally in Safety net: ``litellm.drop_params=True`` is set globally in
``app.services.llm_service`` at module-load time. Any kwarg the destination ``app.services.llm_service`` at module-load time. Any kwarg the destination
@ -81,13 +91,31 @@ _DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
{"location": "message", "index": -1}, {"location": "message", "index": -1},
) )
# Providers (uppercase ``AgentConfig.provider`` values) that natively expose # Providers (uppercase ``AgentConfig.provider`` values) that accept the
# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and # OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs
# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers # (2026-05-13) confirm automatic prompt caching applies to every GPT-4o
# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without # or newer Azure deployment at ≥1024 tokens with no configuration needed,
# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU, # and that ``prompt_cache_key`` is combined with the prefix hash to
# MINIMAX), so we can't infer family from the litellm prefix alone. # improve routing affinity and therefore cache hit rate. LiteLLM's Azure
_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"}) # transformer ships ``prompt_cache_key`` in its supported params as of
# https://github.com/BerriAI/litellm/pull/20989.
#
# Strict whitelist — many other providers in ``PROVIDER_MAP`` route
# through litellm's ``openai`` prefix without implementing the OpenAI
# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer
# family from the litellm prefix alone.
_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset(
{"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"}
)
# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept
# ``prompt_cache_retention="24h"``. Azure is excluded: see module
# docstring — LiteLLM 1.83.14's Azure transformer omits the param so
# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM
# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``.
_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset(
{"OPENAI", "DEEPSEEK", "XAI"}
)
def _is_router_llm(llm: BaseChatModel) -> bool: def _is_router_llm(llm: BaseChatModel) -> bool:
@ -101,13 +129,13 @@ def _is_router_llm(llm: BaseChatModel) -> bool:
return type(llm).__name__ == "ChatLiteLLMRouter" return type(llm).__name__ == "ChatLiteLLMRouter"
def _is_openai_family_config(agent_config: AgentConfig | None) -> bool: def _provider_supports_prompt_cache_key(agent_config: AgentConfig | None) -> bool:
"""Whether the config targets an OpenAI-style prompt-cache surface. """Whether the config targets a provider that accepts ``prompt_cache_key``.
Strict only returns True when the user explicitly chose OPENAI, Strict only returns True for explicitly chosen OPENAI, DEEPSEEK,
DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` / XAI, AZURE, or AZURE_OPENAI providers. Auto-mode and custom
``YAMLConfig``. Auto-mode and custom providers return False because providers return False because we can't statically know the
we can't statically know the destination. destination and the router fans out across mixed providers.
""" """
if agent_config is None or not agent_config.provider: if agent_config is None or not agent_config.provider:
return False return False
@ -115,7 +143,25 @@ def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
return False return False
if agent_config.custom_provider: if agent_config.custom_provider:
return False return False
return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS return agent_config.provider.upper() in _PROMPT_CACHE_KEY_PROVIDERS
def _provider_supports_prompt_cache_retention(
agent_config: AgentConfig | None,
) -> bool:
"""Whether the config targets a provider that accepts ``prompt_cache_retention``.
Tighter than :func:`_provider_supports_prompt_cache_key` Azure
deployments are excluded until LiteLLM ships the param in its Azure
transformer (see module docstring).
"""
if agent_config is None or not agent_config.provider:
return False
if agent_config.is_auto_mode:
return False
if agent_config.custom_provider:
return False
return agent_config.provider.upper() in _PROMPT_CACHE_RETENTION_PROVIDERS
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None: def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
@ -173,16 +219,23 @@ def apply_litellm_prompt_caching(
dict(point) for point in _DEFAULT_INJECTION_POINTS dict(point) for point in _DEFAULT_INJECTION_POINTS
] ]
# OpenAI-family extras only when we statically know the destination is # OpenAI-style extras only when we statically know the destination
# OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers # accepts them. Auto-mode router fans out across mixed providers so
# so we can't safely set OpenAI-only kwargs there (drop_params would # we can't safely set destination-specific kwargs there (drop_params
# strip them but it's wasteful to set them in the first place). # would strip them but it's wasteful to set them in the first
# place).
if _is_router_llm(llm): if _is_router_llm(llm):
return return
if not _is_openai_family_config(agent_config):
return
if thread_id is not None and "prompt_cache_key" not in model_kwargs: if (
thread_id is not None
and "prompt_cache_key" not in model_kwargs
and _provider_supports_prompt_cache_key(agent_config)
):
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}" model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
if "prompt_cache_retention" not in model_kwargs:
if (
"prompt_cache_retention" not in model_kwargs
and _provider_supports_prompt_cache_retention(agent_config)
):
model_kwargs["prompt_cache_retention"] = "24h" model_kwargs["prompt_cache_retention"] = "24h"

View file

@ -36,8 +36,16 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.agents.new_chat.tools.mcp_client import MCPClient from app.agents.new_chat.tools.mcp_client import MCPClient
from app.agents.new_chat.tools.mcp_tools_cache import (
CachedMCPTools,
read_cached_tools,
write_cached_tools,
)
from app.db import SearchSourceConnector from app.db import SearchSourceConnector
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -293,15 +301,21 @@ async def _create_mcp_tool_from_definition_http(
timeout: float = 60.0, timeout: float = 60.0,
) -> str: ) -> str:
"""Execute a single MCP HTTP call with the given headers.""" """Execute a single MCP HTTP call with the given headers."""
call_start = time.perf_counter()
async with ( async with (
streamablehttp_client(url, headers=call_headers) as (read, write, _), streamablehttp_client(url, headers=call_headers) as (read, write, _),
ClientSession(read, write) as session, ClientSession(read, write) as session,
): ):
init_start = time.perf_counter()
await session.initialize() await session.initialize()
init_elapsed = time.perf_counter() - init_start
tool_start = time.perf_counter()
response = await asyncio.wait_for( response = await asyncio.wait_for(
session.call_tool(original_tool_name, arguments=call_kwargs), session.call_tool(original_tool_name, arguments=call_kwargs),
timeout=timeout, timeout=timeout,
) )
tool_elapsed = time.perf_counter() - tool_start
result = [] result = []
for content in response.content: for content in response.content:
@ -312,7 +326,18 @@ async def _create_mcp_tool_from_definition_http(
else: else:
result.append(str(content)) result.append(str(content))
return "\n".join(result) if result else "" payload = "\n".join(result) if result else ""
_perf_log.info(
"[mcp_http_call] connector=%s tool=%s init=%.3fs call=%.3fs total=%.3fs out_chars=%d",
connector_id,
original_tool_name,
init_elapsed,
tool_elapsed,
time.perf_counter() - call_start,
len(payload),
)
return payload
async def mcp_http_tool_call(**kwargs) -> str: async def mcp_http_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via HTTP transport.""" """Execute the MCP tool call via HTTP transport."""
@ -496,6 +521,7 @@ async def _load_http_mcp_tools(
is_generic_mcp: bool = False, is_generic_mcp: bool = False,
*, *,
bypass_internal_hitl: bool = False, bypass_internal_hitl: bool = False,
cached_tools: CachedMCPTools | None = None,
) -> list[StructuredTool]: ) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server. """Load tools from an HTTP-based MCP server.
@ -506,6 +532,8 @@ async def _load_http_mcp_tools(
readonly_tools: Tool names that skip HITL approval (read-only operations). readonly_tools: Tool names that skip HITL approval (read-only operations).
tool_name_prefix: If set, each tool name is prefixed for multi-account tool_name_prefix: If set, each tool name is prefixed for multi-account
disambiguation (e.g. ``linear_25``). disambiguation (e.g. ``linear_25``).
cached_tools: If provided, skip live discovery and rebuild wrappers
from the persisted definitions.
""" """
tools: list[StructuredTool] = [] tools: list[StructuredTool] = []
@ -529,15 +557,23 @@ async def _load_http_mcp_tools(
allowed_set = set(allowed_tools) if allowed_tools else None allowed_set = set(allowed_tools) if allowed_tools else None
async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]: async def _discover(
"""Connect, initialize, and list tools from the MCP server.""" disc_headers: dict[str, str],
) -> tuple[dict[str, str | None], list[dict[str, Any]]]:
"""Connect, initialize, and list tools — returns (serverInfo, tools)."""
async with ( async with (
streamablehttp_client(url, headers=disc_headers) as (read, write, _), streamablehttp_client(url, headers=disc_headers) as (read, write, _),
ClientSession(read, write) as session, ClientSession(read, write) as session,
): ):
await session.initialize() init_result = await session.initialize()
server_info: dict[str, str | None] = {"name": None, "version": None}
si = getattr(init_result, "serverInfo", None)
if si is not None:
server_info["name"] = getattr(si, "name", None)
server_info["version"] = getattr(si, "version", None)
response = await session.list_tools() response = await session.list_tools()
return [ return server_info, [
{ {
"name": tool.name, "name": tool.name,
"description": tool.description or "", "description": tool.description or "",
@ -548,47 +584,65 @@ async def _load_http_mcp_tools(
for tool in response.tools for tool in response.tools
] ]
try: if cached_tools is not None:
tool_definitions = await _discover(headers) tool_definitions = [
except Exception as first_err: {
if not _is_auth_error(first_err) or connector_id is None: "name": td.name,
logger.exception( "description": td.description,
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s", "input_schema": td.input_schema,
url, }
connector_id, for td in cached_tools.tools
first_err, ]
) else:
return tools
logger.warning(
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
connector_id,
)
fresh_headers = await _force_refresh_and_get_headers(connector_id)
if fresh_headers is None:
await _mark_connector_auth_expired(connector_id)
logger.error(
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
connector_id,
)
return tools
try: try:
tool_definitions = await _discover(fresh_headers) server_info, tool_definitions = await _discover(headers)
headers = fresh_headers except Exception as first_err:
logger.info( if not _is_auth_error(first_err) or connector_id is None:
"HTTP MCP discovery for connector %d succeeded after 401 recovery", logger.exception(
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
url,
connector_id,
first_err,
)
return tools
logger.warning(
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
connector_id, connector_id,
) )
except Exception as retry_err: fresh_headers = await _force_refresh_and_get_headers(connector_id)
logger.exception( if fresh_headers is None:
"HTTP MCP discovery for connector %d still failing after refresh: %s",
connector_id,
retry_err,
)
if _is_auth_error(retry_err):
await _mark_connector_auth_expired(connector_id) await _mark_connector_auth_expired(connector_id)
return tools logger.error(
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
connector_id,
)
return tools
try:
server_info, tool_definitions = await _discover(fresh_headers)
headers = fresh_headers
logger.info(
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
connector_id,
)
except Exception as retry_err:
logger.exception(
"HTTP MCP discovery for connector %d still failing after refresh: %s",
connector_id,
retry_err,
)
if _is_auth_error(retry_err):
await _mark_connector_auth_expired(connector_id)
return tools
await write_cached_tools(
connector_id,
tool_definitions,
server_name=server_info.get("name"),
server_version=server_info.get("version"),
transport=server_config.get("transport", "streamable-http"),
)
total_discovered = len(tool_definitions) total_discovered = len(tool_definitions)
@ -792,14 +846,25 @@ async def _maybe_refresh_mcp_oauth_token(
except (ValueError, TypeError): except (ValueError, TypeError):
return server_config return server_config
refresh_start = time.perf_counter()
try: try:
new_access = await _refresh_connector_token(session, connector) new_access = await _refresh_connector_token(session, connector)
if not new_access: if not new_access:
_perf_log.info(
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=no_token",
connector.id,
time.perf_counter() - refresh_start,
)
return server_config return server_config
logger.info( logger.info(
"Proactively refreshed MCP OAuth token for connector %s", connector.id "Proactively refreshed MCP OAuth token for connector %s", connector.id
) )
_perf_log.info(
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=refreshed",
connector.id,
time.perf_counter() - refresh_start,
)
refreshed_config = dict(server_config) refreshed_config = dict(server_config)
refreshed_config["headers"] = { refreshed_config["headers"] = {
@ -809,6 +874,11 @@ async def _maybe_refresh_mcp_oauth_token(
return refreshed_config return refreshed_config
except Exception: except Exception:
_perf_log.info(
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=failed",
connector.id,
time.perf_counter() - refresh_start,
)
logger.warning( logger.warning(
"Failed to refresh MCP OAuth token for connector %s", "Failed to refresh MCP OAuth token for connector %s",
connector.id, connector.id,
@ -937,6 +1007,94 @@ def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
_mcp_tools_cache.clear() _mcp_tools_cache.clear()
async def discover_single_mcp_connector(connector_id: int) -> None:
"""Force live MCP discovery for one connector so its ``cached_tools`` row is fresh.
``_load_http_mcp_tools`` persists ``cached_tools`` as a side effect of any
live discovery; passing ``cached_tools=None`` here guarantees we go to the
network. The returned wrappers are discarded the in-process LRU is
rebuilt lazily on the next user query. Stdio connectors are not cached and
are skipped.
"""
from app.db import async_session_maker
started = time.perf_counter()
try:
async with async_session_maker() as session:
connector = await session.get(SearchSourceConnector, connector_id)
if connector is None:
logger.info(
"discover_single_mcp_connector: connector %d not found",
connector_id,
)
return
cfg = connector.config or {}
server_config = cfg.get("server_config", {})
if not server_config or not isinstance(server_config, dict):
return
transport = server_config.get("transport", "stdio")
if transport not in ("streamable-http", "http", "sse"):
return
if cfg.get("mcp_oauth"):
server_config = await _maybe_refresh_mcp_oauth_token(
session, connector, cfg, server_config
)
cfg = connector.config or {}
server_config = _inject_oauth_headers(cfg, server_config)
if server_config is None:
logger.info(
"discover_single_mcp_connector: OAuth token unavailable for connector %d",
connector_id,
)
return
ct = (
connector.connector_type.value
if hasattr(connector.connector_type, "value")
else str(connector.connector_type)
)
svc_cfg = get_service_by_connector_type(ct)
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
await asyncio.wait_for(
_load_http_mcp_tools(
connector.id,
connector.name,
server_config,
trusted_tools=cfg.get("trusted_tools", []),
allowed_tools=allowed_tools,
readonly_tools=readonly_tools,
tool_name_prefix=None,
is_generic_mcp=svc_cfg is None,
bypass_internal_hitl=True,
cached_tools=None,
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
_perf_log.info(
"[mcp_prefetch] connector=%s elapsed=%.3fs",
connector_id,
time.perf_counter() - started,
)
except TimeoutError:
logger.warning(
"discover_single_mcp_connector: connector %d timed out after %ds",
connector_id,
_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
except Exception:
logger.warning(
"discover_single_mcp_connector: failed for connector %d",
connector_id,
exc_info=True,
)
async def load_mcp_tools( async def load_mcp_tools(
session: AsyncSession, session: AsyncSession,
search_space_id: int, search_space_id: int,
@ -1063,6 +1221,7 @@ async def load_mcp_tools(
"tool_name_prefix": tool_name_prefix, "tool_name_prefix": tool_name_prefix,
"transport": server_config.get("transport", "stdio"), "transport": server_config.get("transport", "stdio"),
"is_generic_mcp": svc_cfg is None, "is_generic_mcp": svc_cfg is None,
"cached_tools": read_cached_tools(connector),
} }
) )
@ -1074,9 +1233,12 @@ async def load_mcp_tools(
) )
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]: async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
discover_start = time.perf_counter()
transport = task["transport"]
cached_tools = task.get("cached_tools")
try: try:
if task["transport"] in ("streamable-http", "http", "sse"): if transport in ("streamable-http", "http", "sse"):
return await asyncio.wait_for( result = await asyncio.wait_for(
_load_http_mcp_tools( _load_http_mcp_tools(
task["connector_id"], task["connector_id"],
task["connector_name"], task["connector_name"],
@ -1087,11 +1249,12 @@ async def load_mcp_tools(
tool_name_prefix=task["tool_name_prefix"], tool_name_prefix=task["tool_name_prefix"],
is_generic_mcp=task.get("is_generic_mcp", False), is_generic_mcp=task.get("is_generic_mcp", False),
bypass_internal_hitl=bypass_internal_hitl, bypass_internal_hitl=bypass_internal_hitl,
cached_tools=cached_tools,
), ),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
) )
else: else:
return await asyncio.wait_for( result = await asyncio.wait_for(
_load_stdio_mcp_tools( _load_stdio_mcp_tools(
task["connector_id"], task["connector_id"],
task["connector_name"], task["connector_name"],
@ -1101,7 +1264,24 @@ async def load_mcp_tools(
), ),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
) )
_perf_log.info(
"[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs cache=%s",
task["connector_id"],
task["connector_name"],
transport,
len(result),
time.perf_counter() - discover_start,
"hit" if cached_tools is not None else "miss",
)
return result
except TimeoutError: except TimeoutError:
_perf_log.info(
"[mcp_discover] connector=%s name=%r transport=%s elapsed=%.3fs outcome=timeout",
task["connector_id"],
task["connector_name"],
transport,
time.perf_counter() - discover_start,
)
logger.error( logger.error(
"MCP connector %d timed out after %ds during discovery", "MCP connector %d timed out after %ds during discovery",
task["connector_id"], task["connector_id"],
@ -1109,6 +1289,13 @@ async def load_mcp_tools(
) )
return [] return []
except Exception as e: except Exception as e:
_perf_log.info(
"[mcp_discover] connector=%s name=%r transport=%s elapsed=%.3fs outcome=error",
task["connector_id"],
task["connector_name"],
transport,
time.perf_counter() - discover_start,
)
logger.exception( logger.exception(
"Failed to load tools from MCP connector %d: %s", "Failed to load tools from MCP connector %d: %s",
task["connector_id"], task["connector_id"],
@ -1116,7 +1303,14 @@ async def load_mcp_tools(
) )
return [] return []
gather_start = time.perf_counter()
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks]) results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
_perf_log.info(
"[mcp_discover] gather_wall=%.3fs connectors=%d total_tools=%d",
time.perf_counter() - gather_start,
len(discovery_tasks),
sum(len(r) for r in results),
)
tools: list[StructuredTool] = [tool for sublist in results for tool in sublist] tools: list[StructuredTool] = [tool for sublist in results for tool in sublist]
_mcp_tools_cache[cache_key] = (now, tools) _mcp_tools_cache[cache_key] = (now, tools)

View file

@ -0,0 +1,145 @@
"""Persist MCP ``list_tools`` results in ``SearchSourceConnector.config.cached_tools``."""
from __future__ import annotations
import asyncio
import logging
from datetime import UTC, datetime
from typing import Any
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import SearchSourceConnector, async_session_maker
logger = logging.getLogger(__name__)
_pending_prefetch_tasks: set[asyncio.Task[None]] = set()
class CachedMCPToolDef(BaseModel):
name: str
description: str = ""
input_schema: dict[str, Any] = Field(default_factory=dict)
class CachedMCPTools(BaseModel):
discovered_at: datetime
server_version: str | None = None
server_name: str | None = None
transport: str | None = None
tools: list[CachedMCPToolDef]
def read_cached_tools(connector: SearchSourceConnector) -> CachedMCPTools | None:
"""Return parsed cached tools or ``None`` if missing / corrupt (caller falls back to live discovery)."""
cfg = connector.config or {}
raw = cfg.get("cached_tools")
if not raw or not isinstance(raw, dict):
return None
try:
return CachedMCPTools.model_validate(raw)
except ValidationError as exc:
logger.warning(
"MCP connector %d has corrupt cached_tools — falling back to live discovery: %s",
connector.id,
exc,
)
return None
async def write_cached_tools(
connector_id: int,
tool_definitions: list[dict[str, Any]],
*,
server_name: str | None = None,
server_version: str | None = None,
transport: str | None = None,
) -> None:
"""Best-effort persist; uses its own session so a write failure cannot poison the caller's transaction."""
payload = CachedMCPTools(
discovered_at=datetime.now(UTC),
server_version=server_version,
server_name=server_name,
transport=transport,
tools=[CachedMCPToolDef.model_validate(td) for td in tool_definitions],
)
try:
async with async_session_maker() as session:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
)
)
connector = result.scalars().first()
if connector is None:
return
cfg = dict(connector.config or {})
cfg["cached_tools"] = payload.model_dump(mode="json")
connector.config = cfg
flag_modified(connector, "config")
await session.commit()
logger.info(
"Persisted cached_tools for MCP connector %d (%d tools)",
connector_id,
len(payload.tools),
)
except Exception:
logger.warning(
"Failed to persist cached_tools for MCP connector %d",
connector_id,
exc_info=True,
)
def refresh_mcp_tools_cache_for_connector(
connector_id: int,
search_space_id: int,
) -> None:
"""Maintain the MCP tool cache after a single-connector lifecycle event.
Synchronously evicts the in-process LRU for the connector's search space
(LRU keys are per-space, so eviction cannot be scoped finer), then schedules
a background live discovery for this connector alone so its persisted
``cached_tools`` row is refreshed before the next user query.
Idempotent. Eviction is best-effort; prefetch is best-effort and only runs
when an event loop is available. Neither path raises.
"""
try:
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(search_space_id)
except Exception:
logger.debug(
"MCP in-process cache eviction skipped for space %d",
search_space_id,
exc_info=True,
)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
task = loop.create_task(_run_connector_prefetch(connector_id))
_pending_prefetch_tasks.add(task)
task.add_done_callback(_pending_prefetch_tasks.discard)
async def _run_connector_prefetch(connector_id: int) -> None:
from app.agents.new_chat.tools.mcp_tool import discover_single_mcp_connector
try:
await discover_single_mcp_connector(connector_id)
except Exception:
logger.warning(
"MCP background prefetch failed for connector_id=%d",
connector_id,
exc_info=True,
)

View file

@ -110,6 +110,19 @@ def load_global_llm_configs():
except Exception as e: except Exception as e:
print(f"Warning: Failed to score global LLM configs: {e}") print(f"Warning: Failed to score global LLM configs: {e}")
# Planner LLM is a singleton role. If an operator accidentally
# marks multiple configs ``is_planner: true``, only the first one
# is used at runtime — surface the others at startup so the
# mistake is caught before traffic, not silently buried.
planner_cfgs = [c for c in configs if c.get("is_planner") is True]
if len(planner_cfgs) > 1:
extra_ids = [c.get("id") for c in planner_cfgs[1:]]
print(
"Warning: Multiple global LLM configs marked is_planner=true "
f"(ids {[c.get('id') for c in planner_cfgs]}); using id "
f"{planner_cfgs[0].get('id')} and ignoring {extra_ids}"
)
return configs return configs
except Exception as e: except Exception as e:
print(f"Warning: Failed to load global LLM configs: {e}") print(f"Warning: Failed to load global LLM configs: {e}")

View file

@ -258,6 +258,45 @@ global_llm_configs:
use_default_system_instructions: true use_default_system_instructions: true
citations_enabled: true citations_enabled: true
# Example: Planner LLM - small, fast model used for internal utility tasks
#
# The PLANNER role handles short, structured internal calls (KB query
# rewriting, date extraction, recency classification, etc.) that don't
# need frontier-tier capability. Pointing the planner at a cheap+fast
# model (gpt-4o-mini, Claude Haiku, Azure gpt-5.x-nano, Groq Llama, ...)
# typically saves 500ms-1.5s per turn vs. routing those same internal
# calls through the user's chat model.
#
# Activation:
# - Mark EXACTLY ONE global config with ``is_planner: true``.
# - If multiple are marked, the first one wins and a WARNING is logged.
# - If none is marked, every internal call falls back to the user's
# chat LLM (same behavior as before this flag existed).
#
# This config is operator-only — it is NOT exposed in the user-facing
# model selector, never billed against premium quota, and the
# billing_tier / anonymous_enabled fields below are ignored.
- id: -9
name: "Global Planner (GPT-4o mini)"
description: "Internal-only planner LLM for query rewriting and classification"
is_planner: true
billing_tier: "free"
anonymous_enabled: false
seo_enabled: false
quota_reserve_tokens: 1000
provider: "OPENAI"
model_name: "gpt-4o-mini"
api_key: "sk-your-openai-api-key-here"
api_base: ""
rpm: 3500
tpm: 200000
litellm_params:
temperature: 0
max_tokens: 1000
system_instructions: ""
use_default_system_instructions: true
citations_enabled: false
# ============================================================================= # =============================================================================
# OpenRouter Integration # OpenRouter Integration
# ============================================================================= # =============================================================================
@ -493,6 +532,20 @@ global_vision_llm_configs:
# - Lower temperature (0.3) is recommended for accurate screenshot analysis # - Lower temperature (0.3) is recommended for accurate screenshot analysis
# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions # - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions
# #
# PLANNER LLM NOTES:
# - is_planner: true marks a config as the internal-only planner LLM (small,
# fast model used for KB query rewriting, date extraction, recency
# classification, etc.). Only one config may carry this flag — if
# multiple do, the first one wins and a startup WARNING is logged.
# - When no config is marked is_planner, every internal utility call falls
# back to the user's chat LLM (the historical behavior).
# - Planner configs are NOT shown in the user-facing model selector and
# are NOT billed against the user's premium quota. Their billing_tier,
# anonymous_enabled, seo_* fields are ignored.
# - Recommended models: gpt-4o-mini, claude-3-5-haiku, gemini-1.5-flash,
# azure gpt-5.x-nano, groq llama3-8b — anything <200ms p50 on a 1-2k
# prompt. Frontier models here defeat the purpose of the flag.
#
# TOKEN QUOTA & ANONYMOUS ACCESS NOTES: # TOKEN QUOTA & ANONYMOUS ACCESS NOTES:
# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota. # - billing_tier: "free" or "premium". Controls whether registered users need premium token quota.
# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog. # - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog.

View file

@ -428,7 +428,7 @@ async def mcp_oauth_callback(
await session.commit() await session.commit()
await session.refresh(db_connector) await session.refresh(db_connector)
_invalidate_cache(space_id) _refresh_mcp_cache(db_connector.id, space_id)
logger.info( logger.info(
"Re-authenticated %s MCP connector %s for user %s", "Re-authenticated %s MCP connector %s for user %s",
@ -481,7 +481,7 @@ async def mcp_oauth_callback(
detail="A connector for this service already exists.", detail="A connector for this service already exists.",
) from e ) from e
_invalidate_cache(space_id) _refresh_mcp_cache(new_connector.id, space_id)
logger.info( logger.info(
"Created %s MCP connector %s for user %s in space %s", "Created %s MCP connector %s for user %s in space %s",
@ -658,10 +658,17 @@ async def reauth_mcp_service(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _invalidate_cache(space_id: int) -> None: def _refresh_mcp_cache(connector_id: int, space_id: int) -> None:
try: """Evict the in-process MCP tool LRU and schedule background prefetch.
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(space_id) Wraps :func:`refresh_mcp_tools_cache_for_connector` so any failure is
isolated from the OAuth response flow.
"""
try:
from app.agents.new_chat.tools.mcp_tools_cache import (
refresh_mcp_tools_cache_for_connector,
)
refresh_mcp_tools_cache_for_connector(connector_id, space_id)
except Exception: except Exception:
logger.debug("MCP cache invalidation skipped", exc_info=True) logger.debug("MCP cache refresh skipped", exc_info=True)

View file

@ -2650,9 +2650,11 @@ async def create_mcp_connector(
f"for user {user.id} in search space {search_space_id}" f"for user {user.id} in search space {search_space_id}"
) )
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache from app.agents.new_chat.tools.mcp_tools_cache import (
refresh_mcp_tools_cache_for_connector,
)
invalidate_mcp_tools_cache(search_space_id) refresh_mcp_tools_cache_for_connector(db_connector.id, search_space_id)
connector_read = SearchSourceConnectorRead.model_validate(db_connector) connector_read = SearchSourceConnectorRead.model_validate(db_connector)
return MCPConnectorRead.from_connector(connector_read) return MCPConnectorRead.from_connector(connector_read)
@ -2828,9 +2830,11 @@ async def update_mcp_connector(
logger.info(f"Updated MCP connector {connector_id}") logger.info(f"Updated MCP connector {connector_id}")
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache from app.agents.new_chat.tools.mcp_tools_cache import (
refresh_mcp_tools_cache_for_connector,
)
invalidate_mcp_tools_cache(connector.search_space_id) refresh_mcp_tools_cache_for_connector(connector.id, connector.search_space_id)
connector_read = SearchSourceConnectorRead.model_validate(connector) connector_read = SearchSourceConnectorRead.model_validate(connector)
return MCPConnectorRead.from_connector(connector_read) return MCPConnectorRead.from_connector(connector_read)

View file

@ -1,3 +1,4 @@
import asyncio
import logging import logging
from datetime import datetime from datetime import datetime
@ -100,7 +101,9 @@ class GmailKBSyncService:
else: else:
logger.warning("No LLM configured -- using fallback summary") logger.warning("No LLM configured -- using fallback summary")
summary_content = f"Gmail Message: {subject}\n\n{indexable_content}" summary_content = f"Gmail Message: {subject}\n\n{indexable_content}"
summary_embedding = embed_text(summary_content) summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(indexable_content) chunks = await create_document_chunks(indexable_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View file

@ -116,7 +116,9 @@ class GoogleCalendarKBSyncService:
summary_content = ( summary_content = (
f"Google Calendar Event: {event_summary}\n\n{indexable_content}" f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
) )
summary_embedding = embed_text(summary_content) summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(indexable_content) chunks = await create_document_chunks(indexable_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -295,7 +297,9 @@ class GoogleCalendarKBSyncService:
summary_content = ( summary_content = (
f"Google Calendar Event: {event_summary}\n\n{indexable_content}" f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
) )
summary_embedding = embed_text(summary_content) summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(indexable_content) chunks = await create_document_chunks(indexable_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View file

@ -98,7 +98,9 @@ class JiraKBSyncService:
summary_content = ( summary_content = (
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}" f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
) )
summary_embedding = embed_text(summary_content) summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(issue_content) chunks = await create_document_chunks(issue_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -212,7 +214,9 @@ class JiraKBSyncService:
summary_content = ( summary_content = (
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}" f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
) )
summary_embedding = embed_text(summary_content) summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(issue_content) chunks = await create_document_chunks(issue_content)

View file

@ -659,3 +659,36 @@ async def get_user_long_context_llm(
return await get_document_summary_llm( return await get_document_summary_llm(
session, search_space_id, disable_streaming=disable_streaming session, search_space_id, disable_streaming=disable_streaming
) )
def get_planner_llm() -> ChatLiteLLM | None:
"""Return a planner LLM instance from the first global config marked
``is_planner: true``, or ``None`` if no planner config is defined.
The planner role handles short, structured internal tasks (KB search
planning: query rewriting, date extraction, recency classification).
These tasks are well-served by small/fast models (e.g. gpt-4o-mini,
Claude Haiku, Azure gpt-5.x-nano) using the user's chat LLM for them
is unnecessarily expensive and slow.
This helper reads from ``config.GLOBAL_LLM_CONFIGS`` (loaded at import
time from ``global_llm_config.yaml``) so it has no DB cost and can be
called synchronously from middleware/factory code. It returns the same
instance shape as the global path of ``get_search_space_llm_instance``.
Callers MUST fall back to their chat LLM when this returns ``None`` so
deployments without a planner config keep working unchanged.
"""
from app.agents.new_chat.llm_config import create_chat_litellm_from_config
planner_cfg = next(
(
cfg
for cfg in config.GLOBAL_LLM_CONFIGS
if cfg.get("is_planner") is True
),
None,
)
if not planner_cfg:
return None
return create_chat_litellm_from_config(planner_cfg)

View file

@ -1,3 +1,4 @@
import asyncio
import logging import logging
from datetime import datetime from datetime import datetime
@ -95,7 +96,9 @@ class OneDriveKBSyncService:
else: else:
logger.warning("No LLM configured — using fallback summary") logger.warning("No LLM configured — using fallback summary")
summary_content = f"OneDrive File: {file_name}\n\n{indexable_content}" summary_content = f"OneDrive File: {file_name}\n\n{indexable_content}"
summary_embedding = embed_text(summary_content) summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(indexable_content) chunks = await create_document_chunks(indexable_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View file

@ -29,6 +29,7 @@ same trap waiting to happen).
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime from datetime import UTC, datetime
@ -234,7 +235,7 @@ async def _restore_in_place_document(
if isinstance(c, dict) and isinstance(c.get("content"), str) if isinstance(c, dict) and isinstance(c.get("content"), str)
] ]
if chunk_texts: if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts) chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
session.add_all( session.add_all(
[ [
Chunk(document_id=doc.id, content=text, embedding=embedding) Chunk(document_id=doc.id, content=text, embedding=embedding)
@ -244,7 +245,9 @@ async def _restore_in_place_document(
] ]
) )
if isinstance(revision.content_before, str): if isinstance(revision.content_before, str):
doc.embedding = embed_texts([revision.content_before])[0] doc.embedding = (
await asyncio.to_thread(embed_texts, [revision.content_before])
)[0]
doc.updated_at = datetime.now(UTC) doc.updated_at = datetime.now(UTC)
return RevertOutcome(status="ok", message="Document restored from snapshot.") return RevertOutcome(status="ok", message="Document restored from snapshot.")
@ -320,7 +323,7 @@ async def _reinsert_document_from_revision(
session.add(new_doc) session.add(new_doc)
await session.flush() await session.flush()
new_doc.embedding = embed_texts([content])[0] new_doc.embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
chunk_texts = [] chunk_texts = []
chunks_before = revision.chunks_before chunks_before = revision.chunks_before
if isinstance(chunks_before, list): if isinstance(chunks_before, list):
@ -330,7 +333,7 @@ async def _reinsert_document_from_revision(
if isinstance(c, dict) and isinstance(c.get("content"), str) if isinstance(c, dict) and isinstance(c.get("content"), str)
] ]
if chunk_texts: if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts) chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
session.add_all( session.add_all(
[ [
Chunk(document_id=new_doc.id, content=text, embedding=embedding) Chunk(document_id=new_doc.id, content=text, embedding=embedding)

View file

@ -325,6 +325,24 @@ class TokenTrackingCallback(CustomLogger):
total_tokens = getattr(usage, "total_tokens", 0) or 0 total_tokens = getattr(usage, "total_tokens", 0) or 0
call_kind = "chat" call_kind = "chat"
# Prompt-cache accounting. LiteLLM normalizes every provider's cache
# fields onto ``usage.prompt_tokens_details``:
# - ``cached_tokens`` — cache reads (OpenAI/Azure native, DeepSeek
# mapped from ``prompt_cache_hit_tokens``,
# Anthropic mapped from ``cache_read_input_tokens``).
# - ``cache_creation_tokens`` — cache writes (Anthropic only; OpenAI/Azure
# do not expose a write count).
# See ``litellm.types.utils.Usage.__init__`` for the mapping.
cached_tokens = 0
cache_creation_tokens = 0
if not is_image:
prompt_details = getattr(usage, "prompt_tokens_details", None)
if prompt_details is not None:
cached_tokens = getattr(prompt_details, "cached_tokens", 0) or 0
cache_creation_tokens = (
getattr(prompt_details, "cache_creation_tokens", 0) or 0
)
model = kwargs.get("model", "unknown") model = kwargs.get("model", "unknown")
cost_usd = _extract_cost_usd( cost_usd = _extract_cost_usd(
@ -357,9 +375,23 @@ class TokenTrackingCallback(CustomLogger):
cost_micros=cost_micros, cost_micros=cost_micros,
call_kind=call_kind, call_kind=call_kind,
) )
# Per-LLM-call wall-clock latency (LiteLLM passes datetime objects).
call_latency_s: float | None = None
try:
if start_time is not None and end_time is not None:
delta = end_time - start_time
call_latency_s = getattr(delta, "total_seconds", lambda: float(delta))()
except Exception:
call_latency_s = None
cache_hit_ratio: float | None = None
if prompt_tokens > 0 and (cached_tokens > 0 or cache_creation_tokens > 0):
cache_hit_ratio = cached_tokens / prompt_tokens
logger.info( logger.info(
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d " "[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
"cost=$%.6f (%d micros) (accumulator now has %d calls)", "cost=$%.6f (%d micros) (accumulator now has %d calls)%s%s",
model, model,
call_kind, call_kind,
prompt_tokens, prompt_tokens,
@ -368,6 +400,17 @@ class TokenTrackingCallback(CustomLogger):
cost_usd, cost_usd,
cost_micros, cost_micros,
len(acc.calls), len(acc.calls),
f" latency={call_latency_s:.3f}s" if call_latency_s is not None else "",
(
f" cache_read={cached_tokens} cache_write={cache_creation_tokens}"
f" hit_ratio={cache_hit_ratio:.1%}"
if cache_hit_ratio is not None
else (
f" cache_read={cached_tokens} cache_write={cache_creation_tokens}"
if (cached_tokens or cache_creation_tokens)
else ""
)
),
) )

View file

@ -60,8 +60,6 @@ from app.db import (
) )
from app.prompts import TITLE_GENERATION_PROMPT from app.prompts import TITLE_GENERATION_PROMPT
from app.services.auto_model_pin_service import ( from app.services.auto_model_pin_service import (
is_recently_healthy,
mark_healthy,
mark_runtime_cooldown, mark_runtime_cooldown,
resolve_or_get_pinned_llm_config_id, resolve_or_get_pinned_llm_config_id,
) )
@ -501,54 +499,6 @@ def _is_provider_rate_limited(exc: BaseException) -> bool:
) )
_PREFLIGHT_TIMEOUT_SEC: float = 2.5
_PREFLIGHT_MAX_TOKENS: int = 1
async def _preflight_llm(llm: Any) -> None:
"""Issue a minimal completion to confirm the pinned model isn't 429'ing.
Used before agent build / planner / classifier / title-gen so a known-bad
free OpenRouter deployment is detected and repinned before it cascades
into multiple wasted internal calls. The probe is intentionally cheap:
one token, low timeout, tagged ``surfsense:internal`` so token tracking
and SSE pipelines treat it as overhead rather than user output.
Raises the original exception when the provider responds with a
rate-limit-shaped error so the caller can drive the cooldown/repin
branch via :func:`_is_provider_rate_limited`. Other transient failures
are swallowed the caller continues to the normal stream path and the
in-stream recovery loop remains the safety net.
"""
from litellm import acompletion
model = getattr(llm, "model", None)
if not model or model == "auto":
# Auto-mode router doesn't have a single deployment to ping; the
# router itself handles per-deployment rate-limit accounting.
return
try:
await acompletion(
model=model,
messages=[{"role": "user", "content": "ping"}],
api_key=getattr(llm, "api_key", None),
api_base=getattr(llm, "api_base", None),
max_tokens=_PREFLIGHT_MAX_TOKENS,
timeout=_PREFLIGHT_TIMEOUT_SEC,
stream=False,
metadata={"tags": ["surfsense:internal", "auto-pin-preflight"]},
)
except Exception as exc:
if _is_provider_rate_limited(exc):
raise
logging.getLogger(__name__).debug(
"auto_pin_preflight non_rate_limit_error model=%s err=%s",
model,
exc,
)
async def _build_main_agent_for_thread( async def _build_main_agent_for_thread(
agent_factory: Any, agent_factory: Any,
*, *,
@ -566,9 +516,9 @@ async def _build_main_agent_for_thread(
disabled_tools: list[str] | None = None, disabled_tools: list[str] | None = None,
mentioned_document_ids: list[int] | None = None, mentioned_document_ids: list[int] | None = None,
) -> Any: ) -> Any:
"""Single (re)build path so the agent factory cannot drift across """Single (re)build path so the agent factory cannot drift across the
initial build, preflight repin, and mid-stream 429 recovery for one initial build and mid-stream 429 recovery for one ``thread_id``: a
``thread_id``: a graph swap mid-turn would corrupt checkpointer state.""" graph swap mid-turn would corrupt checkpointer state."""
return await agent_factory( return await agent_factory(
llm=llm, llm=llm,
search_space_id=search_space_id, search_space_id=search_space_id,
@ -586,29 +536,6 @@ async def _build_main_agent_for_thread(
) )
async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
"""Wait for a discarded speculative agent build to release shared state.
Used by the parallel preflight + agent-build path. The speculative build
closes over the request-scoped ``AsyncSession`` (for the brief connector
discovery / tool-factory window before its CPU work moves into a worker
thread). If preflight reports a 429 we want to fall back to the original
repin reload rebuild path, but we MUST NOT touch ``session`` again
until any in-flight session work owned by the speculative build has
fully settled :class:`sqlalchemy.ext.asyncio.AsyncSession` is not
concurrency-safe and the same hazard cost us a hard ``InvalidRequestError``
earlier in this PR (see ``connector_service`` parallel-gather revert).
We simply ``await`` the task and swallow any exception: in this path the
build's outcome is irrelevant — success populates the agent cache (a free
side effect), failure is discarded. The wasted CPU is acceptable since
429 fallbacks are rare and the original sequential code also paid the
full build cost on the same path.
"""
with contextlib.suppress(BaseException):
await task
def _classify_stream_exception( def _classify_stream_exception(
exc: Exception, exc: Exception,
*, *,
@ -1236,39 +1163,6 @@ async def stream_new_chat(
yield streaming_service.format_done() yield streaming_service.format_done()
return return
# Auto-mode preflight ping. Runs ONLY for thread-pinned auto cfgs
# (negative ids selected via ``resolve_or_get_pinned_llm_config_id``)
# whose health hasn't already been confirmed within the TTL window.
# Detecting a 429 here lets us repin BEFORE the planner/classifier/
# title-generation LLM calls fan out and each independently hit the
# same upstream rate limit.
#
# PERF: preflight is a network round-trip to the LLM provider (~1-5s)
# and is independent of the agent build (CPU-bound, ~5-7s). They used
# to run sequentially → ``preflight + build`` on cold cache = 11.5s.
# We now kick off preflight as a background task FIRST, then run the
# synchronous setup work and the agent build in parallel. In the
# success path (the common case) total wall time drops to roughly
# ``max(preflight, build)`` — the preflight finishes during the
# agent compile and we just consume its result. In the rare 429
# path the speculative build is awaited to completion (so its
# session usage is fully released) via
# :func:`_settle_speculative_agent_build`, then discarded, and
# we fall back to the original repin-and-rebuild flow.
preflight_needed = (
requested_llm_config_id == 0
and llm_config_id < 0
and not is_recently_healthy(llm_config_id)
)
preflight_task: asyncio.Task[None] | None = None
_t_preflight = 0.0
if preflight_needed:
_t_preflight = time.perf_counter()
preflight_task = asyncio.create_task(
_preflight_llm(llm),
name=f"auto_pin_preflight:{llm_config_id}",
)
# Create connector service # Create connector service
_t0 = time.perf_counter() _t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id) connector_service = ConnectorService(session, search_space_id=search_space_id)
@ -1302,136 +1196,26 @@ async def stream_new_chat(
if use_multi_agent if use_multi_agent
else create_surfsense_deep_agent else create_surfsense_deep_agent
) )
# Speculative agent build — runs in parallel with the preflight # Build the agent inline. Provider 429s surface through the
# task (if any). Built with the *current* ``llm`` / ``agent_config``; # in-stream recovery loop below (``_is_provider_rate_limited``),
# if preflight reports 429 we will discard this future and rebuild # which repins the thread to an eligible alternative config and
# against the freshly pinned config below. # rebuilds the agent before the user sees any output.
agent_build_task = asyncio.create_task( agent = await _build_main_agent_for_thread(
_build_main_agent_for_thread( agent_factory,
agent_factory, llm=llm,
llm=llm, search_space_id=search_space_id,
search_space_id=search_space_id, db_session=session,
db_session=session, connector_service=connector_service,
connector_service=connector_service, checkpointer=checkpointer,
checkpointer=checkpointer, user_id=user_id,
user_id=user_id, thread_id=chat_id,
thread_id=chat_id, agent_config=agent_config,
agent_config=agent_config, firecrawl_api_key=firecrawl_api_key,
firecrawl_api_key=firecrawl_api_key, thread_visibility=visibility,
thread_visibility=visibility, filesystem_selection=filesystem_selection,
filesystem_selection=filesystem_selection, disabled_tools=disabled_tools,
disabled_tools=disabled_tools, mentioned_document_ids=mentioned_document_ids,
mentioned_document_ids=mentioned_document_ids,
),
name="agent_build:stream_new_chat",
) )
agent: Any = None
if preflight_task is not None:
try:
await preflight_task
mark_healthy(llm_config_id)
_perf_log.info(
"[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
llm_config_id,
time.perf_counter() - _t_preflight,
)
except Exception as preflight_exc:
# Both branches below need the session: the non-429 path
# may unwind via cleanup that uses ``session``, and the
# 429 path explicitly calls ``resolve_or_get_pinned_llm_config_id``
# against it. Wait for the speculative build to release its
# session usage before we proceed.
await _settle_speculative_agent_build(agent_build_task)
if not _is_provider_rate_limited(preflight_exc):
raise
# 429: speculative agent is discarded; run the original
# repin → reload → rebuild path against the freshly
# pinned config.
previous_config_id = llm_config_id
mark_runtime_cooldown(
previous_config_id, reason="preflight_rate_limited"
)
try:
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
requires_image_input=_requires_image_input,
)
).resolved_llm_config_id
except ValueError as pin_error:
yield _emit_stream_error(
message=str(pin_error),
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
llm, agent_config, llm_load_error = await _load_llm_bundle(
llm_config_id
)
if llm_load_error or not llm:
yield _emit_stream_error(
message=llm_load_error or "Failed to create LLM instance",
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
# Trust the freshly-resolved cfg for the remainder of this
# turn rather than recursing into another preflight; the
# in-stream 429 recovery loop is still in place as the
# safety net if even this fallback hits an upstream cap.
mark_healthy(llm_config_id)
_log_chat_stream_error(
flow=flow,
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model failed preflight; switched to another "
"eligible model and continuing."
),
extra={
"auto_runtime_recover": True,
"preflight": True,
"previous_config_id": previous_config_id,
"fallback_config_id": llm_config_id,
},
)
# Rebuild against the new llm/agent_config. Sequential
# here because we no longer have anything to overlap with.
agent = await agent_factory(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
)
if agent is None:
# Either no preflight was needed, or preflight succeeded —
# in both cases the speculative build is the agent we want.
agent = await agent_build_task
_perf_log.info( _perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
) )
@ -2647,25 +2431,6 @@ async def stream_resume_chat(
yield streaming_service.format_done() yield streaming_service.format_done()
return return
# Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``:
# one cheap probe before the agent is rebuilt so a 429'd pin gets
# repinned without burning planner/classifier/title calls first.
# See ``stream_new_chat`` for the full rationale on the speculative
# parallel build pattern below.
preflight_needed = (
requested_llm_config_id == 0
and llm_config_id < 0
and not is_recently_healthy(llm_config_id)
)
preflight_task: asyncio.Task[None] | None = None
_t_preflight = 0.0
if preflight_needed:
_t_preflight = time.perf_counter()
preflight_task = asyncio.create_task(
_preflight_llm(llm),
name=f"auto_pin_preflight_resume:{llm_config_id}",
)
_t0 = time.perf_counter() _t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id) connector_service = ConnectorService(session, search_space_id=search_space_id)
@ -2695,115 +2460,25 @@ async def stream_resume_chat(
if _app_config.MULTI_AGENT_CHAT_ENABLED if _app_config.MULTI_AGENT_CHAT_ENABLED
else create_surfsense_deep_agent else create_surfsense_deep_agent
) )
agent_build_task = asyncio.create_task( # Build the agent inline. Provider 429s are handled by the
_build_main_agent_for_thread( # in-stream recovery loop, which repins to an eligible
agent_factory, # alternative config and rebuilds the agent before the user sees
llm=llm, # any output.
search_space_id=search_space_id, agent = await _build_main_agent_for_thread(
db_session=session, agent_factory,
connector_service=connector_service, llm=llm,
checkpointer=checkpointer, search_space_id=search_space_id,
user_id=user_id, db_session=session,
thread_id=chat_id, connector_service=connector_service,
agent_config=agent_config, checkpointer=checkpointer,
firecrawl_api_key=firecrawl_api_key, user_id=user_id,
thread_visibility=visibility, thread_id=chat_id,
filesystem_selection=filesystem_selection, agent_config=agent_config,
disabled_tools=disabled_tools, firecrawl_api_key=firecrawl_api_key,
), thread_visibility=visibility,
name="agent_build:stream_resume", filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
) )
agent: Any = None
if preflight_task is not None:
try:
await preflight_task
mark_healthy(llm_config_id)
_perf_log.info(
"[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
llm_config_id,
time.perf_counter() - _t_preflight,
)
except Exception as preflight_exc:
# Same session-safety rationale as ``stream_new_chat``.
await _settle_speculative_agent_build(agent_build_task)
if not _is_provider_rate_limited(preflight_exc):
raise
previous_config_id = llm_config_id
mark_runtime_cooldown(
previous_config_id, reason="preflight_rate_limited"
)
try:
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
)
).resolved_llm_config_id
except ValueError as pin_error:
yield _emit_stream_error(
message=str(pin_error),
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
llm, agent_config, llm_load_error = await _load_llm_bundle(
llm_config_id
)
if llm_load_error or not llm:
yield _emit_stream_error(
message=llm_load_error or "Failed to create LLM instance",
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
mark_healthy(llm_config_id)
_log_chat_stream_error(
flow="resume",
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model failed preflight; switched to another "
"eligible model and continuing."
),
extra={
"auto_runtime_recover": True,
"preflight": True,
"previous_config_id": previous_config_id,
"fallback_config_id": llm_config_id,
},
)
agent = await _build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
)
if agent is None:
agent = await agent_build_task
_perf_log.info( _perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
) )

View file

@ -670,7 +670,9 @@ async def index_discord_messages(
# Heavy processing (embeddings, chunks) # Heavy processing (embeddings, chunks)
chunks = await create_document_chunks(item["combined_document_string"]) chunks = await create_document_chunks(item["combined_document_string"])
doc_embedding = embed_text(item["combined_document_string"]) doc_embedding = await asyncio.to_thread(
embed_text, item["combined_document_string"]
)
# Update document to READY with actual content # Update document to READY with actual content
document.title = f"{item['guild_name']}#{item['channel_name']}" document.title = f"{item['guild_name']}#{item['channel_name']}"

View file

@ -6,6 +6,7 @@ Implements 2-phase document status updates for real-time UI feedback:
- Phase 2: Process each event: pending processing ready/failed - Phase 2: Process each event: pending processing ready/failed
""" """
import asyncio
import time import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -465,7 +466,9 @@ async def index_luma_events(
summary_content = ( summary_content = (
f"Luma Event: {item['event_name']}\n\n{item['event_markdown']}" f"Luma Event: {item['event_name']}\n\n{item['event_markdown']}"
) )
summary_embedding = embed_text(summary_content) summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(item["event_markdown"]) chunks = await create_document_chunks(item["event_markdown"])

View file

@ -9,6 +9,7 @@ Uses 2-phase document status updates for real-time UI feedback:
- Phase 2: Process each document: pending processing ready/failed - Phase 2: Process each document: pending processing ready/failed
""" """
import asyncio
import time import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import UTC, datetime from datetime import UTC, datetime
@ -581,7 +582,9 @@ async def index_teams_messages(
# Heavy processing (embeddings, chunks) # Heavy processing (embeddings, chunks)
chunks = await create_document_chunks(item["combined_document_string"]) chunks = await create_document_chunks(item["combined_document_string"])
doc_embedding = embed_text(item["combined_document_string"]) doc_embedding = await asyncio.to_thread(
embed_text, item["combined_document_string"]
)
# Update document to READY with actual content # Update document to READY with actual content
document.title = f"{item['team_name']} - {item['channel_name']}" document.title = f"{item['team_name']} - {item['channel_name']}"

View file

@ -2,6 +2,7 @@
Unified document save/update logic for file processors. Unified document save/update logic for file processors.
""" """
import asyncio
import logging import logging
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@ -43,7 +44,7 @@ async def _generate_summary(
""" """
if not enable_summary: if not enable_summary:
summary = f"File: {file_name}\n\n{markdown_content[:4000]}" summary = f"File: {file_name}\n\n{markdown_content[:4000]}"
return summary, embed_text(summary) return summary, await asyncio.to_thread(embed_text, summary)
if etl_service == "DOCLING": if etl_service == "DOCLING":
from app.services.docling_service import create_docling_service from app.services.docling_service import create_docling_service
@ -65,7 +66,7 @@ async def _generate_summary(
parts.append(f"**{formatted_key}:** {value}") parts.append(f"**{formatted_key}:** {value}")
enhanced = "\n".join(parts) + "\n\n# DOCUMENT SUMMARY\n\n" + summary_text enhanced = "\n".join(parts) + "\n\n# DOCUMENT SUMMARY\n\n" + summary_text
return enhanced, embed_text(enhanced) return enhanced, await asyncio.to_thread(embed_text, enhanced)
# Standard summary (Unstructured / LlamaCloud / others) # Standard summary (Unstructured / LlamaCloud / others)
meta = { meta = {

View file

@ -1,3 +1,4 @@
import asyncio
import hashlib import hashlib
import logging import logging
import threading import threading
@ -221,7 +222,9 @@ async def generate_document_summary(
else: else:
enhanced_summary_content = summary_content enhanced_summary_content = summary_content
summary_embedding = embed_text(enhanced_summary_content) summary_embedding = await asyncio.to_thread(
embed_text, enhanced_summary_content
)
return enhanced_summary_content, summary_embedding return enhanced_summary_content, summary_embedding
@ -237,7 +240,7 @@ async def create_document_chunks(content: str) -> list[Chunk]:
List of Chunk objects with embeddings List of Chunk objects with embeddings
""" """
chunk_texts = [c.text for c in config.chunker_instance.chunk(content)] chunk_texts = [c.text for c in config.chunker_instance.chunk(content)]
chunk_embeddings = embed_texts(chunk_texts) chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
return [ return [
Chunk(content=text, embedding=emb) Chunk(content=text, embedding=emb)
for text, emb in zip(chunk_texts, chunk_embeddings, strict=False) for text, emb in zip(chunk_texts, chunk_embeddings, strict=False)

View file

@ -12,13 +12,19 @@ prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
the deepagent stack accumulates multiple ``SystemMessage``\ s in the deepagent stack accumulates multiple ``SystemMessage``\ s in
``state["messages"]`` and ``role: system`` would tag every one of ``state["messages"]`` and ``role: system`` would tag every one of
them, blowing past Anthropic's 4-block ``cache_control`` cap. them, blowing past Anthropic's 4-block ``cache_control`` cap.
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for 2. Adds ``prompt_cache_key`` for OPENAI/DEEPSEEK/XAI/AZURE/AZURE_OPENAI
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic configs (Microsoft's Azure transformer was added to LiteLLM in
prompt-cache surface is available). https://github.com/BerriAI/litellm/pull/20989, Feb 2026).
3. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only no 3. Adds ``prompt_cache_retention="24h"`` ONLY for OPENAI/DEEPSEEK/XAI.
OpenAI-only kwargs because the router fans out across providers. Azure's server-side support landed in Microsoft's docs on 2026-05-13
4. Idempotent: user-supplied values in ``model_kwargs`` are preserved. but LiteLLM 1.83.14 hasn't wired it through yet, so we let Azure use
5. Defensive: LLMs without a writable ``model_kwargs`` are silently its default in-memory retention rather than send a param that
``litellm.drop_params`` would silently strip.
4. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only no
destination-specific kwargs because the router fans out across
providers.
5. Idempotent: user-supplied values in ``model_kwargs`` are preserved.
6. Defensive: LLMs without a writable ``model_kwargs`` are silently
skipped rather than raising. skipped rather than raising.
""" """
@ -191,9 +197,9 @@ def test_does_not_overwrite_user_supplied_prompt_cache_key() -> None:
@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"]) @pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"])
def test_sets_openai_family_extras(provider: str) -> None: def test_sets_openai_family_extras(provider: str) -> None:
"""OpenAI-style providers gain ``prompt_cache_key`` (raises hit rate """Native OpenAI-style providers gain ``prompt_cache_key`` (raises
via routing affinity) and ``prompt_cache_retention="24h"`` (extends hit rate via routing affinity) and ``prompt_cache_retention="24h"``
cache TTL beyond the default 5-10 min).""" (extends cache TTL beyond the default 5-10 min)."""
cfg = _make_cfg(provider=provider) cfg = _make_cfg(provider=provider)
llm = _FakeLLM() llm = _FakeLLM()
@ -203,6 +209,27 @@ def test_sets_openai_family_extras(provider: str) -> None:
assert llm.model_kwargs["prompt_cache_retention"] == "24h" assert llm.model_kwargs["prompt_cache_retention"] == "24h"
@pytest.mark.parametrize("provider", ["AZURE", "AZURE_OPENAI"])
def test_azure_gets_prompt_cache_key_only(provider: str) -> None:
"""Azure configs gain ``prompt_cache_key`` for routing affinity
(Microsoft auto-caches every GPT-4o+ deployment at 1024 tokens;
the key clusters same-prefix requests on the same backend GPU pool
so hit rate climbs). They DO NOT get ``prompt_cache_retention``
because LiteLLM 1.83.14's Azure transformer omits it from its
supported params list ``drop_params`` would silently strip it.
Azure's default in-memory retention (5-10 min, max 1 h) is already
enough to cover intra-conversation turns; revisit when LiteLLM
bumps Azure to match its OpenAI surface."""
cfg = _make_cfg(provider=provider, model_name="gpt-5.4")
llm = _FakeLLM(model="azure/gpt-5.4")
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert llm.model_kwargs["prompt_cache_key"] == "surfsense-thread-42"
assert "prompt_cache_retention" not in llm.model_kwargs
assert "cache_control_injection_points" in llm.model_kwargs
def test_skips_prompt_cache_key_when_no_thread_id() -> None: def test_skips_prompt_cache_key_when_no_thread_id() -> None:
"""Without a thread id we can't construct a per-thread key. Retention """Without a thread id we can't construct a per-thread key. Retention
is still useful so we set it (it's free).""" is still useful so we set it (it's free)."""
@ -215,12 +242,26 @@ def test_skips_prompt_cache_key_when_no_thread_id() -> None:
assert llm.model_kwargs["prompt_cache_retention"] == "24h" assert llm.model_kwargs["prompt_cache_retention"] == "24h"
def test_azure_skips_prompt_cache_key_when_no_thread_id() -> None:
"""Azure without a thread id ends up with no extras (retention is
Azure-skipped, key needs a thread id) universal injection points
still land."""
cfg = _make_cfg(provider="AZURE", model_name="gpt-5.4")
llm = _FakeLLM(model="azure/gpt-5.4")
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=None)
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
assert "cache_control_injection_points" in llm.model_kwargs
@pytest.mark.parametrize( @pytest.mark.parametrize(
"provider", "provider",
["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"], ["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"],
) )
def test_no_openai_extras_for_other_providers(provider: str) -> None: def test_no_openai_extras_for_other_providers(provider: str) -> None:
"""Non-OpenAI-family providers don't expose ``prompt_cache_key`` — """Non-OpenAI-style providers don't expose ``prompt_cache_key`` —
skip it. ``cache_control_injection_points`` is still set (universal).""" skip it. ``cache_control_injection_points`` is still set (universal)."""
cfg = _make_cfg(provider=provider) cfg = _make_cfg(provider=provider)
llm = _FakeLLM() llm = _FakeLLM()

View file

@ -0,0 +1,130 @@
"""Unit tests for ``mcp_tools_cache``."""
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
import pytest
from app.agents.new_chat.tools.mcp_tools_cache import (
CachedMCPToolDef,
CachedMCPTools,
read_cached_tools,
)
pytestmark = pytest.mark.unit
def _make_connector(config: dict | None) -> SimpleNamespace:
return SimpleNamespace(id=42, config=config)
def test_read_returns_none_when_config_is_none() -> None:
assert read_cached_tools(_make_connector(None)) is None
def test_read_returns_none_when_cached_tools_missing() -> None:
assert read_cached_tools(_make_connector({"server_config": {}})) is None
def test_read_returns_none_when_cached_tools_is_not_a_dict() -> None:
assert read_cached_tools(_make_connector({"cached_tools": []})) is None
assert read_cached_tools(_make_connector({"cached_tools": "stale"})) is None
def test_read_parses_minimal_valid_payload() -> None:
parsed = read_cached_tools(
_make_connector(
{
"cached_tools": {
"discovered_at": "2026-05-20T10:00:00+00:00",
"tools": [
{
"name": "list_issues",
"description": "List Linear issues",
"input_schema": {"type": "object"},
}
],
}
}
)
)
assert parsed is not None
assert parsed.server_version is None
assert parsed.server_name is None
assert parsed.transport is None
assert len(parsed.tools) == 1
assert parsed.tools[0].name == "list_issues"
def test_read_parses_full_payload_with_serverinfo() -> None:
parsed = read_cached_tools(
_make_connector(
{
"cached_tools": {
"discovered_at": "2026-05-20T10:00:00+00:00",
"server_version": "1.2.3",
"server_name": "atlassian-mcp",
"transport": "streamable-http",
"tools": [
{"name": "create_issue", "input_schema": {}},
{"name": "list_issues", "input_schema": {}},
],
}
}
)
)
assert parsed is not None
assert parsed.server_version == "1.2.3"
assert parsed.server_name == "atlassian-mcp"
assert parsed.transport == "streamable-http"
assert [t.name for t in parsed.tools] == ["create_issue", "list_issues"]
def test_read_returns_none_for_corrupt_payload(caplog) -> None:
parsed = read_cached_tools(
_make_connector(
{
"cached_tools": {
"discovered_at": "not-a-date",
"tools": "should-be-a-list",
}
}
)
)
assert parsed is None
assert any("corrupt cached_tools" in r.getMessage() for r in caplog.records)
def test_read_returns_none_when_tools_missing() -> None:
parsed = read_cached_tools(
_make_connector(
{"cached_tools": {"discovered_at": "2026-05-20T10:00:00+00:00"}}
)
)
assert parsed is None
def test_tool_def_defaults_description_and_schema() -> None:
td = CachedMCPToolDef.model_validate({"name": "ping"})
assert td.description == ""
assert td.input_schema == {}
def test_model_dump_json_mode_is_round_trippable() -> None:
original = CachedMCPTools(
discovered_at=datetime(2026, 5, 20, 10, 0, 0, tzinfo=UTC),
server_version="1.2.3",
server_name="atlassian-mcp",
transport="streamable-http",
tools=[CachedMCPToolDef(name="list_issues")],
)
payload = original.model_dump(mode="json")
assert payload["discovered_at"] == "2026-05-20T10:00:00Z"
assert payload["tools"][0]["name"] == "list_issues"
reparsed = CachedMCPTools.model_validate(payload)
assert reparsed.discovered_at == original.discovered_at
assert reparsed.tools[0].name == "list_issues"

View file

@ -209,128 +209,6 @@ def test_stream_exception_classifies_openrouter_429_payload():
assert extra is None assert extra is None
@pytest.mark.asyncio
async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch):
"""``_preflight_llm`` is best-effort.
- On rate-limit shaped exceptions (provider 429) it MUST re-raise so the
caller can drive the cooldown/repin branch.
- On any other transient failure it MUST swallow the error so the normal
stream path continues without surfacing preflight noise to the user.
"""
from types import SimpleNamespace
from app.tasks.chat.stream_new_chat import _preflight_llm
class _RateLimitedError(Exception):
"""Class-name carries 'RateLimit' so _is_provider_rate_limited triggers."""
rate_calls: list[dict] = []
other_calls: list[dict] = []
async def _fake_acompletion_429(**kwargs):
rate_calls.append(kwargs)
raise _RateLimitedError("simulated 429")
async def _fake_acompletion_other(**kwargs):
other_calls.append(kwargs)
raise RuntimeError("some unrelated transient failure")
fake_llm = SimpleNamespace(
model="openrouter/google/gemma-4-31b-it:free",
api_key="test",
api_base=None,
)
import litellm # type: ignore[import-not-found]
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429)
with pytest.raises(_RateLimitedError):
await _preflight_llm(fake_llm)
assert len(rate_calls) == 1
assert rate_calls[0]["max_tokens"] == 1
assert rate_calls[0]["stream"] is False
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other)
# MUST NOT raise: non-rate-limit failures are swallowed.
await _preflight_llm(fake_llm)
assert len(other_calls) == 1
@pytest.mark.asyncio
async def test_preflight_skipped_for_auto_router_model():
"""Router-mode ``model='auto'`` has no single deployment to ping; the
LiteLLM router itself owns per-deployment rate-limit accounting, so the
preflight helper must short-circuit instead of issuing a probe."""
from types import SimpleNamespace
from app.tasks.chat.stream_new_chat import _preflight_llm
fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None)
# Should return without raising or making any LiteLLM call.
await _preflight_llm(fake_llm)
@pytest.mark.asyncio
async def test_settle_speculative_agent_build_swallows_exceptions():
"""``_settle_speculative_agent_build`` MUST always return cleanly so the
caller can safely re-touch the request-scoped session afterwards.
The helper guards the parallel preflight + agent-build path: when the
speculative build is being discarded (429 or non-429 preflight failure)
we await it solely to release any in-flight ``AsyncSession`` usage
the build's outcome is irrelevant. Any exception (including
``CancelledError``) leaking out would skip the caller's recovery flow
and re-introduce the very session-concurrency hazard the helper exists
to prevent.
"""
import asyncio
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
async def _raises() -> None:
raise RuntimeError("speculative build crashed")
async def _succeeds() -> str:
return "agent"
async def _slow() -> None:
await asyncio.sleep(0.05)
for coro in (_raises(), _succeeds(), _slow()):
task = asyncio.create_task(coro)
await _settle_speculative_agent_build(task)
assert task.done()
@pytest.mark.asyncio
async def test_settle_speculative_agent_build_handles_already_done_task():
"""Done tasks (success or failure) must still be settled without raising."""
import asyncio
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
async def _ok() -> str:
return "ok"
async def _bad() -> None:
raise ValueError("nope")
ok_task = asyncio.create_task(_ok())
bad_task = asyncio.create_task(_bad())
# Drive both to completion before settling.
await asyncio.sleep(0)
await asyncio.sleep(0)
await _settle_speculative_agent_build(ok_task)
await _settle_speculative_agent_build(bad_task)
assert ok_task.result() == "ok"
# ``bad_task`` exception was consumed by the settle helper; calling
# ``.exception()`` after the fact must still return the original error
# (the helper observes it but doesn't clear it).
assert isinstance(bad_task.exception(), ValueError)
def test_stream_exception_classifies_thread_busy(): def test_stream_exception_classifies_thread_busy():
exc = BusyError(request_id="thread-123") exc = BusyError(request_id="thread-123")
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(