mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
feat(automations): static_inputs on triggers + vertical-slice api/services
This commit is contained in:
parent
84d99f19a2
commit
27ab367a13
27 changed files with 915 additions and 356 deletions
|
|
@ -87,6 +87,7 @@ def upgrade() -> None:
|
|||
REFERENCES automations(id) ON DELETE CASCADE,
|
||||
type automation_trigger_type NOT NULL,
|
||||
params JSONB NOT NULL,
|
||||
static_inputs JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
last_fired_at TIMESTAMP WITH TIME ZONE,
|
||||
next_fire_at TIMESTAMP WITH TIME ZONE,
|
||||
|
|
@ -129,8 +130,7 @@ def upgrade() -> None:
|
|||
REFERENCES automation_triggers(id) ON DELETE SET NULL,
|
||||
status automation_run_status NOT NULL DEFAULT 'pending',
|
||||
definition_snapshot JSONB NOT NULL,
|
||||
trigger_payload JSONB,
|
||||
resolved_inputs JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
inputs JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
step_results JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
output JSONB,
|
||||
artifacts JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ AGENT_TASK_ACTION = ActionDefinition(
|
|||
type="agent_task",
|
||||
name="Agent task",
|
||||
description="Run a multi_agent_chat turn from an automation step.",
|
||||
params_schema=AgentTaskActionParams.model_json_schema(),
|
||||
params_model=AgentTaskActionParams,
|
||||
build_handler=build_handler,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from dataclasses import dataclass
|
|||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
|
|
@ -30,5 +31,10 @@ class ActionDefinition:
|
|||
type: str
|
||||
name: str
|
||||
description: str
|
||||
params_schema: dict[str, Any]
|
||||
params_model: type[BaseModel]
|
||||
build_handler: ActionHandlerFactory
|
||||
|
||||
@property
|
||||
def params_schema(self) -> dict[str, Any]:
|
||||
"""JSON Schema (draft 2020-12) derived from ``params_model``."""
|
||||
return self.params_model.model_json_schema()
|
||||
|
|
|
|||
|
|
@ -5,8 +5,12 @@ from __future__ import annotations
|
|||
from fastapi import APIRouter
|
||||
|
||||
from .automation import router as automation_router
|
||||
from .run import router as run_router
|
||||
from .trigger import router as trigger_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(automation_router)
|
||||
router.include_router(trigger_router)
|
||||
router.include_router(run_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
|
|
|||
|
|
@ -1,23 +1,80 @@
|
|||
"""Routes for the ``Automation`` resource."""
|
||||
"""HTTP routes for the ``Automation`` resource."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
|
||||
from app.automations.api.schemas import RunDispatched
|
||||
from app.automations.schemas.api import (
|
||||
AutomationCreate,
|
||||
AutomationDetail,
|
||||
AutomationList,
|
||||
AutomationSummary,
|
||||
AutomationUpdate,
|
||||
)
|
||||
from app.automations.services import AutomationService, get_automation_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/automations/{automation_id}/run", response_model=RunDispatched)
|
||||
async def run_automation_now(
|
||||
automation_id: int,
|
||||
payload: dict[str, Any] | None = Body(default=None),
|
||||
@router.post(
|
||||
"/automations",
|
||||
response_model=AutomationDetail,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_automation(
|
||||
payload: AutomationCreate,
|
||||
service: AutomationService = Depends(get_automation_service),
|
||||
) -> RunDispatched:
|
||||
"""Fire a manual run."""
|
||||
run = await service.run_now(automation_id=automation_id, payload=payload)
|
||||
return RunDispatched(run_id=run.id, status=run.status)
|
||||
) -> AutomationDetail:
|
||||
"""Create an automation, optionally with initial triggers (atomic)."""
|
||||
automation = await service.create(payload)
|
||||
return AutomationDetail.model_validate(automation)
|
||||
|
||||
|
||||
@router.get("/automations", response_model=AutomationList)
|
||||
async def list_automations(
|
||||
search_space_id: int = Query(...),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
service: AutomationService = Depends(get_automation_service),
|
||||
) -> AutomationList:
|
||||
"""List automations in a search space."""
|
||||
items, total = await service.list(
|
||||
search_space_id=search_space_id, limit=limit, offset=offset
|
||||
)
|
||||
return AutomationList(
|
||||
items=[AutomationSummary.model_validate(a) for a in items],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/automations/{automation_id}", response_model=AutomationDetail)
|
||||
async def get_automation(
|
||||
automation_id: int,
|
||||
service: AutomationService = Depends(get_automation_service),
|
||||
) -> AutomationDetail:
|
||||
"""Get one automation with its definition and triggers."""
|
||||
automation = await service.get(automation_id)
|
||||
return AutomationDetail.model_validate(automation)
|
||||
|
||||
|
||||
@router.patch("/automations/{automation_id}", response_model=AutomationDetail)
|
||||
async def update_automation(
|
||||
automation_id: int,
|
||||
patch: AutomationUpdate,
|
||||
service: AutomationService = Depends(get_automation_service),
|
||||
) -> AutomationDetail:
|
||||
"""Partially update an automation. Triggers are managed separately."""
|
||||
automation = await service.update(automation_id, patch)
|
||||
return AutomationDetail.model_validate(automation)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/automations/{automation_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_automation(
|
||||
automation_id: int,
|
||||
service: AutomationService = Depends(get_automation_service),
|
||||
) -> None:
|
||||
"""Delete an automation; triggers and runs are removed by FK cascade."""
|
||||
await service.delete(automation_id)
|
||||
|
|
|
|||
71
surfsense_backend/app/automations/api/run.py
Normal file
71
surfsense_backend/app/automations/api/run.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""HTTP routes for automation runs (dispatch + history)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Query, status
|
||||
|
||||
from app.automations.schemas.api import (
|
||||
RunDetail,
|
||||
RunDispatched,
|
||||
RunList,
|
||||
RunSummary,
|
||||
)
|
||||
from app.automations.services import RunService, get_run_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/automations/{automation_id}/run",
|
||||
response_model=RunDispatched,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
async def run_automation_now(
|
||||
automation_id: int,
|
||||
inputs: dict[str, Any] | None = Body(default=None),
|
||||
service: RunService = Depends(get_run_service),
|
||||
) -> RunDispatched:
|
||||
"""Fire a manual run.
|
||||
|
||||
``inputs`` is the runtime payload supplied by the caller; it is merged with
|
||||
the manual trigger's ``static_inputs`` (static wins) and validated against
|
||||
the automation's input schema.
|
||||
"""
|
||||
run = await service.dispatch_manual(automation_id=automation_id, runtime_inputs=inputs)
|
||||
return RunDispatched(run_id=run.id, status=run.status)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/automations/{automation_id}/runs",
|
||||
response_model=RunList,
|
||||
)
|
||||
async def list_runs(
|
||||
automation_id: int,
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
service: RunService = Depends(get_run_service),
|
||||
) -> RunList:
|
||||
"""List run history for an automation, newest first."""
|
||||
items, total = await service.list(
|
||||
automation_id=automation_id, limit=limit, offset=offset
|
||||
)
|
||||
return RunList(
|
||||
items=[RunSummary.model_validate(r) for r in items],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/automations/{automation_id}/runs/{run_id}",
|
||||
response_model=RunDetail,
|
||||
)
|
||||
async def get_run(
|
||||
automation_id: int,
|
||||
run_id: int,
|
||||
service: RunService = Depends(get_run_service),
|
||||
) -> RunDetail:
|
||||
"""Get the full record of a single run, including step results and artifacts."""
|
||||
run = await service.get(automation_id=automation_id, run_id=run_id)
|
||||
return RunDetail.model_validate(run)
|
||||
55
surfsense_backend/app/automations/api/trigger.py
Normal file
55
surfsense_backend/app/automations/api/trigger.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
"""HTTP routes for triggers attached to an automation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
|
||||
from app.automations.schemas.api import TriggerCreate, TriggerDetail, TriggerUpdate
|
||||
from app.automations.services import TriggerService, get_trigger_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/automations/{automation_id}/triggers",
|
||||
response_model=TriggerDetail,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def add_trigger(
|
||||
automation_id: int,
|
||||
payload: TriggerCreate,
|
||||
service: TriggerService = Depends(get_trigger_service),
|
||||
) -> TriggerDetail:
|
||||
"""Attach a new trigger to an automation."""
|
||||
trigger = await service.add(automation_id=automation_id, payload=payload)
|
||||
return TriggerDetail.model_validate(trigger)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/automations/{automation_id}/triggers/{trigger_id}",
|
||||
response_model=TriggerDetail,
|
||||
)
|
||||
async def update_trigger(
|
||||
automation_id: int,
|
||||
trigger_id: int,
|
||||
patch: TriggerUpdate,
|
||||
service: TriggerService = Depends(get_trigger_service),
|
||||
) -> TriggerDetail:
|
||||
"""Toggle ``enabled`` or replace ``params``. Trigger type is immutable."""
|
||||
trigger = await service.update(
|
||||
automation_id=automation_id, trigger_id=trigger_id, patch=patch
|
||||
)
|
||||
return TriggerDetail.model_validate(trigger)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/automations/{automation_id}/triggers/{trigger_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def remove_trigger(
|
||||
automation_id: int,
|
||||
trigger_id: int,
|
||||
service: TriggerService = Depends(get_trigger_service),
|
||||
) -> None:
|
||||
"""Detach a trigger from an automation."""
|
||||
await service.remove(automation_id=automation_id, trigger_id=trigger_id)
|
||||
|
|
@ -22,10 +22,14 @@ async def dispatch_run(
|
|||
session: AsyncSession,
|
||||
automation: Automation,
|
||||
trigger: AutomationTrigger,
|
||||
payload: dict[str, Any] | None,
|
||||
runtime_inputs: dict[str, Any] | None = None,
|
||||
) -> AutomationRun:
|
||||
"""Validate, snapshot the definition, persist an ``AutomationRun``, enqueue execution.
|
||||
|
||||
Final inputs = ``trigger.static_inputs`` merged with ``runtime_inputs``,
|
||||
static winning on key collision. The merged dict is validated against
|
||||
``automation.definition.inputs.schema_`` and stored on the run.
|
||||
|
||||
Callers (trigger-specific adapters) are responsible for resolving
|
||||
``automation`` and ``trigger`` and for the trigger-side ``ACTIVE`` /
|
||||
``enabled`` guards. This function only handles what's identical across
|
||||
|
|
@ -36,7 +40,8 @@ async def dispatch_run(
|
|||
except Exception as exc:
|
||||
raise DispatchError(f"invalid automation definition: {exc}") from exc
|
||||
|
||||
resolved_inputs = _validate_inputs(definition, payload or {})
|
||||
merged_inputs = {**(runtime_inputs or {}), **(trigger.static_inputs or {})}
|
||||
validated_inputs = _validate_inputs(definition, merged_inputs)
|
||||
snapshot = definition.model_dump(mode="json", by_alias=True)
|
||||
|
||||
run = AutomationRun(
|
||||
|
|
@ -44,8 +49,7 @@ async def dispatch_run(
|
|||
trigger_id=trigger.id,
|
||||
status=RunStatus.PENDING,
|
||||
definition_snapshot=snapshot,
|
||||
trigger_payload=payload,
|
||||
resolved_inputs=resolved_inputs,
|
||||
inputs=validated_inputs,
|
||||
step_results=[],
|
||||
artifacts=[],
|
||||
)
|
||||
|
|
@ -61,12 +65,12 @@ async def dispatch_run(
|
|||
|
||||
|
||||
def _validate_inputs(
|
||||
definition: AutomationDefinition, payload: dict[str, Any]
|
||||
definition: AutomationDefinition, inputs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
if definition.inputs is None or not definition.inputs.schema_:
|
||||
return {}
|
||||
try:
|
||||
jsonschema.validate(instance=payload, schema=definition.inputs.schema_)
|
||||
jsonschema.validate(instance=inputs, schema=definition.inputs.schema_)
|
||||
except jsonschema.ValidationError as exc:
|
||||
raise DispatchError(f"inputs: {exc.message}") from exc
|
||||
return payload
|
||||
return inputs
|
||||
|
|
|
|||
|
|
@ -45,8 +45,9 @@ class AutomationRun(BaseModel, TimestampMixin):
|
|||
# locked at fire time so historical runs always show the exact code path
|
||||
definition_snapshot = Column(JSONB, nullable=False)
|
||||
|
||||
trigger_payload = Column(JSONB, nullable=True)
|
||||
resolved_inputs = Column(JSONB, nullable=False, server_default="{}")
|
||||
# merged & validated inputs the run was dispatched with
|
||||
# (trigger.static_inputs ∪ producer runtime data, static wins on collision)
|
||||
inputs = Column(JSONB, nullable=False, server_default="{}")
|
||||
# one entry per executed step; agent_task entries carry their own
|
||||
# `agent_session_id` inside their entry
|
||||
step_results = Column(JSONB, nullable=False, server_default="[]")
|
||||
|
|
|
|||
|
|
@ -36,6 +36,10 @@ class AutomationTrigger(BaseModel, TimestampMixin):
|
|||
|
||||
params = Column(JSONB, nullable=False)
|
||||
|
||||
# Per-attachment domain values merged into every dispatched run's inputs.
|
||||
# Static wins over runtime data on key collision.
|
||||
static_inputs = Column(JSONB, nullable=False, server_default="{}")
|
||||
|
||||
enabled = Column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ def _build_template_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dic
|
|||
trigger_type=trigger.type.value if trigger else None,
|
||||
started_at=run.started_at,
|
||||
attempt=1,
|
||||
resolved_inputs=run.resolved_inputs or {},
|
||||
inputs=run.inputs or {},
|
||||
step_outputs=step_outputs,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,8 +28,7 @@ class RunDetail(RunSummary):
|
|||
"""Full run view including snapshot, results and artifacts."""
|
||||
|
||||
definition_snapshot: dict[str, Any]
|
||||
trigger_payload: dict[str, Any] | None = None
|
||||
resolved_inputs: dict[str, Any]
|
||||
inputs: dict[str, Any]
|
||||
step_results: list[dict[str, Any]]
|
||||
output: dict[str, Any] | None = None
|
||||
artifacts: list[dict[str, Any]]
|
||||
|
|
@ -17,6 +17,7 @@ class TriggerCreate(BaseModel):
|
|||
|
||||
type: TriggerType
|
||||
params: dict[str, Any] = Field(default_factory=dict)
|
||||
static_inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
|
|
@ -27,6 +28,7 @@ class TriggerUpdate(BaseModel):
|
|||
|
||||
enabled: bool | None = None
|
||||
params: dict[str, Any] | None = None
|
||||
static_inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TriggerDetail(BaseModel):
|
||||
|
|
@ -37,6 +39,7 @@ class TriggerDetail(BaseModel):
|
|||
id: int
|
||||
type: TriggerType
|
||||
params: dict[str, Any]
|
||||
static_inputs: dict[str, Any]
|
||||
enabled: bool
|
||||
last_fired_at: datetime | None = None
|
||||
next_fire_at: datetime | None = None
|
||||
|
|
@ -1,7 +1,16 @@
|
|||
"""Service layer for the automations feature."""
|
||||
"""Services for the automations HTTP layer (one service per resource)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .automation import AutomationService, get_automation_service
|
||||
from .run import RunService, get_run_service
|
||||
from .trigger import TriggerService, get_trigger_service
|
||||
|
||||
__all__ = ["AutomationService", "get_automation_service"]
|
||||
__all__ = [
|
||||
"AutomationService",
|
||||
"RunService",
|
||||
"TriggerService",
|
||||
"get_automation_service",
|
||||
"get_run_service",
|
||||
"get_trigger_service",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,54 +2,111 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.automations.dispatch import DispatchError
|
||||
from app.automations.schemas.api import (
|
||||
AutomationCreate,
|
||||
AutomationUpdate,
|
||||
TriggerCreate,
|
||||
)
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
from app.automations.triggers.manual import dispatch_manual_run
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.schedule import compute_next_fire_at
|
||||
from app.db import Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
||||
class AutomationService:
|
||||
"""Service for the ``Automation`` resource."""
|
||||
"""Lifecycle of the ``Automation`` resource."""
|
||||
|
||||
def __init__(self, *, session: AsyncSession, user: User) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
|
||||
async def run_now(
|
||||
async def create(self, payload: AutomationCreate) -> Automation:
|
||||
"""Create an automation and its initial triggers in one transaction."""
|
||||
await self._authorize(payload.search_space_id, Permission.AUTOMATIONS_CREATE.value)
|
||||
|
||||
automation = Automation(
|
||||
search_space_id=payload.search_space_id,
|
||||
created_by_user_id=self.user.id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
definition=payload.definition.model_dump(mode="json", by_alias=True),
|
||||
version=1,
|
||||
)
|
||||
for spec in payload.triggers:
|
||||
automation.triggers.append(_build_trigger(spec))
|
||||
|
||||
self.session.add(automation)
|
||||
await self.session.commit()
|
||||
return await self._get_with_triggers_or_raise(automation.id)
|
||||
|
||||
async def list(
|
||||
self,
|
||||
*,
|
||||
automation_id: int,
|
||||
payload: dict[str, Any] | None,
|
||||
) -> AutomationRun:
|
||||
"""Fire a manual run for ``automation_id``."""
|
||||
automation = await self._get_automation_or_raise(automation_id)
|
||||
await check_permission(
|
||||
self.session,
|
||||
self.user,
|
||||
automation.search_space_id,
|
||||
Permission.AUTOMATIONS_EXECUTE.value,
|
||||
"You don't have permission to execute automations in this search space",
|
||||
search_space_id: int,
|
||||
limit: int,
|
||||
offset: int,
|
||||
) -> tuple[list[Automation], int]:
|
||||
"""Return a page of automations and the total count."""
|
||||
await self._authorize(search_space_id, Permission.AUTOMATIONS_READ.value)
|
||||
|
||||
base = select(Automation).where(Automation.search_space_id == search_space_id)
|
||||
total = await self.session.scalar(
|
||||
select(func.count()).select_from(base.subquery())
|
||||
)
|
||||
|
||||
try:
|
||||
return await dispatch_manual_run(
|
||||
session=self.session,
|
||||
automation_id=automation_id,
|
||||
payload=payload,
|
||||
rows = (
|
||||
await self.session.execute(
|
||||
base.order_by(Automation.created_at.desc()).limit(limit).offset(offset)
|
||||
)
|
||||
except DispatchError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
).scalars().all()
|
||||
return list(rows), int(total or 0)
|
||||
|
||||
async def _get_automation_or_raise(self, automation_id: int) -> Automation:
|
||||
"""Get the automation by id; 404 if missing."""
|
||||
async def get(self, automation_id: int) -> Automation:
|
||||
"""Get an automation with its triggers loaded."""
|
||||
automation = await self._get_with_triggers_or_raise(automation_id)
|
||||
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_READ.value)
|
||||
return automation
|
||||
|
||||
async def update(self, automation_id: int, patch: AutomationUpdate) -> Automation:
|
||||
"""Patch fields. Bumps ``version`` when ``definition`` changes."""
|
||||
automation = await self._get_with_triggers_or_raise(automation_id)
|
||||
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_UPDATE.value)
|
||||
|
||||
data = patch.model_dump(exclude_unset=True)
|
||||
|
||||
if "name" in data:
|
||||
automation.name = data["name"]
|
||||
if "description" in data:
|
||||
automation.description = data["description"]
|
||||
if "status" in data:
|
||||
automation.status = data["status"]
|
||||
if "definition" in data:
|
||||
automation.definition = patch.definition.model_dump(mode="json", by_alias=True)
|
||||
automation.version += 1
|
||||
|
||||
await self.session.commit()
|
||||
return await self._get_with_triggers_or_raise(automation_id)
|
||||
|
||||
async def delete(self, automation_id: int) -> None:
|
||||
"""Delete an automation; FK cascades remove triggers and runs."""
|
||||
automation = await self._get_or_raise(automation_id)
|
||||
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_DELETE.value)
|
||||
await self.session.delete(automation)
|
||||
await self.session.commit()
|
||||
|
||||
async def _get_or_raise(self, automation_id: int) -> Automation:
|
||||
automation = await self.session.get(Automation, automation_id)
|
||||
if automation is None:
|
||||
raise HTTPException(
|
||||
|
|
@ -57,6 +114,56 @@ class AutomationService:
|
|||
)
|
||||
return automation
|
||||
|
||||
async def _get_with_triggers_or_raise(self, automation_id: int) -> Automation:
|
||||
stmt = (
|
||||
select(Automation)
|
||||
.where(Automation.id == automation_id)
|
||||
.options(selectinload(Automation.triggers))
|
||||
)
|
||||
automation = (await self.session.execute(stmt)).scalar_one_or_none()
|
||||
if automation is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"automation {automation_id} not found"
|
||||
)
|
||||
return automation
|
||||
|
||||
async def _authorize(self, search_space_id: int, permission: str) -> None:
|
||||
await check_permission(
|
||||
self.session,
|
||||
self.user,
|
||||
search_space_id,
|
||||
permission,
|
||||
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
|
||||
)
|
||||
|
||||
|
||||
def _build_trigger(spec: TriggerCreate) -> AutomationTrigger:
|
||||
"""Validate trigger params via its registered Pydantic model and build the ORM row."""
|
||||
definition = get_trigger(spec.type.value)
|
||||
if definition is None:
|
||||
raise HTTPException(status_code=422, detail=f"unknown trigger type {spec.type.value!r}")
|
||||
|
||||
try:
|
||||
validated = definition.params_model.model_validate(spec.params)
|
||||
except ValidationError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
||||
params = validated.model_dump(mode="json")
|
||||
|
||||
next_fire_at = None
|
||||
if spec.type == TriggerType.SCHEDULE and spec.enabled:
|
||||
next_fire_at = compute_next_fire_at(
|
||||
params["cron"], params["timezone"], after=datetime.now(UTC)
|
||||
)
|
||||
|
||||
return AutomationTrigger(
|
||||
type=spec.type,
|
||||
params=params,
|
||||
static_inputs=spec.static_inputs,
|
||||
enabled=spec.enabled,
|
||||
next_fire_at=next_fire_at,
|
||||
)
|
||||
|
||||
|
||||
def get_automation_service(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
|
|
|
|||
93
surfsense_backend/app/automations/services/run.py
Normal file
93
surfsense_backend/app/automations/services/run.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""``RunService`` — dispatch and history of automation runs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.dispatch import DispatchError
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
from app.automations.triggers.manual import dispatch_manual_run
|
||||
from app.db import Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
||||
class RunService:
|
||||
"""Lifecycle of the ``AutomationRun`` resource."""
|
||||
|
||||
def __init__(self, *, session: AsyncSession, user: User) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
|
||||
async def dispatch_manual(
|
||||
self,
|
||||
*,
|
||||
automation_id: int,
|
||||
runtime_inputs: dict[str, Any] | None,
|
||||
) -> AutomationRun:
|
||||
"""Fire a manual run via the registered manual trigger."""
|
||||
await self._authorize(automation_id, Permission.AUTOMATIONS_EXECUTE.value)
|
||||
try:
|
||||
return await dispatch_manual_run(
|
||||
session=self.session,
|
||||
automation_id=automation_id,
|
||||
runtime_inputs=runtime_inputs,
|
||||
)
|
||||
except DispatchError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
||||
async def list(
|
||||
self,
|
||||
*,
|
||||
automation_id: int,
|
||||
limit: int,
|
||||
offset: int,
|
||||
) -> tuple[list[AutomationRun], int]:
|
||||
"""Return a page of runs for an automation, newest first."""
|
||||
await self._authorize(automation_id, Permission.AUTOMATIONS_READ.value)
|
||||
|
||||
base = select(AutomationRun).where(AutomationRun.automation_id == automation_id)
|
||||
total = await self.session.scalar(
|
||||
select(func.count()).select_from(base.subquery())
|
||||
)
|
||||
|
||||
rows = (
|
||||
await self.session.execute(
|
||||
base.order_by(AutomationRun.created_at.desc()).limit(limit).offset(offset)
|
||||
)
|
||||
).scalars().all()
|
||||
return list(rows), int(total or 0)
|
||||
|
||||
async def get(self, *, automation_id: int, run_id: int) -> AutomationRun:
|
||||
await self._authorize(automation_id, Permission.AUTOMATIONS_READ.value)
|
||||
run = await self.session.get(AutomationRun, run_id)
|
||||
if run is None or run.automation_id != automation_id:
|
||||
raise HTTPException(status_code=404, detail=f"run {run_id} not found")
|
||||
return run
|
||||
|
||||
async def _authorize(self, automation_id: int, permission: str) -> Automation:
|
||||
automation = await self.session.get(Automation, automation_id)
|
||||
if automation is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"automation {automation_id} not found"
|
||||
)
|
||||
await check_permission(
|
||||
self.session,
|
||||
self.user,
|
||||
automation.search_space_id,
|
||||
permission,
|
||||
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
|
||||
)
|
||||
return automation
|
||||
|
||||
|
||||
def get_run_service(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> RunService:
|
||||
return RunService(session=session, user=user)
|
||||
143
surfsense_backend/app/automations/services/trigger.py
Normal file
143
surfsense_backend/app/automations/services/trigger.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""``TriggerService`` — lifecycle of triggers attached to an automation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.schedule import compute_next_fire_at
|
||||
from app.db import Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
||||
class TriggerService:
|
||||
"""Lifecycle of the ``AutomationTrigger`` sub-resource."""
|
||||
|
||||
def __init__(self, *, session: AsyncSession, user: User) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
|
||||
async def add(
|
||||
self, *, automation_id: int, payload: TriggerCreate
|
||||
) -> AutomationTrigger:
|
||||
automation = await self._authorize_automation(
|
||||
automation_id, Permission.AUTOMATIONS_UPDATE.value
|
||||
)
|
||||
|
||||
validated_params = _validate_params(payload.type, payload.params)
|
||||
trigger = AutomationTrigger(
|
||||
automation_id=automation.id,
|
||||
type=payload.type,
|
||||
params=validated_params,
|
||||
static_inputs=payload.static_inputs,
|
||||
enabled=payload.enabled,
|
||||
next_fire_at=_initial_next_fire(payload.type, validated_params, payload.enabled),
|
||||
)
|
||||
self.session.add(trigger)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(trigger)
|
||||
return trigger
|
||||
|
||||
async def update(
|
||||
self,
|
||||
*,
|
||||
automation_id: int,
|
||||
trigger_id: int,
|
||||
patch: TriggerUpdate,
|
||||
) -> AutomationTrigger:
|
||||
await self._authorize_automation(automation_id, Permission.AUTOMATIONS_UPDATE.value)
|
||||
trigger = await self._get_trigger_or_raise(automation_id, trigger_id)
|
||||
|
||||
data = patch.model_dump(exclude_unset=True)
|
||||
|
||||
if "params" in data:
|
||||
trigger.params = _validate_params(trigger.type, data["params"])
|
||||
|
||||
if "static_inputs" in data:
|
||||
trigger.static_inputs = data["static_inputs"]
|
||||
|
||||
if "enabled" in data:
|
||||
trigger.enabled = data["enabled"]
|
||||
|
||||
# Recompute next_fire_at when schedule timing changed or the trigger was
|
||||
# toggled back on. Manual triggers always have NULL next_fire_at.
|
||||
if trigger.type == TriggerType.SCHEDULE:
|
||||
trigger.next_fire_at = _initial_next_fire(
|
||||
trigger.type, trigger.params, trigger.enabled
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(trigger)
|
||||
return trigger
|
||||
|
||||
async def remove(self, *, automation_id: int, trigger_id: int) -> None:
|
||||
await self._authorize_automation(automation_id, Permission.AUTOMATIONS_UPDATE.value)
|
||||
trigger = await self._get_trigger_or_raise(automation_id, trigger_id)
|
||||
await self.session.delete(trigger)
|
||||
await self.session.commit()
|
||||
|
||||
async def _authorize_automation(
|
||||
self, automation_id: int, permission: str
|
||||
) -> Automation:
|
||||
automation = await self.session.get(Automation, automation_id)
|
||||
if automation is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"automation {automation_id} not found"
|
||||
)
|
||||
await check_permission(
|
||||
self.session,
|
||||
self.user,
|
||||
automation.search_space_id,
|
||||
permission,
|
||||
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
|
||||
)
|
||||
return automation
|
||||
|
||||
async def _get_trigger_or_raise(
|
||||
self, automation_id: int, trigger_id: int
|
||||
) -> AutomationTrigger:
|
||||
trigger = await self.session.get(AutomationTrigger, trigger_id)
|
||||
if trigger is None or trigger.automation_id != automation_id:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"trigger {trigger_id} not found"
|
||||
)
|
||||
return trigger
|
||||
|
||||
|
||||
def _validate_params(trigger_type: TriggerType, raw: dict) -> dict:
|
||||
definition = get_trigger(trigger_type.value)
|
||||
if definition is None:
|
||||
raise HTTPException(
|
||||
status_code=422, detail=f"unknown trigger type {trigger_type.value!r}"
|
||||
)
|
||||
try:
|
||||
validated = definition.params_model.model_validate(raw)
|
||||
except ValidationError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
return validated.model_dump(mode="json")
|
||||
|
||||
|
||||
def _initial_next_fire(
|
||||
trigger_type: TriggerType, params: dict, enabled: bool
|
||||
) -> datetime | None:
|
||||
if trigger_type != TriggerType.SCHEDULE or not enabled:
|
||||
return None
|
||||
return compute_next_fire_at(
|
||||
params["cron"], params["timezone"], after=datetime.now(UTC)
|
||||
)
|
||||
|
||||
|
||||
def get_trigger_service(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> TriggerService:
|
||||
return TriggerService(session=session, user=user)
|
||||
|
|
@ -15,6 +15,7 @@ Runs every minute. Each tick performs two passes:
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
|
|
@ -39,6 +40,15 @@ TASK_NAME = "automation_schedule_tick"
|
|||
_TICK_BATCH = 200
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _Claim:
|
||||
"""Per-trigger fire context captured before row state is mutated."""
|
||||
|
||||
trigger_id: int
|
||||
scheduled_for: datetime
|
||||
previous_last_fired_at: datetime | None
|
||||
|
||||
|
||||
@celery_app.task(name=TASK_NAME)
|
||||
def automation_schedule_tick() -> None:
|
||||
"""Tick once: self-heal NULL next_fire_at, claim due rows, fire each."""
|
||||
|
|
@ -52,12 +62,12 @@ async def _tick() -> None:
|
|||
|
||||
await _self_heal_null_next_fire(session, now=now)
|
||||
|
||||
claimed_ids = await _claim_due_triggers(session, now=now)
|
||||
if not claimed_ids:
|
||||
claims = await _claim_due_triggers(session, now=now)
|
||||
if not claims:
|
||||
return
|
||||
|
||||
for trigger_id in claimed_ids:
|
||||
await _fire_one(session, trigger_id=trigger_id)
|
||||
for claim in claims:
|
||||
await _fire_one(session, claim=claim, fired_at=now)
|
||||
|
||||
|
||||
async def _self_heal_null_next_fire(session: AsyncSession, *, now: datetime) -> None:
|
||||
|
|
@ -95,8 +105,8 @@ async def _self_heal_null_next_fire(session: AsyncSession, *, now: datetime) ->
|
|||
|
||||
async def _claim_due_triggers(
|
||||
session: AsyncSession, *, now: datetime
|
||||
) -> list[int]:
|
||||
"""Lock and advance due rows; return claimed trigger ids."""
|
||||
) -> list[_Claim]:
|
||||
"""Lock and advance due rows; return per-trigger fire context."""
|
||||
stmt = (
|
||||
select(AutomationTrigger)
|
||||
.where(
|
||||
|
|
@ -113,8 +123,12 @@ async def _claim_due_triggers(
|
|||
if not triggers:
|
||||
return []
|
||||
|
||||
claimed: list[int] = []
|
||||
claims: list[_Claim] = []
|
||||
for trigger in triggers:
|
||||
# Snapshot fire-context BEFORE we advance the row.
|
||||
scheduled_for = trigger.next_fire_at
|
||||
previous_last_fired_at = trigger.last_fired_at
|
||||
|
||||
try:
|
||||
trigger.next_fire_at = compute_next_fire_at(
|
||||
trigger.params["cron"],
|
||||
|
|
@ -131,29 +145,43 @@ async def _claim_due_triggers(
|
|||
continue
|
||||
|
||||
trigger.last_fired_at = now
|
||||
claimed.append(trigger.id)
|
||||
claims.append(
|
||||
_Claim(
|
||||
trigger_id=trigger.id,
|
||||
scheduled_for=scheduled_for,
|
||||
previous_last_fired_at=previous_last_fired_at,
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
return claimed
|
||||
return claims
|
||||
|
||||
|
||||
async def _fire_one(session: AsyncSession, *, trigger_id: int) -> None:
|
||||
async def _fire_one(
|
||||
session: AsyncSession, *, claim: _Claim, fired_at: datetime
|
||||
) -> None:
|
||||
"""Reload the trigger post-commit and dispatch a run for it."""
|
||||
trigger = await session.get(AutomationTrigger, trigger_id)
|
||||
trigger = await session.get(AutomationTrigger, claim.trigger_id)
|
||||
if trigger is None:
|
||||
return
|
||||
|
||||
try:
|
||||
run = await dispatch_schedule_run(session=session, trigger=trigger)
|
||||
run = await dispatch_schedule_run(
|
||||
session=session,
|
||||
trigger=trigger,
|
||||
fired_at=fired_at,
|
||||
scheduled_for=claim.scheduled_for,
|
||||
previous_last_fired_at=claim.previous_last_fired_at,
|
||||
)
|
||||
logger.info(
|
||||
"scheduled fire: trigger=%d automation=%d run=%d",
|
||||
trigger_id,
|
||||
claim.trigger_id,
|
||||
trigger.automation_id,
|
||||
run.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"scheduled fire failed for trigger %d (next attempt at next match)",
|
||||
trigger_id,
|
||||
claim.trigger_id,
|
||||
)
|
||||
await session.rollback()
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ def build_run_context(
|
|||
trigger_type: str | None,
|
||||
started_at: datetime | None,
|
||||
attempt: int,
|
||||
resolved_inputs: Mapping[str, Any],
|
||||
inputs: Mapping[str, Any],
|
||||
step_outputs: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Build the ``{run, inputs, steps}`` namespace exposed to every template."""
|
||||
|
|
@ -36,6 +36,6 @@ def build_run_context(
|
|||
"started_at": started_at,
|
||||
"attempt": attempt,
|
||||
},
|
||||
"inputs": dict(resolved_inputs),
|
||||
"inputs": dict(inputs),
|
||||
"steps": dict(step_outputs),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ from .params import ManualTriggerParams
|
|||
MANUAL_TRIGGER = TriggerDefinition(
|
||||
type="manual",
|
||||
description="Fire on a user-initiated 'Run now' invocation.",
|
||||
params_schema=ManualTriggerParams.model_json_schema(),
|
||||
payload_schema={"type": "object"},
|
||||
params_model=ManualTriggerParams,
|
||||
)
|
||||
|
||||
register_trigger(MANUAL_TRIGGER)
|
||||
|
|
|
|||
|
|
@ -19,9 +19,14 @@ async def dispatch_manual_run(
|
|||
*,
|
||||
session: AsyncSession,
|
||||
automation_id: int,
|
||||
payload: dict[str, Any] | None,
|
||||
runtime_inputs: dict[str, Any] | None,
|
||||
) -> AutomationRun:
|
||||
"""Find the automation + its enabled manual trigger, then run the generic dispatch."""
|
||||
"""Find the automation + its enabled manual trigger, then run the generic dispatch.
|
||||
|
||||
``runtime_inputs`` is the caller-supplied payload (e.g. an HTTP body for a
|
||||
"Run now" API call); it is merged with the trigger's ``static_inputs`` by
|
||||
the generic dispatcher, with static winning on key collision.
|
||||
"""
|
||||
automation = await _load_automation(session, automation_id)
|
||||
if automation is None:
|
||||
raise DispatchError(f"automation {automation_id} not found")
|
||||
|
|
@ -41,7 +46,7 @@ async def dispatch_manual_run(
|
|||
session=session,
|
||||
automation=automation,
|
||||
trigger=trigger,
|
||||
payload=payload,
|
||||
runtime_inputs=runtime_inputs,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,12 +9,7 @@ from .params import ScheduleTriggerParams
|
|||
SCHEDULE_TRIGGER = TriggerDefinition(
|
||||
type="schedule",
|
||||
description="Fire on a cron schedule in a given timezone.",
|
||||
params_schema=ScheduleTriggerParams.model_json_schema(),
|
||||
payload_schema={
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
"properties": {},
|
||||
},
|
||||
params_model=ScheduleTriggerParams,
|
||||
)
|
||||
|
||||
register_trigger(SCHEDULE_TRIGGER)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -16,9 +18,18 @@ async def dispatch_schedule_run(
|
|||
*,
|
||||
session: AsyncSession,
|
||||
trigger: AutomationTrigger,
|
||||
fired_at: datetime,
|
||||
scheduled_for: datetime,
|
||||
previous_last_fired_at: datetime | None,
|
||||
) -> AutomationRun:
|
||||
"""Fire one scheduled run for ``trigger``.
|
||||
|
||||
Emits calendar context as runtime inputs:
|
||||
|
||||
- ``fired_at`` — actual fire time
|
||||
- ``scheduled_for`` — cron-derived target time for this fire
|
||||
- ``last_fired_at`` — fire time of the previous run, or null on first fire
|
||||
|
||||
The caller (the schedule tick) is responsible for selecting due triggers
|
||||
and advancing ``next_fire_at`` / ``last_fired_at`` before invoking this.
|
||||
"""
|
||||
|
|
@ -33,11 +44,19 @@ async def dispatch_schedule_run(
|
|||
f"automation {trigger.automation_id} is {automation.status.value}, not active"
|
||||
)
|
||||
|
||||
runtime_inputs = {
|
||||
"fired_at": fired_at.isoformat(),
|
||||
"scheduled_for": scheduled_for.isoformat(),
|
||||
"last_fired_at": (
|
||||
previous_last_fired_at.isoformat() if previous_last_fired_at else None
|
||||
),
|
||||
}
|
||||
|
||||
return await dispatch_run(
|
||||
session=session,
|
||||
automation=automation,
|
||||
trigger=trigger,
|
||||
payload=None,
|
||||
runtime_inputs=runtime_inputs,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,16 @@ from __future__ import annotations
|
|||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class TriggerDefinition:
|
||||
type: str
|
||||
description: str
|
||||
params_schema: dict[str, Any]
|
||||
payload_schema: dict[str, Any]
|
||||
params_model: type[BaseModel]
|
||||
|
||||
@property
|
||||
def params_schema(self) -> dict[str, Any]:
|
||||
"""JSON Schema (draft 2020-12) derived from ``params_model``."""
|
||||
return self.params_model.model_json_schema()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue