From 7ec3468113fe28ffb7c634aee790bfd91c625766 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 27 May 2026 16:29:32 +0200 Subject: [PATCH] refactor(automations): bind action handlers via ActionContext factory --- .../app/automations/registries/__init__.py | 4 +++ .../registries/actions/__init__.py | 4 ++- .../registries/actions/agent_task.py | 17 +++++++---- .../automations/registries/actions/types.py | 20 +++++++++++-- .../app/automations/runtime/executor.py | 29 +++++++++++++++---- .../app/automations/runtime/step.py | 6 +++- 6 files changed, 65 insertions(+), 15 deletions(-) diff --git a/surfsense_backend/app/automations/registries/__init__.py b/surfsense_backend/app/automations/registries/__init__.py index f497caf59..f6af3817b 100644 --- a/surfsense_backend/app/automations/registries/__init__.py +++ b/surfsense_backend/app/automations/registries/__init__.py @@ -3,8 +3,10 @@ from __future__ import annotations from .actions import ( + ActionContext, ActionDefinition, ActionHandler, + ActionHandlerFactory, all_actions, get_action, register_action, @@ -17,8 +19,10 @@ from .triggers import ( ) __all__ = [ + "ActionContext", "ActionDefinition", "ActionHandler", + "ActionHandlerFactory", "TriggerDefinition", "all_actions", "all_triggers", diff --git a/surfsense_backend/app/automations/registries/actions/__init__.py b/surfsense_backend/app/automations/registries/actions/__init__.py index 68e507133..b95c634f2 100644 --- a/surfsense_backend/app/automations/registries/actions/__init__.py +++ b/surfsense_backend/app/automations/registries/actions/__init__.py @@ -3,11 +3,13 @@ from __future__ import annotations from .store import all_actions, get_action, register_action -from .types import ActionDefinition, ActionHandler +from .types import ActionContext, ActionDefinition, ActionHandler, ActionHandlerFactory __all__ = [ + "ActionContext", "ActionDefinition", "ActionHandler", + "ActionHandlerFactory", "all_actions", "get_action", "register_action", diff --git a/surfsense_backend/app/automations/registries/actions/agent_task.py b/surfsense_backend/app/automations/registries/actions/agent_task.py index 9acc11c2c..beba455cc 100644 --- a/surfsense_backend/app/automations/registries/actions/agent_task.py +++ b/surfsense_backend/app/automations/registries/actions/agent_task.py @@ -7,13 +7,18 @@ from typing import Any from app.automations.schemas.actions import AgentTaskActionParams from .store import register_action -from .types import ActionDefinition +from .types import ActionContext, ActionDefinition, ActionHandler -async def _handle_agent_task(args: dict[str, Any]) -> dict[str, Any]: - """Stub. Validates params; real wiring lands with the executor.""" - AgentTaskActionParams.model_validate(args) - return {"status": "stubbed"} +def _build_handler(ctx: ActionContext) -> ActionHandler: + """Bind run/session context to the agent_task handler. Real wiring lands in Phase 4b.""" + del ctx # ignored by the stub; real handler will consume it + + async def handle(params: dict[str, Any]) -> dict[str, Any]: + AgentTaskActionParams.model_validate(params) + return {"status": "stubbed"} + + return handle AGENT_TASK_ACTION = ActionDefinition( @@ -21,7 +26,7 @@ AGENT_TASK_ACTION = ActionDefinition( name="Agent task", description="Run an agent task with a scoped tool allowlist.", params_schema=AgentTaskActionParams.model_json_schema(), - handler=_handle_agent_task, + build_handler=_build_handler, ) register_action(AGENT_TASK_ACTION) diff --git a/surfsense_backend/app/automations/registries/actions/types.py b/surfsense_backend/app/automations/registries/actions/types.py index 99f94ae7c..433c60841 100644 --- a/surfsense_backend/app/automations/registries/actions/types.py +++ b/surfsense_backend/app/automations/registries/actions/types.py @@ -1,12 +1,28 @@ -"""``ActionDefinition`` dataclass and handler signature.""" +"""``ActionDefinition``, ``ActionContext``, and handler/factory signatures.""" from __future__ import annotations from collections.abc import Awaitable, Callable from dataclasses import dataclass from typing import Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + + +@dataclass(frozen=True, slots=True) +class ActionContext: + """Per-invocation dependencies bound to an action handler at execute time.""" + + session: AsyncSession + run_id: int + step_id: str + search_space_id: int + creator_user_id: UUID | None + ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]] +ActionHandlerFactory = Callable[[ActionContext], ActionHandler] @dataclass(frozen=True, slots=True) @@ -15,4 +31,4 @@ class ActionDefinition: name: str description: str params_schema: dict[str, Any] - handler: ActionHandler + build_handler: ActionHandlerFactory diff --git a/surfsense_backend/app/automations/runtime/executor.py b/surfsense_backend/app/automations/runtime/executor.py index 51c4417e3..e9e55b02d 100644 --- a/surfsense_backend/app/automations/runtime/executor.py +++ b/surfsense_backend/app/automations/runtime/executor.py @@ -8,7 +8,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.automations.persistence.enums.run_status import RunStatus from app.automations.persistence.models.run import AutomationRun +from app.automations.registries.actions.types import ActionContext from app.automations.schemas.definition.envelope import AutomationDefinition +from app.automations.schemas.definition.plan_step import PlanStep from app.automations.templating import build_run_context from . import repository @@ -41,10 +43,12 @@ async def execute_run(session: AsyncSession, run_id: int) -> None: step_outputs: dict[str, Any] = {} for step in definition.plan: - ctx = _build_ctx(run, step_outputs) + template_ctx = _build_template_ctx(run, step_outputs) + action_ctx = _build_action_ctx(session, run, step) result = await execute_step( step=step, - template_context=ctx, + template_context=template_ctx, + action_context=action_ctx, default_max_retries=definition.execution.max_retries, default_retry_backoff=definition.execution.retry_backoff, default_timeout_seconds=definition.execution.timeout_seconds, @@ -73,11 +77,13 @@ async def _run_on_failure( """Run the on_failure steps. Their failures don't recurse into more on_failure.""" if not definition.execution.on_failure: return - ctx = _build_ctx(run, step_outputs={}) + template_ctx = _build_template_ctx(run, step_outputs={}) for step in definition.execution.on_failure: + action_ctx = _build_action_ctx(session, run, step) result = await execute_step( step=step, - template_context=ctx, + template_context=template_ctx, + action_context=action_ctx, default_max_retries=definition.execution.max_retries, default_retry_backoff=definition.execution.retry_backoff, default_timeout_seconds=definition.execution.timeout_seconds, @@ -86,7 +92,7 @@ async def _run_on_failure( await session.commit() -def _build_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dict[str, Any]: +def _build_template_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dict[str, Any]: automation = run.automation trigger = run.trigger return build_run_context( @@ -103,3 +109,16 @@ def _build_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dict[str, An resolved_inputs=run.resolved_inputs or {}, step_outputs=step_outputs, ) + + +def _build_action_ctx( + session: AsyncSession, run: AutomationRun, step: PlanStep +) -> ActionContext: + automation = run.automation + return ActionContext( + session=session, + run_id=run.id, + step_id=step.step_id, + search_space_id=automation.search_space_id, + creator_user_id=automation.created_by_user_id, + ) diff --git a/surfsense_backend/app/automations/runtime/step.py b/surfsense_backend/app/automations/runtime/step.py index 07b894a91..76e3ba171 100644 --- a/surfsense_backend/app/automations/runtime/step.py +++ b/surfsense_backend/app/automations/runtime/step.py @@ -7,6 +7,7 @@ from datetime import UTC, datetime from typing import Any from app.automations.registries import get_action +from app.automations.registries.actions.types import ActionContext from app.automations.schemas.definition.plan_step import PlanStep from app.automations.templating import evaluate_predicate, render_value @@ -17,6 +18,7 @@ async def execute_step( *, step: PlanStep, template_context: Mapping[str, Any], + action_context: ActionContext, default_max_retries: int, default_retry_backoff: str, default_timeout_seconds: int, @@ -47,12 +49,14 @@ async def execute_step( error={"message": f"action not registered: {step.action}", "type": "ActionNotFound"}, ) + handler = action.build_handler(action_context) + max_retries = step.max_retries if step.max_retries is not None else default_max_retries timeout = step.timeout_seconds or default_timeout_seconds try: result, attempts = await with_retries( - lambda: action.handler(resolved_params), + lambda: handler(resolved_params), max_retries=max_retries, backoff=default_retry_backoff, timeout=timeout,