mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
refactor(automations): bind action handlers via ActionContext factory
This commit is contained in:
parent
f646b5cbab
commit
7ec3468113
6 changed files with 65 additions and 15 deletions
|
|
@ -3,8 +3,10 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .actions import (
|
from .actions import (
|
||||||
|
ActionContext,
|
||||||
ActionDefinition,
|
ActionDefinition,
|
||||||
ActionHandler,
|
ActionHandler,
|
||||||
|
ActionHandlerFactory,
|
||||||
all_actions,
|
all_actions,
|
||||||
get_action,
|
get_action,
|
||||||
register_action,
|
register_action,
|
||||||
|
|
@ -17,8 +19,10 @@ from .triggers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"ActionContext",
|
||||||
"ActionDefinition",
|
"ActionDefinition",
|
||||||
"ActionHandler",
|
"ActionHandler",
|
||||||
|
"ActionHandlerFactory",
|
||||||
"TriggerDefinition",
|
"TriggerDefinition",
|
||||||
"all_actions",
|
"all_actions",
|
||||||
"all_triggers",
|
"all_triggers",
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,13 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .store import all_actions, get_action, register_action
|
from .store import all_actions, get_action, register_action
|
||||||
from .types import ActionDefinition, ActionHandler
|
from .types import ActionContext, ActionDefinition, ActionHandler, ActionHandlerFactory
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"ActionContext",
|
||||||
"ActionDefinition",
|
"ActionDefinition",
|
||||||
"ActionHandler",
|
"ActionHandler",
|
||||||
|
"ActionHandlerFactory",
|
||||||
"all_actions",
|
"all_actions",
|
||||||
"get_action",
|
"get_action",
|
||||||
"register_action",
|
"register_action",
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,18 @@ from typing import Any
|
||||||
from app.automations.schemas.actions import AgentTaskActionParams
|
from app.automations.schemas.actions import AgentTaskActionParams
|
||||||
|
|
||||||
from .store import register_action
|
from .store import register_action
|
||||||
from .types import ActionDefinition
|
from .types import ActionContext, ActionDefinition, ActionHandler
|
||||||
|
|
||||||
|
|
||||||
async def _handle_agent_task(args: dict[str, Any]) -> dict[str, Any]:
|
def _build_handler(ctx: ActionContext) -> ActionHandler:
|
||||||
"""Stub. Validates params; real wiring lands with the executor."""
|
"""Bind run/session context to the agent_task handler. Real wiring lands in Phase 4b."""
|
||||||
AgentTaskActionParams.model_validate(args)
|
del ctx # ignored by the stub; real handler will consume it
|
||||||
return {"status": "stubbed"}
|
|
||||||
|
async def handle(params: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
AgentTaskActionParams.model_validate(params)
|
||||||
|
return {"status": "stubbed"}
|
||||||
|
|
||||||
|
return handle
|
||||||
|
|
||||||
|
|
||||||
AGENT_TASK_ACTION = ActionDefinition(
|
AGENT_TASK_ACTION = ActionDefinition(
|
||||||
|
|
@ -21,7 +26,7 @@ AGENT_TASK_ACTION = ActionDefinition(
|
||||||
name="Agent task",
|
name="Agent task",
|
||||||
description="Run an agent task with a scoped tool allowlist.",
|
description="Run an agent task with a scoped tool allowlist.",
|
||||||
params_schema=AgentTaskActionParams.model_json_schema(),
|
params_schema=AgentTaskActionParams.model_json_schema(),
|
||||||
handler=_handle_agent_task,
|
build_handler=_build_handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
register_action(AGENT_TASK_ACTION)
|
register_action(AGENT_TASK_ACTION)
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,28 @@
|
||||||
"""``ActionDefinition`` dataclass and handler signature."""
|
"""``ActionDefinition``, ``ActionContext``, and handler/factory signatures."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class ActionContext:
|
||||||
|
"""Per-invocation dependencies bound to an action handler at execute time."""
|
||||||
|
|
||||||
|
session: AsyncSession
|
||||||
|
run_id: int
|
||||||
|
step_id: str
|
||||||
|
search_space_id: int
|
||||||
|
creator_user_id: UUID | None
|
||||||
|
|
||||||
|
|
||||||
ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]]
|
ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]]
|
||||||
|
ActionHandlerFactory = Callable[[ActionContext], ActionHandler]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
|
|
@ -15,4 +31,4 @@ class ActionDefinition:
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
params_schema: dict[str, Any]
|
params_schema: dict[str, Any]
|
||||||
handler: ActionHandler
|
build_handler: ActionHandlerFactory
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.automations.persistence.enums.run_status import RunStatus
|
from app.automations.persistence.enums.run_status import RunStatus
|
||||||
from app.automations.persistence.models.run import AutomationRun
|
from app.automations.persistence.models.run import AutomationRun
|
||||||
|
from app.automations.registries.actions.types import ActionContext
|
||||||
from app.automations.schemas.definition.envelope import AutomationDefinition
|
from app.automations.schemas.definition.envelope import AutomationDefinition
|
||||||
|
from app.automations.schemas.definition.plan_step import PlanStep
|
||||||
from app.automations.templating import build_run_context
|
from app.automations.templating import build_run_context
|
||||||
|
|
||||||
from . import repository
|
from . import repository
|
||||||
|
|
@ -41,10 +43,12 @@ async def execute_run(session: AsyncSession, run_id: int) -> None:
|
||||||
step_outputs: dict[str, Any] = {}
|
step_outputs: dict[str, Any] = {}
|
||||||
|
|
||||||
for step in definition.plan:
|
for step in definition.plan:
|
||||||
ctx = _build_ctx(run, step_outputs)
|
template_ctx = _build_template_ctx(run, step_outputs)
|
||||||
|
action_ctx = _build_action_ctx(session, run, step)
|
||||||
result = await execute_step(
|
result = await execute_step(
|
||||||
step=step,
|
step=step,
|
||||||
template_context=ctx,
|
template_context=template_ctx,
|
||||||
|
action_context=action_ctx,
|
||||||
default_max_retries=definition.execution.max_retries,
|
default_max_retries=definition.execution.max_retries,
|
||||||
default_retry_backoff=definition.execution.retry_backoff,
|
default_retry_backoff=definition.execution.retry_backoff,
|
||||||
default_timeout_seconds=definition.execution.timeout_seconds,
|
default_timeout_seconds=definition.execution.timeout_seconds,
|
||||||
|
|
@ -73,11 +77,13 @@ async def _run_on_failure(
|
||||||
"""Run the on_failure steps. Their failures don't recurse into more on_failure."""
|
"""Run the on_failure steps. Their failures don't recurse into more on_failure."""
|
||||||
if not definition.execution.on_failure:
|
if not definition.execution.on_failure:
|
||||||
return
|
return
|
||||||
ctx = _build_ctx(run, step_outputs={})
|
template_ctx = _build_template_ctx(run, step_outputs={})
|
||||||
for step in definition.execution.on_failure:
|
for step in definition.execution.on_failure:
|
||||||
|
action_ctx = _build_action_ctx(session, run, step)
|
||||||
result = await execute_step(
|
result = await execute_step(
|
||||||
step=step,
|
step=step,
|
||||||
template_context=ctx,
|
template_context=template_ctx,
|
||||||
|
action_context=action_ctx,
|
||||||
default_max_retries=definition.execution.max_retries,
|
default_max_retries=definition.execution.max_retries,
|
||||||
default_retry_backoff=definition.execution.retry_backoff,
|
default_retry_backoff=definition.execution.retry_backoff,
|
||||||
default_timeout_seconds=definition.execution.timeout_seconds,
|
default_timeout_seconds=definition.execution.timeout_seconds,
|
||||||
|
|
@ -86,7 +92,7 @@ async def _run_on_failure(
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
def _build_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dict[str, Any]:
|
def _build_template_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
automation = run.automation
|
automation = run.automation
|
||||||
trigger = run.trigger
|
trigger = run.trigger
|
||||||
return build_run_context(
|
return build_run_context(
|
||||||
|
|
@ -103,3 +109,16 @@ def _build_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dict[str, An
|
||||||
resolved_inputs=run.resolved_inputs or {},
|
resolved_inputs=run.resolved_inputs or {},
|
||||||
step_outputs=step_outputs,
|
step_outputs=step_outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_action_ctx(
|
||||||
|
session: AsyncSession, run: AutomationRun, step: PlanStep
|
||||||
|
) -> ActionContext:
|
||||||
|
automation = run.automation
|
||||||
|
return ActionContext(
|
||||||
|
session=session,
|
||||||
|
run_id=run.id,
|
||||||
|
step_id=step.step_id,
|
||||||
|
search_space_id=automation.search_space_id,
|
||||||
|
creator_user_id=automation.created_by_user_id,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.automations.registries import get_action
|
from app.automations.registries import get_action
|
||||||
|
from app.automations.registries.actions.types import ActionContext
|
||||||
from app.automations.schemas.definition.plan_step import PlanStep
|
from app.automations.schemas.definition.plan_step import PlanStep
|
||||||
from app.automations.templating import evaluate_predicate, render_value
|
from app.automations.templating import evaluate_predicate, render_value
|
||||||
|
|
||||||
|
|
@ -17,6 +18,7 @@ async def execute_step(
|
||||||
*,
|
*,
|
||||||
step: PlanStep,
|
step: PlanStep,
|
||||||
template_context: Mapping[str, Any],
|
template_context: Mapping[str, Any],
|
||||||
|
action_context: ActionContext,
|
||||||
default_max_retries: int,
|
default_max_retries: int,
|
||||||
default_retry_backoff: str,
|
default_retry_backoff: str,
|
||||||
default_timeout_seconds: int,
|
default_timeout_seconds: int,
|
||||||
|
|
@ -47,12 +49,14 @@ async def execute_step(
|
||||||
error={"message": f"action not registered: {step.action}", "type": "ActionNotFound"},
|
error={"message": f"action not registered: {step.action}", "type": "ActionNotFound"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
handler = action.build_handler(action_context)
|
||||||
|
|
||||||
max_retries = step.max_retries if step.max_retries is not None else default_max_retries
|
max_retries = step.max_retries if step.max_retries is not None else default_max_retries
|
||||||
timeout = step.timeout_seconds or default_timeout_seconds
|
timeout = step.timeout_seconds or default_timeout_seconds
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result, attempts = await with_retries(
|
result, attempts = await with_retries(
|
||||||
lambda: action.handler(resolved_params),
|
lambda: handler(resolved_params),
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
backoff=default_retry_backoff,
|
backoff=default_retry_backoff,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue