chore: linting

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-28 19:21:29 -07:00
parent 4dda02c06c
commit 94e834134f
80 changed files with 443 additions and 404 deletions

View file

@ -98,9 +98,7 @@ def upgrade() -> None:
op.execute(
"CREATE INDEX ix_automation_triggers_automation_id ON automation_triggers(automation_id);"
)
op.execute(
"CREATE INDEX ix_automation_triggers_type ON automation_triggers(type);"
)
op.execute("CREATE INDEX ix_automation_triggers_type ON automation_triggers(type);")
op.execute(
"CREATE INDEX ix_automation_triggers_enabled ON automation_triggers(enabled);"
)

View file

@ -28,7 +28,6 @@ from __future__ import annotations
from datetime import UTC, datetime
_HEADER = """\
You are the SurfSense automation drafter. Convert the user intent below
into a SINGLE JSON object matching the AutomationCreate schema. Output

View file

@ -404,9 +404,7 @@ def build_task_tool_with_parent_config(
continue
messages = payload.get("messages") or []
last_text = _safe_message_text(messages[-1]).rstrip() if messages else ""
message_blocks.append(
f"[task {task_index}] {last_text or '<empty>'}"
)
message_blocks.append(f"[task {task_index}] {last_text or '<empty>'}")
try:
child_trace = _build_tool_trace(messages)
except Exception:

View file

@ -117,9 +117,7 @@ def create_generate_podcast_tool(
"podcast_id": podcast_id,
"title": podcast_title,
"file_location": file_location,
"message": (
"Podcast generated and saved to your podcast panel."
),
"message": ("Podcast generated and saved to your podcast panel."),
}
return with_receipt(
payload=payload,

View file

@ -126,8 +126,7 @@ def create_generate_video_presentation_tool(
elapsed,
)
err = (
"Background worker reported FAILED status for this "
"video presentation."
"Background worker reported FAILED status for this video presentation."
)
payload = {
"status": VideoPresentationStatus.FAILED.value,
@ -151,9 +150,7 @@ def create_generate_video_presentation_tool(
except Exception as e:
error_message = str(e)
logger.exception(
"[generate_video_presentation] Error: %s", error_message
)
logger.exception("[generate_video_presentation] Error: %s", error_message)
payload = {
"status": VideoPresentationStatus.FAILED.value,
"error": error_message,

View file

@ -131,9 +131,7 @@ def create_generate_podcast_tool(
"podcast_id": podcast_id,
"title": podcast_title,
"file_location": file_location,
"message": (
"Podcast generated and saved to your podcast panel."
),
"message": ("Podcast generated and saved to your podcast panel."),
}
# Only other terminal state is FAILED.
@ -146,9 +144,7 @@ def create_generate_podcast_tool(
"status": PodcastStatus.FAILED.value,
"podcast_id": podcast_id,
"title": podcast_title,
"error": (
"Background worker reported FAILED status for this podcast."
),
"error": ("Background worker reported FAILED status for this podcast."),
}
except Exception as e:

View file

@ -127,9 +127,7 @@ def create_generate_video_presentation_tool(
except Exception as e:
error_message = str(e)
logger.exception(
"[generate_video_presentation] Error: %s", error_message
)
logger.exception("[generate_video_presentation] Error: %s", error_message)
return {
"status": VideoPresentationStatus.FAILED.value,
"error": error_message,

View file

@ -21,4 +21,4 @@ __all__ = [
]
# Built-in actions self-register at import time.
from . import agent_task # noqa: E402, F401
from . import agent_task # noqa: F401

View file

@ -12,4 +12,4 @@ from .params import AgentTaskActionParams
__all__ = ["AgentTaskActionParams", "build_handler"]
# Side-effect: register on the actions store.
from . import definition # noqa: E402, F401
from . import definition # noqa: F401

View file

@ -13,7 +13,6 @@ from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
from app.db import ChatVisibility, async_session_maker
from ..types import ActionContext
from .auto_decide import build_auto_decisions
from .dependencies import build_dependencies
from .finalize import extract_final_assistant_message

View file

@ -50,7 +50,7 @@ class AutomationRun(BaseModel, TimestampMixin):
definition_snapshot = Column(JSONB, nullable=False)
# merged & validated inputs the run was dispatched with
# (trigger.static_inputs producer runtime data, static wins on collision)
# (trigger.static_inputs union producer runtime data, static wins on collision)
inputs = Column(JSONB, nullable=False, server_default="{}")
# one entry per executed step; agent_task entries carry their own
# `agent_session_id` inside their entry

View file

@ -6,9 +6,9 @@ from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.actions.types import ActionContext
from app.automations.persistence.enums.run_status import RunStatus
from app.automations.persistence.models.run import AutomationRun
from app.automations.actions.types import ActionContext
from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.schemas.definition.plan_step import PlanStep
from app.automations.templating import build_run_context
@ -32,7 +32,10 @@ async def execute_run(session: AsyncSession, run_id: int) -> None:
await repository.mark_failed(
session,
run,
{"message": f"definition_snapshot invalid: {exc}", "type": type(exc).__name__},
{
"message": f"definition_snapshot invalid: {exc}",
"type": type(exc).__name__,
},
)
await session.commit()
return
@ -92,7 +95,9 @@ async def _run_on_failure(
await session.commit()
def _build_template_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dict[str, Any]:
def _build_template_ctx(
run: AutomationRun, step_outputs: dict[str, Any]
) -> dict[str, Any]:
automation = run.automation
trigger = run.trigger
return build_run_context(

View file

@ -30,14 +30,18 @@ async def execute_step(
try:
should_run = evaluate_predicate(step.when, template_context)
except Exception as exc:
return _result(step, "failed", started_at, attempts=0, error=_error(exc, "when"))
return _result(
step, "failed", started_at, attempts=0, error=_error(exc, "when")
)
if not should_run:
return _result(step, "skipped", started_at, attempts=0)
try:
resolved_params = render_value(step.params, template_context)
except Exception as exc:
return _result(step, "failed", started_at, attempts=0, error=_error(exc, "render"))
return _result(
step, "failed", started_at, attempts=0, error=_error(exc, "render")
)
action = get_action(step.action)
if action is None:
@ -46,12 +50,17 @@ async def execute_step(
"failed",
started_at,
attempts=0,
error={"message": f"action not registered: {step.action}", "type": "ActionNotFound"},
error={
"message": f"action not registered: {step.action}",
"type": "ActionNotFound",
},
)
handler = action.build_handler(action_context)
max_retries = step.max_retries if step.max_retries is not None else default_max_retries
max_retries = (
step.max_retries if step.max_retries is not None else default_max_retries
)
timeout = step.timeout_seconds or default_timeout_seconds
try:
@ -62,7 +71,9 @@ async def execute_step(
timeout=timeout,
)
except Exception as exc:
return _result(step, "failed", started_at, attempts=max_retries + 1, error=_error(exc))
return _result(
step, "failed", started_at, attempts=max_retries + 1, error=_error(exc)
)
return _result(step, "succeeded", started_at, attempts=attempts, result=result)

View file

@ -12,7 +12,9 @@ from .plan_step import PlanStep
class Execution(BaseModel):
model_config = ConfigDict(extra="forbid")
timeout_seconds: int = Field(default=600, gt=0, description="Wall-clock cap for the run.")
timeout_seconds: int = Field(
default=600, gt=0, description="Wall-clock cap for the run."
)
max_retries: int = Field(default=2, ge=0, description="Per-step retry budget.")
retry_backoff: Literal["exponential", "linear", "none"] = "exponential"
concurrency: Literal["drop_if_running", "queue", "always"] = "drop_if_running"

View file

@ -11,7 +11,9 @@ class PlanStep(BaseModel):
model_config = ConfigDict(extra="forbid")
step_id: str = Field(..., min_length=1, description="Unique within the plan.")
action: str = Field(..., min_length=1, description="Action type; resolved via registry.")
action: str = Field(
..., min_length=1, description="Action type; resolved via registry."
)
when: str | None = Field(
default=None,
description="Optional predicate; step is skipped when falsy.",

View file

@ -10,7 +10,9 @@ from pydantic import BaseModel, ConfigDict, Field
class TriggerSpec(BaseModel):
model_config = ConfigDict(extra="forbid")
type: str = Field(..., min_length=1, description="Trigger type; resolved via registry.")
type: str = Field(
..., min_length=1, description="Trigger type; resolved via registry."
)
params: dict[str, Any] = Field(
default_factory=dict,
description="Type-specific params; validated against the trigger's schema.",

View file

@ -10,14 +10,14 @@ from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.schemas.api import (
AutomationCreate,
AutomationUpdate,
TriggerCreate,
)
from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.triggers import get_trigger
from app.automations.triggers.schedule import compute_next_fire_at
from app.db import Permission, User, get_async_session
@ -34,7 +34,9 @@ class AutomationService:
async def create(self, payload: AutomationCreate) -> Automation:
"""Create an automation and its initial triggers in one transaction."""
await self._authorize(payload.search_space_id, Permission.AUTOMATIONS_CREATE.value)
await self._authorize(
payload.search_space_id, Permission.AUTOMATIONS_CREATE.value
)
automation = Automation(
search_space_id=payload.search_space_id,
@ -67,22 +69,32 @@ class AutomationService:
)
rows = (
await self.session.execute(
base.order_by(Automation.created_at.desc()).limit(limit).offset(offset)
(
await self.session.execute(
base.order_by(Automation.created_at.desc())
.limit(limit)
.offset(offset)
)
)
).scalars().all()
.scalars()
.all()
)
return list(rows), int(total or 0)
async def get(self, automation_id: int) -> Automation:
"""Get an automation with its triggers loaded."""
automation = await self._get_with_triggers_or_raise(automation_id)
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_READ.value)
await self._authorize(
automation.search_space_id, Permission.AUTOMATIONS_READ.value
)
return automation
async def update(self, automation_id: int, patch: AutomationUpdate) -> Automation:
"""Patch fields. Bumps ``version`` when ``definition`` changes."""
automation = await self._get_with_triggers_or_raise(automation_id)
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_UPDATE.value)
await self._authorize(
automation.search_space_id, Permission.AUTOMATIONS_UPDATE.value
)
data = patch.model_dump(exclude_unset=True)
@ -93,7 +105,9 @@ class AutomationService:
if "status" in data:
automation.status = data["status"]
if "definition" in data:
automation.definition = patch.definition.model_dump(mode="json", by_alias=True)
automation.definition = patch.definition.model_dump(
mode="json", by_alias=True
)
automation.version += 1
await self.session.commit()
@ -102,7 +116,9 @@ class AutomationService:
async def delete(self, automation_id: int) -> None:
"""Delete an automation; FK cascades remove triggers and runs."""
automation = await self._get_or_raise(automation_id)
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_DELETE.value)
await self._authorize(
automation.search_space_id, Permission.AUTOMATIONS_DELETE.value
)
await self.session.delete(automation)
await self.session.commit()
@ -141,7 +157,9 @@ def _build_trigger(spec: TriggerCreate) -> AutomationTrigger:
"""Validate trigger params via its registered Pydantic model and build the ORM row."""
definition = get_trigger(spec.type.value)
if definition is None:
raise HTTPException(status_code=422, detail=f"unknown trigger type {spec.type.value!r}")
raise HTTPException(
status_code=422, detail=f"unknown trigger type {spec.type.value!r}"
)
try:
validated = definition.params_model.model_validate(spec.params)

View file

@ -36,10 +36,16 @@ class RunService:
)
rows = (
await self.session.execute(
base.order_by(AutomationRun.created_at.desc()).limit(limit).offset(offset)
(
await self.session.execute(
base.order_by(AutomationRun.created_at.desc())
.limit(limit)
.offset(offset)
)
)
).scalars().all()
.scalars()
.all()
)
return list(rows), int(total or 0)
async def get(self, *, automation_id: int, run_id: int) -> AutomationRun:

View file

@ -8,10 +8,10 @@ from fastapi import Depends, HTTPException
from pydantic import ValidationError
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
from app.automations.triggers import get_trigger
from app.automations.triggers.schedule import compute_next_fire_at
from app.db import Permission, User, get_async_session
@ -40,7 +40,9 @@ class TriggerService:
params=validated_params,
static_inputs=payload.static_inputs,
enabled=payload.enabled,
next_fire_at=_initial_next_fire(payload.type, validated_params, payload.enabled),
next_fire_at=_initial_next_fire(
payload.type, validated_params, payload.enabled
),
)
self.session.add(trigger)
await self.session.commit()
@ -54,7 +56,9 @@ class TriggerService:
trigger_id: int,
patch: TriggerUpdate,
) -> AutomationTrigger:
await self._authorize_automation(automation_id, Permission.AUTOMATIONS_UPDATE.value)
await self._authorize_automation(
automation_id, Permission.AUTOMATIONS_UPDATE.value
)
trigger = await self._get_trigger_or_raise(automation_id, trigger_id)
data = patch.model_dump(exclude_unset=True)
@ -80,7 +84,9 @@ class TriggerService:
return trigger
async def remove(self, *, automation_id: int, trigger_id: int) -> None:
await self._authorize_automation(automation_id, Permission.AUTOMATIONS_UPDATE.value)
await self._authorize_automation(
automation_id, Permission.AUTOMATIONS_UPDATE.value
)
trigger = await self._get_trigger_or_raise(automation_id, trigger_id)
await self.session.delete(trigger)
await self.session.commit()

View file

@ -17,7 +17,7 @@ TASK_NAME = "automation_run_execute"
@celery_app.task(name=TASK_NAME, bind=True)
def automation_run_execute(self, run_id: int) -> None: # noqa: ARG001 — Celery bind
def automation_run_execute(self, run_id: int) -> None:
"""Execute one ``AutomationRun``. Idempotent: terminal runs no-op."""
return run_async_celery_task(lambda: _impl(run_id))

View file

@ -103,9 +103,7 @@ async def _self_heal_null_next_fire(session: AsyncSession, *, now: datetime) ->
await session.commit()
async def _claim_due_triggers(
session: AsyncSession, *, now: datetime
) -> list[_Claim]:
async def _claim_due_triggers(session: AsyncSession, *, now: datetime) -> list[_Claim]:
"""Lock and advance due rows; return per-trigger fire context."""
stmt = (
select(AutomationTrigger)

View file

@ -17,4 +17,4 @@ __all__ = [
]
# Built-in triggers self-register at import time.
from . import schedule # noqa: E402, F401
from . import schedule # noqa: F401

View file

@ -15,4 +15,4 @@ __all__ = [
]
# Side-effect: register on the triggers store.
from . import definition # noqa: E402, F401
from . import definition # noqa: F401

View file

@ -32,6 +32,10 @@ def compute_next_fire_at(cron: str, timezone: str, *, after: datetime) -> dateti
given timezone before evaluation so DST and IANA rules apply correctly.
"""
tz = ZoneInfo(timezone)
base = after.astimezone(tz) if after.tzinfo else after.replace(tzinfo=UTC).astimezone(tz)
base = (
after.astimezone(tz)
if after.tzinfo
else after.replace(tzinfo=UTC).astimezone(tz)
)
nxt: datetime = croniter(cron, base).get_next(datetime)
return nxt.astimezone(UTC)

View file

@ -10,7 +10,9 @@ from .cron import InvalidCronError, validate_cron
class ScheduleTriggerParams(BaseModel):
model_config = ConfigDict(extra="forbid")
cron: str = Field(..., description="Five-field cron expression.", examples=["0 9 * * 1-5"])
cron: str = Field(
..., description="Five-field cron expression.", examples=["0 9 * * 1-5"]
)
timezone: str = Field(..., description="IANA timezone.", examples=["Africa/Kigali"])
@model_validator(mode="after")

View file

@ -2605,7 +2605,6 @@ from app.automations.persistence import ( # noqa: E402, F401
AutomationTrigger,
)
engine = create_async_engine(
DATABASE_URL,
pool_size=30,

View file

@ -1,5 +1,7 @@
from fastapi import APIRouter
from app.automations.api import router as automations_router
from .agent_action_log_route import router as agent_action_log_router
from .agent_flags_route import router as agent_flags_router
from .agent_permissions_route import router as agent_permissions_router
@ -7,7 +9,6 @@ from .agent_revert_route import router as agent_revert_router
from .airtable_add_connector_route import (
router as airtable_add_connector_router,
)
from app.automations.api import router as automations_router
from .chat_comments_routes import router as chat_comments_router
from .circleback_webhook_route import router as circleback_webhook_router
from .clickup_add_connector_route import router as clickup_add_connector_router

View file

@ -19,9 +19,7 @@ def extract_todos_from_deepagents(command_output: Any) -> dict:
elif isinstance(command_output, dict):
if "todos" in command_output:
todos_data = command_output.get("todos", [])
elif "update" in command_output and isinstance(
command_output["update"], dict
):
elif "update" in command_output and isinstance(command_output["update"], dict):
todos_data = command_output["update"].get("todos", [])
return {"todos": todos_data}

View file

@ -69,17 +69,13 @@ async def resolve_initial_auto_pin(
"pin.requires_image_input": requires_image_input,
},
)
return AutoPinResult(
llm_config_id=pinned.resolved_llm_config_id, error=None
)
return AutoPinResult(llm_config_id=pinned.resolved_llm_config_id, error=None)
except ValueError as pin_error:
# The "no vision-capable cfg" path raises a ValueError whose message
# we map to the friendly image-input SSE error so the user sees the
# same message regardless of whether the gate fired in the resolver or
# in ``llm_capability.assert_vision_capability_for_image_turn``.
is_vision_failure = (
requires_image_input and "vision-capable" in str(pin_error)
)
is_vision_failure = requires_image_input and "vision-capable" in str(pin_error)
error_code = (
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
if is_vision_failure

View file

@ -207,9 +207,7 @@ async def _resolve_mentions_for_query(
try:
chip_objs.append(MentionedDocumentInfo.model_validate(raw))
except Exception:
logger.debug(
"stream_new_chat: dropping malformed mention chip %r", raw
)
logger.debug("stream_new_chat: dropping malformed mention chip %r", raw)
resolved = await resolve_mentions(
session,

View file

@ -48,9 +48,7 @@ def check_image_input_capability(
return None
model_label = agent_config.config_name or agent_config.model_name or "model"
ot.add_event(
"quota.denied", {"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"}
)
ot.add_event("quota.denied", {"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"})
return (
(
f"The selected model ({model_label}) does not support "

View file

@ -259,7 +259,8 @@ async def stream_new_chat(
if needs_premium_quota(agent_config, user_id):
premium_reservation = await reserve_premium(
agent_config=agent_config, user_id=user_id # type: ignore[arg-type]
agent_config=agent_config,
user_id=user_id, # type: ignore[arg-type]
)
if not premium_reservation.allowed:
ot.add_event("quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"})
@ -492,7 +493,9 @@ async def stream_new_chat(
# --- Block 4: First SSE frames ---
for sse in iter_initial_frames(streaming_service, turn_id=stream_result.turn_id):
for sse in iter_initial_frames(
streaming_service, turn_id=stream_result.turn_id
):
yield sse
# --- Block 5: Persistence join + message-id frames ---
@ -693,7 +696,9 @@ async def stream_new_chat(
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode if filesystem_selection else FilesystemMode.CLOUD
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
runtime_context=runtime_context,
@ -715,11 +720,7 @@ async def stream_new_chat(
title_emitted = True
# Account for the case where the task completed but produced no
# title — flip the flag anyway so we don't keep checking it.
if (
title_task is not None
and title_task.done()
and not title_emitted
):
if title_task is not None and title_task.done() and not title_emitted:
title_emitted = True
_perf_log.info(
@ -811,9 +812,7 @@ async def stream_new_chat(
end_turn(str(chat_id))
if premium_reservation is not None and user_id:
await release_premium(
reservation=premium_reservation, user_id=user_id
)
await release_premium(reservation=premium_reservation, user_id=user_id)
await close_session_and_clear_ai_responding(session, chat_id)
@ -852,9 +851,9 @@ async def stream_new_chat(
# Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass.
agent = llm = connector_service = None # noqa: F841
input_state = stream_result = None # noqa: F841
session = None # noqa: F841
agent = llm = connector_service = None
input_state = stream_result = None
session = None
run_gc_pass(log_prefix="stream_new_chat", chat_id=chat_id)
close_chat_request_span(

View file

@ -30,9 +30,7 @@ def build_new_chat_runtime_context(
return SurfSenseContextSchema(
search_space_id=search_space_id,
mentioned_document_ids=list(mentioned_document_ids or []),
mentioned_folder_ids=list(
accepted_folder_ids or mentioned_folder_ids or []
),
mentioned_folder_ids=list(accepted_folder_ids or mentioned_folder_ids or []),
request_id=request_id,
turn_id=turn_id,
)

View file

@ -133,12 +133,8 @@ async def _generate_title(
# inherited Azure endpoint — see ``provider_api_base`` for the
# same bug repro on the image-gen / vision paths.
raw_model = getattr(llm, "model", "") or ""
provider_prefix = (
raw_model.split("/", 1)[0] if "/" in raw_model else None
)
provider_value = (
agent_config.provider if agent_config is not None else None
)
provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None
provider_value = agent_config.provider if agent_config is not None else None
title_api_base = resolve_api_base(
provider=provider_value,
provider_prefix=provider_prefix,

View file

@ -15,14 +15,10 @@ building blocks under ``flows/shared/``. Mirrors ``stream_new_chat`` but:
from __future__ import annotations
import contextlib
import gc
import logging
import sys
import time
import uuid as _uuid
from collections.abc import AsyncGenerator
from functools import partial
from typing import Any
from uuid import UUID
import anyio
@ -32,7 +28,7 @@ from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.middleware.busy_mutex import end_turn
from app.config import config as _app_config
from app.db import ChatVisibility, async_session_maker, shielded_async_session
from app.db import ChatVisibility, async_session_maker
from app.observability import otel as ot
from app.services.chat_session_state_service import set_ai_responding
from app.services.new_streaming_service import VercelStreamingService
@ -89,7 +85,7 @@ from app.tasks.chat.streaming.flows.shared.terminal_error import (
)
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.tasks.chat.streaming.shared.utils import resume_step_prefix
from app.utils.perf import get_perf_logger, log_system_snapshot
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
@ -217,12 +213,11 @@ async def stream_resume_chat(
if needs_premium_quota(agent_config, user_id):
premium_reservation = await reserve_premium(
agent_config=agent_config, user_id=user_id # type: ignore[arg-type]
agent_config=agent_config,
user_id=user_id, # type: ignore[arg-type]
)
if not premium_reservation.allowed:
ot.add_event(
"quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"}
)
ot.add_event("quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"})
if requested_llm_config_id == 0:
try:
pinned_fb = await resolve_or_get_pinned_llm_config_id(
@ -396,7 +391,9 @@ async def stream_resume_chat(
# --- First SSE frames ---
for sse in iter_initial_frames(streaming_service, turn_id=stream_result.turn_id):
for sse in iter_initial_frames(
streaming_service, turn_id=stream_result.turn_id
):
yield sse
# --- Assistant-shell persistence + id frame ---
@ -517,7 +514,9 @@ async def stream_resume_chat(
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode if filesystem_selection else FilesystemMode.CLOUD
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
runtime_context=runtime_context,
@ -589,9 +588,7 @@ async def stream_resume_chat(
end_turn(str(chat_id))
if premium_reservation is not None and user_id:
await release_premium(
reservation=premium_reservation, user_id=user_id
)
await release_premium(reservation=premium_reservation, user_id=user_id)
await close_session_and_clear_ai_responding(session, chat_id)
@ -609,13 +606,11 @@ async def stream_resume_chat(
if not busy_error_raised:
with contextlib.suppress(Exception):
end_turn(str(chat_id))
_perf_log.info(
"[stream_resume] end_turn cleanup (chat_id=%s)", chat_id
)
_perf_log.info("[stream_resume] end_turn cleanup (chat_id=%s)", chat_id)
agent = llm = connector_service = None # noqa: F841
stream_result = None # noqa: F841
session = None # noqa: F841
agent = llm = connector_service = None
stream_result = None
session = None
run_gc_pass(log_prefix="stream_resume", chat_id=chat_id)
close_chat_request_span(

View file

@ -47,9 +47,7 @@ async def build_resume_routing(
slice_decisions_by_tool_call,
)
parent_state = await agent.aget_state(
{"configurable": {"thread_id": str(chat_id)}}
)
parent_state = await agent.aget_state({"configurable": {"thread_id": str(chat_id)}})
pending = collect_pending_tool_calls(parent_state)
_perf_log.info(
"[hitl_route] resume_entry chat_id=%s decisions=%d pending_subagents=%d",

View file

@ -49,9 +49,7 @@ async def finalize_assistant_message(
was never assigned.
"""
if not (
stream_result
and stream_result.turn_id
and stream_result.assistant_message_id
stream_result and stream_result.turn_id and stream_result.assistant_message_id
):
return

View file

@ -39,9 +39,7 @@ async def close_session_and_clear_ai_responding(
async with shielded_async_session() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
except Exception:
logger.warning(
"Failed to clear AI responding state for thread %s", chat_id
)
logger.warning("Failed to clear AI responding state for thread %s", chat_id)
with contextlib.suppress(Exception):
session.expunge_all()

View file

@ -41,9 +41,7 @@ class PremiumReservation:
allowed: bool
def needs_premium_quota(
agent_config: AgentConfig | None, user_id: str | None
) -> bool:
def needs_premium_quota(agent_config: AgentConfig | None, user_id: str | None) -> bool:
return bool(agent_config is not None and user_id and agent_config.is_premium)
@ -61,8 +59,10 @@ async def reserve_premium(
request_id = _uuid.uuid4().hex[:16]
litellm_params = agent_config.litellm_params or {}
base_model = (
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
) or agent_config.model_name or ""
(litellm_params.get("base_model") if isinstance(litellm_params, dict) else None)
or agent_config.model_name
or ""
)
reserve_amount_micros = estimate_call_reserve_micros(
base_model=base_model,
quota_reserve_tokens=agent_config.quota_reserve_tokens,

View file

@ -6,8 +6,7 @@ import contextlib
import sys
from typing import Any, Literal
from app.observability import metrics as ot_metrics
from app.observability import otel as ot
from app.observability import metrics as ot_metrics, otel as ot
def open_chat_request_span(

View file

@ -15,8 +15,7 @@ from collections.abc import Iterator
from typing import Any, Literal
from app.agents.new_chat.errors import BusyError
from app.observability import metrics as ot_metrics
from app.observability import otel as ot
from app.observability import metrics as ot_metrics, otel as ot
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.streaming.errors.classifier import classify_stream_exception
from app.tasks.chat.streaming.errors.emitter import emit_stream_terminal_error

View file

@ -72,7 +72,11 @@ def test_extract_returns_none_when_no_assistant_text_is_present() -> None:
anything?" rather than guess whether ``""`` means silence or empty
output. Empty-string contents are normalized to ``None`` too."""
no_ai = {"messages": [HumanMessage(content="just a question")]}
only_tools = {"messages": [AIMessage(content=[{"type": "tool_use", "name": "x", "input": {}}])]}
only_tools = {
"messages": [
AIMessage(content=[{"type": "tool_use", "name": "x", "input": {}}])
]
}
empty_string = {"messages": [AIMessage(content=" ")]}
assert extract_final_assistant_message(no_ai) is None

View file

@ -33,7 +33,9 @@ async def test_with_retries_returns_result_and_attempts_one_on_first_success() -
assert calls == 1
async def test_with_retries_returns_attempt_count_when_succeeding_after_failures() -> None:
async def test_with_retries_returns_attempt_count_when_succeeding_after_failures() -> (
None
):
"""A coroutine that fails twice then succeeds returns ``attempts=3``
(the actual attempt that produced the result). Locks the contract
that the caller can distinguish first-try success from a recovery."""

View file

@ -11,7 +11,9 @@ from app.automations.schemas.definition.plan_step import PlanStep
pytestmark = pytest.mark.unit
def test_automation_definition_accepts_minimal_valid_input_with_sensible_defaults() -> None:
def test_automation_definition_accepts_minimal_valid_input_with_sensible_defaults() -> (
None
):
"""A definition with just ``name`` + a one-step ``plan`` is valid and
fills in the rest with safe defaults so users don't have to write
out every section to get started."""

View file

@ -32,7 +32,9 @@ def test_environment_finalizes_datetime_output_to_iso_string() -> None:
when emitting ``inputs.fired_at`` and other datetime values."""
dt = datetime(2026, 5, 28, 14, 30, tzinfo=UTC)
assert render_template("{{ moment }}", {"moment": dt}) == "2026-05-28T14:30:00+00:00"
assert (
render_template("{{ moment }}", {"moment": dt}) == "2026-05-28T14:30:00+00:00"
)
def test_environment_finalizes_none_output_to_empty_string() -> None:

View file

@ -31,7 +31,7 @@ def test_action_definition_params_schema_reflects_params_model() -> None:
name="N",
description="D",
params_model=_Topic,
build_handler=lambda _ctx: (lambda _p: {}), # type: ignore[arg-type,return-value]
build_handler=lambda _ctx: lambda _p: {}, # type: ignore[arg-type,return-value]
)
schema = definition.params_schema

View file

@ -29,7 +29,9 @@ class _Params(BaseModel):
def _trigger(type_: str = "test_trigger") -> TriggerDefinition:
return TriggerDefinition(type=type_, description="Test trigger.", params_model=_Params)
return TriggerDefinition(
type=type_, description="Test trigger.", params_model=_Params
)
def _action(type_: str = "test_action") -> ActionDefinition:
@ -38,7 +40,7 @@ def _action(type_: str = "test_action") -> ActionDefinition:
name="Test",
description="Test action.",
params_model=_Params,
build_handler=lambda _ctx: (lambda _p: {}), # type: ignore[arg-type,return-value]
build_handler=lambda _ctx: lambda _p: {}, # type: ignore[arg-type,return-value]
)
@ -112,4 +114,4 @@ def test_all_triggers_returns_defensive_snapshot(
snapshot = all_triggers()
snapshot.pop("snapshot_test")
assert get_trigger("snapshot_test") is not None
assert get_trigger("snapshot_test") is not None

View file

@ -45,8 +45,12 @@ def test_compute_next_fire_at_respects_dst_offset_change() -> None:
winter_after = datetime(2026, 2, 15, 0, 0, tzinfo=UTC)
summer_after = datetime(2026, 4, 15, 0, 0, tzinfo=UTC)
winter_fire = compute_next_fire_at("0 9 * * *", "America/New_York", after=winter_after)
summer_fire = compute_next_fire_at("0 9 * * *", "America/New_York", after=summer_after)
winter_fire = compute_next_fire_at(
"0 9 * * *", "America/New_York", after=winter_after
)
summer_fire = compute_next_fire_at(
"0 9 * * *", "America/New_York", after=summer_after
)
assert winter_fire == datetime(2026, 2, 15, 14, 0, tzinfo=UTC)
assert summer_fire == datetime(2026, 4, 15, 13, 0, tzinfo=UTC)

View file

@ -33,7 +33,6 @@ import pytest
from app.agents.new_chat.context import SurfSenseContextSchema
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_new_chat import (
stream_new_chat as old_stream_new_chat,
stream_resume_chat as old_stream_resume_chat,
@ -152,7 +151,13 @@ class _FakeSurfsenseDoc:
"user_query, image_urls, docs, expected_title, expected_action",
[
("hello world", None, [], "Understanding your request", "Processing"),
("", ["data:image/png;base64,AAA"], [], "Understanding your request", "Processing"),
(
"",
["data:image/png;base64,AAA"],
[],
"Understanding your request",
"Processing",
),
("", None, [], "Understanding your request", "Processing"),
(
"doc question",
@ -209,9 +214,10 @@ def test_initial_thinking_step_collapses_many_doc_names() -> None:
def test_image_capability_passes_without_images() -> None:
assert check_image_input_capability(
user_image_data_urls=None, agent_config=None
) is None
assert (
check_image_input_capability(user_image_data_urls=None, agent_config=None)
is None
)
def test_image_capability_passes_when_capability_unknown() -> None:
@ -500,9 +506,7 @@ def test_can_recover_provider_rate_limit_rejects_non_rate_limit_exception() -> N
def test_spawn_set_ai_responding_bg_noop_without_user_id() -> None:
async def _run() -> set[asyncio.Task]:
background: set[asyncio.Task] = set()
spawn_set_ai_responding_bg(
chat_id=1, user_id=None, background_tasks=background
)
spawn_set_ai_responding_bg(chat_id=1, user_id=None, background_tasks=background)
return background
bg = asyncio.run(_run())