refactor(automations): move agent_task to builtin and restructure dispatch

This commit is contained in:
CREDO23 2026-05-29 18:13:09 +02:00
parent f356e304e8
commit 30fff9e52f
22 changed files with 142 additions and 133 deletions

View file

@ -21,4 +21,4 @@ __all__ = [
] ]
# Built-in actions self-register at import time. # Built-in actions self-register at import time.
from . import agent_task # noqa: F401 from . import builtin # noqa: F401

View file

@ -0,0 +1,5 @@
"""Built-in action types — each in its own subpackage, self-registering at import."""
from __future__ import annotations
from . import agent_task # noqa: F401

View file

@ -2,8 +2,8 @@
from __future__ import annotations from __future__ import annotations
from ..store import register_action from ...store import register_action
from ..types import ActionDefinition from ...types import ActionDefinition
from .factory import build_handler from .factory import build_handler
from .params import AgentTaskActionParams from .params import AgentTaskActionParams

View file

@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Any from typing import Any
from ..types import ActionContext, ActionHandler from ...types import ActionContext, ActionHandler
from .invoke import run_agent_task from .invoke import run_agent_task
from .params import AgentTaskActionParams from .params import AgentTaskActionParams

View file

@ -16,7 +16,7 @@ from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in
from app.db import ChatVisibility, async_session_maker from app.db import ChatVisibility, async_session_maker
from app.schemas.new_chat import MentionedDocumentInfo from app.schemas.new_chat import MentionedDocumentInfo
from ..types import ActionContext from ...types import ActionContext
from .auto_decide import build_auto_decisions from .auto_decide import build_auto_decisions
from .dependencies import build_dependencies from .dependencies import build_dependencies
from .finalize import extract_final_assistant_message from .finalize import extract_final_assistant_message

View file

@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
from .errors import DispatchError from .errors import DispatchError
from .run import dispatch_run from .launch import launch_run
from .start import start_run
__all__ = ["DispatchError", "dispatch_run", "start_run"] __all__ = ["DispatchError", "launch_run"]

View file

@ -0,0 +1,43 @@
"""Merge and validate the inputs a run starts with."""
from __future__ import annotations
from typing import Any
import jsonschema
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.schemas.definition.envelope import AutomationDefinition
from .errors import DispatchError
def prepare_inputs(
definition: AutomationDefinition,
trigger: AutomationTrigger,
runtime_inputs: dict[str, Any] | None,
) -> dict[str, Any]:
"""Merge ``trigger.static_inputs`` over ``runtime_inputs``, then validate.
Static inputs win on key collision.
"""
merged = {**(runtime_inputs or {}), **(trigger.static_inputs or {})}
return validate_inputs(definition, merged)
def validate_inputs(
definition: AutomationDefinition, inputs: dict[str, Any]
) -> dict[str, Any]:
"""Validate ``inputs`` against the definition's optional declared schema.
No declared schema pass through unchanged so runtime keys (``fired_at``,
``last_fired_at``, ...) still reach the template context. A declared schema
that the inputs violate is surfaced as ``DispatchError``.
"""
if definition.inputs is None or not definition.inputs.schema_:
return inputs
try:
jsonschema.validate(instance=inputs, schema=definition.inputs.schema_)
except jsonschema.ValidationError as exc:
raise DispatchError(f"inputs: {exc.message}") from exc
return inputs

View file

@ -0,0 +1,60 @@
"""Launch a run for a trigger that fired: resolve, validate, persist, enqueue.
The trigger-facing entry every selector calls. A selector builds the runtime
inputs and hands one trigger row here; this resolves and guards its automation,
snapshots the definition onto a PENDING run, and enqueues execution. The
snapshot makes the run immune to later edits of the automation.
"""
from __future__ import annotations
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.persistence.enums.run_status import RunStatus
from app.automations.persistence.models.run import AutomationRun
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.tasks.execute_run import automation_run_execute
from .errors import DispatchError
from .inputs import prepare_inputs
from .resolve import resolve_active_automation
async def launch_run(
*,
session: AsyncSession,
trigger: AutomationTrigger,
runtime_inputs: dict[str, Any] | None = None,
) -> AutomationRun:
"""Resolve ``trigger``'s active automation and enqueue a PENDING run for it."""
automation = await resolve_active_automation(session, trigger)
try:
definition = AutomationDefinition.model_validate(automation.definition)
except Exception as exc:
raise DispatchError(f"invalid automation definition: {exc}") from exc
inputs = prepare_inputs(definition, trigger, runtime_inputs)
snapshot = definition.model_dump(mode="json", by_alias=True)
run = AutomationRun(
automation_id=automation.id,
trigger_id=trigger.id,
status=RunStatus.PENDING,
definition_snapshot=snapshot,
inputs=inputs,
step_results=[],
artifacts=[],
)
session.add(run)
await session.commit()
await session.refresh(run)
automation_run_execute.apply_async(
args=[run.id],
time_limit=definition.execution.timeout_seconds,
)
return run

View file

@ -1,33 +1,24 @@
"""Start one run for a trigger: resolve its automation, guard ``ACTIVE``, dispatch. """Resolve the automation behind a trigger and guard that it may run."""
Shared by every trigger type. A type's selector builds the runtime inputs and
hands one trigger row here; this resolves and guards the automation, then calls
the generic ``dispatch_run``.
"""
from __future__ import annotations from __future__ import annotations
from typing import Any
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.persistence.enums.automation_status import AutomationStatus from app.automations.persistence.enums.automation_status import AutomationStatus
from app.automations.persistence.models.automation import Automation from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.run import AutomationRun
from app.automations.persistence.models.trigger import AutomationTrigger from app.automations.persistence.models.trigger import AutomationTrigger
from .errors import DispatchError from .errors import DispatchError
from .run import dispatch_run
async def start_run( async def resolve_active_automation(
*, session: AsyncSession, trigger: AutomationTrigger
session: AsyncSession, ) -> Automation:
trigger: AutomationTrigger, """Load ``trigger``'s automation and require it ``ACTIVE``.
runtime_inputs: dict[str, Any] | None = None,
) -> AutomationRun: Raises ``DispatchError`` if the automation is missing or not active.
"""Resolve ``trigger``'s automation, require it ``ACTIVE``, dispatch a run.""" """
automation = await _load_automation(session, trigger.automation_id) automation = await _load_automation(session, trigger.automation_id)
if automation is None: if automation is None:
raise DispatchError( raise DispatchError(
@ -39,12 +30,7 @@ async def start_run(
f"automation {trigger.automation_id} is {automation.status.value}, not active" f"automation {trigger.automation_id} is {automation.status.value}, not active"
) )
return await dispatch_run( return automation
session=session,
automation=automation,
trigger=trigger,
runtime_inputs=runtime_inputs,
)
async def _load_automation( async def _load_automation(

View file

@ -1,83 +0,0 @@
"""Generic run dispatch: validate, snapshot, persist, enqueue. Shared by every trigger."""
from __future__ import annotations
from typing import Any
import jsonschema
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.persistence.enums.run_status import RunStatus
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.run import AutomationRun
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.tasks.execute_run import automation_run_execute
from .errors import DispatchError
async def dispatch_run(
*,
session: AsyncSession,
automation: Automation,
trigger: AutomationTrigger,
runtime_inputs: dict[str, Any] | None = None,
) -> AutomationRun:
"""Validate, snapshot the definition, persist an ``AutomationRun``, enqueue execution.
Final inputs = ``trigger.static_inputs`` merged with ``runtime_inputs``,
static winning on key collision. The merged dict is validated against
``automation.definition.inputs.schema_`` and stored on the run.
Callers (trigger-specific adapters) are responsible for resolving
``automation`` and ``trigger`` and for the trigger-side ``ACTIVE`` /
``enabled`` guards. This function only handles what's identical across
every trigger type.
"""
try:
definition = AutomationDefinition.model_validate(automation.definition)
except Exception as exc:
raise DispatchError(f"invalid automation definition: {exc}") from exc
merged_inputs = {**(runtime_inputs or {}), **(trigger.static_inputs or {})}
validated_inputs = _validate_inputs(definition, merged_inputs)
snapshot = definition.model_dump(mode="json", by_alias=True)
run = AutomationRun(
automation_id=automation.id,
trigger_id=trigger.id,
status=RunStatus.PENDING,
definition_snapshot=snapshot,
inputs=validated_inputs,
step_results=[],
artifacts=[],
)
session.add(run)
await session.commit()
await session.refresh(run)
automation_run_execute.apply_async(
args=[run.id],
time_limit=definition.execution.timeout_seconds,
)
return run
def _validate_inputs(
definition: AutomationDefinition, inputs: dict[str, Any]
) -> dict[str, Any]:
"""Validate merged inputs against the optional declared schema.
No declared schema pass through (runtime inputs like ``fired_at`` /
``last_fired_at`` and trigger ``static_inputs`` must still reach the
template context). Returning ``{}`` here strips them and makes Jinja
blow up on any ``{{ inputs.* }}`` reference.
"""
if definition.inputs is None or not definition.inputs.schema_:
return inputs
try:
jsonschema.validate(instance=inputs, schema=definition.inputs.schema_)
except jsonschema.ValidationError as exc:
raise DispatchError(f"inputs: {exc.message}") from exc
return inputs

View file

@ -13,7 +13,7 @@ from typing import Any
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.dispatch import start_run from app.automations.dispatch import launch_run
from app.automations.persistence.enums.trigger_type import TriggerType from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.persistence.models.trigger import AutomationTrigger from app.automations.persistence.models.trigger import AutomationTrigger
from app.celery_app import celery_app from app.celery_app import celery_app
@ -58,7 +58,7 @@ async def _start_one(
session: AsyncSession, *, trigger: AutomationTrigger, event: Event session: AsyncSession, *, trigger: AutomationTrigger, event: Event
) -> None: ) -> None:
try: try:
run = await start_run( run = await launch_run(
session=session, session=session,
trigger=trigger, trigger=trigger,
runtime_inputs=event_runtime_inputs(event), runtime_inputs=event_runtime_inputs(event),

View file

@ -18,7 +18,7 @@ from datetime import UTC, datetime
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.dispatch import start_run from app.automations.dispatch import launch_run
from app.automations.persistence.enums.trigger_type import TriggerType from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.persistence.models.trigger import AutomationTrigger from app.automations.persistence.models.trigger import AutomationTrigger
from app.celery_app import celery_app from app.celery_app import celery_app
@ -159,7 +159,7 @@ async def _start_one(
return return
try: try:
run = await start_run( run = await launch_run(
session=session, session=session,
trigger=trigger, trigger=trigger,
runtime_inputs=schedule_runtime_inputs( runtime_inputs=schedule_runtime_inputs(

View file

@ -13,7 +13,7 @@ from typing import Any
import pytest import pytest
from app.automations.actions.agent_task.auto_decide import build_auto_decisions from app.automations.actions.builtin.agent_task.auto_decide import build_auto_decisions
pytestmark = pytest.mark.unit pytestmark = pytest.mark.unit

View file

@ -10,7 +10,9 @@ from __future__ import annotations
import pytest import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from app.automations.actions.agent_task.finalize import extract_final_assistant_message from app.automations.actions.builtin.agent_task.finalize import (
extract_final_assistant_message,
)
pytestmark = pytest.mark.unit pytestmark = pytest.mark.unit

View file

@ -1,10 +1,8 @@
"""Lock the input-validation contract used by ``dispatch_run``. """Lock the input-validation contract enforced before a run is enqueued.
``_validate_inputs`` is module-internal by convention (underscore), but it ``validate_inputs`` is the pure schema check that ``enqueue_run`` runs against
encodes a real behavior contract the rest of the system depends on, and the merged inputs. ``enqueue_run`` itself needs a real DB session, so tests target
public alternative (``dispatch_run``) requires a real DB session. Tests this pure function directly; the contract not the symbol is what's locked.
target the pure function directly; the contract not the symbol is what's
locked.
""" """
from __future__ import annotations from __future__ import annotations
@ -12,7 +10,7 @@ from __future__ import annotations
import pytest import pytest
from app.automations.dispatch.errors import DispatchError from app.automations.dispatch.errors import DispatchError
from app.automations.dispatch.run import _validate_inputs from app.automations.dispatch.inputs import validate_inputs
from app.automations.schemas.definition.envelope import AutomationDefinition from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.schemas.definition.inputs import Inputs from app.automations.schemas.definition.inputs import Inputs
from app.automations.schemas.definition.plan_step import PlanStep from app.automations.schemas.definition.plan_step import PlanStep
@ -42,7 +40,7 @@ def test_validate_inputs_passes_through_when_no_schema_is_declared() -> None:
"static_key": "value", "static_key": "value",
} }
assert _validate_inputs(definition, runtime_inputs) == runtime_inputs assert validate_inputs(definition, runtime_inputs) == runtime_inputs
def test_validate_inputs_returns_inputs_when_they_match_declared_schema() -> None: def test_validate_inputs_returns_inputs_when_they_match_declared_schema() -> None:
@ -58,14 +56,13 @@ def test_validate_inputs_returns_inputs_when_they_match_declared_schema() -> Non
inputs = {"topic": "weekly report"} inputs = {"topic": "weekly report"}
assert _validate_inputs(definition, inputs) == inputs assert validate_inputs(definition, inputs) == inputs
def test_validate_inputs_raises_dispatch_error_when_inputs_violate_schema() -> None: def test_validate_inputs_raises_dispatch_error_when_inputs_violate_schema() -> None:
"""Inputs that don't match the declared schema must surface as """Inputs that don't match the declared schema must surface as
``DispatchError`` (not the raw ``jsonschema.ValidationError``), so the ``DispatchError`` (not the raw ``jsonschema.ValidationError``), so every
schedule tick and any other caller can handle one dispatch-domain caller can handle one dispatch-domain exception type uniformly."""
exception type uniformly."""
schema = { schema = {
"type": "object", "type": "object",
"properties": {"topic": {"type": "string"}}, "properties": {"topic": {"type": "string"}},
@ -74,4 +71,4 @@ def test_validate_inputs_raises_dispatch_error_when_inputs_violate_schema() -> N
definition = _minimal_definition(inputs=Inputs(schema=schema)) definition = _minimal_definition(inputs=Inputs(schema=schema))
with pytest.raises(DispatchError): with pytest.raises(DispatchError):
_validate_inputs(definition, {"topic": 42}) # type violates string validate_inputs(definition, {"topic": 42}) # type violates string