Merge pull request #1449 from CREDO23/feature-automations-v2

[Feat] [Automations] Event-Driven Trigger Type with document.entered_folder
This commit is contained in:
Rohan Verma 2026-05-29 19:07:21 -07:00 committed by GitHub
commit 7972901f15
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
74 changed files with 1681 additions and 234 deletions

2
.gitignore vendored
View file

@ -18,3 +18,5 @@ surfsense_web/test-results/
surfsense_web/blob-report/
content_research/
automation-design-plan.md
automation-frontend-builder-plan.md

View file

@ -0,0 +1,47 @@
"""Add 'event' to automation_trigger_type enum
Revision ID: 147
Revises: 146
Create Date: 2026-05-29
Adds the ``event`` value to the ``automation_trigger_type`` enum so automations
can be triggered by published domain events, alongside the existing
``schedule`` triggers.
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "147"
down_revision: str | None = "146"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
ENUM_NAME = "automation_trigger_type"
NEW_VALUE = "event"
def upgrade() -> None:
"""Safely add 'event' to automation_trigger_type enum if missing."""
op.execute(
f"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_type t
JOIN pg_enum e ON t.oid = e.enumtypid
WHERE t.typname = '{ENUM_NAME}' AND e.enumlabel = '{NEW_VALUE}'
) THEN
ALTER TYPE {ENUM_NAME} ADD VALUE '{NEW_VALUE}';
END IF;
END
$$;
"""
)
def downgrade() -> None:
"""No-op: PostgreSQL does not support removing enum values."""
pass

View file

@ -43,6 +43,7 @@ from app.rate_limiter import get_real_client_ip, limiter
from app.routes import router as crud_router
from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate
from app.session_events import register_session_hooks
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
from app.utils.perf import log_system_snapshot
@ -588,6 +589,7 @@ async def lifespan(app: FastAPI):
"first real request will pay the full compile cost."
)
register_session_hooks()
log_system_snapshot("startup_complete")
yield

View file

@ -21,4 +21,4 @@ __all__ = [
]
# Built-in actions self-register at import time.
from . import agent_task # noqa: F401
from . import builtin # noqa: F401

View file

@ -0,0 +1,5 @@
"""Built-in action types — each in its own subpackage, self-registering at import."""
from __future__ import annotations
from . import agent_task # noqa: F401

View file

@ -2,8 +2,8 @@
from __future__ import annotations
from ..store import register_action
from ..types import ActionDefinition
from ...store import register_action
from ...types import ActionDefinition
from .factory import build_handler
from .params import AgentTaskActionParams

View file

@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Any
from ..types import ActionContext, ActionHandler
from ...types import ActionContext, ActionHandler
from .invoke import run_agent_task
from .params import AgentTaskActionParams

View file

@ -16,7 +16,7 @@ from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in
from app.db import ChatVisibility, async_session_maker
from app.schemas.new_chat import MentionedDocumentInfo
from ..types import ActionContext
from ...types import ActionContext
from .auto_decide import build_auto_decisions
from .dependencies import build_dependencies
from .finalize import extract_final_assistant_message

View file

@ -3,6 +3,6 @@
from __future__ import annotations
from .errors import DispatchError
from .run import dispatch_run
from .launch import launch_run
__all__ = ["DispatchError", "dispatch_run"]
__all__ = ["DispatchError", "launch_run"]

View file

@ -0,0 +1,43 @@
"""Merge and validate the inputs a run starts with."""
from __future__ import annotations
from typing import Any
import jsonschema
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.schemas.definition.envelope import AutomationDefinition
from .errors import DispatchError
def prepare_inputs(
definition: AutomationDefinition,
trigger: AutomationTrigger,
runtime_inputs: dict[str, Any] | None,
) -> dict[str, Any]:
"""Merge ``trigger.static_inputs`` over ``runtime_inputs``, then validate.
Static inputs win on key collision.
"""
merged = {**(runtime_inputs or {}), **(trigger.static_inputs or {})}
return validate_inputs(definition, merged)
def validate_inputs(
definition: AutomationDefinition, inputs: dict[str, Any]
) -> dict[str, Any]:
"""Validate ``inputs`` against the definition's optional declared schema.
No declared schema pass through unchanged so runtime keys (``fired_at``,
``last_fired_at``, ...) still reach the template context. A declared schema
that the inputs violate is surfaced as ``DispatchError``.
"""
if definition.inputs is None or not definition.inputs.schema_:
return inputs
try:
jsonschema.validate(instance=inputs, schema=definition.inputs.schema_)
except jsonschema.ValidationError as exc:
raise DispatchError(f"inputs: {exc.message}") from exc
return inputs

View file

@ -0,0 +1,60 @@
"""Launch a run for a trigger that fired: resolve, validate, persist, enqueue.
The trigger-facing entry every selector calls. A selector builds the runtime
inputs and hands one trigger row here; this resolves and guards its automation,
snapshots the definition onto a PENDING run, and enqueues execution. The
snapshot makes the run immune to later edits of the automation.
"""
from __future__ import annotations
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.persistence.enums.run_status import RunStatus
from app.automations.persistence.models.run import AutomationRun
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.tasks.execute_run import automation_run_execute
from .errors import DispatchError
from .inputs import prepare_inputs
from .resolve import resolve_active_automation
async def launch_run(
*,
session: AsyncSession,
trigger: AutomationTrigger,
runtime_inputs: dict[str, Any] | None = None,
) -> AutomationRun:
"""Resolve ``trigger``'s active automation and enqueue a PENDING run for it."""
automation = await resolve_active_automation(session, trigger)
try:
definition = AutomationDefinition.model_validate(automation.definition)
except Exception as exc:
raise DispatchError(f"invalid automation definition: {exc}") from exc
inputs = prepare_inputs(definition, trigger, runtime_inputs)
snapshot = definition.model_dump(mode="json", by_alias=True)
run = AutomationRun(
automation_id=automation.id,
trigger_id=trigger.id,
status=RunStatus.PENDING,
definition_snapshot=snapshot,
inputs=inputs,
step_results=[],
artifacts=[],
)
session.add(run)
await session.commit()
await session.refresh(run)
automation_run_execute.apply_async(
args=[run.id],
time_limit=definition.execution.timeout_seconds,
)
return run

View file

@ -0,0 +1,40 @@
"""Resolve the automation behind a trigger and guard that it may run."""
from __future__ import annotations
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.persistence.enums.automation_status import AutomationStatus
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.trigger import AutomationTrigger
from .errors import DispatchError
async def resolve_active_automation(
session: AsyncSession, trigger: AutomationTrigger
) -> Automation:
"""Load ``trigger``'s automation and require it ``ACTIVE``.
Raises ``DispatchError`` if the automation is missing or not active.
"""
automation = await _load_automation(session, trigger.automation_id)
if automation is None:
raise DispatchError(
f"automation {trigger.automation_id} not found for trigger {trigger.id}"
)
if automation.status != AutomationStatus.ACTIVE:
raise DispatchError(
f"automation {trigger.automation_id} is {automation.status.value}, not active"
)
return automation
async def _load_automation(
session: AsyncSession, automation_id: int
) -> Automation | None:
stmt = select(Automation).where(Automation.id == automation_id)
return (await session.execute(stmt)).scalar_one_or_none()

View file

@ -1,83 +0,0 @@
"""Generic run dispatch: validate, snapshot, persist, enqueue. Shared by every trigger."""
from __future__ import annotations
from typing import Any
import jsonschema
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.persistence.enums.run_status import RunStatus
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.run import AutomationRun
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.tasks.execute_run import automation_run_execute
from .errors import DispatchError
async def dispatch_run(
*,
session: AsyncSession,
automation: Automation,
trigger: AutomationTrigger,
runtime_inputs: dict[str, Any] | None = None,
) -> AutomationRun:
"""Validate, snapshot the definition, persist an ``AutomationRun``, enqueue execution.
Final inputs = ``trigger.static_inputs`` merged with ``runtime_inputs``,
static winning on key collision. The merged dict is validated against
``automation.definition.inputs.schema_`` and stored on the run.
Callers (trigger-specific adapters) are responsible for resolving
``automation`` and ``trigger`` and for the trigger-side ``ACTIVE`` /
``enabled`` guards. This function only handles what's identical across
every trigger type.
"""
try:
definition = AutomationDefinition.model_validate(automation.definition)
except Exception as exc:
raise DispatchError(f"invalid automation definition: {exc}") from exc
merged_inputs = {**(runtime_inputs or {}), **(trigger.static_inputs or {})}
validated_inputs = _validate_inputs(definition, merged_inputs)
snapshot = definition.model_dump(mode="json", by_alias=True)
run = AutomationRun(
automation_id=automation.id,
trigger_id=trigger.id,
status=RunStatus.PENDING,
definition_snapshot=snapshot,
inputs=validated_inputs,
step_results=[],
artifacts=[],
)
session.add(run)
await session.commit()
await session.refresh(run)
automation_run_execute.apply_async(
args=[run.id],
time_limit=definition.execution.timeout_seconds,
)
return run
def _validate_inputs(
definition: AutomationDefinition, inputs: dict[str, Any]
) -> dict[str, Any]:
"""Validate merged inputs against the optional declared schema.
No declared schema pass through (runtime inputs like ``fired_at`` /
``last_fired_at`` and trigger ``static_inputs`` must still reach the
template context). Returning ``{}`` here strips them and makes Jinja
blow up on any ``{{ inputs.* }}`` reference.
"""
if definition.inputs is None or not definition.inputs.schema_:
return inputs
try:
jsonschema.validate(instance=inputs, schema=definition.inputs.schema_)
except jsonschema.ValidationError as exc:
raise DispatchError(f"inputs: {exc.message}") from exc
return inputs

View file

@ -1,8 +1,8 @@
"""Trigger-kind discriminator.
v1 only registers ``schedule``. ``manual`` is reserved in the enum (mirrors the
postgres enum) but is intentionally unregistered pending a redesign of the
"Run now" UX.
``schedule`` and ``event`` are registered. ``manual`` is reserved in the enum
(mirrors the postgres enum) but is intentionally unregistered pending a redesign
of the "Run now" UX.
"""
from __future__ import annotations
@ -12,4 +12,5 @@ from enum import StrEnum
class TriggerType(StrEnum):
SCHEDULE = "schedule"
EVENT = "event"
MANUAL = "manual"

View file

@ -25,7 +25,7 @@ from app.automations.services.model_policy import (
get_automation_model_eligibility,
)
from app.automations.triggers import get_trigger
from app.automations.triggers.schedule import compute_next_fire_at
from app.automations.triggers.builtin.schedule import compute_next_fire_at
from app.db import Permission, SearchSpace, User, get_async_session
from app.users import current_active_user
from app.utils.rbac import check_permission

View file

@ -13,7 +13,7 @@ from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
from app.automations.triggers import get_trigger
from app.automations.triggers.schedule import compute_next_fire_at
from app.automations.triggers.builtin.schedule import compute_next_fire_at
from app.db import Permission, User, get_async_session
from app.users import current_active_user
from app.utils.rbac import check_permission

View file

@ -1,7 +1,6 @@
"""Triggers domain: registry surface + built-in trigger packages.
Each trigger lives in its own subpackage (``schedule/``, ...) and
self-registers at import time via its ``definition`` module.
Built-in trigger types live under ``builtin/`` and self-register at import time.
"""
from __future__ import annotations
@ -17,4 +16,4 @@ __all__ = [
]
# Built-in triggers self-register at import time.
from . import schedule # noqa: F401
from . import builtin # noqa: F401

View file

@ -0,0 +1,5 @@
"""Built-in trigger types — each in its own subpackage, self-registering at import."""
from __future__ import annotations
from . import event, schedule # noqa: F401

View file

@ -0,0 +1,29 @@
"""``event`` trigger: fire an automation when a matching domain event is published.
Subscribes to the event bus and matches events against a user-authored JSON
filter (see :mod:`.filter`).
"""
from __future__ import annotations
from app.event_bus import bus
from .filter import FilterError, matches
from .inputs import event_runtime_inputs
from .match import trigger_matches_event
from .params import EventTriggerParams
from .source import on_event
__all__ = [
"EventTriggerParams",
"FilterError",
"event_runtime_inputs",
"matches",
"trigger_matches_event",
]
# Side-effect: register on the triggers store.
from . import definition # noqa: F401
# Side-effect: react to published events.
bus.subscribe(on_event)

View file

@ -0,0 +1,16 @@
"""``event`` ``TriggerDefinition`` registration."""
from __future__ import annotations
from app.automations.triggers.store import register_trigger
from app.automations.triggers.types import TriggerDefinition
from .params import EventTriggerParams
EVENT_TRIGGER = TriggerDefinition(
type="event",
description="Fire when a matching domain event is published.",
params_model=EventTriggerParams,
)
register_trigger(EVENT_TRIGGER)

View file

@ -0,0 +1,78 @@
"""Pure JSON filter grammar: ``matches(filter_expr, payload) -> bool``.
The ``event`` trigger uses it to decide whether an event fires the automation.
"""
from __future__ import annotations
import operator
from collections.abc import Callable
from typing import Any
class FilterError(ValueError):
"""Unknown operator in a filter. Raised (not silently false) so a bad filter
fails at authoring time instead of quietly disabling the trigger."""
# Scalar comparison operators: (actual, operand) -> bool.
_COMPARATORS: dict[str, Callable[[Any, Any], bool]] = {
"$eq": operator.eq,
"$ne": operator.ne,
"$gt": operator.gt,
"$gte": operator.ge,
"$lt": operator.lt,
"$lte": operator.le,
"$in": lambda actual, operand: actual in operand,
"$nin": lambda actual, operand: actual not in operand,
}
# Sentinel for "the payload has no such field" — distinct from a present None.
_MISSING = object()
def matches(filter_expr: dict[str, Any], payload: dict[str, Any]) -> bool:
"""Return ``True`` when ``payload`` satisfies every constraint in ``filter_expr``.
An empty filter expresses "no constraints" and matches every payload.
Sibling keys (fields and logical operators alike) are ANDed together.
"""
for key, value in filter_expr.items():
if key == "$and":
if not all(matches(sub, payload) for sub in value):
return False
elif key == "$or":
if not any(matches(sub, payload) for sub in value):
return False
elif key == "$not":
if matches(value, payload):
return False
elif key.startswith("$"):
raise FilterError(f"unknown logical operator: {key}")
elif not _match_condition(value, payload.get(key, _MISSING)):
return False
return True
def _match_condition(condition: Any, actual: Any) -> bool:
"""Match one field's ``actual`` value against its ``condition``.
A dict condition is an operator object (``{"$gt": 10}``); every operator in
it must hold. Any other value is an implicit equality check. A field absent
from the payload (``actual is _MISSING``) fails every constraint.
"""
if actual is _MISSING:
return False
if isinstance(condition, dict):
return all(
_apply_operator(op, operand, actual)
for op, operand in condition.items()
)
return actual == condition
def _apply_operator(op: str, operand: Any, actual: Any) -> bool:
comparator = _COMPARATORS.get(op)
if comparator is not None:
return comparator(actual, operand)
raise FilterError(f"unknown operator: {op}")

View file

@ -0,0 +1,17 @@
"""Build run inputs from a published event."""
from __future__ import annotations
from typing import Any
from app.event_bus import Event
def event_runtime_inputs(event: Event) -> dict[str, Any]:
"""Flatten the event payload and stamp event metadata as run inputs."""
return {
**event.payload,
"event_type": event.event_type,
"event_id": event.event_id,
"occurred_at": event.occurred_at.isoformat(),
}

View file

@ -0,0 +1,16 @@
"""Pure predicate: does an event trigger fire for a given event?"""
from __future__ import annotations
from typing import Any
from app.event_bus import Event
from .filter import matches
def trigger_matches_event(params: dict[str, Any], event: Event) -> bool:
"""True when an event trigger configured with ``params`` should fire for ``event``."""
if params.get("event_type") != event.event_type:
return False
return matches(params.get("filter") or {}, event.payload)

View file

@ -0,0 +1,23 @@
"""``EventTriggerParams`` — params for the ``event`` trigger type."""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
class EventTriggerParams(BaseModel):
model_config = ConfigDict(extra="forbid")
event_type: str = Field(
...,
min_length=1,
description="Event type to listen for.",
examples=["document.indexed"],
)
filter: dict[str, Any] = Field(
default_factory=dict,
description="JSON filter matched against the event payload.",
examples=[{"document_type": "FILE"}],
)

View file

@ -0,0 +1,75 @@
"""Event selector (worker task): pick the triggers an event fires, start each.
The source enqueues this with a serialized event. Here we load the enabled
``event`` triggers for that event type, keep the ones whose filter matches the
payload, and start a run for each. Per-trigger failures are isolated.
"""
from __future__ import annotations
import logging
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.dispatch import launch_run
from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.persistence.models.trigger import AutomationTrigger
from app.celery_app import celery_app
from app.event_bus import Event
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
from .inputs import event_runtime_inputs
from .match import trigger_matches_event
from .source import TASK_NAME
logger = logging.getLogger(__name__)
@celery_app.task(name=TASK_NAME)
def automation_event_select(event: dict[str, Any]) -> None:
"""Select and start the runs an event fires."""
return run_async_celery_task(lambda: _select_and_start(event))
async def _select_and_start(event_dict: dict[str, Any]) -> None:
event = Event.model_validate(event_dict)
session_maker = get_celery_session_maker()
async with session_maker() as session:
for trigger in await _eligible(session, event=event):
await _start_one(session, trigger=trigger, event=event)
async def _eligible(
session: AsyncSession, *, event: Event
) -> list[AutomationTrigger]:
"""Enabled ``event`` triggers for this event type whose filter matches."""
stmt = select(AutomationTrigger).where(
AutomationTrigger.type == TriggerType.EVENT,
AutomationTrigger.enabled.is_(True),
AutomationTrigger.params["event_type"].astext == event.event_type,
)
triggers = (await session.execute(stmt)).scalars().all()
return [t for t in triggers if trigger_matches_event(t.params, event)]
async def _start_one(
session: AsyncSession, *, trigger: AutomationTrigger, event: Event
) -> None:
try:
run = await launch_run(
session=session,
trigger=trigger,
runtime_inputs=event_runtime_inputs(event),
)
logger.info(
"event fire: trigger=%d automation=%d run=%d event=%s",
trigger.id,
trigger.automation_id,
run.id,
event.event_id,
)
except Exception:
logger.exception("event fire failed for trigger %d", trigger.id)
await session.rollback()

View file

@ -0,0 +1,19 @@
"""Event trigger source: the bus subscriber that enqueues the selector.
Runs in whatever process published the event, so it stays thin it only hands
the event to a worker (the selector does the DB matching).
"""
from __future__ import annotations
from app.event_bus import Event
TASK_NAME = "automation_event_select"
async def on_event(event: Event) -> None:
"""Enqueue the selector for ``event``."""
# Lazy import: keeps app.celery_app out of the triggers-package import graph.
from app.celery_app import celery_app
celery_app.send_task(TASK_NAME, kwargs={"event": event.model_dump(mode="json")})

View file

@ -3,14 +3,12 @@
from __future__ import annotations
from .cron import InvalidCronError, compute_next_fire_at, validate_cron
from .dispatch import dispatch_schedule_run
from .params import ScheduleTriggerParams
__all__ = [
"InvalidCronError",
"ScheduleTriggerParams",
"compute_next_fire_at",
"dispatch_schedule_run",
"validate_cron",
]

View file

@ -2,8 +2,9 @@
from __future__ import annotations
from ..store import register_trigger
from ..types import TriggerDefinition
from app.automations.triggers.store import register_trigger
from app.automations.triggers.types import TriggerDefinition
from .params import ScheduleTriggerParams
SCHEDULE_TRIGGER = TriggerDefinition(

View file

@ -0,0 +1,27 @@
"""Build run inputs from a schedule fire."""
from __future__ import annotations
from datetime import datetime
from typing import Any
def schedule_runtime_inputs(
*,
fired_at: datetime,
scheduled_for: datetime,
previous_last_fired_at: datetime | None,
) -> dict[str, Any]:
"""Calendar context for a scheduled run.
- ``fired_at`` actual fire time
- ``scheduled_for`` cron-derived target time for this fire
- ``last_fired_at`` previous fire time, or null on first fire
"""
return {
"fired_at": fired_at.isoformat(),
"scheduled_for": scheduled_for.isoformat(),
"last_fired_at": (
previous_last_fired_at.isoformat() if previous_last_fired_at else None
),
}

View file

@ -1,15 +1,12 @@
"""Celery Beat tick that fires due ``schedule`` triggers.
"""Schedule selector (worker task): claim due triggers and start each.
Runs every minute. Each tick performs two passes:
Beat ticks this every minute. Two passes:
1. **Self-heal**: enabled schedule triggers with NULL ``next_fire_at`` get
it computed from their ``cron`` + ``timezone`` (e.g. fresh inserts or
rows restored from backup).
2. **Claim & fire**: due rows are locked with ``FOR UPDATE SKIP LOCKED``,
their ``next_fire_at`` is advanced and ``last_fired_at`` is set, and
``dispatch_schedule_run`` is invoked for each. Dispatch errors are
logged; a missed fire stays missed (matches K8s CronJob / Airflow
``catchup=False`` semantics).
1. **Self-heal**: enabled schedule triggers with NULL ``next_fire_at`` get it
computed from their ``cron`` + ``timezone`` (fresh inserts, restored rows).
2. **Claim & start**: due rows are locked ``FOR UPDATE SKIP LOCKED``, their
``next_fire_at`` is advanced and ``last_fired_at`` set, and a run is started
for each. A missed fire stays missed (``catchup=False`` semantics).
"""
from __future__ import annotations
@ -21,19 +18,17 @@ from datetime import UTC, datetime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.dispatch import launch_run
from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.triggers.schedule import (
InvalidCronError,
compute_next_fire_at,
dispatch_schedule_run,
)
from app.celery_app import celery_app
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
from .cron import InvalidCronError, compute_next_fire_at
from .inputs import schedule_runtime_inputs
from .source import TASK_NAME
TASK_NAME = "automation_schedule_tick"
logger = logging.getLogger(__name__)
# Cap rows touched per tick so a backlog of due triggers can't starve the
# worker; remaining rows fire on the next tick.
@ -50,8 +45,8 @@ class _Claim:
@celery_app.task(name=TASK_NAME)
def automation_schedule_tick() -> None:
"""Tick once: self-heal NULL next_fire_at, claim due rows, fire each."""
def automation_schedule_select() -> None:
"""Tick once: self-heal NULL next_fire_at, claim due rows, start each."""
return run_async_celery_task(_tick)
@ -67,7 +62,7 @@ async def _tick() -> None:
return
for claim in claims:
await _fire_one(session, claim=claim, fired_at=now)
await _start_one(session, claim=claim, fired_at=now)
async def _self_heal_null_next_fire(session: AsyncSession, *, now: datetime) -> None:
@ -155,21 +150,23 @@ async def _claim_due_triggers(session: AsyncSession, *, now: datetime) -> list[_
return claims
async def _fire_one(
async def _start_one(
session: AsyncSession, *, claim: _Claim, fired_at: datetime
) -> None:
"""Reload the trigger post-commit and dispatch a run for it."""
"""Reload the trigger post-commit and start a run for it."""
trigger = await session.get(AutomationTrigger, claim.trigger_id)
if trigger is None:
return
try:
run = await dispatch_schedule_run(
run = await launch_run(
session=session,
trigger=trigger,
fired_at=fired_at,
scheduled_for=claim.scheduled_for,
previous_last_fired_at=claim.previous_last_fired_at,
runtime_inputs=schedule_runtime_inputs(
fired_at=fired_at,
scheduled_for=claim.scheduled_for,
previous_last_fired_at=claim.previous_last_fired_at,
),
)
logger.info(
"scheduled fire: trigger=%d automation=%d run=%d",

View file

@ -0,0 +1,20 @@
"""Schedule trigger source: Celery Beat ticks the selector every minute.
``BEAT_SCHEDULE`` is merged into ``celery_app.conf.beat_schedule``. Per-row cron
math is precomputed (the ``next_fire_at`` column), so each tick is an indexed
lookup rather than N cron evaluations.
"""
from __future__ import annotations
from celery.schedules import crontab
TASK_NAME = "automation_schedule_select"
BEAT_SCHEDULE = {
"automation-schedule-select": {
"task": TASK_NAME,
"schedule": crontab(minute="*"),
"options": {"expires": 50},
},
}

View file

@ -1,67 +0,0 @@
"""Schedule dispatch adapter: load + guard, then call generic dispatch."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.dispatch import DispatchError, dispatch_run
from app.automations.persistence.enums.automation_status import AutomationStatus
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.run import AutomationRun
from app.automations.persistence.models.trigger import AutomationTrigger
async def dispatch_schedule_run(
*,
session: AsyncSession,
trigger: AutomationTrigger,
fired_at: datetime,
scheduled_for: datetime,
previous_last_fired_at: datetime | None,
) -> AutomationRun:
"""Fire one scheduled run for ``trigger``.
Emits calendar context as runtime inputs:
- ``fired_at`` actual fire time
- ``scheduled_for`` cron-derived target time for this fire
- ``last_fired_at`` fire time of the previous run, or null on first fire
The caller (the schedule tick) is responsible for selecting due triggers
and advancing ``next_fire_at`` / ``last_fired_at`` before invoking this.
"""
automation = await _load_automation(session, trigger.automation_id)
if automation is None:
raise DispatchError(
f"automation {trigger.automation_id} not found for trigger {trigger.id}"
)
if automation.status != AutomationStatus.ACTIVE:
raise DispatchError(
f"automation {trigger.automation_id} is {automation.status.value}, not active"
)
runtime_inputs = {
"fired_at": fired_at.isoformat(),
"scheduled_for": scheduled_for.isoformat(),
"last_fired_at": (
previous_last_fired_at.isoformat() if previous_last_fired_at else None
),
}
return await dispatch_run(
session=session,
automation=automation,
trigger=trigger,
runtime_inputs=runtime_inputs,
)
async def _load_automation(
session: AsyncSession, automation_id: int
) -> Automation | None:
stmt = select(Automation).where(Automation.id == automation_id)
return (await session.execute(stmt)).scalar_one_or_none()

View file

@ -189,7 +189,8 @@ celery_app = Celery(
"app.tasks.celery_tasks.stale_notification_cleanup_task",
"app.tasks.celery_tasks.stripe_reconciliation_task",
"app.automations.tasks.execute_run",
"app.automations.tasks.schedule_tick",
"app.automations.triggers.builtin.schedule.selector",
"app.automations.triggers.builtin.event.selector",
],
)
@ -247,6 +248,12 @@ celery_app.conf.update(
},
)
# Imported late (after celery_app is built) to keep the automations triggers
# package out of this module's top-level import graph.
from app.automations.triggers.builtin.schedule.source import ( # noqa: E402
BEAT_SCHEDULE as SCHEDULE_BEAT_SCHEDULE,
)
# Configure Celery Beat schedule
# This uses a meta-scheduler pattern: instead of creating individual Beat schedules
# for each connector, we have ONE schedule that checks the database at the configured interval
@ -284,14 +291,7 @@ celery_app.conf.beat_schedule = {
"expires": 60,
},
},
# Fire due automation schedule triggers. Ticks every minute; per-row cron
# math is precomputed (next_fire_at column) so the tick is an indexed
# lookup, not N cron evaluations.
"automation-schedule-tick": {
"task": "automation_schedule_tick",
"schedule": crontab(minute="*"),
"options": {
"expires": 50,
},
},
# Fire due automation schedule triggers (Beat entry owned by the schedule
# trigger; see app.automations.triggers.builtin.schedule.source).
**SCHEDULE_BEAT_SCHEDULE,
}

View file

@ -0,0 +1,25 @@
"""In-process domain event bus.
Domain-agnostic pub/sub. Producers ``await bus.publish(...)``; subscribers
``bus.subscribe(...)``. Domain modules depend on it, never the reverse.
from app.event_bus import bus
await bus.publish("document.indexed", {"document_id": 42}, search_space_id=7)
"""
from __future__ import annotations
from . import events # noqa: F401 — populates the event-type catalog
from .bus import EventBus, Subscriber, bus
from .catalog import EventCatalog, EventType, catalog
from .event import Event
__all__ = [
"Event",
"EventBus",
"EventCatalog",
"EventType",
"Subscriber",
"bus",
"catalog",
]

View file

@ -0,0 +1,77 @@
"""In-process pub/sub. Streams :class:`Event` values from producers to listeners.
Boundary-crossing (Celery, DB, workers) is a subscriber's job — e.g. the
``event`` trigger enqueues its own task.
"""
from __future__ import annotations
import asyncio
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from .event import Event
logger = logging.getLogger(__name__)
Subscriber = Callable[[Event], Awaitable[None]]
class EventBus:
"""An in-process pub/sub bus with a per-instance subscriber registry."""
def __init__(self) -> None:
self._subscribers: list[Subscriber] = []
def subscribe(self, handler: Subscriber) -> Subscriber:
"""Register ``handler`` for every event. Idempotent; returns the handler
so it works as a decorator."""
if handler not in self._subscribers:
self._subscribers.append(handler)
return handler
def subscribers(self) -> list[Subscriber]:
"""Defensive snapshot of the registered subscribers."""
return list(self._subscribers)
async def publish(
self,
event_type: str,
payload: dict[str, Any] | None = None,
*,
search_space_id: int,
) -> None:
"""Stamp an :class:`Event` and fan it out. Call after your commit."""
event = Event(
event_type=event_type,
payload=payload or {},
search_space_id=search_space_id,
)
await self.dispatch(event)
async def dispatch(self, event: Event) -> None:
"""Fan ``event`` out concurrently. Subscriber failures are logged and
isolated; never propagate."""
subscribers = self.subscribers()
if not subscribers:
return
results = await asyncio.gather(
*(handler(event) for handler in subscribers),
return_exceptions=True,
)
for handler, result in zip(subscribers, results, strict=True):
if isinstance(result, Exception):
logger.error(
"event subscriber %r failed for event %s (%s)",
getattr(handler, "__qualname__", handler),
event.event_id,
event.event_type,
exc_info=result,
)
# Process-wide bus. Producers publish to it; subscribers register on it.
bus = EventBus()

View file

@ -0,0 +1,48 @@
"""Event type catalog: the deliberate contract behind each event.
``EventType`` declares a dotted name and the shape of its payload.
``EventCatalog`` is the registry populated once at import by each event type
module. ``catalog`` is the process-wide singleton.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from pydantic import BaseModel
@dataclass(frozen=True, slots=True)
class EventType:
type: str
description: str
payload_model: type[BaseModel]
@property
def payload_schema(self) -> dict[str, Any]:
"""JSON Schema (draft 2020-12) derived from ``payload_model``."""
return self.payload_model.model_json_schema()
class EventCatalog:
"""Registry of known event types. Populated at import; read at runtime."""
def __init__(self) -> None:
self._registry: dict[str, EventType] = {}
def register(self, event_type: EventType) -> None:
"""Register an event type. Raises on duplicate type."""
if event_type.type in self._registry:
raise ValueError(f"Event type already registered: {event_type.type!r}")
self._registry[event_type.type] = event_type
def get(self, type_: str) -> EventType | None:
return self._registry.get(type_)
def all(self) -> dict[str, EventType]:
"""Defensive snapshot of the registry."""
return dict(self._registry)
catalog = EventCatalog()

View file

@ -0,0 +1,38 @@
"""The ``Event`` value object — the only shape that crosses the bus.
An immutable fact: something named happened, with this payload, in this space,
at this time. JSON round-trippable so a subscriber can queue it to a worker.
"""
from __future__ import annotations
import uuid
from datetime import UTC, datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
def _new_event_id() -> str:
return uuid.uuid4().hex
def _now() -> datetime:
return datetime.now(UTC)
class Event(BaseModel):
"""A published domain fact.
``event_type`` is a dotted namespace (``document.indexed``, etc). ``payload`` is
JSON-serializable. ``search_space_id`` scopes delivery. ``event_id`` and
``occurred_at`` are engine-stamped.
"""
model_config = ConfigDict(frozen=True)
event_type: str
payload: dict[str, Any] = Field(default_factory=dict)
search_space_id: int
event_id: str = Field(default_factory=_new_event_id)
occurred_at: datetime = Field(default_factory=_now)

View file

@ -0,0 +1,5 @@
"""Domain event type definitions — each in its own module, self-registering at import."""
from __future__ import annotations
from . import document_entered_folder # noqa: F401

View file

@ -0,0 +1,86 @@
"""``document.entered_folder``: a document became a member of a folder.
Fires once per arrival, however the document got there (upload, AI sort, move).
The payload carries the fields a user can filter a trigger on.
"""
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, computed_field
from app.event_bus.catalog import EventType, catalog
EVENT_TYPE = "document.entered_folder"
class DocumentEnteredFolderPayload(BaseModel):
"""Snapshot of the document at the moment it entered ``folder_id``.
``previous_folder_id`` is the folder it left, or ``None`` for a first
placement. ``is_move`` derives from it and is emitted for filtering.
"""
model_config = ConfigDict(extra="forbid")
document_id: int
folder_id: int
previous_folder_id: int | None = None
document_type: str
title: str
connector_id: int | None = None
created_by_id: str | None = None
@computed_field
@property
def is_move(self) -> bool:
return self.previous_folder_id is not None
catalog.register(
EventType(
type=EVENT_TYPE,
description="A document became a member of a folder.",
payload_model=DocumentEnteredFolderPayload,
)
)
def payload_if_entered_folder(
*,
document_id: int,
search_space_id: int,
new_folder_id: int | None,
previous_folder_id: int | None,
folder_id_changed: bool,
status_state: str,
document_type: str,
title: str,
connector_id: int | None,
created_by_id: str | None,
) -> dict | None:
"""Return a publish payload if this commit represents a folder arrival, else None.
``folder_id_changed`` comes from SQLAlchemy attribute history it is True
only when ``folder_id`` actually changed in this transaction, preventing
spurious events on unrelated saves.
"""
if not folder_id_changed:
return None
if new_folder_id is None:
return None
if status_state != "ready":
return None
return {
"event_type": EVENT_TYPE,
"search_space_id": search_space_id,
"payload": {
"document_id": document_id,
"folder_id": new_folder_id,
"previous_folder_id": previous_folder_id,
"document_type": document_type,
"title": title,
"connector_id": connector_id,
"created_by_id": created_by_id,
},
}

View file

@ -525,11 +525,8 @@ async def bulk_move_documents(
detail="Cannot move documents to a folder in a different search space",
)
await session.execute(
Document.__table__.update()
.where(Document.id.in_(request.document_ids))
.values(folder_id=request.folder_id)
)
for doc in documents:
doc.folder_id = request.folder_id
await session.commit()
return {"message": f"{len(request.document_ids)} documents moved successfully"}

View file

@ -0,0 +1,101 @@
"""SQLAlchemy session event hooks — wired once at app startup.
Detects document folder arrivals across every ORM commit and publishes
``document.entered_folder`` events to the bus after the transaction is durable.
"""
from __future__ import annotations
import asyncio
import logging
from sqlalchemy import event
from sqlalchemy.orm import Session, attributes
from app.db import Document, DocumentStatus
from app.event_bus.bus import EventBus, bus as default_bus
from app.event_bus.events.document_entered_folder import payload_if_entered_folder
logger = logging.getLogger(__name__)
_PENDING_KEY = "_entered_folder_pending"
def _after_flush(session: Session, flush_context: object) -> None:
"""Collect folder-arrival candidates while attribute history is still available."""
pending: list[dict] = []
for obj in list(session.new) + list(session.dirty):
if not isinstance(obj, Document):
continue
history = attributes.get_history(obj, "folder_id")
if not history.added:
continue
new_folder_id = history.added[0]
previous_folder_id = history.deleted[0] if history.deleted else None
result = payload_if_entered_folder(
document_id=obj.id,
search_space_id=obj.search_space_id,
new_folder_id=new_folder_id,
previous_folder_id=previous_folder_id,
folder_id_changed=True,
status_state=DocumentStatus.get_state(obj.status) or "",
document_type=obj.document_type.value if obj.document_type else "",
title=obj.title or "",
connector_id=obj.connector_id,
created_by_id=str(obj.created_by_id) if obj.created_by_id else None,
)
if result is not None:
pending.append(result)
setattr(session, _PENDING_KEY, pending)
def _after_commit(session: Session) -> None:
"""Publish collected events now that the transaction is durable."""
pending: list[dict] = getattr(session, _PENDING_KEY, [])
if not pending:
return
setattr(session, _PENDING_KEY, [])
try:
loop = asyncio.get_running_loop()
except RuntimeError:
logger.warning("No running event loop — skipping %d event(s)", len(pending))
return
tasks = [
loop.create_task(
default_bus.publish(
item["event_type"],
item["payload"],
search_space_id=item["search_space_id"],
)
)
for item in pending
]
for task in tasks:
task.add_done_callback(
lambda t: logger.error("event publish failed: %s", t.exception())
if not t.cancelled() and t.exception()
else None
)
def _after_rollback(session: Session) -> None:
"""Discard any pending events — the transaction did not commit."""
setattr(session, _PENDING_KEY, [])
def register_session_hooks(bus: EventBus = default_bus) -> None:
"""Register document folder-arrival hooks on the SQLAlchemy Session class.
Call once at application startup (e.g. in ``app.app`` lifespan). Idempotent
SQLAlchemy deduplicates identical listener registrations.
"""
event.listen(Session, "after_flush", _after_flush)
event.listen(Session, "after_commit", _after_commit)
event.listen(Session, "after_rollback", _after_rollback)

View file

@ -13,7 +13,7 @@ from typing import Any
import pytest
from app.automations.actions.agent_task.auto_decide import build_auto_decisions
from app.automations.actions.builtin.agent_task.auto_decide import build_auto_decisions
pytestmark = pytest.mark.unit

View file

@ -10,7 +10,9 @@ from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from app.automations.actions.agent_task.finalize import extract_final_assistant_message
from app.automations.actions.builtin.agent_task.finalize import (
extract_final_assistant_message,
)
pytestmark = pytest.mark.unit

View file

@ -1,10 +1,8 @@
"""Lock the input-validation contract used by ``dispatch_run``.
"""Lock the input-validation contract enforced before a run is enqueued.
``_validate_inputs`` is module-internal by convention (underscore), but it
encodes a real behavior contract the rest of the system depends on, and the
public alternative (``dispatch_run``) requires a real DB session. Tests
target the pure function directly; the contract not the symbol is what's
locked.
``validate_inputs`` is the pure schema check that ``enqueue_run`` runs against
merged inputs. ``enqueue_run`` itself needs a real DB session, so tests target
this pure function directly; the contract not the symbol is what's locked.
"""
from __future__ import annotations
@ -12,7 +10,7 @@ from __future__ import annotations
import pytest
from app.automations.dispatch.errors import DispatchError
from app.automations.dispatch.run import _validate_inputs
from app.automations.dispatch.inputs import validate_inputs
from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.schemas.definition.inputs import Inputs
from app.automations.schemas.definition.plan_step import PlanStep
@ -42,7 +40,7 @@ def test_validate_inputs_passes_through_when_no_schema_is_declared() -> None:
"static_key": "value",
}
assert _validate_inputs(definition, runtime_inputs) == runtime_inputs
assert validate_inputs(definition, runtime_inputs) == runtime_inputs
def test_validate_inputs_returns_inputs_when_they_match_declared_schema() -> None:
@ -58,14 +56,13 @@ def test_validate_inputs_returns_inputs_when_they_match_declared_schema() -> Non
inputs = {"topic": "weekly report"}
assert _validate_inputs(definition, inputs) == inputs
assert validate_inputs(definition, inputs) == inputs
def test_validate_inputs_raises_dispatch_error_when_inputs_violate_schema() -> None:
"""Inputs that don't match the declared schema must surface as
``DispatchError`` (not the raw ``jsonschema.ValidationError``), so the
schedule tick and any other caller can handle one dispatch-domain
exception type uniformly."""
``DispatchError`` (not the raw ``jsonschema.ValidationError``), so every
caller can handle one dispatch-domain exception type uniformly."""
schema = {
"type": "object",
"properties": {"topic": {"type": "string"}},
@ -74,4 +71,4 @@ def test_validate_inputs_raises_dispatch_error_when_inputs_violate_schema() -> N
definition = _minimal_definition(inputs=Inputs(schema=schema))
with pytest.raises(DispatchError):
_validate_inputs(definition, {"topic": 42}) # type violates string
validate_inputs(definition, {"topic": 42}) # type violates string

View file

@ -39,7 +39,7 @@ def test_run_status_string_values_are_stable() -> None:
def test_trigger_type_keeps_manual_member_even_though_unregistered() -> None:
"""``MANUAL`` is reserved (mirrors the Postgres enum) but the trigger
store does not register it in v1. The enum must keep both members so
existing DB rows and the schema migration plan stay valid."""
assert {member.value for member in TriggerType} == {"schedule", "manual"}
"""``schedule`` and ``event`` are registered; ``MANUAL`` is reserved
(mirrors the Postgres enum) but the trigger store does not register it.
The enum must keep every member so DB rows and migrations stay valid."""
assert {member.value for member in TriggerType} == {"schedule", "event", "manual"}

View file

@ -0,0 +1,18 @@
"""The ``event`` trigger self-registers on the triggers store at import."""
from __future__ import annotations
import pytest
from app.automations.triggers import get_trigger
from app.automations.triggers.builtin.event.params import EventTriggerParams
pytestmark = pytest.mark.unit
def test_event_trigger_is_registered() -> None:
definition = get_trigger("event")
assert definition is not None
assert definition.type == "event"
assert definition.params_model is EventTriggerParams

View file

@ -0,0 +1,115 @@
"""Behavior tests for the ``matches`` filter grammar."""
from __future__ import annotations
import pytest
from app.automations.triggers.builtin.event.filter import FilterError, matches
pytestmark = pytest.mark.unit
def test_empty_filter_matches_any_payload() -> None:
assert matches({}, {"document_id": 42, "document_type": "FILE"}) is True
assert matches({}, {}) is True
def test_scalar_value_is_implicit_equality() -> None:
flt = {"document_type": "FILE"}
assert matches(flt, {"document_type": "FILE"}) is True
assert matches(flt, {"document_type": "WEBPAGE"}) is False
def test_multiple_fields_are_anded() -> None:
flt = {"document_type": "FILE", "search_space_id": 7}
assert matches(flt, {"document_type": "FILE", "search_space_id": 7}) is True
assert matches(flt, {"document_type": "FILE", "search_space_id": 9}) is False
def test_gt_operator_compares_greater_than() -> None:
flt = {"page_count": {"$gt": 10}}
assert matches(flt, {"page_count": 20}) is True
assert matches(flt, {"page_count": 10}) is False
assert matches(flt, {"page_count": 5}) is False
def test_remaining_comparison_operators() -> None:
assert matches({"n": {"$gte": 10}}, {"n": 10}) is True
assert matches({"n": {"$gte": 10}}, {"n": 9}) is False
assert matches({"n": {"$lt": 10}}, {"n": 9}) is True
assert matches({"n": {"$lt": 10}}, {"n": 10}) is False
assert matches({"n": {"$lte": 10}}, {"n": 10}) is True
assert matches({"n": {"$lte": 10}}, {"n": 11}) is False
assert matches({"s": {"$eq": "FILE"}}, {"s": "FILE"}) is True
assert matches({"s": {"$eq": "FILE"}}, {"s": "WEB"}) is False
assert matches({"s": {"$ne": "FILE"}}, {"s": "WEB"}) is True
assert matches({"s": {"$ne": "FILE"}}, {"s": "FILE"}) is False
def test_multiple_operators_on_one_field_are_anded() -> None:
flt = {"n": {"$gte": 10, "$lt": 20}}
assert matches(flt, {"n": 15}) is True
assert matches(flt, {"n": 10}) is True
assert matches(flt, {"n": 20}) is False
assert matches(flt, {"n": 5}) is False
def test_in_and_nin_membership_operators() -> None:
flt_in = {"document_type": {"$in": ["FILE", "WEBPAGE"]}}
assert matches(flt_in, {"document_type": "FILE"}) is True
assert matches(flt_in, {"document_type": "SLACK"}) is False
flt_nin = {"document_type": {"$nin": ["FILE", "WEBPAGE"]}}
assert matches(flt_nin, {"document_type": "SLACK"}) is True
assert matches(flt_nin, {"document_type": "FILE"}) is False
def test_or_matches_when_any_branch_holds() -> None:
flt = {"$or": [{"document_type": "FILE"}, {"document_type": "WEBPAGE"}]}
assert matches(flt, {"document_type": "WEBPAGE"}) is True
assert matches(flt, {"document_type": "SLACK"}) is False
def test_and_matches_when_every_branch_holds() -> None:
flt = {"$and": [{"n": {"$gt": 5}}, {"n": {"$lt": 10}}]}
assert matches(flt, {"n": 7}) is True
assert matches(flt, {"n": 12}) is False
def test_not_inverts_its_subexpression() -> None:
flt = {"$not": {"document_type": "FILE"}}
assert matches(flt, {"document_type": "WEBPAGE"}) is True
assert matches(flt, {"document_type": "FILE"}) is False
def test_missing_field_never_matches_and_never_raises() -> None:
# Conservative: an absent field fails the constraint, and comparisons must
# not raise on the missing value — including $ne (absence isn't "not equal").
assert matches({"document_type": "FILE"}, {}) is False
assert matches({"page_count": {"$gt": 5}}, {}) is False
assert matches({"document_type": {"$in": ["FILE"]}}, {}) is False
assert matches({"document_type": {"$ne": "FILE"}}, {}) is False
def test_logical_operators_compose_with_fields() -> None:
flt = {
"search_space_id": 7,
"$or": [{"document_type": "FILE"}, {"document_type": "WEBPAGE"}],
}
assert matches(flt, {"search_space_id": 7, "document_type": "FILE"}) is True
assert matches(flt, {"search_space_id": 9, "document_type": "FILE"}) is False
assert matches(flt, {"search_space_id": 7, "document_type": "SLACK"}) is False
def test_unknown_field_operator_raises_filter_error() -> None:
with pytest.raises(FilterError):
matches({"n": {"$regex": "x"}}, {"n": "xyz"})
def test_unknown_logical_operator_raises_filter_error() -> None:
with pytest.raises(FilterError):
matches({"$nor": [{"document_type": "FILE"}]}, {"document_type": "FILE"})

View file

@ -0,0 +1,26 @@
"""An event hands its payload + metadata to the run as inputs."""
from __future__ import annotations
import pytest
from app.automations.triggers.builtin.event.inputs import event_runtime_inputs
from app.event_bus import Event
pytestmark = pytest.mark.unit
def test_runtime_inputs_flatten_payload_with_event_metadata() -> None:
event = Event(
event_type="document.indexed",
payload={"document_id": 42, "document_type": "FILE"},
search_space_id=7,
)
inputs = event_runtime_inputs(event)
assert inputs["document_id"] == 42
assert inputs["document_type"] == "FILE"
assert inputs["event_type"] == "document.indexed"
assert inputs["event_id"] == event.event_id
assert inputs["occurred_at"] == event.occurred_at.isoformat()

View file

@ -0,0 +1,39 @@
"""Which triggers an event fires: event_type equality + filter match."""
from __future__ import annotations
import pytest
from app.automations.triggers.builtin.event.match import trigger_matches_event
from app.event_bus import Event
pytestmark = pytest.mark.unit
def _event(event_type: str = "document.indexed", **payload) -> Event:
return Event(event_type=event_type, payload=payload, search_space_id=7)
def test_matches_when_event_type_equal_and_filter_passes() -> None:
params = {"event_type": "document.indexed", "filter": {"document_type": "FILE"}}
assert trigger_matches_event(params, _event(document_type="FILE")) is True
def test_no_match_when_event_type_differs() -> None:
params = {"event_type": "document.indexed", "filter": {}}
assert trigger_matches_event(params, _event("podcast.generated")) is False
def test_no_match_when_filter_rejects_payload() -> None:
params = {"event_type": "document.indexed", "filter": {"document_type": "FILE"}}
assert trigger_matches_event(params, _event(document_type="WEBPAGE")) is False
def test_empty_filter_matches_any_payload_of_that_type() -> None:
params = {"event_type": "document.indexed", "filter": {}}
assert trigger_matches_event(params, _event(document_type="ANYTHING")) is True
def test_missing_filter_key_is_treated_as_empty() -> None:
params = {"event_type": "document.indexed"}
assert trigger_matches_event(params, _event(document_type="X")) is True

View file

@ -0,0 +1,40 @@
"""``EventTriggerParams`` contract: an event_type to listen for + an optional filter."""
from __future__ import annotations
import pytest
from app.automations.triggers.builtin.event.params import EventTriggerParams
pytestmark = pytest.mark.unit
def test_accepts_event_type_and_filter() -> None:
params = EventTriggerParams(
event_type="document.indexed",
filter={"document_type": "FILE"},
)
assert params.event_type == "document.indexed"
assert params.filter == {"document_type": "FILE"}
def test_filter_defaults_to_empty() -> None:
params = EventTriggerParams(event_type="document.indexed")
assert params.filter == {}
def test_event_type_is_required() -> None:
with pytest.raises(ValueError):
EventTriggerParams(filter={"x": 1})
def test_event_type_must_not_be_blank() -> None:
with pytest.raises(ValueError):
EventTriggerParams(event_type="")
def test_extra_keys_are_forbidden() -> None:
with pytest.raises(ValueError):
EventTriggerParams(event_type="document.indexed", typo=True)

View file

@ -6,7 +6,7 @@ from datetime import UTC, datetime
import pytest
from app.automations.triggers.schedule.cron import (
from app.automations.triggers.builtin.schedule.cron import (
InvalidCronError,
compute_next_fire_at,
validate_cron,

View file

@ -5,7 +5,7 @@ from __future__ import annotations
import pytest
from pydantic import ValidationError
from app.automations.triggers.schedule.params import ScheduleTriggerParams
from app.automations.triggers.builtin.schedule.params import ScheduleTriggerParams
pytestmark = pytest.mark.unit

View file

@ -0,0 +1,25 @@
"""Shared fixtures for the ``app.event_bus`` unit-test tree.
The event-type catalog is a module-level registry populated at import. Tests
that register their own event types (or assert on registry contents) snapshot
and restore it so state never leaks between tests.
"""
from __future__ import annotations
from collections.abc import Iterator
import pytest
from app.event_bus.catalog import catalog
@pytest.fixture
def isolated_event_catalog() -> Iterator[None]:
"""Snapshot and restore the event-type catalog around a test."""
snapshot = dict(catalog._registry)
try:
yield
finally:
catalog._registry.clear()
catalog._registry.update(snapshot)

View file

@ -0,0 +1,181 @@
"""``EventBus`` contract: subscribe, publish (stamp + fan out), dispatch.
Each test uses a fresh ``EventBus`` no shared global state.
"""
from __future__ import annotations
import pytest
from app.event_bus import Event, EventBus
pytestmark = pytest.mark.unit
def _event() -> Event:
return Event(event_type="x.happened", payload={"k": "v"}, search_space_id=1)
async def _noop(_event: Event) -> None:
return None
async def _other(_event: Event) -> None:
return None
# --- registry -------------------------------------------------------------
def test_subscribe_then_subscribers_returns_the_handler() -> None:
bus = EventBus()
bus.subscribe(_noop)
assert _noop in bus.subscribers()
def test_subscribe_is_idempotent_for_the_same_handler() -> None:
"""Registering the same handler twice must not make it fire twice."""
bus = EventBus()
bus.subscribe(_noop)
bus.subscribe(_noop)
assert bus.subscribers().count(_noop) == 1
def test_distinct_handlers_both_register() -> None:
bus = EventBus()
bus.subscribe(_noop)
bus.subscribe(_other)
registered = bus.subscribers()
assert _noop in registered
assert _other in registered
def test_subscribers_returns_a_defensive_snapshot() -> None:
"""Mutating the returned list must not corrupt the registry."""
bus = EventBus()
bus.subscribe(_noop)
snapshot = bus.subscribers()
snapshot.clear()
assert _noop in bus.subscribers()
def test_subscribe_returns_handler_so_it_can_be_used_as_a_decorator() -> None:
bus = EventBus()
returned = bus.subscribe(_other)
assert returned is _other
def test_two_buses_do_not_share_subscribers() -> None:
"""The registry is per-instance, not global."""
a = EventBus()
b = EventBus()
a.subscribe(_noop)
assert _noop in a.subscribers()
assert _noop not in b.subscribers()
# --- dispatch -------------------------------------------------------------
async def test_dispatch_delivers_event_to_every_subscriber() -> None:
bus = EventBus()
seen: list[tuple[str, Event]] = []
async def first(event: Event) -> None:
seen.append(("first", event))
async def second(event: Event) -> None:
seen.append(("second", event))
bus.subscribe(first)
bus.subscribe(second)
event = _event()
await bus.dispatch(event)
assert ("first", event) in seen
assert ("second", event) in seen
async def test_dispatch_isolates_a_failing_subscriber() -> None:
"""A subscriber that raises must not stop a healthy one from running."""
bus = EventBus()
healthy_ran = False
async def boom(_event: Event) -> None:
raise RuntimeError("subscriber blew up")
async def healthy(_event: Event) -> None:
nonlocal healthy_ran
healthy_ran = True
bus.subscribe(boom)
bus.subscribe(healthy)
await bus.dispatch(_event())
assert healthy_ran is True
async def test_dispatch_never_propagates_subscriber_errors() -> None:
"""``dispatch`` itself must not raise even if every subscriber fails."""
bus = EventBus()
async def boom(_event: Event) -> None:
raise ValueError("nope")
bus.subscribe(boom)
await bus.dispatch(_event()) # must not raise
async def test_dispatch_with_no_subscribers_is_a_noop() -> None:
bus = EventBus()
await bus.dispatch(_event()) # must not raise
# --- publish --------------------------------------------------------------
async def test_publish_builds_a_stamped_event_and_fans_it_out() -> None:
bus = EventBus()
received: list[Event] = []
async def handler(event: Event) -> None:
received.append(event)
bus.subscribe(handler)
await bus.publish("document.indexed", {"document_id": 42}, search_space_id=7)
assert len(received) == 1
event = received[0]
assert event.event_type == "document.indexed"
assert event.payload == {"document_id": 42}
assert event.search_space_id == 7
# Engine-stamped identity/time on the way through.
assert event.event_id
assert event.occurred_at
async def test_publish_defaults_payload_to_empty_dict() -> None:
bus = EventBus()
received: list[Event] = []
async def handler(event: Event) -> None:
received.append(event)
bus.subscribe(handler)
await bus.publish("x.happened", search_space_id=1)
assert received[0].payload == {}
async def test_publish_with_no_subscribers_is_a_noop() -> None:
await EventBus().publish("x.happened", search_space_id=1) # must not raise

View file

@ -0,0 +1,73 @@
"""EventCatalog contract: register, look up, snapshot, derive schema."""
from __future__ import annotations
import pytest
from pydantic import BaseModel
from app.event_bus.catalog import EventCatalog, EventType
pytestmark = pytest.mark.unit
class _SamplePayload(BaseModel):
document_id: int
def _event_type(type_: str = "test.thing") -> EventType:
return EventType(
type=type_,
description="A thing happened.",
payload_model=_SamplePayload,
)
def test_register_then_get_returns_the_event_type(isolated_event_catalog: None) -> None:
from app.event_bus.catalog import catalog
catalog.register(_event_type())
assert catalog.get("test.thing") is not None
assert catalog.get("test.thing").type == "test.thing"
def test_get_unknown_type_returns_none(isolated_event_catalog: None) -> None:
from app.event_bus.catalog import catalog
assert catalog.get("does.not.exist") is None
def test_register_duplicate_type_raises(isolated_event_catalog: None) -> None:
"""A type is a contract; registering it twice is a bug, not an override."""
from app.event_bus.catalog import catalog
catalog.register(_event_type())
with pytest.raises(ValueError, match="already registered"):
catalog.register(_event_type())
def test_all_is_a_defensive_snapshot(isolated_event_catalog: None) -> None:
"""Mutating the returned dict must not corrupt the registry."""
from app.event_bus.catalog import catalog
catalog.register(_event_type())
snapshot = catalog.all()
snapshot.clear()
assert catalog.get("test.thing") is not None
def test_payload_schema_is_derived_from_the_payload_model() -> None:
"""The JSON Schema a UI/validator consumes comes from the payload model."""
event_type = _event_type()
assert event_type.payload_schema == _SamplePayload.model_json_schema()
def test_each_catalog_instance_has_its_own_registry() -> None:
"""Two EventCatalog instances are fully independent."""
a = EventCatalog()
b = EventCatalog()
a.register(_event_type())
assert a.get("test.thing") is not None
assert b.get("test.thing") is None

View file

@ -0,0 +1,56 @@
"""``document.entered_folder`` payload contract + catalog registration."""
from __future__ import annotations
import pytest
from app.event_bus.catalog import catalog
from app.event_bus.events.document_entered_folder import (
EVENT_TYPE,
DocumentEnteredFolderPayload,
)
pytestmark = pytest.mark.unit
def _payload(**overrides: object) -> DocumentEnteredFolderPayload:
base: dict[str, object] = {
"document_id": 42,
"folder_id": 7,
"document_type": "FILE",
"title": "Q3 report.pdf",
}
base.update(overrides)
return DocumentEnteredFolderPayload(**base)
def test_payload_carries_the_filterable_fields() -> None:
payload = _payload(connector_id=12, created_by_id="abc")
assert payload.document_id == 42
assert payload.folder_id == 7
assert payload.document_type == "FILE"
assert payload.connector_id == 12
def test_first_placement_is_not_a_move() -> None:
"""No previous folder (created or AI-sorted into place) → not a move."""
assert _payload(previous_folder_id=None).is_move is False
def test_change_between_folders_is_a_move() -> None:
assert _payload(previous_folder_id=3).is_move is True
def test_is_move_is_serialized_for_filtering() -> None:
"""Filters match against the dumped payload, so ``is_move`` must appear there."""
dumped = _payload(previous_folder_id=3).model_dump()
assert dumped["is_move"] is True
def test_event_type_is_registered_in_the_catalog() -> None:
registered = catalog.get(EVENT_TYPE)
assert registered is not None
assert registered.payload_model is DocumentEnteredFolderPayload

View file

@ -0,0 +1,58 @@
"""payload_if_entered_folder: decides whether a document commit warrants an event."""
from __future__ import annotations
from typing import Any
import pytest
from app.event_bus.events.document_entered_folder import payload_if_entered_folder
pytestmark = pytest.mark.unit
def _call(**overrides: Any) -> dict[str, Any] | None:
defaults: dict[str, Any] = {
"document_id": 1,
"search_space_id": 10,
"new_folder_id": 7,
"previous_folder_id": None,
"folder_id_changed": True,
"status_state": "ready",
"document_type": "FILE",
"title": "report.pdf",
"connector_id": None,
"created_by_id": None,
}
defaults.update(overrides)
return payload_if_entered_folder(**defaults)
def test_folder_set_ready_fires() -> None:
result = _call()
assert result is not None
assert result["event_type"] == "document.entered_folder"
assert result["search_space_id"] == 10
assert result["payload"]["folder_id"] == 7
assert result["payload"]["previous_folder_id"] is None
def test_no_folder_is_silent() -> None:
assert _call(new_folder_id=None) is None
def test_not_ready_is_silent() -> None:
assert _call(status_state="processing") is None
def test_folder_unchanged_is_silent() -> None:
assert _call(folder_id_changed=False) is None
def test_move_carries_previous_folder_id() -> None:
result = _call(previous_folder_id=3, new_folder_id=7)
assert result is not None
assert result["payload"]["previous_folder_id"] == 3
assert result["payload"]["folder_id"] == 7

View file

@ -0,0 +1,53 @@
"""``Event`` contract: carry caller facts + engine-stamped id/time, round-trip JSON."""
from __future__ import annotations
from datetime import datetime
import pytest
from app.event_bus.event import Event
pytestmark = pytest.mark.unit
def test_event_carries_caller_supplied_facts() -> None:
"""The three caller inputs are stored verbatim."""
event = Event(
event_type="document.indexed",
payload={"document_id": 42, "content_type": "pdf"},
search_space_id=7,
)
assert event.event_type == "document.indexed"
assert event.payload == {"document_id": 42, "content_type": "pdf"}
assert event.search_space_id == 7
def test_event_stamps_identity_and_time_when_not_supplied() -> None:
"""Engine stamps id + time so subscribers can dedup/order."""
event = Event(event_type="x.happened", payload={}, search_space_id=1)
assert event.event_id
assert isinstance(event.occurred_at, datetime)
def test_event_ids_are_unique_per_instance() -> None:
"""Two events published with identical content are still distinct facts."""
first = Event(event_type="x.happened", payload={}, search_space_id=1)
second = Event(event_type="x.happened", payload={}, search_space_id=1)
assert first.event_id != second.event_id
def test_event_survives_json_round_trip() -> None:
"""Serialize → deserialize reproduces the event (subscribers queue it as JSON)."""
original = Event(
event_type="podcast.generated",
payload={"podcast_id": 9, "duration_s": 123.5},
search_space_id=3,
)
restored = Event.model_validate_json(original.model_dump_json())
assert restored == original

View file

@ -35,6 +35,23 @@ export interface JsonViewProps {
className?: string;
}
/** Recursively coerce string values that are valid JSON numbers back to numbers.
* react-json-view's text input always yields strings; this restores the
* correct type so filters like ``{ "folder_id": 56 }`` survive editing. */
function coerceNumbers(value: unknown): unknown {
if (typeof value === "string") {
const n = Number(value);
return !Number.isNaN(n) && value.trim() !== "" ? n : value;
}
if (Array.isArray(value)) return value.map(coerceNumbers);
if (value && typeof value === "object") {
return Object.fromEntries(
Object.entries(value as Record<string, unknown>).map(([k, v]) => [k, coerceNumbers(v)])
);
}
return value;
}
const DARK_THEME = "monokai" as const;
const LIGHT_THEME = "rjv-default" as const;
@ -67,7 +84,7 @@ export function JsonView({
const handleChange = useCallback(
(interaction: InteractionProps) => {
onChange?.(interaction.updated_src);
onChange?.(coerceNumbers(interaction.updated_src));
return true;
},
[onChange]

View file

@ -7,7 +7,7 @@ import { z } from "zod";
export const automationStatus = z.enum(["active", "paused", "archived"]);
export type AutomationStatus = z.infer<typeof automationStatus>;
export const triggerType = z.enum(["schedule", "manual"]);
export const triggerType = z.enum(["schedule", "manual", "event"]);
export type TriggerType = z.infer<typeof triggerType>;
export const runStatus = z.enum([