chore: linting

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-28 21:37:51 -07:00
parent b9a66cb417
commit ca9bbee06d
41 changed files with 314 additions and 244 deletions

View file

@ -101,9 +101,7 @@ class ActionLogMiddleware(AgentMiddleware):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
],
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
if not self._enabled():
return await handler(request)

View file

@ -177,8 +177,8 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware):
messages_in=len(conversation_messages),
extra={"compaction.cutoff_index": int(cutoff_index)},
):
messages_to_summarize, preserved_messages = (
super()._partition_messages(conversation_messages, cutoff_index)
messages_to_summarize, preserved_messages = super()._partition_messages(
conversation_messages, cutoff_index
)
protected: list[AnyMessage] = []

View file

@ -58,8 +58,7 @@ DEFAULT_SPILL_PREFIX = "/tool_outputs"
def _build_spill_placeholder(spill_path: str) -> str:
"""Build the user-facing placeholder text shown to the model."""
return (
f"[cleared — full output at {spill_path}; "
f"ask the explore subagent to read it]"
f"[cleared — full output at {spill_path}; ask the explore subagent to read it]"
)
@ -131,7 +130,9 @@ class SpillToBackendEdit(ContextEdit):
return
candidates = [
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
(idx, msg)
for idx, msg in enumerate(messages)
if isinstance(msg, ToolMessage)
]
if self.keep >= len(candidates):
return

View file

@ -137,16 +137,21 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon
triggered_call: dict[str, Any] | None = None
for call in message.tool_calls:
name = call.get("name") if isinstance(call, dict) else getattr(call, "name", None)
args = call.get("args") if isinstance(call, dict) else getattr(call, "args", {})
name = (
call.get("name")
if isinstance(call, dict)
else getattr(call, "name", None)
)
args = (
call.get("args")
if isinstance(call, dict)
else getattr(call, "args", {})
)
if not isinstance(name, str):
continue
sig = _signature(name, args)
window.append(sig)
if (
len(window) >= self._threshold
and len(set(window)) == 1
):
if len(window) >= self._threshold and len(set(window)) == 1:
triggered_call = {"name": name, "params": args or {}}
break
@ -209,7 +214,9 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon
# tool call proceeds. The frontend's exact reply names may differ —
# we tolerate any shape that contains a string with "reject"/"cancel".
if isinstance(decision, dict):
kind = str(decision.get("decision_type") or decision.get("type") or "").lower()
kind = str(
decision.get("decision_type") or decision.get("type") or ""
).lower()
if "reject" in kind or "cancel" in kind:
return {"jump_to": "end"}
return None

View file

@ -552,7 +552,7 @@ def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage:
for entry in priority:
score = entry.get("score")
mentioned = entry.get("mentioned")
score_str = f"{score:.3f}" if isinstance(score, (int, float)) else "n/a"
score_str = f"{score:.3f}" if isinstance(score, int | float) else "n/a"
mark = " [USER-MENTIONED]" if mentioned else ""
lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}")
body = "\n".join(lines)
@ -593,7 +593,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
self.top_k = top_k
self.mentioned_document_ids = mentioned_document_ids or []
# Tier 4.2: build the kb-planner private Runnable ONCE here so we
# don't pay the create_agent compile cost (50200ms) on every turn.
# don't pay the create_agent compile cost (50-200ms) on every turn.
# Disabled by default behind ``enable_kb_planner_runnable``; when off
# the planner falls back to the legacy ``self.llm.ainvoke`` path.
self._planner: Runnable | None = None
@ -617,10 +617,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
if self.llm is None:
return None
flags = get_flags()
if (
not flags.enable_kb_planner_runnable
or flags.disable_new_agent_stack
):
if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
return None
from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware
@ -920,7 +917,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
chunk_ids = doc.get("matched_chunk_ids") or []
if chunk_ids:
matched_chunk_ids[doc_id] = [
int(cid) for cid in chunk_ids if isinstance(cid, (int, str))
int(cid) for cid in chunk_ids if isinstance(cid, int | str)
]
return priority, matched_chunk_ids

View file

@ -35,9 +35,7 @@ from langchain_core.tools import tool
logger = logging.getLogger(__name__)
NOOP_TOOL_NAME = "_noop"
NOOP_TOOL_DESCRIPTION = (
"Do not call this tool. It exists only for API compatibility."
)
NOOP_TOOL_DESCRIPTION = "Do not call this tool. It exists only for API compatibility."
@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION)
@ -78,7 +76,9 @@ def _last_ai_has_tool_calls(messages: list[Any]) -> bool:
return False
class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
class NoopInjectionMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Inject the ``_noop`` tool only when the provider would otherwise 400.
The check fires per model call, not at agent build time, because the
@ -116,7 +116,9 @@ class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> Any:
if self._should_inject(request):
logger.debug("Injecting _noop tool for provider compatibility")

View file

@ -56,9 +56,7 @@ class OtelSpanMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[
[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]
],
handler: Callable[[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]],
) -> ModelResponse | AIMessage | Any:
if not ot.is_enabled():
return await handler(request)
@ -81,9 +79,7 @@ class OtelSpanMiddleware(AgentMiddleware):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
],
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
if not ot.is_enabled():
return await handler(request)
@ -187,7 +183,11 @@ def _annotate_model_response(span: Any, result: Any) -> None:
def _annotate_tool_result(span: Any, result: Any) -> None:
try:
if isinstance(result, ToolMessage):
content = result.content if isinstance(result.content, str) else repr(result.content)
content = (
result.content
if isinstance(result.content, str)
else repr(result.content)
)
span.set_attribute("tool.output.size", len(content))
status = getattr(result, "status", None)
if isinstance(status, str):

View file

@ -145,7 +145,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
try:
patterns = resolver(args or {})
except Exception:
logger.exception("Pattern resolver for %s raised; using bare name", tool_name)
logger.exception(
"Pattern resolver for %s raised; using bare name", tool_name
)
patterns = [tool_name]
if not patterns:
patterns = [tool_name]
@ -198,11 +200,14 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
# Tier 3b: permission.asked + interrupt.raised spans (no-op when
# OTel is disabled). Both fire here so dashboards can correlate
# "we asked X" with "interrupt was actually delivered".
with ot.permission_asked_span(
permission=tool_name,
pattern=patterns[0] if patterns else None,
extra={"permission.patterns": list(patterns)},
), ot.interrupt_span(interrupt_type="permission_ask"):
with (
ot.permission_asked_span(
permission=tool_name,
pattern=patterns[0] if patterns else None,
extra={"permission.patterns": list(patterns)},
),
ot.interrupt_span(interrupt_type="permission_ask"),
):
decision = interrupt(payload)
if isinstance(decision, dict):
return decision
@ -211,9 +216,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
return {"decision_type": decision}
return {"decision_type": "reject"}
def _persist_always(
self, tool_name: str, patterns: list[str]
) -> None:
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
"""Promote ``always`` reply into runtime allow rules.
Persistence to ``agent_permission_rules`` is done by the
@ -276,12 +279,16 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
any_change = False
for raw in last.tool_calls:
call = dict(raw) if isinstance(raw, dict) else {
"name": getattr(raw, "name", None),
"args": getattr(raw, "args", {}),
"id": getattr(raw, "id", None),
"type": "tool_call",
}
call = (
dict(raw)
if isinstance(raw, dict)
else {
"name": getattr(raw, "name", None),
"args": getattr(raw, "args", {}),
"id": getattr(raw, "id", None),
"type": "tool_call",
}
)
name = call.get("name") or ""
args = call.get("args") or {}
action, patterns, rules = self._evaluate(name, args)
@ -307,7 +314,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
feedback = decision.get("feedback")
if isinstance(feedback, str) and feedback.strip():
raise CorrectedError(feedback, tool=name)
raise RejectedError(tool=name, pattern=patterns[0] if patterns else None)
raise RejectedError(
tool=name, pattern=patterns[0] if patterns else None
)
else:
logger.warning(
"Unknown permission decision %r; treating as reject", kind

View file

@ -113,7 +113,9 @@ def _exponential_delay(
jitter: bool,
) -> float:
"""Compute an exponential-backoff delay with optional ±25% jitter."""
delay = initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
delay = (
initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
)
delay = min(delay, max_delay)
if jitter and delay > 0:
delay *= 1 + random.uniform(-0.25, 0.25)
@ -201,7 +203,9 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
},
)
except Exception:
logger.debug("dispatch_custom_event failed; suppressed", exc_info=True)
logger.debug(
"dispatch_custom_event failed; suppressed", exc_info=True
)
if delay > 0:
time.sleep(delay)
# Unreachable
@ -210,7 +214,9 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> ModelResponse[ResponseT] | AIMessage:
for attempt in range(self.max_retries + 1):
try:

View file

@ -29,6 +29,7 @@ gives a clean failure mode if anything tries.
from __future__ import annotations
import contextlib
import logging
from collections.abc import Callable
from dataclasses import replace
@ -114,8 +115,10 @@ class BuiltinSkillsBackend(BackendProtocol):
infos: list[FileInfo] = []
# Build virtual paths anchored at "/" because CompositeBackend already
# stripped the route prefix before calling us.
target_virtual = "/" if target == self.root else (
"/" + str(target.relative_to(self.root)).replace("\\", "/")
target_virtual = (
"/"
if target == self.root
else ("/" + str(target.relative_to(self.root)).replace("\\", "/"))
)
for child in sorted(target.iterdir()):
child_virtual = (
@ -128,10 +131,8 @@ class BuiltinSkillsBackend(BackendProtocol):
"is_dir": child.is_dir(),
}
if child.is_file():
try:
with contextlib.suppress(OSError): # pragma: no cover - defensive
info["size"] = child.stat().st_size
except OSError: # pragma: no cover - defensive
pass
infos.append(info)
return infos
@ -163,7 +164,9 @@ class BuiltinSkillsBackend(BackendProtocol):
else:
content = target.read_bytes()
except PermissionError:
responses.append(FileDownloadResponse(path=p, error="permission_denied"))
responses.append(
FileDownloadResponse(path=p, error="permission_denied")
)
continue
except OSError as exc: # pragma: no cover - defensive
logger.warning("Builtin skill read failed %s: %s", target, exc)
@ -286,6 +289,7 @@ def build_skills_backend_factory(
builtin = BuiltinSkillsBackend(builtin_root)
if search_space_id is None:
def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol:
# Default StateBackend is intentionally inert: any path outside the
# ``/skills/builtin/`` route resolves to an empty per-runtime state
@ -294,6 +298,7 @@ def build_skills_backend_factory(
default=StateBackend(runtime),
routes={SKILLS_BUILTIN_PREFIX: builtin},
)
return _factory_builtin_only
def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol:

View file

@ -51,13 +51,15 @@ def _coerce_existing_tool_call(call: Any) -> dict[str, Any]:
}
class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
class ToolCallNameRepairMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Two-stage tool-name repair on the most recent ``AIMessage``.
Args:
registered_tool_names: Set of canonically-registered tool names.
``invalid`` should be in this set so the fallback dispatches.
fuzzy_match_threshold: Optional ``difflib`` ratio (01) for the
fuzzy_match_threshold: Optional ``difflib`` ratio (0-1) for the
fuzzy-match step that runs *between* lowercase and invalid.
Set to ``None`` to disable fuzzy matching (opencode parity).
"""
@ -77,9 +79,9 @@ class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], Contex
def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]:
"""Allow runtime overrides to expand the set (e.g. dynamic MCP tools)."""
ctx_tools = getattr(runtime.context, "registered_tool_names", None)
if isinstance(ctx_tools, (set, frozenset)):
if isinstance(ctx_tools, set | frozenset):
return self._registered | set(ctx_tools)
if isinstance(ctx_tools, (list, tuple)):
if isinstance(ctx_tools, list | tuple):
return self._registered | set(ctx_tools)
return self._registered