mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 13:22:41 +02:00
chore: linting
This commit is contained in:
parent
b9a66cb417
commit
ca9bbee06d
41 changed files with 314 additions and 244 deletions
|
|
@ -88,7 +88,5 @@ def upgrade() -> None:
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.drop_index(
|
op.drop_index("ix_agent_action_log_thread_created", table_name="agent_action_log")
|
||||||
"ix_agent_action_log_thread_created", table_name="agent_action_log"
|
|
||||||
)
|
|
||||||
op.drop_table("agent_action_log")
|
op.drop_table("agent_action_log")
|
||||||
|
|
|
||||||
|
|
@ -51,9 +51,7 @@ def upgrade() -> None:
|
||||||
# implicit-unique-index variant SQLAlchemy may emit need draining.
|
# implicit-unique-index variant SQLAlchemy may emit need draining.
|
||||||
constraints = _existing_constraint_names(bind, "documents")
|
constraints = _existing_constraint_names(bind, "documents")
|
||||||
if "uq_documents_content_hash" in constraints:
|
if "uq_documents_content_hash" in constraints:
|
||||||
op.drop_constraint(
|
op.drop_constraint("uq_documents_content_hash", "documents", type_="unique")
|
||||||
"uq_documents_content_hash", "documents", type_="unique"
|
|
||||||
)
|
|
||||||
|
|
||||||
indexes = _existing_index_names(bind, "documents")
|
indexes = _existing_index_names(bind, "documents")
|
||||||
# Some Postgres versions surface the unique constraint via a unique
|
# Some Postgres versions surface the unique constraint via a unique
|
||||||
|
|
|
||||||
|
|
@ -416,10 +416,10 @@ async def create_surfsense_deep_agent(
|
||||||
# cheap to build. ``SubAgentMiddleware.__init__`` calls ``create_agent``
|
# cheap to build. ``SubAgentMiddleware.__init__`` calls ``create_agent``
|
||||||
# synchronously to compile the general-purpose subagent's full state graph
|
# synchronously to compile the general-purpose subagent's full state graph
|
||||||
# (every tool + every middleware → pydantic schemas + langgraph compile).
|
# (every tool + every middleware → pydantic schemas + langgraph compile).
|
||||||
# On gpt-5.x agents that's roughly 1.5–2s of pure CPU work. If we run it
|
# On gpt-5.x agents that's roughly 1.5-2s of pure CPU work. If we run it
|
||||||
# directly here it blocks the asyncio event loop for the whole streaming
|
# directly here it blocks the asyncio event loop for the whole streaming
|
||||||
# task (and any other coroutine sharing this loop), which is why
|
# task (and any other coroutine sharing this loop), which is why
|
||||||
# "agent creation" wall-clock time used to stretch to ~3–4s. Move the
|
# "agent creation" wall-clock time used to stretch to ~3-4s. Move the
|
||||||
# entire middleware build + main-graph compile into a single
|
# entire middleware build + main-graph compile into a single
|
||||||
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
|
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
|
||||||
# event loop stays responsive.
|
# event loop stays responsive.
|
||||||
|
|
@ -587,10 +587,7 @@ def _build_compiled_agent_blocking(
|
||||||
# by name. Off by default until the flag flips so existing deployments
|
# by name. Off by default until the flag flips so existing deployments
|
||||||
# don't see new agent types in the task tool description.
|
# don't see new agent types in the task tool description.
|
||||||
specialized_subagents: list[SubAgent] = []
|
specialized_subagents: list[SubAgent] = []
|
||||||
if (
|
if flags.enable_specialized_subagents and not flags.disable_new_agent_stack:
|
||||||
flags.enable_specialized_subagents
|
|
||||||
and not flags.disable_new_agent_stack
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
# Specialized subagents share the parent's filesystem +
|
# Specialized subagents share the parent's filesystem +
|
||||||
# todo view so their system prompts (which promise
|
# todo view so their system prompts (which promise
|
||||||
|
|
@ -696,7 +693,9 @@ def _build_compiled_agent_blocking(
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
tool_call_limit_mw = (
|
tool_call_limit_mw = (
|
||||||
ToolCallLimitMiddleware(thread_limit=300, run_limit=80, exit_behavior="continue")
|
ToolCallLimitMiddleware(
|
||||||
|
thread_limit=300, run_limit=80, exit_behavior="continue"
|
||||||
|
)
|
||||||
if flags.enable_tool_call_limit and not flags.disable_new_agent_stack
|
if flags.enable_tool_call_limit and not flags.disable_new_agent_stack
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
@ -879,7 +878,11 @@ def _build_compiled_agent_blocking(
|
||||||
max_tools=12,
|
max_tools=12,
|
||||||
always_include=[
|
always_include=[
|
||||||
name
|
name
|
||||||
for name in ("update_memory", "get_connected_accounts", "scrape_webpage")
|
for name in (
|
||||||
|
"update_memory",
|
||||||
|
"get_connected_accounts",
|
||||||
|
"scrape_webpage",
|
||||||
|
)
|
||||||
if name in {t.name for t in tools}
|
if name in {t.name for t in tools}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,9 @@ class AgentFeatureFlags:
|
||||||
enable_model_call_limit: bool = False
|
enable_model_call_limit: bool = False
|
||||||
enable_tool_call_limit: bool = False
|
enable_tool_call_limit: bool = False
|
||||||
enable_tool_call_repair: bool = False
|
enable_tool_call_repair: bool = False
|
||||||
enable_doom_loop: bool = False # Default OFF until UI handles permission='doom_loop'
|
enable_doom_loop: bool = (
|
||||||
|
False # Default OFF until UI handles permission='doom_loop'
|
||||||
|
)
|
||||||
|
|
||||||
# Tier 2 — Safety
|
# Tier 2 — Safety
|
||||||
enable_permission: bool = False # Default OFF for first deploy
|
enable_permission: bool = False # Default OFF for first deploy
|
||||||
|
|
@ -79,7 +81,9 @@ class AgentFeatureFlags:
|
||||||
|
|
||||||
# Tier 5 — Snapshot / revert
|
# Tier 5 — Snapshot / revert
|
||||||
enable_action_log: bool = False
|
enable_action_log: bool = False
|
||||||
enable_revert_route: bool = False # Backend ships before UI; route returns 503 until this flips
|
enable_revert_route: bool = (
|
||||||
|
False # Backend ships before UI; route returns 503 until this flips
|
||||||
|
)
|
||||||
|
|
||||||
# Tier 6 — Plugins
|
# Tier 6 — Plugins
|
||||||
enable_plugin_loader: bool = False
|
enable_plugin_loader: bool = False
|
||||||
|
|
@ -109,14 +113,20 @@ class AgentFeatureFlags:
|
||||||
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False),
|
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False),
|
||||||
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False),
|
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False),
|
||||||
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
|
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
|
||||||
enable_model_call_limit=_env_bool("SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False),
|
enable_model_call_limit=_env_bool(
|
||||||
|
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False
|
||||||
|
),
|
||||||
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False),
|
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False),
|
||||||
enable_tool_call_repair=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False),
|
enable_tool_call_repair=_env_bool(
|
||||||
|
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False
|
||||||
|
),
|
||||||
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False),
|
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False),
|
||||||
# Tier 2
|
# Tier 2
|
||||||
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False),
|
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False),
|
||||||
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False),
|
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False),
|
||||||
enable_llm_tool_selector=_env_bool("SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False),
|
enable_llm_tool_selector=_env_bool(
|
||||||
|
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
|
||||||
|
),
|
||||||
# Tier 4
|
# Tier 4
|
||||||
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False),
|
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False),
|
||||||
enable_specialized_subagents=_env_bool(
|
enable_specialized_subagents=_env_bool(
|
||||||
|
|
|
||||||
|
|
@ -101,9 +101,7 @@ class ActionLogMiddleware(AgentMiddleware):
|
||||||
async def awrap_tool_call(
|
async def awrap_tool_call(
|
||||||
self,
|
self,
|
||||||
request: ToolCallRequest,
|
request: ToolCallRequest,
|
||||||
handler: Callable[
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||||
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
|
|
||||||
],
|
|
||||||
) -> ToolMessage | Command[Any]:
|
) -> ToolMessage | Command[Any]:
|
||||||
if not self._enabled():
|
if not self._enabled():
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
|
|
|
||||||
|
|
@ -177,8 +177,8 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware):
|
||||||
messages_in=len(conversation_messages),
|
messages_in=len(conversation_messages),
|
||||||
extra={"compaction.cutoff_index": int(cutoff_index)},
|
extra={"compaction.cutoff_index": int(cutoff_index)},
|
||||||
):
|
):
|
||||||
messages_to_summarize, preserved_messages = (
|
messages_to_summarize, preserved_messages = super()._partition_messages(
|
||||||
super()._partition_messages(conversation_messages, cutoff_index)
|
conversation_messages, cutoff_index
|
||||||
)
|
)
|
||||||
|
|
||||||
protected: list[AnyMessage] = []
|
protected: list[AnyMessage] = []
|
||||||
|
|
|
||||||
|
|
@ -58,8 +58,7 @@ DEFAULT_SPILL_PREFIX = "/tool_outputs"
|
||||||
def _build_spill_placeholder(spill_path: str) -> str:
|
def _build_spill_placeholder(spill_path: str) -> str:
|
||||||
"""Build the user-facing placeholder text shown to the model."""
|
"""Build the user-facing placeholder text shown to the model."""
|
||||||
return (
|
return (
|
||||||
f"[cleared — full output at {spill_path}; "
|
f"[cleared — full output at {spill_path}; ask the explore subagent to read it]"
|
||||||
f"ask the explore subagent to read it]"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -131,7 +130,9 @@ class SpillToBackendEdit(ContextEdit):
|
||||||
return
|
return
|
||||||
|
|
||||||
candidates = [
|
candidates = [
|
||||||
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
|
(idx, msg)
|
||||||
|
for idx, msg in enumerate(messages)
|
||||||
|
if isinstance(msg, ToolMessage)
|
||||||
]
|
]
|
||||||
if self.keep >= len(candidates):
|
if self.keep >= len(candidates):
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -137,16 +137,21 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon
|
||||||
|
|
||||||
triggered_call: dict[str, Any] | None = None
|
triggered_call: dict[str, Any] | None = None
|
||||||
for call in message.tool_calls:
|
for call in message.tool_calls:
|
||||||
name = call.get("name") if isinstance(call, dict) else getattr(call, "name", None)
|
name = (
|
||||||
args = call.get("args") if isinstance(call, dict) else getattr(call, "args", {})
|
call.get("name")
|
||||||
|
if isinstance(call, dict)
|
||||||
|
else getattr(call, "name", None)
|
||||||
|
)
|
||||||
|
args = (
|
||||||
|
call.get("args")
|
||||||
|
if isinstance(call, dict)
|
||||||
|
else getattr(call, "args", {})
|
||||||
|
)
|
||||||
if not isinstance(name, str):
|
if not isinstance(name, str):
|
||||||
continue
|
continue
|
||||||
sig = _signature(name, args)
|
sig = _signature(name, args)
|
||||||
window.append(sig)
|
window.append(sig)
|
||||||
if (
|
if len(window) >= self._threshold and len(set(window)) == 1:
|
||||||
len(window) >= self._threshold
|
|
||||||
and len(set(window)) == 1
|
|
||||||
):
|
|
||||||
triggered_call = {"name": name, "params": args or {}}
|
triggered_call = {"name": name, "params": args or {}}
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -209,7 +214,9 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon
|
||||||
# tool call proceeds. The frontend's exact reply names may differ —
|
# tool call proceeds. The frontend's exact reply names may differ —
|
||||||
# we tolerate any shape that contains a string with "reject"/"cancel".
|
# we tolerate any shape that contains a string with "reject"/"cancel".
|
||||||
if isinstance(decision, dict):
|
if isinstance(decision, dict):
|
||||||
kind = str(decision.get("decision_type") or decision.get("type") or "").lower()
|
kind = str(
|
||||||
|
decision.get("decision_type") or decision.get("type") or ""
|
||||||
|
).lower()
|
||||||
if "reject" in kind or "cancel" in kind:
|
if "reject" in kind or "cancel" in kind:
|
||||||
return {"jump_to": "end"}
|
return {"jump_to": "end"}
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -552,7 +552,7 @@ def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage:
|
||||||
for entry in priority:
|
for entry in priority:
|
||||||
score = entry.get("score")
|
score = entry.get("score")
|
||||||
mentioned = entry.get("mentioned")
|
mentioned = entry.get("mentioned")
|
||||||
score_str = f"{score:.3f}" if isinstance(score, (int, float)) else "n/a"
|
score_str = f"{score:.3f}" if isinstance(score, int | float) else "n/a"
|
||||||
mark = " [USER-MENTIONED]" if mentioned else ""
|
mark = " [USER-MENTIONED]" if mentioned else ""
|
||||||
lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}")
|
lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}")
|
||||||
body = "\n".join(lines)
|
body = "\n".join(lines)
|
||||||
|
|
@ -593,7 +593,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.mentioned_document_ids = mentioned_document_ids or []
|
self.mentioned_document_ids = mentioned_document_ids or []
|
||||||
# Tier 4.2: build the kb-planner private Runnable ONCE here so we
|
# Tier 4.2: build the kb-planner private Runnable ONCE here so we
|
||||||
# don't pay the create_agent compile cost (50–200ms) on every turn.
|
# don't pay the create_agent compile cost (50-200ms) on every turn.
|
||||||
# Disabled by default behind ``enable_kb_planner_runnable``; when off
|
# Disabled by default behind ``enable_kb_planner_runnable``; when off
|
||||||
# the planner falls back to the legacy ``self.llm.ainvoke`` path.
|
# the planner falls back to the legacy ``self.llm.ainvoke`` path.
|
||||||
self._planner: Runnable | None = None
|
self._planner: Runnable | None = None
|
||||||
|
|
@ -617,10 +617,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
if self.llm is None:
|
if self.llm is None:
|
||||||
return None
|
return None
|
||||||
flags = get_flags()
|
flags = get_flags()
|
||||||
if (
|
if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
|
||||||
not flags.enable_kb_planner_runnable
|
|
||||||
or flags.disable_new_agent_stack
|
|
||||||
):
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware
|
from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware
|
||||||
|
|
@ -920,7 +917,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
chunk_ids = doc.get("matched_chunk_ids") or []
|
chunk_ids = doc.get("matched_chunk_ids") or []
|
||||||
if chunk_ids:
|
if chunk_ids:
|
||||||
matched_chunk_ids[doc_id] = [
|
matched_chunk_ids[doc_id] = [
|
||||||
int(cid) for cid in chunk_ids if isinstance(cid, (int, str))
|
int(cid) for cid in chunk_ids if isinstance(cid, int | str)
|
||||||
]
|
]
|
||||||
return priority, matched_chunk_ids
|
return priority, matched_chunk_ids
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,9 +35,7 @@ from langchain_core.tools import tool
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
NOOP_TOOL_NAME = "_noop"
|
NOOP_TOOL_NAME = "_noop"
|
||||||
NOOP_TOOL_DESCRIPTION = (
|
NOOP_TOOL_DESCRIPTION = "Do not call this tool. It exists only for API compatibility."
|
||||||
"Do not call this tool. It exists only for API compatibility."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION)
|
@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION)
|
||||||
|
|
@ -78,7 +76,9 @@ def _last_ai_has_tool_calls(messages: list[Any]) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
class NoopInjectionMiddleware(
|
||||||
|
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||||
|
):
|
||||||
"""Inject the ``_noop`` tool only when the provider would otherwise 400.
|
"""Inject the ``_noop`` tool only when the provider would otherwise 400.
|
||||||
|
|
||||||
The check fires per model call, not at agent build time, because the
|
The check fires per model call, not at agent build time, because the
|
||||||
|
|
@ -116,7 +116,9 @@ class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R
|
||||||
async def awrap_model_call( # type: ignore[override]
|
async def awrap_model_call( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
request: ModelRequest[ContextT],
|
request: ModelRequest[ContextT],
|
||||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
handler: Callable[
|
||||||
|
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||||
|
],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if self._should_inject(request):
|
if self._should_inject(request):
|
||||||
logger.debug("Injecting _noop tool for provider compatibility")
|
logger.debug("Injecting _noop tool for provider compatibility")
|
||||||
|
|
|
||||||
|
|
@ -56,9 +56,7 @@ class OtelSpanMiddleware(AgentMiddleware):
|
||||||
async def awrap_model_call(
|
async def awrap_model_call(
|
||||||
self,
|
self,
|
||||||
request: ModelRequest,
|
request: ModelRequest,
|
||||||
handler: Callable[
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]],
|
||||||
[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]
|
|
||||||
],
|
|
||||||
) -> ModelResponse | AIMessage | Any:
|
) -> ModelResponse | AIMessage | Any:
|
||||||
if not ot.is_enabled():
|
if not ot.is_enabled():
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
|
|
@ -81,9 +79,7 @@ class OtelSpanMiddleware(AgentMiddleware):
|
||||||
async def awrap_tool_call(
|
async def awrap_tool_call(
|
||||||
self,
|
self,
|
||||||
request: ToolCallRequest,
|
request: ToolCallRequest,
|
||||||
handler: Callable[
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||||
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
|
|
||||||
],
|
|
||||||
) -> ToolMessage | Command[Any]:
|
) -> ToolMessage | Command[Any]:
|
||||||
if not ot.is_enabled():
|
if not ot.is_enabled():
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
|
|
@ -187,7 +183,11 @@ def _annotate_model_response(span: Any, result: Any) -> None:
|
||||||
def _annotate_tool_result(span: Any, result: Any) -> None:
|
def _annotate_tool_result(span: Any, result: Any) -> None:
|
||||||
try:
|
try:
|
||||||
if isinstance(result, ToolMessage):
|
if isinstance(result, ToolMessage):
|
||||||
content = result.content if isinstance(result.content, str) else repr(result.content)
|
content = (
|
||||||
|
result.content
|
||||||
|
if isinstance(result.content, str)
|
||||||
|
else repr(result.content)
|
||||||
|
)
|
||||||
span.set_attribute("tool.output.size", len(content))
|
span.set_attribute("tool.output.size", len(content))
|
||||||
status = getattr(result, "status", None)
|
status = getattr(result, "status", None)
|
||||||
if isinstance(status, str):
|
if isinstance(status, str):
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
try:
|
try:
|
||||||
patterns = resolver(args or {})
|
patterns = resolver(args or {})
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Pattern resolver for %s raised; using bare name", tool_name)
|
logger.exception(
|
||||||
|
"Pattern resolver for %s raised; using bare name", tool_name
|
||||||
|
)
|
||||||
patterns = [tool_name]
|
patterns = [tool_name]
|
||||||
if not patterns:
|
if not patterns:
|
||||||
patterns = [tool_name]
|
patterns = [tool_name]
|
||||||
|
|
@ -198,11 +200,14 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
# Tier 3b: permission.asked + interrupt.raised spans (no-op when
|
# Tier 3b: permission.asked + interrupt.raised spans (no-op when
|
||||||
# OTel is disabled). Both fire here so dashboards can correlate
|
# OTel is disabled). Both fire here so dashboards can correlate
|
||||||
# "we asked X" with "interrupt was actually delivered".
|
# "we asked X" with "interrupt was actually delivered".
|
||||||
with ot.permission_asked_span(
|
with (
|
||||||
|
ot.permission_asked_span(
|
||||||
permission=tool_name,
|
permission=tool_name,
|
||||||
pattern=patterns[0] if patterns else None,
|
pattern=patterns[0] if patterns else None,
|
||||||
extra={"permission.patterns": list(patterns)},
|
extra={"permission.patterns": list(patterns)},
|
||||||
), ot.interrupt_span(interrupt_type="permission_ask"):
|
),
|
||||||
|
ot.interrupt_span(interrupt_type="permission_ask"),
|
||||||
|
):
|
||||||
decision = interrupt(payload)
|
decision = interrupt(payload)
|
||||||
if isinstance(decision, dict):
|
if isinstance(decision, dict):
|
||||||
return decision
|
return decision
|
||||||
|
|
@ -211,9 +216,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
return {"decision_type": decision}
|
return {"decision_type": decision}
|
||||||
return {"decision_type": "reject"}
|
return {"decision_type": "reject"}
|
||||||
|
|
||||||
def _persist_always(
|
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
|
||||||
self, tool_name: str, patterns: list[str]
|
|
||||||
) -> None:
|
|
||||||
"""Promote ``always`` reply into runtime allow rules.
|
"""Promote ``always`` reply into runtime allow rules.
|
||||||
|
|
||||||
Persistence to ``agent_permission_rules`` is done by the
|
Persistence to ``agent_permission_rules`` is done by the
|
||||||
|
|
@ -276,12 +279,16 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
any_change = False
|
any_change = False
|
||||||
|
|
||||||
for raw in last.tool_calls:
|
for raw in last.tool_calls:
|
||||||
call = dict(raw) if isinstance(raw, dict) else {
|
call = (
|
||||||
|
dict(raw)
|
||||||
|
if isinstance(raw, dict)
|
||||||
|
else {
|
||||||
"name": getattr(raw, "name", None),
|
"name": getattr(raw, "name", None),
|
||||||
"args": getattr(raw, "args", {}),
|
"args": getattr(raw, "args", {}),
|
||||||
"id": getattr(raw, "id", None),
|
"id": getattr(raw, "id", None),
|
||||||
"type": "tool_call",
|
"type": "tool_call",
|
||||||
}
|
}
|
||||||
|
)
|
||||||
name = call.get("name") or ""
|
name = call.get("name") or ""
|
||||||
args = call.get("args") or {}
|
args = call.get("args") or {}
|
||||||
action, patterns, rules = self._evaluate(name, args)
|
action, patterns, rules = self._evaluate(name, args)
|
||||||
|
|
@ -307,7 +314,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
feedback = decision.get("feedback")
|
feedback = decision.get("feedback")
|
||||||
if isinstance(feedback, str) and feedback.strip():
|
if isinstance(feedback, str) and feedback.strip():
|
||||||
raise CorrectedError(feedback, tool=name)
|
raise CorrectedError(feedback, tool=name)
|
||||||
raise RejectedError(tool=name, pattern=patterns[0] if patterns else None)
|
raise RejectedError(
|
||||||
|
tool=name, pattern=patterns[0] if patterns else None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Unknown permission decision %r; treating as reject", kind
|
"Unknown permission decision %r; treating as reject", kind
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,9 @@ def _exponential_delay(
|
||||||
jitter: bool,
|
jitter: bool,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Compute an exponential-backoff delay with optional ±25% jitter."""
|
"""Compute an exponential-backoff delay with optional ±25% jitter."""
|
||||||
delay = initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
|
delay = (
|
||||||
|
initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
|
||||||
|
)
|
||||||
delay = min(delay, max_delay)
|
delay = min(delay, max_delay)
|
||||||
if jitter and delay > 0:
|
if jitter and delay > 0:
|
||||||
delay *= 1 + random.uniform(-0.25, 0.25)
|
delay *= 1 + random.uniform(-0.25, 0.25)
|
||||||
|
|
@ -201,7 +203,9 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("dispatch_custom_event failed; suppressed", exc_info=True)
|
logger.debug(
|
||||||
|
"dispatch_custom_event failed; suppressed", exc_info=True
|
||||||
|
)
|
||||||
if delay > 0:
|
if delay > 0:
|
||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
# Unreachable
|
# Unreachable
|
||||||
|
|
@ -210,7 +214,9 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
|
||||||
async def awrap_model_call( # type: ignore[override]
|
async def awrap_model_call( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
request: ModelRequest[ContextT],
|
request: ModelRequest[ContextT],
|
||||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
handler: Callable[
|
||||||
|
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||||
|
],
|
||||||
) -> ModelResponse[ResponseT] | AIMessage:
|
) -> ModelResponse[ResponseT] | AIMessage:
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ gives a clean failure mode if anything tries.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
|
|
@ -114,8 +115,10 @@ class BuiltinSkillsBackend(BackendProtocol):
|
||||||
infos: list[FileInfo] = []
|
infos: list[FileInfo] = []
|
||||||
# Build virtual paths anchored at "/" because CompositeBackend already
|
# Build virtual paths anchored at "/" because CompositeBackend already
|
||||||
# stripped the route prefix before calling us.
|
# stripped the route prefix before calling us.
|
||||||
target_virtual = "/" if target == self.root else (
|
target_virtual = (
|
||||||
"/" + str(target.relative_to(self.root)).replace("\\", "/")
|
"/"
|
||||||
|
if target == self.root
|
||||||
|
else ("/" + str(target.relative_to(self.root)).replace("\\", "/"))
|
||||||
)
|
)
|
||||||
for child in sorted(target.iterdir()):
|
for child in sorted(target.iterdir()):
|
||||||
child_virtual = (
|
child_virtual = (
|
||||||
|
|
@ -128,10 +131,8 @@ class BuiltinSkillsBackend(BackendProtocol):
|
||||||
"is_dir": child.is_dir(),
|
"is_dir": child.is_dir(),
|
||||||
}
|
}
|
||||||
if child.is_file():
|
if child.is_file():
|
||||||
try:
|
with contextlib.suppress(OSError): # pragma: no cover - defensive
|
||||||
info["size"] = child.stat().st_size
|
info["size"] = child.stat().st_size
|
||||||
except OSError: # pragma: no cover - defensive
|
|
||||||
pass
|
|
||||||
infos.append(info)
|
infos.append(info)
|
||||||
return infos
|
return infos
|
||||||
|
|
||||||
|
|
@ -163,7 +164,9 @@ class BuiltinSkillsBackend(BackendProtocol):
|
||||||
else:
|
else:
|
||||||
content = target.read_bytes()
|
content = target.read_bytes()
|
||||||
except PermissionError:
|
except PermissionError:
|
||||||
responses.append(FileDownloadResponse(path=p, error="permission_denied"))
|
responses.append(
|
||||||
|
FileDownloadResponse(path=p, error="permission_denied")
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
except OSError as exc: # pragma: no cover - defensive
|
except OSError as exc: # pragma: no cover - defensive
|
||||||
logger.warning("Builtin skill read failed %s: %s", target, exc)
|
logger.warning("Builtin skill read failed %s: %s", target, exc)
|
||||||
|
|
@ -286,6 +289,7 @@ def build_skills_backend_factory(
|
||||||
builtin = BuiltinSkillsBackend(builtin_root)
|
builtin = BuiltinSkillsBackend(builtin_root)
|
||||||
|
|
||||||
if search_space_id is None:
|
if search_space_id is None:
|
||||||
|
|
||||||
def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol:
|
def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol:
|
||||||
# Default StateBackend is intentionally inert: any path outside the
|
# Default StateBackend is intentionally inert: any path outside the
|
||||||
# ``/skills/builtin/`` route resolves to an empty per-runtime state
|
# ``/skills/builtin/`` route resolves to an empty per-runtime state
|
||||||
|
|
@ -294,6 +298,7 @@ def build_skills_backend_factory(
|
||||||
default=StateBackend(runtime),
|
default=StateBackend(runtime),
|
||||||
routes={SKILLS_BUILTIN_PREFIX: builtin},
|
routes={SKILLS_BUILTIN_PREFIX: builtin},
|
||||||
)
|
)
|
||||||
|
|
||||||
return _factory_builtin_only
|
return _factory_builtin_only
|
||||||
|
|
||||||
def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol:
|
def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol:
|
||||||
|
|
|
||||||
|
|
@ -51,13 +51,15 @@ def _coerce_existing_tool_call(call: Any) -> dict[str, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
class ToolCallNameRepairMiddleware(
|
||||||
|
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||||
|
):
|
||||||
"""Two-stage tool-name repair on the most recent ``AIMessage``.
|
"""Two-stage tool-name repair on the most recent ``AIMessage``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
registered_tool_names: Set of canonically-registered tool names.
|
registered_tool_names: Set of canonically-registered tool names.
|
||||||
``invalid`` should be in this set so the fallback dispatches.
|
``invalid`` should be in this set so the fallback dispatches.
|
||||||
fuzzy_match_threshold: Optional ``difflib`` ratio (0–1) for the
|
fuzzy_match_threshold: Optional ``difflib`` ratio (0-1) for the
|
||||||
fuzzy-match step that runs *between* lowercase and invalid.
|
fuzzy-match step that runs *between* lowercase and invalid.
|
||||||
Set to ``None`` to disable fuzzy matching (opencode parity).
|
Set to ``None`` to disable fuzzy matching (opencode parity).
|
||||||
"""
|
"""
|
||||||
|
|
@ -77,9 +79,9 @@ class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], Contex
|
||||||
def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]:
|
def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]:
|
||||||
"""Allow runtime overrides to expand the set (e.g. dynamic MCP tools)."""
|
"""Allow runtime overrides to expand the set (e.g. dynamic MCP tools)."""
|
||||||
ctx_tools = getattr(runtime.context, "registered_tool_names", None)
|
ctx_tools = getattr(runtime.context, "registered_tool_names", None)
|
||||||
if isinstance(ctx_tools, (set, frozenset)):
|
if isinstance(ctx_tools, set | frozenset):
|
||||||
return self._registered | set(ctx_tools)
|
return self._registered | set(ctx_tools)
|
||||||
if isinstance(ctx_tools, (list, tuple)):
|
if isinstance(ctx_tools, list | tuple):
|
||||||
return self._registered | set(ctx_tools)
|
return self._registered | set(ctx_tools)
|
||||||
return self._registered
|
return self._registered
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,16 +52,17 @@ class _YearSubstituterMiddleware(AgentMiddleware):
|
||||||
async def awrap_tool_call(
|
async def awrap_tool_call(
|
||||||
self,
|
self,
|
||||||
request: ToolCallRequest,
|
request: ToolCallRequest,
|
||||||
handler: Callable[
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||||
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
|
|
||||||
],
|
|
||||||
) -> ToolMessage | Command[Any]:
|
) -> ToolMessage | Command[Any]:
|
||||||
result = await handler(request)
|
result = await handler(request)
|
||||||
try:
|
try:
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
|
|
||||||
if isinstance(result, ToolMessage) and isinstance(result.content, str):
|
if (
|
||||||
if "{{year}}" in result.content:
|
isinstance(result, ToolMessage)
|
||||||
|
and isinstance(result.content, str)
|
||||||
|
and "{{year}}" in result.content
|
||||||
|
):
|
||||||
new_text = result.content.replace("{{year}}", self._year)
|
new_text = result.content.replace("{{year}}", self._year)
|
||||||
result = ToolMessage(
|
result = ToolMessage(
|
||||||
content=new_text,
|
content=new_text,
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,9 @@ ProviderVariant = str
|
||||||
# More specific patterns must come first (e.g. ``codex`` before
|
# More specific patterns must come first (e.g. ``codex`` before
|
||||||
# ``openai_reasoning`` because codex model ids contain ``gpt``).
|
# ``openai_reasoning`` because codex model ids contain ``gpt``).
|
||||||
|
|
||||||
_OPENAI_CODEX_RE = re.compile(r"\b(gpt-codex|codex-mini|gpt-[\d.]+-codex)\b", re.IGNORECASE)
|
_OPENAI_CODEX_RE = re.compile(
|
||||||
|
r"\b(gpt-codex|codex-mini|gpt-[\d.]+-codex)\b", re.IGNORECASE
|
||||||
|
)
|
||||||
_OPENAI_REASONING_RE = re.compile(r"\b(gpt-5|o\d|o-)", re.IGNORECASE)
|
_OPENAI_REASONING_RE = re.compile(r"\b(gpt-5|o\d|o-)", re.IGNORECASE)
|
||||||
_OPENAI_CLASSIC_RE = re.compile(r"\bgpt-4", re.IGNORECASE)
|
_OPENAI_CLASSIC_RE = re.compile(r"\bgpt-4", re.IGNORECASE)
|
||||||
_ANTHROPIC_RE = re.compile(r"\bclaude\b", re.IGNORECASE)
|
_ANTHROPIC_RE = re.compile(r"\bclaude\b", re.IGNORECASE)
|
||||||
|
|
@ -257,9 +259,7 @@ def _build_tools_section(
|
||||||
)
|
)
|
||||||
if known_disabled:
|
if known_disabled:
|
||||||
disabled_list = ", ".join(
|
disabled_list = ", ".join(
|
||||||
_format_tool_label(n)
|
_format_tool_label(n) for n in ALL_TOOL_NAMES_ORDERED if n in known_disabled
|
||||||
for n in ALL_TOOL_NAMES_ORDERED
|
|
||||||
if n in known_disabled
|
|
||||||
)
|
)
|
||||||
parts.append(
|
parts.append(
|
||||||
"\n"
|
"\n"
|
||||||
|
|
|
||||||
|
|
@ -279,9 +279,7 @@ def build_explore_subagent(
|
||||||
|
|
||||||
selected_tools = _filter_tools(tools, EXPLORE_READ_TOOLS)
|
selected_tools = _filter_tools(tools, EXPLORE_READ_TOOLS)
|
||||||
deny_rules = _read_only_deny_rules()
|
deny_rules = _read_only_deny_rules()
|
||||||
permission_mw = _build_permission_middleware(
|
permission_mw = _build_permission_middleware(deny_rules, origin="subagent_explore")
|
||||||
deny_rules, origin="subagent_explore"
|
|
||||||
)
|
|
||||||
|
|
||||||
spec: dict = {
|
spec: dict = {
|
||||||
"name": "explore",
|
"name": "explore",
|
||||||
|
|
|
||||||
|
|
@ -111,6 +111,8 @@ from .update_memory import create_update_memory_tool, create_update_team_memory_
|
||||||
from .video_presentation import create_generate_video_presentation_tool
|
from .video_presentation import create_generate_video_presentation_tool
|
||||||
from .web_search import create_web_search_tool
|
from .web_search import create_web_search_tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Tool Definition
|
# Tool Definition
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ Goals
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
|
@ -154,18 +155,14 @@ def span(
|
||||||
|
|
||||||
with tracer.start_as_current_span(name) as sp:
|
with tracer.start_as_current_span(name) as sp:
|
||||||
if attributes:
|
if attributes:
|
||||||
try:
|
with contextlib.suppress(Exception): # pragma: no cover — defensive
|
||||||
sp.set_attributes(attributes)
|
sp.set_attributes(attributes)
|
||||||
except Exception: # pragma: no cover — defensive
|
|
||||||
pass
|
|
||||||
try:
|
try:
|
||||||
yield sp
|
yield sp
|
||||||
except BaseException as exc:
|
except BaseException as exc:
|
||||||
try:
|
with contextlib.suppress(Exception): # pragma: no cover — defensive
|
||||||
sp.record_exception(exc)
|
sp.record_exception(exc)
|
||||||
sp.set_status(_OtStatus(_OtStatusCode.ERROR, str(exc)))
|
sp.set_status(_OtStatus(_OtStatusCode.ERROR, str(exc)))
|
||||||
except Exception: # pragma: no cover — defensive
|
|
||||||
pass
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ class AgentFeatureFlagsRead(BaseModel):
|
||||||
enable_otel: bool
|
enable_otel: bool
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_flags(cls, flags: AgentFeatureFlags) -> "AgentFeatureFlagsRead":
|
def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead:
|
||||||
# asdict() avoids missing-field bugs when AgentFeatureFlags grows.
|
# asdict() avoids missing-field bugs when AgentFeatureFlags grows.
|
||||||
return cls(**asdict(flags))
|
return cls(**asdict(flags))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -210,7 +210,7 @@ async def create_rule(
|
||||||
session.add(row)
|
session.add(row)
|
||||||
try:
|
try:
|
||||||
await session.commit()
|
await session.commit()
|
||||||
except IntegrityError:
|
except IntegrityError as err:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=409,
|
status_code=409,
|
||||||
|
|
@ -218,7 +218,7 @@ async def create_rule(
|
||||||
"An identical rule already exists for this scope. Update the "
|
"An identical rule already exists for this scope. Update the "
|
||||||
"existing rule instead."
|
"existing rule instead."
|
||||||
),
|
),
|
||||||
)
|
) from err
|
||||||
await session.refresh(row)
|
await session.refresh(row)
|
||||||
return _to_read(row)
|
return _to_read(row)
|
||||||
|
|
||||||
|
|
@ -248,12 +248,12 @@ async def update_rule(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await session.commit()
|
await session.commit()
|
||||||
except IntegrityError:
|
except IntegrityError as err:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=409,
|
status_code=409,
|
||||||
detail="Update would create a duplicate rule for this scope.",
|
detail="Update would create a duplicate rule for this scope.",
|
||||||
)
|
) from err
|
||||||
await session.refresh(row)
|
await session.refresh(row)
|
||||||
return _to_read(row)
|
return _to_read(row)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -97,10 +97,12 @@ async def revert_agent_action(
|
||||||
action=action,
|
action=action,
|
||||||
requester_user_id=str(user.id) if user is not None else None,
|
requester_user_id=str(user.id) if user is not None else None,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as err:
|
||||||
logger.exception("Revert dispatch raised for action_id=%s", action_id)
|
logger.exception("Revert dispatch raised for action_id=%s", action_id)
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(status_code=500, detail="Internal error during revert.")
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Internal error during revert."
|
||||||
|
) from err
|
||||||
|
|
||||||
if outcome.status == "ok":
|
if outcome.status == "ok":
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
|
||||||
|
|
@ -1242,7 +1242,9 @@ async def handle_new_chat(
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
image_urls = (
|
image_urls = (
|
||||||
[p.as_data_url() for p in request.user_images] if request.user_images else None
|
[p.as_data_url() for p in request.user_images]
|
||||||
|
if request.user_images
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
|
|
||||||
|
|
@ -79,9 +79,7 @@ async def load_action(
|
||||||
return result.scalars().first()
|
return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
async def load_thread(
|
async def load_thread(session: AsyncSession, *, thread_id: int) -> NewChatThread | None:
|
||||||
session: AsyncSession, *, thread_id: int
|
|
||||||
) -> NewChatThread | None:
|
|
||||||
stmt = select(NewChatThread).where(NewChatThread.id == thread_id)
|
stmt = select(NewChatThread).where(NewChatThread.id == thread_id)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return result.scalars().first()
|
return result.scalars().first()
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,9 @@ import binascii
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def build_human_message_content(final_query: str, image_data_urls: list[str]) -> str | list[dict[str, Any]]:
|
def build_human_message_content(
|
||||||
|
final_query: str, image_data_urls: list[str]
|
||||||
|
) -> str | list[dict[str, Any]]:
|
||||||
if not image_data_urls:
|
if not image_data_urls:
|
||||||
return final_query
|
return final_query
|
||||||
parts: list[dict[str, Any]] = [{"type": "text", "text": final_query}]
|
parts: list[dict[str, Any]] = [{"type": "text", "text": final_query}]
|
||||||
|
|
|
||||||
|
|
@ -90,9 +90,7 @@ class TestCompose:
|
||||||
assert "<citation_instructions>" in prompt
|
assert "<citation_instructions>" in prompt
|
||||||
assert "[citation:chunk_id]" in prompt
|
assert "[citation:chunk_id]" in prompt
|
||||||
|
|
||||||
def test_team_visibility_uses_team_variants(
|
def test_team_visibility_uses_team_variants(self, fixed_today: datetime) -> None:
|
||||||
self, fixed_today: datetime
|
|
||||||
) -> None:
|
|
||||||
prompt = compose_system_prompt(
|
prompt = compose_system_prompt(
|
||||||
today=fixed_today,
|
today=fixed_today,
|
||||||
thread_visibility=ChatVisibility.SEARCH_SPACE,
|
thread_visibility=ChatVisibility.SEARCH_SPACE,
|
||||||
|
|
@ -145,9 +143,7 @@ class TestCompose:
|
||||||
assert "Generate Image" in prompt
|
assert "Generate Image" in prompt
|
||||||
assert "Generate Podcast" in prompt
|
assert "Generate Podcast" in prompt
|
||||||
|
|
||||||
def test_mcp_routing_block_emits_when_provided(
|
def test_mcp_routing_block_emits_when_provided(self, fixed_today: datetime) -> None:
|
||||||
self, fixed_today: datetime
|
|
||||||
) -> None:
|
|
||||||
prompt = compose_system_prompt(
|
prompt = compose_system_prompt(
|
||||||
today=fixed_today,
|
today=fixed_today,
|
||||||
mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]},
|
mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]},
|
||||||
|
|
@ -162,9 +158,7 @@ class TestCompose:
|
||||||
prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={})
|
prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={})
|
||||||
assert "<mcp_tool_routing>" not in prompt
|
assert "<mcp_tool_routing>" not in prompt
|
||||||
|
|
||||||
def test_provider_block_renders_when_anthropic(
|
def test_provider_block_renders_when_anthropic(self, fixed_today: datetime) -> None:
|
||||||
self, fixed_today: datetime
|
|
||||||
) -> None:
|
|
||||||
prompt = compose_system_prompt(
|
prompt = compose_system_prompt(
|
||||||
today=fixed_today, model_name="anthropic:claude-3-5-sonnet"
|
today=fixed_today, model_name="anthropic:claude-3-5-sonnet"
|
||||||
)
|
)
|
||||||
|
|
@ -267,7 +261,10 @@ class TestStableOrderingForCacheStability:
|
||||||
)
|
)
|
||||||
b = compose_system_prompt(
|
b = compose_system_prompt(
|
||||||
today=fixed_today,
|
today=fixed_today,
|
||||||
enabled_tool_names={"scrape_webpage", "web_search"}, # set order shouldn't matter
|
enabled_tool_names={
|
||||||
|
"scrape_webpage",
|
||||||
|
"web_search",
|
||||||
|
}, # set order shouldn't matter
|
||||||
mcp_connector_tools={"X": ["x_a", "x_b"]},
|
mcp_connector_tools={"X": ["x_a", "x_b"]},
|
||||||
)
|
)
|
||||||
assert a == b
|
assert a == b
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,11 @@ class TestActionLogMiddlewareDisabled:
|
||||||
async def test_no_op_when_flag_off(self, patch_get_flags) -> None:
|
async def test_no_op_when_flag_off(self, patch_get_flags) -> None:
|
||||||
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||||
request = _FakeRequest(
|
request = _FakeRequest(
|
||||||
tool_call={"name": "make_widget", "args": {"color": "red", "size": 1}, "id": "tc1"}
|
tool_call={
|
||||||
|
"name": "make_widget",
|
||||||
|
"args": {"color": "red", "size": 1},
|
||||||
|
"id": "tc1",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
|
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
|
||||||
with patch_get_flags(_disabled_flags()):
|
with patch_get_flags(_disabled_flags()):
|
||||||
|
|
@ -117,13 +121,12 @@ class TestActionLogMiddlewarePersistence:
|
||||||
"id": "tc-abc",
|
"id": "tc-abc",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
result_msg = ToolMessage(
|
result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1")
|
||||||
content="ok", tool_call_id="tc-abc", id="msg-1"
|
|
||||||
)
|
|
||||||
handler = AsyncMock(return_value=result_msg)
|
handler = AsyncMock(return_value=result_msg)
|
||||||
|
|
||||||
with patch_get_flags(_enabled_flags()), patch(
|
with (
|
||||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||||
):
|
):
|
||||||
result = await mw.awrap_tool_call(request, handler)
|
result = await mw.awrap_tool_call(request, handler)
|
||||||
|
|
||||||
|
|
@ -151,9 +154,11 @@ class TestActionLogMiddlewarePersistence:
|
||||||
)
|
)
|
||||||
handler = AsyncMock(side_effect=ValueError("boom"))
|
handler = AsyncMock(side_effect=ValueError("boom"))
|
||||||
|
|
||||||
with patch_get_flags(_enabled_flags()), patch(
|
with (
|
||||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
patch_get_flags(_enabled_flags()),
|
||||||
), pytest.raises(ValueError, match="boom"):
|
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||||
|
pytest.raises(ValueError, match="boom"),
|
||||||
|
):
|
||||||
await mw.awrap_tool_call(request, handler)
|
await mw.awrap_tool_call(request, handler)
|
||||||
|
|
||||||
assert len(captured["rows"]) == 1
|
assert len(captured["rows"]) == 1
|
||||||
|
|
@ -177,8 +182,9 @@ class TestActionLogMiddlewarePersistence:
|
||||||
def _exploding_session():
|
def _exploding_session():
|
||||||
raise RuntimeError("DB is down")
|
raise RuntimeError("DB is down")
|
||||||
|
|
||||||
with patch_get_flags(_enabled_flags()), patch(
|
with (
|
||||||
"app.db.shielded_async_session", side_effect=_exploding_session
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch("app.db.shielded_async_session", side_effect=_exploding_session),
|
||||||
):
|
):
|
||||||
result = await mw.awrap_tool_call(request, handler)
|
result = await mw.awrap_tool_call(request, handler)
|
||||||
assert result is result_msg
|
assert result is result_msg
|
||||||
|
|
@ -218,8 +224,9 @@ class TestReverseDescriptor:
|
||||||
)
|
)
|
||||||
handler = AsyncMock(return_value=result_msg)
|
handler = AsyncMock(return_value=result_msg)
|
||||||
|
|
||||||
with patch_get_flags(_enabled_flags()), patch(
|
with (
|
||||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||||
):
|
):
|
||||||
await mw.awrap_tool_call(request, handler)
|
await mw.awrap_tool_call(request, handler)
|
||||||
|
|
||||||
|
|
@ -257,8 +264,9 @@ class TestReverseDescriptor:
|
||||||
result_msg = ToolMessage(content="ok", tool_call_id="tc1")
|
result_msg = ToolMessage(content="ok", tool_call_id="tc1")
|
||||||
handler = AsyncMock(return_value=result_msg)
|
handler = AsyncMock(return_value=result_msg)
|
||||||
|
|
||||||
with patch_get_flags(_enabled_flags()), patch(
|
with (
|
||||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||||
):
|
):
|
||||||
await mw.awrap_tool_call(request, handler)
|
await mw.awrap_tool_call(request, handler)
|
||||||
|
|
||||||
|
|
@ -275,11 +283,10 @@ class TestReverseDescriptor:
|
||||||
request = _FakeRequest(
|
request = _FakeRequest(
|
||||||
tool_call={"name": "unknown_tool", "args": {}, "id": "tc1"}
|
tool_call={"name": "unknown_tool", "args": {}, "id": "tc1"}
|
||||||
)
|
)
|
||||||
handler = AsyncMock(
|
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
|
||||||
return_value=ToolMessage(content="ok", tool_call_id="tc1")
|
with (
|
||||||
)
|
patch_get_flags(_enabled_flags()),
|
||||||
with patch_get_flags(_enabled_flags()), patch(
|
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
|
||||||
):
|
):
|
||||||
await mw.awrap_tool_call(request, handler)
|
await mw.awrap_tool_call(request, handler)
|
||||||
row = captured["rows"][0]
|
row = captured["rows"][0]
|
||||||
|
|
@ -298,11 +305,10 @@ class TestArgsTruncation:
|
||||||
request = _FakeRequest(
|
request = _FakeRequest(
|
||||||
tool_call={"name": "make_widget", "args": {"blob": huge}, "id": "tc1"},
|
tool_call={"name": "make_widget", "args": {"blob": huge}, "id": "tc1"},
|
||||||
)
|
)
|
||||||
handler = AsyncMock(
|
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
|
||||||
return_value=ToolMessage(content="ok", tool_call_id="tc1")
|
with (
|
||||||
)
|
patch_get_flags(_enabled_flags()),
|
||||||
with patch_get_flags(_enabled_flags()), patch(
|
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
|
||||||
):
|
):
|
||||||
await mw.awrap_tool_call(request, handler)
|
await mw.awrap_tool_call(request, handler)
|
||||||
row = captured["rows"][0]
|
row = captured["rows"][0]
|
||||||
|
|
|
||||||
|
|
@ -26,10 +26,16 @@ class TestIsProtectedSystemMessage:
|
||||||
assert _is_protected_system_message(msg) is True
|
assert _is_protected_system_message(msg) is True
|
||||||
|
|
||||||
def test_unprotected_system_message(self) -> None:
|
def test_unprotected_system_message(self) -> None:
|
||||||
assert _is_protected_system_message(SystemMessage(content="random instructions")) is False
|
assert (
|
||||||
|
_is_protected_system_message(SystemMessage(content="random instructions"))
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
def test_human_message_never_protected(self) -> None:
|
def test_human_message_never_protected(self) -> None:
|
||||||
assert _is_protected_system_message(HumanMessage(content="<workspace_tree>...")) is False
|
assert (
|
||||||
|
_is_protected_system_message(HumanMessage(content="<workspace_tree>..."))
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
def test_tolerates_leading_whitespace(self) -> None:
|
def test_tolerates_leading_whitespace(self) -> None:
|
||||||
msg = SystemMessage(content=" \n<priority_documents>\n...")
|
msg = SystemMessage(content=" \n<priority_documents>\n...")
|
||||||
|
|
@ -97,11 +103,17 @@ class TestPartitionMessages:
|
||||||
assert protected not in to_summary
|
assert protected not in to_summary
|
||||||
assert protected in preserved
|
assert protected in preserved
|
||||||
# The non-protected old messages remain in to_summary
|
# The non-protected old messages remain in to_summary
|
||||||
assert any(isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary)
|
assert any(
|
||||||
|
isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary
|
||||||
|
)
|
||||||
|
|
||||||
def test_unprotected_messages_unaffected(self) -> None:
|
def test_unprotected_messages_unaffected(self) -> None:
|
||||||
partitioner = self._build_partitioner()
|
partitioner = self._build_partitioner()
|
||||||
msgs = [HumanMessage(content="a"), HumanMessage(content="b"), HumanMessage(content="c")]
|
msgs = [
|
||||||
|
HumanMessage(content="a"),
|
||||||
|
HumanMessage(content="b"),
|
||||||
|
HumanMessage(content="c"),
|
||||||
|
]
|
||||||
to_summary, preserved = partitioner._partition_messages(msgs, 2)
|
to_summary, preserved = partitioner._partition_messages(msgs, 2)
|
||||||
assert [m.content for m in to_summary] == ["a", "b"]
|
assert [m.content for m in to_summary] == ["a", "b"]
|
||||||
assert [m.content for m in preserved] == ["c"]
|
assert [m.content for m in preserved] == ["c"]
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,8 @@ class TestSpillEdit:
|
||||||
|
|
||||||
# Earlier ToolMessages should now contain the placeholder text
|
# Earlier ToolMessages should now contain the placeholder text
|
||||||
cleared = [
|
cleared = [
|
||||||
m for m in tool_messages
|
m
|
||||||
|
for m in tool_messages
|
||||||
if isinstance(m.content, str) and m.content.startswith("[cleared")
|
if isinstance(m.content, str) and m.content.startswith("[cleared")
|
||||||
]
|
]
|
||||||
assert len(cleared) >= 1
|
assert len(cleared) >= 1
|
||||||
|
|
|
||||||
|
|
@ -46,9 +46,21 @@ def test_callable_dedup_key_takes_priority() -> None:
|
||||||
state = {
|
state = {
|
||||||
"messages": [
|
"messages": [
|
||||||
_msg(
|
_msg(
|
||||||
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "1"},
|
{
|
||||||
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "2"},
|
"name": "create_doc",
|
||||||
{"name": "create_doc", "args": {"parent_id": "x", "title": "z"}, "id": "3"},
|
"args": {"parent_id": "x", "title": "y"},
|
||||||
|
"id": "1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "create_doc",
|
||||||
|
"args": {"parent_id": "x", "title": "y"},
|
||||||
|
"id": "2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "create_doc",
|
||||||
|
"args": {"parent_id": "x", "title": "z"},
|
||||||
|
"id": "3",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -84,9 +84,7 @@ class TestConnectorDenyOverridesDefaultAllow:
|
||||||
Rule(permission="linear_create_issue", pattern="*", action="deny")
|
Rule(permission="linear_create_issue", pattern="*", action="deny")
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
rules = evaluate_many(
|
rules = evaluate_many("linear_create_issue", ["linear_create_issue"], *rulesets)
|
||||||
"linear_create_issue", ["linear_create_issue"], *rulesets
|
|
||||||
)
|
|
||||||
assert aggregate_action(rules) == "deny"
|
assert aggregate_action(rules) == "deny"
|
||||||
|
|
||||||
def test_default_allow_still_applies_to_other_tools(self) -> None:
|
def test_default_allow_still_applies_to_other_tools(self) -> None:
|
||||||
|
|
@ -124,5 +122,7 @@ class TestUserRuleOverridesDefault:
|
||||||
rules=[Rule(permission="send_*", pattern="*", action="deny")],
|
rules=[Rule(permission="send_*", pattern="*", action="deny")],
|
||||||
origin="user",
|
origin="user",
|
||||||
)
|
)
|
||||||
rules = evaluate_many("send_gmail_email", ["send_gmail_email"], defaults, user_ruleset)
|
rules = evaluate_many(
|
||||||
|
"send_gmail_email", ["send_gmail_email"], defaults, user_ruleset
|
||||||
|
)
|
||||||
assert aggregate_action(rules) == "deny"
|
assert aggregate_action(rules) == "deny"
|
||||||
|
|
|
||||||
|
|
@ -64,22 +64,17 @@ def test_threshold_triggers_after_n_identical_calls() -> None:
|
||||||
runtime,
|
runtime,
|
||||||
)
|
)
|
||||||
name = type(excinfo.value).__name__.lower()
|
name = type(excinfo.value).__name__.lower()
|
||||||
assert (
|
assert "interrupt" in name or "runtimeerror" in name, (
|
||||||
"interrupt" in name
|
f"Expected an interrupt-style exception, got {name}"
|
||||||
or "runtimeerror" in name
|
)
|
||||||
), f"Expected an interrupt-style exception, got {name}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_does_not_trigger_when_args_differ() -> None:
|
def test_does_not_trigger_when_args_differ() -> None:
|
||||||
mw = DoomLoopMiddleware(threshold=2)
|
mw = DoomLoopMiddleware(threshold=2)
|
||||||
runtime = _FakeRuntime()
|
runtime = _FakeRuntime()
|
||||||
out = mw.after_model(
|
out = mw.after_model({"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime)
|
||||||
{"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime
|
|
||||||
)
|
|
||||||
assert out is None
|
assert out is None
|
||||||
out = mw.after_model(
|
out = mw.after_model({"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime)
|
||||||
{"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime
|
|
||||||
)
|
|
||||||
assert out is None
|
assert out is None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,9 @@ class TestShouldInject:
|
||||||
mw = NoopInjectionMiddleware()
|
mw = NoopInjectionMiddleware()
|
||||||
req = _FakeRequest(
|
req = _FakeRequest(
|
||||||
tools=[object()],
|
tools=[object()],
|
||||||
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])],
|
messages=[
|
||||||
|
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])
|
||||||
|
],
|
||||||
model=_LiteLLMModel(),
|
model=_LiteLLMModel(),
|
||||||
)
|
)
|
||||||
assert mw._should_inject(req) is False
|
assert mw._should_inject(req) is False
|
||||||
|
|
@ -109,7 +111,9 @@ class TestShouldInject:
|
||||||
mw = NoopInjectionMiddleware()
|
mw = NoopInjectionMiddleware()
|
||||||
req = _FakeRequest(
|
req = _FakeRequest(
|
||||||
tools=[],
|
tools=[],
|
||||||
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])],
|
messages=[
|
||||||
|
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])
|
||||||
|
],
|
||||||
model=_OpenAIModel(),
|
model=_OpenAIModel(),
|
||||||
)
|
)
|
||||||
assert mw._should_inject(req) is False
|
assert mw._should_inject(req) is False
|
||||||
|
|
|
||||||
|
|
@ -111,6 +111,4 @@ class TestAsk:
|
||||||
assert out is None # call kept
|
assert out is None # call kept
|
||||||
# Runtime ruleset got the always-allow rule
|
# Runtime ruleset got the always-allow rule
|
||||||
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
|
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
|
||||||
assert any(
|
assert any(r.permission == "send_email" for r in new_rules)
|
||||||
r.permission == "send_email" for r in new_rules
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,9 @@ class TestPluginLoaderBasics:
|
||||||
"app.agents.new_chat.plugin_loader.entry_points",
|
"app.agents.new_chat.plugin_loader.entry_points",
|
||||||
return_value=[ep],
|
return_value=[ep],
|
||||||
):
|
):
|
||||||
result = load_plugin_middlewares(_ctx(), allowed_plugin_names=["allowed_only"])
|
result = load_plugin_middlewares(
|
||||||
|
_ctx(), allowed_plugin_names=["allowed_only"]
|
||||||
|
)
|
||||||
assert result == []
|
assert result == []
|
||||||
assert not called
|
assert not called
|
||||||
|
|
||||||
|
|
@ -135,9 +137,7 @@ class TestPluginLoaderIsolation:
|
||||||
_FakeEntryPoint("crashing", crashing_factory),
|
_FakeEntryPoint("crashing", crashing_factory),
|
||||||
_FakeEntryPoint("ok", year_substituter_factory),
|
_FakeEntryPoint("ok", year_substituter_factory),
|
||||||
]
|
]
|
||||||
with patch(
|
with patch("app.agents.new_chat.plugin_loader.entry_points", return_value=eps):
|
||||||
"app.agents.new_chat.plugin_loader.entry_points", return_value=eps
|
|
||||||
):
|
|
||||||
result = load_plugin_middlewares(
|
result = load_plugin_middlewares(
|
||||||
_ctx(), allowed_plugin_names={"crashing", "ok"}
|
_ctx(), allowed_plugin_names={"crashing", "ok"}
|
||||||
)
|
)
|
||||||
|
|
@ -151,9 +151,7 @@ class TestAllowlistEnv:
|
||||||
assert load_allowed_plugin_names_from_env() == set()
|
assert load_allowed_plugin_names_from_env() == set()
|
||||||
|
|
||||||
def test_parses_comma_separated_value(self, monkeypatch) -> None:
|
def test_parses_comma_separated_value(self, monkeypatch) -> None:
|
||||||
monkeypatch.setenv(
|
monkeypatch.setenv("SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , ")
|
||||||
"SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , "
|
|
||||||
)
|
|
||||||
assert load_allowed_plugin_names_from_env() == {
|
assert load_allowed_plugin_names_from_env() == {
|
||||||
"year_substituter",
|
"year_substituter",
|
||||||
"noisy",
|
"noisy",
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ class _FakeResponse:
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
|
|
||||||
|
|
||||||
class _FakeRateLimit(Exception):
|
class _FakeRateLimitError(Exception):
|
||||||
def __init__(self, msg: str, headers: dict[str, str] | None = None) -> None:
|
def __init__(self, msg: str, headers: dict[str, str] | None = None) -> None:
|
||||||
super().__init__(msg)
|
super().__init__(msg)
|
||||||
if headers is not None:
|
if headers is not None:
|
||||||
|
|
@ -27,15 +27,15 @@ class _FakeRateLimit(Exception):
|
||||||
|
|
||||||
class TestExtractRetryAfter:
|
class TestExtractRetryAfter:
|
||||||
def test_seconds_header(self) -> None:
|
def test_seconds_header(self) -> None:
|
||||||
exc = _FakeRateLimit("rate", {"Retry-After": "30"})
|
exc = _FakeRateLimitError("rate", {"Retry-After": "30"})
|
||||||
assert _extract_retry_after_seconds(exc) == 30.0
|
assert _extract_retry_after_seconds(exc) == 30.0
|
||||||
|
|
||||||
def test_milliseconds_header_overrides_seconds(self) -> None:
|
def test_milliseconds_header_overrides_seconds(self) -> None:
|
||||||
exc = _FakeRateLimit("rate", {"retry-after-ms": "1500"})
|
exc = _FakeRateLimitError("rate", {"retry-after-ms": "1500"})
|
||||||
assert _extract_retry_after_seconds(exc) == 1.5
|
assert _extract_retry_after_seconds(exc) == 1.5
|
||||||
|
|
||||||
def test_case_insensitive(self) -> None:
|
def test_case_insensitive(self) -> None:
|
||||||
exc = _FakeRateLimit("rate", {"RETRY-AFTER": "12"})
|
exc = _FakeRateLimitError("rate", {"RETRY-AFTER": "12"})
|
||||||
assert _extract_retry_after_seconds(exc) == 12.0
|
assert _extract_retry_after_seconds(exc) == 12.0
|
||||||
|
|
||||||
def test_falls_back_to_message_regex(self) -> None:
|
def test_falls_back_to_message_regex(self) -> None:
|
||||||
|
|
@ -67,7 +67,7 @@ class TestIsNonRetryable:
|
||||||
class TestDelayCalculation:
|
class TestDelayCalculation:
|
||||||
def test_takes_max_of_backoff_and_header(self) -> None:
|
def test_takes_max_of_backoff_and_header(self) -> None:
|
||||||
mw = RetryAfterMiddleware(max_retries=3, initial_delay=1.0, jitter=False)
|
mw = RetryAfterMiddleware(max_retries=3, initial_delay=1.0, jitter=False)
|
||||||
exc = _FakeRateLimit("rl", {"retry-after": "10"})
|
exc = _FakeRateLimitError("rl", {"retry-after": "10"})
|
||||||
delay = mw._delay_for_attempt(0, exc)
|
delay = mw._delay_for_attempt(0, exc)
|
||||||
assert delay == pytest.approx(10.0)
|
assert delay == pytest.approx(10.0)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -122,7 +122,9 @@ class TestExploreSubagent:
|
||||||
def test_includes_permission_middleware_with_deny_rules(self) -> None:
|
def test_includes_permission_middleware_with_deny_rules(self) -> None:
|
||||||
spec = build_explore_subagent(tools=ALL_TOOLS)
|
spec = build_explore_subagent(tools=ALL_TOOLS)
|
||||||
permission_mws = [
|
permission_mws = [
|
||||||
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
m
|
||||||
|
for m in spec["middleware"]
|
||||||
|
if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
||||||
]
|
]
|
||||||
assert len(permission_mws) == 1
|
assert len(permission_mws) == 1
|
||||||
ruleset = permission_mws[0]._static_rulesets[0]
|
ruleset = permission_mws[0]._static_rulesets[0]
|
||||||
|
|
@ -164,7 +166,9 @@ class TestReportWriterSubagent:
|
||||||
def test_deny_rules_block_writes_but_allow_generate_report(self) -> None:
|
def test_deny_rules_block_writes_but_allow_generate_report(self) -> None:
|
||||||
spec = build_report_writer_subagent(tools=ALL_TOOLS)
|
spec = build_report_writer_subagent(tools=ALL_TOOLS)
|
||||||
permission_mws = [
|
permission_mws = [
|
||||||
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
m
|
||||||
|
for m in spec["middleware"]
|
||||||
|
if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
||||||
]
|
]
|
||||||
ruleset = permission_mws[0]._static_rulesets[0]
|
ruleset = permission_mws[0]._static_rulesets[0]
|
||||||
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
|
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
|
||||||
|
|
@ -194,17 +198,15 @@ class TestConnectorNegotiatorSubagent:
|
||||||
def test_deny_ruleset_blocks_mutating_connector_tools(self) -> None:
|
def test_deny_ruleset_blocks_mutating_connector_tools(self) -> None:
|
||||||
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
|
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
|
||||||
permission_mws = [
|
permission_mws = [
|
||||||
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
m
|
||||||
|
for m in spec["middleware"]
|
||||||
|
if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
||||||
]
|
]
|
||||||
ruleset = permission_mws[0]._static_rulesets[0]
|
ruleset = permission_mws[0]._static_rulesets[0]
|
||||||
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
|
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
|
||||||
# `linear_create_issue` matches the `*_create` deny pattern.
|
# `linear_create_issue` matches the `*_create` deny pattern.
|
||||||
assert any(
|
assert any(_wildcard_matches(p, "linear_create_issue") for p in deny_patterns)
|
||||||
_wildcard_matches(p, "linear_create_issue") for p in deny_patterns
|
assert any(_wildcard_matches(p, "slack_send_message") for p in deny_patterns)
|
||||||
)
|
|
||||||
assert any(
|
|
||||||
_wildcard_matches(p, "slack_send_message") for p in deny_patterns
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestBuildSpecializedSubagents:
|
class TestBuildSpecializedSubagents:
|
||||||
|
|
@ -242,8 +244,7 @@ class TestBuildSpecializedSubagents:
|
||||||
# order: extra → custom → patch → dedup.
|
# order: extra → custom → patch → dedup.
|
||||||
sentinel_idx = mws.index(sentinel)
|
sentinel_idx = mws.index(sentinel)
|
||||||
perm_idx = next(
|
perm_idx = next(
|
||||||
(i for i, m in enumerate(mws)
|
(i for i, m in enumerate(mws) if isinstance(m, PermissionMiddleware)),
|
||||||
if isinstance(m, PermissionMiddleware)),
|
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
assert perm_idx is not None
|
assert perm_idx is not None
|
||||||
|
|
@ -259,7 +260,9 @@ class TestFilterToolsWarningSuppression:
|
||||||
|
|
||||||
from app.agents.new_chat.subagents.config import _filter_tools
|
from app.agents.new_chat.subagents.config import _filter_tools
|
||||||
|
|
||||||
with caplog.at_level(logging.INFO, logger="app.agents.new_chat.subagents.config"):
|
with caplog.at_level(
|
||||||
|
logging.INFO, logger="app.agents.new_chat.subagents.config"
|
||||||
|
):
|
||||||
# Allowed set asks for two registry tools (one present, one
|
# Allowed set asks for two registry tools (one present, one
|
||||||
# not) plus a bunch of middleware-provided names.
|
# not) plus a bunch of middleware-provided names.
|
||||||
_filter_tools(
|
_filter_tools(
|
||||||
|
|
@ -275,9 +278,7 @@ class TestFilterToolsWarningSuppression:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
warnings = [
|
warnings = [r.message for r in caplog.records if r.levelno >= logging.INFO]
|
||||||
r.message for r in caplog.records if r.levelno >= logging.INFO
|
|
||||||
]
|
|
||||||
# Exactly one warning, and it should mention scrape_webpage but not
|
# Exactly one warning, and it should mention scrape_webpage but not
|
||||||
# any middleware-provided name. Inspect the rendered "missing"
|
# any middleware-provided name. Inspect the rendered "missing"
|
||||||
# list (between the brackets) so we don't false-match substrings
|
# list (between the brackets) so we don't false-match substrings
|
||||||
|
|
|
||||||
|
|
@ -27,9 +27,12 @@ class TestRepair:
|
||||||
mw = ToolCallNameRepairMiddleware(
|
mw = ToolCallNameRepairMiddleware(
|
||||||
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
||||||
)
|
)
|
||||||
msg = AIMessage(content="", tool_calls=[
|
msg = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
{"name": "echo", "args": {}, "id": "1"},
|
{"name": "echo", "args": {}, "id": "1"},
|
||||||
])
|
],
|
||||||
|
)
|
||||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||||
assert out is None # no change
|
assert out is None # no change
|
||||||
|
|
||||||
|
|
@ -37,9 +40,12 @@ class TestRepair:
|
||||||
mw = ToolCallNameRepairMiddleware(
|
mw = ToolCallNameRepairMiddleware(
|
||||||
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
||||||
)
|
)
|
||||||
msg = AIMessage(content="", tool_calls=[
|
msg = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
{"name": "Echo", "args": {"x": 1}, "id": "1"},
|
{"name": "Echo", "args": {"x": 1}, "id": "1"},
|
||||||
])
|
],
|
||||||
|
)
|
||||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||||
assert out is not None
|
assert out is not None
|
||||||
repaired = out["messages"][0]
|
repaired = out["messages"][0]
|
||||||
|
|
@ -50,9 +56,12 @@ class TestRepair:
|
||||||
registered_tool_names={"echo", INVALID_TOOL_NAME},
|
registered_tool_names={"echo", INVALID_TOOL_NAME},
|
||||||
fuzzy_match_threshold=None,
|
fuzzy_match_threshold=None,
|
||||||
)
|
)
|
||||||
msg = AIMessage(content="", tool_calls=[
|
msg = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
{"name": "totally_different_name", "args": {"k": "v"}, "id": "1"},
|
{"name": "totally_different_name", "args": {"k": "v"}, "id": "1"},
|
||||||
])
|
],
|
||||||
|
)
|
||||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||||
assert out is not None
|
assert out is not None
|
||||||
repaired_call = out["messages"][0].tool_calls[0]
|
repaired_call = out["messages"][0].tool_calls[0]
|
||||||
|
|
@ -64,9 +73,12 @@ class TestRepair:
|
||||||
mw = ToolCallNameRepairMiddleware(
|
mw = ToolCallNameRepairMiddleware(
|
||||||
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
||||||
)
|
)
|
||||||
msg = AIMessage(content="", tool_calls=[
|
msg = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
{"name": "unknown", "args": {}, "id": "1"},
|
{"name": "unknown", "args": {}, "id": "1"},
|
||||||
])
|
],
|
||||||
|
)
|
||||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||||
# No repair available; original returned unchanged (no update)
|
# No repair available; original returned unchanged (no update)
|
||||||
assert out is None
|
assert out is None
|
||||||
|
|
@ -76,9 +88,12 @@ class TestRepair:
|
||||||
registered_tool_names={"search_documents"},
|
registered_tool_names={"search_documents"},
|
||||||
fuzzy_match_threshold=0.7,
|
fuzzy_match_threshold=0.7,
|
||||||
)
|
)
|
||||||
msg = AIMessage(content="", tool_calls=[
|
msg = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
{"name": "search_docments", "args": {}, "id": "1"},
|
{"name": "search_docments", "args": {}, "id": "1"},
|
||||||
])
|
],
|
||||||
|
)
|
||||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||||
assert out is not None
|
assert out is not None
|
||||||
assert out["messages"][0].tool_calls[0]["name"] == "search_documents"
|
assert out["messages"][0].tool_calls[0]["name"] == "search_documents"
|
||||||
|
|
@ -94,9 +109,12 @@ class TestRepair:
|
||||||
mw = ToolCallNameRepairMiddleware(
|
mw = ToolCallNameRepairMiddleware(
|
||||||
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
||||||
)
|
)
|
||||||
msg = AIMessage(content="", tool_calls=[
|
msg = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
{"name": "DynamicTool", "args": {}, "id": "1"},
|
{"name": "DynamicTool", "args": {}, "id": "1"},
|
||||||
])
|
],
|
||||||
|
)
|
||||||
runtime = _FakeRuntime(SimpleNamespace(registered_tool_names=["dynamictool"]))
|
runtime = _FakeRuntime(SimpleNamespace(registered_tool_names=["dynamictool"]))
|
||||||
out = mw.after_model(_make_state(msg), runtime)
|
out = mw.after_model(_make_state(msg), runtime)
|
||||||
assert out is not None
|
assert out is not None
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ through :class:`KnowledgeBasePersistenceMiddleware` without losing the copy.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,7 @@ class _FakeAction:
|
||||||
class TestCanRevert:
|
class TestCanRevert:
|
||||||
def test_owner_can_revert_their_own_action(self) -> None:
|
def test_owner_can_revert_their_own_action(self) -> None:
|
||||||
action = _FakeAction(user_id="user-123")
|
action = _FakeAction(user_id="user-123")
|
||||||
assert can_revert(
|
assert can_revert(requester_user_id="user-123", action=action, is_admin=False)
|
||||||
requester_user_id="user-123", action=action, is_admin=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_other_user_cannot_revert(self) -> None:
|
def test_other_user_cannot_revert(self) -> None:
|
||||||
action = _FakeAction(user_id="user-123")
|
action = _FakeAction(user_id="user-123")
|
||||||
|
|
@ -28,21 +26,15 @@ class TestCanRevert:
|
||||||
|
|
||||||
def test_admin_always_allowed(self) -> None:
|
def test_admin_always_allowed(self) -> None:
|
||||||
action = _FakeAction(user_id="user-123")
|
action = _FakeAction(user_id="user-123")
|
||||||
assert can_revert(
|
assert can_revert(requester_user_id="anybody", action=action, is_admin=True)
|
||||||
requester_user_id="anybody", action=action, is_admin=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_admin_can_revert_anonymous_action(self) -> None:
|
def test_admin_can_revert_anonymous_action(self) -> None:
|
||||||
action = _FakeAction(user_id=None)
|
action = _FakeAction(user_id=None)
|
||||||
assert can_revert(
|
assert can_revert(requester_user_id="admin", action=action, is_admin=True)
|
||||||
requester_user_id="admin", action=action, is_admin=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_anonymous_action_blocks_non_admin(self) -> None:
|
def test_anonymous_action_blocks_non_admin(self) -> None:
|
||||||
action = _FakeAction(user_id=None)
|
action = _FakeAction(user_id=None)
|
||||||
assert not can_revert(
|
assert not can_revert(requester_user_id="user-1", action=action, is_admin=False)
|
||||||
requester_user_id="user-1", action=action, is_admin=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_uuid_string_normalization(self) -> None:
|
def test_uuid_string_normalization(self) -> None:
|
||||||
"""``user_id`` may be a UUID object; comparison should still work."""
|
"""``user_id`` may be a UUID object; comparison should still work."""
|
||||||
|
|
@ -51,6 +43,4 @@ class TestCanRevert:
|
||||||
u = uuid.uuid4()
|
u = uuid.uuid4()
|
||||||
action = _FakeAction(user_id=u)
|
action = _FakeAction(user_id=u)
|
||||||
# Same UUID, passed as string from the requesting side.
|
# Same UUID, passed as string from the requesting side.
|
||||||
assert can_revert(
|
assert can_revert(requester_user_id=str(u), action=action, is_admin=False)
|
||||||
requester_user_id=str(u), action=action, is_admin=False
|
|
||||||
)
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue