refactor(automations): bind action handlers via ActionContext factory

This commit is contained in:
CREDO23 2026-05-27 16:29:32 +02:00
parent f646b5cbab
commit 7ec3468113
6 changed files with 65 additions and 15 deletions

View file

@ -3,8 +3,10 @@
from __future__ import annotations from __future__ import annotations
from .actions import ( from .actions import (
ActionContext,
ActionDefinition, ActionDefinition,
ActionHandler, ActionHandler,
ActionHandlerFactory,
all_actions, all_actions,
get_action, get_action,
register_action, register_action,
@ -17,8 +19,10 @@ from .triggers import (
) )
__all__ = [ __all__ = [
"ActionContext",
"ActionDefinition", "ActionDefinition",
"ActionHandler", "ActionHandler",
"ActionHandlerFactory",
"TriggerDefinition", "TriggerDefinition",
"all_actions", "all_actions",
"all_triggers", "all_triggers",

View file

@ -3,11 +3,13 @@
from __future__ import annotations from __future__ import annotations
from .store import all_actions, get_action, register_action from .store import all_actions, get_action, register_action
from .types import ActionDefinition, ActionHandler from .types import ActionContext, ActionDefinition, ActionHandler, ActionHandlerFactory
__all__ = [ __all__ = [
"ActionContext",
"ActionDefinition", "ActionDefinition",
"ActionHandler", "ActionHandler",
"ActionHandlerFactory",
"all_actions", "all_actions",
"get_action", "get_action",
"register_action", "register_action",

View file

@ -7,13 +7,18 @@ from typing import Any
from app.automations.schemas.actions import AgentTaskActionParams from app.automations.schemas.actions import AgentTaskActionParams
from .store import register_action from .store import register_action
from .types import ActionDefinition from .types import ActionContext, ActionDefinition, ActionHandler
async def _handle_agent_task(args: dict[str, Any]) -> dict[str, Any]: def _build_handler(ctx: ActionContext) -> ActionHandler:
"""Stub. Validates params; real wiring lands with the executor.""" """Bind run/session context to the agent_task handler. Real wiring lands in Phase 4b."""
AgentTaskActionParams.model_validate(args) del ctx # ignored by the stub; real handler will consume it
return {"status": "stubbed"}
async def handle(params: dict[str, Any]) -> dict[str, Any]:
AgentTaskActionParams.model_validate(params)
return {"status": "stubbed"}
return handle
AGENT_TASK_ACTION = ActionDefinition( AGENT_TASK_ACTION = ActionDefinition(
@ -21,7 +26,7 @@ AGENT_TASK_ACTION = ActionDefinition(
name="Agent task", name="Agent task",
description="Run an agent task with a scoped tool allowlist.", description="Run an agent task with a scoped tool allowlist.",
params_schema=AgentTaskActionParams.model_json_schema(), params_schema=AgentTaskActionParams.model_json_schema(),
handler=_handle_agent_task, build_handler=_build_handler,
) )
register_action(AGENT_TASK_ACTION) register_action(AGENT_TASK_ACTION)

View file

@ -1,12 +1,28 @@
"""``ActionDefinition`` dataclass and handler signature.""" """``ActionDefinition``, ``ActionContext``, and handler/factory signatures."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
@dataclass(frozen=True, slots=True)
class ActionContext:
"""Per-invocation dependencies bound to an action handler at execute time."""
session: AsyncSession
run_id: int
step_id: str
search_space_id: int
creator_user_id: UUID | None
ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]] ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]]
ActionHandlerFactory = Callable[[ActionContext], ActionHandler]
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
@ -15,4 +31,4 @@ class ActionDefinition:
name: str name: str
description: str description: str
params_schema: dict[str, Any] params_schema: dict[str, Any]
handler: ActionHandler build_handler: ActionHandlerFactory

View file

@ -8,7 +8,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.persistence.enums.run_status import RunStatus from app.automations.persistence.enums.run_status import RunStatus
from app.automations.persistence.models.run import AutomationRun from app.automations.persistence.models.run import AutomationRun
from app.automations.registries.actions.types import ActionContext
from app.automations.schemas.definition.envelope import AutomationDefinition from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.schemas.definition.plan_step import PlanStep
from app.automations.templating import build_run_context from app.automations.templating import build_run_context
from . import repository from . import repository
@ -41,10 +43,12 @@ async def execute_run(session: AsyncSession, run_id: int) -> None:
step_outputs: dict[str, Any] = {} step_outputs: dict[str, Any] = {}
for step in definition.plan: for step in definition.plan:
ctx = _build_ctx(run, step_outputs) template_ctx = _build_template_ctx(run, step_outputs)
action_ctx = _build_action_ctx(session, run, step)
result = await execute_step( result = await execute_step(
step=step, step=step,
template_context=ctx, template_context=template_ctx,
action_context=action_ctx,
default_max_retries=definition.execution.max_retries, default_max_retries=definition.execution.max_retries,
default_retry_backoff=definition.execution.retry_backoff, default_retry_backoff=definition.execution.retry_backoff,
default_timeout_seconds=definition.execution.timeout_seconds, default_timeout_seconds=definition.execution.timeout_seconds,
@ -73,11 +77,13 @@ async def _run_on_failure(
"""Run the on_failure steps. Their failures don't recurse into more on_failure.""" """Run the on_failure steps. Their failures don't recurse into more on_failure."""
if not definition.execution.on_failure: if not definition.execution.on_failure:
return return
ctx = _build_ctx(run, step_outputs={}) template_ctx = _build_template_ctx(run, step_outputs={})
for step in definition.execution.on_failure: for step in definition.execution.on_failure:
action_ctx = _build_action_ctx(session, run, step)
result = await execute_step( result = await execute_step(
step=step, step=step,
template_context=ctx, template_context=template_ctx,
action_context=action_ctx,
default_max_retries=definition.execution.max_retries, default_max_retries=definition.execution.max_retries,
default_retry_backoff=definition.execution.retry_backoff, default_retry_backoff=definition.execution.retry_backoff,
default_timeout_seconds=definition.execution.timeout_seconds, default_timeout_seconds=definition.execution.timeout_seconds,
@ -86,7 +92,7 @@ async def _run_on_failure(
await session.commit() await session.commit()
def _build_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dict[str, Any]: def _build_template_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dict[str, Any]:
automation = run.automation automation = run.automation
trigger = run.trigger trigger = run.trigger
return build_run_context( return build_run_context(
@ -103,3 +109,16 @@ def _build_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dict[str, An
resolved_inputs=run.resolved_inputs or {}, resolved_inputs=run.resolved_inputs or {},
step_outputs=step_outputs, step_outputs=step_outputs,
) )
def _build_action_ctx(
session: AsyncSession, run: AutomationRun, step: PlanStep
) -> ActionContext:
automation = run.automation
return ActionContext(
session=session,
run_id=run.id,
step_id=step.step_id,
search_space_id=automation.search_space_id,
creator_user_id=automation.created_by_user_id,
)

View file

@ -7,6 +7,7 @@ from datetime import UTC, datetime
from typing import Any from typing import Any
from app.automations.registries import get_action from app.automations.registries import get_action
from app.automations.registries.actions.types import ActionContext
from app.automations.schemas.definition.plan_step import PlanStep from app.automations.schemas.definition.plan_step import PlanStep
from app.automations.templating import evaluate_predicate, render_value from app.automations.templating import evaluate_predicate, render_value
@ -17,6 +18,7 @@ async def execute_step(
*, *,
step: PlanStep, step: PlanStep,
template_context: Mapping[str, Any], template_context: Mapping[str, Any],
action_context: ActionContext,
default_max_retries: int, default_max_retries: int,
default_retry_backoff: str, default_retry_backoff: str,
default_timeout_seconds: int, default_timeout_seconds: int,
@ -47,12 +49,14 @@ async def execute_step(
error={"message": f"action not registered: {step.action}", "type": "ActionNotFound"}, error={"message": f"action not registered: {step.action}", "type": "ActionNotFound"},
) )
handler = action.build_handler(action_context)
max_retries = step.max_retries if step.max_retries is not None else default_max_retries max_retries = step.max_retries if step.max_retries is not None else default_max_retries
timeout = step.timeout_seconds or default_timeout_seconds timeout = step.timeout_seconds or default_timeout_seconds
try: try:
result, attempts = await with_retries( result, attempts = await with_retries(
lambda: action.handler(resolved_params), lambda: handler(resolved_params),
max_retries=max_retries, max_retries=max_retries,
backoff=default_retry_backoff, backoff=default_retry_backoff,
timeout=timeout, timeout=timeout,