mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
chore: linting
This commit is contained in:
parent
4dda02c06c
commit
94e834134f
80 changed files with 443 additions and 404 deletions
|
|
@ -98,9 +98,7 @@ def upgrade() -> None:
|
|||
op.execute(
|
||||
"CREATE INDEX ix_automation_triggers_automation_id ON automation_triggers(automation_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX ix_automation_triggers_type ON automation_triggers(type);"
|
||||
)
|
||||
op.execute("CREATE INDEX ix_automation_triggers_type ON automation_triggers(type);")
|
||||
op.execute(
|
||||
"CREATE INDEX ix_automation_triggers_enabled ON automation_triggers(enabled);"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ from __future__ import annotations
|
|||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
_HEADER = """\
|
||||
You are the SurfSense automation drafter. Convert the user intent below
|
||||
into a SINGLE JSON object matching the AutomationCreate schema. Output
|
||||
|
|
|
|||
|
|
@ -404,9 +404,7 @@ def build_task_tool_with_parent_config(
|
|||
continue
|
||||
messages = payload.get("messages") or []
|
||||
last_text = _safe_message_text(messages[-1]).rstrip() if messages else ""
|
||||
message_blocks.append(
|
||||
f"[task {task_index}] {last_text or '<empty>'}"
|
||||
)
|
||||
message_blocks.append(f"[task {task_index}] {last_text or '<empty>'}")
|
||||
try:
|
||||
child_trace = _build_tool_trace(messages)
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -117,9 +117,7 @@ def create_generate_podcast_tool(
|
|||
"podcast_id": podcast_id,
|
||||
"title": podcast_title,
|
||||
"file_location": file_location,
|
||||
"message": (
|
||||
"Podcast generated and saved to your podcast panel."
|
||||
),
|
||||
"message": ("Podcast generated and saved to your podcast panel."),
|
||||
}
|
||||
return with_receipt(
|
||||
payload=payload,
|
||||
|
|
|
|||
|
|
@ -126,8 +126,7 @@ def create_generate_video_presentation_tool(
|
|||
elapsed,
|
||||
)
|
||||
err = (
|
||||
"Background worker reported FAILED status for this "
|
||||
"video presentation."
|
||||
"Background worker reported FAILED status for this video presentation."
|
||||
)
|
||||
payload = {
|
||||
"status": VideoPresentationStatus.FAILED.value,
|
||||
|
|
@ -151,9 +150,7 @@ def create_generate_video_presentation_tool(
|
|||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.exception(
|
||||
"[generate_video_presentation] Error: %s", error_message
|
||||
)
|
||||
logger.exception("[generate_video_presentation] Error: %s", error_message)
|
||||
payload = {
|
||||
"status": VideoPresentationStatus.FAILED.value,
|
||||
"error": error_message,
|
||||
|
|
|
|||
|
|
@ -131,9 +131,7 @@ def create_generate_podcast_tool(
|
|||
"podcast_id": podcast_id,
|
||||
"title": podcast_title,
|
||||
"file_location": file_location,
|
||||
"message": (
|
||||
"Podcast generated and saved to your podcast panel."
|
||||
),
|
||||
"message": ("Podcast generated and saved to your podcast panel."),
|
||||
}
|
||||
|
||||
# Only other terminal state is FAILED.
|
||||
|
|
@ -146,9 +144,7 @@ def create_generate_podcast_tool(
|
|||
"status": PodcastStatus.FAILED.value,
|
||||
"podcast_id": podcast_id,
|
||||
"title": podcast_title,
|
||||
"error": (
|
||||
"Background worker reported FAILED status for this podcast."
|
||||
),
|
||||
"error": ("Background worker reported FAILED status for this podcast."),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -127,9 +127,7 @@ def create_generate_video_presentation_tool(
|
|||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.exception(
|
||||
"[generate_video_presentation] Error: %s", error_message
|
||||
)
|
||||
logger.exception("[generate_video_presentation] Error: %s", error_message)
|
||||
return {
|
||||
"status": VideoPresentationStatus.FAILED.value,
|
||||
"error": error_message,
|
||||
|
|
|
|||
|
|
@ -21,4 +21,4 @@ __all__ = [
|
|||
]
|
||||
|
||||
# Built-in actions self-register at import time.
|
||||
from . import agent_task # noqa: E402, F401
|
||||
from . import agent_task # noqa: F401
|
||||
|
|
|
|||
|
|
@ -12,4 +12,4 @@ from .params import AgentTaskActionParams
|
|||
__all__ = ["AgentTaskActionParams", "build_handler"]
|
||||
|
||||
# Side-effect: register on the actions store.
|
||||
from . import definition # noqa: E402, F401
|
||||
from . import definition # noqa: F401
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
|
|||
from app.db import ChatVisibility, async_session_maker
|
||||
|
||||
from ..types import ActionContext
|
||||
|
||||
from .auto_decide import build_auto_decisions
|
||||
from .dependencies import build_dependencies
|
||||
from .finalize import extract_final_assistant_message
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class AutomationRun(BaseModel, TimestampMixin):
|
|||
definition_snapshot = Column(JSONB, nullable=False)
|
||||
|
||||
# merged & validated inputs the run was dispatched with
|
||||
# (trigger.static_inputs ∪ producer runtime data, static wins on collision)
|
||||
# (trigger.static_inputs union producer runtime data, static wins on collision)
|
||||
inputs = Column(JSONB, nullable=False, server_default="{}")
|
||||
# one entry per executed step; agent_task entries carry their own
|
||||
# `agent_session_id` inside their entry
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ from typing import Any
|
|||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.actions.types import ActionContext
|
||||
from app.automations.persistence.enums.run_status import RunStatus
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
from app.automations.actions.types import ActionContext
|
||||
from app.automations.schemas.definition.envelope import AutomationDefinition
|
||||
from app.automations.schemas.definition.plan_step import PlanStep
|
||||
from app.automations.templating import build_run_context
|
||||
|
|
@ -32,7 +32,10 @@ async def execute_run(session: AsyncSession, run_id: int) -> None:
|
|||
await repository.mark_failed(
|
||||
session,
|
||||
run,
|
||||
{"message": f"definition_snapshot invalid: {exc}", "type": type(exc).__name__},
|
||||
{
|
||||
"message": f"definition_snapshot invalid: {exc}",
|
||||
"type": type(exc).__name__,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
return
|
||||
|
|
@ -92,7 +95,9 @@ async def _run_on_failure(
|
|||
await session.commit()
|
||||
|
||||
|
||||
def _build_template_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dict[str, Any]:
|
||||
def _build_template_ctx(
|
||||
run: AutomationRun, step_outputs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
automation = run.automation
|
||||
trigger = run.trigger
|
||||
return build_run_context(
|
||||
|
|
|
|||
|
|
@ -30,14 +30,18 @@ async def execute_step(
|
|||
try:
|
||||
should_run = evaluate_predicate(step.when, template_context)
|
||||
except Exception as exc:
|
||||
return _result(step, "failed", started_at, attempts=0, error=_error(exc, "when"))
|
||||
return _result(
|
||||
step, "failed", started_at, attempts=0, error=_error(exc, "when")
|
||||
)
|
||||
if not should_run:
|
||||
return _result(step, "skipped", started_at, attempts=0)
|
||||
|
||||
try:
|
||||
resolved_params = render_value(step.params, template_context)
|
||||
except Exception as exc:
|
||||
return _result(step, "failed", started_at, attempts=0, error=_error(exc, "render"))
|
||||
return _result(
|
||||
step, "failed", started_at, attempts=0, error=_error(exc, "render")
|
||||
)
|
||||
|
||||
action = get_action(step.action)
|
||||
if action is None:
|
||||
|
|
@ -46,12 +50,17 @@ async def execute_step(
|
|||
"failed",
|
||||
started_at,
|
||||
attempts=0,
|
||||
error={"message": f"action not registered: {step.action}", "type": "ActionNotFound"},
|
||||
error={
|
||||
"message": f"action not registered: {step.action}",
|
||||
"type": "ActionNotFound",
|
||||
},
|
||||
)
|
||||
|
||||
handler = action.build_handler(action_context)
|
||||
|
||||
max_retries = step.max_retries if step.max_retries is not None else default_max_retries
|
||||
max_retries = (
|
||||
step.max_retries if step.max_retries is not None else default_max_retries
|
||||
)
|
||||
timeout = step.timeout_seconds or default_timeout_seconds
|
||||
|
||||
try:
|
||||
|
|
@ -62,7 +71,9 @@ async def execute_step(
|
|||
timeout=timeout,
|
||||
)
|
||||
except Exception as exc:
|
||||
return _result(step, "failed", started_at, attempts=max_retries + 1, error=_error(exc))
|
||||
return _result(
|
||||
step, "failed", started_at, attempts=max_retries + 1, error=_error(exc)
|
||||
)
|
||||
|
||||
return _result(step, "succeeded", started_at, attempts=attempts, result=result)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@ from .plan_step import PlanStep
|
|||
class Execution(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
timeout_seconds: int = Field(default=600, gt=0, description="Wall-clock cap for the run.")
|
||||
timeout_seconds: int = Field(
|
||||
default=600, gt=0, description="Wall-clock cap for the run."
|
||||
)
|
||||
max_retries: int = Field(default=2, ge=0, description="Per-step retry budget.")
|
||||
retry_backoff: Literal["exponential", "linear", "none"] = "exponential"
|
||||
concurrency: Literal["drop_if_running", "queue", "always"] = "drop_if_running"
|
||||
|
|
|
|||
|
|
@ -11,7 +11,9 @@ class PlanStep(BaseModel):
|
|||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
step_id: str = Field(..., min_length=1, description="Unique within the plan.")
|
||||
action: str = Field(..., min_length=1, description="Action type; resolved via registry.")
|
||||
action: str = Field(
|
||||
..., min_length=1, description="Action type; resolved via registry."
|
||||
)
|
||||
when: str | None = Field(
|
||||
default=None,
|
||||
description="Optional predicate; step is skipped when falsy.",
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||
class TriggerSpec(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: str = Field(..., min_length=1, description="Trigger type; resolved via registry.")
|
||||
type: str = Field(
|
||||
..., min_length=1, description="Trigger type; resolved via registry."
|
||||
)
|
||||
params: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Type-specific params; validated against the trigger's schema.",
|
||||
|
|
|
|||
|
|
@ -10,14 +10,14 @@ from sqlalchemy import func, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.schemas.api import (
|
||||
AutomationCreate,
|
||||
AutomationUpdate,
|
||||
TriggerCreate,
|
||||
)
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.schedule import compute_next_fire_at
|
||||
from app.db import Permission, User, get_async_session
|
||||
|
|
@ -34,7 +34,9 @@ class AutomationService:
|
|||
|
||||
async def create(self, payload: AutomationCreate) -> Automation:
|
||||
"""Create an automation and its initial triggers in one transaction."""
|
||||
await self._authorize(payload.search_space_id, Permission.AUTOMATIONS_CREATE.value)
|
||||
await self._authorize(
|
||||
payload.search_space_id, Permission.AUTOMATIONS_CREATE.value
|
||||
)
|
||||
|
||||
automation = Automation(
|
||||
search_space_id=payload.search_space_id,
|
||||
|
|
@ -67,22 +69,32 @@ class AutomationService:
|
|||
)
|
||||
|
||||
rows = (
|
||||
await self.session.execute(
|
||||
base.order_by(Automation.created_at.desc()).limit(limit).offset(offset)
|
||||
(
|
||||
await self.session.execute(
|
||||
base.order_by(Automation.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
)
|
||||
).scalars().all()
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
return list(rows), int(total or 0)
|
||||
|
||||
async def get(self, automation_id: int) -> Automation:
|
||||
"""Get an automation with its triggers loaded."""
|
||||
automation = await self._get_with_triggers_or_raise(automation_id)
|
||||
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_READ.value)
|
||||
await self._authorize(
|
||||
automation.search_space_id, Permission.AUTOMATIONS_READ.value
|
||||
)
|
||||
return automation
|
||||
|
||||
async def update(self, automation_id: int, patch: AutomationUpdate) -> Automation:
|
||||
"""Patch fields. Bumps ``version`` when ``definition`` changes."""
|
||||
automation = await self._get_with_triggers_or_raise(automation_id)
|
||||
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_UPDATE.value)
|
||||
await self._authorize(
|
||||
automation.search_space_id, Permission.AUTOMATIONS_UPDATE.value
|
||||
)
|
||||
|
||||
data = patch.model_dump(exclude_unset=True)
|
||||
|
||||
|
|
@ -93,7 +105,9 @@ class AutomationService:
|
|||
if "status" in data:
|
||||
automation.status = data["status"]
|
||||
if "definition" in data:
|
||||
automation.definition = patch.definition.model_dump(mode="json", by_alias=True)
|
||||
automation.definition = patch.definition.model_dump(
|
||||
mode="json", by_alias=True
|
||||
)
|
||||
automation.version += 1
|
||||
|
||||
await self.session.commit()
|
||||
|
|
@ -102,7 +116,9 @@ class AutomationService:
|
|||
async def delete(self, automation_id: int) -> None:
|
||||
"""Delete an automation; FK cascades remove triggers and runs."""
|
||||
automation = await self._get_or_raise(automation_id)
|
||||
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_DELETE.value)
|
||||
await self._authorize(
|
||||
automation.search_space_id, Permission.AUTOMATIONS_DELETE.value
|
||||
)
|
||||
await self.session.delete(automation)
|
||||
await self.session.commit()
|
||||
|
||||
|
|
@ -141,7 +157,9 @@ def _build_trigger(spec: TriggerCreate) -> AutomationTrigger:
|
|||
"""Validate trigger params via its registered Pydantic model and build the ORM row."""
|
||||
definition = get_trigger(spec.type.value)
|
||||
if definition is None:
|
||||
raise HTTPException(status_code=422, detail=f"unknown trigger type {spec.type.value!r}")
|
||||
raise HTTPException(
|
||||
status_code=422, detail=f"unknown trigger type {spec.type.value!r}"
|
||||
)
|
||||
|
||||
try:
|
||||
validated = definition.params_model.model_validate(spec.params)
|
||||
|
|
|
|||
|
|
@ -36,10 +36,16 @@ class RunService:
|
|||
)
|
||||
|
||||
rows = (
|
||||
await self.session.execute(
|
||||
base.order_by(AutomationRun.created_at.desc()).limit(limit).offset(offset)
|
||||
(
|
||||
await self.session.execute(
|
||||
base.order_by(AutomationRun.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
)
|
||||
).scalars().all()
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
return list(rows), int(total or 0)
|
||||
|
||||
async def get(self, *, automation_id: int, run_id: int) -> AutomationRun:
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ from fastapi import Depends, HTTPException
|
|||
from pydantic import ValidationError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.schedule import compute_next_fire_at
|
||||
from app.db import Permission, User, get_async_session
|
||||
|
|
@ -40,7 +40,9 @@ class TriggerService:
|
|||
params=validated_params,
|
||||
static_inputs=payload.static_inputs,
|
||||
enabled=payload.enabled,
|
||||
next_fire_at=_initial_next_fire(payload.type, validated_params, payload.enabled),
|
||||
next_fire_at=_initial_next_fire(
|
||||
payload.type, validated_params, payload.enabled
|
||||
),
|
||||
)
|
||||
self.session.add(trigger)
|
||||
await self.session.commit()
|
||||
|
|
@ -54,7 +56,9 @@ class TriggerService:
|
|||
trigger_id: int,
|
||||
patch: TriggerUpdate,
|
||||
) -> AutomationTrigger:
|
||||
await self._authorize_automation(automation_id, Permission.AUTOMATIONS_UPDATE.value)
|
||||
await self._authorize_automation(
|
||||
automation_id, Permission.AUTOMATIONS_UPDATE.value
|
||||
)
|
||||
trigger = await self._get_trigger_or_raise(automation_id, trigger_id)
|
||||
|
||||
data = patch.model_dump(exclude_unset=True)
|
||||
|
|
@ -80,7 +84,9 @@ class TriggerService:
|
|||
return trigger
|
||||
|
||||
async def remove(self, *, automation_id: int, trigger_id: int) -> None:
|
||||
await self._authorize_automation(automation_id, Permission.AUTOMATIONS_UPDATE.value)
|
||||
await self._authorize_automation(
|
||||
automation_id, Permission.AUTOMATIONS_UPDATE.value
|
||||
)
|
||||
trigger = await self._get_trigger_or_raise(automation_id, trigger_id)
|
||||
await self.session.delete(trigger)
|
||||
await self.session.commit()
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ TASK_NAME = "automation_run_execute"
|
|||
|
||||
|
||||
@celery_app.task(name=TASK_NAME, bind=True)
|
||||
def automation_run_execute(self, run_id: int) -> None: # noqa: ARG001 — Celery bind
|
||||
def automation_run_execute(self, run_id: int) -> None:
|
||||
"""Execute one ``AutomationRun``. Idempotent: terminal runs no-op."""
|
||||
return run_async_celery_task(lambda: _impl(run_id))
|
||||
|
||||
|
|
|
|||
|
|
@ -103,9 +103,7 @@ async def _self_heal_null_next_fire(session: AsyncSession, *, now: datetime) ->
|
|||
await session.commit()
|
||||
|
||||
|
||||
async def _claim_due_triggers(
|
||||
session: AsyncSession, *, now: datetime
|
||||
) -> list[_Claim]:
|
||||
async def _claim_due_triggers(session: AsyncSession, *, now: datetime) -> list[_Claim]:
|
||||
"""Lock and advance due rows; return per-trigger fire context."""
|
||||
stmt = (
|
||||
select(AutomationTrigger)
|
||||
|
|
|
|||
|
|
@ -17,4 +17,4 @@ __all__ = [
|
|||
]
|
||||
|
||||
# Built-in triggers self-register at import time.
|
||||
from . import schedule # noqa: E402, F401
|
||||
from . import schedule # noqa: F401
|
||||
|
|
|
|||
|
|
@ -15,4 +15,4 @@ __all__ = [
|
|||
]
|
||||
|
||||
# Side-effect: register on the triggers store.
|
||||
from . import definition # noqa: E402, F401
|
||||
from . import definition # noqa: F401
|
||||
|
|
|
|||
|
|
@ -32,6 +32,10 @@ def compute_next_fire_at(cron: str, timezone: str, *, after: datetime) -> dateti
|
|||
given timezone before evaluation so DST and IANA rules apply correctly.
|
||||
"""
|
||||
tz = ZoneInfo(timezone)
|
||||
base = after.astimezone(tz) if after.tzinfo else after.replace(tzinfo=UTC).astimezone(tz)
|
||||
base = (
|
||||
after.astimezone(tz)
|
||||
if after.tzinfo
|
||||
else after.replace(tzinfo=UTC).astimezone(tz)
|
||||
)
|
||||
nxt: datetime = croniter(cron, base).get_next(datetime)
|
||||
return nxt.astimezone(UTC)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ from .cron import InvalidCronError, validate_cron
|
|||
class ScheduleTriggerParams(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
cron: str = Field(..., description="Five-field cron expression.", examples=["0 9 * * 1-5"])
|
||||
cron: str = Field(
|
||||
..., description="Five-field cron expression.", examples=["0 9 * * 1-5"]
|
||||
)
|
||||
timezone: str = Field(..., description="IANA timezone.", examples=["Africa/Kigali"])
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
|
|
|||
|
|
@ -2605,7 +2605,6 @@ from app.automations.persistence import ( # noqa: E402, F401
|
|||
AutomationTrigger,
|
||||
)
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
DATABASE_URL,
|
||||
pool_size=30,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from app.automations.api import router as automations_router
|
||||
|
||||
from .agent_action_log_route import router as agent_action_log_router
|
||||
from .agent_flags_route import router as agent_flags_router
|
||||
from .agent_permissions_route import router as agent_permissions_router
|
||||
|
|
@ -7,7 +9,6 @@ from .agent_revert_route import router as agent_revert_router
|
|||
from .airtable_add_connector_route import (
|
||||
router as airtable_add_connector_router,
|
||||
)
|
||||
from app.automations.api import router as automations_router
|
||||
from .chat_comments_routes import router as chat_comments_router
|
||||
from .circleback_webhook_route import router as circleback_webhook_router
|
||||
from .clickup_add_connector_route import router as clickup_add_connector_router
|
||||
|
|
|
|||
|
|
@ -19,9 +19,7 @@ def extract_todos_from_deepagents(command_output: Any) -> dict:
|
|||
elif isinstance(command_output, dict):
|
||||
if "todos" in command_output:
|
||||
todos_data = command_output.get("todos", [])
|
||||
elif "update" in command_output and isinstance(
|
||||
command_output["update"], dict
|
||||
):
|
||||
elif "update" in command_output and isinstance(command_output["update"], dict):
|
||||
todos_data = command_output["update"].get("todos", [])
|
||||
|
||||
return {"todos": todos_data}
|
||||
|
|
|
|||
|
|
@ -69,17 +69,13 @@ async def resolve_initial_auto_pin(
|
|||
"pin.requires_image_input": requires_image_input,
|
||||
},
|
||||
)
|
||||
return AutoPinResult(
|
||||
llm_config_id=pinned.resolved_llm_config_id, error=None
|
||||
)
|
||||
return AutoPinResult(llm_config_id=pinned.resolved_llm_config_id, error=None)
|
||||
except ValueError as pin_error:
|
||||
# The "no vision-capable cfg" path raises a ValueError whose message
|
||||
# we map to the friendly image-input SSE error so the user sees the
|
||||
# same message regardless of whether the gate fired in the resolver or
|
||||
# in ``llm_capability.assert_vision_capability_for_image_turn``.
|
||||
is_vision_failure = (
|
||||
requires_image_input and "vision-capable" in str(pin_error)
|
||||
)
|
||||
is_vision_failure = requires_image_input and "vision-capable" in str(pin_error)
|
||||
error_code = (
|
||||
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
|
||||
if is_vision_failure
|
||||
|
|
|
|||
|
|
@ -207,9 +207,7 @@ async def _resolve_mentions_for_query(
|
|||
try:
|
||||
chip_objs.append(MentionedDocumentInfo.model_validate(raw))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"stream_new_chat: dropping malformed mention chip %r", raw
|
||||
)
|
||||
logger.debug("stream_new_chat: dropping malformed mention chip %r", raw)
|
||||
|
||||
resolved = await resolve_mentions(
|
||||
session,
|
||||
|
|
|
|||
|
|
@ -48,9 +48,7 @@ def check_image_input_capability(
|
|||
return None
|
||||
|
||||
model_label = agent_config.config_name or agent_config.model_name or "model"
|
||||
ot.add_event(
|
||||
"quota.denied", {"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"}
|
||||
)
|
||||
ot.add_event("quota.denied", {"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"})
|
||||
return (
|
||||
(
|
||||
f"The selected model ({model_label}) does not support "
|
||||
|
|
|
|||
|
|
@ -259,7 +259,8 @@ async def stream_new_chat(
|
|||
|
||||
if needs_premium_quota(agent_config, user_id):
|
||||
premium_reservation = await reserve_premium(
|
||||
agent_config=agent_config, user_id=user_id # type: ignore[arg-type]
|
||||
agent_config=agent_config,
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
)
|
||||
if not premium_reservation.allowed:
|
||||
ot.add_event("quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"})
|
||||
|
|
@ -492,7 +493,9 @@ async def stream_new_chat(
|
|||
|
||||
# --- Block 4: First SSE frames ---
|
||||
|
||||
for sse in iter_initial_frames(streaming_service, turn_id=stream_result.turn_id):
|
||||
for sse in iter_initial_frames(
|
||||
streaming_service, turn_id=stream_result.turn_id
|
||||
):
|
||||
yield sse
|
||||
|
||||
# --- Block 5: Persistence join + message-id frames ---
|
||||
|
|
@ -693,7 +696,9 @@ async def stream_new_chat(
|
|||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode if filesystem_selection else FilesystemMode.CLOUD
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
runtime_context=runtime_context,
|
||||
|
|
@ -715,11 +720,7 @@ async def stream_new_chat(
|
|||
title_emitted = True
|
||||
# Account for the case where the task completed but produced no
|
||||
# title — flip the flag anyway so we don't keep checking it.
|
||||
if (
|
||||
title_task is not None
|
||||
and title_task.done()
|
||||
and not title_emitted
|
||||
):
|
||||
if title_task is not None and title_task.done() and not title_emitted:
|
||||
title_emitted = True
|
||||
|
||||
_perf_log.info(
|
||||
|
|
@ -811,9 +812,7 @@ async def stream_new_chat(
|
|||
end_turn(str(chat_id))
|
||||
|
||||
if premium_reservation is not None and user_id:
|
||||
await release_premium(
|
||||
reservation=premium_reservation, user_id=user_id
|
||||
)
|
||||
await release_premium(reservation=premium_reservation, user_id=user_id)
|
||||
|
||||
await close_session_and_clear_ai_responding(session, chat_id)
|
||||
|
||||
|
|
@ -852,9 +851,9 @@ async def stream_new_chat(
|
|||
|
||||
# Break circular refs held by the agent graph, tools, and LLM
|
||||
# wrappers so the GC can reclaim them in a single pass.
|
||||
agent = llm = connector_service = None # noqa: F841
|
||||
input_state = stream_result = None # noqa: F841
|
||||
session = None # noqa: F841
|
||||
agent = llm = connector_service = None
|
||||
input_state = stream_result = None
|
||||
session = None
|
||||
|
||||
run_gc_pass(log_prefix="stream_new_chat", chat_id=chat_id)
|
||||
close_chat_request_span(
|
||||
|
|
|
|||
|
|
@ -30,9 +30,7 @@ def build_new_chat_runtime_context(
|
|||
return SurfSenseContextSchema(
|
||||
search_space_id=search_space_id,
|
||||
mentioned_document_ids=list(mentioned_document_ids or []),
|
||||
mentioned_folder_ids=list(
|
||||
accepted_folder_ids or mentioned_folder_ids or []
|
||||
),
|
||||
mentioned_folder_ids=list(accepted_folder_ids or mentioned_folder_ids or []),
|
||||
request_id=request_id,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -133,12 +133,8 @@ async def _generate_title(
|
|||
# inherited Azure endpoint — see ``provider_api_base`` for the
|
||||
# same bug repro on the image-gen / vision paths.
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
provider_prefix = (
|
||||
raw_model.split("/", 1)[0] if "/" in raw_model else None
|
||||
)
|
||||
provider_value = (
|
||||
agent_config.provider if agent_config is not None else None
|
||||
)
|
||||
provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None
|
||||
provider_value = agent_config.provider if agent_config is not None else None
|
||||
title_api_base = resolve_api_base(
|
||||
provider=provider_value,
|
||||
provider_prefix=provider_prefix,
|
||||
|
|
|
|||
|
|
@ -15,14 +15,10 @@ building blocks under ``flows/shared/``. Mirrors ``stream_new_chat`` but:
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import uuid as _uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import anyio
|
||||
|
|
@ -32,7 +28,7 @@ from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
|||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.middleware.busy_mutex import end_turn
|
||||
from app.config import config as _app_config
|
||||
from app.db import ChatVisibility, async_session_maker, shielded_async_session
|
||||
from app.db import ChatVisibility, async_session_maker
|
||||
from app.observability import otel as ot
|
||||
from app.services.chat_session_state_service import set_ai_responding
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
|
@ -89,7 +85,7 @@ from app.tasks.chat.streaming.flows.shared.terminal_error import (
|
|||
)
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
from app.tasks.chat.streaming.shared.utils import resume_step_prefix
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
|
|
@ -217,12 +213,11 @@ async def stream_resume_chat(
|
|||
|
||||
if needs_premium_quota(agent_config, user_id):
|
||||
premium_reservation = await reserve_premium(
|
||||
agent_config=agent_config, user_id=user_id # type: ignore[arg-type]
|
||||
agent_config=agent_config,
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
)
|
||||
if not premium_reservation.allowed:
|
||||
ot.add_event(
|
||||
"quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"}
|
||||
)
|
||||
ot.add_event("quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"})
|
||||
if requested_llm_config_id == 0:
|
||||
try:
|
||||
pinned_fb = await resolve_or_get_pinned_llm_config_id(
|
||||
|
|
@ -396,7 +391,9 @@ async def stream_resume_chat(
|
|||
|
||||
# --- First SSE frames ---
|
||||
|
||||
for sse in iter_initial_frames(streaming_service, turn_id=stream_result.turn_id):
|
||||
for sse in iter_initial_frames(
|
||||
streaming_service, turn_id=stream_result.turn_id
|
||||
):
|
||||
yield sse
|
||||
|
||||
# --- Assistant-shell persistence + id frame ---
|
||||
|
|
@ -517,7 +514,9 @@ async def stream_resume_chat(
|
|||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode if filesystem_selection else FilesystemMode.CLOUD
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
runtime_context=runtime_context,
|
||||
|
|
@ -589,9 +588,7 @@ async def stream_resume_chat(
|
|||
end_turn(str(chat_id))
|
||||
|
||||
if premium_reservation is not None and user_id:
|
||||
await release_premium(
|
||||
reservation=premium_reservation, user_id=user_id
|
||||
)
|
||||
await release_premium(reservation=premium_reservation, user_id=user_id)
|
||||
|
||||
await close_session_and_clear_ai_responding(session, chat_id)
|
||||
|
||||
|
|
@ -609,13 +606,11 @@ async def stream_resume_chat(
|
|||
if not busy_error_raised:
|
||||
with contextlib.suppress(Exception):
|
||||
end_turn(str(chat_id))
|
||||
_perf_log.info(
|
||||
"[stream_resume] end_turn cleanup (chat_id=%s)", chat_id
|
||||
)
|
||||
_perf_log.info("[stream_resume] end_turn cleanup (chat_id=%s)", chat_id)
|
||||
|
||||
agent = llm = connector_service = None # noqa: F841
|
||||
stream_result = None # noqa: F841
|
||||
session = None # noqa: F841
|
||||
agent = llm = connector_service = None
|
||||
stream_result = None
|
||||
session = None
|
||||
|
||||
run_gc_pass(log_prefix="stream_resume", chat_id=chat_id)
|
||||
close_chat_request_span(
|
||||
|
|
|
|||
|
|
@ -47,9 +47,7 @@ async def build_resume_routing(
|
|||
slice_decisions_by_tool_call,
|
||||
)
|
||||
|
||||
parent_state = await agent.aget_state(
|
||||
{"configurable": {"thread_id": str(chat_id)}}
|
||||
)
|
||||
parent_state = await agent.aget_state({"configurable": {"thread_id": str(chat_id)}})
|
||||
pending = collect_pending_tool_calls(parent_state)
|
||||
_perf_log.info(
|
||||
"[hitl_route] resume_entry chat_id=%s decisions=%d pending_subagents=%d",
|
||||
|
|
|
|||
|
|
@ -49,9 +49,7 @@ async def finalize_assistant_message(
|
|||
was never assigned.
|
||||
"""
|
||||
if not (
|
||||
stream_result
|
||||
and stream_result.turn_id
|
||||
and stream_result.assistant_message_id
|
||||
stream_result and stream_result.turn_id and stream_result.assistant_message_id
|
||||
):
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -39,9 +39,7 @@ async def close_session_and_clear_ai_responding(
|
|||
async with shielded_async_session() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
logger.warning("Failed to clear AI responding state for thread %s", chat_id)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
session.expunge_all()
|
||||
|
|
|
|||
|
|
@ -41,9 +41,7 @@ class PremiumReservation:
|
|||
allowed: bool
|
||||
|
||||
|
||||
def needs_premium_quota(
|
||||
agent_config: AgentConfig | None, user_id: str | None
|
||||
) -> bool:
|
||||
def needs_premium_quota(agent_config: AgentConfig | None, user_id: str | None) -> bool:
|
||||
return bool(agent_config is not None and user_id and agent_config.is_premium)
|
||||
|
||||
|
||||
|
|
@ -61,8 +59,10 @@ async def reserve_premium(
|
|||
request_id = _uuid.uuid4().hex[:16]
|
||||
litellm_params = agent_config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
) or agent_config.model_name or ""
|
||||
(litellm_params.get("base_model") if isinstance(litellm_params, dict) else None)
|
||||
or agent_config.model_name
|
||||
or ""
|
||||
)
|
||||
reserve_amount_micros = estimate_call_reserve_micros(
|
||||
base_model=base_model,
|
||||
quota_reserve_tokens=agent_config.quota_reserve_tokens,
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ import contextlib
|
|||
import sys
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.observability import metrics as ot_metrics
|
||||
from app.observability import otel as ot
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
|
||||
def open_chat_request_span(
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ from collections.abc import Iterator
|
|||
from typing import Any, Literal
|
||||
|
||||
from app.agents.new_chat.errors import BusyError
|
||||
from app.observability import metrics as ot_metrics
|
||||
from app.observability import otel as ot
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.streaming.errors.classifier import classify_stream_exception
|
||||
from app.tasks.chat.streaming.errors.emitter import emit_stream_terminal_error
|
||||
|
|
|
|||
|
|
@ -72,7 +72,11 @@ def test_extract_returns_none_when_no_assistant_text_is_present() -> None:
|
|||
anything?" rather than guess whether ``""`` means silence or empty
|
||||
output. Empty-string contents are normalized to ``None`` too."""
|
||||
no_ai = {"messages": [HumanMessage(content="just a question")]}
|
||||
only_tools = {"messages": [AIMessage(content=[{"type": "tool_use", "name": "x", "input": {}}])]}
|
||||
only_tools = {
|
||||
"messages": [
|
||||
AIMessage(content=[{"type": "tool_use", "name": "x", "input": {}}])
|
||||
]
|
||||
}
|
||||
empty_string = {"messages": [AIMessage(content=" ")]}
|
||||
|
||||
assert extract_final_assistant_message(no_ai) is None
|
||||
|
|
|
|||
|
|
@ -33,7 +33,9 @@ async def test_with_retries_returns_result_and_attempts_one_on_first_success() -
|
|||
assert calls == 1
|
||||
|
||||
|
||||
async def test_with_retries_returns_attempt_count_when_succeeding_after_failures() -> None:
|
||||
async def test_with_retries_returns_attempt_count_when_succeeding_after_failures() -> (
|
||||
None
|
||||
):
|
||||
"""A coroutine that fails twice then succeeds returns ``attempts=3``
|
||||
(the actual attempt that produced the result). Locks the contract
|
||||
that the caller can distinguish first-try success from a recovery."""
|
||||
|
|
|
|||
|
|
@ -11,7 +11,9 @@ from app.automations.schemas.definition.plan_step import PlanStep
|
|||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_automation_definition_accepts_minimal_valid_input_with_sensible_defaults() -> None:
|
||||
def test_automation_definition_accepts_minimal_valid_input_with_sensible_defaults() -> (
|
||||
None
|
||||
):
|
||||
"""A definition with just ``name`` + a one-step ``plan`` is valid and
|
||||
fills in the rest with safe defaults so users don't have to write
|
||||
out every section to get started."""
|
||||
|
|
|
|||
|
|
@ -32,7 +32,9 @@ def test_environment_finalizes_datetime_output_to_iso_string() -> None:
|
|||
when emitting ``inputs.fired_at`` and other datetime values."""
|
||||
dt = datetime(2026, 5, 28, 14, 30, tzinfo=UTC)
|
||||
|
||||
assert render_template("{{ moment }}", {"moment": dt}) == "2026-05-28T14:30:00+00:00"
|
||||
assert (
|
||||
render_template("{{ moment }}", {"moment": dt}) == "2026-05-28T14:30:00+00:00"
|
||||
)
|
||||
|
||||
|
||||
def test_environment_finalizes_none_output_to_empty_string() -> None:
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def test_action_definition_params_schema_reflects_params_model() -> None:
|
|||
name="N",
|
||||
description="D",
|
||||
params_model=_Topic,
|
||||
build_handler=lambda _ctx: (lambda _p: {}), # type: ignore[arg-type,return-value]
|
||||
build_handler=lambda _ctx: lambda _p: {}, # type: ignore[arg-type,return-value]
|
||||
)
|
||||
|
||||
schema = definition.params_schema
|
||||
|
|
|
|||
|
|
@ -29,7 +29,9 @@ class _Params(BaseModel):
|
|||
|
||||
|
||||
def _trigger(type_: str = "test_trigger") -> TriggerDefinition:
|
||||
return TriggerDefinition(type=type_, description="Test trigger.", params_model=_Params)
|
||||
return TriggerDefinition(
|
||||
type=type_, description="Test trigger.", params_model=_Params
|
||||
)
|
||||
|
||||
|
||||
def _action(type_: str = "test_action") -> ActionDefinition:
|
||||
|
|
@ -38,7 +40,7 @@ def _action(type_: str = "test_action") -> ActionDefinition:
|
|||
name="Test",
|
||||
description="Test action.",
|
||||
params_model=_Params,
|
||||
build_handler=lambda _ctx: (lambda _p: {}), # type: ignore[arg-type,return-value]
|
||||
build_handler=lambda _ctx: lambda _p: {}, # type: ignore[arg-type,return-value]
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -112,4 +114,4 @@ def test_all_triggers_returns_defensive_snapshot(
|
|||
snapshot = all_triggers()
|
||||
snapshot.pop("snapshot_test")
|
||||
|
||||
assert get_trigger("snapshot_test") is not None
|
||||
assert get_trigger("snapshot_test") is not None
|
||||
|
|
|
|||
|
|
@ -45,8 +45,12 @@ def test_compute_next_fire_at_respects_dst_offset_change() -> None:
|
|||
winter_after = datetime(2026, 2, 15, 0, 0, tzinfo=UTC)
|
||||
summer_after = datetime(2026, 4, 15, 0, 0, tzinfo=UTC)
|
||||
|
||||
winter_fire = compute_next_fire_at("0 9 * * *", "America/New_York", after=winter_after)
|
||||
summer_fire = compute_next_fire_at("0 9 * * *", "America/New_York", after=summer_after)
|
||||
winter_fire = compute_next_fire_at(
|
||||
"0 9 * * *", "America/New_York", after=winter_after
|
||||
)
|
||||
summer_fire = compute_next_fire_at(
|
||||
"0 9 * * *", "America/New_York", after=summer_after
|
||||
)
|
||||
|
||||
assert winter_fire == datetime(2026, 2, 15, 14, 0, tzinfo=UTC)
|
||||
assert summer_fire == datetime(2026, 4, 15, 13, 0, tzinfo=UTC)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ import pytest
|
|||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
stream_new_chat as old_stream_new_chat,
|
||||
stream_resume_chat as old_stream_resume_chat,
|
||||
|
|
@ -152,7 +151,13 @@ class _FakeSurfsenseDoc:
|
|||
"user_query, image_urls, docs, expected_title, expected_action",
|
||||
[
|
||||
("hello world", None, [], "Understanding your request", "Processing"),
|
||||
("", ["data:image/png;base64,AAA"], [], "Understanding your request", "Processing"),
|
||||
(
|
||||
"",
|
||||
["data:image/png;base64,AAA"],
|
||||
[],
|
||||
"Understanding your request",
|
||||
"Processing",
|
||||
),
|
||||
("", None, [], "Understanding your request", "Processing"),
|
||||
(
|
||||
"doc question",
|
||||
|
|
@ -209,9 +214,10 @@ def test_initial_thinking_step_collapses_many_doc_names() -> None:
|
|||
|
||||
|
||||
def test_image_capability_passes_without_images() -> None:
|
||||
assert check_image_input_capability(
|
||||
user_image_data_urls=None, agent_config=None
|
||||
) is None
|
||||
assert (
|
||||
check_image_input_capability(user_image_data_urls=None, agent_config=None)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_image_capability_passes_when_capability_unknown() -> None:
|
||||
|
|
@ -500,9 +506,7 @@ def test_can_recover_provider_rate_limit_rejects_non_rate_limit_exception() -> N
|
|||
def test_spawn_set_ai_responding_bg_noop_without_user_id() -> None:
|
||||
async def _run() -> set[asyncio.Task]:
|
||||
background: set[asyncio.Task] = set()
|
||||
spawn_set_ai_responding_bg(
|
||||
chat_id=1, user_id=None, background_tasks=background
|
||||
)
|
||||
spawn_set_ai_responding_bg(chat_id=1, user_id=None, background_tasks=background)
|
||||
return background
|
||||
|
||||
bg = asyncio.run(_run())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue