mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-27 19:25:15 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/opentelemetry
This commit is contained in:
commit
862381f4e6
45 changed files with 1126 additions and 661 deletions
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from deepagents.backends.protocol import BackendFactory, BackendProtocol
|
from deepagents.backends.protocol import BackendFactory, BackendProtocol
|
||||||
|
|
@ -15,8 +16,12 @@ from langchain.agents import create_agent
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
from .task_tool import build_task_tool_with_parent_config
|
from .task_tool import build_task_tool_with_parent_config
|
||||||
|
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
||||||
"""``SubAgentMiddleware`` variant that compiles each subagent against the parent checkpointer."""
|
"""``SubAgentMiddleware`` variant that compiles each subagent against the parent checkpointer."""
|
||||||
|
|
@ -54,8 +59,11 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
||||||
def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]:
|
def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]:
|
||||||
"""Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer."""
|
"""Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer."""
|
||||||
specs: list[dict[str, Any]] = []
|
specs: list[dict[str, Any]] = []
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
timings: list[tuple[str, float, str]] = [] # (name, elapsed, source)
|
||||||
|
|
||||||
for spec in self._subagents:
|
for spec in self._subagents:
|
||||||
|
spec_start = time.perf_counter()
|
||||||
if "runnable" in spec:
|
if "runnable" in spec:
|
||||||
compiled = cast(CompiledSubAgent, spec)
|
compiled = cast(CompiledSubAgent, spec)
|
||||||
specs.append(
|
specs.append(
|
||||||
|
|
@ -65,6 +73,9 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
||||||
"runnable": compiled["runnable"],
|
"runnable": compiled["runnable"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
timings.append(
|
||||||
|
(compiled["name"], time.perf_counter() - spec_start, "precompiled")
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "model" not in spec:
|
if "model" not in spec:
|
||||||
|
|
@ -79,20 +90,44 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
||||||
model = init_chat_model(model)
|
model = init_chat_model(model)
|
||||||
|
|
||||||
middleware: list[Any] = list(spec.get("middleware", []))
|
middleware: list[Any] = list(spec.get("middleware", []))
|
||||||
|
tools_count = len(spec.get("tools") or [])
|
||||||
|
mw_count = len(middleware)
|
||||||
|
|
||||||
|
compile_start = time.perf_counter()
|
||||||
|
runnable = create_agent(
|
||||||
|
model,
|
||||||
|
system_prompt=spec["system_prompt"],
|
||||||
|
tools=spec["tools"],
|
||||||
|
middleware=middleware,
|
||||||
|
name=spec["name"],
|
||||||
|
checkpointer=self._surf_checkpointer,
|
||||||
|
)
|
||||||
|
compile_elapsed = time.perf_counter() - compile_start
|
||||||
specs.append(
|
specs.append(
|
||||||
{
|
{
|
||||||
"name": spec["name"],
|
"name": spec["name"],
|
||||||
"description": spec["description"],
|
"description": spec["description"],
|
||||||
"runnable": create_agent(
|
"runnable": runnable,
|
||||||
model,
|
|
||||||
system_prompt=spec["system_prompt"],
|
|
||||||
tools=spec["tools"],
|
|
||||||
middleware=middleware,
|
|
||||||
name=spec["name"],
|
|
||||||
checkpointer=self._surf_checkpointer,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
timings.append(
|
||||||
|
(
|
||||||
|
spec["name"],
|
||||||
|
compile_elapsed,
|
||||||
|
f"compiled tools={tools_count} mw={mw_count}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
total_elapsed = time.perf_counter() - loop_start
|
||||||
|
per_subagent = ", ".join(
|
||||||
|
f"{name}={elapsed * 1000:.0f}ms[{source}]"
|
||||||
|
for name, elapsed, source in timings
|
||||||
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[subagent_compile] total=%.3fs count=%d details=[%s]",
|
||||||
|
total_elapsed,
|
||||||
|
len(timings),
|
||||||
|
per_subagent,
|
||||||
|
)
|
||||||
|
|
||||||
return specs
|
return specs
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ re-raises any new pending interrupt back to the parent.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Annotated, Any, NoReturn
|
from typing import Annotated, Any, NoReturn
|
||||||
|
|
||||||
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
|
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
|
||||||
|
|
@ -19,6 +20,8 @@ from langchain_core.tools import StructuredTool
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
from langgraph.types import Command, Interrupt
|
from langgraph.types import Command, Interrupt
|
||||||
|
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
consume_surfsense_resume,
|
consume_surfsense_resume,
|
||||||
drain_parent_null_resume,
|
drain_parent_null_resume,
|
||||||
|
|
@ -35,6 +38,7 @@ from .resume import (
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
def _reraise_stamped_subagent_interrupt(
|
def _reraise_stamped_subagent_interrupt(
|
||||||
|
|
@ -209,6 +213,7 @@ def build_task_tool_with_parent_config(
|
||||||
],
|
],
|
||||||
runtime: ToolRuntime,
|
runtime: ToolRuntime,
|
||||||
) -> str | Command:
|
) -> str | Command:
|
||||||
|
atask_start = time.perf_counter()
|
||||||
logger.info(
|
logger.info(
|
||||||
"[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s",
|
"[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s",
|
||||||
subagent_type,
|
subagent_type,
|
||||||
|
|
@ -230,8 +235,10 @@ def build_task_tool_with_parent_config(
|
||||||
# Resume bridge — see ``task`` above.
|
# Resume bridge — see ``task`` above.
|
||||||
pending_id: str | None = None
|
pending_id: str | None = None
|
||||||
pending_value: Any = None
|
pending_value: Any = None
|
||||||
|
aget_state_elapsed = 0.0
|
||||||
aget_state = getattr(subagent, "aget_state", None)
|
aget_state = getattr(subagent, "aget_state", None)
|
||||||
if callable(aget_state):
|
if callable(aget_state):
|
||||||
|
aget_state_start = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
snapshot = await aget_state(sub_config)
|
snapshot = await aget_state(sub_config)
|
||||||
pending_id, pending_value = get_first_pending_subagent_interrupt(
|
pending_id, pending_value = get_first_pending_subagent_interrupt(
|
||||||
|
|
@ -248,32 +255,78 @@ def build_task_tool_with_parent_config(
|
||||||
"Subagent aget_state failed; falling back to fresh ainvoke",
|
"Subagent aget_state failed; falling back to fresh ainvoke",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
aget_state_elapsed = time.perf_counter() - aget_state_start
|
||||||
|
|
||||||
if pending_value is not None:
|
invoke_path = "resume" if pending_value is not None else "fresh"
|
||||||
resume_value = consume_surfsense_resume(runtime)
|
ainvoke_start = time.perf_counter()
|
||||||
if resume_value is None:
|
ainvoke_outcome = "ok"
|
||||||
raise RuntimeError(
|
try:
|
||||||
f"Subagent {subagent_type!r} has a pending interrupt but no "
|
if pending_value is not None:
|
||||||
"surfsense_resume_value on config; resume bridge is broken."
|
resume_value = consume_surfsense_resume(runtime)
|
||||||
)
|
if resume_value is None:
|
||||||
expected = hitlrequest_action_count(pending_value)
|
raise RuntimeError(
|
||||||
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
f"Subagent {subagent_type!r} has a pending interrupt but no "
|
||||||
# Prevent the parent's resume payload from leaking into subagent
|
"surfsense_resume_value on config; resume bridge is broken."
|
||||||
# interrupts via langgraph's parent_scratchpad fallback.
|
)
|
||||||
drain_parent_null_resume(runtime)
|
expected = hitlrequest_action_count(pending_value)
|
||||||
try:
|
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
||||||
result = await subagent.ainvoke(
|
# Prevent the parent's resume payload from leaking into subagent
|
||||||
build_resume_command(resume_value, pending_id),
|
# interrupts via langgraph's parent_scratchpad fallback.
|
||||||
config=sub_config,
|
drain_parent_null_resume(runtime)
|
||||||
)
|
try:
|
||||||
except GraphInterrupt as gi:
|
result = await subagent.ainvoke(
|
||||||
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
build_resume_command(resume_value, pending_id),
|
||||||
else:
|
config=sub_config,
|
||||||
try:
|
)
|
||||||
result = await subagent.ainvoke(subagent_state, config=sub_config)
|
except GraphInterrupt as gi:
|
||||||
except GraphInterrupt as gi:
|
ainvoke_outcome = "interrupted"
|
||||||
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
_perf_log.info(
|
||||||
return _return_command_with_state_update(result, runtime.tool_call_id)
|
"[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s "
|
||||||
|
"aget_state=%.3fs ainvoke=%.3fs total=%.3fs",
|
||||||
|
subagent_type,
|
||||||
|
invoke_path,
|
||||||
|
ainvoke_outcome,
|
||||||
|
aget_state_elapsed,
|
||||||
|
time.perf_counter() - ainvoke_start,
|
||||||
|
time.perf_counter() - atask_start,
|
||||||
|
)
|
||||||
|
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
result = await subagent.ainvoke(subagent_state, config=sub_config)
|
||||||
|
except GraphInterrupt as gi:
|
||||||
|
ainvoke_outcome = "interrupted"
|
||||||
|
_perf_log.info(
|
||||||
|
"[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s "
|
||||||
|
"aget_state=%.3fs ainvoke=%.3fs total=%.3fs",
|
||||||
|
subagent_type,
|
||||||
|
invoke_path,
|
||||||
|
ainvoke_outcome,
|
||||||
|
aget_state_elapsed,
|
||||||
|
time.perf_counter() - ainvoke_start,
|
||||||
|
time.perf_counter() - atask_start,
|
||||||
|
)
|
||||||
|
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
||||||
|
ainvoke_elapsed = time.perf_counter() - ainvoke_start
|
||||||
|
except GraphInterrupt:
|
||||||
|
raise
|
||||||
|
|
||||||
|
merge_start = time.perf_counter()
|
||||||
|
cmd = _return_command_with_state_update(result, runtime.tool_call_id)
|
||||||
|
merge_elapsed = time.perf_counter() - merge_start
|
||||||
|
_perf_log.info(
|
||||||
|
"[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s "
|
||||||
|
"aget_state=%.3fs ainvoke=%.3fs merge=%.3fs total=%.3fs",
|
||||||
|
subagent_type,
|
||||||
|
invoke_path,
|
||||||
|
ainvoke_outcome,
|
||||||
|
aget_state_elapsed,
|
||||||
|
ainvoke_elapsed,
|
||||||
|
merge_elapsed,
|
||||||
|
time.perf_counter() - atask_start,
|
||||||
|
)
|
||||||
|
return cmd
|
||||||
|
|
||||||
return StructuredTool.from_function(
|
return StructuredTool.from_function(
|
||||||
name="task",
|
name="task",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
from app.agents.new_chat.middleware import KnowledgePriorityMiddleware
|
from app.agents.new_chat.middleware import KnowledgePriorityMiddleware
|
||||||
|
from app.services.llm_service import get_planner_llm
|
||||||
|
|
||||||
|
|
||||||
def build_knowledge_priority_mw(
|
def build_knowledge_priority_mw(
|
||||||
|
|
@ -19,6 +20,7 @@ def build_knowledge_priority_mw(
|
||||||
) -> KnowledgePriorityMiddleware:
|
) -> KnowledgePriorityMiddleware:
|
||||||
return KnowledgePriorityMiddleware(
|
return KnowledgePriorityMiddleware(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
planner_llm=get_planner_llm(),
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
filesystem_mode=filesystem_mode,
|
filesystem_mode=filesystem_mode,
|
||||||
available_connectors=available_connectors,
|
available_connectors=available_connectors,
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
|
@ -10,6 +11,9 @@ from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||||
from app.agents.new_chat.middleware.knowledge_search import _render_priority_message
|
from app.agents.new_chat.middleware.knowledge_search import _render_priority_message
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
class KbContextProjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
class KbContextProjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
|
@ -30,17 +34,34 @@ class KbContextProjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
runtime: Runtime[Any],
|
runtime: Runtime[Any],
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
del runtime
|
del runtime
|
||||||
|
start = time.perf_counter()
|
||||||
tree_text = state.get("workspace_tree_text")
|
tree_text = state.get("workspace_tree_text")
|
||||||
priority = state.get("kb_priority")
|
priority = state.get("kb_priority")
|
||||||
if not tree_text and not priority:
|
if not tree_text and not priority:
|
||||||
|
_perf_log.info(
|
||||||
|
"[kb_context_projection] tree=0 priority=0 elapsed=%.3fs",
|
||||||
|
time.perf_counter() - start,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
messages = list(state.get("messages") or [])
|
messages = list(state.get("messages") or [])
|
||||||
insert_at = max(len(messages) - 1, 0)
|
insert_at = max(len(messages) - 1, 0)
|
||||||
|
tree_chars = 0
|
||||||
if tree_text:
|
if tree_text:
|
||||||
|
tree_chars = len(tree_text)
|
||||||
messages.insert(insert_at, SystemMessage(content=tree_text))
|
messages.insert(insert_at, SystemMessage(content=tree_text))
|
||||||
|
priority_count = 0
|
||||||
if priority:
|
if priority:
|
||||||
|
priority_count = (
|
||||||
|
len(priority) if hasattr(priority, "__len__") else 1
|
||||||
|
)
|
||||||
messages.insert(insert_at, _render_priority_message(priority))
|
messages.insert(insert_at, _render_priority_message(priority))
|
||||||
|
_perf_log.info(
|
||||||
|
"[kb_context_projection] tree_chars=%d priority_items=%d elapsed=%.3fs",
|
||||||
|
tree_chars,
|
||||||
|
priority_count,
|
||||||
|
time.perf_counter() - start,
|
||||||
|
)
|
||||||
return {"messages": messages}
|
return {"messages": messages}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -118,5 +118,6 @@ Rules:
|
||||||
- `status=success` → `next_step=null`, `missing_fields=null`.
|
- `status=success` → `next_step=null`, `missing_fields=null`.
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
|
- `evidence.content_excerpt`: max ~500 characters. Surface a short excerpt or a one-sentence summary, not the full file body. The supervisor already sees the tool's raw output.
|
||||||
|
|
||||||
Infer before you call; map every tool outcome faithfully.
|
Infer before you call; map every tool outcome faithfully.
|
||||||
|
|
|
||||||
|
|
@ -118,5 +118,6 @@ Rules:
|
||||||
- `status=success` → `next_step=null`, `missing_fields=null`.
|
- `status=success` → `next_step=null`, `missing_fields=null`.
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
|
- `evidence.content_excerpt`: max ~500 characters. Surface a short excerpt or a one-sentence summary, not the full file body. The supervisor already sees the tool's raw output.
|
||||||
|
|
||||||
Infer before you call; map every tool outcome faithfully.
|
Infer before you call; map every tool outcome faithfully.
|
||||||
|
|
|
||||||
|
|
@ -50,4 +50,6 @@ Rules:
|
||||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||||
|
- `evidence.findings`: max 10 entries, each a single sentence stating one distinct fact. Do not paste raw paragraphs, scraped pages, or quote blocks.
|
||||||
|
- `evidence.sources`: max 10 URLs, one per finding when applicable. List each URL once.
|
||||||
</output_contract>
|
</output_contract>
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ Supervisor: "List open tasks in the Project Tracker base."
|
||||||
2. List tables in that base → identify the Tasks table; capture its table ID.
|
2. List tables in that base → identify the Tasks table; capture its table ID.
|
||||||
3. Get table schema → identify the status field and the choice IDs that represent "open" states.
|
3. Get table schema → identify the status field and the choice IDs that represent "open" states.
|
||||||
4. List records with a typed filter on the status field for those choice IDs.
|
4. List records with a typed filter on the status field for those choice IDs.
|
||||||
5. Return `status=success` with the matched records in `evidence.items`.
|
5. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched records listed in `action_summary` (record id, primary-field value, and 1-2 most relevant fields; one line per record; up to 10 entries, then `"...and N more"`).
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
|
|
@ -97,7 +97,7 @@ Rules:
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: base, table, field, choice, record, etc.).
|
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: base, table, field, choice, record, etc.).
|
||||||
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
|
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (record id, primary-field value, and 1-2 most relevant fields; up to 10 entries, then `"...and N more"`).
|
||||||
</output_contract>
|
</output_contract>
|
||||||
|
|
||||||
Discover before you mutate; never guess identifiers, choice IDs, or required fields.
|
Discover before you mutate; never guess identifiers, choice IDs, or required fields.
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ You are a Google Calendar specialist for the user's connected calendar.
|
||||||
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
||||||
| tool raises / unknown | `error` | `"Calendar tool failed unexpectedly. Ask the user to retry shortly."` |
|
| tool raises / unknown | `error` | `"Calendar tool failed unexpectedly. Ask the user to retry shortly."` |
|
||||||
|
|
||||||
Surface the tool's `event_id`, `title` / `summary`, `start_at`, `end_at`, and `html_link` inside `evidence` when the tool returned them. For `search_calendar_events`, place the raw `events` array inside `evidence.items`. Never invent a field the tool did not return.
|
Surface the tool's `event_id`, `title` / `summary`, `start_at`, `end_at`, and `html_link` inside `evidence` when the tool returned them. For `search_calendar_events`, set `evidence.items` to `{ "total": N }` and list the matched events in `action_summary` (title, date, start time; one line per event; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
|
@ -115,7 +115,7 @@ Rules:
|
||||||
- `status=success` → `next_step=null`, `missing_fields=null`.
|
- `status=success` → `next_step=null`, `missing_fields=null`.
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For `search_calendar_events` results, populate `evidence.items` with `{ "events": [...], "total": N }`.
|
- For `search_calendar_events` results, set `evidence.items` to `{ "total": N }` and list the matched events in `action_summary` (title, date, start time; up to 10 entries, then `"...and N more"`).
|
||||||
- For ambiguous matches across `update_calendar_event` / `delete_calendar_event`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`, where `label` should include the event title and start time for human readability).
|
- For ambiguous matches across `update_calendar_event` / `delete_calendar_event`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`, where `label` should include the event title and start time for human readability).
|
||||||
|
|
||||||
Infer before you call; map every tool outcome faithfully.
|
Infer before you call; map every tool outcome faithfully.
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ Failure handling:
|
||||||
<example>
|
<example>
|
||||||
Supervisor: "Find tasks about the homepage redesign."
|
Supervisor: "Find tasks about the homepage redesign."
|
||||||
1. Workspace search for "homepage redesign" → matched tasks.
|
1. Workspace search for "homepage redesign" → matched tasks.
|
||||||
2. Return `status=success` with the matched tasks in `evidence.items`.
|
2. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched tasks listed in `action_summary` (task id, title, status, assignees; one line per task; up to 10 entries, then `"...and N more"`).
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
|
|
@ -98,7 +98,7 @@ Rules:
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: task, list, member, status, custom-field choice, etc.).
|
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: task, list, member, status, custom-field choice, etc.).
|
||||||
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
|
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (task id, title, status, assignees; up to 10 entries, then `"...and N more"`).
|
||||||
</output_contract>
|
</output_contract>
|
||||||
|
|
||||||
Discover before you mutate; never guess identifiers, list statuses, or assignees.
|
Discover before you mutate; never guess identifiers, list statuses, or assignees.
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ You are a Discord specialist for the user's connected Discord server.
|
||||||
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
||||||
| tool raises / unknown | `error` | `"Discord tool failed unexpectedly. Ask the user to retry shortly."` |
|
| tool raises / unknown | `error` | `"Discord tool failed unexpectedly. Ask the user to retry shortly."` |
|
||||||
|
|
||||||
Surface the tool's `message`, `channel_id`, `message_id`, and the listed channels/messages payload inside `evidence` when the tool returned them. Never invent a field the tool did not return.
|
Surface the tool's `message`, `channel_id`, and `message_id` inside `evidence` when the tool returned them. For `list_discord_channels` and `read_discord_messages`, set `evidence.items` to `{ "total": N }` and list the matched entries in `action_summary` (channel name or sender + timestamp + short text snippet; one line per entry; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ You are a Gmail specialist for the user's connected Gmail mailbox.
|
||||||
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
||||||
| tool raises / unknown | `error` | `"Gmail tool failed unexpectedly. Ask the user to retry shortly."` |
|
| tool raises / unknown | `error` | `"Gmail tool failed unexpectedly. Ask the user to retry shortly."` |
|
||||||
|
|
||||||
Surface the tool's `message_id`, `thread_id`, `draft_id`, `subject`, and recipient fields inside `evidence` when the tool returned them. For `search_gmail`, place the raw `emails` array inside `evidence.items`. Never invent a field the tool did not return.
|
Surface the tool's `message_id`, `thread_id`, `draft_id`, `subject`, and recipient fields inside `evidence` when the tool returned them. For `search_gmail`, set `evidence.items` to `{ "total": N }` and list the matched emails in `action_summary` (sender, subject, date; one line per email; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
|
@ -114,7 +114,7 @@ Rules:
|
||||||
- `status=success` → `next_step=null`, `missing_fields=null`.
|
- `status=success` → `next_step=null`, `missing_fields=null`.
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For `search_gmail` results, populate `evidence.items` with `{ "emails": [...], "total": N }`.
|
- For `search_gmail` results, set `evidence.items` to `{ "total": N }` and list the matched emails in `action_summary` (sender, subject, date; up to 10 entries, then `"...and N more"`).
|
||||||
- For ambiguous matches across `update_gmail_draft` / `trash_gmail_email` / `read_gmail_email`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`).
|
- For ambiguous matches across `update_gmail_draft` / `trash_gmail_email` / `read_gmail_email`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`).
|
||||||
|
|
||||||
Infer before you call; verify before you send; map every tool outcome faithfully.
|
Infer before you call; verify before you send; map every tool outcome faithfully.
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ Failure handling:
|
||||||
<example>
|
<example>
|
||||||
Supervisor: "Find issues assigned to me with status 'In Progress'."
|
Supervisor: "Find issues assigned to me with status 'In Progress'."
|
||||||
1. JQL search with `assignee = currentUser() AND status = "In Progress"`.
|
1. JQL search with `assignee = currentUser() AND status = "In Progress"`.
|
||||||
2. Return `status=success` with the matched issues in `evidence.items`.
|
2. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched issues listed in `action_summary` (issue key, summary, status, assignee; one line per issue; up to 10 entries, then `"...and N more"`).
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
|
|
@ -116,7 +116,7 @@ Rules:
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: site, project, issue, user, transition, etc.).
|
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: site, project, issue, user, transition, etc.).
|
||||||
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
|
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (issue key, summary, status, assignee; up to 10 entries, then `"...and N more"`).
|
||||||
</output_contract>
|
</output_contract>
|
||||||
|
|
||||||
Discover before you mutate; never guess identifiers, transitions, or required fields.
|
Discover before you mutate; never guess identifiers, transitions, or required fields.
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ Failure handling:
|
||||||
<example>
|
<example>
|
||||||
Supervisor: "Find issues assigned to me with priority Urgent."
|
Supervisor: "Find issues assigned to me with priority Urgent."
|
||||||
1. Discovery: list issues with filters `{assignee: "me", priority: 1}`.
|
1. Discovery: list issues with filters `{assignee: "me", priority: 1}`.
|
||||||
2. Return `status=success` with the matched issues in `evidence.items`.
|
2. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched issues listed in `action_summary` (identifier, title, state, assignee; one line per issue; up to 10 entries, then `"...and N more"`).
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
|
|
@ -106,7 +106,7 @@ Rules:
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: issue, user, project, state, etc.).
|
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: issue, user, project, state, etc.).
|
||||||
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
|
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (identifier, title, state, assignee; up to 10 entries, then `"...and N more"`).
|
||||||
</output_contract>
|
</output_contract>
|
||||||
|
|
||||||
Discover before you mutate; never guess identifiers.
|
Discover before you mutate; never guess identifiers.
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ You are a Luma specialist for the user's connected Luma account.
|
||||||
| `error` | `error` | Relay the tool's `message` verbatim as `next_step` (this covers Luma Plus 403s and other API errors). |
|
| `error` | `error` | Relay the tool's `message` verbatim as `next_step` (this covers Luma Plus 403s and other API errors). |
|
||||||
| tool raises / unknown | `error` | `"Luma tool failed unexpectedly. Ask the user to retry shortly."` |
|
| tool raises / unknown | `error` | `"Luma tool failed unexpectedly. Ask the user to retry shortly."` |
|
||||||
|
|
||||||
Surface the tool's `message`, `event_id`, `name`, `start_at`, and `url` inside `evidence` when the tool returned them. Never invent a field the tool did not return.
|
Surface the tool's `message`, `event_id`, `name`, `start_at`, and `url` inside `evidence` when the tool returned them. For `list_luma_events`, set `evidence.items` to `{ "total": N }` and list the matched events in `action_summary` (event name, start date/time, location if present; one line per event; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ Failure handling:
|
||||||
Supervisor: "Summarize the latest discussion in #marketing."
|
Supervisor: "Summarize the latest discussion in #marketing."
|
||||||
1. Search channels for "marketing" → one strong match. Capture the channel ID.
|
1. Search channels for "marketing" → one strong match. Capture the channel ID.
|
||||||
2. Read that channel's recent message history.
|
2. Read that channel's recent message history.
|
||||||
3. Return `status=success` with the message list in `evidence.items`.
|
3. Return `status=success` with `evidence.items` set to `{ "total": N }` and the messages listed in `action_summary` (sender, timestamp, text snippet; one line per message; up to 10 entries, then `"...and N more"`).
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
|
|
@ -92,7 +92,7 @@ Rules:
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: channel, user, message, thread).
|
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: channel, user, message, thread).
|
||||||
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
|
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (channel/user, key identifier, timestamp, short snippet; up to 10 entries, then `"...and N more"`).
|
||||||
</output_contract>
|
</output_contract>
|
||||||
|
|
||||||
Discover before you post; never guess channel, user, or thread targets.
|
Discover before you post; never guess channel, user, or thread targets.
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ You are a Microsoft Teams specialist for the user's connected Teams account.
|
||||||
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
||||||
| tool raises / unknown | `error` | `"Teams tool failed unexpectedly. Ask the user to retry shortly."` |
|
| tool raises / unknown | `error` | `"Teams tool failed unexpectedly. Ask the user to retry shortly."` |
|
||||||
|
|
||||||
Surface the tool's `message`, `team_id`, `team_name`, `channel_id`, `channel_name`, and `message_id` inside `evidence` when the tool returned them. Never invent a field the tool did not return.
|
Surface the tool's `message`, `team_id`, `team_name`, `channel_id`, `channel_name`, and `message_id` inside `evidence` when the tool returned them. For `list_teams_channels` and `read_teams_messages`, set `evidence.items` to `{ "total": N }` and list the matched entries in `action_summary` (team › channel, or sender + timestamp + short text snippet; one line per entry; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,7 @@ from app.agents.new_chat.tools.registry import (
|
||||||
)
|
)
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
|
from app.services.llm_service import get_planner_llm
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
_perf_log = get_perf_logger()
|
_perf_log = get_perf_logger()
|
||||||
|
|
@ -1077,6 +1078,7 @@ def _build_compiled_agent_blocking(
|
||||||
else None,
|
else None,
|
||||||
KnowledgePriorityMiddleware(
|
KnowledgePriorityMiddleware(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
planner_llm=get_planner_llm(),
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
filesystem_mode=filesystem_mode,
|
filesystem_mode=filesystem_mode,
|
||||||
available_connectors=available_connectors,
|
available_connectors=available_connectors,
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect).
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -249,11 +250,11 @@ async def _create_document(
|
||||||
session.add(doc)
|
session.add(doc)
|
||||||
await session.flush()
|
await session.flush()
|
||||||
|
|
||||||
summary_embedding = embed_texts([content])[0]
|
summary_embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
|
||||||
doc.embedding = summary_embedding
|
doc.embedding = summary_embedding
|
||||||
chunks = chunk_text(content)
|
chunks = chunk_text(content)
|
||||||
if chunks:
|
if chunks:
|
||||||
chunk_embeddings = embed_texts(chunks)
|
chunk_embeddings = await asyncio.to_thread(embed_texts, chunks)
|
||||||
session.add_all(
|
session.add_all(
|
||||||
[
|
[
|
||||||
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||||
|
|
@ -295,13 +296,13 @@ async def _update_document(
|
||||||
search_space_id,
|
search_space_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
summary_embedding = embed_texts([content])[0]
|
summary_embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
|
||||||
document.embedding = summary_embedding
|
document.embedding = summary_embedding
|
||||||
|
|
||||||
await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
|
await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
|
||||||
chunks = chunk_text(content)
|
chunks = chunk_text(content)
|
||||||
if chunks:
|
if chunks:
|
||||||
chunk_embeddings = embed_texts(chunks)
|
chunk_embeddings = await asyncio.to_thread(embed_texts, chunks)
|
||||||
session.add_all(
|
session.add_all(
|
||||||
[
|
[
|
||||||
Chunk(document_id=document.id, content=text, embedding=embedding)
|
Chunk(document_id=document.id, content=text, embedding=embedding)
|
||||||
|
|
|
||||||
|
|
@ -457,7 +457,7 @@ async def search_knowledge_base(
|
||||||
if not query:
|
if not query:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
[embedding] = embed_texts([query])
|
[embedding] = await asyncio.to_thread(embed_texts, [query])
|
||||||
doc_types = _resolve_search_types(available_connectors, available_document_types)
|
doc_types = _resolve_search_types(available_connectors, available_document_types)
|
||||||
retriever_top_k = min(top_k * 3, 30)
|
retriever_top_k = min(top_k * 3, 30)
|
||||||
|
|
||||||
|
|
@ -579,6 +579,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
llm: BaseChatModel | None = None,
|
llm: BaseChatModel | None = None,
|
||||||
|
planner_llm: BaseChatModel | None = None,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||||
available_connectors: list[str] | None = None,
|
available_connectors: list[str] | None = None,
|
||||||
|
|
@ -588,6 +589,15 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
inject_system_message: bool = True, # For backwards compatibility
|
inject_system_message: bool = True, # For backwards compatibility
|
||||||
) -> None:
|
) -> None:
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
|
# The planner LLM handles short, structured internal tasks (query
|
||||||
|
# rewriting, date extraction, recency classification). When an
|
||||||
|
# operator marks a global config ``is_planner: true`` we route
|
||||||
|
# those calls to a cheap/fast model (e.g. gpt-4o-mini, Haiku, Azure
|
||||||
|
# gpt-5.x-nano) instead of the user's chat LLM — those classification
|
||||||
|
# tasks don't need frontier-tier capability. Falls back to the chat
|
||||||
|
# LLM when no planner config is wired up so deployments without one
|
||||||
|
# keep working unchanged.
|
||||||
|
self.planner_llm = planner_llm or llm
|
||||||
self.search_space_id = search_space_id
|
self.search_space_id = search_space_id
|
||||||
self.filesystem_mode = filesystem_mode
|
self.filesystem_mode = filesystem_mode
|
||||||
self.available_connectors = available_connectors
|
self.available_connectors = available_connectors
|
||||||
|
|
@ -598,7 +608,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
# Build the kb-planner private Runnable ONCE here so we don't pay
|
# Build the kb-planner private Runnable ONCE here so we don't pay
|
||||||
# the ``create_agent`` compile cost (50-200ms) on every turn.
|
# the ``create_agent`` compile cost (50-200ms) on every turn.
|
||||||
# Disabled by default behind ``enable_kb_planner_runnable``; when
|
# Disabled by default behind ``enable_kb_planner_runnable``; when
|
||||||
# off the planner falls back to the legacy ``self.llm.ainvoke``
|
# off the planner falls back to the legacy ``planner_llm.ainvoke``
|
||||||
# path.
|
# path.
|
||||||
self._planner: Runnable | None = None
|
self._planner: Runnable | None = None
|
||||||
self._planner_compile_failed = False
|
self._planner_compile_failed = False
|
||||||
|
|
@ -608,7 +618,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
|
||||||
Returns ``None`` when the feature flag is disabled, when the LLM is
|
Returns ``None`` when the feature flag is disabled, when the LLM is
|
||||||
unavailable, or when ``create_agent`` raises (we fall back to the
|
unavailable, or when ``create_agent`` raises (we fall back to the
|
||||||
legacy ``self.llm.ainvoke`` path in that case). Compilation happens
|
legacy ``planner_llm.ainvoke`` path in that case). Compilation happens
|
||||||
lazily on first call, then memoized via ``self._planner``.
|
lazily on first call, then memoized via ``self._planner``.
|
||||||
|
|
||||||
The compiled agent is constructed without tools — the planner's
|
The compiled agent is constructed without tools — the planner's
|
||||||
|
|
@ -618,7 +628,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
"""
|
"""
|
||||||
if self._planner is not None or self._planner_compile_failed:
|
if self._planner is not None or self._planner_compile_failed:
|
||||||
return self._planner
|
return self._planner
|
||||||
if self.llm is None:
|
if self.planner_llm is None:
|
||||||
return None
|
return None
|
||||||
flags = get_flags()
|
flags = get_flags()
|
||||||
if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
|
if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
|
||||||
|
|
@ -628,13 +638,13 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._planner = create_agent(
|
self._planner = create_agent(
|
||||||
self.llm,
|
self.planner_llm,
|
||||||
tools=[],
|
tools=[],
|
||||||
middleware=[RetryAfterMiddleware(max_retries=2)],
|
middleware=[RetryAfterMiddleware(max_retries=2)],
|
||||||
)
|
)
|
||||||
except Exception as exc: # pragma: no cover - defensive
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"kb-planner Runnable compile failed; falling back to llm.ainvoke: %s",
|
"kb-planner Runnable compile failed; falling back to planner_llm.ainvoke: %s",
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
self._planner_compile_failed = True
|
self._planner_compile_failed = True
|
||||||
|
|
@ -647,12 +657,12 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
messages: Sequence[BaseMessage],
|
messages: Sequence[BaseMessage],
|
||||||
user_text: str,
|
user_text: str,
|
||||||
) -> tuple[str, datetime | None, datetime | None, bool]:
|
) -> tuple[str, datetime | None, datetime | None, bool]:
|
||||||
if self.llm is None:
|
if self.planner_llm is None:
|
||||||
return user_text, None, None, False
|
return user_text, None, None, False
|
||||||
|
|
||||||
recent_conversation = _render_recent_conversation(
|
recent_conversation = _render_recent_conversation(
|
||||||
messages,
|
messages,
|
||||||
llm=self.llm,
|
llm=self.planner_llm,
|
||||||
user_text=user_text,
|
user_text=user_text,
|
||||||
)
|
)
|
||||||
prompt = _build_kb_planner_prompt(
|
prompt = _build_kb_planner_prompt(
|
||||||
|
|
@ -663,8 +673,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
t0 = loop.time()
|
t0 = loop.time()
|
||||||
|
|
||||||
# Prefer the compiled-once planner Runnable when enabled; otherwise
|
# Prefer the compiled-once planner Runnable when enabled; otherwise
|
||||||
# fall back to ``self.llm.ainvoke``. The ``surfsense:internal`` tag
|
# fall back to ``planner_llm.ainvoke``. The ``surfsense:internal``
|
||||||
# is preserved on both paths so ``_stream_agent_events`` still
|
# tag is preserved on both paths so ``_stream_agent_events`` still
|
||||||
# suppresses the planner's intermediate events from the UI.
|
# suppresses the planner's intermediate events from the UI.
|
||||||
planner = self._build_kb_planner_runnable()
|
planner = self._build_kb_planner_runnable()
|
||||||
try:
|
try:
|
||||||
|
|
@ -684,7 +694,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
else AIMessage(content="")
|
else AIMessage(content="")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await self.llm.ainvoke(
|
response = await self.planner_llm.ainvoke(
|
||||||
[HumanMessage(content=prompt)],
|
[HumanMessage(content=prompt)],
|
||||||
config={"tags": ["surfsense:internal"]},
|
config={"tags": ["surfsense:internal"]},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
|
@ -41,6 +42,9 @@ from app.agents.new_chat.path_resolver import (
|
||||||
doc_to_virtual_path,
|
doc_to_virtual_path,
|
||||||
)
|
)
|
||||||
from app.db import Document, shielded_async_session
|
from app.db import Document, shielded_async_session
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm import token_counter
|
from litellm import token_counter
|
||||||
|
|
@ -124,6 +128,7 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
update: dict[str, Any] = {}
|
update: dict[str, Any] = {}
|
||||||
if not state.get("cwd"):
|
if not state.get("cwd"):
|
||||||
update["cwd"] = DOCUMENTS_ROOT
|
update["cwd"] = DOCUMENTS_ROOT
|
||||||
|
|
@ -131,7 +136,11 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
anon_doc = state.get("kb_anon_doc")
|
anon_doc = state.get("kb_anon_doc")
|
||||||
if anon_doc:
|
if anon_doc:
|
||||||
tree_msg = self._render_anon_tree(anon_doc)
|
tree_msg = self._render_anon_tree(anon_doc)
|
||||||
|
cache_outcome = "anon"
|
||||||
else:
|
else:
|
||||||
|
version = int(state.get("tree_version") or 0)
|
||||||
|
cache_key = (self.search_space_id, version, False)
|
||||||
|
cache_outcome = "hit" if cache_key in self._cache else "miss"
|
||||||
tree_msg = await self._render_kb_tree(state)
|
tree_msg = await self._render_kb_tree(state)
|
||||||
|
|
||||||
update["workspace_tree_text"] = tree_msg
|
update["workspace_tree_text"] = tree_msg
|
||||||
|
|
@ -141,6 +150,14 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
insert_at = max(len(messages) - 1, 0)
|
insert_at = max(len(messages) - 1, 0)
|
||||||
messages.insert(insert_at, SystemMessage(content=tree_msg))
|
messages.insert(insert_at, SystemMessage(content=tree_msg))
|
||||||
update["messages"] = messages
|
update["messages"] = messages
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[knowledge_tree] cache=%s chars=%d elapsed=%.3fs space=%d",
|
||||||
|
cache_outcome,
|
||||||
|
len(tree_msg),
|
||||||
|
time.perf_counter() - start,
|
||||||
|
self.search_space_id,
|
||||||
|
)
|
||||||
return update
|
return update
|
||||||
|
|
||||||
def before_agent( # type: ignore[override]
|
def before_agent( # type: ignore[override]
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ Injects memory markdown into the system prompt on every turn:
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
@ -19,8 +20,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import ChatVisibility, SearchSpace, User, shielded_async_session
|
from app.db import ChatVisibility, SearchSpace, User, shielded_async_session
|
||||||
from app.services.memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
|
from app.services.memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
|
@ -53,9 +56,13 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
if not isinstance(last_message, HumanMessage):
|
if not isinstance(last_message, HumanMessage):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
db_elapsed = 0.0
|
||||||
memory_blocks: list[str] = []
|
memory_blocks: list[str] = []
|
||||||
|
scope = "team" if self.visibility == ChatVisibility.SEARCH_SPACE else "user"
|
||||||
|
|
||||||
async with shielded_async_session() as session:
|
async with shielded_async_session() as session:
|
||||||
|
db_start = time.perf_counter()
|
||||||
if self.visibility == ChatVisibility.SEARCH_SPACE:
|
if self.visibility == ChatVisibility.SEARCH_SPACE:
|
||||||
team_memory = await self._load_team_memory(session)
|
team_memory = await self._load_team_memory(session)
|
||||||
if team_memory:
|
if team_memory:
|
||||||
|
|
@ -96,7 +103,15 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
f"</memory_warning>"
|
f"</memory_warning>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
db_elapsed = time.perf_counter() - db_start
|
||||||
|
|
||||||
if not memory_blocks:
|
if not memory_blocks:
|
||||||
|
_perf_log.info(
|
||||||
|
"[memory_injection] scope=%s injected=0 db=%.3fs total=%.3fs",
|
||||||
|
scope,
|
||||||
|
db_elapsed,
|
||||||
|
time.perf_counter() - start,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
memory_text = "\n\n".join(memory_blocks)
|
memory_text = "\n\n".join(memory_blocks)
|
||||||
|
|
@ -106,6 +121,13 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
insert_idx = 1 if len(new_messages) > 1 else 0
|
insert_idx = 1 if len(new_messages) > 1 else 0
|
||||||
new_messages.insert(insert_idx, memory_msg)
|
new_messages.insert(insert_idx, memory_msg)
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[memory_injection] scope=%s injected=1 chars=%d db=%.3fs total=%.3fs",
|
||||||
|
scope,
|
||||||
|
len(memory_text),
|
||||||
|
db_elapsed,
|
||||||
|
time.perf_counter() - start,
|
||||||
|
)
|
||||||
return {"messages": new_messages}
|
return {"messages": new_messages}
|
||||||
|
|
||||||
async def _load_user_memory(
|
async def _load_user_memory(
|
||||||
|
|
|
||||||
|
|
@ -39,9 +39,19 @@ For OpenAI-family configs we additionally pass:
|
||||||
|
|
||||||
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
|
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
|
||||||
raises hit rate by sending requests with a shared prefix to the same
|
raises hit rate by sending requests with a shared prefix to the same
|
||||||
backend.
|
backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and
|
||||||
|
``azure/`` (added to LiteLLM's Azure transformer in
|
||||||
|
https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified
|
||||||
|
against ``AzureOpenAIConfig.get_supported_openai_params`` in our
|
||||||
|
installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``,
|
||||||
|
``azure/gpt-5.4``, ``azure/gpt-5.4-mini``).
|
||||||
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
|
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
|
||||||
5-10 min in-memory cache.
|
5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's
|
||||||
|
server-side support landed in Microsoft's docs on 2026-05-13 but
|
||||||
|
LiteLLM 1.83.14's Azure transformer still omits it from its supported
|
||||||
|
params list, so it gets silently dropped by ``litellm.drop_params``.
|
||||||
|
Azure's default in-memory retention (5-10 min, max 1 h) already
|
||||||
|
bridges intra-conversation turns; revisit when LiteLLM bumps Azure.
|
||||||
|
|
||||||
Safety net: ``litellm.drop_params=True`` is set globally in
|
Safety net: ``litellm.drop_params=True`` is set globally in
|
||||||
``app.services.llm_service`` at module-load time. Any kwarg the destination
|
``app.services.llm_service`` at module-load time. Any kwarg the destination
|
||||||
|
|
@ -81,13 +91,31 @@ _DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||||
{"location": "message", "index": -1},
|
{"location": "message", "index": -1},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Providers (uppercase ``AgentConfig.provider`` values) that natively expose
|
# Providers (uppercase ``AgentConfig.provider`` values) that accept the
|
||||||
# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and
|
# OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs
|
||||||
# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers
|
# (2026-05-13) confirm automatic prompt caching applies to every GPT-4o
|
||||||
# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without
|
# or newer Azure deployment at ≥1024 tokens with no configuration needed,
|
||||||
# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU,
|
# and that ``prompt_cache_key`` is combined with the prefix hash to
|
||||||
# MINIMAX), so we can't infer family from the litellm prefix alone.
|
# improve routing affinity and therefore cache hit rate. LiteLLM's Azure
|
||||||
_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"})
|
# transformer ships ``prompt_cache_key`` in its supported params as of
|
||||||
|
# https://github.com/BerriAI/litellm/pull/20989.
|
||||||
|
#
|
||||||
|
# Strict whitelist — many other providers in ``PROVIDER_MAP`` route
|
||||||
|
# through litellm's ``openai`` prefix without implementing the OpenAI
|
||||||
|
# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer
|
||||||
|
# family from the litellm prefix alone.
|
||||||
|
_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset(
|
||||||
|
{"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept
|
||||||
|
# ``prompt_cache_retention="24h"``. Azure is excluded: see module
|
||||||
|
# docstring — LiteLLM 1.83.14's Azure transformer omits the param so
|
||||||
|
# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM
|
||||||
|
# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``.
|
||||||
|
_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset(
|
||||||
|
{"OPENAI", "DEEPSEEK", "XAI"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_router_llm(llm: BaseChatModel) -> bool:
|
def _is_router_llm(llm: BaseChatModel) -> bool:
|
||||||
|
|
@ -101,13 +129,13 @@ def _is_router_llm(llm: BaseChatModel) -> bool:
|
||||||
return type(llm).__name__ == "ChatLiteLLMRouter"
|
return type(llm).__name__ == "ChatLiteLLMRouter"
|
||||||
|
|
||||||
|
|
||||||
def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
|
def _provider_supports_prompt_cache_key(agent_config: AgentConfig | None) -> bool:
|
||||||
"""Whether the config targets an OpenAI-style prompt-cache surface.
|
"""Whether the config targets a provider that accepts ``prompt_cache_key``.
|
||||||
|
|
||||||
Strict — only returns True when the user explicitly chose OPENAI,
|
Strict — only returns True for explicitly chosen OPENAI, DEEPSEEK,
|
||||||
DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` /
|
XAI, AZURE, or AZURE_OPENAI providers. Auto-mode and custom
|
||||||
``YAMLConfig``. Auto-mode and custom providers return False because
|
providers return False because we can't statically know the
|
||||||
we can't statically know the destination.
|
destination and the router fans out across mixed providers.
|
||||||
"""
|
"""
|
||||||
if agent_config is None or not agent_config.provider:
|
if agent_config is None or not agent_config.provider:
|
||||||
return False
|
return False
|
||||||
|
|
@ -115,7 +143,25 @@ def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
|
||||||
return False
|
return False
|
||||||
if agent_config.custom_provider:
|
if agent_config.custom_provider:
|
||||||
return False
|
return False
|
||||||
return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS
|
return agent_config.provider.upper() in _PROMPT_CACHE_KEY_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_supports_prompt_cache_retention(
|
||||||
|
agent_config: AgentConfig | None,
|
||||||
|
) -> bool:
|
||||||
|
"""Whether the config targets a provider that accepts ``prompt_cache_retention``.
|
||||||
|
|
||||||
|
Tighter than :func:`_provider_supports_prompt_cache_key` — Azure
|
||||||
|
deployments are excluded until LiteLLM ships the param in its Azure
|
||||||
|
transformer (see module docstring).
|
||||||
|
"""
|
||||||
|
if agent_config is None or not agent_config.provider:
|
||||||
|
return False
|
||||||
|
if agent_config.is_auto_mode:
|
||||||
|
return False
|
||||||
|
if agent_config.custom_provider:
|
||||||
|
return False
|
||||||
|
return agent_config.provider.upper() in _PROMPT_CACHE_RETENTION_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
|
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
|
||||||
|
|
@ -173,16 +219,23 @@ def apply_litellm_prompt_caching(
|
||||||
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
||||||
]
|
]
|
||||||
|
|
||||||
# OpenAI-family extras only when we statically know the destination is
|
# OpenAI-style extras only when we statically know the destination
|
||||||
# OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers
|
# accepts them. Auto-mode router fans out across mixed providers so
|
||||||
# so we can't safely set OpenAI-only kwargs there (drop_params would
|
# we can't safely set destination-specific kwargs there (drop_params
|
||||||
# strip them but it's wasteful to set them in the first place).
|
# would strip them but it's wasteful to set them in the first
|
||||||
|
# place).
|
||||||
if _is_router_llm(llm):
|
if _is_router_llm(llm):
|
||||||
return
|
return
|
||||||
if not _is_openai_family_config(agent_config):
|
|
||||||
return
|
|
||||||
|
|
||||||
if thread_id is not None and "prompt_cache_key" not in model_kwargs:
|
if (
|
||||||
|
thread_id is not None
|
||||||
|
and "prompt_cache_key" not in model_kwargs
|
||||||
|
and _provider_supports_prompt_cache_key(agent_config)
|
||||||
|
):
|
||||||
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
|
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
|
||||||
if "prompt_cache_retention" not in model_kwargs:
|
|
||||||
|
if (
|
||||||
|
"prompt_cache_retention" not in model_kwargs
|
||||||
|
and _provider_supports_prompt_cache_retention(agent_config)
|
||||||
|
):
|
||||||
model_kwargs["prompt_cache_retention"] = "24h"
|
model_kwargs["prompt_cache_retention"] = "24h"
|
||||||
|
|
|
||||||
|
|
@ -36,8 +36,16 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||||
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
CachedMCPTools,
|
||||||
|
read_cached_tools,
|
||||||
|
write_cached_tools,
|
||||||
|
)
|
||||||
from app.db import SearchSourceConnector
|
from app.db import SearchSourceConnector
|
||||||
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
|
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -293,15 +301,21 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
timeout: float = 60.0,
|
timeout: float = 60.0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Execute a single MCP HTTP call with the given headers."""
|
"""Execute a single MCP HTTP call with the given headers."""
|
||||||
|
call_start = time.perf_counter()
|
||||||
async with (
|
async with (
|
||||||
streamablehttp_client(url, headers=call_headers) as (read, write, _),
|
streamablehttp_client(url, headers=call_headers) as (read, write, _),
|
||||||
ClientSession(read, write) as session,
|
ClientSession(read, write) as session,
|
||||||
):
|
):
|
||||||
|
init_start = time.perf_counter()
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
init_elapsed = time.perf_counter() - init_start
|
||||||
|
|
||||||
|
tool_start = time.perf_counter()
|
||||||
response = await asyncio.wait_for(
|
response = await asyncio.wait_for(
|
||||||
session.call_tool(original_tool_name, arguments=call_kwargs),
|
session.call_tool(original_tool_name, arguments=call_kwargs),
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
tool_elapsed = time.perf_counter() - tool_start
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for content in response.content:
|
for content in response.content:
|
||||||
|
|
@ -312,7 +326,18 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
else:
|
else:
|
||||||
result.append(str(content))
|
result.append(str(content))
|
||||||
|
|
||||||
return "\n".join(result) if result else ""
|
payload = "\n".join(result) if result else ""
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_http_call] connector=%s tool=%s init=%.3fs call=%.3fs total=%.3fs out_chars=%d",
|
||||||
|
connector_id,
|
||||||
|
original_tool_name,
|
||||||
|
init_elapsed,
|
||||||
|
tool_elapsed,
|
||||||
|
time.perf_counter() - call_start,
|
||||||
|
len(payload),
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
|
||||||
async def mcp_http_tool_call(**kwargs) -> str:
|
async def mcp_http_tool_call(**kwargs) -> str:
|
||||||
"""Execute the MCP tool call via HTTP transport."""
|
"""Execute the MCP tool call via HTTP transport."""
|
||||||
|
|
@ -496,6 +521,7 @@ async def _load_http_mcp_tools(
|
||||||
is_generic_mcp: bool = False,
|
is_generic_mcp: bool = False,
|
||||||
*,
|
*,
|
||||||
bypass_internal_hitl: bool = False,
|
bypass_internal_hitl: bool = False,
|
||||||
|
cached_tools: CachedMCPTools | None = None,
|
||||||
) -> list[StructuredTool]:
|
) -> list[StructuredTool]:
|
||||||
"""Load tools from an HTTP-based MCP server.
|
"""Load tools from an HTTP-based MCP server.
|
||||||
|
|
||||||
|
|
@ -506,6 +532,8 @@ async def _load_http_mcp_tools(
|
||||||
readonly_tools: Tool names that skip HITL approval (read-only operations).
|
readonly_tools: Tool names that skip HITL approval (read-only operations).
|
||||||
tool_name_prefix: If set, each tool name is prefixed for multi-account
|
tool_name_prefix: If set, each tool name is prefixed for multi-account
|
||||||
disambiguation (e.g. ``linear_25``).
|
disambiguation (e.g. ``linear_25``).
|
||||||
|
cached_tools: If provided, skip live discovery and rebuild wrappers
|
||||||
|
from the persisted definitions.
|
||||||
"""
|
"""
|
||||||
tools: list[StructuredTool] = []
|
tools: list[StructuredTool] = []
|
||||||
|
|
||||||
|
|
@ -529,15 +557,23 @@ async def _load_http_mcp_tools(
|
||||||
|
|
||||||
allowed_set = set(allowed_tools) if allowed_tools else None
|
allowed_set = set(allowed_tools) if allowed_tools else None
|
||||||
|
|
||||||
async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]:
|
async def _discover(
|
||||||
"""Connect, initialize, and list tools from the MCP server."""
|
disc_headers: dict[str, str],
|
||||||
|
) -> tuple[dict[str, str | None], list[dict[str, Any]]]:
|
||||||
|
"""Connect, initialize, and list tools — returns (serverInfo, tools)."""
|
||||||
async with (
|
async with (
|
||||||
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
|
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
|
||||||
ClientSession(read, write) as session,
|
ClientSession(read, write) as session,
|
||||||
):
|
):
|
||||||
await session.initialize()
|
init_result = await session.initialize()
|
||||||
|
server_info: dict[str, str | None] = {"name": None, "version": None}
|
||||||
|
si = getattr(init_result, "serverInfo", None)
|
||||||
|
if si is not None:
|
||||||
|
server_info["name"] = getattr(si, "name", None)
|
||||||
|
server_info["version"] = getattr(si, "version", None)
|
||||||
|
|
||||||
response = await session.list_tools()
|
response = await session.list_tools()
|
||||||
return [
|
return server_info, [
|
||||||
{
|
{
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description or "",
|
"description": tool.description or "",
|
||||||
|
|
@ -548,47 +584,65 @@ async def _load_http_mcp_tools(
|
||||||
for tool in response.tools
|
for tool in response.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
if cached_tools is not None:
|
||||||
tool_definitions = await _discover(headers)
|
tool_definitions = [
|
||||||
except Exception as first_err:
|
{
|
||||||
if not _is_auth_error(first_err) or connector_id is None:
|
"name": td.name,
|
||||||
logger.exception(
|
"description": td.description,
|
||||||
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
"input_schema": td.input_schema,
|
||||||
url,
|
}
|
||||||
connector_id,
|
for td in cached_tools.tools
|
||||||
first_err,
|
]
|
||||||
)
|
else:
|
||||||
return tools
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
|
|
||||||
connector_id,
|
|
||||||
)
|
|
||||||
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
|
||||||
if fresh_headers is None:
|
|
||||||
await _mark_connector_auth_expired(connector_id)
|
|
||||||
logger.error(
|
|
||||||
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
|
|
||||||
connector_id,
|
|
||||||
)
|
|
||||||
return tools
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_definitions = await _discover(fresh_headers)
|
server_info, tool_definitions = await _discover(headers)
|
||||||
headers = fresh_headers
|
except Exception as first_err:
|
||||||
logger.info(
|
if not _is_auth_error(first_err) or connector_id is None:
|
||||||
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
|
logger.exception(
|
||||||
|
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
||||||
|
url,
|
||||||
|
connector_id,
|
||||||
|
first_err,
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
|
||||||
connector_id,
|
connector_id,
|
||||||
)
|
)
|
||||||
except Exception as retry_err:
|
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
||||||
logger.exception(
|
if fresh_headers is None:
|
||||||
"HTTP MCP discovery for connector %d still failing after refresh: %s",
|
|
||||||
connector_id,
|
|
||||||
retry_err,
|
|
||||||
)
|
|
||||||
if _is_auth_error(retry_err):
|
|
||||||
await _mark_connector_auth_expired(connector_id)
|
await _mark_connector_auth_expired(connector_id)
|
||||||
return tools
|
logger.error(
|
||||||
|
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
try:
|
||||||
|
server_info, tool_definitions = await _discover(fresh_headers)
|
||||||
|
headers = fresh_headers
|
||||||
|
logger.info(
|
||||||
|
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
except Exception as retry_err:
|
||||||
|
logger.exception(
|
||||||
|
"HTTP MCP discovery for connector %d still failing after refresh: %s",
|
||||||
|
connector_id,
|
||||||
|
retry_err,
|
||||||
|
)
|
||||||
|
if _is_auth_error(retry_err):
|
||||||
|
await _mark_connector_auth_expired(connector_id)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
await write_cached_tools(
|
||||||
|
connector_id,
|
||||||
|
tool_definitions,
|
||||||
|
server_name=server_info.get("name"),
|
||||||
|
server_version=server_info.get("version"),
|
||||||
|
transport=server_config.get("transport", "streamable-http"),
|
||||||
|
)
|
||||||
|
|
||||||
total_discovered = len(tool_definitions)
|
total_discovered = len(tool_definitions)
|
||||||
|
|
||||||
|
|
@ -792,14 +846,25 @@ async def _maybe_refresh_mcp_oauth_token(
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return server_config
|
return server_config
|
||||||
|
|
||||||
|
refresh_start = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
new_access = await _refresh_connector_token(session, connector)
|
new_access = await _refresh_connector_token(session, connector)
|
||||||
if not new_access:
|
if not new_access:
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=no_token",
|
||||||
|
connector.id,
|
||||||
|
time.perf_counter() - refresh_start,
|
||||||
|
)
|
||||||
return server_config
|
return server_config
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Proactively refreshed MCP OAuth token for connector %s", connector.id
|
"Proactively refreshed MCP OAuth token for connector %s", connector.id
|
||||||
)
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=refreshed",
|
||||||
|
connector.id,
|
||||||
|
time.perf_counter() - refresh_start,
|
||||||
|
)
|
||||||
|
|
||||||
refreshed_config = dict(server_config)
|
refreshed_config = dict(server_config)
|
||||||
refreshed_config["headers"] = {
|
refreshed_config["headers"] = {
|
||||||
|
|
@ -809,6 +874,11 @@ async def _maybe_refresh_mcp_oauth_token(
|
||||||
return refreshed_config
|
return refreshed_config
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=failed",
|
||||||
|
connector.id,
|
||||||
|
time.perf_counter() - refresh_start,
|
||||||
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to refresh MCP OAuth token for connector %s",
|
"Failed to refresh MCP OAuth token for connector %s",
|
||||||
connector.id,
|
connector.id,
|
||||||
|
|
@ -937,6 +1007,94 @@ def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||||
_mcp_tools_cache.clear()
|
_mcp_tools_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
async def discover_single_mcp_connector(connector_id: int) -> None:
|
||||||
|
"""Force live MCP discovery for one connector so its ``cached_tools`` row is fresh.
|
||||||
|
|
||||||
|
``_load_http_mcp_tools`` persists ``cached_tools`` as a side effect of any
|
||||||
|
live discovery; passing ``cached_tools=None`` here guarantees we go to the
|
||||||
|
network. The returned wrappers are discarded — the in-process LRU is
|
||||||
|
rebuilt lazily on the next user query. Stdio connectors are not cached and
|
||||||
|
are skipped.
|
||||||
|
"""
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
|
started = time.perf_counter()
|
||||||
|
try:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
connector = await session.get(SearchSourceConnector, connector_id)
|
||||||
|
if connector is None:
|
||||||
|
logger.info(
|
||||||
|
"discover_single_mcp_connector: connector %d not found",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
cfg = connector.config or {}
|
||||||
|
server_config = cfg.get("server_config", {})
|
||||||
|
if not server_config or not isinstance(server_config, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
transport = server_config.get("transport", "stdio")
|
||||||
|
if transport not in ("streamable-http", "http", "sse"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if cfg.get("mcp_oauth"):
|
||||||
|
server_config = await _maybe_refresh_mcp_oauth_token(
|
||||||
|
session, connector, cfg, server_config
|
||||||
|
)
|
||||||
|
cfg = connector.config or {}
|
||||||
|
server_config = _inject_oauth_headers(cfg, server_config)
|
||||||
|
if server_config is None:
|
||||||
|
logger.info(
|
||||||
|
"discover_single_mcp_connector: OAuth token unavailable for connector %d",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
ct = (
|
||||||
|
connector.connector_type.value
|
||||||
|
if hasattr(connector.connector_type, "value")
|
||||||
|
else str(connector.connector_type)
|
||||||
|
)
|
||||||
|
svc_cfg = get_service_by_connector_type(ct)
|
||||||
|
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
||||||
|
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
||||||
|
|
||||||
|
await asyncio.wait_for(
|
||||||
|
_load_http_mcp_tools(
|
||||||
|
connector.id,
|
||||||
|
connector.name,
|
||||||
|
server_config,
|
||||||
|
trusted_tools=cfg.get("trusted_tools", []),
|
||||||
|
allowed_tools=allowed_tools,
|
||||||
|
readonly_tools=readonly_tools,
|
||||||
|
tool_name_prefix=None,
|
||||||
|
is_generic_mcp=svc_cfg is None,
|
||||||
|
bypass_internal_hitl=True,
|
||||||
|
cached_tools=None,
|
||||||
|
),
|
||||||
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_prefetch] connector=%s elapsed=%.3fs",
|
||||||
|
connector_id,
|
||||||
|
time.perf_counter() - started,
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"discover_single_mcp_connector: connector %d timed out after %ds",
|
||||||
|
connector_id,
|
||||||
|
_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"discover_single_mcp_connector: failed for connector %d",
|
||||||
|
connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def load_mcp_tools(
|
async def load_mcp_tools(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
@ -1063,6 +1221,7 @@ async def load_mcp_tools(
|
||||||
"tool_name_prefix": tool_name_prefix,
|
"tool_name_prefix": tool_name_prefix,
|
||||||
"transport": server_config.get("transport", "stdio"),
|
"transport": server_config.get("transport", "stdio"),
|
||||||
"is_generic_mcp": svc_cfg is None,
|
"is_generic_mcp": svc_cfg is None,
|
||||||
|
"cached_tools": read_cached_tools(connector),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1074,9 +1233,12 @@ async def load_mcp_tools(
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
|
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
|
||||||
|
discover_start = time.perf_counter()
|
||||||
|
transport = task["transport"]
|
||||||
|
cached_tools = task.get("cached_tools")
|
||||||
try:
|
try:
|
||||||
if task["transport"] in ("streamable-http", "http", "sse"):
|
if transport in ("streamable-http", "http", "sse"):
|
||||||
return await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
_load_http_mcp_tools(
|
_load_http_mcp_tools(
|
||||||
task["connector_id"],
|
task["connector_id"],
|
||||||
task["connector_name"],
|
task["connector_name"],
|
||||||
|
|
@ -1087,11 +1249,12 @@ async def load_mcp_tools(
|
||||||
tool_name_prefix=task["tool_name_prefix"],
|
tool_name_prefix=task["tool_name_prefix"],
|
||||||
is_generic_mcp=task.get("is_generic_mcp", False),
|
is_generic_mcp=task.get("is_generic_mcp", False),
|
||||||
bypass_internal_hitl=bypass_internal_hitl,
|
bypass_internal_hitl=bypass_internal_hitl,
|
||||||
|
cached_tools=cached_tools,
|
||||||
),
|
),
|
||||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
_load_stdio_mcp_tools(
|
_load_stdio_mcp_tools(
|
||||||
task["connector_id"],
|
task["connector_id"],
|
||||||
task["connector_name"],
|
task["connector_name"],
|
||||||
|
|
@ -1101,7 +1264,24 @@ async def load_mcp_tools(
|
||||||
),
|
),
|
||||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs cache=%s",
|
||||||
|
task["connector_id"],
|
||||||
|
task["connector_name"],
|
||||||
|
transport,
|
||||||
|
len(result),
|
||||||
|
time.perf_counter() - discover_start,
|
||||||
|
"hit" if cached_tools is not None else "miss",
|
||||||
|
)
|
||||||
|
return result
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_discover] connector=%s name=%r transport=%s elapsed=%.3fs outcome=timeout",
|
||||||
|
task["connector_id"],
|
||||||
|
task["connector_name"],
|
||||||
|
transport,
|
||||||
|
time.perf_counter() - discover_start,
|
||||||
|
)
|
||||||
logger.error(
|
logger.error(
|
||||||
"MCP connector %d timed out after %ds during discovery",
|
"MCP connector %d timed out after %ds during discovery",
|
||||||
task["connector_id"],
|
task["connector_id"],
|
||||||
|
|
@ -1109,6 +1289,13 @@ async def load_mcp_tools(
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_discover] connector=%s name=%r transport=%s elapsed=%.3fs outcome=error",
|
||||||
|
task["connector_id"],
|
||||||
|
task["connector_name"],
|
||||||
|
transport,
|
||||||
|
time.perf_counter() - discover_start,
|
||||||
|
)
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to load tools from MCP connector %d: %s",
|
"Failed to load tools from MCP connector %d: %s",
|
||||||
task["connector_id"],
|
task["connector_id"],
|
||||||
|
|
@ -1116,7 +1303,14 @@ async def load_mcp_tools(
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
gather_start = time.perf_counter()
|
||||||
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
|
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_discover] gather_wall=%.3fs connectors=%d total_tools=%d",
|
||||||
|
time.perf_counter() - gather_start,
|
||||||
|
len(discovery_tasks),
|
||||||
|
sum(len(r) for r in results),
|
||||||
|
)
|
||||||
tools: list[StructuredTool] = [tool for sublist in results for tool in sublist]
|
tools: list[StructuredTool] = [tool for sublist in results for tool in sublist]
|
||||||
|
|
||||||
_mcp_tools_cache[cache_key] = (now, tools)
|
_mcp_tools_cache[cache_key] = (now, tools)
|
||||||
|
|
|
||||||
145
surfsense_backend/app/agents/new_chat/tools/mcp_tools_cache.py
Normal file
145
surfsense_backend/app/agents/new_chat/tools/mcp_tools_cache.py
Normal file
|
|
@ -0,0 +1,145 @@
|
||||||
|
"""Persist MCP ``list_tools`` results in ``SearchSourceConnector.config.cached_tools``."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, async_session_maker
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_pending_prefetch_tasks: set[asyncio.Task[None]] = set()
|
||||||
|
|
||||||
|
|
||||||
|
class CachedMCPToolDef(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
input_schema: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class CachedMCPTools(BaseModel):
|
||||||
|
discovered_at: datetime
|
||||||
|
server_version: str | None = None
|
||||||
|
server_name: str | None = None
|
||||||
|
transport: str | None = None
|
||||||
|
tools: list[CachedMCPToolDef]
|
||||||
|
|
||||||
|
|
||||||
|
def read_cached_tools(connector: SearchSourceConnector) -> CachedMCPTools | None:
|
||||||
|
"""Return parsed cached tools or ``None`` if missing / corrupt (caller falls back to live discovery)."""
|
||||||
|
cfg = connector.config or {}
|
||||||
|
raw = cfg.get("cached_tools")
|
||||||
|
if not raw or not isinstance(raw, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return CachedMCPTools.model_validate(raw)
|
||||||
|
except ValidationError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"MCP connector %d has corrupt cached_tools — falling back to live discovery: %s",
|
||||||
|
connector.id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def write_cached_tools(
|
||||||
|
connector_id: int,
|
||||||
|
tool_definitions: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
server_name: str | None = None,
|
||||||
|
server_version: str | None = None,
|
||||||
|
transport: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Best-effort persist; uses its own session so a write failure cannot poison the caller's transaction."""
|
||||||
|
payload = CachedMCPTools(
|
||||||
|
discovered_at=datetime.now(UTC),
|
||||||
|
server_version=server_version,
|
||||||
|
server_name=server_name,
|
||||||
|
transport=transport,
|
||||||
|
tools=[CachedMCPToolDef.model_validate(td) for td in tool_definitions],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == connector_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if connector is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
cfg = dict(connector.config or {})
|
||||||
|
cfg["cached_tools"] = payload.model_dump(mode="json")
|
||||||
|
connector.config = cfg
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Persisted cached_tools for MCP connector %d (%d tools)",
|
||||||
|
connector_id,
|
||||||
|
len(payload.tools),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to persist cached_tools for MCP connector %d",
|
||||||
|
connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_mcp_tools_cache_for_connector(
|
||||||
|
connector_id: int,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Maintain the MCP tool cache after a single-connector lifecycle event.
|
||||||
|
|
||||||
|
Synchronously evicts the in-process LRU for the connector's search space
|
||||||
|
(LRU keys are per-space, so eviction cannot be scoped finer), then schedules
|
||||||
|
a background live discovery for this connector alone so its persisted
|
||||||
|
``cached_tools`` row is refreshed before the next user query.
|
||||||
|
|
||||||
|
Idempotent. Eviction is best-effort; prefetch is best-effort and only runs
|
||||||
|
when an event loop is available. Neither path raises.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||||
|
|
||||||
|
invalidate_mcp_tools_cache(search_space_id)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"MCP in-process cache eviction skipped for space %d",
|
||||||
|
search_space_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
return
|
||||||
|
|
||||||
|
task = loop.create_task(_run_connector_prefetch(connector_id))
|
||||||
|
_pending_prefetch_tasks.add(task)
|
||||||
|
task.add_done_callback(_pending_prefetch_tasks.discard)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_connector_prefetch(connector_id: int) -> None:
|
||||||
|
from app.agents.new_chat.tools.mcp_tool import discover_single_mcp_connector
|
||||||
|
|
||||||
|
try:
|
||||||
|
await discover_single_mcp_connector(connector_id)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"MCP background prefetch failed for connector_id=%d",
|
||||||
|
connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
@ -110,6 +110,19 @@ def load_global_llm_configs():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to score global LLM configs: {e}")
|
print(f"Warning: Failed to score global LLM configs: {e}")
|
||||||
|
|
||||||
|
# Planner LLM is a singleton role. If an operator accidentally
|
||||||
|
# marks multiple configs ``is_planner: true``, only the first one
|
||||||
|
# is used at runtime — surface the others at startup so the
|
||||||
|
# mistake is caught before traffic, not silently buried.
|
||||||
|
planner_cfgs = [c for c in configs if c.get("is_planner") is True]
|
||||||
|
if len(planner_cfgs) > 1:
|
||||||
|
extra_ids = [c.get("id") for c in planner_cfgs[1:]]
|
||||||
|
print(
|
||||||
|
"Warning: Multiple global LLM configs marked is_planner=true "
|
||||||
|
f"(ids {[c.get('id') for c in planner_cfgs]}); using id "
|
||||||
|
f"{planner_cfgs[0].get('id')} and ignoring {extra_ids}"
|
||||||
|
)
|
||||||
|
|
||||||
return configs
|
return configs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to load global LLM configs: {e}")
|
print(f"Warning: Failed to load global LLM configs: {e}")
|
||||||
|
|
|
||||||
|
|
@ -258,6 +258,45 @@ global_llm_configs:
|
||||||
use_default_system_instructions: true
|
use_default_system_instructions: true
|
||||||
citations_enabled: true
|
citations_enabled: true
|
||||||
|
|
||||||
|
# Example: Planner LLM - small, fast model used for internal utility tasks
|
||||||
|
#
|
||||||
|
# The PLANNER role handles short, structured internal calls (KB query
|
||||||
|
# rewriting, date extraction, recency classification, etc.) that don't
|
||||||
|
# need frontier-tier capability. Pointing the planner at a cheap+fast
|
||||||
|
# model (gpt-4o-mini, Claude Haiku, Azure gpt-5.x-nano, Groq Llama, ...)
|
||||||
|
# typically saves 500ms-1.5s per turn vs. routing those same internal
|
||||||
|
# calls through the user's chat model.
|
||||||
|
#
|
||||||
|
# Activation:
|
||||||
|
# - Mark EXACTLY ONE global config with ``is_planner: true``.
|
||||||
|
# - If multiple are marked, the first one wins and a WARNING is logged.
|
||||||
|
# - If none is marked, every internal call falls back to the user's
|
||||||
|
# chat LLM (same behavior as before this flag existed).
|
||||||
|
#
|
||||||
|
# This config is operator-only — it is NOT exposed in the user-facing
|
||||||
|
# model selector, never billed against premium quota, and the
|
||||||
|
# billing_tier / anonymous_enabled fields below are ignored.
|
||||||
|
- id: -9
|
||||||
|
name: "Global Planner (GPT-4o mini)"
|
||||||
|
description: "Internal-only planner LLM for query rewriting and classification"
|
||||||
|
is_planner: true
|
||||||
|
billing_tier: "free"
|
||||||
|
anonymous_enabled: false
|
||||||
|
seo_enabled: false
|
||||||
|
quota_reserve_tokens: 1000
|
||||||
|
provider: "OPENAI"
|
||||||
|
model_name: "gpt-4o-mini"
|
||||||
|
api_key: "sk-your-openai-api-key-here"
|
||||||
|
api_base: ""
|
||||||
|
rpm: 3500
|
||||||
|
tpm: 200000
|
||||||
|
litellm_params:
|
||||||
|
temperature: 0
|
||||||
|
max_tokens: 1000
|
||||||
|
system_instructions: ""
|
||||||
|
use_default_system_instructions: true
|
||||||
|
citations_enabled: false
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# OpenRouter Integration
|
# OpenRouter Integration
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -493,6 +532,20 @@ global_vision_llm_configs:
|
||||||
# - Lower temperature (0.3) is recommended for accurate screenshot analysis
|
# - Lower temperature (0.3) is recommended for accurate screenshot analysis
|
||||||
# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions
|
# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions
|
||||||
#
|
#
|
||||||
|
# PLANNER LLM NOTES:
|
||||||
|
# - is_planner: true marks a config as the internal-only planner LLM (small,
|
||||||
|
# fast model used for KB query rewriting, date extraction, recency
|
||||||
|
# classification, etc.). Only one config may carry this flag — if
|
||||||
|
# multiple do, the first one wins and a startup WARNING is logged.
|
||||||
|
# - When no config is marked is_planner, every internal utility call falls
|
||||||
|
# back to the user's chat LLM (the historical behavior).
|
||||||
|
# - Planner configs are NOT shown in the user-facing model selector and
|
||||||
|
# are NOT billed against the user's premium quota. Their billing_tier,
|
||||||
|
# anonymous_enabled, seo_* fields are ignored.
|
||||||
|
# - Recommended models: gpt-4o-mini, claude-3-5-haiku, gemini-1.5-flash,
|
||||||
|
# azure gpt-5.x-nano, groq llama3-8b — anything <200ms p50 on a 1-2k
|
||||||
|
# prompt. Frontier models here defeat the purpose of the flag.
|
||||||
|
#
|
||||||
# TOKEN QUOTA & ANONYMOUS ACCESS NOTES:
|
# TOKEN QUOTA & ANONYMOUS ACCESS NOTES:
|
||||||
# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota.
|
# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota.
|
||||||
# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog.
|
# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog.
|
||||||
|
|
|
||||||
|
|
@ -428,7 +428,7 @@ async def mcp_oauth_callback(
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(db_connector)
|
await session.refresh(db_connector)
|
||||||
|
|
||||||
_invalidate_cache(space_id)
|
_refresh_mcp_cache(db_connector.id, space_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Re-authenticated %s MCP connector %s for user %s",
|
"Re-authenticated %s MCP connector %s for user %s",
|
||||||
|
|
@ -481,7 +481,7 @@ async def mcp_oauth_callback(
|
||||||
detail="A connector for this service already exists.",
|
detail="A connector for this service already exists.",
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
_invalidate_cache(space_id)
|
_refresh_mcp_cache(new_connector.id, space_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Created %s MCP connector %s for user %s in space %s",
|
"Created %s MCP connector %s for user %s in space %s",
|
||||||
|
|
@ -658,10 +658,17 @@ async def reauth_mcp_service(
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _invalidate_cache(space_id: int) -> None:
|
def _refresh_mcp_cache(connector_id: int, space_id: int) -> None:
|
||||||
try:
|
"""Evict the in-process MCP tool LRU and schedule background prefetch.
|
||||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
|
||||||
|
|
||||||
invalidate_mcp_tools_cache(space_id)
|
Wraps :func:`refresh_mcp_tools_cache_for_connector` so any failure is
|
||||||
|
isolated from the OAuth response flow.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
refresh_mcp_tools_cache_for_connector,
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_mcp_tools_cache_for_connector(connector_id, space_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("MCP cache invalidation skipped", exc_info=True)
|
logger.debug("MCP cache refresh skipped", exc_info=True)
|
||||||
|
|
|
||||||
|
|
@ -2650,9 +2650,11 @@ async def create_mcp_connector(
|
||||||
f"for user {user.id} in search space {search_space_id}"
|
f"for user {user.id} in search space {search_space_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
refresh_mcp_tools_cache_for_connector,
|
||||||
|
)
|
||||||
|
|
||||||
invalidate_mcp_tools_cache(search_space_id)
|
refresh_mcp_tools_cache_for_connector(db_connector.id, search_space_id)
|
||||||
|
|
||||||
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
|
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
|
||||||
return MCPConnectorRead.from_connector(connector_read)
|
return MCPConnectorRead.from_connector(connector_read)
|
||||||
|
|
@ -2828,9 +2830,11 @@ async def update_mcp_connector(
|
||||||
|
|
||||||
logger.info(f"Updated MCP connector {connector_id}")
|
logger.info(f"Updated MCP connector {connector_id}")
|
||||||
|
|
||||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
refresh_mcp_tools_cache_for_connector,
|
||||||
|
)
|
||||||
|
|
||||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
refresh_mcp_tools_cache_for_connector(connector.id, connector.search_space_id)
|
||||||
|
|
||||||
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
||||||
return MCPConnectorRead.from_connector(connector_read)
|
return MCPConnectorRead.from_connector(connector_read)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
@ -100,7 +101,9 @@ class GmailKBSyncService:
|
||||||
else:
|
else:
|
||||||
logger.warning("No LLM configured -- using fallback summary")
|
logger.warning("No LLM configured -- using fallback summary")
|
||||||
summary_content = f"Gmail Message: {subject}\n\n{indexable_content}"
|
summary_content = f"Gmail Message: {subject}\n\n{indexable_content}"
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(indexable_content)
|
chunks = await create_document_chunks(indexable_content)
|
||||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,9 @@ class GoogleCalendarKBSyncService:
|
||||||
summary_content = (
|
summary_content = (
|
||||||
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
||||||
)
|
)
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(indexable_content)
|
chunks = await create_document_chunks(indexable_content)
|
||||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
@ -295,7 +297,9 @@ class GoogleCalendarKBSyncService:
|
||||||
summary_content = (
|
summary_content = (
|
||||||
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
||||||
)
|
)
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(indexable_content)
|
chunks = await create_document_chunks(indexable_content)
|
||||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,9 @@ class JiraKBSyncService:
|
||||||
summary_content = (
|
summary_content = (
|
||||||
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
||||||
)
|
)
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(issue_content)
|
chunks = await create_document_chunks(issue_content)
|
||||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
@ -212,7 +214,9 @@ class JiraKBSyncService:
|
||||||
summary_content = (
|
summary_content = (
|
||||||
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
||||||
)
|
)
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(issue_content)
|
chunks = await create_document_chunks(issue_content)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -659,3 +659,36 @@ async def get_user_long_context_llm(
|
||||||
return await get_document_summary_llm(
|
return await get_document_summary_llm(
|
||||||
session, search_space_id, disable_streaming=disable_streaming
|
session, search_space_id, disable_streaming=disable_streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_planner_llm() -> ChatLiteLLM | None:
|
||||||
|
"""Return a planner LLM instance from the first global config marked
|
||||||
|
``is_planner: true``, or ``None`` if no planner config is defined.
|
||||||
|
|
||||||
|
The planner role handles short, structured internal tasks (KB search
|
||||||
|
planning: query rewriting, date extraction, recency classification).
|
||||||
|
These tasks are well-served by small/fast models (e.g. gpt-4o-mini,
|
||||||
|
Claude Haiku, Azure gpt-5.x-nano) — using the user's chat LLM for them
|
||||||
|
is unnecessarily expensive and slow.
|
||||||
|
|
||||||
|
This helper reads from ``config.GLOBAL_LLM_CONFIGS`` (loaded at import
|
||||||
|
time from ``global_llm_config.yaml``) so it has no DB cost and can be
|
||||||
|
called synchronously from middleware/factory code. It returns the same
|
||||||
|
instance shape as the global path of ``get_search_space_llm_instance``.
|
||||||
|
|
||||||
|
Callers MUST fall back to their chat LLM when this returns ``None`` so
|
||||||
|
deployments without a planner config keep working unchanged.
|
||||||
|
"""
|
||||||
|
from app.agents.new_chat.llm_config import create_chat_litellm_from_config
|
||||||
|
|
||||||
|
planner_cfg = next(
|
||||||
|
(
|
||||||
|
cfg
|
||||||
|
for cfg in config.GLOBAL_LLM_CONFIGS
|
||||||
|
if cfg.get("is_planner") is True
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if not planner_cfg:
|
||||||
|
return None
|
||||||
|
return create_chat_litellm_from_config(planner_cfg)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
@ -95,7 +96,9 @@ class OneDriveKBSyncService:
|
||||||
else:
|
else:
|
||||||
logger.warning("No LLM configured — using fallback summary")
|
logger.warning("No LLM configured — using fallback summary")
|
||||||
summary_content = f"OneDrive File: {file_name}\n\n{indexable_content}"
|
summary_content = f"OneDrive File: {file_name}\n\n{indexable_content}"
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(indexable_content)
|
chunks = await create_document_chunks(indexable_content)
|
||||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ same trap waiting to happen).
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
@ -234,7 +235,7 @@ async def _restore_in_place_document(
|
||||||
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
||||||
]
|
]
|
||||||
if chunk_texts:
|
if chunk_texts:
|
||||||
chunk_embeddings = embed_texts(chunk_texts)
|
chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
|
||||||
session.add_all(
|
session.add_all(
|
||||||
[
|
[
|
||||||
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||||
|
|
@ -244,7 +245,9 @@ async def _restore_in_place_document(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if isinstance(revision.content_before, str):
|
if isinstance(revision.content_before, str):
|
||||||
doc.embedding = embed_texts([revision.content_before])[0]
|
doc.embedding = (
|
||||||
|
await asyncio.to_thread(embed_texts, [revision.content_before])
|
||||||
|
)[0]
|
||||||
|
|
||||||
doc.updated_at = datetime.now(UTC)
|
doc.updated_at = datetime.now(UTC)
|
||||||
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
||||||
|
|
@ -320,7 +323,7 @@ async def _reinsert_document_from_revision(
|
||||||
session.add(new_doc)
|
session.add(new_doc)
|
||||||
await session.flush()
|
await session.flush()
|
||||||
|
|
||||||
new_doc.embedding = embed_texts([content])[0]
|
new_doc.embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
|
||||||
chunk_texts = []
|
chunk_texts = []
|
||||||
chunks_before = revision.chunks_before
|
chunks_before = revision.chunks_before
|
||||||
if isinstance(chunks_before, list):
|
if isinstance(chunks_before, list):
|
||||||
|
|
@ -330,7 +333,7 @@ async def _reinsert_document_from_revision(
|
||||||
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
||||||
]
|
]
|
||||||
if chunk_texts:
|
if chunk_texts:
|
||||||
chunk_embeddings = embed_texts(chunk_texts)
|
chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
|
||||||
session.add_all(
|
session.add_all(
|
||||||
[
|
[
|
||||||
Chunk(document_id=new_doc.id, content=text, embedding=embedding)
|
Chunk(document_id=new_doc.id, content=text, embedding=embedding)
|
||||||
|
|
|
||||||
|
|
@ -325,6 +325,24 @@ class TokenTrackingCallback(CustomLogger):
|
||||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||||
call_kind = "chat"
|
call_kind = "chat"
|
||||||
|
|
||||||
|
# Prompt-cache accounting. LiteLLM normalizes every provider's cache
|
||||||
|
# fields onto ``usage.prompt_tokens_details``:
|
||||||
|
# - ``cached_tokens`` — cache reads (OpenAI/Azure native, DeepSeek
|
||||||
|
# mapped from ``prompt_cache_hit_tokens``,
|
||||||
|
# Anthropic mapped from ``cache_read_input_tokens``).
|
||||||
|
# - ``cache_creation_tokens`` — cache writes (Anthropic only; OpenAI/Azure
|
||||||
|
# do not expose a write count).
|
||||||
|
# See ``litellm.types.utils.Usage.__init__`` for the mapping.
|
||||||
|
cached_tokens = 0
|
||||||
|
cache_creation_tokens = 0
|
||||||
|
if not is_image:
|
||||||
|
prompt_details = getattr(usage, "prompt_tokens_details", None)
|
||||||
|
if prompt_details is not None:
|
||||||
|
cached_tokens = getattr(prompt_details, "cached_tokens", 0) or 0
|
||||||
|
cache_creation_tokens = (
|
||||||
|
getattr(prompt_details, "cache_creation_tokens", 0) or 0
|
||||||
|
)
|
||||||
|
|
||||||
model = kwargs.get("model", "unknown")
|
model = kwargs.get("model", "unknown")
|
||||||
|
|
||||||
cost_usd = _extract_cost_usd(
|
cost_usd = _extract_cost_usd(
|
||||||
|
|
@ -357,9 +375,23 @@ class TokenTrackingCallback(CustomLogger):
|
||||||
cost_micros=cost_micros,
|
cost_micros=cost_micros,
|
||||||
call_kind=call_kind,
|
call_kind=call_kind,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Per-LLM-call wall-clock latency (LiteLLM passes datetime objects).
|
||||||
|
call_latency_s: float | None = None
|
||||||
|
try:
|
||||||
|
if start_time is not None and end_time is not None:
|
||||||
|
delta = end_time - start_time
|
||||||
|
call_latency_s = getattr(delta, "total_seconds", lambda: float(delta))()
|
||||||
|
except Exception:
|
||||||
|
call_latency_s = None
|
||||||
|
|
||||||
|
cache_hit_ratio: float | None = None
|
||||||
|
if prompt_tokens > 0 and (cached_tokens > 0 or cache_creation_tokens > 0):
|
||||||
|
cache_hit_ratio = cached_tokens / prompt_tokens
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
|
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
|
||||||
"cost=$%.6f (%d micros) (accumulator now has %d calls)",
|
"cost=$%.6f (%d micros) (accumulator now has %d calls)%s%s",
|
||||||
model,
|
model,
|
||||||
call_kind,
|
call_kind,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
|
|
@ -368,6 +400,17 @@ class TokenTrackingCallback(CustomLogger):
|
||||||
cost_usd,
|
cost_usd,
|
||||||
cost_micros,
|
cost_micros,
|
||||||
len(acc.calls),
|
len(acc.calls),
|
||||||
|
f" latency={call_latency_s:.3f}s" if call_latency_s is not None else "",
|
||||||
|
(
|
||||||
|
f" cache_read={cached_tokens} cache_write={cache_creation_tokens}"
|
||||||
|
f" hit_ratio={cache_hit_ratio:.1%}"
|
||||||
|
if cache_hit_ratio is not None
|
||||||
|
else (
|
||||||
|
f" cache_read={cached_tokens} cache_write={cache_creation_tokens}"
|
||||||
|
if (cached_tokens or cache_creation_tokens)
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,8 +60,6 @@ from app.db import (
|
||||||
)
|
)
|
||||||
from app.prompts import TITLE_GENERATION_PROMPT
|
from app.prompts import TITLE_GENERATION_PROMPT
|
||||||
from app.services.auto_model_pin_service import (
|
from app.services.auto_model_pin_service import (
|
||||||
is_recently_healthy,
|
|
||||||
mark_healthy,
|
|
||||||
mark_runtime_cooldown,
|
mark_runtime_cooldown,
|
||||||
resolve_or_get_pinned_llm_config_id,
|
resolve_or_get_pinned_llm_config_id,
|
||||||
)
|
)
|
||||||
|
|
@ -501,54 +499,6 @@ def _is_provider_rate_limited(exc: BaseException) -> bool:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_PREFLIGHT_TIMEOUT_SEC: float = 2.5
|
|
||||||
_PREFLIGHT_MAX_TOKENS: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
async def _preflight_llm(llm: Any) -> None:
|
|
||||||
"""Issue a minimal completion to confirm the pinned model isn't 429'ing.
|
|
||||||
|
|
||||||
Used before agent build / planner / classifier / title-gen so a known-bad
|
|
||||||
free OpenRouter deployment is detected and repinned before it cascades
|
|
||||||
into multiple wasted internal calls. The probe is intentionally cheap:
|
|
||||||
one token, low timeout, tagged ``surfsense:internal`` so token tracking
|
|
||||||
and SSE pipelines treat it as overhead rather than user output.
|
|
||||||
|
|
||||||
Raises the original exception when the provider responds with a
|
|
||||||
rate-limit-shaped error so the caller can drive the cooldown/repin
|
|
||||||
branch via :func:`_is_provider_rate_limited`. Other transient failures
|
|
||||||
are swallowed — the caller continues to the normal stream path and the
|
|
||||||
in-stream recovery loop remains the safety net.
|
|
||||||
"""
|
|
||||||
from litellm import acompletion
|
|
||||||
|
|
||||||
model = getattr(llm, "model", None)
|
|
||||||
if not model or model == "auto":
|
|
||||||
# Auto-mode router doesn't have a single deployment to ping; the
|
|
||||||
# router itself handles per-deployment rate-limit accounting.
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
await acompletion(
|
|
||||||
model=model,
|
|
||||||
messages=[{"role": "user", "content": "ping"}],
|
|
||||||
api_key=getattr(llm, "api_key", None),
|
|
||||||
api_base=getattr(llm, "api_base", None),
|
|
||||||
max_tokens=_PREFLIGHT_MAX_TOKENS,
|
|
||||||
timeout=_PREFLIGHT_TIMEOUT_SEC,
|
|
||||||
stream=False,
|
|
||||||
metadata={"tags": ["surfsense:internal", "auto-pin-preflight"]},
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
if _is_provider_rate_limited(exc):
|
|
||||||
raise
|
|
||||||
logging.getLogger(__name__).debug(
|
|
||||||
"auto_pin_preflight non_rate_limit_error model=%s err=%s",
|
|
||||||
model,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _build_main_agent_for_thread(
|
async def _build_main_agent_for_thread(
|
||||||
agent_factory: Any,
|
agent_factory: Any,
|
||||||
*,
|
*,
|
||||||
|
|
@ -566,9 +516,9 @@ async def _build_main_agent_for_thread(
|
||||||
disabled_tools: list[str] | None = None,
|
disabled_tools: list[str] | None = None,
|
||||||
mentioned_document_ids: list[int] | None = None,
|
mentioned_document_ids: list[int] | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Single (re)build path so the agent factory cannot drift across
|
"""Single (re)build path so the agent factory cannot drift across the
|
||||||
initial build, preflight repin, and mid-stream 429 recovery for one
|
initial build and mid-stream 429 recovery for one ``thread_id``: a
|
||||||
``thread_id``: a graph swap mid-turn would corrupt checkpointer state."""
|
graph swap mid-turn would corrupt checkpointer state."""
|
||||||
return await agent_factory(
|
return await agent_factory(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
|
|
@ -586,29 +536,6 @@ async def _build_main_agent_for_thread(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
|
|
||||||
"""Wait for a discarded speculative agent build to release shared state.
|
|
||||||
|
|
||||||
Used by the parallel preflight + agent-build path. The speculative build
|
|
||||||
closes over the request-scoped ``AsyncSession`` (for the brief connector
|
|
||||||
discovery / tool-factory window before its CPU work moves into a worker
|
|
||||||
thread). If preflight reports a 429 we want to fall back to the original
|
|
||||||
repin → reload → rebuild path, but we MUST NOT touch ``session`` again
|
|
||||||
until any in-flight session work owned by the speculative build has
|
|
||||||
fully settled — :class:`sqlalchemy.ext.asyncio.AsyncSession` is not
|
|
||||||
concurrency-safe and the same hazard cost us a hard ``InvalidRequestError``
|
|
||||||
earlier in this PR (see ``connector_service`` parallel-gather revert).
|
|
||||||
|
|
||||||
We simply ``await`` the task and swallow any exception: in this path the
|
|
||||||
build's outcome is irrelevant — success populates the agent cache (a free
|
|
||||||
side effect), failure is discarded. The wasted CPU is acceptable since
|
|
||||||
429 fallbacks are rare and the original sequential code also paid the
|
|
||||||
full build cost on the same path.
|
|
||||||
"""
|
|
||||||
with contextlib.suppress(BaseException):
|
|
||||||
await task
|
|
||||||
|
|
||||||
|
|
||||||
def _classify_stream_exception(
|
def _classify_stream_exception(
|
||||||
exc: Exception,
|
exc: Exception,
|
||||||
*,
|
*,
|
||||||
|
|
@ -1236,39 +1163,6 @@ async def stream_new_chat(
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Auto-mode preflight ping. Runs ONLY for thread-pinned auto cfgs
|
|
||||||
# (negative ids selected via ``resolve_or_get_pinned_llm_config_id``)
|
|
||||||
# whose health hasn't already been confirmed within the TTL window.
|
|
||||||
# Detecting a 429 here lets us repin BEFORE the planner/classifier/
|
|
||||||
# title-generation LLM calls fan out and each independently hit the
|
|
||||||
# same upstream rate limit.
|
|
||||||
#
|
|
||||||
# PERF: preflight is a network round-trip to the LLM provider (~1-5s)
|
|
||||||
# and is independent of the agent build (CPU-bound, ~5-7s). They used
|
|
||||||
# to run sequentially → ``preflight + build`` on cold cache = 11.5s.
|
|
||||||
# We now kick off preflight as a background task FIRST, then run the
|
|
||||||
# synchronous setup work and the agent build in parallel. In the
|
|
||||||
# success path (the common case) total wall time drops to roughly
|
|
||||||
# ``max(preflight, build)`` — the preflight finishes during the
|
|
||||||
# agent compile and we just consume its result. In the rare 429
|
|
||||||
# path the speculative build is awaited to completion (so its
|
|
||||||
# session usage is fully released) via
|
|
||||||
# :func:`_settle_speculative_agent_build`, then discarded, and
|
|
||||||
# we fall back to the original repin-and-rebuild flow.
|
|
||||||
preflight_needed = (
|
|
||||||
requested_llm_config_id == 0
|
|
||||||
and llm_config_id < 0
|
|
||||||
and not is_recently_healthy(llm_config_id)
|
|
||||||
)
|
|
||||||
preflight_task: asyncio.Task[None] | None = None
|
|
||||||
_t_preflight = 0.0
|
|
||||||
if preflight_needed:
|
|
||||||
_t_preflight = time.perf_counter()
|
|
||||||
preflight_task = asyncio.create_task(
|
|
||||||
_preflight_llm(llm),
|
|
||||||
name=f"auto_pin_preflight:{llm_config_id}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create connector service
|
# Create connector service
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||||
|
|
@ -1302,136 +1196,26 @@ async def stream_new_chat(
|
||||||
if use_multi_agent
|
if use_multi_agent
|
||||||
else create_surfsense_deep_agent
|
else create_surfsense_deep_agent
|
||||||
)
|
)
|
||||||
# Speculative agent build — runs in parallel with the preflight
|
# Build the agent inline. Provider 429s surface through the
|
||||||
# task (if any). Built with the *current* ``llm`` / ``agent_config``;
|
# in-stream recovery loop below (``_is_provider_rate_limited``),
|
||||||
# if preflight reports 429 we will discard this future and rebuild
|
# which repins the thread to an eligible alternative config and
|
||||||
# against the freshly pinned config below.
|
# rebuilds the agent before the user sees any output.
|
||||||
agent_build_task = asyncio.create_task(
|
agent = await _build_main_agent_for_thread(
|
||||||
_build_main_agent_for_thread(
|
agent_factory,
|
||||||
agent_factory,
|
llm=llm,
|
||||||
llm=llm,
|
search_space_id=search_space_id,
|
||||||
search_space_id=search_space_id,
|
db_session=session,
|
||||||
db_session=session,
|
connector_service=connector_service,
|
||||||
connector_service=connector_service,
|
checkpointer=checkpointer,
|
||||||
checkpointer=checkpointer,
|
user_id=user_id,
|
||||||
user_id=user_id,
|
thread_id=chat_id,
|
||||||
thread_id=chat_id,
|
agent_config=agent_config,
|
||||||
agent_config=agent_config,
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
thread_visibility=visibility,
|
||||||
thread_visibility=visibility,
|
filesystem_selection=filesystem_selection,
|
||||||
filesystem_selection=filesystem_selection,
|
disabled_tools=disabled_tools,
|
||||||
disabled_tools=disabled_tools,
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
|
||||||
),
|
|
||||||
name="agent_build:stream_new_chat",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
agent: Any = None
|
|
||||||
if preflight_task is not None:
|
|
||||||
try:
|
|
||||||
await preflight_task
|
|
||||||
mark_healthy(llm_config_id)
|
|
||||||
_perf_log.info(
|
|
||||||
"[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
|
|
||||||
llm_config_id,
|
|
||||||
time.perf_counter() - _t_preflight,
|
|
||||||
)
|
|
||||||
except Exception as preflight_exc:
|
|
||||||
# Both branches below need the session: the non-429 path
|
|
||||||
# may unwind via cleanup that uses ``session``, and the
|
|
||||||
# 429 path explicitly calls ``resolve_or_get_pinned_llm_config_id``
|
|
||||||
# against it. Wait for the speculative build to release its
|
|
||||||
# session usage before we proceed.
|
|
||||||
await _settle_speculative_agent_build(agent_build_task)
|
|
||||||
if not _is_provider_rate_limited(preflight_exc):
|
|
||||||
raise
|
|
||||||
# 429: speculative agent is discarded; run the original
|
|
||||||
# repin → reload → rebuild path against the freshly
|
|
||||||
# pinned config.
|
|
||||||
previous_config_id = llm_config_id
|
|
||||||
mark_runtime_cooldown(
|
|
||||||
previous_config_id, reason="preflight_rate_limited"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
llm_config_id = (
|
|
||||||
await resolve_or_get_pinned_llm_config_id(
|
|
||||||
session,
|
|
||||||
thread_id=chat_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
selected_llm_config_id=0,
|
|
||||||
exclude_config_ids={previous_config_id},
|
|
||||||
requires_image_input=_requires_image_input,
|
|
||||||
)
|
|
||||||
).resolved_llm_config_id
|
|
||||||
except ValueError as pin_error:
|
|
||||||
yield _emit_stream_error(
|
|
||||||
message=str(pin_error),
|
|
||||||
error_kind="server_error",
|
|
||||||
error_code="SERVER_ERROR",
|
|
||||||
)
|
|
||||||
yield streaming_service.format_done()
|
|
||||||
return
|
|
||||||
|
|
||||||
llm, agent_config, llm_load_error = await _load_llm_bundle(
|
|
||||||
llm_config_id
|
|
||||||
)
|
|
||||||
if llm_load_error or not llm:
|
|
||||||
yield _emit_stream_error(
|
|
||||||
message=llm_load_error or "Failed to create LLM instance",
|
|
||||||
error_kind="server_error",
|
|
||||||
error_code="SERVER_ERROR",
|
|
||||||
)
|
|
||||||
yield streaming_service.format_done()
|
|
||||||
return
|
|
||||||
# Trust the freshly-resolved cfg for the remainder of this
|
|
||||||
# turn rather than recursing into another preflight; the
|
|
||||||
# in-stream 429 recovery loop is still in place as the
|
|
||||||
# safety net if even this fallback hits an upstream cap.
|
|
||||||
mark_healthy(llm_config_id)
|
|
||||||
_log_chat_stream_error(
|
|
||||||
flow=flow,
|
|
||||||
error_kind="rate_limited",
|
|
||||||
error_code="RATE_LIMITED",
|
|
||||||
severity="info",
|
|
||||||
is_expected=True,
|
|
||||||
request_id=request_id,
|
|
||||||
thread_id=chat_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
message=(
|
|
||||||
"Auto-pinned model failed preflight; switched to another "
|
|
||||||
"eligible model and continuing."
|
|
||||||
),
|
|
||||||
extra={
|
|
||||||
"auto_runtime_recover": True,
|
|
||||||
"preflight": True,
|
|
||||||
"previous_config_id": previous_config_id,
|
|
||||||
"fallback_config_id": llm_config_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Rebuild against the new llm/agent_config. Sequential
|
|
||||||
# here because we no longer have anything to overlap with.
|
|
||||||
agent = await agent_factory(
|
|
||||||
llm=llm,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
db_session=session,
|
|
||||||
connector_service=connector_service,
|
|
||||||
checkpointer=checkpointer,
|
|
||||||
user_id=user_id,
|
|
||||||
thread_id=chat_id,
|
|
||||||
agent_config=agent_config,
|
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
|
||||||
thread_visibility=visibility,
|
|
||||||
disabled_tools=disabled_tools,
|
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
|
||||||
filesystem_selection=filesystem_selection,
|
|
||||||
)
|
|
||||||
|
|
||||||
if agent is None:
|
|
||||||
# Either no preflight was needed, or preflight succeeded —
|
|
||||||
# in both cases the speculative build is the agent we want.
|
|
||||||
agent = await agent_build_task
|
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
||||||
)
|
)
|
||||||
|
|
@ -2647,25 +2431,6 @@ async def stream_resume_chat(
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``:
|
|
||||||
# one cheap probe before the agent is rebuilt so a 429'd pin gets
|
|
||||||
# repinned without burning planner/classifier/title calls first.
|
|
||||||
# See ``stream_new_chat`` for the full rationale on the speculative
|
|
||||||
# parallel build pattern below.
|
|
||||||
preflight_needed = (
|
|
||||||
requested_llm_config_id == 0
|
|
||||||
and llm_config_id < 0
|
|
||||||
and not is_recently_healthy(llm_config_id)
|
|
||||||
)
|
|
||||||
preflight_task: asyncio.Task[None] | None = None
|
|
||||||
_t_preflight = 0.0
|
|
||||||
if preflight_needed:
|
|
||||||
_t_preflight = time.perf_counter()
|
|
||||||
preflight_task = asyncio.create_task(
|
|
||||||
_preflight_llm(llm),
|
|
||||||
name=f"auto_pin_preflight_resume:{llm_config_id}",
|
|
||||||
)
|
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||||
|
|
||||||
|
|
@ -2695,115 +2460,25 @@ async def stream_resume_chat(
|
||||||
if _app_config.MULTI_AGENT_CHAT_ENABLED
|
if _app_config.MULTI_AGENT_CHAT_ENABLED
|
||||||
else create_surfsense_deep_agent
|
else create_surfsense_deep_agent
|
||||||
)
|
)
|
||||||
agent_build_task = asyncio.create_task(
|
# Build the agent inline. Provider 429s are handled by the
|
||||||
_build_main_agent_for_thread(
|
# in-stream recovery loop, which repins to an eligible
|
||||||
agent_factory,
|
# alternative config and rebuilds the agent before the user sees
|
||||||
llm=llm,
|
# any output.
|
||||||
search_space_id=search_space_id,
|
agent = await _build_main_agent_for_thread(
|
||||||
db_session=session,
|
agent_factory,
|
||||||
connector_service=connector_service,
|
llm=llm,
|
||||||
checkpointer=checkpointer,
|
search_space_id=search_space_id,
|
||||||
user_id=user_id,
|
db_session=session,
|
||||||
thread_id=chat_id,
|
connector_service=connector_service,
|
||||||
agent_config=agent_config,
|
checkpointer=checkpointer,
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
user_id=user_id,
|
||||||
thread_visibility=visibility,
|
thread_id=chat_id,
|
||||||
filesystem_selection=filesystem_selection,
|
agent_config=agent_config,
|
||||||
disabled_tools=disabled_tools,
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
),
|
thread_visibility=visibility,
|
||||||
name="agent_build:stream_resume",
|
filesystem_selection=filesystem_selection,
|
||||||
|
disabled_tools=disabled_tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent: Any = None
|
|
||||||
if preflight_task is not None:
|
|
||||||
try:
|
|
||||||
await preflight_task
|
|
||||||
mark_healthy(llm_config_id)
|
|
||||||
_perf_log.info(
|
|
||||||
"[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
|
|
||||||
llm_config_id,
|
|
||||||
time.perf_counter() - _t_preflight,
|
|
||||||
)
|
|
||||||
except Exception as preflight_exc:
|
|
||||||
# Same session-safety rationale as ``stream_new_chat``.
|
|
||||||
await _settle_speculative_agent_build(agent_build_task)
|
|
||||||
if not _is_provider_rate_limited(preflight_exc):
|
|
||||||
raise
|
|
||||||
previous_config_id = llm_config_id
|
|
||||||
mark_runtime_cooldown(
|
|
||||||
previous_config_id, reason="preflight_rate_limited"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
llm_config_id = (
|
|
||||||
await resolve_or_get_pinned_llm_config_id(
|
|
||||||
session,
|
|
||||||
thread_id=chat_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
selected_llm_config_id=0,
|
|
||||||
exclude_config_ids={previous_config_id},
|
|
||||||
)
|
|
||||||
).resolved_llm_config_id
|
|
||||||
except ValueError as pin_error:
|
|
||||||
yield _emit_stream_error(
|
|
||||||
message=str(pin_error),
|
|
||||||
error_kind="server_error",
|
|
||||||
error_code="SERVER_ERROR",
|
|
||||||
)
|
|
||||||
yield streaming_service.format_done()
|
|
||||||
return
|
|
||||||
|
|
||||||
llm, agent_config, llm_load_error = await _load_llm_bundle(
|
|
||||||
llm_config_id
|
|
||||||
)
|
|
||||||
if llm_load_error or not llm:
|
|
||||||
yield _emit_stream_error(
|
|
||||||
message=llm_load_error or "Failed to create LLM instance",
|
|
||||||
error_kind="server_error",
|
|
||||||
error_code="SERVER_ERROR",
|
|
||||||
)
|
|
||||||
yield streaming_service.format_done()
|
|
||||||
return
|
|
||||||
mark_healthy(llm_config_id)
|
|
||||||
_log_chat_stream_error(
|
|
||||||
flow="resume",
|
|
||||||
error_kind="rate_limited",
|
|
||||||
error_code="RATE_LIMITED",
|
|
||||||
severity="info",
|
|
||||||
is_expected=True,
|
|
||||||
request_id=request_id,
|
|
||||||
thread_id=chat_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
message=(
|
|
||||||
"Auto-pinned model failed preflight; switched to another "
|
|
||||||
"eligible model and continuing."
|
|
||||||
),
|
|
||||||
extra={
|
|
||||||
"auto_runtime_recover": True,
|
|
||||||
"preflight": True,
|
|
||||||
"previous_config_id": previous_config_id,
|
|
||||||
"fallback_config_id": llm_config_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
agent = await _build_main_agent_for_thread(
|
|
||||||
agent_factory,
|
|
||||||
llm=llm,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
db_session=session,
|
|
||||||
connector_service=connector_service,
|
|
||||||
checkpointer=checkpointer,
|
|
||||||
user_id=user_id,
|
|
||||||
thread_id=chat_id,
|
|
||||||
agent_config=agent_config,
|
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
|
||||||
thread_visibility=visibility,
|
|
||||||
filesystem_selection=filesystem_selection,
|
|
||||||
disabled_tools=disabled_tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
if agent is None:
|
|
||||||
agent = await agent_build_task
|
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -670,7 +670,9 @@ async def index_discord_messages(
|
||||||
|
|
||||||
# Heavy processing (embeddings, chunks)
|
# Heavy processing (embeddings, chunks)
|
||||||
chunks = await create_document_chunks(item["combined_document_string"])
|
chunks = await create_document_chunks(item["combined_document_string"])
|
||||||
doc_embedding = embed_text(item["combined_document_string"])
|
doc_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, item["combined_document_string"]
|
||||||
|
)
|
||||||
|
|
||||||
# Update document to READY with actual content
|
# Update document to READY with actual content
|
||||||
document.title = f"{item['guild_name']}#{item['channel_name']}"
|
document.title = f"{item['guild_name']}#{item['channel_name']}"
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ Implements 2-phase document status updates for real-time UI feedback:
|
||||||
- Phase 2: Process each event: pending → processing → ready/failed
|
- Phase 2: Process each event: pending → processing → ready/failed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
@ -465,7 +466,9 @@ async def index_luma_events(
|
||||||
summary_content = (
|
summary_content = (
|
||||||
f"Luma Event: {item['event_name']}\n\n{item['event_markdown']}"
|
f"Luma Event: {item['event_name']}\n\n{item['event_markdown']}"
|
||||||
)
|
)
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(item["event_markdown"])
|
chunks = await create_document_chunks(item["event_markdown"])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ Uses 2-phase document status updates for real-time UI feedback:
|
||||||
- Phase 2: Process each document: pending → processing → ready/failed
|
- Phase 2: Process each document: pending → processing → ready/failed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
@ -581,7 +582,9 @@ async def index_teams_messages(
|
||||||
|
|
||||||
# Heavy processing (embeddings, chunks)
|
# Heavy processing (embeddings, chunks)
|
||||||
chunks = await create_document_chunks(item["combined_document_string"])
|
chunks = await create_document_chunks(item["combined_document_string"])
|
||||||
doc_embedding = embed_text(item["combined_document_string"])
|
doc_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, item["combined_document_string"]
|
||||||
|
)
|
||||||
|
|
||||||
# Update document to READY with actual content
|
# Update document to READY with actual content
|
||||||
document.title = f"{item['team_name']} - {item['channel_name']}"
|
document.title = f"{item['team_name']} - {item['channel_name']}"
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
Unified document save/update logic for file processors.
|
Unified document save/update logic for file processors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
@ -43,7 +44,7 @@ async def _generate_summary(
|
||||||
"""
|
"""
|
||||||
if not enable_summary:
|
if not enable_summary:
|
||||||
summary = f"File: {file_name}\n\n{markdown_content[:4000]}"
|
summary = f"File: {file_name}\n\n{markdown_content[:4000]}"
|
||||||
return summary, embed_text(summary)
|
return summary, await asyncio.to_thread(embed_text, summary)
|
||||||
|
|
||||||
if etl_service == "DOCLING":
|
if etl_service == "DOCLING":
|
||||||
from app.services.docling_service import create_docling_service
|
from app.services.docling_service import create_docling_service
|
||||||
|
|
@ -65,7 +66,7 @@ async def _generate_summary(
|
||||||
parts.append(f"**{formatted_key}:** {value}")
|
parts.append(f"**{formatted_key}:** {value}")
|
||||||
|
|
||||||
enhanced = "\n".join(parts) + "\n\n# DOCUMENT SUMMARY\n\n" + summary_text
|
enhanced = "\n".join(parts) + "\n\n# DOCUMENT SUMMARY\n\n" + summary_text
|
||||||
return enhanced, embed_text(enhanced)
|
return enhanced, await asyncio.to_thread(embed_text, enhanced)
|
||||||
|
|
||||||
# Standard summary (Unstructured / LlamaCloud / others)
|
# Standard summary (Unstructured / LlamaCloud / others)
|
||||||
meta = {
|
meta = {
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
|
@ -221,7 +222,9 @@ async def generate_document_summary(
|
||||||
else:
|
else:
|
||||||
enhanced_summary_content = summary_content
|
enhanced_summary_content = summary_content
|
||||||
|
|
||||||
summary_embedding = embed_text(enhanced_summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, enhanced_summary_content
|
||||||
|
)
|
||||||
|
|
||||||
return enhanced_summary_content, summary_embedding
|
return enhanced_summary_content, summary_embedding
|
||||||
|
|
||||||
|
|
@ -237,7 +240,7 @@ async def create_document_chunks(content: str) -> list[Chunk]:
|
||||||
List of Chunk objects with embeddings
|
List of Chunk objects with embeddings
|
||||||
"""
|
"""
|
||||||
chunk_texts = [c.text for c in config.chunker_instance.chunk(content)]
|
chunk_texts = [c.text for c in config.chunker_instance.chunk(content)]
|
||||||
chunk_embeddings = embed_texts(chunk_texts)
|
chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
|
||||||
return [
|
return [
|
||||||
Chunk(content=text, embedding=emb)
|
Chunk(content=text, embedding=emb)
|
||||||
for text, emb in zip(chunk_texts, chunk_embeddings, strict=False)
|
for text, emb in zip(chunk_texts, chunk_embeddings, strict=False)
|
||||||
|
|
|
||||||
|
|
@ -12,13 +12,19 @@ prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
|
||||||
the deepagent stack accumulates multiple ``SystemMessage``\ s in
|
the deepagent stack accumulates multiple ``SystemMessage``\ s in
|
||||||
``state["messages"]`` and ``role: system`` would tag every one of
|
``state["messages"]`` and ``role: system`` would tag every one of
|
||||||
them, blowing past Anthropic's 4-block ``cache_control`` cap.
|
them, blowing past Anthropic's 4-block ``cache_control`` cap.
|
||||||
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
|
2. Adds ``prompt_cache_key`` for OPENAI/DEEPSEEK/XAI/AZURE/AZURE_OPENAI
|
||||||
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
|
configs (Microsoft's Azure transformer was added to LiteLLM in
|
||||||
prompt-cache surface is available).
|
https://github.com/BerriAI/litellm/pull/20989, Feb 2026).
|
||||||
3. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only — no
|
3. Adds ``prompt_cache_retention="24h"`` ONLY for OPENAI/DEEPSEEK/XAI.
|
||||||
OpenAI-only kwargs because the router fans out across providers.
|
Azure's server-side support landed in Microsoft's docs on 2026-05-13
|
||||||
4. Idempotent: user-supplied values in ``model_kwargs`` are preserved.
|
but LiteLLM 1.83.14 hasn't wired it through yet, so we let Azure use
|
||||||
5. Defensive: LLMs without a writable ``model_kwargs`` are silently
|
its default in-memory retention rather than send a param that
|
||||||
|
``litellm.drop_params`` would silently strip.
|
||||||
|
4. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only — no
|
||||||
|
destination-specific kwargs because the router fans out across
|
||||||
|
providers.
|
||||||
|
5. Idempotent: user-supplied values in ``model_kwargs`` are preserved.
|
||||||
|
6. Defensive: LLMs without a writable ``model_kwargs`` are silently
|
||||||
skipped rather than raising.
|
skipped rather than raising.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -191,9 +197,9 @@ def test_does_not_overwrite_user_supplied_prompt_cache_key() -> None:
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"])
|
@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"])
|
||||||
def test_sets_openai_family_extras(provider: str) -> None:
|
def test_sets_openai_family_extras(provider: str) -> None:
|
||||||
"""OpenAI-style providers gain ``prompt_cache_key`` (raises hit rate
|
"""Native OpenAI-style providers gain ``prompt_cache_key`` (raises
|
||||||
via routing affinity) and ``prompt_cache_retention="24h"`` (extends
|
hit rate via routing affinity) and ``prompt_cache_retention="24h"``
|
||||||
cache TTL beyond the default 5-10 min)."""
|
(extends cache TTL beyond the default 5-10 min)."""
|
||||||
cfg = _make_cfg(provider=provider)
|
cfg = _make_cfg(provider=provider)
|
||||||
llm = _FakeLLM()
|
llm = _FakeLLM()
|
||||||
|
|
||||||
|
|
@ -203,6 +209,27 @@ def test_sets_openai_family_extras(provider: str) -> None:
|
||||||
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
|
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("provider", ["AZURE", "AZURE_OPENAI"])
|
||||||
|
def test_azure_gets_prompt_cache_key_only(provider: str) -> None:
|
||||||
|
"""Azure configs gain ``prompt_cache_key`` for routing affinity
|
||||||
|
(Microsoft auto-caches every GPT-4o+ deployment at ≥1024 tokens;
|
||||||
|
the key clusters same-prefix requests on the same backend GPU pool
|
||||||
|
so hit rate climbs). They DO NOT get ``prompt_cache_retention``
|
||||||
|
because LiteLLM 1.83.14's Azure transformer omits it from its
|
||||||
|
supported params list — ``drop_params`` would silently strip it.
|
||||||
|
Azure's default in-memory retention (5-10 min, max 1 h) is already
|
||||||
|
enough to cover intra-conversation turns; revisit when LiteLLM
|
||||||
|
bumps Azure to match its OpenAI surface."""
|
||||||
|
cfg = _make_cfg(provider=provider, model_name="gpt-5.4")
|
||||||
|
llm = _FakeLLM(model="azure/gpt-5.4")
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
|
||||||
|
|
||||||
|
assert llm.model_kwargs["prompt_cache_key"] == "surfsense-thread-42"
|
||||||
|
assert "prompt_cache_retention" not in llm.model_kwargs
|
||||||
|
assert "cache_control_injection_points" in llm.model_kwargs
|
||||||
|
|
||||||
|
|
||||||
def test_skips_prompt_cache_key_when_no_thread_id() -> None:
|
def test_skips_prompt_cache_key_when_no_thread_id() -> None:
|
||||||
"""Without a thread id we can't construct a per-thread key. Retention
|
"""Without a thread id we can't construct a per-thread key. Retention
|
||||||
is still useful so we set it (it's free)."""
|
is still useful so we set it (it's free)."""
|
||||||
|
|
@ -215,12 +242,26 @@ def test_skips_prompt_cache_key_when_no_thread_id() -> None:
|
||||||
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
|
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_skips_prompt_cache_key_when_no_thread_id() -> None:
|
||||||
|
"""Azure without a thread id ends up with no extras (retention is
|
||||||
|
Azure-skipped, key needs a thread id) — universal injection points
|
||||||
|
still land."""
|
||||||
|
cfg = _make_cfg(provider="AZURE", model_name="gpt-5.4")
|
||||||
|
llm = _FakeLLM(model="azure/gpt-5.4")
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=None)
|
||||||
|
|
||||||
|
assert "prompt_cache_key" not in llm.model_kwargs
|
||||||
|
assert "prompt_cache_retention" not in llm.model_kwargs
|
||||||
|
assert "cache_control_injection_points" in llm.model_kwargs
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"provider",
|
"provider",
|
||||||
["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"],
|
["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"],
|
||||||
)
|
)
|
||||||
def test_no_openai_extras_for_other_providers(provider: str) -> None:
|
def test_no_openai_extras_for_other_providers(provider: str) -> None:
|
||||||
"""Non-OpenAI-family providers don't expose ``prompt_cache_key`` —
|
"""Non-OpenAI-style providers don't expose ``prompt_cache_key`` —
|
||||||
skip it. ``cache_control_injection_points`` is still set (universal)."""
|
skip it. ``cache_control_injection_points`` is still set (universal)."""
|
||||||
cfg = _make_cfg(provider=provider)
|
cfg = _make_cfg(provider=provider)
|
||||||
llm = _FakeLLM()
|
llm = _FakeLLM()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,130 @@
|
||||||
|
"""Unit tests for ``mcp_tools_cache``."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
CachedMCPToolDef,
|
||||||
|
CachedMCPTools,
|
||||||
|
read_cached_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _make_connector(config: dict | None) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(id=42, config=config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_when_config_is_none() -> None:
|
||||||
|
assert read_cached_tools(_make_connector(None)) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_when_cached_tools_missing() -> None:
|
||||||
|
assert read_cached_tools(_make_connector({"server_config": {}})) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_when_cached_tools_is_not_a_dict() -> None:
|
||||||
|
assert read_cached_tools(_make_connector({"cached_tools": []})) is None
|
||||||
|
assert read_cached_tools(_make_connector({"cached_tools": "stale"})) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_parses_minimal_valid_payload() -> None:
|
||||||
|
parsed = read_cached_tools(
|
||||||
|
_make_connector(
|
||||||
|
{
|
||||||
|
"cached_tools": {
|
||||||
|
"discovered_at": "2026-05-20T10:00:00+00:00",
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": "list_issues",
|
||||||
|
"description": "List Linear issues",
|
||||||
|
"input_schema": {"type": "object"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert parsed is not None
|
||||||
|
assert parsed.server_version is None
|
||||||
|
assert parsed.server_name is None
|
||||||
|
assert parsed.transport is None
|
||||||
|
assert len(parsed.tools) == 1
|
||||||
|
assert parsed.tools[0].name == "list_issues"
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_parses_full_payload_with_serverinfo() -> None:
|
||||||
|
parsed = read_cached_tools(
|
||||||
|
_make_connector(
|
||||||
|
{
|
||||||
|
"cached_tools": {
|
||||||
|
"discovered_at": "2026-05-20T10:00:00+00:00",
|
||||||
|
"server_version": "1.2.3",
|
||||||
|
"server_name": "atlassian-mcp",
|
||||||
|
"transport": "streamable-http",
|
||||||
|
"tools": [
|
||||||
|
{"name": "create_issue", "input_schema": {}},
|
||||||
|
{"name": "list_issues", "input_schema": {}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert parsed is not None
|
||||||
|
assert parsed.server_version == "1.2.3"
|
||||||
|
assert parsed.server_name == "atlassian-mcp"
|
||||||
|
assert parsed.transport == "streamable-http"
|
||||||
|
assert [t.name for t in parsed.tools] == ["create_issue", "list_issues"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_for_corrupt_payload(caplog) -> None:
|
||||||
|
parsed = read_cached_tools(
|
||||||
|
_make_connector(
|
||||||
|
{
|
||||||
|
"cached_tools": {
|
||||||
|
"discovered_at": "not-a-date",
|
||||||
|
"tools": "should-be-a-list",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert parsed is None
|
||||||
|
assert any("corrupt cached_tools" in r.getMessage() for r in caplog.records)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_when_tools_missing() -> None:
|
||||||
|
parsed = read_cached_tools(
|
||||||
|
_make_connector(
|
||||||
|
{"cached_tools": {"discovered_at": "2026-05-20T10:00:00+00:00"}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert parsed is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_def_defaults_description_and_schema() -> None:
|
||||||
|
td = CachedMCPToolDef.model_validate({"name": "ping"})
|
||||||
|
assert td.description == ""
|
||||||
|
assert td.input_schema == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_dump_json_mode_is_round_trippable() -> None:
|
||||||
|
original = CachedMCPTools(
|
||||||
|
discovered_at=datetime(2026, 5, 20, 10, 0, 0, tzinfo=UTC),
|
||||||
|
server_version="1.2.3",
|
||||||
|
server_name="atlassian-mcp",
|
||||||
|
transport="streamable-http",
|
||||||
|
tools=[CachedMCPToolDef(name="list_issues")],
|
||||||
|
)
|
||||||
|
payload = original.model_dump(mode="json")
|
||||||
|
|
||||||
|
assert payload["discovered_at"] == "2026-05-20T10:00:00Z"
|
||||||
|
assert payload["tools"][0]["name"] == "list_issues"
|
||||||
|
|
||||||
|
reparsed = CachedMCPTools.model_validate(payload)
|
||||||
|
assert reparsed.discovered_at == original.discovered_at
|
||||||
|
assert reparsed.tools[0].name == "list_issues"
|
||||||
|
|
@ -209,128 +209,6 @@ def test_stream_exception_classifies_openrouter_429_payload():
|
||||||
assert extra is None
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch):
|
|
||||||
"""``_preflight_llm`` is best-effort.
|
|
||||||
|
|
||||||
- On rate-limit shaped exceptions (provider 429) it MUST re-raise so the
|
|
||||||
caller can drive the cooldown/repin branch.
|
|
||||||
- On any other transient failure it MUST swallow the error so the normal
|
|
||||||
stream path continues without surfacing preflight noise to the user.
|
|
||||||
"""
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from app.tasks.chat.stream_new_chat import _preflight_llm
|
|
||||||
|
|
||||||
class _RateLimitedError(Exception):
|
|
||||||
"""Class-name carries 'RateLimit' so _is_provider_rate_limited triggers."""
|
|
||||||
|
|
||||||
rate_calls: list[dict] = []
|
|
||||||
other_calls: list[dict] = []
|
|
||||||
|
|
||||||
async def _fake_acompletion_429(**kwargs):
|
|
||||||
rate_calls.append(kwargs)
|
|
||||||
raise _RateLimitedError("simulated 429")
|
|
||||||
|
|
||||||
async def _fake_acompletion_other(**kwargs):
|
|
||||||
other_calls.append(kwargs)
|
|
||||||
raise RuntimeError("some unrelated transient failure")
|
|
||||||
|
|
||||||
fake_llm = SimpleNamespace(
|
|
||||||
model="openrouter/google/gemma-4-31b-it:free",
|
|
||||||
api_key="test",
|
|
||||||
api_base=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
import litellm # type: ignore[import-not-found]
|
|
||||||
|
|
||||||
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429)
|
|
||||||
with pytest.raises(_RateLimitedError):
|
|
||||||
await _preflight_llm(fake_llm)
|
|
||||||
assert len(rate_calls) == 1
|
|
||||||
assert rate_calls[0]["max_tokens"] == 1
|
|
||||||
assert rate_calls[0]["stream"] is False
|
|
||||||
|
|
||||||
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other)
|
|
||||||
# MUST NOT raise: non-rate-limit failures are swallowed.
|
|
||||||
await _preflight_llm(fake_llm)
|
|
||||||
assert len(other_calls) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_preflight_skipped_for_auto_router_model():
|
|
||||||
"""Router-mode ``model='auto'`` has no single deployment to ping; the
|
|
||||||
LiteLLM router itself owns per-deployment rate-limit accounting, so the
|
|
||||||
preflight helper must short-circuit instead of issuing a probe."""
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from app.tasks.chat.stream_new_chat import _preflight_llm
|
|
||||||
|
|
||||||
fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None)
|
|
||||||
# Should return without raising or making any LiteLLM call.
|
|
||||||
await _preflight_llm(fake_llm)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_settle_speculative_agent_build_swallows_exceptions():
|
|
||||||
"""``_settle_speculative_agent_build`` MUST always return cleanly so the
|
|
||||||
caller can safely re-touch the request-scoped session afterwards.
|
|
||||||
|
|
||||||
The helper guards the parallel preflight + agent-build path: when the
|
|
||||||
speculative build is being discarded (429 or non-429 preflight failure)
|
|
||||||
we await it solely to release any in-flight ``AsyncSession`` usage —
|
|
||||||
the build's outcome is irrelevant. Any exception (including
|
|
||||||
``CancelledError``) leaking out would skip the caller's recovery flow
|
|
||||||
and re-introduce the very session-concurrency hazard the helper exists
|
|
||||||
to prevent.
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
|
|
||||||
|
|
||||||
async def _raises() -> None:
|
|
||||||
raise RuntimeError("speculative build crashed")
|
|
||||||
|
|
||||||
async def _succeeds() -> str:
|
|
||||||
return "agent"
|
|
||||||
|
|
||||||
async def _slow() -> None:
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
|
|
||||||
for coro in (_raises(), _succeeds(), _slow()):
|
|
||||||
task = asyncio.create_task(coro)
|
|
||||||
await _settle_speculative_agent_build(task)
|
|
||||||
assert task.done()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_settle_speculative_agent_build_handles_already_done_task():
|
|
||||||
"""Done tasks (success or failure) must still be settled without raising."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
|
|
||||||
|
|
||||||
async def _ok() -> str:
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
async def _bad() -> None:
|
|
||||||
raise ValueError("nope")
|
|
||||||
|
|
||||||
ok_task = asyncio.create_task(_ok())
|
|
||||||
bad_task = asyncio.create_task(_bad())
|
|
||||||
# Drive both to completion before settling.
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
|
|
||||||
await _settle_speculative_agent_build(ok_task)
|
|
||||||
await _settle_speculative_agent_build(bad_task)
|
|
||||||
assert ok_task.result() == "ok"
|
|
||||||
# ``bad_task`` exception was consumed by the settle helper; calling
|
|
||||||
# ``.exception()`` after the fact must still return the original error
|
|
||||||
# (the helper observes it but doesn't clear it).
|
|
||||||
assert isinstance(bad_task.exception(), ValueError)
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_exception_classifies_thread_busy():
|
def test_stream_exception_classifies_thread_busy():
|
||||||
exc = BusyError(request_id="thread-123")
|
exc = BusyError(request_id="thread-123")
|
||||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue