mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 21:32:39 +02:00
chore: linting
This commit is contained in:
parent
b9a66cb417
commit
ca9bbee06d
41 changed files with 314 additions and 244 deletions
|
|
@ -101,9 +101,7 @@ class ActionLogMiddleware(AgentMiddleware):
|
|||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[
|
||||
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
|
||||
],
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
if not self._enabled():
|
||||
return await handler(request)
|
||||
|
|
|
|||
|
|
@ -177,8 +177,8 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware):
|
|||
messages_in=len(conversation_messages),
|
||||
extra={"compaction.cutoff_index": int(cutoff_index)},
|
||||
):
|
||||
messages_to_summarize, preserved_messages = (
|
||||
super()._partition_messages(conversation_messages, cutoff_index)
|
||||
messages_to_summarize, preserved_messages = super()._partition_messages(
|
||||
conversation_messages, cutoff_index
|
||||
)
|
||||
|
||||
protected: list[AnyMessage] = []
|
||||
|
|
|
|||
|
|
@ -58,8 +58,7 @@ DEFAULT_SPILL_PREFIX = "/tool_outputs"
|
|||
def _build_spill_placeholder(spill_path: str) -> str:
|
||||
"""Build the user-facing placeholder text shown to the model."""
|
||||
return (
|
||||
f"[cleared — full output at {spill_path}; "
|
||||
f"ask the explore subagent to read it]"
|
||||
f"[cleared — full output at {spill_path}; ask the explore subagent to read it]"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -131,7 +130,9 @@ class SpillToBackendEdit(ContextEdit):
|
|||
return
|
||||
|
||||
candidates = [
|
||||
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
|
||||
(idx, msg)
|
||||
for idx, msg in enumerate(messages)
|
||||
if isinstance(msg, ToolMessage)
|
||||
]
|
||||
if self.keep >= len(candidates):
|
||||
return
|
||||
|
|
|
|||
|
|
@ -137,16 +137,21 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon
|
|||
|
||||
triggered_call: dict[str, Any] | None = None
|
||||
for call in message.tool_calls:
|
||||
name = call.get("name") if isinstance(call, dict) else getattr(call, "name", None)
|
||||
args = call.get("args") if isinstance(call, dict) else getattr(call, "args", {})
|
||||
name = (
|
||||
call.get("name")
|
||||
if isinstance(call, dict)
|
||||
else getattr(call, "name", None)
|
||||
)
|
||||
args = (
|
||||
call.get("args")
|
||||
if isinstance(call, dict)
|
||||
else getattr(call, "args", {})
|
||||
)
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
sig = _signature(name, args)
|
||||
window.append(sig)
|
||||
if (
|
||||
len(window) >= self._threshold
|
||||
and len(set(window)) == 1
|
||||
):
|
||||
if len(window) >= self._threshold and len(set(window)) == 1:
|
||||
triggered_call = {"name": name, "params": args or {}}
|
||||
break
|
||||
|
||||
|
|
@ -209,7 +214,9 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon
|
|||
# tool call proceeds. The frontend's exact reply names may differ —
|
||||
# we tolerate any shape that contains a string with "reject"/"cancel".
|
||||
if isinstance(decision, dict):
|
||||
kind = str(decision.get("decision_type") or decision.get("type") or "").lower()
|
||||
kind = str(
|
||||
decision.get("decision_type") or decision.get("type") or ""
|
||||
).lower()
|
||||
if "reject" in kind or "cancel" in kind:
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -552,7 +552,7 @@ def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage:
|
|||
for entry in priority:
|
||||
score = entry.get("score")
|
||||
mentioned = entry.get("mentioned")
|
||||
score_str = f"{score:.3f}" if isinstance(score, (int, float)) else "n/a"
|
||||
score_str = f"{score:.3f}" if isinstance(score, int | float) else "n/a"
|
||||
mark = " [USER-MENTIONED]" if mentioned else ""
|
||||
lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}")
|
||||
body = "\n".join(lines)
|
||||
|
|
@ -593,7 +593,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
self.top_k = top_k
|
||||
self.mentioned_document_ids = mentioned_document_ids or []
|
||||
# Tier 4.2: build the kb-planner private Runnable ONCE here so we
|
||||
# don't pay the create_agent compile cost (50–200ms) on every turn.
|
||||
# don't pay the create_agent compile cost (50-200ms) on every turn.
|
||||
# Disabled by default behind ``enable_kb_planner_runnable``; when off
|
||||
# the planner falls back to the legacy ``self.llm.ainvoke`` path.
|
||||
self._planner: Runnable | None = None
|
||||
|
|
@ -617,10 +617,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
if self.llm is None:
|
||||
return None
|
||||
flags = get_flags()
|
||||
if (
|
||||
not flags.enable_kb_planner_runnable
|
||||
or flags.disable_new_agent_stack
|
||||
):
|
||||
if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
|
||||
return None
|
||||
|
||||
from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware
|
||||
|
|
@ -920,7 +917,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
chunk_ids = doc.get("matched_chunk_ids") or []
|
||||
if chunk_ids:
|
||||
matched_chunk_ids[doc_id] = [
|
||||
int(cid) for cid in chunk_ids if isinstance(cid, (int, str))
|
||||
int(cid) for cid in chunk_ids if isinstance(cid, int | str)
|
||||
]
|
||||
return priority, matched_chunk_ids
|
||||
|
||||
|
|
|
|||
|
|
@ -35,9 +35,7 @@ from langchain_core.tools import tool
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
NOOP_TOOL_NAME = "_noop"
|
||||
NOOP_TOOL_DESCRIPTION = (
|
||||
"Do not call this tool. It exists only for API compatibility."
|
||||
)
|
||||
NOOP_TOOL_DESCRIPTION = "Do not call this tool. It exists only for API compatibility."
|
||||
|
||||
|
||||
@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION)
|
||||
|
|
@ -78,7 +76,9 @@ def _last_ai_has_tool_calls(messages: list[Any]) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
class NoopInjectionMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""Inject the ``_noop`` tool only when the provider would otherwise 400.
|
||||
|
||||
The check fires per model call, not at agent build time, because the
|
||||
|
|
@ -116,7 +116,9 @@ class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R
|
|||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> Any:
|
||||
if self._should_inject(request):
|
||||
logger.debug("Injecting _noop tool for provider compatibility")
|
||||
|
|
|
|||
|
|
@ -56,9 +56,7 @@ class OtelSpanMiddleware(AgentMiddleware):
|
|||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[
|
||||
[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]
|
||||
],
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]],
|
||||
) -> ModelResponse | AIMessage | Any:
|
||||
if not ot.is_enabled():
|
||||
return await handler(request)
|
||||
|
|
@ -81,9 +79,7 @@ class OtelSpanMiddleware(AgentMiddleware):
|
|||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[
|
||||
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
|
||||
],
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
if not ot.is_enabled():
|
||||
return await handler(request)
|
||||
|
|
@ -187,7 +183,11 @@ def _annotate_model_response(span: Any, result: Any) -> None:
|
|||
def _annotate_tool_result(span: Any, result: Any) -> None:
|
||||
try:
|
||||
if isinstance(result, ToolMessage):
|
||||
content = result.content if isinstance(result.content, str) else repr(result.content)
|
||||
content = (
|
||||
result.content
|
||||
if isinstance(result.content, str)
|
||||
else repr(result.content)
|
||||
)
|
||||
span.set_attribute("tool.output.size", len(content))
|
||||
status = getattr(result, "status", None)
|
||||
if isinstance(status, str):
|
||||
|
|
|
|||
|
|
@ -145,7 +145,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
try:
|
||||
patterns = resolver(args or {})
|
||||
except Exception:
|
||||
logger.exception("Pattern resolver for %s raised; using bare name", tool_name)
|
||||
logger.exception(
|
||||
"Pattern resolver for %s raised; using bare name", tool_name
|
||||
)
|
||||
patterns = [tool_name]
|
||||
if not patterns:
|
||||
patterns = [tool_name]
|
||||
|
|
@ -198,11 +200,14 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
# Tier 3b: permission.asked + interrupt.raised spans (no-op when
|
||||
# OTel is disabled). Both fire here so dashboards can correlate
|
||||
# "we asked X" with "interrupt was actually delivered".
|
||||
with ot.permission_asked_span(
|
||||
permission=tool_name,
|
||||
pattern=patterns[0] if patterns else None,
|
||||
extra={"permission.patterns": list(patterns)},
|
||||
), ot.interrupt_span(interrupt_type="permission_ask"):
|
||||
with (
|
||||
ot.permission_asked_span(
|
||||
permission=tool_name,
|
||||
pattern=patterns[0] if patterns else None,
|
||||
extra={"permission.patterns": list(patterns)},
|
||||
),
|
||||
ot.interrupt_span(interrupt_type="permission_ask"),
|
||||
):
|
||||
decision = interrupt(payload)
|
||||
if isinstance(decision, dict):
|
||||
return decision
|
||||
|
|
@ -211,9 +216,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
return {"decision_type": decision}
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
def _persist_always(
|
||||
self, tool_name: str, patterns: list[str]
|
||||
) -> None:
|
||||
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
|
||||
"""Promote ``always`` reply into runtime allow rules.
|
||||
|
||||
Persistence to ``agent_permission_rules`` is done by the
|
||||
|
|
@ -276,12 +279,16 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
any_change = False
|
||||
|
||||
for raw in last.tool_calls:
|
||||
call = dict(raw) if isinstance(raw, dict) else {
|
||||
"name": getattr(raw, "name", None),
|
||||
"args": getattr(raw, "args", {}),
|
||||
"id": getattr(raw, "id", None),
|
||||
"type": "tool_call",
|
||||
}
|
||||
call = (
|
||||
dict(raw)
|
||||
if isinstance(raw, dict)
|
||||
else {
|
||||
"name": getattr(raw, "name", None),
|
||||
"args": getattr(raw, "args", {}),
|
||||
"id": getattr(raw, "id", None),
|
||||
"type": "tool_call",
|
||||
}
|
||||
)
|
||||
name = call.get("name") or ""
|
||||
args = call.get("args") or {}
|
||||
action, patterns, rules = self._evaluate(name, args)
|
||||
|
|
@ -307,7 +314,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
feedback = decision.get("feedback")
|
||||
if isinstance(feedback, str) and feedback.strip():
|
||||
raise CorrectedError(feedback, tool=name)
|
||||
raise RejectedError(tool=name, pattern=patterns[0] if patterns else None)
|
||||
raise RejectedError(
|
||||
tool=name, pattern=patterns[0] if patterns else None
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown permission decision %r; treating as reject", kind
|
||||
|
|
|
|||
|
|
@ -113,7 +113,9 @@ def _exponential_delay(
|
|||
jitter: bool,
|
||||
) -> float:
|
||||
"""Compute an exponential-backoff delay with optional ±25% jitter."""
|
||||
delay = initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
|
||||
delay = (
|
||||
initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
|
||||
)
|
||||
delay = min(delay, max_delay)
|
||||
if jitter and delay > 0:
|
||||
delay *= 1 + random.uniform(-0.25, 0.25)
|
||||
|
|
@ -201,7 +203,9 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
|
|||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("dispatch_custom_event failed; suppressed", exc_info=True)
|
||||
logger.debug(
|
||||
"dispatch_custom_event failed; suppressed", exc_info=True
|
||||
)
|
||||
if delay > 0:
|
||||
time.sleep(delay)
|
||||
# Unreachable
|
||||
|
|
@ -210,7 +214,9 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
|
|||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ gives a clean failure mode if anything tries.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import replace
|
||||
|
|
@ -114,8 +115,10 @@ class BuiltinSkillsBackend(BackendProtocol):
|
|||
infos: list[FileInfo] = []
|
||||
# Build virtual paths anchored at "/" because CompositeBackend already
|
||||
# stripped the route prefix before calling us.
|
||||
target_virtual = "/" if target == self.root else (
|
||||
"/" + str(target.relative_to(self.root)).replace("\\", "/")
|
||||
target_virtual = (
|
||||
"/"
|
||||
if target == self.root
|
||||
else ("/" + str(target.relative_to(self.root)).replace("\\", "/"))
|
||||
)
|
||||
for child in sorted(target.iterdir()):
|
||||
child_virtual = (
|
||||
|
|
@ -128,10 +131,8 @@ class BuiltinSkillsBackend(BackendProtocol):
|
|||
"is_dir": child.is_dir(),
|
||||
}
|
||||
if child.is_file():
|
||||
try:
|
||||
with contextlib.suppress(OSError): # pragma: no cover - defensive
|
||||
info["size"] = child.stat().st_size
|
||||
except OSError: # pragma: no cover - defensive
|
||||
pass
|
||||
infos.append(info)
|
||||
return infos
|
||||
|
||||
|
|
@ -163,7 +164,9 @@ class BuiltinSkillsBackend(BackendProtocol):
|
|||
else:
|
||||
content = target.read_bytes()
|
||||
except PermissionError:
|
||||
responses.append(FileDownloadResponse(path=p, error="permission_denied"))
|
||||
responses.append(
|
||||
FileDownloadResponse(path=p, error="permission_denied")
|
||||
)
|
||||
continue
|
||||
except OSError as exc: # pragma: no cover - defensive
|
||||
logger.warning("Builtin skill read failed %s: %s", target, exc)
|
||||
|
|
@ -286,6 +289,7 @@ def build_skills_backend_factory(
|
|||
builtin = BuiltinSkillsBackend(builtin_root)
|
||||
|
||||
if search_space_id is None:
|
||||
|
||||
def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol:
|
||||
# Default StateBackend is intentionally inert: any path outside the
|
||||
# ``/skills/builtin/`` route resolves to an empty per-runtime state
|
||||
|
|
@ -294,6 +298,7 @@ def build_skills_backend_factory(
|
|||
default=StateBackend(runtime),
|
||||
routes={SKILLS_BUILTIN_PREFIX: builtin},
|
||||
)
|
||||
|
||||
return _factory_builtin_only
|
||||
|
||||
def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol:
|
||||
|
|
|
|||
|
|
@ -51,13 +51,15 @@ def _coerce_existing_tool_call(call: Any) -> dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
class ToolCallNameRepairMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""Two-stage tool-name repair on the most recent ``AIMessage``.
|
||||
|
||||
Args:
|
||||
registered_tool_names: Set of canonically-registered tool names.
|
||||
``invalid`` should be in this set so the fallback dispatches.
|
||||
fuzzy_match_threshold: Optional ``difflib`` ratio (0–1) for the
|
||||
fuzzy_match_threshold: Optional ``difflib`` ratio (0-1) for the
|
||||
fuzzy-match step that runs *between* lowercase and invalid.
|
||||
Set to ``None`` to disable fuzzy matching (opencode parity).
|
||||
"""
|
||||
|
|
@ -77,9 +79,9 @@ class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], Contex
|
|||
def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]:
|
||||
"""Allow runtime overrides to expand the set (e.g. dynamic MCP tools)."""
|
||||
ctx_tools = getattr(runtime.context, "registered_tool_names", None)
|
||||
if isinstance(ctx_tools, (set, frozenset)):
|
||||
if isinstance(ctx_tools, set | frozenset):
|
||||
return self._registered | set(ctx_tools)
|
||||
if isinstance(ctx_tools, (list, tuple)):
|
||||
if isinstance(ctx_tools, list | tuple):
|
||||
return self._registered | set(ctx_tools)
|
||||
return self._registered
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue