chore: linting

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-28 21:37:51 -07:00
parent b9a66cb417
commit ca9bbee06d
41 changed files with 314 additions and 244 deletions

View file

@ -88,7 +88,5 @@ def upgrade() -> None:
def downgrade() -> None: def downgrade() -> None:
op.drop_index( op.drop_index("ix_agent_action_log_thread_created", table_name="agent_action_log")
"ix_agent_action_log_thread_created", table_name="agent_action_log"
)
op.drop_table("agent_action_log") op.drop_table("agent_action_log")

View file

@ -51,9 +51,7 @@ def upgrade() -> None:
# implicit-unique-index variant SQLAlchemy may emit need draining. # implicit-unique-index variant SQLAlchemy may emit need draining.
constraints = _existing_constraint_names(bind, "documents") constraints = _existing_constraint_names(bind, "documents")
if "uq_documents_content_hash" in constraints: if "uq_documents_content_hash" in constraints:
op.drop_constraint( op.drop_constraint("uq_documents_content_hash", "documents", type_="unique")
"uq_documents_content_hash", "documents", type_="unique"
)
indexes = _existing_index_names(bind, "documents") indexes = _existing_index_names(bind, "documents")
# Some Postgres versions surface the unique constraint via a unique # Some Postgres versions surface the unique constraint via a unique

View file

@ -416,10 +416,10 @@ async def create_surfsense_deep_agent(
# cheap to build. ``SubAgentMiddleware.__init__`` calls ``create_agent`` # cheap to build. ``SubAgentMiddleware.__init__`` calls ``create_agent``
# synchronously to compile the general-purpose subagent's full state graph # synchronously to compile the general-purpose subagent's full state graph
# (every tool + every middleware → pydantic schemas + langgraph compile). # (every tool + every middleware → pydantic schemas + langgraph compile).
# On gpt-5.x agents that's roughly 1.52s of pure CPU work. If we run it # On gpt-5.x agents that's roughly 1.5-2s of pure CPU work. If we run it
# directly here it blocks the asyncio event loop for the whole streaming # directly here it blocks the asyncio event loop for the whole streaming
# task (and any other coroutine sharing this loop), which is why # task (and any other coroutine sharing this loop), which is why
# "agent creation" wall-clock time used to stretch to ~34s. Move the # "agent creation" wall-clock time used to stretch to ~3-4s. Move the
# entire middleware build + main-graph compile into a single # entire middleware build + main-graph compile into a single
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the # ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
# event loop stays responsive. # event loop stays responsive.
@ -587,10 +587,7 @@ def _build_compiled_agent_blocking(
# by name. Off by default until the flag flips so existing deployments # by name. Off by default until the flag flips so existing deployments
# don't see new agent types in the task tool description. # don't see new agent types in the task tool description.
specialized_subagents: list[SubAgent] = [] specialized_subagents: list[SubAgent] = []
if ( if flags.enable_specialized_subagents and not flags.disable_new_agent_stack:
flags.enable_specialized_subagents
and not flags.disable_new_agent_stack
):
try: try:
# Specialized subagents share the parent's filesystem + # Specialized subagents share the parent's filesystem +
# todo view so their system prompts (which promise # todo view so their system prompts (which promise
@ -696,7 +693,9 @@ def _build_compiled_agent_blocking(
else None else None
) )
tool_call_limit_mw = ( tool_call_limit_mw = (
ToolCallLimitMiddleware(thread_limit=300, run_limit=80, exit_behavior="continue") ToolCallLimitMiddleware(
thread_limit=300, run_limit=80, exit_behavior="continue"
)
if flags.enable_tool_call_limit and not flags.disable_new_agent_stack if flags.enable_tool_call_limit and not flags.disable_new_agent_stack
else None else None
) )
@ -879,7 +878,11 @@ def _build_compiled_agent_blocking(
max_tools=12, max_tools=12,
always_include=[ always_include=[
name name
for name in ("update_memory", "get_connected_accounts", "scrape_webpage") for name in (
"update_memory",
"get_connected_accounts",
"scrape_webpage",
)
if name in {t.name for t in tools} if name in {t.name for t in tools}
], ],
) )

View file

@ -65,7 +65,9 @@ class AgentFeatureFlags:
enable_model_call_limit: bool = False enable_model_call_limit: bool = False
enable_tool_call_limit: bool = False enable_tool_call_limit: bool = False
enable_tool_call_repair: bool = False enable_tool_call_repair: bool = False
enable_doom_loop: bool = False # Default OFF until UI handles permission='doom_loop' enable_doom_loop: bool = (
False # Default OFF until UI handles permission='doom_loop'
)
# Tier 2 — Safety # Tier 2 — Safety
enable_permission: bool = False # Default OFF for first deploy enable_permission: bool = False # Default OFF for first deploy
@ -79,7 +81,9 @@ class AgentFeatureFlags:
# Tier 5 — Snapshot / revert # Tier 5 — Snapshot / revert
enable_action_log: bool = False enable_action_log: bool = False
enable_revert_route: bool = False # Backend ships before UI; route returns 503 until this flips enable_revert_route: bool = (
False # Backend ships before UI; route returns 503 until this flips
)
# Tier 6 — Plugins # Tier 6 — Plugins
enable_plugin_loader: bool = False enable_plugin_loader: bool = False
@ -109,14 +113,20 @@ class AgentFeatureFlags:
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False), enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False),
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False), enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False),
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False), enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
enable_model_call_limit=_env_bool("SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False), enable_model_call_limit=_env_bool(
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False
),
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False), enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False),
enable_tool_call_repair=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False), enable_tool_call_repair=_env_bool(
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False
),
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False), enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False),
# Tier 2 # Tier 2
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False), enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False),
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False), enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False),
enable_llm_tool_selector=_env_bool("SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False), enable_llm_tool_selector=_env_bool(
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
),
# Tier 4 # Tier 4
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False), enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False),
enable_specialized_subagents=_env_bool( enable_specialized_subagents=_env_bool(

View file

@ -101,9 +101,7 @@ class ActionLogMiddleware(AgentMiddleware):
async def awrap_tool_call( async def awrap_tool_call(
self, self,
request: ToolCallRequest, request: ToolCallRequest,
handler: Callable[ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
],
) -> ToolMessage | Command[Any]: ) -> ToolMessage | Command[Any]:
if not self._enabled(): if not self._enabled():
return await handler(request) return await handler(request)

View file

@ -177,8 +177,8 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware):
messages_in=len(conversation_messages), messages_in=len(conversation_messages),
extra={"compaction.cutoff_index": int(cutoff_index)}, extra={"compaction.cutoff_index": int(cutoff_index)},
): ):
messages_to_summarize, preserved_messages = ( messages_to_summarize, preserved_messages = super()._partition_messages(
super()._partition_messages(conversation_messages, cutoff_index) conversation_messages, cutoff_index
) )
protected: list[AnyMessage] = [] protected: list[AnyMessage] = []

View file

@ -58,8 +58,7 @@ DEFAULT_SPILL_PREFIX = "/tool_outputs"
def _build_spill_placeholder(spill_path: str) -> str: def _build_spill_placeholder(spill_path: str) -> str:
"""Build the user-facing placeholder text shown to the model.""" """Build the user-facing placeholder text shown to the model."""
return ( return (
f"[cleared — full output at {spill_path}; " f"[cleared — full output at {spill_path}; ask the explore subagent to read it]"
f"ask the explore subagent to read it]"
) )
@ -131,7 +130,9 @@ class SpillToBackendEdit(ContextEdit):
return return
candidates = [ candidates = [
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage) (idx, msg)
for idx, msg in enumerate(messages)
if isinstance(msg, ToolMessage)
] ]
if self.keep >= len(candidates): if self.keep >= len(candidates):
return return

View file

@ -137,16 +137,21 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon
triggered_call: dict[str, Any] | None = None triggered_call: dict[str, Any] | None = None
for call in message.tool_calls: for call in message.tool_calls:
name = call.get("name") if isinstance(call, dict) else getattr(call, "name", None) name = (
args = call.get("args") if isinstance(call, dict) else getattr(call, "args", {}) call.get("name")
if isinstance(call, dict)
else getattr(call, "name", None)
)
args = (
call.get("args")
if isinstance(call, dict)
else getattr(call, "args", {})
)
if not isinstance(name, str): if not isinstance(name, str):
continue continue
sig = _signature(name, args) sig = _signature(name, args)
window.append(sig) window.append(sig)
if ( if len(window) >= self._threshold and len(set(window)) == 1:
len(window) >= self._threshold
and len(set(window)) == 1
):
triggered_call = {"name": name, "params": args or {}} triggered_call = {"name": name, "params": args or {}}
break break
@ -209,7 +214,9 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon
# tool call proceeds. The frontend's exact reply names may differ — # tool call proceeds. The frontend's exact reply names may differ —
# we tolerate any shape that contains a string with "reject"/"cancel". # we tolerate any shape that contains a string with "reject"/"cancel".
if isinstance(decision, dict): if isinstance(decision, dict):
kind = str(decision.get("decision_type") or decision.get("type") or "").lower() kind = str(
decision.get("decision_type") or decision.get("type") or ""
).lower()
if "reject" in kind or "cancel" in kind: if "reject" in kind or "cancel" in kind:
return {"jump_to": "end"} return {"jump_to": "end"}
return None return None

View file

@ -552,7 +552,7 @@ def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage:
for entry in priority: for entry in priority:
score = entry.get("score") score = entry.get("score")
mentioned = entry.get("mentioned") mentioned = entry.get("mentioned")
score_str = f"{score:.3f}" if isinstance(score, (int, float)) else "n/a" score_str = f"{score:.3f}" if isinstance(score, int | float) else "n/a"
mark = " [USER-MENTIONED]" if mentioned else "" mark = " [USER-MENTIONED]" if mentioned else ""
lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}") lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}")
body = "\n".join(lines) body = "\n".join(lines)
@ -593,7 +593,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
self.top_k = top_k self.top_k = top_k
self.mentioned_document_ids = mentioned_document_ids or [] self.mentioned_document_ids = mentioned_document_ids or []
# Tier 4.2: build the kb-planner private Runnable ONCE here so we # Tier 4.2: build the kb-planner private Runnable ONCE here so we
# don't pay the create_agent compile cost (50200ms) on every turn. # don't pay the create_agent compile cost (50-200ms) on every turn.
# Disabled by default behind ``enable_kb_planner_runnable``; when off # Disabled by default behind ``enable_kb_planner_runnable``; when off
# the planner falls back to the legacy ``self.llm.ainvoke`` path. # the planner falls back to the legacy ``self.llm.ainvoke`` path.
self._planner: Runnable | None = None self._planner: Runnable | None = None
@ -617,10 +617,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
if self.llm is None: if self.llm is None:
return None return None
flags = get_flags() flags = get_flags()
if ( if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
not flags.enable_kb_planner_runnable
or flags.disable_new_agent_stack
):
return None return None
from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware
@ -920,7 +917,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
chunk_ids = doc.get("matched_chunk_ids") or [] chunk_ids = doc.get("matched_chunk_ids") or []
if chunk_ids: if chunk_ids:
matched_chunk_ids[doc_id] = [ matched_chunk_ids[doc_id] = [
int(cid) for cid in chunk_ids if isinstance(cid, (int, str)) int(cid) for cid in chunk_ids if isinstance(cid, int | str)
] ]
return priority, matched_chunk_ids return priority, matched_chunk_ids

View file

@ -35,9 +35,7 @@ from langchain_core.tools import tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
NOOP_TOOL_NAME = "_noop" NOOP_TOOL_NAME = "_noop"
NOOP_TOOL_DESCRIPTION = ( NOOP_TOOL_DESCRIPTION = "Do not call this tool. It exists only for API compatibility."
"Do not call this tool. It exists only for API compatibility."
)
@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION) @tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION)
@ -78,7 +76,9 @@ def _last_ai_has_tool_calls(messages: list[Any]) -> bool:
return False return False
class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): class NoopInjectionMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Inject the ``_noop`` tool only when the provider would otherwise 400. """Inject the ``_noop`` tool only when the provider would otherwise 400.
The check fires per model call, not at agent build time, because the The check fires per model call, not at agent build time, because the
@ -116,7 +116,9 @@ class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R
async def awrap_model_call( # type: ignore[override] async def awrap_model_call( # type: ignore[override]
self, self,
request: ModelRequest[ContextT], request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> Any: ) -> Any:
if self._should_inject(request): if self._should_inject(request):
logger.debug("Injecting _noop tool for provider compatibility") logger.debug("Injecting _noop tool for provider compatibility")

View file

@ -56,9 +56,7 @@ class OtelSpanMiddleware(AgentMiddleware):
async def awrap_model_call( async def awrap_model_call(
self, self,
request: ModelRequest, request: ModelRequest,
handler: Callable[ handler: Callable[[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]],
[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]
],
) -> ModelResponse | AIMessage | Any: ) -> ModelResponse | AIMessage | Any:
if not ot.is_enabled(): if not ot.is_enabled():
return await handler(request) return await handler(request)
@ -81,9 +79,7 @@ class OtelSpanMiddleware(AgentMiddleware):
async def awrap_tool_call( async def awrap_tool_call(
self, self,
request: ToolCallRequest, request: ToolCallRequest,
handler: Callable[ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
],
) -> ToolMessage | Command[Any]: ) -> ToolMessage | Command[Any]:
if not ot.is_enabled(): if not ot.is_enabled():
return await handler(request) return await handler(request)
@ -187,7 +183,11 @@ def _annotate_model_response(span: Any, result: Any) -> None:
def _annotate_tool_result(span: Any, result: Any) -> None: def _annotate_tool_result(span: Any, result: Any) -> None:
try: try:
if isinstance(result, ToolMessage): if isinstance(result, ToolMessage):
content = result.content if isinstance(result.content, str) else repr(result.content) content = (
result.content
if isinstance(result.content, str)
else repr(result.content)
)
span.set_attribute("tool.output.size", len(content)) span.set_attribute("tool.output.size", len(content))
status = getattr(result, "status", None) status = getattr(result, "status", None)
if isinstance(status, str): if isinstance(status, str):

View file

@ -145,7 +145,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
try: try:
patterns = resolver(args or {}) patterns = resolver(args or {})
except Exception: except Exception:
logger.exception("Pattern resolver for %s raised; using bare name", tool_name) logger.exception(
"Pattern resolver for %s raised; using bare name", tool_name
)
patterns = [tool_name] patterns = [tool_name]
if not patterns: if not patterns:
patterns = [tool_name] patterns = [tool_name]
@ -198,11 +200,14 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
# Tier 3b: permission.asked + interrupt.raised spans (no-op when # Tier 3b: permission.asked + interrupt.raised spans (no-op when
# OTel is disabled). Both fire here so dashboards can correlate # OTel is disabled). Both fire here so dashboards can correlate
# "we asked X" with "interrupt was actually delivered". # "we asked X" with "interrupt was actually delivered".
with ot.permission_asked_span( with (
ot.permission_asked_span(
permission=tool_name, permission=tool_name,
pattern=patterns[0] if patterns else None, pattern=patterns[0] if patterns else None,
extra={"permission.patterns": list(patterns)}, extra={"permission.patterns": list(patterns)},
), ot.interrupt_span(interrupt_type="permission_ask"): ),
ot.interrupt_span(interrupt_type="permission_ask"),
):
decision = interrupt(payload) decision = interrupt(payload)
if isinstance(decision, dict): if isinstance(decision, dict):
return decision return decision
@ -211,9 +216,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
return {"decision_type": decision} return {"decision_type": decision}
return {"decision_type": "reject"} return {"decision_type": "reject"}
def _persist_always( def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
self, tool_name: str, patterns: list[str]
) -> None:
"""Promote ``always`` reply into runtime allow rules. """Promote ``always`` reply into runtime allow rules.
Persistence to ``agent_permission_rules`` is done by the Persistence to ``agent_permission_rules`` is done by the
@ -276,12 +279,16 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
any_change = False any_change = False
for raw in last.tool_calls: for raw in last.tool_calls:
call = dict(raw) if isinstance(raw, dict) else { call = (
dict(raw)
if isinstance(raw, dict)
else {
"name": getattr(raw, "name", None), "name": getattr(raw, "name", None),
"args": getattr(raw, "args", {}), "args": getattr(raw, "args", {}),
"id": getattr(raw, "id", None), "id": getattr(raw, "id", None),
"type": "tool_call", "type": "tool_call",
} }
)
name = call.get("name") or "" name = call.get("name") or ""
args = call.get("args") or {} args = call.get("args") or {}
action, patterns, rules = self._evaluate(name, args) action, patterns, rules = self._evaluate(name, args)
@ -307,7 +314,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
feedback = decision.get("feedback") feedback = decision.get("feedback")
if isinstance(feedback, str) and feedback.strip(): if isinstance(feedback, str) and feedback.strip():
raise CorrectedError(feedback, tool=name) raise CorrectedError(feedback, tool=name)
raise RejectedError(tool=name, pattern=patterns[0] if patterns else None) raise RejectedError(
tool=name, pattern=patterns[0] if patterns else None
)
else: else:
logger.warning( logger.warning(
"Unknown permission decision %r; treating as reject", kind "Unknown permission decision %r; treating as reject", kind

View file

@ -113,7 +113,9 @@ def _exponential_delay(
jitter: bool, jitter: bool,
) -> float: ) -> float:
"""Compute an exponential-backoff delay with optional ±25% jitter.""" """Compute an exponential-backoff delay with optional ±25% jitter."""
delay = initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay delay = (
initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
)
delay = min(delay, max_delay) delay = min(delay, max_delay)
if jitter and delay > 0: if jitter and delay > 0:
delay *= 1 + random.uniform(-0.25, 0.25) delay *= 1 + random.uniform(-0.25, 0.25)
@ -201,7 +203,9 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
}, },
) )
except Exception: except Exception:
logger.debug("dispatch_custom_event failed; suppressed", exc_info=True) logger.debug(
"dispatch_custom_event failed; suppressed", exc_info=True
)
if delay > 0: if delay > 0:
time.sleep(delay) time.sleep(delay)
# Unreachable # Unreachable
@ -210,7 +214,9 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
async def awrap_model_call( # type: ignore[override] async def awrap_model_call( # type: ignore[override]
self, self,
request: ModelRequest[ContextT], request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> ModelResponse[ResponseT] | AIMessage: ) -> ModelResponse[ResponseT] | AIMessage:
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
try: try:

View file

@ -29,6 +29,7 @@ gives a clean failure mode if anything tries.
from __future__ import annotations from __future__ import annotations
import contextlib
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from dataclasses import replace from dataclasses import replace
@ -114,8 +115,10 @@ class BuiltinSkillsBackend(BackendProtocol):
infos: list[FileInfo] = [] infos: list[FileInfo] = []
# Build virtual paths anchored at "/" because CompositeBackend already # Build virtual paths anchored at "/" because CompositeBackend already
# stripped the route prefix before calling us. # stripped the route prefix before calling us.
target_virtual = "/" if target == self.root else ( target_virtual = (
"/" + str(target.relative_to(self.root)).replace("\\", "/") "/"
if target == self.root
else ("/" + str(target.relative_to(self.root)).replace("\\", "/"))
) )
for child in sorted(target.iterdir()): for child in sorted(target.iterdir()):
child_virtual = ( child_virtual = (
@ -128,10 +131,8 @@ class BuiltinSkillsBackend(BackendProtocol):
"is_dir": child.is_dir(), "is_dir": child.is_dir(),
} }
if child.is_file(): if child.is_file():
try: with contextlib.suppress(OSError): # pragma: no cover - defensive
info["size"] = child.stat().st_size info["size"] = child.stat().st_size
except OSError: # pragma: no cover - defensive
pass
infos.append(info) infos.append(info)
return infos return infos
@ -163,7 +164,9 @@ class BuiltinSkillsBackend(BackendProtocol):
else: else:
content = target.read_bytes() content = target.read_bytes()
except PermissionError: except PermissionError:
responses.append(FileDownloadResponse(path=p, error="permission_denied")) responses.append(
FileDownloadResponse(path=p, error="permission_denied")
)
continue continue
except OSError as exc: # pragma: no cover - defensive except OSError as exc: # pragma: no cover - defensive
logger.warning("Builtin skill read failed %s: %s", target, exc) logger.warning("Builtin skill read failed %s: %s", target, exc)
@ -286,6 +289,7 @@ def build_skills_backend_factory(
builtin = BuiltinSkillsBackend(builtin_root) builtin = BuiltinSkillsBackend(builtin_root)
if search_space_id is None: if search_space_id is None:
def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol: def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol:
# Default StateBackend is intentionally inert: any path outside the # Default StateBackend is intentionally inert: any path outside the
# ``/skills/builtin/`` route resolves to an empty per-runtime state # ``/skills/builtin/`` route resolves to an empty per-runtime state
@ -294,6 +298,7 @@ def build_skills_backend_factory(
default=StateBackend(runtime), default=StateBackend(runtime),
routes={SKILLS_BUILTIN_PREFIX: builtin}, routes={SKILLS_BUILTIN_PREFIX: builtin},
) )
return _factory_builtin_only return _factory_builtin_only
def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol: def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol:

View file

@ -51,13 +51,15 @@ def _coerce_existing_tool_call(call: Any) -> dict[str, Any]:
} }
class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): class ToolCallNameRepairMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Two-stage tool-name repair on the most recent ``AIMessage``. """Two-stage tool-name repair on the most recent ``AIMessage``.
Args: Args:
registered_tool_names: Set of canonically-registered tool names. registered_tool_names: Set of canonically-registered tool names.
``invalid`` should be in this set so the fallback dispatches. ``invalid`` should be in this set so the fallback dispatches.
fuzzy_match_threshold: Optional ``difflib`` ratio (01) for the fuzzy_match_threshold: Optional ``difflib`` ratio (0-1) for the
fuzzy-match step that runs *between* lowercase and invalid. fuzzy-match step that runs *between* lowercase and invalid.
Set to ``None`` to disable fuzzy matching (opencode parity). Set to ``None`` to disable fuzzy matching (opencode parity).
""" """
@ -77,9 +79,9 @@ class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], Contex
def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]: def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]:
"""Allow runtime overrides to expand the set (e.g. dynamic MCP tools).""" """Allow runtime overrides to expand the set (e.g. dynamic MCP tools)."""
ctx_tools = getattr(runtime.context, "registered_tool_names", None) ctx_tools = getattr(runtime.context, "registered_tool_names", None)
if isinstance(ctx_tools, (set, frozenset)): if isinstance(ctx_tools, set | frozenset):
return self._registered | set(ctx_tools) return self._registered | set(ctx_tools)
if isinstance(ctx_tools, (list, tuple)): if isinstance(ctx_tools, list | tuple):
return self._registered | set(ctx_tools) return self._registered | set(ctx_tools)
return self._registered return self._registered

View file

@ -52,16 +52,17 @@ class _YearSubstituterMiddleware(AgentMiddleware):
async def awrap_tool_call( async def awrap_tool_call(
self, self,
request: ToolCallRequest, request: ToolCallRequest,
handler: Callable[ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
],
) -> ToolMessage | Command[Any]: ) -> ToolMessage | Command[Any]:
result = await handler(request) result = await handler(request)
try: try:
from langchain_core.messages import ToolMessage from langchain_core.messages import ToolMessage
if isinstance(result, ToolMessage) and isinstance(result.content, str): if (
if "{{year}}" in result.content: isinstance(result, ToolMessage)
and isinstance(result.content, str)
and "{{year}}" in result.content
):
new_text = result.content.replace("{{year}}", self._year) new_text = result.content.replace("{{year}}", self._year)
result = ToolMessage( result = ToolMessage(
content=new_text, content=new_text,

View file

@ -62,7 +62,9 @@ ProviderVariant = str
# More specific patterns must come first (e.g. ``codex`` before # More specific patterns must come first (e.g. ``codex`` before
# ``openai_reasoning`` because codex model ids contain ``gpt``). # ``openai_reasoning`` because codex model ids contain ``gpt``).
_OPENAI_CODEX_RE = re.compile(r"\b(gpt-codex|codex-mini|gpt-[\d.]+-codex)\b", re.IGNORECASE) _OPENAI_CODEX_RE = re.compile(
r"\b(gpt-codex|codex-mini|gpt-[\d.]+-codex)\b", re.IGNORECASE
)
_OPENAI_REASONING_RE = re.compile(r"\b(gpt-5|o\d|o-)", re.IGNORECASE) _OPENAI_REASONING_RE = re.compile(r"\b(gpt-5|o\d|o-)", re.IGNORECASE)
_OPENAI_CLASSIC_RE = re.compile(r"\bgpt-4", re.IGNORECASE) _OPENAI_CLASSIC_RE = re.compile(r"\bgpt-4", re.IGNORECASE)
_ANTHROPIC_RE = re.compile(r"\bclaude\b", re.IGNORECASE) _ANTHROPIC_RE = re.compile(r"\bclaude\b", re.IGNORECASE)
@ -257,9 +259,7 @@ def _build_tools_section(
) )
if known_disabled: if known_disabled:
disabled_list = ", ".join( disabled_list = ", ".join(
_format_tool_label(n) _format_tool_label(n) for n in ALL_TOOL_NAMES_ORDERED if n in known_disabled
for n in ALL_TOOL_NAMES_ORDERED
if n in known_disabled
) )
parts.append( parts.append(
"\n" "\n"

View file

@ -279,9 +279,7 @@ def build_explore_subagent(
selected_tools = _filter_tools(tools, EXPLORE_READ_TOOLS) selected_tools = _filter_tools(tools, EXPLORE_READ_TOOLS)
deny_rules = _read_only_deny_rules() deny_rules = _read_only_deny_rules()
permission_mw = _build_permission_middleware( permission_mw = _build_permission_middleware(deny_rules, origin="subagent_explore")
deny_rules, origin="subagent_explore"
)
spec: dict = { spec: dict = {
"name": "explore", "name": "explore",

View file

@ -111,6 +111,8 @@ from .update_memory import create_update_memory_tool, create_update_team_memory_
from .video_presentation import create_generate_video_presentation_tool from .video_presentation import create_generate_video_presentation_tool
from .web_search import create_web_search_tool from .web_search import create_web_search_tool
logger = logging.getLogger(__name__)
# ============================================================================= # =============================================================================
# Tool Definition # Tool Definition
# ============================================================================= # =============================================================================

View file

@ -22,6 +22,7 @@ Goals
from __future__ import annotations from __future__ import annotations
import contextlib
import logging import logging
import os import os
from collections.abc import Iterator from collections.abc import Iterator
@ -154,18 +155,14 @@ def span(
with tracer.start_as_current_span(name) as sp: with tracer.start_as_current_span(name) as sp:
if attributes: if attributes:
try: with contextlib.suppress(Exception): # pragma: no cover — defensive
sp.set_attributes(attributes) sp.set_attributes(attributes)
except Exception: # pragma: no cover — defensive
pass
try: try:
yield sp yield sp
except BaseException as exc: except BaseException as exc:
try: with contextlib.suppress(Exception): # pragma: no cover — defensive
sp.record_exception(exc) sp.record_exception(exc)
sp.set_status(_OtStatus(_OtStatusCode.ERROR, str(exc))) sp.set_status(_OtStatus(_OtStatusCode.ERROR, str(exc)))
except Exception: # pragma: no cover — defensive
pass
raise raise

View file

@ -59,7 +59,7 @@ class AgentFeatureFlagsRead(BaseModel):
enable_otel: bool enable_otel: bool
@classmethod @classmethod
def from_flags(cls, flags: AgentFeatureFlags) -> "AgentFeatureFlagsRead": def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead:
# asdict() avoids missing-field bugs when AgentFeatureFlags grows. # asdict() avoids missing-field bugs when AgentFeatureFlags grows.
return cls(**asdict(flags)) return cls(**asdict(flags))

View file

@ -210,7 +210,7 @@ async def create_rule(
session.add(row) session.add(row)
try: try:
await session.commit() await session.commit()
except IntegrityError: except IntegrityError as err:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=409, status_code=409,
@ -218,7 +218,7 @@ async def create_rule(
"An identical rule already exists for this scope. Update the " "An identical rule already exists for this scope. Update the "
"existing rule instead." "existing rule instead."
), ),
) ) from err
await session.refresh(row) await session.refresh(row)
return _to_read(row) return _to_read(row)
@ -248,12 +248,12 @@ async def update_rule(
try: try:
await session.commit() await session.commit()
except IntegrityError: except IntegrityError as err:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=409, status_code=409,
detail="Update would create a duplicate rule for this scope.", detail="Update would create a duplicate rule for this scope.",
) ) from err
await session.refresh(row) await session.refresh(row)
return _to_read(row) return _to_read(row)

View file

@ -97,10 +97,12 @@ async def revert_agent_action(
action=action, action=action,
requester_user_id=str(user.id) if user is not None else None, requester_user_id=str(user.id) if user is not None else None,
) )
except Exception: except Exception as err:
logger.exception("Revert dispatch raised for action_id=%s", action_id) logger.exception("Revert dispatch raised for action_id=%s", action_id)
await session.rollback() await session.rollback()
raise HTTPException(status_code=500, detail="Internal error during revert.") raise HTTPException(
status_code=500, detail="Internal error during revert."
) from err
if outcome.status == "ok": if outcome.status == "ok":
await session.commit() await session.commit()

View file

@ -1242,7 +1242,9 @@ async def handle_new_chat(
await session.close() await session.close()
image_urls = ( image_urls = (
[p.as_data_url() for p in request.user_images] if request.user_images else None [p.as_data_url() for p in request.user_images]
if request.user_images
else None
) )
return StreamingResponse( return StreamingResponse(

View file

@ -79,9 +79,7 @@ async def load_action(
return result.scalars().first() return result.scalars().first()
async def load_thread( async def load_thread(session: AsyncSession, *, thread_id: int) -> NewChatThread | None:
session: AsyncSession, *, thread_id: int
) -> NewChatThread | None:
stmt = select(NewChatThread).where(NewChatThread.id == thread_id) stmt = select(NewChatThread).where(NewChatThread.id == thread_id)
result = await session.execute(stmt) result = await session.execute(stmt)
return result.scalars().first() return result.scalars().first()

View file

@ -7,7 +7,9 @@ import binascii
from typing import Any from typing import Any
def build_human_message_content(final_query: str, image_data_urls: list[str]) -> str | list[dict[str, Any]]: def build_human_message_content(
final_query: str, image_data_urls: list[str]
) -> str | list[dict[str, Any]]:
if not image_data_urls: if not image_data_urls:
return final_query return final_query
parts: list[dict[str, Any]] = [{"type": "text", "text": final_query}] parts: list[dict[str, Any]] = [{"type": "text", "text": final_query}]

View file

@ -90,9 +90,7 @@ class TestCompose:
assert "<citation_instructions>" in prompt assert "<citation_instructions>" in prompt
assert "[citation:chunk_id]" in prompt assert "[citation:chunk_id]" in prompt
def test_team_visibility_uses_team_variants( def test_team_visibility_uses_team_variants(self, fixed_today: datetime) -> None:
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt( prompt = compose_system_prompt(
today=fixed_today, today=fixed_today,
thread_visibility=ChatVisibility.SEARCH_SPACE, thread_visibility=ChatVisibility.SEARCH_SPACE,
@ -145,9 +143,7 @@ class TestCompose:
assert "Generate Image" in prompt assert "Generate Image" in prompt
assert "Generate Podcast" in prompt assert "Generate Podcast" in prompt
def test_mcp_routing_block_emits_when_provided( def test_mcp_routing_block_emits_when_provided(self, fixed_today: datetime) -> None:
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt( prompt = compose_system_prompt(
today=fixed_today, today=fixed_today,
mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]}, mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]},
@ -162,9 +158,7 @@ class TestCompose:
prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={}) prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={})
assert "<mcp_tool_routing>" not in prompt assert "<mcp_tool_routing>" not in prompt
def test_provider_block_renders_when_anthropic( def test_provider_block_renders_when_anthropic(self, fixed_today: datetime) -> None:
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt( prompt = compose_system_prompt(
today=fixed_today, model_name="anthropic:claude-3-5-sonnet" today=fixed_today, model_name="anthropic:claude-3-5-sonnet"
) )
@ -267,7 +261,10 @@ class TestStableOrderingForCacheStability:
) )
b = compose_system_prompt( b = compose_system_prompt(
today=fixed_today, today=fixed_today,
enabled_tool_names={"scrape_webpage", "web_search"}, # set order shouldn't matter enabled_tool_names={
"scrape_webpage",
"web_search",
}, # set order shouldn't matter
mcp_connector_tools={"X": ["x_a", "x_b"]}, mcp_connector_tools={"X": ["x_a", "x_b"]},
) )
assert a == b assert a == b

View file

@ -83,7 +83,11 @@ class TestActionLogMiddlewareDisabled:
async def test_no_op_when_flag_off(self, patch_get_flags) -> None: async def test_no_op_when_flag_off(self, patch_get_flags) -> None:
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest( request = _FakeRequest(
tool_call={"name": "make_widget", "args": {"color": "red", "size": 1}, "id": "tc1"} tool_call={
"name": "make_widget",
"args": {"color": "red", "size": 1},
"id": "tc1",
}
) )
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1")) handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
with patch_get_flags(_disabled_flags()): with patch_get_flags(_disabled_flags()):
@ -117,13 +121,12 @@ class TestActionLogMiddlewarePersistence:
"id": "tc-abc", "id": "tc-abc",
}, },
) )
result_msg = ToolMessage( result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1")
content="ok", tool_call_id="tc-abc", id="msg-1"
)
handler = AsyncMock(return_value=result_msg) handler = AsyncMock(return_value=result_msg)
with patch_get_flags(_enabled_flags()), patch( with (
"app.db.shielded_async_session", side_effect=lambda: factory() patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
): ):
result = await mw.awrap_tool_call(request, handler) result = await mw.awrap_tool_call(request, handler)
@ -151,9 +154,11 @@ class TestActionLogMiddlewarePersistence:
) )
handler = AsyncMock(side_effect=ValueError("boom")) handler = AsyncMock(side_effect=ValueError("boom"))
with patch_get_flags(_enabled_flags()), patch( with (
"app.db.shielded_async_session", side_effect=lambda: factory() patch_get_flags(_enabled_flags()),
), pytest.raises(ValueError, match="boom"): patch("app.db.shielded_async_session", side_effect=lambda: factory()),
pytest.raises(ValueError, match="boom"),
):
await mw.awrap_tool_call(request, handler) await mw.awrap_tool_call(request, handler)
assert len(captured["rows"]) == 1 assert len(captured["rows"]) == 1
@ -177,8 +182,9 @@ class TestActionLogMiddlewarePersistence:
def _exploding_session(): def _exploding_session():
raise RuntimeError("DB is down") raise RuntimeError("DB is down")
with patch_get_flags(_enabled_flags()), patch( with (
"app.db.shielded_async_session", side_effect=_exploding_session patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=_exploding_session),
): ):
result = await mw.awrap_tool_call(request, handler) result = await mw.awrap_tool_call(request, handler)
assert result is result_msg assert result is result_msg
@ -218,8 +224,9 @@ class TestReverseDescriptor:
) )
handler = AsyncMock(return_value=result_msg) handler = AsyncMock(return_value=result_msg)
with patch_get_flags(_enabled_flags()), patch( with (
"app.db.shielded_async_session", side_effect=lambda: factory() patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
): ):
await mw.awrap_tool_call(request, handler) await mw.awrap_tool_call(request, handler)
@ -257,8 +264,9 @@ class TestReverseDescriptor:
result_msg = ToolMessage(content="ok", tool_call_id="tc1") result_msg = ToolMessage(content="ok", tool_call_id="tc1")
handler = AsyncMock(return_value=result_msg) handler = AsyncMock(return_value=result_msg)
with patch_get_flags(_enabled_flags()), patch( with (
"app.db.shielded_async_session", side_effect=lambda: factory() patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
): ):
await mw.awrap_tool_call(request, handler) await mw.awrap_tool_call(request, handler)
@ -275,11 +283,10 @@ class TestReverseDescriptor:
request = _FakeRequest( request = _FakeRequest(
tool_call={"name": "unknown_tool", "args": {}, "id": "tc1"} tool_call={"name": "unknown_tool", "args": {}, "id": "tc1"}
) )
handler = AsyncMock( handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
return_value=ToolMessage(content="ok", tool_call_id="tc1") with (
) patch_get_flags(_enabled_flags()),
with patch_get_flags(_enabled_flags()), patch( patch("app.db.shielded_async_session", side_effect=lambda: factory()),
"app.db.shielded_async_session", side_effect=lambda: factory()
): ):
await mw.awrap_tool_call(request, handler) await mw.awrap_tool_call(request, handler)
row = captured["rows"][0] row = captured["rows"][0]
@ -298,11 +305,10 @@ class TestArgsTruncation:
request = _FakeRequest( request = _FakeRequest(
tool_call={"name": "make_widget", "args": {"blob": huge}, "id": "tc1"}, tool_call={"name": "make_widget", "args": {"blob": huge}, "id": "tc1"},
) )
handler = AsyncMock( handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
return_value=ToolMessage(content="ok", tool_call_id="tc1") with (
) patch_get_flags(_enabled_flags()),
with patch_get_flags(_enabled_flags()), patch( patch("app.db.shielded_async_session", side_effect=lambda: factory()),
"app.db.shielded_async_session", side_effect=lambda: factory()
): ):
await mw.awrap_tool_call(request, handler) await mw.awrap_tool_call(request, handler)
row = captured["rows"][0] row = captured["rows"][0]

View file

@ -26,10 +26,16 @@ class TestIsProtectedSystemMessage:
assert _is_protected_system_message(msg) is True assert _is_protected_system_message(msg) is True
def test_unprotected_system_message(self) -> None: def test_unprotected_system_message(self) -> None:
assert _is_protected_system_message(SystemMessage(content="random instructions")) is False assert (
_is_protected_system_message(SystemMessage(content="random instructions"))
is False
)
def test_human_message_never_protected(self) -> None: def test_human_message_never_protected(self) -> None:
assert _is_protected_system_message(HumanMessage(content="<workspace_tree>...")) is False assert (
_is_protected_system_message(HumanMessage(content="<workspace_tree>..."))
is False
)
def test_tolerates_leading_whitespace(self) -> None: def test_tolerates_leading_whitespace(self) -> None:
msg = SystemMessage(content=" \n<priority_documents>\n...") msg = SystemMessage(content=" \n<priority_documents>\n...")
@ -97,11 +103,17 @@ class TestPartitionMessages:
assert protected not in to_summary assert protected not in to_summary
assert protected in preserved assert protected in preserved
# The non-protected old messages remain in to_summary # The non-protected old messages remain in to_summary
assert any(isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary) assert any(
isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary
)
def test_unprotected_messages_unaffected(self) -> None: def test_unprotected_messages_unaffected(self) -> None:
partitioner = self._build_partitioner() partitioner = self._build_partitioner()
msgs = [HumanMessage(content="a"), HumanMessage(content="b"), HumanMessage(content="c")] msgs = [
HumanMessage(content="a"),
HumanMessage(content="b"),
HumanMessage(content="c"),
]
to_summary, preserved = partitioner._partition_messages(msgs, 2) to_summary, preserved = partitioner._partition_messages(msgs, 2)
assert [m.content for m in to_summary] == ["a", "b"] assert [m.content for m in to_summary] == ["a", "b"]
assert [m.content for m in preserved] == ["c"] assert [m.content for m in preserved] == ["c"]

View file

@ -70,7 +70,8 @@ class TestSpillEdit:
# Earlier ToolMessages should now contain the placeholder text # Earlier ToolMessages should now contain the placeholder text
cleared = [ cleared = [
m for m in tool_messages m
for m in tool_messages
if isinstance(m.content, str) and m.content.startswith("[cleared") if isinstance(m.content, str) and m.content.startswith("[cleared")
] ]
assert len(cleared) >= 1 assert len(cleared) >= 1

View file

@ -46,9 +46,21 @@ def test_callable_dedup_key_takes_priority() -> None:
state = { state = {
"messages": [ "messages": [
_msg( _msg(
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "1"}, {
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "2"}, "name": "create_doc",
{"name": "create_doc", "args": {"parent_id": "x", "title": "z"}, "id": "3"}, "args": {"parent_id": "x", "title": "y"},
"id": "1",
},
{
"name": "create_doc",
"args": {"parent_id": "x", "title": "y"},
"id": "2",
},
{
"name": "create_doc",
"args": {"parent_id": "x", "title": "z"},
"id": "3",
},
) )
] ]
} }

View file

@ -84,9 +84,7 @@ class TestConnectorDenyOverridesDefaultAllow:
Rule(permission="linear_create_issue", pattern="*", action="deny") Rule(permission="linear_create_issue", pattern="*", action="deny")
] ]
) )
rules = evaluate_many( rules = evaluate_many("linear_create_issue", ["linear_create_issue"], *rulesets)
"linear_create_issue", ["linear_create_issue"], *rulesets
)
assert aggregate_action(rules) == "deny" assert aggregate_action(rules) == "deny"
def test_default_allow_still_applies_to_other_tools(self) -> None: def test_default_allow_still_applies_to_other_tools(self) -> None:
@ -124,5 +122,7 @@ class TestUserRuleOverridesDefault:
rules=[Rule(permission="send_*", pattern="*", action="deny")], rules=[Rule(permission="send_*", pattern="*", action="deny")],
origin="user", origin="user",
) )
rules = evaluate_many("send_gmail_email", ["send_gmail_email"], defaults, user_ruleset) rules = evaluate_many(
"send_gmail_email", ["send_gmail_email"], defaults, user_ruleset
)
assert aggregate_action(rules) == "deny" assert aggregate_action(rules) == "deny"

View file

@ -64,22 +64,17 @@ def test_threshold_triggers_after_n_identical_calls() -> None:
runtime, runtime,
) )
name = type(excinfo.value).__name__.lower() name = type(excinfo.value).__name__.lower()
assert ( assert "interrupt" in name or "runtimeerror" in name, (
"interrupt" in name f"Expected an interrupt-style exception, got {name}"
or "runtimeerror" in name )
), f"Expected an interrupt-style exception, got {name}"
def test_does_not_trigger_when_args_differ() -> None: def test_does_not_trigger_when_args_differ() -> None:
mw = DoomLoopMiddleware(threshold=2) mw = DoomLoopMiddleware(threshold=2)
runtime = _FakeRuntime() runtime = _FakeRuntime()
out = mw.after_model( out = mw.after_model({"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime)
{"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime
)
assert out is None assert out is None
out = mw.after_model( out = mw.after_model({"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime)
{"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime
)
assert out is None assert out is None

View file

@ -91,7 +91,9 @@ class TestShouldInject:
mw = NoopInjectionMiddleware() mw = NoopInjectionMiddleware()
req = _FakeRequest( req = _FakeRequest(
tools=[object()], tools=[object()],
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])], messages=[
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])
],
model=_LiteLLMModel(), model=_LiteLLMModel(),
) )
assert mw._should_inject(req) is False assert mw._should_inject(req) is False
@ -109,7 +111,9 @@ class TestShouldInject:
mw = NoopInjectionMiddleware() mw = NoopInjectionMiddleware()
req = _FakeRequest( req = _FakeRequest(
tools=[], tools=[],
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])], messages=[
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])
],
model=_OpenAIModel(), model=_OpenAIModel(),
) )
assert mw._should_inject(req) is False assert mw._should_inject(req) is False

View file

@ -111,6 +111,4 @@ class TestAsk:
assert out is None # call kept assert out is None # call kept
# Runtime ruleset got the always-allow rule # Runtime ruleset got the always-allow rule
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"] new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
assert any( assert any(r.permission == "send_email" for r in new_rules)
r.permission == "send_email" for r in new_rules
)

View file

@ -69,7 +69,9 @@ class TestPluginLoaderBasics:
"app.agents.new_chat.plugin_loader.entry_points", "app.agents.new_chat.plugin_loader.entry_points",
return_value=[ep], return_value=[ep],
): ):
result = load_plugin_middlewares(_ctx(), allowed_plugin_names=["allowed_only"]) result = load_plugin_middlewares(
_ctx(), allowed_plugin_names=["allowed_only"]
)
assert result == [] assert result == []
assert not called assert not called
@ -135,9 +137,7 @@ class TestPluginLoaderIsolation:
_FakeEntryPoint("crashing", crashing_factory), _FakeEntryPoint("crashing", crashing_factory),
_FakeEntryPoint("ok", year_substituter_factory), _FakeEntryPoint("ok", year_substituter_factory),
] ]
with patch( with patch("app.agents.new_chat.plugin_loader.entry_points", return_value=eps):
"app.agents.new_chat.plugin_loader.entry_points", return_value=eps
):
result = load_plugin_middlewares( result = load_plugin_middlewares(
_ctx(), allowed_plugin_names={"crashing", "ok"} _ctx(), allowed_plugin_names={"crashing", "ok"}
) )
@ -151,9 +151,7 @@ class TestAllowlistEnv:
assert load_allowed_plugin_names_from_env() == set() assert load_allowed_plugin_names_from_env() == set()
def test_parses_comma_separated_value(self, monkeypatch) -> None: def test_parses_comma_separated_value(self, monkeypatch) -> None:
monkeypatch.setenv( monkeypatch.setenv("SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , ")
"SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , "
)
assert load_allowed_plugin_names_from_env() == { assert load_allowed_plugin_names_from_env() == {
"year_substituter", "year_substituter",
"noisy", "noisy",

View file

@ -18,7 +18,7 @@ class _FakeResponse:
self.headers = headers self.headers = headers
class _FakeRateLimit(Exception): class _FakeRateLimitError(Exception):
def __init__(self, msg: str, headers: dict[str, str] | None = None) -> None: def __init__(self, msg: str, headers: dict[str, str] | None = None) -> None:
super().__init__(msg) super().__init__(msg)
if headers is not None: if headers is not None:
@ -27,15 +27,15 @@ class _FakeRateLimit(Exception):
class TestExtractRetryAfter: class TestExtractRetryAfter:
def test_seconds_header(self) -> None: def test_seconds_header(self) -> None:
exc = _FakeRateLimit("rate", {"Retry-After": "30"}) exc = _FakeRateLimitError("rate", {"Retry-After": "30"})
assert _extract_retry_after_seconds(exc) == 30.0 assert _extract_retry_after_seconds(exc) == 30.0
def test_milliseconds_header_overrides_seconds(self) -> None: def test_milliseconds_header_overrides_seconds(self) -> None:
exc = _FakeRateLimit("rate", {"retry-after-ms": "1500"}) exc = _FakeRateLimitError("rate", {"retry-after-ms": "1500"})
assert _extract_retry_after_seconds(exc) == 1.5 assert _extract_retry_after_seconds(exc) == 1.5
def test_case_insensitive(self) -> None: def test_case_insensitive(self) -> None:
exc = _FakeRateLimit("rate", {"RETRY-AFTER": "12"}) exc = _FakeRateLimitError("rate", {"RETRY-AFTER": "12"})
assert _extract_retry_after_seconds(exc) == 12.0 assert _extract_retry_after_seconds(exc) == 12.0
def test_falls_back_to_message_regex(self) -> None: def test_falls_back_to_message_regex(self) -> None:
@ -67,7 +67,7 @@ class TestIsNonRetryable:
class TestDelayCalculation: class TestDelayCalculation:
def test_takes_max_of_backoff_and_header(self) -> None: def test_takes_max_of_backoff_and_header(self) -> None:
mw = RetryAfterMiddleware(max_retries=3, initial_delay=1.0, jitter=False) mw = RetryAfterMiddleware(max_retries=3, initial_delay=1.0, jitter=False)
exc = _FakeRateLimit("rl", {"retry-after": "10"}) exc = _FakeRateLimitError("rl", {"retry-after": "10"})
delay = mw._delay_for_attempt(0, exc) delay = mw._delay_for_attempt(0, exc)
assert delay == pytest.approx(10.0) assert delay == pytest.approx(10.0)

View file

@ -122,7 +122,9 @@ class TestExploreSubagent:
def test_includes_permission_middleware_with_deny_rules(self) -> None: def test_includes_permission_middleware_with_deny_rules(self) -> None:
spec = build_explore_subagent(tools=ALL_TOOLS) spec = build_explore_subagent(tools=ALL_TOOLS)
permission_mws = [ permission_mws = [
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index] m
for m in spec["middleware"]
if isinstance(m, PermissionMiddleware) # type: ignore[index]
] ]
assert len(permission_mws) == 1 assert len(permission_mws) == 1
ruleset = permission_mws[0]._static_rulesets[0] ruleset = permission_mws[0]._static_rulesets[0]
@ -164,7 +166,9 @@ class TestReportWriterSubagent:
def test_deny_rules_block_writes_but_allow_generate_report(self) -> None: def test_deny_rules_block_writes_but_allow_generate_report(self) -> None:
spec = build_report_writer_subagent(tools=ALL_TOOLS) spec = build_report_writer_subagent(tools=ALL_TOOLS)
permission_mws = [ permission_mws = [
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index] m
for m in spec["middleware"]
if isinstance(m, PermissionMiddleware) # type: ignore[index]
] ]
ruleset = permission_mws[0]._static_rulesets[0] ruleset = permission_mws[0]._static_rulesets[0]
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"} deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
@ -194,17 +198,15 @@ class TestConnectorNegotiatorSubagent:
def test_deny_ruleset_blocks_mutating_connector_tools(self) -> None: def test_deny_ruleset_blocks_mutating_connector_tools(self) -> None:
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS) spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
permission_mws = [ permission_mws = [
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index] m
for m in spec["middleware"]
if isinstance(m, PermissionMiddleware) # type: ignore[index]
] ]
ruleset = permission_mws[0]._static_rulesets[0] ruleset = permission_mws[0]._static_rulesets[0]
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"} deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
# `linear_create_issue` matches the `*_create` deny pattern. # `linear_create_issue` matches the `*_create` deny pattern.
assert any( assert any(_wildcard_matches(p, "linear_create_issue") for p in deny_patterns)
_wildcard_matches(p, "linear_create_issue") for p in deny_patterns assert any(_wildcard_matches(p, "slack_send_message") for p in deny_patterns)
)
assert any(
_wildcard_matches(p, "slack_send_message") for p in deny_patterns
)
class TestBuildSpecializedSubagents: class TestBuildSpecializedSubagents:
@ -242,8 +244,7 @@ class TestBuildSpecializedSubagents:
# order: extra → custom → patch → dedup. # order: extra → custom → patch → dedup.
sentinel_idx = mws.index(sentinel) sentinel_idx = mws.index(sentinel)
perm_idx = next( perm_idx = next(
(i for i, m in enumerate(mws) (i for i, m in enumerate(mws) if isinstance(m, PermissionMiddleware)),
if isinstance(m, PermissionMiddleware)),
None, None,
) )
assert perm_idx is not None assert perm_idx is not None
@ -259,7 +260,9 @@ class TestFilterToolsWarningSuppression:
from app.agents.new_chat.subagents.config import _filter_tools from app.agents.new_chat.subagents.config import _filter_tools
with caplog.at_level(logging.INFO, logger="app.agents.new_chat.subagents.config"): with caplog.at_level(
logging.INFO, logger="app.agents.new_chat.subagents.config"
):
# Allowed set asks for two registry tools (one present, one # Allowed set asks for two registry tools (one present, one
# not) plus a bunch of middleware-provided names. # not) plus a bunch of middleware-provided names.
_filter_tools( _filter_tools(
@ -275,9 +278,7 @@ class TestFilterToolsWarningSuppression:
}, },
) )
warnings = [ warnings = [r.message for r in caplog.records if r.levelno >= logging.INFO]
r.message for r in caplog.records if r.levelno >= logging.INFO
]
# Exactly one warning, and it should mention scrape_webpage but not # Exactly one warning, and it should mention scrape_webpage but not
# any middleware-provided name. Inspect the rendered "missing" # any middleware-provided name. Inspect the rendered "missing"
# list (between the brackets) so we don't false-match substrings # list (between the brackets) so we don't false-match substrings

View file

@ -27,9 +27,12 @@ class TestRepair:
mw = ToolCallNameRepairMiddleware( mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None registered_tool_names={"echo"}, fuzzy_match_threshold=None
) )
msg = AIMessage(content="", tool_calls=[ msg = AIMessage(
content="",
tool_calls=[
{"name": "echo", "args": {}, "id": "1"}, {"name": "echo", "args": {}, "id": "1"},
]) ],
)
out = mw.after_model(_make_state(msg), _FakeRuntime()) out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is None # no change assert out is None # no change
@ -37,9 +40,12 @@ class TestRepair:
mw = ToolCallNameRepairMiddleware( mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None registered_tool_names={"echo"}, fuzzy_match_threshold=None
) )
msg = AIMessage(content="", tool_calls=[ msg = AIMessage(
content="",
tool_calls=[
{"name": "Echo", "args": {"x": 1}, "id": "1"}, {"name": "Echo", "args": {"x": 1}, "id": "1"},
]) ],
)
out = mw.after_model(_make_state(msg), _FakeRuntime()) out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is not None assert out is not None
repaired = out["messages"][0] repaired = out["messages"][0]
@ -50,9 +56,12 @@ class TestRepair:
registered_tool_names={"echo", INVALID_TOOL_NAME}, registered_tool_names={"echo", INVALID_TOOL_NAME},
fuzzy_match_threshold=None, fuzzy_match_threshold=None,
) )
msg = AIMessage(content="", tool_calls=[ msg = AIMessage(
content="",
tool_calls=[
{"name": "totally_different_name", "args": {"k": "v"}, "id": "1"}, {"name": "totally_different_name", "args": {"k": "v"}, "id": "1"},
]) ],
)
out = mw.after_model(_make_state(msg), _FakeRuntime()) out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is not None assert out is not None
repaired_call = out["messages"][0].tool_calls[0] repaired_call = out["messages"][0].tool_calls[0]
@ -64,9 +73,12 @@ class TestRepair:
mw = ToolCallNameRepairMiddleware( mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None registered_tool_names={"echo"}, fuzzy_match_threshold=None
) )
msg = AIMessage(content="", tool_calls=[ msg = AIMessage(
content="",
tool_calls=[
{"name": "unknown", "args": {}, "id": "1"}, {"name": "unknown", "args": {}, "id": "1"},
]) ],
)
out = mw.after_model(_make_state(msg), _FakeRuntime()) out = mw.after_model(_make_state(msg), _FakeRuntime())
# No repair available; original returned unchanged (no update) # No repair available; original returned unchanged (no update)
assert out is None assert out is None
@ -76,9 +88,12 @@ class TestRepair:
registered_tool_names={"search_documents"}, registered_tool_names={"search_documents"},
fuzzy_match_threshold=0.7, fuzzy_match_threshold=0.7,
) )
msg = AIMessage(content="", tool_calls=[ msg = AIMessage(
content="",
tool_calls=[
{"name": "search_docments", "args": {}, "id": "1"}, {"name": "search_docments", "args": {}, "id": "1"},
]) ],
)
out = mw.after_model(_make_state(msg), _FakeRuntime()) out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is not None assert out is not None
assert out["messages"][0].tool_calls[0]["name"] == "search_documents" assert out["messages"][0].tool_calls[0]["name"] == "search_documents"
@ -94,9 +109,12 @@ class TestRepair:
mw = ToolCallNameRepairMiddleware( mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None registered_tool_names={"echo"}, fuzzy_match_threshold=None
) )
msg = AIMessage(content="", tool_calls=[ msg = AIMessage(
content="",
tool_calls=[
{"name": "DynamicTool", "args": {}, "id": "1"}, {"name": "DynamicTool", "args": {}, "id": "1"},
]) ],
)
runtime = _FakeRuntime(SimpleNamespace(registered_tool_names=["dynamictool"])) runtime = _FakeRuntime(SimpleNamespace(registered_tool_names=["dynamictool"]))
out = mw.after_model(_make_state(msg), runtime) out = mw.after_model(_make_state(msg), runtime)
assert out is not None assert out is not None

View file

@ -10,7 +10,7 @@ through :class:`KnowledgeBasePersistenceMiddleware` without losing the copy.
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock
import numpy as np import numpy as np
import pytest import pytest

View file

@ -16,9 +16,7 @@ class _FakeAction:
class TestCanRevert: class TestCanRevert:
def test_owner_can_revert_their_own_action(self) -> None: def test_owner_can_revert_their_own_action(self) -> None:
action = _FakeAction(user_id="user-123") action = _FakeAction(user_id="user-123")
assert can_revert( assert can_revert(requester_user_id="user-123", action=action, is_admin=False)
requester_user_id="user-123", action=action, is_admin=False
)
def test_other_user_cannot_revert(self) -> None: def test_other_user_cannot_revert(self) -> None:
action = _FakeAction(user_id="user-123") action = _FakeAction(user_id="user-123")
@ -28,21 +26,15 @@ class TestCanRevert:
def test_admin_always_allowed(self) -> None: def test_admin_always_allowed(self) -> None:
action = _FakeAction(user_id="user-123") action = _FakeAction(user_id="user-123")
assert can_revert( assert can_revert(requester_user_id="anybody", action=action, is_admin=True)
requester_user_id="anybody", action=action, is_admin=True
)
def test_admin_can_revert_anonymous_action(self) -> None: def test_admin_can_revert_anonymous_action(self) -> None:
action = _FakeAction(user_id=None) action = _FakeAction(user_id=None)
assert can_revert( assert can_revert(requester_user_id="admin", action=action, is_admin=True)
requester_user_id="admin", action=action, is_admin=True
)
def test_anonymous_action_blocks_non_admin(self) -> None: def test_anonymous_action_blocks_non_admin(self) -> None:
action = _FakeAction(user_id=None) action = _FakeAction(user_id=None)
assert not can_revert( assert not can_revert(requester_user_id="user-1", action=action, is_admin=False)
requester_user_id="user-1", action=action, is_admin=False
)
def test_uuid_string_normalization(self) -> None: def test_uuid_string_normalization(self) -> None:
"""``user_id`` may be a UUID object; comparison should still work.""" """``user_id`` may be a UUID object; comparison should still work."""
@ -51,6 +43,4 @@ class TestCanRevert:
u = uuid.uuid4() u = uuid.uuid4()
action = _FakeAction(user_id=u) action = _FakeAction(user_id=u)
# Same UUID, passed as string from the requesting side. # Same UUID, passed as string from the requesting side.
assert can_revert( assert can_revert(requester_user_id=str(u), action=action, is_admin=False)
requester_user_id=str(u), action=action, is_admin=False
)