chore: linting

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-28 21:37:51 -07:00
parent b9a66cb417
commit ca9bbee06d
41 changed files with 314 additions and 244 deletions

View file

@ -88,7 +88,5 @@ def upgrade() -> None:
def downgrade() -> None:
op.drop_index(
"ix_agent_action_log_thread_created", table_name="agent_action_log"
)
op.drop_index("ix_agent_action_log_thread_created", table_name="agent_action_log")
op.drop_table("agent_action_log")

View file

@ -51,9 +51,7 @@ def upgrade() -> None:
# implicit-unique-index variant SQLAlchemy may emit need draining.
constraints = _existing_constraint_names(bind, "documents")
if "uq_documents_content_hash" in constraints:
op.drop_constraint(
"uq_documents_content_hash", "documents", type_="unique"
)
op.drop_constraint("uq_documents_content_hash", "documents", type_="unique")
indexes = _existing_index_names(bind, "documents")
# Some Postgres versions surface the unique constraint via a unique

View file

@ -416,10 +416,10 @@ async def create_surfsense_deep_agent(
# cheap to build. ``SubAgentMiddleware.__init__`` calls ``create_agent``
# synchronously to compile the general-purpose subagent's full state graph
# (every tool + every middleware → pydantic schemas + langgraph compile).
# On gpt-5.x agents that's roughly 1.52s of pure CPU work. If we run it
# On gpt-5.x agents that's roughly 1.5-2s of pure CPU work. If we run it
# directly here it blocks the asyncio event loop for the whole streaming
# task (and any other coroutine sharing this loop), which is why
# "agent creation" wall-clock time used to stretch to ~34s. Move the
# "agent creation" wall-clock time used to stretch to ~3-4s. Move the
# entire middleware build + main-graph compile into a single
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
# event loop stays responsive.
@ -587,10 +587,7 @@ def _build_compiled_agent_blocking(
# by name. Off by default until the flag flips so existing deployments
# don't see new agent types in the task tool description.
specialized_subagents: list[SubAgent] = []
if (
flags.enable_specialized_subagents
and not flags.disable_new_agent_stack
):
if flags.enable_specialized_subagents and not flags.disable_new_agent_stack:
try:
# Specialized subagents share the parent's filesystem +
# todo view so their system prompts (which promise
@ -696,7 +693,9 @@ def _build_compiled_agent_blocking(
else None
)
tool_call_limit_mw = (
ToolCallLimitMiddleware(thread_limit=300, run_limit=80, exit_behavior="continue")
ToolCallLimitMiddleware(
thread_limit=300, run_limit=80, exit_behavior="continue"
)
if flags.enable_tool_call_limit and not flags.disable_new_agent_stack
else None
)
@ -879,7 +878,11 @@ def _build_compiled_agent_blocking(
max_tools=12,
always_include=[
name
for name in ("update_memory", "get_connected_accounts", "scrape_webpage")
for name in (
"update_memory",
"get_connected_accounts",
"scrape_webpage",
)
if name in {t.name for t in tools}
],
)

View file

@ -65,7 +65,9 @@ class AgentFeatureFlags:
enable_model_call_limit: bool = False
enable_tool_call_limit: bool = False
enable_tool_call_repair: bool = False
enable_doom_loop: bool = False # Default OFF until UI handles permission='doom_loop'
enable_doom_loop: bool = (
False # Default OFF until UI handles permission='doom_loop'
)
# Tier 2 — Safety
enable_permission: bool = False # Default OFF for first deploy
@ -79,7 +81,9 @@ class AgentFeatureFlags:
# Tier 5 — Snapshot / revert
enable_action_log: bool = False
enable_revert_route: bool = False # Backend ships before UI; route returns 503 until this flips
enable_revert_route: bool = (
False # Backend ships before UI; route returns 503 until this flips
)
# Tier 6 — Plugins
enable_plugin_loader: bool = False
@ -109,14 +113,20 @@ class AgentFeatureFlags:
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False),
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False),
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
enable_model_call_limit=_env_bool("SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False),
enable_model_call_limit=_env_bool(
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False
),
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False),
enable_tool_call_repair=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False),
enable_tool_call_repair=_env_bool(
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False
),
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False),
# Tier 2
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False),
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False),
enable_llm_tool_selector=_env_bool("SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False),
enable_llm_tool_selector=_env_bool(
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
),
# Tier 4
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False),
enable_specialized_subagents=_env_bool(

View file

@ -101,9 +101,7 @@ class ActionLogMiddleware(AgentMiddleware):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
],
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
if not self._enabled():
return await handler(request)

View file

@ -177,8 +177,8 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware):
messages_in=len(conversation_messages),
extra={"compaction.cutoff_index": int(cutoff_index)},
):
messages_to_summarize, preserved_messages = (
super()._partition_messages(conversation_messages, cutoff_index)
messages_to_summarize, preserved_messages = super()._partition_messages(
conversation_messages, cutoff_index
)
protected: list[AnyMessage] = []

View file

@ -58,8 +58,7 @@ DEFAULT_SPILL_PREFIX = "/tool_outputs"
def _build_spill_placeholder(spill_path: str) -> str:
"""Build the user-facing placeholder text shown to the model."""
return (
f"[cleared — full output at {spill_path}; "
f"ask the explore subagent to read it]"
f"[cleared — full output at {spill_path}; ask the explore subagent to read it]"
)
@ -131,7 +130,9 @@ class SpillToBackendEdit(ContextEdit):
return
candidates = [
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
(idx, msg)
for idx, msg in enumerate(messages)
if isinstance(msg, ToolMessage)
]
if self.keep >= len(candidates):
return

View file

@ -137,16 +137,21 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon
triggered_call: dict[str, Any] | None = None
for call in message.tool_calls:
name = call.get("name") if isinstance(call, dict) else getattr(call, "name", None)
args = call.get("args") if isinstance(call, dict) else getattr(call, "args", {})
name = (
call.get("name")
if isinstance(call, dict)
else getattr(call, "name", None)
)
args = (
call.get("args")
if isinstance(call, dict)
else getattr(call, "args", {})
)
if not isinstance(name, str):
continue
sig = _signature(name, args)
window.append(sig)
if (
len(window) >= self._threshold
and len(set(window)) == 1
):
if len(window) >= self._threshold and len(set(window)) == 1:
triggered_call = {"name": name, "params": args or {}}
break
@ -209,7 +214,9 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon
# tool call proceeds. The frontend's exact reply names may differ —
# we tolerate any shape that contains a string with "reject"/"cancel".
if isinstance(decision, dict):
kind = str(decision.get("decision_type") or decision.get("type") or "").lower()
kind = str(
decision.get("decision_type") or decision.get("type") or ""
).lower()
if "reject" in kind or "cancel" in kind:
return {"jump_to": "end"}
return None

View file

@ -552,7 +552,7 @@ def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage:
for entry in priority:
score = entry.get("score")
mentioned = entry.get("mentioned")
score_str = f"{score:.3f}" if isinstance(score, (int, float)) else "n/a"
score_str = f"{score:.3f}" if isinstance(score, int | float) else "n/a"
mark = " [USER-MENTIONED]" if mentioned else ""
lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}")
body = "\n".join(lines)
@ -593,7 +593,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
self.top_k = top_k
self.mentioned_document_ids = mentioned_document_ids or []
# Tier 4.2: build the kb-planner private Runnable ONCE here so we
# don't pay the create_agent compile cost (50200ms) on every turn.
# don't pay the create_agent compile cost (50-200ms) on every turn.
# Disabled by default behind ``enable_kb_planner_runnable``; when off
# the planner falls back to the legacy ``self.llm.ainvoke`` path.
self._planner: Runnable | None = None
@ -617,10 +617,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
if self.llm is None:
return None
flags = get_flags()
if (
not flags.enable_kb_planner_runnable
or flags.disable_new_agent_stack
):
if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
return None
from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware
@ -920,7 +917,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
chunk_ids = doc.get("matched_chunk_ids") or []
if chunk_ids:
matched_chunk_ids[doc_id] = [
int(cid) for cid in chunk_ids if isinstance(cid, (int, str))
int(cid) for cid in chunk_ids if isinstance(cid, int | str)
]
return priority, matched_chunk_ids

View file

@ -35,9 +35,7 @@ from langchain_core.tools import tool
logger = logging.getLogger(__name__)
NOOP_TOOL_NAME = "_noop"
NOOP_TOOL_DESCRIPTION = (
"Do not call this tool. It exists only for API compatibility."
)
NOOP_TOOL_DESCRIPTION = "Do not call this tool. It exists only for API compatibility."
@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION)
@ -78,7 +76,9 @@ def _last_ai_has_tool_calls(messages: list[Any]) -> bool:
return False
class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
class NoopInjectionMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Inject the ``_noop`` tool only when the provider would otherwise 400.
The check fires per model call, not at agent build time, because the
@ -116,7 +116,9 @@ class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> Any:
if self._should_inject(request):
logger.debug("Injecting _noop tool for provider compatibility")

View file

@ -56,9 +56,7 @@ class OtelSpanMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[
[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]
],
handler: Callable[[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]],
) -> ModelResponse | AIMessage | Any:
if not ot.is_enabled():
return await handler(request)
@ -81,9 +79,7 @@ class OtelSpanMiddleware(AgentMiddleware):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
],
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
if not ot.is_enabled():
return await handler(request)
@ -187,7 +183,11 @@ def _annotate_model_response(span: Any, result: Any) -> None:
def _annotate_tool_result(span: Any, result: Any) -> None:
try:
if isinstance(result, ToolMessage):
content = result.content if isinstance(result.content, str) else repr(result.content)
content = (
result.content
if isinstance(result.content, str)
else repr(result.content)
)
span.set_attribute("tool.output.size", len(content))
status = getattr(result, "status", None)
if isinstance(status, str):

View file

@ -145,7 +145,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
try:
patterns = resolver(args or {})
except Exception:
logger.exception("Pattern resolver for %s raised; using bare name", tool_name)
logger.exception(
"Pattern resolver for %s raised; using bare name", tool_name
)
patterns = [tool_name]
if not patterns:
patterns = [tool_name]
@ -198,11 +200,14 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
# Tier 3b: permission.asked + interrupt.raised spans (no-op when
# OTel is disabled). Both fire here so dashboards can correlate
# "we asked X" with "interrupt was actually delivered".
with ot.permission_asked_span(
permission=tool_name,
pattern=patterns[0] if patterns else None,
extra={"permission.patterns": list(patterns)},
), ot.interrupt_span(interrupt_type="permission_ask"):
with (
ot.permission_asked_span(
permission=tool_name,
pattern=patterns[0] if patterns else None,
extra={"permission.patterns": list(patterns)},
),
ot.interrupt_span(interrupt_type="permission_ask"),
):
decision = interrupt(payload)
if isinstance(decision, dict):
return decision
@ -211,9 +216,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
return {"decision_type": decision}
return {"decision_type": "reject"}
def _persist_always(
self, tool_name: str, patterns: list[str]
) -> None:
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
"""Promote ``always`` reply into runtime allow rules.
Persistence to ``agent_permission_rules`` is done by the
@ -276,12 +279,16 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
any_change = False
for raw in last.tool_calls:
call = dict(raw) if isinstance(raw, dict) else {
"name": getattr(raw, "name", None),
"args": getattr(raw, "args", {}),
"id": getattr(raw, "id", None),
"type": "tool_call",
}
call = (
dict(raw)
if isinstance(raw, dict)
else {
"name": getattr(raw, "name", None),
"args": getattr(raw, "args", {}),
"id": getattr(raw, "id", None),
"type": "tool_call",
}
)
name = call.get("name") or ""
args = call.get("args") or {}
action, patterns, rules = self._evaluate(name, args)
@ -307,7 +314,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
feedback = decision.get("feedback")
if isinstance(feedback, str) and feedback.strip():
raise CorrectedError(feedback, tool=name)
raise RejectedError(tool=name, pattern=patterns[0] if patterns else None)
raise RejectedError(
tool=name, pattern=patterns[0] if patterns else None
)
else:
logger.warning(
"Unknown permission decision %r; treating as reject", kind

View file

@ -113,7 +113,9 @@ def _exponential_delay(
jitter: bool,
) -> float:
"""Compute an exponential-backoff delay with optional ±25% jitter."""
delay = initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
delay = (
initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
)
delay = min(delay, max_delay)
if jitter and delay > 0:
delay *= 1 + random.uniform(-0.25, 0.25)
@ -201,7 +203,9 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
},
)
except Exception:
logger.debug("dispatch_custom_event failed; suppressed", exc_info=True)
logger.debug(
"dispatch_custom_event failed; suppressed", exc_info=True
)
if delay > 0:
time.sleep(delay)
# Unreachable
@ -210,7 +214,9 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> ModelResponse[ResponseT] | AIMessage:
for attempt in range(self.max_retries + 1):
try:

View file

@ -29,6 +29,7 @@ gives a clean failure mode if anything tries.
from __future__ import annotations
import contextlib
import logging
from collections.abc import Callable
from dataclasses import replace
@ -114,8 +115,10 @@ class BuiltinSkillsBackend(BackendProtocol):
infos: list[FileInfo] = []
# Build virtual paths anchored at "/" because CompositeBackend already
# stripped the route prefix before calling us.
target_virtual = "/" if target == self.root else (
"/" + str(target.relative_to(self.root)).replace("\\", "/")
target_virtual = (
"/"
if target == self.root
else ("/" + str(target.relative_to(self.root)).replace("\\", "/"))
)
for child in sorted(target.iterdir()):
child_virtual = (
@ -128,10 +131,8 @@ class BuiltinSkillsBackend(BackendProtocol):
"is_dir": child.is_dir(),
}
if child.is_file():
try:
with contextlib.suppress(OSError): # pragma: no cover - defensive
info["size"] = child.stat().st_size
except OSError: # pragma: no cover - defensive
pass
infos.append(info)
return infos
@ -163,7 +164,9 @@ class BuiltinSkillsBackend(BackendProtocol):
else:
content = target.read_bytes()
except PermissionError:
responses.append(FileDownloadResponse(path=p, error="permission_denied"))
responses.append(
FileDownloadResponse(path=p, error="permission_denied")
)
continue
except OSError as exc: # pragma: no cover - defensive
logger.warning("Builtin skill read failed %s: %s", target, exc)
@ -286,6 +289,7 @@ def build_skills_backend_factory(
builtin = BuiltinSkillsBackend(builtin_root)
if search_space_id is None:
def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol:
# Default StateBackend is intentionally inert: any path outside the
# ``/skills/builtin/`` route resolves to an empty per-runtime state
@ -294,6 +298,7 @@ def build_skills_backend_factory(
default=StateBackend(runtime),
routes={SKILLS_BUILTIN_PREFIX: builtin},
)
return _factory_builtin_only
def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol:

View file

@ -51,13 +51,15 @@ def _coerce_existing_tool_call(call: Any) -> dict[str, Any]:
}
class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
class ToolCallNameRepairMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Two-stage tool-name repair on the most recent ``AIMessage``.
Args:
registered_tool_names: Set of canonically-registered tool names.
``invalid`` should be in this set so the fallback dispatches.
fuzzy_match_threshold: Optional ``difflib`` ratio (01) for the
fuzzy_match_threshold: Optional ``difflib`` ratio (0-1) for the
fuzzy-match step that runs *between* lowercase and invalid.
Set to ``None`` to disable fuzzy matching (opencode parity).
"""
@ -77,9 +79,9 @@ class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], Contex
def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]:
"""Allow runtime overrides to expand the set (e.g. dynamic MCP tools)."""
ctx_tools = getattr(runtime.context, "registered_tool_names", None)
if isinstance(ctx_tools, (set, frozenset)):
if isinstance(ctx_tools, set | frozenset):
return self._registered | set(ctx_tools)
if isinstance(ctx_tools, (list, tuple)):
if isinstance(ctx_tools, list | tuple):
return self._registered | set(ctx_tools)
return self._registered

View file

@ -52,25 +52,26 @@ class _YearSubstituterMiddleware(AgentMiddleware):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
],
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
result = await handler(request)
try:
from langchain_core.messages import ToolMessage
if isinstance(result, ToolMessage) and isinstance(result.content, str):
if "{{year}}" in result.content:
new_text = result.content.replace("{{year}}", self._year)
result = ToolMessage(
content=new_text,
tool_call_id=result.tool_call_id,
id=result.id,
name=result.name,
status=result.status,
artifact=result.artifact,
)
if (
isinstance(result, ToolMessage)
and isinstance(result.content, str)
and "{{year}}" in result.content
):
new_text = result.content.replace("{{year}}", self._year)
result = ToolMessage(
content=new_text,
tool_call_id=result.tool_call_id,
id=result.id,
name=result.name,
status=result.status,
artifact=result.artifact,
)
except Exception: # pragma: no cover - defensive
logger.exception("year_substituter plugin failed; passing original result")
return result

View file

@ -62,7 +62,9 @@ ProviderVariant = str
# More specific patterns must come first (e.g. ``codex`` before
# ``openai_reasoning`` because codex model ids contain ``gpt``).
_OPENAI_CODEX_RE = re.compile(r"\b(gpt-codex|codex-mini|gpt-[\d.]+-codex)\b", re.IGNORECASE)
_OPENAI_CODEX_RE = re.compile(
r"\b(gpt-codex|codex-mini|gpt-[\d.]+-codex)\b", re.IGNORECASE
)
_OPENAI_REASONING_RE = re.compile(r"\b(gpt-5|o\d|o-)", re.IGNORECASE)
_OPENAI_CLASSIC_RE = re.compile(r"\bgpt-4", re.IGNORECASE)
_ANTHROPIC_RE = re.compile(r"\bclaude\b", re.IGNORECASE)
@ -257,9 +259,7 @@ def _build_tools_section(
)
if known_disabled:
disabled_list = ", ".join(
_format_tool_label(n)
for n in ALL_TOOL_NAMES_ORDERED
if n in known_disabled
_format_tool_label(n) for n in ALL_TOOL_NAMES_ORDERED if n in known_disabled
)
parts.append(
"\n"

View file

@ -279,9 +279,7 @@ def build_explore_subagent(
selected_tools = _filter_tools(tools, EXPLORE_READ_TOOLS)
deny_rules = _read_only_deny_rules()
permission_mw = _build_permission_middleware(
deny_rules, origin="subagent_explore"
)
permission_mw = _build_permission_middleware(deny_rules, origin="subagent_explore")
spec: dict = {
"name": "explore",

View file

@ -111,6 +111,8 @@ from .update_memory import create_update_memory_tool, create_update_team_memory_
from .video_presentation import create_generate_video_presentation_tool
from .web_search import create_web_search_tool
logger = logging.getLogger(__name__)
# =============================================================================
# Tool Definition
# =============================================================================

View file

@ -22,6 +22,7 @@ Goals
from __future__ import annotations
import contextlib
import logging
import os
from collections.abc import Iterator
@ -154,18 +155,14 @@ def span(
with tracer.start_as_current_span(name) as sp:
if attributes:
try:
with contextlib.suppress(Exception): # pragma: no cover — defensive
sp.set_attributes(attributes)
except Exception: # pragma: no cover — defensive
pass
try:
yield sp
except BaseException as exc:
try:
with contextlib.suppress(Exception): # pragma: no cover — defensive
sp.record_exception(exc)
sp.set_status(_OtStatus(_OtStatusCode.ERROR, str(exc)))
except Exception: # pragma: no cover — defensive
pass
raise

View file

@ -59,7 +59,7 @@ class AgentFeatureFlagsRead(BaseModel):
enable_otel: bool
@classmethod
def from_flags(cls, flags: AgentFeatureFlags) -> "AgentFeatureFlagsRead":
def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead:
# asdict() avoids missing-field bugs when AgentFeatureFlags grows.
return cls(**asdict(flags))

View file

@ -210,7 +210,7 @@ async def create_rule(
session.add(row)
try:
await session.commit()
except IntegrityError:
except IntegrityError as err:
await session.rollback()
raise HTTPException(
status_code=409,
@ -218,7 +218,7 @@ async def create_rule(
"An identical rule already exists for this scope. Update the "
"existing rule instead."
),
)
) from err
await session.refresh(row)
return _to_read(row)
@ -248,12 +248,12 @@ async def update_rule(
try:
await session.commit()
except IntegrityError:
except IntegrityError as err:
await session.rollback()
raise HTTPException(
status_code=409,
detail="Update would create a duplicate rule for this scope.",
)
) from err
await session.refresh(row)
return _to_read(row)

View file

@ -97,10 +97,12 @@ async def revert_agent_action(
action=action,
requester_user_id=str(user.id) if user is not None else None,
)
except Exception:
except Exception as err:
logger.exception("Revert dispatch raised for action_id=%s", action_id)
await session.rollback()
raise HTTPException(status_code=500, detail="Internal error during revert.")
raise HTTPException(
status_code=500, detail="Internal error during revert."
) from err
if outcome.status == "ok":
await session.commit()

View file

@ -1242,7 +1242,9 @@ async def handle_new_chat(
await session.close()
image_urls = (
[p.as_data_url() for p in request.user_images] if request.user_images else None
[p.as_data_url() for p in request.user_images]
if request.user_images
else None
)
return StreamingResponse(

View file

@ -79,9 +79,7 @@ async def load_action(
return result.scalars().first()
async def load_thread(
session: AsyncSession, *, thread_id: int
) -> NewChatThread | None:
async def load_thread(session: AsyncSession, *, thread_id: int) -> NewChatThread | None:
stmt = select(NewChatThread).where(NewChatThread.id == thread_id)
result = await session.execute(stmt)
return result.scalars().first()

View file

@ -7,7 +7,9 @@ import binascii
from typing import Any
def build_human_message_content(final_query: str, image_data_urls: list[str]) -> str | list[dict[str, Any]]:
def build_human_message_content(
final_query: str, image_data_urls: list[str]
) -> str | list[dict[str, Any]]:
if not image_data_urls:
return final_query
parts: list[dict[str, Any]] = [{"type": "text", "text": final_query}]

View file

@ -90,9 +90,7 @@ class TestCompose:
assert "<citation_instructions>" in prompt
assert "[citation:chunk_id]" in prompt
def test_team_visibility_uses_team_variants(
self, fixed_today: datetime
) -> None:
def test_team_visibility_uses_team_variants(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(
today=fixed_today,
thread_visibility=ChatVisibility.SEARCH_SPACE,
@ -145,9 +143,7 @@ class TestCompose:
assert "Generate Image" in prompt
assert "Generate Podcast" in prompt
def test_mcp_routing_block_emits_when_provided(
self, fixed_today: datetime
) -> None:
def test_mcp_routing_block_emits_when_provided(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(
today=fixed_today,
mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]},
@ -162,9 +158,7 @@ class TestCompose:
prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={})
assert "<mcp_tool_routing>" not in prompt
def test_provider_block_renders_when_anthropic(
self, fixed_today: datetime
) -> None:
def test_provider_block_renders_when_anthropic(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(
today=fixed_today, model_name="anthropic:claude-3-5-sonnet"
)
@ -267,7 +261,10 @@ class TestStableOrderingForCacheStability:
)
b = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"scrape_webpage", "web_search"}, # set order shouldn't matter
enabled_tool_names={
"scrape_webpage",
"web_search",
}, # set order shouldn't matter
mcp_connector_tools={"X": ["x_a", "x_b"]},
)
assert a == b

View file

@ -83,7 +83,11 @@ class TestActionLogMiddlewareDisabled:
async def test_no_op_when_flag_off(self, patch_get_flags) -> None:
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {"color": "red", "size": 1}, "id": "tc1"}
tool_call={
"name": "make_widget",
"args": {"color": "red", "size": 1},
"id": "tc1",
}
)
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
with patch_get_flags(_disabled_flags()):
@ -117,13 +121,12 @@ class TestActionLogMiddlewarePersistence:
"id": "tc-abc",
},
)
result_msg = ToolMessage(
content="ok", tool_call_id="tc-abc", id="msg-1"
)
result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1")
handler = AsyncMock(return_value=result_msg)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
):
result = await mw.awrap_tool_call(request, handler)
@ -151,9 +154,11 @@ class TestActionLogMiddlewarePersistence:
)
handler = AsyncMock(side_effect=ValueError("boom"))
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
), pytest.raises(ValueError, match="boom"):
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
pytest.raises(ValueError, match="boom"),
):
await mw.awrap_tool_call(request, handler)
assert len(captured["rows"]) == 1
@ -177,8 +182,9 @@ class TestActionLogMiddlewarePersistence:
def _exploding_session():
raise RuntimeError("DB is down")
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=_exploding_session
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=_exploding_session),
):
result = await mw.awrap_tool_call(request, handler)
assert result is result_msg
@ -218,8 +224,9 @@ class TestReverseDescriptor:
)
handler = AsyncMock(return_value=result_msg)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
):
await mw.awrap_tool_call(request, handler)
@ -257,8 +264,9 @@ class TestReverseDescriptor:
result_msg = ToolMessage(content="ok", tool_call_id="tc1")
handler = AsyncMock(return_value=result_msg)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
):
await mw.awrap_tool_call(request, handler)
@ -275,11 +283,10 @@ class TestReverseDescriptor:
request = _FakeRequest(
tool_call={"name": "unknown_tool", "args": {}, "id": "tc1"}
)
handler = AsyncMock(
return_value=ToolMessage(content="ok", tool_call_id="tc1")
)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
):
await mw.awrap_tool_call(request, handler)
row = captured["rows"][0]
@ -298,11 +305,10 @@ class TestArgsTruncation:
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {"blob": huge}, "id": "tc1"},
)
handler = AsyncMock(
return_value=ToolMessage(content="ok", tool_call_id="tc1")
)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
):
await mw.awrap_tool_call(request, handler)
row = captured["rows"][0]

View file

@ -26,10 +26,16 @@ class TestIsProtectedSystemMessage:
assert _is_protected_system_message(msg) is True
def test_unprotected_system_message(self) -> None:
assert _is_protected_system_message(SystemMessage(content="random instructions")) is False
assert (
_is_protected_system_message(SystemMessage(content="random instructions"))
is False
)
def test_human_message_never_protected(self) -> None:
assert _is_protected_system_message(HumanMessage(content="<workspace_tree>...")) is False
assert (
_is_protected_system_message(HumanMessage(content="<workspace_tree>..."))
is False
)
def test_tolerates_leading_whitespace(self) -> None:
msg = SystemMessage(content=" \n<priority_documents>\n...")
@ -97,11 +103,17 @@ class TestPartitionMessages:
assert protected not in to_summary
assert protected in preserved
# The non-protected old messages remain in to_summary
assert any(isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary)
assert any(
isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary
)
def test_unprotected_messages_unaffected(self) -> None:
partitioner = self._build_partitioner()
msgs = [HumanMessage(content="a"), HumanMessage(content="b"), HumanMessage(content="c")]
msgs = [
HumanMessage(content="a"),
HumanMessage(content="b"),
HumanMessage(content="c"),
]
to_summary, preserved = partitioner._partition_messages(msgs, 2)
assert [m.content for m in to_summary] == ["a", "b"]
assert [m.content for m in preserved] == ["c"]

View file

@ -70,7 +70,8 @@ class TestSpillEdit:
# Earlier ToolMessages should now contain the placeholder text
cleared = [
m for m in tool_messages
m
for m in tool_messages
if isinstance(m.content, str) and m.content.startswith("[cleared")
]
assert len(cleared) >= 1

View file

@ -46,9 +46,21 @@ def test_callable_dedup_key_takes_priority() -> None:
state = {
"messages": [
_msg(
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "1"},
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "2"},
{"name": "create_doc", "args": {"parent_id": "x", "title": "z"}, "id": "3"},
{
"name": "create_doc",
"args": {"parent_id": "x", "title": "y"},
"id": "1",
},
{
"name": "create_doc",
"args": {"parent_id": "x", "title": "y"},
"id": "2",
},
{
"name": "create_doc",
"args": {"parent_id": "x", "title": "z"},
"id": "3",
},
)
]
}

View file

@ -84,9 +84,7 @@ class TestConnectorDenyOverridesDefaultAllow:
Rule(permission="linear_create_issue", pattern="*", action="deny")
]
)
rules = evaluate_many(
"linear_create_issue", ["linear_create_issue"], *rulesets
)
rules = evaluate_many("linear_create_issue", ["linear_create_issue"], *rulesets)
assert aggregate_action(rules) == "deny"
def test_default_allow_still_applies_to_other_tools(self) -> None:
@ -124,5 +122,7 @@ class TestUserRuleOverridesDefault:
rules=[Rule(permission="send_*", pattern="*", action="deny")],
origin="user",
)
rules = evaluate_many("send_gmail_email", ["send_gmail_email"], defaults, user_ruleset)
rules = evaluate_many(
"send_gmail_email", ["send_gmail_email"], defaults, user_ruleset
)
assert aggregate_action(rules) == "deny"

View file

@ -64,22 +64,17 @@ def test_threshold_triggers_after_n_identical_calls() -> None:
runtime,
)
name = type(excinfo.value).__name__.lower()
assert (
"interrupt" in name
or "runtimeerror" in name
), f"Expected an interrupt-style exception, got {name}"
assert "interrupt" in name or "runtimeerror" in name, (
f"Expected an interrupt-style exception, got {name}"
)
def test_does_not_trigger_when_args_differ() -> None:
mw = DoomLoopMiddleware(threshold=2)
runtime = _FakeRuntime()
out = mw.after_model(
{"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime
)
out = mw.after_model({"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime)
assert out is None
out = mw.after_model(
{"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime
)
out = mw.after_model({"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime)
assert out is None

View file

@ -91,7 +91,9 @@ class TestShouldInject:
mw = NoopInjectionMiddleware()
req = _FakeRequest(
tools=[object()],
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])],
messages=[
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])
],
model=_LiteLLMModel(),
)
assert mw._should_inject(req) is False
@ -109,7 +111,9 @@ class TestShouldInject:
mw = NoopInjectionMiddleware()
req = _FakeRequest(
tools=[],
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])],
messages=[
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])
],
model=_OpenAIModel(),
)
assert mw._should_inject(req) is False

View file

@ -111,6 +111,4 @@ class TestAsk:
assert out is None # call kept
# Runtime ruleset got the always-allow rule
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
assert any(
r.permission == "send_email" for r in new_rules
)
assert any(r.permission == "send_email" for r in new_rules)

View file

@ -69,7 +69,9 @@ class TestPluginLoaderBasics:
"app.agents.new_chat.plugin_loader.entry_points",
return_value=[ep],
):
result = load_plugin_middlewares(_ctx(), allowed_plugin_names=["allowed_only"])
result = load_plugin_middlewares(
_ctx(), allowed_plugin_names=["allowed_only"]
)
assert result == []
assert not called
@ -135,9 +137,7 @@ class TestPluginLoaderIsolation:
_FakeEntryPoint("crashing", crashing_factory),
_FakeEntryPoint("ok", year_substituter_factory),
]
with patch(
"app.agents.new_chat.plugin_loader.entry_points", return_value=eps
):
with patch("app.agents.new_chat.plugin_loader.entry_points", return_value=eps):
result = load_plugin_middlewares(
_ctx(), allowed_plugin_names={"crashing", "ok"}
)
@ -151,9 +151,7 @@ class TestAllowlistEnv:
assert load_allowed_plugin_names_from_env() == set()
def test_parses_comma_separated_value(self, monkeypatch) -> None:
monkeypatch.setenv(
"SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , "
)
monkeypatch.setenv("SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , ")
assert load_allowed_plugin_names_from_env() == {
"year_substituter",
"noisy",

View file

@ -18,7 +18,7 @@ class _FakeResponse:
self.headers = headers
class _FakeRateLimit(Exception):
class _FakeRateLimitError(Exception):
def __init__(self, msg: str, headers: dict[str, str] | None = None) -> None:
super().__init__(msg)
if headers is not None:
@ -27,15 +27,15 @@ class _FakeRateLimit(Exception):
class TestExtractRetryAfter:
def test_seconds_header(self) -> None:
exc = _FakeRateLimit("rate", {"Retry-After": "30"})
exc = _FakeRateLimitError("rate", {"Retry-After": "30"})
assert _extract_retry_after_seconds(exc) == 30.0
def test_milliseconds_header_overrides_seconds(self) -> None:
exc = _FakeRateLimit("rate", {"retry-after-ms": "1500"})
exc = _FakeRateLimitError("rate", {"retry-after-ms": "1500"})
assert _extract_retry_after_seconds(exc) == 1.5
def test_case_insensitive(self) -> None:
exc = _FakeRateLimit("rate", {"RETRY-AFTER": "12"})
exc = _FakeRateLimitError("rate", {"RETRY-AFTER": "12"})
assert _extract_retry_after_seconds(exc) == 12.0
def test_falls_back_to_message_regex(self) -> None:
@ -67,7 +67,7 @@ class TestIsNonRetryable:
class TestDelayCalculation:
def test_takes_max_of_backoff_and_header(self) -> None:
mw = RetryAfterMiddleware(max_retries=3, initial_delay=1.0, jitter=False)
exc = _FakeRateLimit("rl", {"retry-after": "10"})
exc = _FakeRateLimitError("rl", {"retry-after": "10"})
delay = mw._delay_for_attempt(0, exc)
assert delay == pytest.approx(10.0)

View file

@ -122,7 +122,9 @@ class TestExploreSubagent:
def test_includes_permission_middleware_with_deny_rules(self) -> None:
spec = build_explore_subagent(tools=ALL_TOOLS)
permission_mws = [
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
m
for m in spec["middleware"]
if isinstance(m, PermissionMiddleware) # type: ignore[index]
]
assert len(permission_mws) == 1
ruleset = permission_mws[0]._static_rulesets[0]
@ -164,7 +166,9 @@ class TestReportWriterSubagent:
def test_deny_rules_block_writes_but_allow_generate_report(self) -> None:
spec = build_report_writer_subagent(tools=ALL_TOOLS)
permission_mws = [
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
m
for m in spec["middleware"]
if isinstance(m, PermissionMiddleware) # type: ignore[index]
]
ruleset = permission_mws[0]._static_rulesets[0]
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
@ -194,17 +198,15 @@ class TestConnectorNegotiatorSubagent:
def test_deny_ruleset_blocks_mutating_connector_tools(self) -> None:
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
permission_mws = [
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
m
for m in spec["middleware"]
if isinstance(m, PermissionMiddleware) # type: ignore[index]
]
ruleset = permission_mws[0]._static_rulesets[0]
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
# `linear_create_issue` matches the `*_create` deny pattern.
assert any(
_wildcard_matches(p, "linear_create_issue") for p in deny_patterns
)
assert any(
_wildcard_matches(p, "slack_send_message") for p in deny_patterns
)
assert any(_wildcard_matches(p, "linear_create_issue") for p in deny_patterns)
assert any(_wildcard_matches(p, "slack_send_message") for p in deny_patterns)
class TestBuildSpecializedSubagents:
@ -242,8 +244,7 @@ class TestBuildSpecializedSubagents:
# order: extra → custom → patch → dedup.
sentinel_idx = mws.index(sentinel)
perm_idx = next(
(i for i, m in enumerate(mws)
if isinstance(m, PermissionMiddleware)),
(i for i, m in enumerate(mws) if isinstance(m, PermissionMiddleware)),
None,
)
assert perm_idx is not None
@ -259,7 +260,9 @@ class TestFilterToolsWarningSuppression:
from app.agents.new_chat.subagents.config import _filter_tools
with caplog.at_level(logging.INFO, logger="app.agents.new_chat.subagents.config"):
with caplog.at_level(
logging.INFO, logger="app.agents.new_chat.subagents.config"
):
# Allowed set asks for two registry tools (one present, one
# not) plus a bunch of middleware-provided names.
_filter_tools(
@ -275,9 +278,7 @@ class TestFilterToolsWarningSuppression:
},
)
warnings = [
r.message for r in caplog.records if r.levelno >= logging.INFO
]
warnings = [r.message for r in caplog.records if r.levelno >= logging.INFO]
# Exactly one warning, and it should mention scrape_webpage but not
# any middleware-provided name. Inspect the rendered "missing"
# list (between the brackets) so we don't false-match substrings

View file

@ -27,9 +27,12 @@ class TestRepair:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None
)
msg = AIMessage(content="", tool_calls=[
{"name": "echo", "args": {}, "id": "1"},
])
msg = AIMessage(
content="",
tool_calls=[
{"name": "echo", "args": {}, "id": "1"},
],
)
out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is None # no change
@ -37,9 +40,12 @@ class TestRepair:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None
)
msg = AIMessage(content="", tool_calls=[
{"name": "Echo", "args": {"x": 1}, "id": "1"},
])
msg = AIMessage(
content="",
tool_calls=[
{"name": "Echo", "args": {"x": 1}, "id": "1"},
],
)
out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is not None
repaired = out["messages"][0]
@ -50,9 +56,12 @@ class TestRepair:
registered_tool_names={"echo", INVALID_TOOL_NAME},
fuzzy_match_threshold=None,
)
msg = AIMessage(content="", tool_calls=[
{"name": "totally_different_name", "args": {"k": "v"}, "id": "1"},
])
msg = AIMessage(
content="",
tool_calls=[
{"name": "totally_different_name", "args": {"k": "v"}, "id": "1"},
],
)
out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is not None
repaired_call = out["messages"][0].tool_calls[0]
@ -64,9 +73,12 @@ class TestRepair:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None
)
msg = AIMessage(content="", tool_calls=[
{"name": "unknown", "args": {}, "id": "1"},
])
msg = AIMessage(
content="",
tool_calls=[
{"name": "unknown", "args": {}, "id": "1"},
],
)
out = mw.after_model(_make_state(msg), _FakeRuntime())
# No repair available; original returned unchanged (no update)
assert out is None
@ -76,9 +88,12 @@ class TestRepair:
registered_tool_names={"search_documents"},
fuzzy_match_threshold=0.7,
)
msg = AIMessage(content="", tool_calls=[
{"name": "search_docments", "args": {}, "id": "1"},
])
msg = AIMessage(
content="",
tool_calls=[
{"name": "search_docments", "args": {}, "id": "1"},
],
)
out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is not None
assert out["messages"][0].tool_calls[0]["name"] == "search_documents"
@ -94,9 +109,12 @@ class TestRepair:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None
)
msg = AIMessage(content="", tool_calls=[
{"name": "DynamicTool", "args": {}, "id": "1"},
])
msg = AIMessage(
content="",
tool_calls=[
{"name": "DynamicTool", "args": {}, "id": "1"},
],
)
runtime = _FakeRuntime(SimpleNamespace(registered_tool_names=["dynamictool"]))
out = mw.after_model(_make_state(msg), runtime)
assert out is not None

View file

@ -10,7 +10,7 @@ through :class:`KnowledgeBasePersistenceMiddleware` without losing the copy.
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock
import numpy as np
import pytest

View file

@ -16,9 +16,7 @@ class _FakeAction:
class TestCanRevert:
def test_owner_can_revert_their_own_action(self) -> None:
action = _FakeAction(user_id="user-123")
assert can_revert(
requester_user_id="user-123", action=action, is_admin=False
)
assert can_revert(requester_user_id="user-123", action=action, is_admin=False)
def test_other_user_cannot_revert(self) -> None:
action = _FakeAction(user_id="user-123")
@ -28,21 +26,15 @@ class TestCanRevert:
def test_admin_always_allowed(self) -> None:
action = _FakeAction(user_id="user-123")
assert can_revert(
requester_user_id="anybody", action=action, is_admin=True
)
assert can_revert(requester_user_id="anybody", action=action, is_admin=True)
def test_admin_can_revert_anonymous_action(self) -> None:
action = _FakeAction(user_id=None)
assert can_revert(
requester_user_id="admin", action=action, is_admin=True
)
assert can_revert(requester_user_id="admin", action=action, is_admin=True)
def test_anonymous_action_blocks_non_admin(self) -> None:
action = _FakeAction(user_id=None)
assert not can_revert(
requester_user_id="user-1", action=action, is_admin=False
)
assert not can_revert(requester_user_id="user-1", action=action, is_admin=False)
def test_uuid_string_normalization(self) -> None:
"""``user_id`` may be a UUID object; comparison should still work."""
@ -51,6 +43,4 @@ class TestCanRevert:
u = uuid.uuid4()
action = _FakeAction(user_id=u)
# Same UUID, passed as string from the requesting side.
assert can_revert(
requester_user_id=str(u), action=action, is_admin=False
)
assert can_revert(requester_user_id=str(u), action=action, is_admin=False)