mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-31 19:45:15 +02:00
Merge commit '7972901f15' into dev_mod
This commit is contained in:
commit
80daf46fbf
74 changed files with 1681 additions and 234 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -18,3 +18,5 @@ surfsense_web/test-results/
|
||||||
surfsense_web/blob-report/
|
surfsense_web/blob-report/
|
||||||
|
|
||||||
content_research/
|
content_research/
|
||||||
|
automation-design-plan.md
|
||||||
|
automation-frontend-builder-plan.md
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
"""Add 'event' to automation_trigger_type enum
|
||||||
|
|
||||||
|
Revision ID: 147
|
||||||
|
Revises: 146
|
||||||
|
Create Date: 2026-05-29
|
||||||
|
|
||||||
|
Adds the ``event`` value to the ``automation_trigger_type`` enum so automations
|
||||||
|
can be triggered by published domain events, alongside the existing
|
||||||
|
``schedule`` triggers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "147"
|
||||||
|
down_revision: str | None = "146"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
ENUM_NAME = "automation_trigger_type"
|
||||||
|
NEW_VALUE = "event"
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Safely add 'event' to automation_trigger_type enum if missing."""
|
||||||
|
op.execute(
|
||||||
|
f"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM pg_type t
|
||||||
|
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||||
|
WHERE t.typname = '{ENUM_NAME}' AND e.enumlabel = '{NEW_VALUE}'
|
||||||
|
) THEN
|
||||||
|
ALTER TYPE {ENUM_NAME} ADD VALUE '{NEW_VALUE}';
|
||||||
|
END IF;
|
||||||
|
END
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""No-op: PostgreSQL does not support removing enum values."""
|
||||||
|
pass
|
||||||
|
|
@ -43,6 +43,7 @@ from app.rate_limiter import get_real_client_ip, limiter
|
||||||
from app.routes import router as crud_router
|
from app.routes import router as crud_router
|
||||||
from app.routes.auth_routes import router as auth_router
|
from app.routes.auth_routes import router as auth_router
|
||||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||||
|
from app.session_events import register_session_hooks
|
||||||
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
||||||
from app.utils.perf import log_system_snapshot
|
from app.utils.perf import log_system_snapshot
|
||||||
|
|
||||||
|
|
@ -588,6 +589,7 @@ async def lifespan(app: FastAPI):
|
||||||
"first real request will pay the full compile cost."
|
"first real request will pay the full compile cost."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register_session_hooks()
|
||||||
log_system_snapshot("startup_complete")
|
log_system_snapshot("startup_complete")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
|
||||||
|
|
@ -21,4 +21,4 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
# Built-in actions self-register at import time.
|
# Built-in actions self-register at import time.
|
||||||
from . import agent_task # noqa: F401
|
from . import builtin # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Built-in action types — each in its own subpackage, self-registering at import."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from . import agent_task # noqa: F401
|
||||||
|
|
@ -2,8 +2,8 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from ..store import register_action
|
from ...store import register_action
|
||||||
from ..types import ActionDefinition
|
from ...types import ActionDefinition
|
||||||
from .factory import build_handler
|
from .factory import build_handler
|
||||||
from .params import AgentTaskActionParams
|
from .params import AgentTaskActionParams
|
||||||
|
|
||||||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ..types import ActionContext, ActionHandler
|
from ...types import ActionContext, ActionHandler
|
||||||
from .invoke import run_agent_task
|
from .invoke import run_agent_task
|
||||||
from .params import AgentTaskActionParams
|
from .params import AgentTaskActionParams
|
||||||
|
|
||||||
|
|
@ -16,7 +16,7 @@ from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in
|
||||||
from app.db import ChatVisibility, async_session_maker
|
from app.db import ChatVisibility, async_session_maker
|
||||||
from app.schemas.new_chat import MentionedDocumentInfo
|
from app.schemas.new_chat import MentionedDocumentInfo
|
||||||
|
|
||||||
from ..types import ActionContext
|
from ...types import ActionContext
|
||||||
from .auto_decide import build_auto_decisions
|
from .auto_decide import build_auto_decisions
|
||||||
from .dependencies import build_dependencies
|
from .dependencies import build_dependencies
|
||||||
from .finalize import extract_final_assistant_message
|
from .finalize import extract_final_assistant_message
|
||||||
|
|
@ -3,6 +3,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .errors import DispatchError
|
from .errors import DispatchError
|
||||||
from .run import dispatch_run
|
from .launch import launch_run
|
||||||
|
|
||||||
__all__ = ["DispatchError", "dispatch_run"]
|
__all__ = ["DispatchError", "launch_run"]
|
||||||
|
|
|
||||||
43
surfsense_backend/app/automations/dispatch/inputs.py
Normal file
43
surfsense_backend/app/automations/dispatch/inputs.py
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
"""Merge and validate the inputs a run starts with."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import jsonschema
|
||||||
|
|
||||||
|
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||||
|
from app.automations.schemas.definition.envelope import AutomationDefinition
|
||||||
|
|
||||||
|
from .errors import DispatchError
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_inputs(
|
||||||
|
definition: AutomationDefinition,
|
||||||
|
trigger: AutomationTrigger,
|
||||||
|
runtime_inputs: dict[str, Any] | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Merge ``trigger.static_inputs`` over ``runtime_inputs``, then validate.
|
||||||
|
|
||||||
|
Static inputs win on key collision.
|
||||||
|
"""
|
||||||
|
merged = {**(runtime_inputs or {}), **(trigger.static_inputs or {})}
|
||||||
|
return validate_inputs(definition, merged)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_inputs(
|
||||||
|
definition: AutomationDefinition, inputs: dict[str, Any]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Validate ``inputs`` against the definition's optional declared schema.
|
||||||
|
|
||||||
|
No declared schema → pass through unchanged so runtime keys (``fired_at``,
|
||||||
|
``last_fired_at``, ...) still reach the template context. A declared schema
|
||||||
|
that the inputs violate is surfaced as ``DispatchError``.
|
||||||
|
"""
|
||||||
|
if definition.inputs is None or not definition.inputs.schema_:
|
||||||
|
return inputs
|
||||||
|
try:
|
||||||
|
jsonschema.validate(instance=inputs, schema=definition.inputs.schema_)
|
||||||
|
except jsonschema.ValidationError as exc:
|
||||||
|
raise DispatchError(f"inputs: {exc.message}") from exc
|
||||||
|
return inputs
|
||||||
60
surfsense_backend/app/automations/dispatch/launch.py
Normal file
60
surfsense_backend/app/automations/dispatch/launch.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
"""Launch a run for a trigger that fired: resolve, validate, persist, enqueue.
|
||||||
|
|
||||||
|
The trigger-facing entry every selector calls. A selector builds the runtime
|
||||||
|
inputs and hands one trigger row here; this resolves and guards its automation,
|
||||||
|
snapshots the definition onto a PENDING run, and enqueues execution. The
|
||||||
|
snapshot makes the run immune to later edits of the automation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.automations.persistence.enums.run_status import RunStatus
|
||||||
|
from app.automations.persistence.models.run import AutomationRun
|
||||||
|
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||||
|
from app.automations.schemas.definition.envelope import AutomationDefinition
|
||||||
|
from app.automations.tasks.execute_run import automation_run_execute
|
||||||
|
|
||||||
|
from .errors import DispatchError
|
||||||
|
from .inputs import prepare_inputs
|
||||||
|
from .resolve import resolve_active_automation
|
||||||
|
|
||||||
|
|
||||||
|
async def launch_run(
|
||||||
|
*,
|
||||||
|
session: AsyncSession,
|
||||||
|
trigger: AutomationTrigger,
|
||||||
|
runtime_inputs: dict[str, Any] | None = None,
|
||||||
|
) -> AutomationRun:
|
||||||
|
"""Resolve ``trigger``'s active automation and enqueue a PENDING run for it."""
|
||||||
|
automation = await resolve_active_automation(session, trigger)
|
||||||
|
|
||||||
|
try:
|
||||||
|
definition = AutomationDefinition.model_validate(automation.definition)
|
||||||
|
except Exception as exc:
|
||||||
|
raise DispatchError(f"invalid automation definition: {exc}") from exc
|
||||||
|
|
||||||
|
inputs = prepare_inputs(definition, trigger, runtime_inputs)
|
||||||
|
snapshot = definition.model_dump(mode="json", by_alias=True)
|
||||||
|
|
||||||
|
run = AutomationRun(
|
||||||
|
automation_id=automation.id,
|
||||||
|
trigger_id=trigger.id,
|
||||||
|
status=RunStatus.PENDING,
|
||||||
|
definition_snapshot=snapshot,
|
||||||
|
inputs=inputs,
|
||||||
|
step_results=[],
|
||||||
|
artifacts=[],
|
||||||
|
)
|
||||||
|
session.add(run)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(run)
|
||||||
|
|
||||||
|
automation_run_execute.apply_async(
|
||||||
|
args=[run.id],
|
||||||
|
time_limit=definition.execution.timeout_seconds,
|
||||||
|
)
|
||||||
|
return run
|
||||||
40
surfsense_backend/app/automations/dispatch/resolve.py
Normal file
40
surfsense_backend/app/automations/dispatch/resolve.py
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
"""Resolve the automation behind a trigger and guard that it may run."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.automations.persistence.enums.automation_status import AutomationStatus
|
||||||
|
from app.automations.persistence.models.automation import Automation
|
||||||
|
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||||
|
|
||||||
|
from .errors import DispatchError
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_active_automation(
|
||||||
|
session: AsyncSession, trigger: AutomationTrigger
|
||||||
|
) -> Automation:
|
||||||
|
"""Load ``trigger``'s automation and require it ``ACTIVE``.
|
||||||
|
|
||||||
|
Raises ``DispatchError`` if the automation is missing or not active.
|
||||||
|
"""
|
||||||
|
automation = await _load_automation(session, trigger.automation_id)
|
||||||
|
if automation is None:
|
||||||
|
raise DispatchError(
|
||||||
|
f"automation {trigger.automation_id} not found for trigger {trigger.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if automation.status != AutomationStatus.ACTIVE:
|
||||||
|
raise DispatchError(
|
||||||
|
f"automation {trigger.automation_id} is {automation.status.value}, not active"
|
||||||
|
)
|
||||||
|
|
||||||
|
return automation
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_automation(
|
||||||
|
session: AsyncSession, automation_id: int
|
||||||
|
) -> Automation | None:
|
||||||
|
stmt = select(Automation).where(Automation.id == automation_id)
|
||||||
|
return (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
|
@ -1,83 +0,0 @@
|
||||||
"""Generic run dispatch: validate, snapshot, persist, enqueue. Shared by every trigger."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import jsonschema
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.automations.persistence.enums.run_status import RunStatus
|
|
||||||
from app.automations.persistence.models.automation import Automation
|
|
||||||
from app.automations.persistence.models.run import AutomationRun
|
|
||||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
|
||||||
from app.automations.schemas.definition.envelope import AutomationDefinition
|
|
||||||
from app.automations.tasks.execute_run import automation_run_execute
|
|
||||||
|
|
||||||
from .errors import DispatchError
|
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_run(
|
|
||||||
*,
|
|
||||||
session: AsyncSession,
|
|
||||||
automation: Automation,
|
|
||||||
trigger: AutomationTrigger,
|
|
||||||
runtime_inputs: dict[str, Any] | None = None,
|
|
||||||
) -> AutomationRun:
|
|
||||||
"""Validate, snapshot the definition, persist an ``AutomationRun``, enqueue execution.
|
|
||||||
|
|
||||||
Final inputs = ``trigger.static_inputs`` merged with ``runtime_inputs``,
|
|
||||||
static winning on key collision. The merged dict is validated against
|
|
||||||
``automation.definition.inputs.schema_`` and stored on the run.
|
|
||||||
|
|
||||||
Callers (trigger-specific adapters) are responsible for resolving
|
|
||||||
``automation`` and ``trigger`` and for the trigger-side ``ACTIVE`` /
|
|
||||||
``enabled`` guards. This function only handles what's identical across
|
|
||||||
every trigger type.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
definition = AutomationDefinition.model_validate(automation.definition)
|
|
||||||
except Exception as exc:
|
|
||||||
raise DispatchError(f"invalid automation definition: {exc}") from exc
|
|
||||||
|
|
||||||
merged_inputs = {**(runtime_inputs or {}), **(trigger.static_inputs or {})}
|
|
||||||
validated_inputs = _validate_inputs(definition, merged_inputs)
|
|
||||||
snapshot = definition.model_dump(mode="json", by_alias=True)
|
|
||||||
|
|
||||||
run = AutomationRun(
|
|
||||||
automation_id=automation.id,
|
|
||||||
trigger_id=trigger.id,
|
|
||||||
status=RunStatus.PENDING,
|
|
||||||
definition_snapshot=snapshot,
|
|
||||||
inputs=validated_inputs,
|
|
||||||
step_results=[],
|
|
||||||
artifacts=[],
|
|
||||||
)
|
|
||||||
session.add(run)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(run)
|
|
||||||
|
|
||||||
automation_run_execute.apply_async(
|
|
||||||
args=[run.id],
|
|
||||||
time_limit=definition.execution.timeout_seconds,
|
|
||||||
)
|
|
||||||
return run
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_inputs(
|
|
||||||
definition: AutomationDefinition, inputs: dict[str, Any]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate merged inputs against the optional declared schema.
|
|
||||||
|
|
||||||
No declared schema → pass through (runtime inputs like ``fired_at`` /
|
|
||||||
``last_fired_at`` and trigger ``static_inputs`` must still reach the
|
|
||||||
template context). Returning ``{}`` here strips them and makes Jinja
|
|
||||||
blow up on any ``{{ inputs.* }}`` reference.
|
|
||||||
"""
|
|
||||||
if definition.inputs is None or not definition.inputs.schema_:
|
|
||||||
return inputs
|
|
||||||
try:
|
|
||||||
jsonschema.validate(instance=inputs, schema=definition.inputs.schema_)
|
|
||||||
except jsonschema.ValidationError as exc:
|
|
||||||
raise DispatchError(f"inputs: {exc.message}") from exc
|
|
||||||
return inputs
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
"""Trigger-kind discriminator.
|
"""Trigger-kind discriminator.
|
||||||
|
|
||||||
v1 only registers ``schedule``. ``manual`` is reserved in the enum (mirrors the
|
``schedule`` and ``event`` are registered. ``manual`` is reserved in the enum
|
||||||
postgres enum) but is intentionally unregistered pending a redesign of the
|
(mirrors the postgres enum) but is intentionally unregistered pending a redesign
|
||||||
"Run now" UX.
|
of the "Run now" UX.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -12,4 +12,5 @@ from enum import StrEnum
|
||||||
|
|
||||||
class TriggerType(StrEnum):
|
class TriggerType(StrEnum):
|
||||||
SCHEDULE = "schedule"
|
SCHEDULE = "schedule"
|
||||||
|
EVENT = "event"
|
||||||
MANUAL = "manual"
|
MANUAL = "manual"
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from app.automations.services.model_policy import (
|
||||||
get_automation_model_eligibility,
|
get_automation_model_eligibility,
|
||||||
)
|
)
|
||||||
from app.automations.triggers import get_trigger
|
from app.automations.triggers import get_trigger
|
||||||
from app.automations.triggers.schedule import compute_next_fire_at
|
from app.automations.triggers.builtin.schedule import compute_next_fire_at
|
||||||
from app.db import Permission, SearchSpace, User, get_async_session
|
from app.db import Permission, SearchSpace, User, get_async_session
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.rbac import check_permission
|
from app.utils.rbac import check_permission
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from app.automations.persistence.models.automation import Automation
|
||||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||||
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
|
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
|
||||||
from app.automations.triggers import get_trigger
|
from app.automations.triggers import get_trigger
|
||||||
from app.automations.triggers.schedule import compute_next_fire_at
|
from app.automations.triggers.builtin.schedule import compute_next_fire_at
|
||||||
from app.db import Permission, User, get_async_session
|
from app.db import Permission, User, get_async_session
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.rbac import check_permission
|
from app.utils.rbac import check_permission
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
"""Triggers domain: registry surface + built-in trigger packages.
|
"""Triggers domain: registry surface + built-in trigger packages.
|
||||||
|
|
||||||
Each trigger lives in its own subpackage (``schedule/``, ...) and
|
Built-in trigger types live under ``builtin/`` and self-register at import time.
|
||||||
self-registers at import time via its ``definition`` module.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -17,4 +16,4 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
# Built-in triggers self-register at import time.
|
# Built-in triggers self-register at import time.
|
||||||
from . import schedule # noqa: F401
|
from . import builtin # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Built-in trigger types — each in its own subpackage, self-registering at import."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from . import event, schedule # noqa: F401
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
"""``event`` trigger: fire an automation when a matching domain event is published.
|
||||||
|
|
||||||
|
Subscribes to the event bus and matches events against a user-authored JSON
|
||||||
|
filter (see :mod:`.filter`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.event_bus import bus
|
||||||
|
|
||||||
|
from .filter import FilterError, matches
|
||||||
|
from .inputs import event_runtime_inputs
|
||||||
|
from .match import trigger_matches_event
|
||||||
|
from .params import EventTriggerParams
|
||||||
|
from .source import on_event
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EventTriggerParams",
|
||||||
|
"FilterError",
|
||||||
|
"event_runtime_inputs",
|
||||||
|
"matches",
|
||||||
|
"trigger_matches_event",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Side-effect: register on the triggers store.
|
||||||
|
from . import definition # noqa: F401
|
||||||
|
|
||||||
|
# Side-effect: react to published events.
|
||||||
|
bus.subscribe(on_event)
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""``event`` ``TriggerDefinition`` registration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.automations.triggers.store import register_trigger
|
||||||
|
from app.automations.triggers.types import TriggerDefinition
|
||||||
|
|
||||||
|
from .params import EventTriggerParams
|
||||||
|
|
||||||
|
EVENT_TRIGGER = TriggerDefinition(
|
||||||
|
type="event",
|
||||||
|
description="Fire when a matching domain event is published.",
|
||||||
|
params_model=EventTriggerParams,
|
||||||
|
)
|
||||||
|
|
||||||
|
register_trigger(EVENT_TRIGGER)
|
||||||
|
|
@ -0,0 +1,78 @@
|
||||||
|
"""Pure JSON filter grammar: ``matches(filter_expr, payload) -> bool``.
|
||||||
|
|
||||||
|
The ``event`` trigger uses it to decide whether an event fires the automation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import operator
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class FilterError(ValueError):
|
||||||
|
"""Unknown operator in a filter. Raised (not silently false) so a bad filter
|
||||||
|
fails at authoring time instead of quietly disabling the trigger."""
|
||||||
|
|
||||||
|
|
||||||
|
# Scalar comparison operators: (actual, operand) -> bool.
|
||||||
|
_COMPARATORS: dict[str, Callable[[Any, Any], bool]] = {
|
||||||
|
"$eq": operator.eq,
|
||||||
|
"$ne": operator.ne,
|
||||||
|
"$gt": operator.gt,
|
||||||
|
"$gte": operator.ge,
|
||||||
|
"$lt": operator.lt,
|
||||||
|
"$lte": operator.le,
|
||||||
|
"$in": lambda actual, operand: actual in operand,
|
||||||
|
"$nin": lambda actual, operand: actual not in operand,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sentinel for "the payload has no such field" — distinct from a present None.
|
||||||
|
_MISSING = object()
|
||||||
|
|
||||||
|
|
||||||
|
def matches(filter_expr: dict[str, Any], payload: dict[str, Any]) -> bool:
|
||||||
|
"""Return ``True`` when ``payload`` satisfies every constraint in ``filter_expr``.
|
||||||
|
|
||||||
|
An empty filter expresses "no constraints" and matches every payload.
|
||||||
|
Sibling keys (fields and logical operators alike) are ANDed together.
|
||||||
|
"""
|
||||||
|
for key, value in filter_expr.items():
|
||||||
|
if key == "$and":
|
||||||
|
if not all(matches(sub, payload) for sub in value):
|
||||||
|
return False
|
||||||
|
elif key == "$or":
|
||||||
|
if not any(matches(sub, payload) for sub in value):
|
||||||
|
return False
|
||||||
|
elif key == "$not":
|
||||||
|
if matches(value, payload):
|
||||||
|
return False
|
||||||
|
elif key.startswith("$"):
|
||||||
|
raise FilterError(f"unknown logical operator: {key}")
|
||||||
|
elif not _match_condition(value, payload.get(key, _MISSING)):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _match_condition(condition: Any, actual: Any) -> bool:
|
||||||
|
"""Match one field's ``actual`` value against its ``condition``.
|
||||||
|
|
||||||
|
A dict condition is an operator object (``{"$gt": 10}``); every operator in
|
||||||
|
it must hold. Any other value is an implicit equality check. A field absent
|
||||||
|
from the payload (``actual is _MISSING``) fails every constraint.
|
||||||
|
"""
|
||||||
|
if actual is _MISSING:
|
||||||
|
return False
|
||||||
|
if isinstance(condition, dict):
|
||||||
|
return all(
|
||||||
|
_apply_operator(op, operand, actual)
|
||||||
|
for op, operand in condition.items()
|
||||||
|
)
|
||||||
|
return actual == condition
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_operator(op: str, operand: Any, actual: Any) -> bool:
|
||||||
|
comparator = _COMPARATORS.get(op)
|
||||||
|
if comparator is not None:
|
||||||
|
return comparator(actual, operand)
|
||||||
|
raise FilterError(f"unknown operator: {op}")
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
"""Build run inputs from a published event."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.event_bus import Event
|
||||||
|
|
||||||
|
|
||||||
|
def event_runtime_inputs(event: Event) -> dict[str, Any]:
|
||||||
|
"""Flatten the event payload and stamp event metadata as run inputs."""
|
||||||
|
return {
|
||||||
|
**event.payload,
|
||||||
|
"event_type": event.event_type,
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"occurred_at": event.occurred_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Pure predicate: does an event trigger fire for a given event?"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.event_bus import Event
|
||||||
|
|
||||||
|
from .filter import matches
|
||||||
|
|
||||||
|
|
||||||
|
def trigger_matches_event(params: dict[str, Any], event: Event) -> bool:
|
||||||
|
"""True when an event trigger configured with ``params`` should fire for ``event``."""
|
||||||
|
if params.get("event_type") != event.event_type:
|
||||||
|
return False
|
||||||
|
return matches(params.get("filter") or {}, event.payload)
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
"""``EventTriggerParams`` — params for the ``event`` trigger type."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class EventTriggerParams(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
event_type: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=1,
|
||||||
|
description="Event type to listen for.",
|
||||||
|
examples=["document.indexed"],
|
||||||
|
)
|
||||||
|
filter: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="JSON filter matched against the event payload.",
|
||||||
|
examples=[{"document_type": "FILE"}],
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,75 @@
|
||||||
|
"""Event selector (worker task): pick the triggers an event fires, start each.
|
||||||
|
|
||||||
|
The source enqueues this with a serialized event. Here we load the enabled
|
||||||
|
``event`` triggers for that event type, keep the ones whose filter matches the
|
||||||
|
payload, and start a run for each. Per-trigger failures are isolated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.automations.dispatch import launch_run
|
||||||
|
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||||
|
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
from app.event_bus import Event
|
||||||
|
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||||
|
|
||||||
|
from .inputs import event_runtime_inputs
|
||||||
|
from .match import trigger_matches_event
|
||||||
|
from .source import TASK_NAME
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(name=TASK_NAME)
|
||||||
|
def automation_event_select(event: dict[str, Any]) -> None:
|
||||||
|
"""Select and start the runs an event fires."""
|
||||||
|
return run_async_celery_task(lambda: _select_and_start(event))
|
||||||
|
|
||||||
|
|
||||||
|
async def _select_and_start(event_dict: dict[str, Any]) -> None:
|
||||||
|
event = Event.model_validate(event_dict)
|
||||||
|
session_maker = get_celery_session_maker()
|
||||||
|
async with session_maker() as session:
|
||||||
|
for trigger in await _eligible(session, event=event):
|
||||||
|
await _start_one(session, trigger=trigger, event=event)
|
||||||
|
|
||||||
|
|
||||||
|
async def _eligible(
|
||||||
|
session: AsyncSession, *, event: Event
|
||||||
|
) -> list[AutomationTrigger]:
|
||||||
|
"""Enabled ``event`` triggers for this event type whose filter matches."""
|
||||||
|
stmt = select(AutomationTrigger).where(
|
||||||
|
AutomationTrigger.type == TriggerType.EVENT,
|
||||||
|
AutomationTrigger.enabled.is_(True),
|
||||||
|
AutomationTrigger.params["event_type"].astext == event.event_type,
|
||||||
|
)
|
||||||
|
triggers = (await session.execute(stmt)).scalars().all()
|
||||||
|
return [t for t in triggers if trigger_matches_event(t.params, event)]
|
||||||
|
|
||||||
|
|
||||||
|
async def _start_one(
|
||||||
|
session: AsyncSession, *, trigger: AutomationTrigger, event: Event
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
run = await launch_run(
|
||||||
|
session=session,
|
||||||
|
trigger=trigger,
|
||||||
|
runtime_inputs=event_runtime_inputs(event),
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"event fire: trigger=%d automation=%d run=%d event=%s",
|
||||||
|
trigger.id,
|
||||||
|
trigger.automation_id,
|
||||||
|
run.id,
|
||||||
|
event.event_id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("event fire failed for trigger %d", trigger.id)
|
||||||
|
await session.rollback()
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
"""Event trigger source: the bus subscriber that enqueues the selector.
|
||||||
|
|
||||||
|
Runs in whatever process published the event, so it stays thin — it only hands
|
||||||
|
the event to a worker (the selector does the DB matching).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.event_bus import Event
|
||||||
|
|
||||||
|
TASK_NAME = "automation_event_select"
|
||||||
|
|
||||||
|
|
||||||
|
async def on_event(event: Event) -> None:
|
||||||
|
"""Enqueue the selector for ``event``."""
|
||||||
|
# Lazy import: keeps app.celery_app out of the triggers-package import graph.
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
|
||||||
|
celery_app.send_task(TASK_NAME, kwargs={"event": event.model_dump(mode="json")})
|
||||||
|
|
@ -3,14 +3,12 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .cron import InvalidCronError, compute_next_fire_at, validate_cron
|
from .cron import InvalidCronError, compute_next_fire_at, validate_cron
|
||||||
from .dispatch import dispatch_schedule_run
|
|
||||||
from .params import ScheduleTriggerParams
|
from .params import ScheduleTriggerParams
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InvalidCronError",
|
"InvalidCronError",
|
||||||
"ScheduleTriggerParams",
|
"ScheduleTriggerParams",
|
||||||
"compute_next_fire_at",
|
"compute_next_fire_at",
|
||||||
"dispatch_schedule_run",
|
|
||||||
"validate_cron",
|
"validate_cron",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -2,8 +2,9 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from ..store import register_trigger
|
from app.automations.triggers.store import register_trigger
|
||||||
from ..types import TriggerDefinition
|
from app.automations.triggers.types import TriggerDefinition
|
||||||
|
|
||||||
from .params import ScheduleTriggerParams
|
from .params import ScheduleTriggerParams
|
||||||
|
|
||||||
SCHEDULE_TRIGGER = TriggerDefinition(
|
SCHEDULE_TRIGGER = TriggerDefinition(
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""Build run inputs from a schedule fire."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def schedule_runtime_inputs(
|
||||||
|
*,
|
||||||
|
fired_at: datetime,
|
||||||
|
scheduled_for: datetime,
|
||||||
|
previous_last_fired_at: datetime | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Calendar context for a scheduled run.
|
||||||
|
|
||||||
|
- ``fired_at`` — actual fire time
|
||||||
|
- ``scheduled_for`` — cron-derived target time for this fire
|
||||||
|
- ``last_fired_at`` — previous fire time, or null on first fire
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"fired_at": fired_at.isoformat(),
|
||||||
|
"scheduled_for": scheduled_for.isoformat(),
|
||||||
|
"last_fired_at": (
|
||||||
|
previous_last_fired_at.isoformat() if previous_last_fired_at else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
@ -1,15 +1,12 @@
|
||||||
"""Celery Beat tick that fires due ``schedule`` triggers.
|
"""Schedule selector (worker task): claim due triggers and start each.
|
||||||
|
|
||||||
Runs every minute. Each tick performs two passes:
|
Beat ticks this every minute. Two passes:
|
||||||
|
|
||||||
1. **Self-heal**: enabled schedule triggers with NULL ``next_fire_at`` get
|
1. **Self-heal**: enabled schedule triggers with NULL ``next_fire_at`` get it
|
||||||
it computed from their ``cron`` + ``timezone`` (e.g. fresh inserts or
|
computed from their ``cron`` + ``timezone`` (fresh inserts, restored rows).
|
||||||
rows restored from backup).
|
2. **Claim & start**: due rows are locked ``FOR UPDATE SKIP LOCKED``, their
|
||||||
2. **Claim & fire**: due rows are locked with ``FOR UPDATE SKIP LOCKED``,
|
``next_fire_at`` is advanced and ``last_fired_at`` set, and a run is started
|
||||||
their ``next_fire_at`` is advanced and ``last_fired_at`` is set, and
|
for each. A missed fire stays missed (``catchup=False`` semantics).
|
||||||
``dispatch_schedule_run`` is invoked for each. Dispatch errors are
|
|
||||||
logged; a missed fire stays missed (matches K8s CronJob / Airflow
|
|
||||||
``catchup=False`` semantics).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -21,19 +18,17 @@ from datetime import UTC, datetime
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.automations.dispatch import launch_run
|
||||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||||
from app.automations.triggers.schedule import (
|
|
||||||
InvalidCronError,
|
|
||||||
compute_next_fire_at,
|
|
||||||
dispatch_schedule_run,
|
|
||||||
)
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
from .cron import InvalidCronError, compute_next_fire_at
|
||||||
|
from .inputs import schedule_runtime_inputs
|
||||||
|
from .source import TASK_NAME
|
||||||
|
|
||||||
TASK_NAME = "automation_schedule_tick"
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Cap rows touched per tick so a backlog of due triggers can't starve the
|
# Cap rows touched per tick so a backlog of due triggers can't starve the
|
||||||
# worker; remaining rows fire on the next tick.
|
# worker; remaining rows fire on the next tick.
|
||||||
|
|
@ -50,8 +45,8 @@ class _Claim:
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name=TASK_NAME)
|
@celery_app.task(name=TASK_NAME)
|
||||||
def automation_schedule_tick() -> None:
|
def automation_schedule_select() -> None:
|
||||||
"""Tick once: self-heal NULL next_fire_at, claim due rows, fire each."""
|
"""Tick once: self-heal NULL next_fire_at, claim due rows, start each."""
|
||||||
return run_async_celery_task(_tick)
|
return run_async_celery_task(_tick)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -67,7 +62,7 @@ async def _tick() -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
for claim in claims:
|
for claim in claims:
|
||||||
await _fire_one(session, claim=claim, fired_at=now)
|
await _start_one(session, claim=claim, fired_at=now)
|
||||||
|
|
||||||
|
|
||||||
async def _self_heal_null_next_fire(session: AsyncSession, *, now: datetime) -> None:
|
async def _self_heal_null_next_fire(session: AsyncSession, *, now: datetime) -> None:
|
||||||
|
|
@ -155,21 +150,23 @@ async def _claim_due_triggers(session: AsyncSession, *, now: datetime) -> list[_
|
||||||
return claims
|
return claims
|
||||||
|
|
||||||
|
|
||||||
async def _fire_one(
|
async def _start_one(
|
||||||
session: AsyncSession, *, claim: _Claim, fired_at: datetime
|
session: AsyncSession, *, claim: _Claim, fired_at: datetime
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Reload the trigger post-commit and dispatch a run for it."""
|
"""Reload the trigger post-commit and start a run for it."""
|
||||||
trigger = await session.get(AutomationTrigger, claim.trigger_id)
|
trigger = await session.get(AutomationTrigger, claim.trigger_id)
|
||||||
if trigger is None:
|
if trigger is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
run = await dispatch_schedule_run(
|
run = await launch_run(
|
||||||
session=session,
|
session=session,
|
||||||
trigger=trigger,
|
trigger=trigger,
|
||||||
fired_at=fired_at,
|
runtime_inputs=schedule_runtime_inputs(
|
||||||
scheduled_for=claim.scheduled_for,
|
fired_at=fired_at,
|
||||||
previous_last_fired_at=claim.previous_last_fired_at,
|
scheduled_for=claim.scheduled_for,
|
||||||
|
previous_last_fired_at=claim.previous_last_fired_at,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"scheduled fire: trigger=%d automation=%d run=%d",
|
"scheduled fire: trigger=%d automation=%d run=%d",
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
"""Schedule trigger source: Celery Beat ticks the selector every minute.
|
||||||
|
|
||||||
|
``BEAT_SCHEDULE`` is merged into ``celery_app.conf.beat_schedule``. Per-row cron
|
||||||
|
math is precomputed (the ``next_fire_at`` column), so each tick is an indexed
|
||||||
|
lookup rather than N cron evaluations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from celery.schedules import crontab
|
||||||
|
|
||||||
|
TASK_NAME = "automation_schedule_select"
|
||||||
|
|
||||||
|
BEAT_SCHEDULE = {
|
||||||
|
"automation-schedule-select": {
|
||||||
|
"task": TASK_NAME,
|
||||||
|
"schedule": crontab(minute="*"),
|
||||||
|
"options": {"expires": 50},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -1,67 +0,0 @@
|
||||||
"""Schedule dispatch adapter: load + guard, then call generic dispatch."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.automations.dispatch import DispatchError, dispatch_run
|
|
||||||
from app.automations.persistence.enums.automation_status import AutomationStatus
|
|
||||||
from app.automations.persistence.models.automation import Automation
|
|
||||||
from app.automations.persistence.models.run import AutomationRun
|
|
||||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_schedule_run(
|
|
||||||
*,
|
|
||||||
session: AsyncSession,
|
|
||||||
trigger: AutomationTrigger,
|
|
||||||
fired_at: datetime,
|
|
||||||
scheduled_for: datetime,
|
|
||||||
previous_last_fired_at: datetime | None,
|
|
||||||
) -> AutomationRun:
|
|
||||||
"""Fire one scheduled run for ``trigger``.
|
|
||||||
|
|
||||||
Emits calendar context as runtime inputs:
|
|
||||||
|
|
||||||
- ``fired_at`` — actual fire time
|
|
||||||
- ``scheduled_for`` — cron-derived target time for this fire
|
|
||||||
- ``last_fired_at`` — fire time of the previous run, or null on first fire
|
|
||||||
|
|
||||||
The caller (the schedule tick) is responsible for selecting due triggers
|
|
||||||
and advancing ``next_fire_at`` / ``last_fired_at`` before invoking this.
|
|
||||||
"""
|
|
||||||
automation = await _load_automation(session, trigger.automation_id)
|
|
||||||
if automation is None:
|
|
||||||
raise DispatchError(
|
|
||||||
f"automation {trigger.automation_id} not found for trigger {trigger.id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if automation.status != AutomationStatus.ACTIVE:
|
|
||||||
raise DispatchError(
|
|
||||||
f"automation {trigger.automation_id} is {automation.status.value}, not active"
|
|
||||||
)
|
|
||||||
|
|
||||||
runtime_inputs = {
|
|
||||||
"fired_at": fired_at.isoformat(),
|
|
||||||
"scheduled_for": scheduled_for.isoformat(),
|
|
||||||
"last_fired_at": (
|
|
||||||
previous_last_fired_at.isoformat() if previous_last_fired_at else None
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
return await dispatch_run(
|
|
||||||
session=session,
|
|
||||||
automation=automation,
|
|
||||||
trigger=trigger,
|
|
||||||
runtime_inputs=runtime_inputs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _load_automation(
|
|
||||||
session: AsyncSession, automation_id: int
|
|
||||||
) -> Automation | None:
|
|
||||||
stmt = select(Automation).where(Automation.id == automation_id)
|
|
||||||
return (await session.execute(stmt)).scalar_one_or_none()
|
|
||||||
|
|
@ -189,7 +189,8 @@ celery_app = Celery(
|
||||||
"app.tasks.celery_tasks.stale_notification_cleanup_task",
|
"app.tasks.celery_tasks.stale_notification_cleanup_task",
|
||||||
"app.tasks.celery_tasks.stripe_reconciliation_task",
|
"app.tasks.celery_tasks.stripe_reconciliation_task",
|
||||||
"app.automations.tasks.execute_run",
|
"app.automations.tasks.execute_run",
|
||||||
"app.automations.tasks.schedule_tick",
|
"app.automations.triggers.builtin.schedule.selector",
|
||||||
|
"app.automations.triggers.builtin.event.selector",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -247,6 +248,12 @@ celery_app.conf.update(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Imported late (after celery_app is built) to keep the automations triggers
|
||||||
|
# package out of this module's top-level import graph.
|
||||||
|
from app.automations.triggers.builtin.schedule.source import ( # noqa: E402
|
||||||
|
BEAT_SCHEDULE as SCHEDULE_BEAT_SCHEDULE,
|
||||||
|
)
|
||||||
|
|
||||||
# Configure Celery Beat schedule
|
# Configure Celery Beat schedule
|
||||||
# This uses a meta-scheduler pattern: instead of creating individual Beat schedules
|
# This uses a meta-scheduler pattern: instead of creating individual Beat schedules
|
||||||
# for each connector, we have ONE schedule that checks the database at the configured interval
|
# for each connector, we have ONE schedule that checks the database at the configured interval
|
||||||
|
|
@ -284,14 +291,7 @@ celery_app.conf.beat_schedule = {
|
||||||
"expires": 60,
|
"expires": 60,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
# Fire due automation schedule triggers. Ticks every minute; per-row cron
|
# Fire due automation schedule triggers (Beat entry owned by the schedule
|
||||||
# math is precomputed (next_fire_at column) so the tick is an indexed
|
# trigger; see app.automations.triggers.builtin.schedule.source).
|
||||||
# lookup, not N cron evaluations.
|
**SCHEDULE_BEAT_SCHEDULE,
|
||||||
"automation-schedule-tick": {
|
|
||||||
"task": "automation_schedule_tick",
|
|
||||||
"schedule": crontab(minute="*"),
|
|
||||||
"options": {
|
|
||||||
"expires": 50,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
25
surfsense_backend/app/event_bus/__init__.py
Normal file
25
surfsense_backend/app/event_bus/__init__.py
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
"""In-process domain event bus.
|
||||||
|
|
||||||
|
Domain-agnostic pub/sub. Producers ``await bus.publish(...)``; subscribers
|
||||||
|
``bus.subscribe(...)``. Domain modules depend on it, never the reverse.
|
||||||
|
|
||||||
|
from app.event_bus import bus
|
||||||
|
await bus.publish("document.indexed", {"document_id": 42}, search_space_id=7)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from . import events # noqa: F401 — populates the event-type catalog
|
||||||
|
from .bus import EventBus, Subscriber, bus
|
||||||
|
from .catalog import EventCatalog, EventType, catalog
|
||||||
|
from .event import Event
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Event",
|
||||||
|
"EventBus",
|
||||||
|
"EventCatalog",
|
||||||
|
"EventType",
|
||||||
|
"Subscriber",
|
||||||
|
"bus",
|
||||||
|
"catalog",
|
||||||
|
]
|
||||||
77
surfsense_backend/app/event_bus/bus.py
Normal file
77
surfsense_backend/app/event_bus/bus.py
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
"""In-process pub/sub. Streams :class:`Event` values from producers to listeners.
|
||||||
|
|
||||||
|
Boundary-crossing (Celery, DB, workers) is a subscriber's job — e.g. the
|
||||||
|
``event`` trigger enqueues its own task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .event import Event
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
Subscriber = Callable[[Event], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
|
class EventBus:
|
||||||
|
"""An in-process pub/sub bus with a per-instance subscriber registry."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._subscribers: list[Subscriber] = []
|
||||||
|
|
||||||
|
def subscribe(self, handler: Subscriber) -> Subscriber:
|
||||||
|
"""Register ``handler`` for every event. Idempotent; returns the handler
|
||||||
|
so it works as a decorator."""
|
||||||
|
if handler not in self._subscribers:
|
||||||
|
self._subscribers.append(handler)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
def subscribers(self) -> list[Subscriber]:
|
||||||
|
"""Defensive snapshot of the registered subscribers."""
|
||||||
|
return list(self._subscribers)
|
||||||
|
|
||||||
|
async def publish(
|
||||||
|
self,
|
||||||
|
event_type: str,
|
||||||
|
payload: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Stamp an :class:`Event` and fan it out. Call after your commit."""
|
||||||
|
event = Event(
|
||||||
|
event_type=event_type,
|
||||||
|
payload=payload or {},
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
)
|
||||||
|
await self.dispatch(event)
|
||||||
|
|
||||||
|
async def dispatch(self, event: Event) -> None:
|
||||||
|
"""Fan ``event`` out concurrently. Subscriber failures are logged and
|
||||||
|
isolated; never propagate."""
|
||||||
|
subscribers = self.subscribers()
|
||||||
|
if not subscribers:
|
||||||
|
return
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*(handler(event) for handler in subscribers),
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for handler, result in zip(subscribers, results, strict=True):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(
|
||||||
|
"event subscriber %r failed for event %s (%s)",
|
||||||
|
getattr(handler, "__qualname__", handler),
|
||||||
|
event.event_id,
|
||||||
|
event.event_type,
|
||||||
|
exc_info=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Process-wide bus. Producers publish to it; subscribers register on it.
|
||||||
|
bus = EventBus()
|
||||||
48
surfsense_backend/app/event_bus/catalog.py
Normal file
48
surfsense_backend/app/event_bus/catalog.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
"""Event type catalog: the deliberate contract behind each event.
|
||||||
|
|
||||||
|
``EventType`` declares a dotted name and the shape of its payload.
|
||||||
|
``EventCatalog`` is the registry — populated once at import by each event type
|
||||||
|
module. ``catalog`` is the process-wide singleton.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class EventType:
|
||||||
|
type: str
|
||||||
|
description: str
|
||||||
|
payload_model: type[BaseModel]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def payload_schema(self) -> dict[str, Any]:
|
||||||
|
"""JSON Schema (draft 2020-12) derived from ``payload_model``."""
|
||||||
|
return self.payload_model.model_json_schema()
|
||||||
|
|
||||||
|
|
||||||
|
class EventCatalog:
|
||||||
|
"""Registry of known event types. Populated at import; read at runtime."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._registry: dict[str, EventType] = {}
|
||||||
|
|
||||||
|
def register(self, event_type: EventType) -> None:
|
||||||
|
"""Register an event type. Raises on duplicate type."""
|
||||||
|
if event_type.type in self._registry:
|
||||||
|
raise ValueError(f"Event type already registered: {event_type.type!r}")
|
||||||
|
self._registry[event_type.type] = event_type
|
||||||
|
|
||||||
|
def get(self, type_: str) -> EventType | None:
|
||||||
|
return self._registry.get(type_)
|
||||||
|
|
||||||
|
def all(self) -> dict[str, EventType]:
|
||||||
|
"""Defensive snapshot of the registry."""
|
||||||
|
return dict(self._registry)
|
||||||
|
|
||||||
|
|
||||||
|
catalog = EventCatalog()
|
||||||
38
surfsense_backend/app/event_bus/event.py
Normal file
38
surfsense_backend/app/event_bus/event.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
"""The ``Event`` value object — the only shape that crosses the bus.
|
||||||
|
|
||||||
|
An immutable fact: something named happened, with this payload, in this space,
|
||||||
|
at this time. JSON round-trippable so a subscriber can queue it to a worker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
def _new_event_id() -> str:
|
||||||
|
return uuid.uuid4().hex
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
class Event(BaseModel):
|
||||||
|
"""A published domain fact.
|
||||||
|
|
||||||
|
``event_type`` is a dotted namespace (``document.indexed``, etc). ``payload`` is
|
||||||
|
JSON-serializable. ``search_space_id`` scopes delivery. ``event_id`` and
|
||||||
|
``occurred_at`` are engine-stamped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
event_type: str
|
||||||
|
payload: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
search_space_id: int
|
||||||
|
event_id: str = Field(default_factory=_new_event_id)
|
||||||
|
occurred_at: datetime = Field(default_factory=_now)
|
||||||
5
surfsense_backend/app/event_bus/events/__init__.py
Normal file
5
surfsense_backend/app/event_bus/events/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Domain event type definitions — each in its own module, self-registering at import."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from . import document_entered_folder # noqa: F401
|
||||||
|
|
@ -0,0 +1,86 @@
|
||||||
|
"""``document.entered_folder``: a document became a member of a folder.
|
||||||
|
|
||||||
|
Fires once per arrival, however the document got there (upload, AI sort, move).
|
||||||
|
The payload carries the fields a user can filter a trigger on.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, computed_field
|
||||||
|
|
||||||
|
from app.event_bus.catalog import EventType, catalog
|
||||||
|
|
||||||
|
EVENT_TYPE = "document.entered_folder"
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentEnteredFolderPayload(BaseModel):
|
||||||
|
"""Snapshot of the document at the moment it entered ``folder_id``.
|
||||||
|
|
||||||
|
``previous_folder_id`` is the folder it left, or ``None`` for a first
|
||||||
|
placement. ``is_move`` derives from it and is emitted for filtering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
document_id: int
|
||||||
|
folder_id: int
|
||||||
|
previous_folder_id: int | None = None
|
||||||
|
document_type: str
|
||||||
|
title: str
|
||||||
|
connector_id: int | None = None
|
||||||
|
created_by_id: str | None = None
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def is_move(self) -> bool:
|
||||||
|
return self.previous_folder_id is not None
|
||||||
|
|
||||||
|
|
||||||
|
catalog.register(
|
||||||
|
EventType(
|
||||||
|
type=EVENT_TYPE,
|
||||||
|
description="A document became a member of a folder.",
|
||||||
|
payload_model=DocumentEnteredFolderPayload,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def payload_if_entered_folder(
|
||||||
|
*,
|
||||||
|
document_id: int,
|
||||||
|
search_space_id: int,
|
||||||
|
new_folder_id: int | None,
|
||||||
|
previous_folder_id: int | None,
|
||||||
|
folder_id_changed: bool,
|
||||||
|
status_state: str,
|
||||||
|
document_type: str,
|
||||||
|
title: str,
|
||||||
|
connector_id: int | None,
|
||||||
|
created_by_id: str | None,
|
||||||
|
) -> dict | None:
|
||||||
|
"""Return a publish payload if this commit represents a folder arrival, else None.
|
||||||
|
|
||||||
|
``folder_id_changed`` comes from SQLAlchemy attribute history — it is True
|
||||||
|
only when ``folder_id`` actually changed in this transaction, preventing
|
||||||
|
spurious events on unrelated saves.
|
||||||
|
"""
|
||||||
|
if not folder_id_changed:
|
||||||
|
return None
|
||||||
|
if new_folder_id is None:
|
||||||
|
return None
|
||||||
|
if status_state != "ready":
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"event_type": EVENT_TYPE,
|
||||||
|
"search_space_id": search_space_id,
|
||||||
|
"payload": {
|
||||||
|
"document_id": document_id,
|
||||||
|
"folder_id": new_folder_id,
|
||||||
|
"previous_folder_id": previous_folder_id,
|
||||||
|
"document_type": document_type,
|
||||||
|
"title": title,
|
||||||
|
"connector_id": connector_id,
|
||||||
|
"created_by_id": created_by_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -525,11 +525,8 @@ async def bulk_move_documents(
|
||||||
detail="Cannot move documents to a folder in a different search space",
|
detail="Cannot move documents to a folder in a different search space",
|
||||||
)
|
)
|
||||||
|
|
||||||
await session.execute(
|
for doc in documents:
|
||||||
Document.__table__.update()
|
doc.folder_id = request.folder_id
|
||||||
.where(Document.id.in_(request.document_ids))
|
|
||||||
.values(folder_id=request.folder_id)
|
|
||||||
)
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return {"message": f"{len(request.document_ids)} documents moved successfully"}
|
return {"message": f"{len(request.document_ids)} documents moved successfully"}
|
||||||
|
|
||||||
|
|
|
||||||
101
surfsense_backend/app/session_events.py
Normal file
101
surfsense_backend/app/session_events.py
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
"""SQLAlchemy session event hooks — wired once at app startup.
|
||||||
|
|
||||||
|
Detects document folder arrivals across every ORM commit and publishes
|
||||||
|
``document.entered_folder`` events to the bus after the transaction is durable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from sqlalchemy import event
|
||||||
|
from sqlalchemy.orm import Session, attributes
|
||||||
|
|
||||||
|
from app.db import Document, DocumentStatus
|
||||||
|
from app.event_bus.bus import EventBus, bus as default_bus
|
||||||
|
from app.event_bus.events.document_entered_folder import payload_if_entered_folder
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_PENDING_KEY = "_entered_folder_pending"
|
||||||
|
|
||||||
|
|
||||||
|
def _after_flush(session: Session, flush_context: object) -> None:
|
||||||
|
"""Collect folder-arrival candidates while attribute history is still available."""
|
||||||
|
pending: list[dict] = []
|
||||||
|
|
||||||
|
for obj in list(session.new) + list(session.dirty):
|
||||||
|
if not isinstance(obj, Document):
|
||||||
|
continue
|
||||||
|
|
||||||
|
history = attributes.get_history(obj, "folder_id")
|
||||||
|
if not history.added:
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_folder_id = history.added[0]
|
||||||
|
previous_folder_id = history.deleted[0] if history.deleted else None
|
||||||
|
|
||||||
|
result = payload_if_entered_folder(
|
||||||
|
document_id=obj.id,
|
||||||
|
search_space_id=obj.search_space_id,
|
||||||
|
new_folder_id=new_folder_id,
|
||||||
|
previous_folder_id=previous_folder_id,
|
||||||
|
folder_id_changed=True,
|
||||||
|
status_state=DocumentStatus.get_state(obj.status) or "",
|
||||||
|
document_type=obj.document_type.value if obj.document_type else "",
|
||||||
|
title=obj.title or "",
|
||||||
|
connector_id=obj.connector_id,
|
||||||
|
created_by_id=str(obj.created_by_id) if obj.created_by_id else None,
|
||||||
|
)
|
||||||
|
if result is not None:
|
||||||
|
pending.append(result)
|
||||||
|
|
||||||
|
setattr(session, _PENDING_KEY, pending)
|
||||||
|
|
||||||
|
|
||||||
|
def _after_commit(session: Session) -> None:
|
||||||
|
"""Publish collected events now that the transaction is durable."""
|
||||||
|
pending: list[dict] = getattr(session, _PENDING_KEY, [])
|
||||||
|
if not pending:
|
||||||
|
return
|
||||||
|
setattr(session, _PENDING_KEY, [])
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warning("No running event loop — skipping %d event(s)", len(pending))
|
||||||
|
return
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
loop.create_task(
|
||||||
|
default_bus.publish(
|
||||||
|
item["event_type"],
|
||||||
|
item["payload"],
|
||||||
|
search_space_id=item["search_space_id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for item in pending
|
||||||
|
]
|
||||||
|
for task in tasks:
|
||||||
|
task.add_done_callback(
|
||||||
|
lambda t: logger.error("event publish failed: %s", t.exception())
|
||||||
|
if not t.cancelled() and t.exception()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _after_rollback(session: Session) -> None:
|
||||||
|
"""Discard any pending events — the transaction did not commit."""
|
||||||
|
setattr(session, _PENDING_KEY, [])
|
||||||
|
|
||||||
|
|
||||||
|
def register_session_hooks(bus: EventBus = default_bus) -> None:
|
||||||
|
"""Register document folder-arrival hooks on the SQLAlchemy Session class.
|
||||||
|
|
||||||
|
Call once at application startup (e.g. in ``app.app`` lifespan). Idempotent
|
||||||
|
— SQLAlchemy deduplicates identical listener registrations.
|
||||||
|
"""
|
||||||
|
event.listen(Session, "after_flush", _after_flush)
|
||||||
|
event.listen(Session, "after_commit", _after_commit)
|
||||||
|
event.listen(Session, "after_rollback", _after_rollback)
|
||||||
|
|
@ -13,7 +13,7 @@ from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.automations.actions.agent_task.auto_decide import build_auto_decisions
|
from app.automations.actions.builtin.agent_task.auto_decide import build_auto_decisions
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
@ -10,7 +10,9 @@ from __future__ import annotations
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||||
|
|
||||||
from app.automations.actions.agent_task.finalize import extract_final_assistant_message
|
from app.automations.actions.builtin.agent_task.finalize import (
|
||||||
|
extract_final_assistant_message,
|
||||||
|
)
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Lock the input-validation contract used by ``dispatch_run``.
|
"""Lock the input-validation contract enforced before a run is enqueued.
|
||||||
|
|
||||||
``_validate_inputs`` is module-internal by convention (underscore), but it
|
``validate_inputs`` is the pure schema check that ``enqueue_run`` runs against
|
||||||
encodes a real behavior contract the rest of the system depends on, and the
|
merged inputs. ``enqueue_run`` itself needs a real DB session, so tests target
|
||||||
public alternative (``dispatch_run``) requires a real DB session. Tests
|
this pure function directly; the contract — not the symbol — is what's locked.
|
||||||
target the pure function directly; the contract — not the symbol — is what's
|
|
||||||
locked.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -12,7 +10,7 @@ from __future__ import annotations
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.automations.dispatch.errors import DispatchError
|
from app.automations.dispatch.errors import DispatchError
|
||||||
from app.automations.dispatch.run import _validate_inputs
|
from app.automations.dispatch.inputs import validate_inputs
|
||||||
from app.automations.schemas.definition.envelope import AutomationDefinition
|
from app.automations.schemas.definition.envelope import AutomationDefinition
|
||||||
from app.automations.schemas.definition.inputs import Inputs
|
from app.automations.schemas.definition.inputs import Inputs
|
||||||
from app.automations.schemas.definition.plan_step import PlanStep
|
from app.automations.schemas.definition.plan_step import PlanStep
|
||||||
|
|
@ -42,7 +40,7 @@ def test_validate_inputs_passes_through_when_no_schema_is_declared() -> None:
|
||||||
"static_key": "value",
|
"static_key": "value",
|
||||||
}
|
}
|
||||||
|
|
||||||
assert _validate_inputs(definition, runtime_inputs) == runtime_inputs
|
assert validate_inputs(definition, runtime_inputs) == runtime_inputs
|
||||||
|
|
||||||
|
|
||||||
def test_validate_inputs_returns_inputs_when_they_match_declared_schema() -> None:
|
def test_validate_inputs_returns_inputs_when_they_match_declared_schema() -> None:
|
||||||
|
|
@ -58,14 +56,13 @@ def test_validate_inputs_returns_inputs_when_they_match_declared_schema() -> Non
|
||||||
|
|
||||||
inputs = {"topic": "weekly report"}
|
inputs = {"topic": "weekly report"}
|
||||||
|
|
||||||
assert _validate_inputs(definition, inputs) == inputs
|
assert validate_inputs(definition, inputs) == inputs
|
||||||
|
|
||||||
|
|
||||||
def test_validate_inputs_raises_dispatch_error_when_inputs_violate_schema() -> None:
|
def test_validate_inputs_raises_dispatch_error_when_inputs_violate_schema() -> None:
|
||||||
"""Inputs that don't match the declared schema must surface as
|
"""Inputs that don't match the declared schema must surface as
|
||||||
``DispatchError`` (not the raw ``jsonschema.ValidationError``), so the
|
``DispatchError`` (not the raw ``jsonschema.ValidationError``), so every
|
||||||
schedule tick and any other caller can handle one dispatch-domain
|
caller can handle one dispatch-domain exception type uniformly."""
|
||||||
exception type uniformly."""
|
|
||||||
schema = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {"topic": {"type": "string"}},
|
"properties": {"topic": {"type": "string"}},
|
||||||
|
|
@ -74,4 +71,4 @@ def test_validate_inputs_raises_dispatch_error_when_inputs_violate_schema() -> N
|
||||||
definition = _minimal_definition(inputs=Inputs(schema=schema))
|
definition = _minimal_definition(inputs=Inputs(schema=schema))
|
||||||
|
|
||||||
with pytest.raises(DispatchError):
|
with pytest.raises(DispatchError):
|
||||||
_validate_inputs(definition, {"topic": 42}) # type violates string
|
validate_inputs(definition, {"topic": 42}) # type violates string
|
||||||
|
|
@ -39,7 +39,7 @@ def test_run_status_string_values_are_stable() -> None:
|
||||||
|
|
||||||
|
|
||||||
def test_trigger_type_keeps_manual_member_even_though_unregistered() -> None:
|
def test_trigger_type_keeps_manual_member_even_though_unregistered() -> None:
|
||||||
"""``MANUAL`` is reserved (mirrors the Postgres enum) but the trigger
|
"""``schedule`` and ``event`` are registered; ``MANUAL`` is reserved
|
||||||
store does not register it in v1. The enum must keep both members so
|
(mirrors the Postgres enum) but the trigger store does not register it.
|
||||||
existing DB rows and the schema migration plan stay valid."""
|
The enum must keep every member so DB rows and migrations stay valid."""
|
||||||
assert {member.value for member in TriggerType} == {"schedule", "manual"}
|
assert {member.value for member in TriggerType} == {"schedule", "event", "manual"}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
"""The ``event`` trigger self-registers on the triggers store at import."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.automations.triggers import get_trigger
|
||||||
|
from app.automations.triggers.builtin.event.params import EventTriggerParams
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_trigger_is_registered() -> None:
|
||||||
|
definition = get_trigger("event")
|
||||||
|
|
||||||
|
assert definition is not None
|
||||||
|
assert definition.type == "event"
|
||||||
|
assert definition.params_model is EventTriggerParams
|
||||||
|
|
@ -0,0 +1,115 @@
|
||||||
|
"""Behavior tests for the ``matches`` filter grammar."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.automations.triggers.builtin.event.filter import FilterError, matches
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_filter_matches_any_payload() -> None:
|
||||||
|
assert matches({}, {"document_id": 42, "document_type": "FILE"}) is True
|
||||||
|
assert matches({}, {}) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_scalar_value_is_implicit_equality() -> None:
|
||||||
|
flt = {"document_type": "FILE"}
|
||||||
|
assert matches(flt, {"document_type": "FILE"}) is True
|
||||||
|
assert matches(flt, {"document_type": "WEBPAGE"}) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_fields_are_anded() -> None:
|
||||||
|
flt = {"document_type": "FILE", "search_space_id": 7}
|
||||||
|
assert matches(flt, {"document_type": "FILE", "search_space_id": 7}) is True
|
||||||
|
assert matches(flt, {"document_type": "FILE", "search_space_id": 9}) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_gt_operator_compares_greater_than() -> None:
|
||||||
|
flt = {"page_count": {"$gt": 10}}
|
||||||
|
assert matches(flt, {"page_count": 20}) is True
|
||||||
|
assert matches(flt, {"page_count": 10}) is False
|
||||||
|
assert matches(flt, {"page_count": 5}) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_remaining_comparison_operators() -> None:
|
||||||
|
assert matches({"n": {"$gte": 10}}, {"n": 10}) is True
|
||||||
|
assert matches({"n": {"$gte": 10}}, {"n": 9}) is False
|
||||||
|
|
||||||
|
assert matches({"n": {"$lt": 10}}, {"n": 9}) is True
|
||||||
|
assert matches({"n": {"$lt": 10}}, {"n": 10}) is False
|
||||||
|
|
||||||
|
assert matches({"n": {"$lte": 10}}, {"n": 10}) is True
|
||||||
|
assert matches({"n": {"$lte": 10}}, {"n": 11}) is False
|
||||||
|
|
||||||
|
assert matches({"s": {"$eq": "FILE"}}, {"s": "FILE"}) is True
|
||||||
|
assert matches({"s": {"$eq": "FILE"}}, {"s": "WEB"}) is False
|
||||||
|
|
||||||
|
assert matches({"s": {"$ne": "FILE"}}, {"s": "WEB"}) is True
|
||||||
|
assert matches({"s": {"$ne": "FILE"}}, {"s": "FILE"}) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_operators_on_one_field_are_anded() -> None:
|
||||||
|
flt = {"n": {"$gte": 10, "$lt": 20}}
|
||||||
|
assert matches(flt, {"n": 15}) is True
|
||||||
|
assert matches(flt, {"n": 10}) is True
|
||||||
|
assert matches(flt, {"n": 20}) is False
|
||||||
|
assert matches(flt, {"n": 5}) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_in_and_nin_membership_operators() -> None:
|
||||||
|
flt_in = {"document_type": {"$in": ["FILE", "WEBPAGE"]}}
|
||||||
|
assert matches(flt_in, {"document_type": "FILE"}) is True
|
||||||
|
assert matches(flt_in, {"document_type": "SLACK"}) is False
|
||||||
|
|
||||||
|
flt_nin = {"document_type": {"$nin": ["FILE", "WEBPAGE"]}}
|
||||||
|
assert matches(flt_nin, {"document_type": "SLACK"}) is True
|
||||||
|
assert matches(flt_nin, {"document_type": "FILE"}) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_or_matches_when_any_branch_holds() -> None:
|
||||||
|
flt = {"$or": [{"document_type": "FILE"}, {"document_type": "WEBPAGE"}]}
|
||||||
|
assert matches(flt, {"document_type": "WEBPAGE"}) is True
|
||||||
|
assert matches(flt, {"document_type": "SLACK"}) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_and_matches_when_every_branch_holds() -> None:
|
||||||
|
flt = {"$and": [{"n": {"$gt": 5}}, {"n": {"$lt": 10}}]}
|
||||||
|
assert matches(flt, {"n": 7}) is True
|
||||||
|
assert matches(flt, {"n": 12}) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_not_inverts_its_subexpression() -> None:
|
||||||
|
flt = {"$not": {"document_type": "FILE"}}
|
||||||
|
assert matches(flt, {"document_type": "WEBPAGE"}) is True
|
||||||
|
assert matches(flt, {"document_type": "FILE"}) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_field_never_matches_and_never_raises() -> None:
|
||||||
|
# Conservative: an absent field fails the constraint, and comparisons must
|
||||||
|
# not raise on the missing value — including $ne (absence isn't "not equal").
|
||||||
|
assert matches({"document_type": "FILE"}, {}) is False
|
||||||
|
assert matches({"page_count": {"$gt": 5}}, {}) is False
|
||||||
|
assert matches({"document_type": {"$in": ["FILE"]}}, {}) is False
|
||||||
|
assert matches({"document_type": {"$ne": "FILE"}}, {}) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_logical_operators_compose_with_fields() -> None:
|
||||||
|
flt = {
|
||||||
|
"search_space_id": 7,
|
||||||
|
"$or": [{"document_type": "FILE"}, {"document_type": "WEBPAGE"}],
|
||||||
|
}
|
||||||
|
assert matches(flt, {"search_space_id": 7, "document_type": "FILE"}) is True
|
||||||
|
assert matches(flt, {"search_space_id": 9, "document_type": "FILE"}) is False
|
||||||
|
assert matches(flt, {"search_space_id": 7, "document_type": "SLACK"}) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_field_operator_raises_filter_error() -> None:
|
||||||
|
with pytest.raises(FilterError):
|
||||||
|
matches({"n": {"$regex": "x"}}, {"n": "xyz"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_logical_operator_raises_filter_error() -> None:
|
||||||
|
with pytest.raises(FilterError):
|
||||||
|
matches({"$nor": [{"document_type": "FILE"}]}, {"document_type": "FILE"})
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""An event hands its payload + metadata to the run as inputs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.automations.triggers.builtin.event.inputs import event_runtime_inputs
|
||||||
|
from app.event_bus import Event
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_inputs_flatten_payload_with_event_metadata() -> None:
|
||||||
|
event = Event(
|
||||||
|
event_type="document.indexed",
|
||||||
|
payload={"document_id": 42, "document_type": "FILE"},
|
||||||
|
search_space_id=7,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = event_runtime_inputs(event)
|
||||||
|
|
||||||
|
assert inputs["document_id"] == 42
|
||||||
|
assert inputs["document_type"] == "FILE"
|
||||||
|
assert inputs["event_type"] == "document.indexed"
|
||||||
|
assert inputs["event_id"] == event.event_id
|
||||||
|
assert inputs["occurred_at"] == event.occurred_at.isoformat()
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
"""Which triggers an event fires: event_type equality + filter match."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.automations.triggers.builtin.event.match import trigger_matches_event
|
||||||
|
from app.event_bus import Event
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _event(event_type: str = "document.indexed", **payload) -> Event:
|
||||||
|
return Event(event_type=event_type, payload=payload, search_space_id=7)
|
||||||
|
|
||||||
|
|
||||||
|
def test_matches_when_event_type_equal_and_filter_passes() -> None:
|
||||||
|
params = {"event_type": "document.indexed", "filter": {"document_type": "FILE"}}
|
||||||
|
assert trigger_matches_event(params, _event(document_type="FILE")) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_match_when_event_type_differs() -> None:
|
||||||
|
params = {"event_type": "document.indexed", "filter": {}}
|
||||||
|
assert trigger_matches_event(params, _event("podcast.generated")) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_match_when_filter_rejects_payload() -> None:
|
||||||
|
params = {"event_type": "document.indexed", "filter": {"document_type": "FILE"}}
|
||||||
|
assert trigger_matches_event(params, _event(document_type="WEBPAGE")) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_filter_matches_any_payload_of_that_type() -> None:
|
||||||
|
params = {"event_type": "document.indexed", "filter": {}}
|
||||||
|
assert trigger_matches_event(params, _event(document_type="ANYTHING")) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_filter_key_is_treated_as_empty() -> None:
|
||||||
|
params = {"event_type": "document.indexed"}
|
||||||
|
assert trigger_matches_event(params, _event(document_type="X")) is True
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
"""``EventTriggerParams`` contract: an event_type to listen for + an optional filter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.automations.triggers.builtin.event.params import EventTriggerParams
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_accepts_event_type_and_filter() -> None:
|
||||||
|
params = EventTriggerParams(
|
||||||
|
event_type="document.indexed",
|
||||||
|
filter={"document_type": "FILE"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert params.event_type == "document.indexed"
|
||||||
|
assert params.filter == {"document_type": "FILE"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_defaults_to_empty() -> None:
|
||||||
|
params = EventTriggerParams(event_type="document.indexed")
|
||||||
|
|
||||||
|
assert params.filter == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_type_is_required() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
EventTriggerParams(filter={"x": 1})
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_type_must_not_be_blank() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
EventTriggerParams(event_type="")
|
||||||
|
|
||||||
|
|
||||||
|
def test_extra_keys_are_forbidden() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
EventTriggerParams(event_type="document.indexed", typo=True)
|
||||||
|
|
@ -6,7 +6,7 @@ from datetime import UTC, datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.automations.triggers.schedule.cron import (
|
from app.automations.triggers.builtin.schedule.cron import (
|
||||||
InvalidCronError,
|
InvalidCronError,
|
||||||
compute_next_fire_at,
|
compute_next_fire_at,
|
||||||
validate_cron,
|
validate_cron,
|
||||||
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from app.automations.triggers.schedule.params import ScheduleTriggerParams
|
from app.automations.triggers.builtin.schedule.params import ScheduleTriggerParams
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
0
surfsense_backend/tests/unit/event_bus/__init__.py
Normal file
0
surfsense_backend/tests/unit/event_bus/__init__.py
Normal file
25
surfsense_backend/tests/unit/event_bus/conftest.py
Normal file
25
surfsense_backend/tests/unit/event_bus/conftest.py
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
"""Shared fixtures for the ``app.event_bus`` unit-test tree.
|
||||||
|
|
||||||
|
The event-type catalog is a module-level registry populated at import. Tests
|
||||||
|
that register their own event types (or assert on registry contents) snapshot
|
||||||
|
and restore it so state never leaks between tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.event_bus.catalog import catalog
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def isolated_event_catalog() -> Iterator[None]:
|
||||||
|
"""Snapshot and restore the event-type catalog around a test."""
|
||||||
|
snapshot = dict(catalog._registry)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
catalog._registry.clear()
|
||||||
|
catalog._registry.update(snapshot)
|
||||||
181
surfsense_backend/tests/unit/event_bus/test_bus.py
Normal file
181
surfsense_backend/tests/unit/event_bus/test_bus.py
Normal file
|
|
@ -0,0 +1,181 @@
|
||||||
|
"""``EventBus`` contract: subscribe, publish (stamp + fan out), dispatch.
|
||||||
|
|
||||||
|
Each test uses a fresh ``EventBus`` — no shared global state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.event_bus import Event, EventBus
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _event() -> Event:
|
||||||
|
return Event(event_type="x.happened", payload={"k": "v"}, search_space_id=1)
|
||||||
|
|
||||||
|
|
||||||
|
async def _noop(_event: Event) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _other(_event: Event) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# --- registry -------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribe_then_subscribers_returns_the_handler() -> None:
|
||||||
|
bus = EventBus()
|
||||||
|
bus.subscribe(_noop)
|
||||||
|
|
||||||
|
assert _noop in bus.subscribers()
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribe_is_idempotent_for_the_same_handler() -> None:
|
||||||
|
"""Registering the same handler twice must not make it fire twice."""
|
||||||
|
bus = EventBus()
|
||||||
|
bus.subscribe(_noop)
|
||||||
|
bus.subscribe(_noop)
|
||||||
|
|
||||||
|
assert bus.subscribers().count(_noop) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_distinct_handlers_both_register() -> None:
|
||||||
|
bus = EventBus()
|
||||||
|
bus.subscribe(_noop)
|
||||||
|
bus.subscribe(_other)
|
||||||
|
|
||||||
|
registered = bus.subscribers()
|
||||||
|
assert _noop in registered
|
||||||
|
assert _other in registered
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribers_returns_a_defensive_snapshot() -> None:
|
||||||
|
"""Mutating the returned list must not corrupt the registry."""
|
||||||
|
bus = EventBus()
|
||||||
|
bus.subscribe(_noop)
|
||||||
|
|
||||||
|
snapshot = bus.subscribers()
|
||||||
|
snapshot.clear()
|
||||||
|
|
||||||
|
assert _noop in bus.subscribers()
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribe_returns_handler_so_it_can_be_used_as_a_decorator() -> None:
|
||||||
|
bus = EventBus()
|
||||||
|
returned = bus.subscribe(_other)
|
||||||
|
|
||||||
|
assert returned is _other
|
||||||
|
|
||||||
|
|
||||||
|
def test_two_buses_do_not_share_subscribers() -> None:
|
||||||
|
"""The registry is per-instance, not global."""
|
||||||
|
a = EventBus()
|
||||||
|
b = EventBus()
|
||||||
|
a.subscribe(_noop)
|
||||||
|
|
||||||
|
assert _noop in a.subscribers()
|
||||||
|
assert _noop not in b.subscribers()
|
||||||
|
|
||||||
|
|
||||||
|
# --- dispatch -------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_dispatch_delivers_event_to_every_subscriber() -> None:
|
||||||
|
bus = EventBus()
|
||||||
|
seen: list[tuple[str, Event]] = []
|
||||||
|
|
||||||
|
async def first(event: Event) -> None:
|
||||||
|
seen.append(("first", event))
|
||||||
|
|
||||||
|
async def second(event: Event) -> None:
|
||||||
|
seen.append(("second", event))
|
||||||
|
|
||||||
|
bus.subscribe(first)
|
||||||
|
bus.subscribe(second)
|
||||||
|
|
||||||
|
event = _event()
|
||||||
|
await bus.dispatch(event)
|
||||||
|
|
||||||
|
assert ("first", event) in seen
|
||||||
|
assert ("second", event) in seen
|
||||||
|
|
||||||
|
|
||||||
|
async def test_dispatch_isolates_a_failing_subscriber() -> None:
|
||||||
|
"""A subscriber that raises must not stop a healthy one from running."""
|
||||||
|
bus = EventBus()
|
||||||
|
healthy_ran = False
|
||||||
|
|
||||||
|
async def boom(_event: Event) -> None:
|
||||||
|
raise RuntimeError("subscriber blew up")
|
||||||
|
|
||||||
|
async def healthy(_event: Event) -> None:
|
||||||
|
nonlocal healthy_ran
|
||||||
|
healthy_ran = True
|
||||||
|
|
||||||
|
bus.subscribe(boom)
|
||||||
|
bus.subscribe(healthy)
|
||||||
|
|
||||||
|
await bus.dispatch(_event())
|
||||||
|
|
||||||
|
assert healthy_ran is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_dispatch_never_propagates_subscriber_errors() -> None:
|
||||||
|
"""``dispatch`` itself must not raise even if every subscriber fails."""
|
||||||
|
bus = EventBus()
|
||||||
|
|
||||||
|
async def boom(_event: Event) -> None:
|
||||||
|
raise ValueError("nope")
|
||||||
|
|
||||||
|
bus.subscribe(boom)
|
||||||
|
|
||||||
|
await bus.dispatch(_event()) # must not raise
|
||||||
|
|
||||||
|
|
||||||
|
async def test_dispatch_with_no_subscribers_is_a_noop() -> None:
|
||||||
|
bus = EventBus()
|
||||||
|
await bus.dispatch(_event()) # must not raise
|
||||||
|
|
||||||
|
|
||||||
|
# --- publish --------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_publish_builds_a_stamped_event_and_fans_it_out() -> None:
|
||||||
|
bus = EventBus()
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def handler(event: Event) -> None:
|
||||||
|
received.append(event)
|
||||||
|
|
||||||
|
bus.subscribe(handler)
|
||||||
|
await bus.publish("document.indexed", {"document_id": 42}, search_space_id=7)
|
||||||
|
|
||||||
|
assert len(received) == 1
|
||||||
|
event = received[0]
|
||||||
|
assert event.event_type == "document.indexed"
|
||||||
|
assert event.payload == {"document_id": 42}
|
||||||
|
assert event.search_space_id == 7
|
||||||
|
# Engine-stamped identity/time on the way through.
|
||||||
|
assert event.event_id
|
||||||
|
assert event.occurred_at
|
||||||
|
|
||||||
|
|
||||||
|
async def test_publish_defaults_payload_to_empty_dict() -> None:
|
||||||
|
bus = EventBus()
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def handler(event: Event) -> None:
|
||||||
|
received.append(event)
|
||||||
|
|
||||||
|
bus.subscribe(handler)
|
||||||
|
await bus.publish("x.happened", search_space_id=1)
|
||||||
|
|
||||||
|
assert received[0].payload == {}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_publish_with_no_subscribers_is_a_noop() -> None:
|
||||||
|
await EventBus().publish("x.happened", search_space_id=1) # must not raise
|
||||||
73
surfsense_backend/tests/unit/event_bus/test_catalog.py
Normal file
73
surfsense_backend/tests/unit/event_bus/test_catalog.py
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
"""EventCatalog contract: register, look up, snapshot, derive schema."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.event_bus.catalog import EventCatalog, EventType
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class _SamplePayload(BaseModel):
|
||||||
|
document_id: int
|
||||||
|
|
||||||
|
|
||||||
|
def _event_type(type_: str = "test.thing") -> EventType:
|
||||||
|
return EventType(
|
||||||
|
type=type_,
|
||||||
|
description="A thing happened.",
|
||||||
|
payload_model=_SamplePayload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_then_get_returns_the_event_type(isolated_event_catalog: None) -> None:
|
||||||
|
from app.event_bus.catalog import catalog
|
||||||
|
catalog.register(_event_type())
|
||||||
|
|
||||||
|
assert catalog.get("test.thing") is not None
|
||||||
|
assert catalog.get("test.thing").type == "test.thing"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_unknown_type_returns_none(isolated_event_catalog: None) -> None:
|
||||||
|
from app.event_bus.catalog import catalog
|
||||||
|
assert catalog.get("does.not.exist") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_duplicate_type_raises(isolated_event_catalog: None) -> None:
|
||||||
|
"""A type is a contract; registering it twice is a bug, not an override."""
|
||||||
|
from app.event_bus.catalog import catalog
|
||||||
|
catalog.register(_event_type())
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="already registered"):
|
||||||
|
catalog.register(_event_type())
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_is_a_defensive_snapshot(isolated_event_catalog: None) -> None:
|
||||||
|
"""Mutating the returned dict must not corrupt the registry."""
|
||||||
|
from app.event_bus.catalog import catalog
|
||||||
|
catalog.register(_event_type())
|
||||||
|
|
||||||
|
snapshot = catalog.all()
|
||||||
|
snapshot.clear()
|
||||||
|
|
||||||
|
assert catalog.get("test.thing") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_payload_schema_is_derived_from_the_payload_model() -> None:
|
||||||
|
"""The JSON Schema a UI/validator consumes comes from the payload model."""
|
||||||
|
event_type = _event_type()
|
||||||
|
|
||||||
|
assert event_type.payload_schema == _SamplePayload.model_json_schema()
|
||||||
|
|
||||||
|
|
||||||
|
def test_each_catalog_instance_has_its_own_registry() -> None:
|
||||||
|
"""Two EventCatalog instances are fully independent."""
|
||||||
|
a = EventCatalog()
|
||||||
|
b = EventCatalog()
|
||||||
|
|
||||||
|
a.register(_event_type())
|
||||||
|
|
||||||
|
assert a.get("test.thing") is not None
|
||||||
|
assert b.get("test.thing") is None
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
"""``document.entered_folder`` payload contract + catalog registration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.event_bus.catalog import catalog
|
||||||
|
from app.event_bus.events.document_entered_folder import (
|
||||||
|
EVENT_TYPE,
|
||||||
|
DocumentEnteredFolderPayload,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _payload(**overrides: object) -> DocumentEnteredFolderPayload:
|
||||||
|
base: dict[str, object] = {
|
||||||
|
"document_id": 42,
|
||||||
|
"folder_id": 7,
|
||||||
|
"document_type": "FILE",
|
||||||
|
"title": "Q3 report.pdf",
|
||||||
|
}
|
||||||
|
base.update(overrides)
|
||||||
|
return DocumentEnteredFolderPayload(**base)
|
||||||
|
|
||||||
|
|
||||||
|
def test_payload_carries_the_filterable_fields() -> None:
|
||||||
|
payload = _payload(connector_id=12, created_by_id="abc")
|
||||||
|
|
||||||
|
assert payload.document_id == 42
|
||||||
|
assert payload.folder_id == 7
|
||||||
|
assert payload.document_type == "FILE"
|
||||||
|
assert payload.connector_id == 12
|
||||||
|
|
||||||
|
|
||||||
|
def test_first_placement_is_not_a_move() -> None:
|
||||||
|
"""No previous folder (created or AI-sorted into place) → not a move."""
|
||||||
|
assert _payload(previous_folder_id=None).is_move is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_change_between_folders_is_a_move() -> None:
|
||||||
|
assert _payload(previous_folder_id=3).is_move is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_move_is_serialized_for_filtering() -> None:
|
||||||
|
"""Filters match against the dumped payload, so ``is_move`` must appear there."""
|
||||||
|
dumped = _payload(previous_folder_id=3).model_dump()
|
||||||
|
|
||||||
|
assert dumped["is_move"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_type_is_registered_in_the_catalog() -> None:
|
||||||
|
registered = catalog.get(EVENT_TYPE)
|
||||||
|
|
||||||
|
assert registered is not None
|
||||||
|
assert registered.payload_model is DocumentEnteredFolderPayload
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
"""payload_if_entered_folder: decides whether a document commit warrants an event."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.event_bus.events.document_entered_folder import payload_if_entered_folder
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _call(**overrides: Any) -> dict[str, Any] | None:
|
||||||
|
defaults: dict[str, Any] = {
|
||||||
|
"document_id": 1,
|
||||||
|
"search_space_id": 10,
|
||||||
|
"new_folder_id": 7,
|
||||||
|
"previous_folder_id": None,
|
||||||
|
"folder_id_changed": True,
|
||||||
|
"status_state": "ready",
|
||||||
|
"document_type": "FILE",
|
||||||
|
"title": "report.pdf",
|
||||||
|
"connector_id": None,
|
||||||
|
"created_by_id": None,
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return payload_if_entered_folder(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def test_folder_set_ready_fires() -> None:
|
||||||
|
result = _call()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["event_type"] == "document.entered_folder"
|
||||||
|
assert result["search_space_id"] == 10
|
||||||
|
assert result["payload"]["folder_id"] == 7
|
||||||
|
assert result["payload"]["previous_folder_id"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_folder_is_silent() -> None:
|
||||||
|
assert _call(new_folder_id=None) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_not_ready_is_silent() -> None:
|
||||||
|
assert _call(status_state="processing") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_folder_unchanged_is_silent() -> None:
|
||||||
|
assert _call(folder_id_changed=False) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_move_carries_previous_folder_id() -> None:
|
||||||
|
result = _call(previous_folder_id=3, new_folder_id=7)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["payload"]["previous_folder_id"] == 3
|
||||||
|
assert result["payload"]["folder_id"] == 7
|
||||||
53
surfsense_backend/tests/unit/event_bus/test_event.py
Normal file
53
surfsense_backend/tests/unit/event_bus/test_event.py
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
"""``Event`` contract: carry caller facts + engine-stamped id/time, round-trip JSON."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.event_bus.event import Event
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_carries_caller_supplied_facts() -> None:
|
||||||
|
"""The three caller inputs are stored verbatim."""
|
||||||
|
event = Event(
|
||||||
|
event_type="document.indexed",
|
||||||
|
payload={"document_id": 42, "content_type": "pdf"},
|
||||||
|
search_space_id=7,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == "document.indexed"
|
||||||
|
assert event.payload == {"document_id": 42, "content_type": "pdf"}
|
||||||
|
assert event.search_space_id == 7
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_stamps_identity_and_time_when_not_supplied() -> None:
|
||||||
|
"""Engine stamps id + time so subscribers can dedup/order."""
|
||||||
|
event = Event(event_type="x.happened", payload={}, search_space_id=1)
|
||||||
|
|
||||||
|
assert event.event_id
|
||||||
|
assert isinstance(event.occurred_at, datetime)
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_ids_are_unique_per_instance() -> None:
|
||||||
|
"""Two events published with identical content are still distinct facts."""
|
||||||
|
first = Event(event_type="x.happened", payload={}, search_space_id=1)
|
||||||
|
second = Event(event_type="x.happened", payload={}, search_space_id=1)
|
||||||
|
|
||||||
|
assert first.event_id != second.event_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_survives_json_round_trip() -> None:
|
||||||
|
"""Serialize → deserialize reproduces the event (subscribers queue it as JSON)."""
|
||||||
|
original = Event(
|
||||||
|
event_type="podcast.generated",
|
||||||
|
payload={"podcast_id": 9, "duration_s": 123.5},
|
||||||
|
search_space_id=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
restored = Event.model_validate_json(original.model_dump_json())
|
||||||
|
|
||||||
|
assert restored == original
|
||||||
|
|
@ -35,6 +35,23 @@ export interface JsonViewProps {
|
||||||
className?: string;
|
className?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Recursively coerce string values that are valid JSON numbers back to numbers.
|
||||||
|
* react-json-view's text input always yields strings; this restores the
|
||||||
|
* correct type so filters like ``{ "folder_id": 56 }`` survive editing. */
|
||||||
|
function coerceNumbers(value: unknown): unknown {
|
||||||
|
if (typeof value === "string") {
|
||||||
|
const n = Number(value);
|
||||||
|
return !Number.isNaN(n) && value.trim() !== "" ? n : value;
|
||||||
|
}
|
||||||
|
if (Array.isArray(value)) return value.map(coerceNumbers);
|
||||||
|
if (value && typeof value === "object") {
|
||||||
|
return Object.fromEntries(
|
||||||
|
Object.entries(value as Record<string, unknown>).map(([k, v]) => [k, coerceNumbers(v)])
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
const DARK_THEME = "monokai" as const;
|
const DARK_THEME = "monokai" as const;
|
||||||
const LIGHT_THEME = "rjv-default" as const;
|
const LIGHT_THEME = "rjv-default" as const;
|
||||||
|
|
||||||
|
|
@ -67,7 +84,7 @@ export function JsonView({
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(interaction: InteractionProps) => {
|
(interaction: InteractionProps) => {
|
||||||
onChange?.(interaction.updated_src);
|
onChange?.(coerceNumbers(interaction.updated_src));
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
[onChange]
|
[onChange]
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import { z } from "zod";
|
||||||
export const automationStatus = z.enum(["active", "paused", "archived"]);
|
export const automationStatus = z.enum(["active", "paused", "archived"]);
|
||||||
export type AutomationStatus = z.infer<typeof automationStatus>;
|
export type AutomationStatus = z.infer<typeof automationStatus>;
|
||||||
|
|
||||||
export const triggerType = z.enum(["schedule", "manual"]);
|
export const triggerType = z.enum(["schedule", "manual", "event"]);
|
||||||
export type TriggerType = z.infer<typeof triggerType>;
|
export type TriggerType = z.infer<typeof triggerType>;
|
||||||
|
|
||||||
export const runStatus = z.enum([
|
export const runStatus = z.enum([
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue