mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-31 19:45:15 +02:00
Merge commit '7972901f15' into dev_mod
This commit is contained in:
commit
80daf46fbf
74 changed files with 1681 additions and 234 deletions
|
|
@ -0,0 +1,47 @@
|
|||
"""Add 'event' to automation_trigger_type enum
|
||||
|
||||
Revision ID: 147
|
||||
Revises: 146
|
||||
Create Date: 2026-05-29
|
||||
|
||||
Adds the ``event`` value to the ``automation_trigger_type`` enum so automations
|
||||
can be triggered by published domain events, alongside the existing
|
||||
``schedule`` triggers.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "147"
|
||||
down_revision: str | None = "146"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
ENUM_NAME = "automation_trigger_type"
|
||||
NEW_VALUE = "event"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Safely add 'event' to automation_trigger_type enum if missing."""
|
||||
op.execute(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = '{ENUM_NAME}' AND e.enumlabel = '{NEW_VALUE}'
|
||||
) THEN
|
||||
ALTER TYPE {ENUM_NAME} ADD VALUE '{NEW_VALUE}';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""No-op: PostgreSQL does not support removing enum values."""
|
||||
pass
|
||||
|
|
@ -43,6 +43,7 @@ from app.rate_limiter import get_real_client_ip, limiter
|
|||
from app.routes import router as crud_router
|
||||
from app.routes.auth_routes import router as auth_router
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
from app.session_events import register_session_hooks
|
||||
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
||||
from app.utils.perf import log_system_snapshot
|
||||
|
||||
|
|
@ -588,6 +589,7 @@ async def lifespan(app: FastAPI):
|
|||
"first real request will pay the full compile cost."
|
||||
)
|
||||
|
||||
register_session_hooks()
|
||||
log_system_snapshot("startup_complete")
|
||||
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -21,4 +21,4 @@ __all__ = [
|
|||
]
|
||||
|
||||
# Built-in actions self-register at import time.
|
||||
from . import agent_task # noqa: F401
|
||||
from . import builtin # noqa: F401
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
"""Built-in action types — each in its own subpackage, self-registering at import."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from . import agent_task # noqa: F401
|
||||
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..store import register_action
|
||||
from ..types import ActionDefinition
|
||||
from ...store import register_action
|
||||
from ...types import ActionDefinition
|
||||
from .factory import build_handler
|
||||
from .params import AgentTaskActionParams
|
||||
|
||||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any
|
||||
|
||||
from ..types import ActionContext, ActionHandler
|
||||
from ...types import ActionContext, ActionHandler
|
||||
from .invoke import run_agent_task
|
||||
from .params import AgentTaskActionParams
|
||||
|
||||
|
|
@ -16,7 +16,7 @@ from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in
|
|||
from app.db import ChatVisibility, async_session_maker
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
from ..types import ActionContext
|
||||
from ...types import ActionContext
|
||||
from .auto_decide import build_auto_decisions
|
||||
from .dependencies import build_dependencies
|
||||
from .finalize import extract_final_assistant_message
|
||||
|
|
@ -3,6 +3,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .errors import DispatchError
|
||||
from .run import dispatch_run
|
||||
from .launch import launch_run
|
||||
|
||||
__all__ = ["DispatchError", "dispatch_run"]
|
||||
__all__ = ["DispatchError", "launch_run"]
|
||||
|
|
|
|||
43
surfsense_backend/app/automations/dispatch/inputs.py
Normal file
43
surfsense_backend/app/automations/dispatch/inputs.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
"""Merge and validate the inputs a run starts with."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import jsonschema
|
||||
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.schemas.definition.envelope import AutomationDefinition
|
||||
|
||||
from .errors import DispatchError
|
||||
|
||||
|
||||
def prepare_inputs(
|
||||
definition: AutomationDefinition,
|
||||
trigger: AutomationTrigger,
|
||||
runtime_inputs: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Merge ``trigger.static_inputs`` over ``runtime_inputs``, then validate.
|
||||
|
||||
Static inputs win on key collision.
|
||||
"""
|
||||
merged = {**(runtime_inputs or {}), **(trigger.static_inputs or {})}
|
||||
return validate_inputs(definition, merged)
|
||||
|
||||
|
||||
def validate_inputs(
|
||||
definition: AutomationDefinition, inputs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Validate ``inputs`` against the definition's optional declared schema.
|
||||
|
||||
No declared schema → pass through unchanged so runtime keys (``fired_at``,
|
||||
``last_fired_at``, ...) still reach the template context. A declared schema
|
||||
that the inputs violate is surfaced as ``DispatchError``.
|
||||
"""
|
||||
if definition.inputs is None or not definition.inputs.schema_:
|
||||
return inputs
|
||||
try:
|
||||
jsonschema.validate(instance=inputs, schema=definition.inputs.schema_)
|
||||
except jsonschema.ValidationError as exc:
|
||||
raise DispatchError(f"inputs: {exc.message}") from exc
|
||||
return inputs
|
||||
60
surfsense_backend/app/automations/dispatch/launch.py
Normal file
60
surfsense_backend/app/automations/dispatch/launch.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""Launch a run for a trigger that fired: resolve, validate, persist, enqueue.
|
||||
|
||||
The trigger-facing entry every selector calls. A selector builds the runtime
|
||||
inputs and hands one trigger row here; this resolves and guards its automation,
|
||||
snapshots the definition onto a PENDING run, and enqueues execution. The
|
||||
snapshot makes the run immune to later edits of the automation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.persistence.enums.run_status import RunStatus
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.schemas.definition.envelope import AutomationDefinition
|
||||
from app.automations.tasks.execute_run import automation_run_execute
|
||||
|
||||
from .errors import DispatchError
|
||||
from .inputs import prepare_inputs
|
||||
from .resolve import resolve_active_automation
|
||||
|
||||
|
||||
async def launch_run(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
trigger: AutomationTrigger,
|
||||
runtime_inputs: dict[str, Any] | None = None,
|
||||
) -> AutomationRun:
|
||||
"""Resolve ``trigger``'s active automation and enqueue a PENDING run for it."""
|
||||
automation = await resolve_active_automation(session, trigger)
|
||||
|
||||
try:
|
||||
definition = AutomationDefinition.model_validate(automation.definition)
|
||||
except Exception as exc:
|
||||
raise DispatchError(f"invalid automation definition: {exc}") from exc
|
||||
|
||||
inputs = prepare_inputs(definition, trigger, runtime_inputs)
|
||||
snapshot = definition.model_dump(mode="json", by_alias=True)
|
||||
|
||||
run = AutomationRun(
|
||||
automation_id=automation.id,
|
||||
trigger_id=trigger.id,
|
||||
status=RunStatus.PENDING,
|
||||
definition_snapshot=snapshot,
|
||||
inputs=inputs,
|
||||
step_results=[],
|
||||
artifacts=[],
|
||||
)
|
||||
session.add(run)
|
||||
await session.commit()
|
||||
await session.refresh(run)
|
||||
|
||||
automation_run_execute.apply_async(
|
||||
args=[run.id],
|
||||
time_limit=definition.execution.timeout_seconds,
|
||||
)
|
||||
return run
|
||||
40
surfsense_backend/app/automations/dispatch/resolve.py
Normal file
40
surfsense_backend/app/automations/dispatch/resolve.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""Resolve the automation behind a trigger and guard that it may run."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.persistence.enums.automation_status import AutomationStatus
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
|
||||
from .errors import DispatchError
|
||||
|
||||
|
||||
async def resolve_active_automation(
|
||||
session: AsyncSession, trigger: AutomationTrigger
|
||||
) -> Automation:
|
||||
"""Load ``trigger``'s automation and require it ``ACTIVE``.
|
||||
|
||||
Raises ``DispatchError`` if the automation is missing or not active.
|
||||
"""
|
||||
automation = await _load_automation(session, trigger.automation_id)
|
||||
if automation is None:
|
||||
raise DispatchError(
|
||||
f"automation {trigger.automation_id} not found for trigger {trigger.id}"
|
||||
)
|
||||
|
||||
if automation.status != AutomationStatus.ACTIVE:
|
||||
raise DispatchError(
|
||||
f"automation {trigger.automation_id} is {automation.status.value}, not active"
|
||||
)
|
||||
|
||||
return automation
|
||||
|
||||
|
||||
async def _load_automation(
|
||||
session: AsyncSession, automation_id: int
|
||||
) -> Automation | None:
|
||||
stmt = select(Automation).where(Automation.id == automation_id)
|
||||
return (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
"""Generic run dispatch: validate, snapshot, persist, enqueue. Shared by every trigger."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import jsonschema
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.persistence.enums.run_status import RunStatus
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.schemas.definition.envelope import AutomationDefinition
|
||||
from app.automations.tasks.execute_run import automation_run_execute
|
||||
|
||||
from .errors import DispatchError
|
||||
|
||||
|
||||
async def dispatch_run(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
automation: Automation,
|
||||
trigger: AutomationTrigger,
|
||||
runtime_inputs: dict[str, Any] | None = None,
|
||||
) -> AutomationRun:
|
||||
"""Validate, snapshot the definition, persist an ``AutomationRun``, enqueue execution.
|
||||
|
||||
Final inputs = ``trigger.static_inputs`` merged with ``runtime_inputs``,
|
||||
static winning on key collision. The merged dict is validated against
|
||||
``automation.definition.inputs.schema_`` and stored on the run.
|
||||
|
||||
Callers (trigger-specific adapters) are responsible for resolving
|
||||
``automation`` and ``trigger`` and for the trigger-side ``ACTIVE`` /
|
||||
``enabled`` guards. This function only handles what's identical across
|
||||
every trigger type.
|
||||
"""
|
||||
try:
|
||||
definition = AutomationDefinition.model_validate(automation.definition)
|
||||
except Exception as exc:
|
||||
raise DispatchError(f"invalid automation definition: {exc}") from exc
|
||||
|
||||
merged_inputs = {**(runtime_inputs or {}), **(trigger.static_inputs or {})}
|
||||
validated_inputs = _validate_inputs(definition, merged_inputs)
|
||||
snapshot = definition.model_dump(mode="json", by_alias=True)
|
||||
|
||||
run = AutomationRun(
|
||||
automation_id=automation.id,
|
||||
trigger_id=trigger.id,
|
||||
status=RunStatus.PENDING,
|
||||
definition_snapshot=snapshot,
|
||||
inputs=validated_inputs,
|
||||
step_results=[],
|
||||
artifacts=[],
|
||||
)
|
||||
session.add(run)
|
||||
await session.commit()
|
||||
await session.refresh(run)
|
||||
|
||||
automation_run_execute.apply_async(
|
||||
args=[run.id],
|
||||
time_limit=definition.execution.timeout_seconds,
|
||||
)
|
||||
return run
|
||||
|
||||
|
||||
def _validate_inputs(
|
||||
definition: AutomationDefinition, inputs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Validate merged inputs against the optional declared schema.
|
||||
|
||||
No declared schema → pass through (runtime inputs like ``fired_at`` /
|
||||
``last_fired_at`` and trigger ``static_inputs`` must still reach the
|
||||
template context). Returning ``{}`` here strips them and makes Jinja
|
||||
blow up on any ``{{ inputs.* }}`` reference.
|
||||
"""
|
||||
if definition.inputs is None or not definition.inputs.schema_:
|
||||
return inputs
|
||||
try:
|
||||
jsonschema.validate(instance=inputs, schema=definition.inputs.schema_)
|
||||
except jsonschema.ValidationError as exc:
|
||||
raise DispatchError(f"inputs: {exc.message}") from exc
|
||||
return inputs
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
"""Trigger-kind discriminator.
|
||||
|
||||
v1 only registers ``schedule``. ``manual`` is reserved in the enum (mirrors the
|
||||
postgres enum) but is intentionally unregistered pending a redesign of the
|
||||
"Run now" UX.
|
||||
``schedule`` and ``event`` are registered. ``manual`` is reserved in the enum
|
||||
(mirrors the postgres enum) but is intentionally unregistered pending a redesign
|
||||
of the "Run now" UX.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -12,4 +12,5 @@ from enum import StrEnum
|
|||
|
||||
class TriggerType(StrEnum):
|
||||
SCHEDULE = "schedule"
|
||||
EVENT = "event"
|
||||
MANUAL = "manual"
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from app.automations.services.model_policy import (
|
|||
get_automation_model_eligibility,
|
||||
)
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.schedule import compute_next_fire_at
|
||||
from app.automations.triggers.builtin.schedule import compute_next_fire_at
|
||||
from app.db import Permission, SearchSpace, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from app.automations.persistence.models.automation import Automation
|
|||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.schedule import compute_next_fire_at
|
||||
from app.automations.triggers.builtin.schedule import compute_next_fire_at
|
||||
from app.db import Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
"""Triggers domain: registry surface + built-in trigger packages.
|
||||
|
||||
Each trigger lives in its own subpackage (``schedule/``, ...) and
|
||||
self-registers at import time via its ``definition`` module.
|
||||
Built-in trigger types live under ``builtin/`` and self-register at import time.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -17,4 +16,4 @@ __all__ = [
|
|||
]
|
||||
|
||||
# Built-in triggers self-register at import time.
|
||||
from . import schedule # noqa: F401
|
||||
from . import builtin # noqa: F401
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
"""Built-in trigger types — each in its own subpackage, self-registering at import."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from . import event, schedule # noqa: F401
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
"""``event`` trigger: fire an automation when a matching domain event is published.
|
||||
|
||||
Subscribes to the event bus and matches events against a user-authored JSON
|
||||
filter (see :mod:`.filter`).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.event_bus import bus
|
||||
|
||||
from .filter import FilterError, matches
|
||||
from .inputs import event_runtime_inputs
|
||||
from .match import trigger_matches_event
|
||||
from .params import EventTriggerParams
|
||||
from .source import on_event
|
||||
|
||||
__all__ = [
|
||||
"EventTriggerParams",
|
||||
"FilterError",
|
||||
"event_runtime_inputs",
|
||||
"matches",
|
||||
"trigger_matches_event",
|
||||
]
|
||||
|
||||
# Side-effect: register on the triggers store.
|
||||
from . import definition # noqa: F401
|
||||
|
||||
# Side-effect: react to published events.
|
||||
bus.subscribe(on_event)
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
"""``event`` ``TriggerDefinition`` registration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.automations.triggers.store import register_trigger
|
||||
from app.automations.triggers.types import TriggerDefinition
|
||||
|
||||
from .params import EventTriggerParams
|
||||
|
||||
EVENT_TRIGGER = TriggerDefinition(
|
||||
type="event",
|
||||
description="Fire when a matching domain event is published.",
|
||||
params_model=EventTriggerParams,
|
||||
)
|
||||
|
||||
register_trigger(EVENT_TRIGGER)
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
"""Pure JSON filter grammar: ``matches(filter_expr, payload) -> bool``.
|
||||
|
||||
The ``event`` trigger uses it to decide whether an event fires the automation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
class FilterError(ValueError):
|
||||
"""Unknown operator in a filter. Raised (not silently false) so a bad filter
|
||||
fails at authoring time instead of quietly disabling the trigger."""
|
||||
|
||||
|
||||
# Scalar comparison operators: (actual, operand) -> bool.
|
||||
_COMPARATORS: dict[str, Callable[[Any, Any], bool]] = {
|
||||
"$eq": operator.eq,
|
||||
"$ne": operator.ne,
|
||||
"$gt": operator.gt,
|
||||
"$gte": operator.ge,
|
||||
"$lt": operator.lt,
|
||||
"$lte": operator.le,
|
||||
"$in": lambda actual, operand: actual in operand,
|
||||
"$nin": lambda actual, operand: actual not in operand,
|
||||
}
|
||||
|
||||
# Sentinel for "the payload has no such field" — distinct from a present None.
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
def matches(filter_expr: dict[str, Any], payload: dict[str, Any]) -> bool:
|
||||
"""Return ``True`` when ``payload`` satisfies every constraint in ``filter_expr``.
|
||||
|
||||
An empty filter expresses "no constraints" and matches every payload.
|
||||
Sibling keys (fields and logical operators alike) are ANDed together.
|
||||
"""
|
||||
for key, value in filter_expr.items():
|
||||
if key == "$and":
|
||||
if not all(matches(sub, payload) for sub in value):
|
||||
return False
|
||||
elif key == "$or":
|
||||
if not any(matches(sub, payload) for sub in value):
|
||||
return False
|
||||
elif key == "$not":
|
||||
if matches(value, payload):
|
||||
return False
|
||||
elif key.startswith("$"):
|
||||
raise FilterError(f"unknown logical operator: {key}")
|
||||
elif not _match_condition(value, payload.get(key, _MISSING)):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _match_condition(condition: Any, actual: Any) -> bool:
|
||||
"""Match one field's ``actual`` value against its ``condition``.
|
||||
|
||||
A dict condition is an operator object (``{"$gt": 10}``); every operator in
|
||||
it must hold. Any other value is an implicit equality check. A field absent
|
||||
from the payload (``actual is _MISSING``) fails every constraint.
|
||||
"""
|
||||
if actual is _MISSING:
|
||||
return False
|
||||
if isinstance(condition, dict):
|
||||
return all(
|
||||
_apply_operator(op, operand, actual)
|
||||
for op, operand in condition.items()
|
||||
)
|
||||
return actual == condition
|
||||
|
||||
|
||||
def _apply_operator(op: str, operand: Any, actual: Any) -> bool:
|
||||
comparator = _COMPARATORS.get(op)
|
||||
if comparator is not None:
|
||||
return comparator(actual, operand)
|
||||
raise FilterError(f"unknown operator: {op}")
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
"""Build run inputs from a published event."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.event_bus import Event
|
||||
|
||||
|
||||
def event_runtime_inputs(event: Event) -> dict[str, Any]:
|
||||
"""Flatten the event payload and stamp event metadata as run inputs."""
|
||||
return {
|
||||
**event.payload,
|
||||
"event_type": event.event_type,
|
||||
"event_id": event.event_id,
|
||||
"occurred_at": event.occurred_at.isoformat(),
|
||||
}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
"""Pure predicate: does an event trigger fire for a given event?"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.event_bus import Event
|
||||
|
||||
from .filter import matches
|
||||
|
||||
|
||||
def trigger_matches_event(params: dict[str, Any], event: Event) -> bool:
|
||||
"""True when an event trigger configured with ``params`` should fire for ``event``."""
|
||||
if params.get("event_type") != event.event_type:
|
||||
return False
|
||||
return matches(params.get("filter") or {}, event.payload)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""``EventTriggerParams`` — params for the ``event`` trigger type."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class EventTriggerParams(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
event_type: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Event type to listen for.",
|
||||
examples=["document.indexed"],
|
||||
)
|
||||
filter: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="JSON filter matched against the event payload.",
|
||||
examples=[{"document_type": "FILE"}],
|
||||
)
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
"""Event selector (worker task): pick the triggers an event fires, start each.
|
||||
|
||||
The source enqueues this with a serialized event. Here we load the enabled
|
||||
``event`` triggers for that event type, keep the ones whose filter matches the
|
||||
payload, and start a run for each. Per-trigger failures are isolated.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.dispatch import launch_run
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.celery_app import celery_app
|
||||
from app.event_bus import Event
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
from .inputs import event_runtime_inputs
|
||||
from .match import trigger_matches_event
|
||||
from .source import TASK_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(name=TASK_NAME)
|
||||
def automation_event_select(event: dict[str, Any]) -> None:
|
||||
"""Select and start the runs an event fires."""
|
||||
return run_async_celery_task(lambda: _select_and_start(event))
|
||||
|
||||
|
||||
async def _select_and_start(event_dict: dict[str, Any]) -> None:
|
||||
event = Event.model_validate(event_dict)
|
||||
session_maker = get_celery_session_maker()
|
||||
async with session_maker() as session:
|
||||
for trigger in await _eligible(session, event=event):
|
||||
await _start_one(session, trigger=trigger, event=event)
|
||||
|
||||
|
||||
async def _eligible(
|
||||
session: AsyncSession, *, event: Event
|
||||
) -> list[AutomationTrigger]:
|
||||
"""Enabled ``event`` triggers for this event type whose filter matches."""
|
||||
stmt = select(AutomationTrigger).where(
|
||||
AutomationTrigger.type == TriggerType.EVENT,
|
||||
AutomationTrigger.enabled.is_(True),
|
||||
AutomationTrigger.params["event_type"].astext == event.event_type,
|
||||
)
|
||||
triggers = (await session.execute(stmt)).scalars().all()
|
||||
return [t for t in triggers if trigger_matches_event(t.params, event)]
|
||||
|
||||
|
||||
async def _start_one(
|
||||
session: AsyncSession, *, trigger: AutomationTrigger, event: Event
|
||||
) -> None:
|
||||
try:
|
||||
run = await launch_run(
|
||||
session=session,
|
||||
trigger=trigger,
|
||||
runtime_inputs=event_runtime_inputs(event),
|
||||
)
|
||||
logger.info(
|
||||
"event fire: trigger=%d automation=%d run=%d event=%s",
|
||||
trigger.id,
|
||||
trigger.automation_id,
|
||||
run.id,
|
||||
event.event_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("event fire failed for trigger %d", trigger.id)
|
||||
await session.rollback()
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""Event trigger source: the bus subscriber that enqueues the selector.
|
||||
|
||||
Runs in whatever process published the event, so it stays thin — it only hands
|
||||
the event to a worker (the selector does the DB matching).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.event_bus import Event
|
||||
|
||||
TASK_NAME = "automation_event_select"
|
||||
|
||||
|
||||
async def on_event(event: Event) -> None:
|
||||
"""Enqueue the selector for ``event``."""
|
||||
# Lazy import: keeps app.celery_app out of the triggers-package import graph.
|
||||
from app.celery_app import celery_app
|
||||
|
||||
celery_app.send_task(TASK_NAME, kwargs={"event": event.model_dump(mode="json")})
|
||||
|
|
@ -3,14 +3,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .cron import InvalidCronError, compute_next_fire_at, validate_cron
|
||||
from .dispatch import dispatch_schedule_run
|
||||
from .params import ScheduleTriggerParams
|
||||
|
||||
__all__ = [
|
||||
"InvalidCronError",
|
||||
"ScheduleTriggerParams",
|
||||
"compute_next_fire_at",
|
||||
"dispatch_schedule_run",
|
||||
"validate_cron",
|
||||
]
|
||||
|
||||
|
|
@ -2,8 +2,9 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..store import register_trigger
|
||||
from ..types import TriggerDefinition
|
||||
from app.automations.triggers.store import register_trigger
|
||||
from app.automations.triggers.types import TriggerDefinition
|
||||
|
||||
from .params import ScheduleTriggerParams
|
||||
|
||||
SCHEDULE_TRIGGER = TriggerDefinition(
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""Build run inputs from a schedule fire."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
def schedule_runtime_inputs(
|
||||
*,
|
||||
fired_at: datetime,
|
||||
scheduled_for: datetime,
|
||||
previous_last_fired_at: datetime | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Calendar context for a scheduled run.
|
||||
|
||||
- ``fired_at`` — actual fire time
|
||||
- ``scheduled_for`` — cron-derived target time for this fire
|
||||
- ``last_fired_at`` — previous fire time, or null on first fire
|
||||
"""
|
||||
return {
|
||||
"fired_at": fired_at.isoformat(),
|
||||
"scheduled_for": scheduled_for.isoformat(),
|
||||
"last_fired_at": (
|
||||
previous_last_fired_at.isoformat() if previous_last_fired_at else None
|
||||
),
|
||||
}
|
||||
|
|
@ -1,15 +1,12 @@
|
|||
"""Celery Beat tick that fires due ``schedule`` triggers.
|
||||
"""Schedule selector (worker task): claim due triggers and start each.
|
||||
|
||||
Runs every minute. Each tick performs two passes:
|
||||
Beat ticks this every minute. Two passes:
|
||||
|
||||
1. **Self-heal**: enabled schedule triggers with NULL ``next_fire_at`` get
|
||||
it computed from their ``cron`` + ``timezone`` (e.g. fresh inserts or
|
||||
rows restored from backup).
|
||||
2. **Claim & fire**: due rows are locked with ``FOR UPDATE SKIP LOCKED``,
|
||||
their ``next_fire_at`` is advanced and ``last_fired_at`` is set, and
|
||||
``dispatch_schedule_run`` is invoked for each. Dispatch errors are
|
||||
logged; a missed fire stays missed (matches K8s CronJob / Airflow
|
||||
``catchup=False`` semantics).
|
||||
1. **Self-heal**: enabled schedule triggers with NULL ``next_fire_at`` get it
|
||||
computed from their ``cron`` + ``timezone`` (fresh inserts, restored rows).
|
||||
2. **Claim & start**: due rows are locked ``FOR UPDATE SKIP LOCKED``, their
|
||||
``next_fire_at`` is advanced and ``last_fired_at`` set, and a run is started
|
||||
for each. A missed fire stays missed (``catchup=False`` semantics).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -21,19 +18,17 @@ from datetime import UTC, datetime
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.dispatch import launch_run
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.triggers.schedule import (
|
||||
InvalidCronError,
|
||||
compute_next_fire_at,
|
||||
dispatch_schedule_run,
|
||||
)
|
||||
from app.celery_app import celery_app
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from .cron import InvalidCronError, compute_next_fire_at
|
||||
from .inputs import schedule_runtime_inputs
|
||||
from .source import TASK_NAME
|
||||
|
||||
TASK_NAME = "automation_schedule_tick"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cap rows touched per tick so a backlog of due triggers can't starve the
|
||||
# worker; remaining rows fire on the next tick.
|
||||
|
|
@ -50,8 +45,8 @@ class _Claim:
|
|||
|
||||
|
||||
@celery_app.task(name=TASK_NAME)
|
||||
def automation_schedule_tick() -> None:
|
||||
"""Tick once: self-heal NULL next_fire_at, claim due rows, fire each."""
|
||||
def automation_schedule_select() -> None:
|
||||
"""Tick once: self-heal NULL next_fire_at, claim due rows, start each."""
|
||||
return run_async_celery_task(_tick)
|
||||
|
||||
|
||||
|
|
@ -67,7 +62,7 @@ async def _tick() -> None:
|
|||
return
|
||||
|
||||
for claim in claims:
|
||||
await _fire_one(session, claim=claim, fired_at=now)
|
||||
await _start_one(session, claim=claim, fired_at=now)
|
||||
|
||||
|
||||
async def _self_heal_null_next_fire(session: AsyncSession, *, now: datetime) -> None:
|
||||
|
|
@ -155,21 +150,23 @@ async def _claim_due_triggers(session: AsyncSession, *, now: datetime) -> list[_
|
|||
return claims
|
||||
|
||||
|
||||
async def _fire_one(
|
||||
async def _start_one(
|
||||
session: AsyncSession, *, claim: _Claim, fired_at: datetime
|
||||
) -> None:
|
||||
"""Reload the trigger post-commit and dispatch a run for it."""
|
||||
"""Reload the trigger post-commit and start a run for it."""
|
||||
trigger = await session.get(AutomationTrigger, claim.trigger_id)
|
||||
if trigger is None:
|
||||
return
|
||||
|
||||
try:
|
||||
run = await dispatch_schedule_run(
|
||||
run = await launch_run(
|
||||
session=session,
|
||||
trigger=trigger,
|
||||
fired_at=fired_at,
|
||||
scheduled_for=claim.scheduled_for,
|
||||
previous_last_fired_at=claim.previous_last_fired_at,
|
||||
runtime_inputs=schedule_runtime_inputs(
|
||||
fired_at=fired_at,
|
||||
scheduled_for=claim.scheduled_for,
|
||||
previous_last_fired_at=claim.previous_last_fired_at,
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
"scheduled fire: trigger=%d automation=%d run=%d",
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
"""Schedule trigger source: Celery Beat ticks the selector every minute.
|
||||
|
||||
``BEAT_SCHEDULE`` is merged into ``celery_app.conf.beat_schedule``. Per-row cron
|
||||
math is precomputed (the ``next_fire_at`` column), so each tick is an indexed
|
||||
lookup rather than N cron evaluations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
TASK_NAME = "automation_schedule_select"
|
||||
|
||||
BEAT_SCHEDULE = {
|
||||
"automation-schedule-select": {
|
||||
"task": TASK_NAME,
|
||||
"schedule": crontab(minute="*"),
|
||||
"options": {"expires": 50},
|
||||
},
|
||||
}
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
"""Schedule dispatch adapter: load + guard, then call generic dispatch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.dispatch import DispatchError, dispatch_run
|
||||
from app.automations.persistence.enums.automation_status import AutomationStatus
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
|
||||
|
||||
async def dispatch_schedule_run(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
trigger: AutomationTrigger,
|
||||
fired_at: datetime,
|
||||
scheduled_for: datetime,
|
||||
previous_last_fired_at: datetime | None,
|
||||
) -> AutomationRun:
|
||||
"""Fire one scheduled run for ``trigger``.
|
||||
|
||||
Emits calendar context as runtime inputs:
|
||||
|
||||
- ``fired_at`` — actual fire time
|
||||
- ``scheduled_for`` — cron-derived target time for this fire
|
||||
- ``last_fired_at`` — fire time of the previous run, or null on first fire
|
||||
|
||||
The caller (the schedule tick) is responsible for selecting due triggers
|
||||
and advancing ``next_fire_at`` / ``last_fired_at`` before invoking this.
|
||||
"""
|
||||
automation = await _load_automation(session, trigger.automation_id)
|
||||
if automation is None:
|
||||
raise DispatchError(
|
||||
f"automation {trigger.automation_id} not found for trigger {trigger.id}"
|
||||
)
|
||||
|
||||
if automation.status != AutomationStatus.ACTIVE:
|
||||
raise DispatchError(
|
||||
f"automation {trigger.automation_id} is {automation.status.value}, not active"
|
||||
)
|
||||
|
||||
runtime_inputs = {
|
||||
"fired_at": fired_at.isoformat(),
|
||||
"scheduled_for": scheduled_for.isoformat(),
|
||||
"last_fired_at": (
|
||||
previous_last_fired_at.isoformat() if previous_last_fired_at else None
|
||||
),
|
||||
}
|
||||
|
||||
return await dispatch_run(
|
||||
session=session,
|
||||
automation=automation,
|
||||
trigger=trigger,
|
||||
runtime_inputs=runtime_inputs,
|
||||
)
|
||||
|
||||
|
||||
async def _load_automation(
|
||||
session: AsyncSession, automation_id: int
|
||||
) -> Automation | None:
|
||||
stmt = select(Automation).where(Automation.id == automation_id)
|
||||
return (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
|
@ -189,7 +189,8 @@ celery_app = Celery(
|
|||
"app.tasks.celery_tasks.stale_notification_cleanup_task",
|
||||
"app.tasks.celery_tasks.stripe_reconciliation_task",
|
||||
"app.automations.tasks.execute_run",
|
||||
"app.automations.tasks.schedule_tick",
|
||||
"app.automations.triggers.builtin.schedule.selector",
|
||||
"app.automations.triggers.builtin.event.selector",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -247,6 +248,12 @@ celery_app.conf.update(
|
|||
},
|
||||
)
|
||||
|
||||
# Imported late (after celery_app is built) to keep the automations triggers
|
||||
# package out of this module's top-level import graph.
|
||||
from app.automations.triggers.builtin.schedule.source import ( # noqa: E402
|
||||
BEAT_SCHEDULE as SCHEDULE_BEAT_SCHEDULE,
|
||||
)
|
||||
|
||||
# Configure Celery Beat schedule
|
||||
# This uses a meta-scheduler pattern: instead of creating individual Beat schedules
|
||||
# for each connector, we have ONE schedule that checks the database at the configured interval
|
||||
|
|
@ -284,14 +291,7 @@ celery_app.conf.beat_schedule = {
|
|||
"expires": 60,
|
||||
},
|
||||
},
|
||||
# Fire due automation schedule triggers. Ticks every minute; per-row cron
|
||||
# math is precomputed (next_fire_at column) so the tick is an indexed
|
||||
# lookup, not N cron evaluations.
|
||||
"automation-schedule-tick": {
|
||||
"task": "automation_schedule_tick",
|
||||
"schedule": crontab(minute="*"),
|
||||
"options": {
|
||||
"expires": 50,
|
||||
},
|
||||
},
|
||||
# Fire due automation schedule triggers (Beat entry owned by the schedule
|
||||
# trigger; see app.automations.triggers.builtin.schedule.source).
|
||||
**SCHEDULE_BEAT_SCHEDULE,
|
||||
}
|
||||
|
|
|
|||
25
surfsense_backend/app/event_bus/__init__.py
Normal file
25
surfsense_backend/app/event_bus/__init__.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
"""In-process domain event bus.
|
||||
|
||||
Domain-agnostic pub/sub. Producers ``await bus.publish(...)``; subscribers
|
||||
``bus.subscribe(...)``. Domain modules depend on it, never the reverse.
|
||||
|
||||
from app.event_bus import bus
|
||||
await bus.publish("document.indexed", {"document_id": 42}, search_space_id=7)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from . import events # noqa: F401 — populates the event-type catalog
|
||||
from .bus import EventBus, Subscriber, bus
|
||||
from .catalog import EventCatalog, EventType, catalog
|
||||
from .event import Event
|
||||
|
||||
__all__ = [
|
||||
"Event",
|
||||
"EventBus",
|
||||
"EventCatalog",
|
||||
"EventType",
|
||||
"Subscriber",
|
||||
"bus",
|
||||
"catalog",
|
||||
]
|
||||
77
surfsense_backend/app/event_bus/bus.py
Normal file
77
surfsense_backend/app/event_bus/bus.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""In-process pub/sub. Streams :class:`Event` values from producers to listeners.
|
||||
|
||||
Boundary-crossing (Celery, DB, workers) is a subscriber's job — e.g. the
|
||||
``event`` trigger enqueues its own task.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from .event import Event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Subscriber = Callable[[Event], Awaitable[None]]
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""An in-process pub/sub bus with a per-instance subscriber registry."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._subscribers: list[Subscriber] = []
|
||||
|
||||
def subscribe(self, handler: Subscriber) -> Subscriber:
|
||||
"""Register ``handler`` for every event. Idempotent; returns the handler
|
||||
so it works as a decorator."""
|
||||
if handler not in self._subscribers:
|
||||
self._subscribers.append(handler)
|
||||
return handler
|
||||
|
||||
def subscribers(self) -> list[Subscriber]:
|
||||
"""Defensive snapshot of the registered subscribers."""
|
||||
return list(self._subscribers)
|
||||
|
||||
async def publish(
|
||||
self,
|
||||
event_type: str,
|
||||
payload: dict[str, Any] | None = None,
|
||||
*,
|
||||
search_space_id: int,
|
||||
) -> None:
|
||||
"""Stamp an :class:`Event` and fan it out. Call after your commit."""
|
||||
event = Event(
|
||||
event_type=event_type,
|
||||
payload=payload or {},
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
await self.dispatch(event)
|
||||
|
||||
async def dispatch(self, event: Event) -> None:
|
||||
"""Fan ``event`` out concurrently. Subscriber failures are logged and
|
||||
isolated; never propagate."""
|
||||
subscribers = self.subscribers()
|
||||
if not subscribers:
|
||||
return
|
||||
|
||||
results = await asyncio.gather(
|
||||
*(handler(event) for handler in subscribers),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
for handler, result in zip(subscribers, results, strict=True):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(
|
||||
"event subscriber %r failed for event %s (%s)",
|
||||
getattr(handler, "__qualname__", handler),
|
||||
event.event_id,
|
||||
event.event_type,
|
||||
exc_info=result,
|
||||
)
|
||||
|
||||
|
||||
# Process-wide bus. Producers publish to it; subscribers register on it.
|
||||
bus = EventBus()
|
||||
48
surfsense_backend/app/event_bus/catalog.py
Normal file
48
surfsense_backend/app/event_bus/catalog.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Event type catalog: the deliberate contract behind each event.
|
||||
|
||||
``EventType`` declares a dotted name and the shape of its payload.
|
||||
``EventCatalog`` is the registry — populated once at import by each event type
|
||||
module. ``catalog`` is the process-wide singleton.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EventType:
|
||||
type: str
|
||||
description: str
|
||||
payload_model: type[BaseModel]
|
||||
|
||||
@property
|
||||
def payload_schema(self) -> dict[str, Any]:
|
||||
"""JSON Schema (draft 2020-12) derived from ``payload_model``."""
|
||||
return self.payload_model.model_json_schema()
|
||||
|
||||
|
||||
class EventCatalog:
|
||||
"""Registry of known event types. Populated at import; read at runtime."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._registry: dict[str, EventType] = {}
|
||||
|
||||
def register(self, event_type: EventType) -> None:
|
||||
"""Register an event type. Raises on duplicate type."""
|
||||
if event_type.type in self._registry:
|
||||
raise ValueError(f"Event type already registered: {event_type.type!r}")
|
||||
self._registry[event_type.type] = event_type
|
||||
|
||||
def get(self, type_: str) -> EventType | None:
|
||||
return self._registry.get(type_)
|
||||
|
||||
def all(self) -> dict[str, EventType]:
|
||||
"""Defensive snapshot of the registry."""
|
||||
return dict(self._registry)
|
||||
|
||||
|
||||
catalog = EventCatalog()
|
||||
38
surfsense_backend/app/event_bus/event.py
Normal file
38
surfsense_backend/app/event_bus/event.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
"""The ``Event`` value object — the only shape that crosses the bus.
|
||||
|
||||
An immutable fact: something named happened, with this payload, in this space,
|
||||
at this time. JSON round-trippable so a subscriber can queue it to a worker.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
def _new_event_id() -> str:
|
||||
return uuid.uuid4().hex
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
"""A published domain fact.
|
||||
|
||||
``event_type`` is a dotted namespace (``document.indexed``, etc). ``payload`` is
|
||||
JSON-serializable. ``search_space_id`` scopes delivery. ``event_id`` and
|
||||
``occurred_at`` are engine-stamped.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
event_type: str
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
search_space_id: int
|
||||
event_id: str = Field(default_factory=_new_event_id)
|
||||
occurred_at: datetime = Field(default_factory=_now)
|
||||
5
surfsense_backend/app/event_bus/events/__init__.py
Normal file
5
surfsense_backend/app/event_bus/events/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""Domain event type definitions — each in its own module, self-registering at import."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from . import document_entered_folder # noqa: F401
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
"""``document.entered_folder``: a document became a member of a folder.
|
||||
|
||||
Fires once per arrival, however the document got there (upload, AI sort, move).
|
||||
The payload carries the fields a user can filter a trigger on.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
|
||||
from app.event_bus.catalog import EventType, catalog
|
||||
|
||||
EVENT_TYPE = "document.entered_folder"
|
||||
|
||||
|
||||
class DocumentEnteredFolderPayload(BaseModel):
|
||||
"""Snapshot of the document at the moment it entered ``folder_id``.
|
||||
|
||||
``previous_folder_id`` is the folder it left, or ``None`` for a first
|
||||
placement. ``is_move`` derives from it and is emitted for filtering.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
document_id: int
|
||||
folder_id: int
|
||||
previous_folder_id: int | None = None
|
||||
document_type: str
|
||||
title: str
|
||||
connector_id: int | None = None
|
||||
created_by_id: str | None = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def is_move(self) -> bool:
|
||||
return self.previous_folder_id is not None
|
||||
|
||||
|
||||
catalog.register(
|
||||
EventType(
|
||||
type=EVENT_TYPE,
|
||||
description="A document became a member of a folder.",
|
||||
payload_model=DocumentEnteredFolderPayload,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def payload_if_entered_folder(
|
||||
*,
|
||||
document_id: int,
|
||||
search_space_id: int,
|
||||
new_folder_id: int | None,
|
||||
previous_folder_id: int | None,
|
||||
folder_id_changed: bool,
|
||||
status_state: str,
|
||||
document_type: str,
|
||||
title: str,
|
||||
connector_id: int | None,
|
||||
created_by_id: str | None,
|
||||
) -> dict | None:
|
||||
"""Return a publish payload if this commit represents a folder arrival, else None.
|
||||
|
||||
``folder_id_changed`` comes from SQLAlchemy attribute history — it is True
|
||||
only when ``folder_id`` actually changed in this transaction, preventing
|
||||
spurious events on unrelated saves.
|
||||
"""
|
||||
if not folder_id_changed:
|
||||
return None
|
||||
if new_folder_id is None:
|
||||
return None
|
||||
if status_state != "ready":
|
||||
return None
|
||||
|
||||
return {
|
||||
"event_type": EVENT_TYPE,
|
||||
"search_space_id": search_space_id,
|
||||
"payload": {
|
||||
"document_id": document_id,
|
||||
"folder_id": new_folder_id,
|
||||
"previous_folder_id": previous_folder_id,
|
||||
"document_type": document_type,
|
||||
"title": title,
|
||||
"connector_id": connector_id,
|
||||
"created_by_id": created_by_id,
|
||||
},
|
||||
}
|
||||
|
|
@ -525,11 +525,8 @@ async def bulk_move_documents(
|
|||
detail="Cannot move documents to a folder in a different search space",
|
||||
)
|
||||
|
||||
await session.execute(
|
||||
Document.__table__.update()
|
||||
.where(Document.id.in_(request.document_ids))
|
||||
.values(folder_id=request.folder_id)
|
||||
)
|
||||
for doc in documents:
|
||||
doc.folder_id = request.folder_id
|
||||
await session.commit()
|
||||
return {"message": f"{len(request.document_ids)} documents moved successfully"}
|
||||
|
||||
|
|
|
|||
101
surfsense_backend/app/session_events.py
Normal file
101
surfsense_backend/app/session_events.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
"""SQLAlchemy session event hooks — wired once at app startup.
|
||||
|
||||
Detects document folder arrivals across every ORM commit and publishes
|
||||
``document.entered_folder`` events to the bus after the transaction is durable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.orm import Session, attributes
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
from app.event_bus.bus import EventBus, bus as default_bus
|
||||
from app.event_bus.events.document_entered_folder import payload_if_entered_folder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PENDING_KEY = "_entered_folder_pending"
|
||||
|
||||
|
||||
def _after_flush(session: Session, flush_context: object) -> None:
|
||||
"""Collect folder-arrival candidates while attribute history is still available."""
|
||||
pending: list[dict] = []
|
||||
|
||||
for obj in list(session.new) + list(session.dirty):
|
||||
if not isinstance(obj, Document):
|
||||
continue
|
||||
|
||||
history = attributes.get_history(obj, "folder_id")
|
||||
if not history.added:
|
||||
continue
|
||||
|
||||
new_folder_id = history.added[0]
|
||||
previous_folder_id = history.deleted[0] if history.deleted else None
|
||||
|
||||
result = payload_if_entered_folder(
|
||||
document_id=obj.id,
|
||||
search_space_id=obj.search_space_id,
|
||||
new_folder_id=new_folder_id,
|
||||
previous_folder_id=previous_folder_id,
|
||||
folder_id_changed=True,
|
||||
status_state=DocumentStatus.get_state(obj.status) or "",
|
||||
document_type=obj.document_type.value if obj.document_type else "",
|
||||
title=obj.title or "",
|
||||
connector_id=obj.connector_id,
|
||||
created_by_id=str(obj.created_by_id) if obj.created_by_id else None,
|
||||
)
|
||||
if result is not None:
|
||||
pending.append(result)
|
||||
|
||||
setattr(session, _PENDING_KEY, pending)
|
||||
|
||||
|
||||
def _after_commit(session: Session) -> None:
|
||||
"""Publish collected events now that the transaction is durable."""
|
||||
pending: list[dict] = getattr(session, _PENDING_KEY, [])
|
||||
if not pending:
|
||||
return
|
||||
setattr(session, _PENDING_KEY, [])
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
logger.warning("No running event loop — skipping %d event(s)", len(pending))
|
||||
return
|
||||
|
||||
tasks = [
|
||||
loop.create_task(
|
||||
default_bus.publish(
|
||||
item["event_type"],
|
||||
item["payload"],
|
||||
search_space_id=item["search_space_id"],
|
||||
)
|
||||
)
|
||||
for item in pending
|
||||
]
|
||||
for task in tasks:
|
||||
task.add_done_callback(
|
||||
lambda t: logger.error("event publish failed: %s", t.exception())
|
||||
if not t.cancelled() and t.exception()
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
def _after_rollback(session: Session) -> None:
|
||||
"""Discard any pending events — the transaction did not commit."""
|
||||
setattr(session, _PENDING_KEY, [])
|
||||
|
||||
|
||||
def register_session_hooks(bus: EventBus = default_bus) -> None:
|
||||
"""Register document folder-arrival hooks on the SQLAlchemy Session class.
|
||||
|
||||
Call once at application startup (e.g. in ``app.app`` lifespan). Idempotent
|
||||
— SQLAlchemy deduplicates identical listener registrations.
|
||||
"""
|
||||
event.listen(Session, "after_flush", _after_flush)
|
||||
event.listen(Session, "after_commit", _after_commit)
|
||||
event.listen(Session, "after_rollback", _after_rollback)
|
||||
|
|
@ -13,7 +13,7 @@ from typing import Any
|
|||
|
||||
import pytest
|
||||
|
||||
from app.automations.actions.agent_task.auto_decide import build_auto_decisions
|
||||
from app.automations.actions.builtin.agent_task.auto_decide import build_auto_decisions
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
|
@ -10,7 +10,9 @@ from __future__ import annotations
|
|||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from app.automations.actions.agent_task.finalize import extract_final_assistant_message
|
||||
from app.automations.actions.builtin.agent_task.finalize import (
|
||||
extract_final_assistant_message,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
|
@ -1,10 +1,8 @@
|
|||
"""Lock the input-validation contract used by ``dispatch_run``.
|
||||
"""Lock the input-validation contract enforced before a run is enqueued.
|
||||
|
||||
``_validate_inputs`` is module-internal by convention (underscore), but it
|
||||
encodes a real behavior contract the rest of the system depends on, and the
|
||||
public alternative (``dispatch_run``) requires a real DB session. Tests
|
||||
target the pure function directly; the contract — not the symbol — is what's
|
||||
locked.
|
||||
``validate_inputs`` is the pure schema check that ``enqueue_run`` runs against
|
||||
merged inputs. ``enqueue_run`` itself needs a real DB session, so tests target
|
||||
this pure function directly; the contract — not the symbol — is what's locked.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -12,7 +10,7 @@ from __future__ import annotations
|
|||
import pytest
|
||||
|
||||
from app.automations.dispatch.errors import DispatchError
|
||||
from app.automations.dispatch.run import _validate_inputs
|
||||
from app.automations.dispatch.inputs import validate_inputs
|
||||
from app.automations.schemas.definition.envelope import AutomationDefinition
|
||||
from app.automations.schemas.definition.inputs import Inputs
|
||||
from app.automations.schemas.definition.plan_step import PlanStep
|
||||
|
|
@ -42,7 +40,7 @@ def test_validate_inputs_passes_through_when_no_schema_is_declared() -> None:
|
|||
"static_key": "value",
|
||||
}
|
||||
|
||||
assert _validate_inputs(definition, runtime_inputs) == runtime_inputs
|
||||
assert validate_inputs(definition, runtime_inputs) == runtime_inputs
|
||||
|
||||
|
||||
def test_validate_inputs_returns_inputs_when_they_match_declared_schema() -> None:
|
||||
|
|
@ -58,14 +56,13 @@ def test_validate_inputs_returns_inputs_when_they_match_declared_schema() -> Non
|
|||
|
||||
inputs = {"topic": "weekly report"}
|
||||
|
||||
assert _validate_inputs(definition, inputs) == inputs
|
||||
assert validate_inputs(definition, inputs) == inputs
|
||||
|
||||
|
||||
def test_validate_inputs_raises_dispatch_error_when_inputs_violate_schema() -> None:
|
||||
"""Inputs that don't match the declared schema must surface as
|
||||
``DispatchError`` (not the raw ``jsonschema.ValidationError``), so the
|
||||
schedule tick and any other caller can handle one dispatch-domain
|
||||
exception type uniformly."""
|
||||
``DispatchError`` (not the raw ``jsonschema.ValidationError``), so every
|
||||
caller can handle one dispatch-domain exception type uniformly."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"topic": {"type": "string"}},
|
||||
|
|
@ -74,4 +71,4 @@ def test_validate_inputs_raises_dispatch_error_when_inputs_violate_schema() -> N
|
|||
definition = _minimal_definition(inputs=Inputs(schema=schema))
|
||||
|
||||
with pytest.raises(DispatchError):
|
||||
_validate_inputs(definition, {"topic": 42}) # type violates string
|
||||
validate_inputs(definition, {"topic": 42}) # type violates string
|
||||
|
|
@ -39,7 +39,7 @@ def test_run_status_string_values_are_stable() -> None:
|
|||
|
||||
|
||||
def test_trigger_type_keeps_manual_member_even_though_unregistered() -> None:
|
||||
"""``MANUAL`` is reserved (mirrors the Postgres enum) but the trigger
|
||||
store does not register it in v1. The enum must keep both members so
|
||||
existing DB rows and the schema migration plan stay valid."""
|
||||
assert {member.value for member in TriggerType} == {"schedule", "manual"}
|
||||
"""``schedule`` and ``event`` are registered; ``MANUAL`` is reserved
|
||||
(mirrors the Postgres enum) but the trigger store does not register it.
|
||||
The enum must keep every member so DB rows and migrations stay valid."""
|
||||
assert {member.value for member in TriggerType} == {"schedule", "event", "manual"}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
"""The ``event`` trigger self-registers on the triggers store at import."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.builtin.event.params import EventTriggerParams
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_event_trigger_is_registered() -> None:
|
||||
definition = get_trigger("event")
|
||||
|
||||
assert definition is not None
|
||||
assert definition.type == "event"
|
||||
assert definition.params_model is EventTriggerParams
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
"""Behavior tests for the ``matches`` filter grammar."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.triggers.builtin.event.filter import FilterError, matches
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_empty_filter_matches_any_payload() -> None:
|
||||
assert matches({}, {"document_id": 42, "document_type": "FILE"}) is True
|
||||
assert matches({}, {}) is True
|
||||
|
||||
|
||||
def test_scalar_value_is_implicit_equality() -> None:
|
||||
flt = {"document_type": "FILE"}
|
||||
assert matches(flt, {"document_type": "FILE"}) is True
|
||||
assert matches(flt, {"document_type": "WEBPAGE"}) is False
|
||||
|
||||
|
||||
def test_multiple_fields_are_anded() -> None:
|
||||
flt = {"document_type": "FILE", "search_space_id": 7}
|
||||
assert matches(flt, {"document_type": "FILE", "search_space_id": 7}) is True
|
||||
assert matches(flt, {"document_type": "FILE", "search_space_id": 9}) is False
|
||||
|
||||
|
||||
def test_gt_operator_compares_greater_than() -> None:
|
||||
flt = {"page_count": {"$gt": 10}}
|
||||
assert matches(flt, {"page_count": 20}) is True
|
||||
assert matches(flt, {"page_count": 10}) is False
|
||||
assert matches(flt, {"page_count": 5}) is False
|
||||
|
||||
|
||||
def test_remaining_comparison_operators() -> None:
|
||||
assert matches({"n": {"$gte": 10}}, {"n": 10}) is True
|
||||
assert matches({"n": {"$gte": 10}}, {"n": 9}) is False
|
||||
|
||||
assert matches({"n": {"$lt": 10}}, {"n": 9}) is True
|
||||
assert matches({"n": {"$lt": 10}}, {"n": 10}) is False
|
||||
|
||||
assert matches({"n": {"$lte": 10}}, {"n": 10}) is True
|
||||
assert matches({"n": {"$lte": 10}}, {"n": 11}) is False
|
||||
|
||||
assert matches({"s": {"$eq": "FILE"}}, {"s": "FILE"}) is True
|
||||
assert matches({"s": {"$eq": "FILE"}}, {"s": "WEB"}) is False
|
||||
|
||||
assert matches({"s": {"$ne": "FILE"}}, {"s": "WEB"}) is True
|
||||
assert matches({"s": {"$ne": "FILE"}}, {"s": "FILE"}) is False
|
||||
|
||||
|
||||
def test_multiple_operators_on_one_field_are_anded() -> None:
|
||||
flt = {"n": {"$gte": 10, "$lt": 20}}
|
||||
assert matches(flt, {"n": 15}) is True
|
||||
assert matches(flt, {"n": 10}) is True
|
||||
assert matches(flt, {"n": 20}) is False
|
||||
assert matches(flt, {"n": 5}) is False
|
||||
|
||||
|
||||
def test_in_and_nin_membership_operators() -> None:
|
||||
flt_in = {"document_type": {"$in": ["FILE", "WEBPAGE"]}}
|
||||
assert matches(flt_in, {"document_type": "FILE"}) is True
|
||||
assert matches(flt_in, {"document_type": "SLACK"}) is False
|
||||
|
||||
flt_nin = {"document_type": {"$nin": ["FILE", "WEBPAGE"]}}
|
||||
assert matches(flt_nin, {"document_type": "SLACK"}) is True
|
||||
assert matches(flt_nin, {"document_type": "FILE"}) is False
|
||||
|
||||
|
||||
def test_or_matches_when_any_branch_holds() -> None:
|
||||
flt = {"$or": [{"document_type": "FILE"}, {"document_type": "WEBPAGE"}]}
|
||||
assert matches(flt, {"document_type": "WEBPAGE"}) is True
|
||||
assert matches(flt, {"document_type": "SLACK"}) is False
|
||||
|
||||
|
||||
def test_and_matches_when_every_branch_holds() -> None:
|
||||
flt = {"$and": [{"n": {"$gt": 5}}, {"n": {"$lt": 10}}]}
|
||||
assert matches(flt, {"n": 7}) is True
|
||||
assert matches(flt, {"n": 12}) is False
|
||||
|
||||
|
||||
def test_not_inverts_its_subexpression() -> None:
|
||||
flt = {"$not": {"document_type": "FILE"}}
|
||||
assert matches(flt, {"document_type": "WEBPAGE"}) is True
|
||||
assert matches(flt, {"document_type": "FILE"}) is False
|
||||
|
||||
|
||||
def test_missing_field_never_matches_and_never_raises() -> None:
|
||||
# Conservative: an absent field fails the constraint, and comparisons must
|
||||
# not raise on the missing value — including $ne (absence isn't "not equal").
|
||||
assert matches({"document_type": "FILE"}, {}) is False
|
||||
assert matches({"page_count": {"$gt": 5}}, {}) is False
|
||||
assert matches({"document_type": {"$in": ["FILE"]}}, {}) is False
|
||||
assert matches({"document_type": {"$ne": "FILE"}}, {}) is False
|
||||
|
||||
|
||||
def test_logical_operators_compose_with_fields() -> None:
|
||||
flt = {
|
||||
"search_space_id": 7,
|
||||
"$or": [{"document_type": "FILE"}, {"document_type": "WEBPAGE"}],
|
||||
}
|
||||
assert matches(flt, {"search_space_id": 7, "document_type": "FILE"}) is True
|
||||
assert matches(flt, {"search_space_id": 9, "document_type": "FILE"}) is False
|
||||
assert matches(flt, {"search_space_id": 7, "document_type": "SLACK"}) is False
|
||||
|
||||
|
||||
def test_unknown_field_operator_raises_filter_error() -> None:
|
||||
with pytest.raises(FilterError):
|
||||
matches({"n": {"$regex": "x"}}, {"n": "xyz"})
|
||||
|
||||
|
||||
def test_unknown_logical_operator_raises_filter_error() -> None:
|
||||
with pytest.raises(FilterError):
|
||||
matches({"$nor": [{"document_type": "FILE"}]}, {"document_type": "FILE"})
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
"""An event hands its payload + metadata to the run as inputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.triggers.builtin.event.inputs import event_runtime_inputs
|
||||
from app.event_bus import Event
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_runtime_inputs_flatten_payload_with_event_metadata() -> None:
|
||||
event = Event(
|
||||
event_type="document.indexed",
|
||||
payload={"document_id": 42, "document_type": "FILE"},
|
||||
search_space_id=7,
|
||||
)
|
||||
|
||||
inputs = event_runtime_inputs(event)
|
||||
|
||||
assert inputs["document_id"] == 42
|
||||
assert inputs["document_type"] == "FILE"
|
||||
assert inputs["event_type"] == "document.indexed"
|
||||
assert inputs["event_id"] == event.event_id
|
||||
assert inputs["occurred_at"] == event.occurred_at.isoformat()
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""Which triggers an event fires: event_type equality + filter match."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.triggers.builtin.event.match import trigger_matches_event
|
||||
from app.event_bus import Event
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _event(event_type: str = "document.indexed", **payload) -> Event:
|
||||
return Event(event_type=event_type, payload=payload, search_space_id=7)
|
||||
|
||||
|
||||
def test_matches_when_event_type_equal_and_filter_passes() -> None:
|
||||
params = {"event_type": "document.indexed", "filter": {"document_type": "FILE"}}
|
||||
assert trigger_matches_event(params, _event(document_type="FILE")) is True
|
||||
|
||||
|
||||
def test_no_match_when_event_type_differs() -> None:
|
||||
params = {"event_type": "document.indexed", "filter": {}}
|
||||
assert trigger_matches_event(params, _event("podcast.generated")) is False
|
||||
|
||||
|
||||
def test_no_match_when_filter_rejects_payload() -> None:
|
||||
params = {"event_type": "document.indexed", "filter": {"document_type": "FILE"}}
|
||||
assert trigger_matches_event(params, _event(document_type="WEBPAGE")) is False
|
||||
|
||||
|
||||
def test_empty_filter_matches_any_payload_of_that_type() -> None:
|
||||
params = {"event_type": "document.indexed", "filter": {}}
|
||||
assert trigger_matches_event(params, _event(document_type="ANYTHING")) is True
|
||||
|
||||
|
||||
def test_missing_filter_key_is_treated_as_empty() -> None:
|
||||
params = {"event_type": "document.indexed"}
|
||||
assert trigger_matches_event(params, _event(document_type="X")) is True
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
"""``EventTriggerParams`` contract: an event_type to listen for + an optional filter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.triggers.builtin.event.params import EventTriggerParams
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_accepts_event_type_and_filter() -> None:
|
||||
params = EventTriggerParams(
|
||||
event_type="document.indexed",
|
||||
filter={"document_type": "FILE"},
|
||||
)
|
||||
|
||||
assert params.event_type == "document.indexed"
|
||||
assert params.filter == {"document_type": "FILE"}
|
||||
|
||||
|
||||
def test_filter_defaults_to_empty() -> None:
|
||||
params = EventTriggerParams(event_type="document.indexed")
|
||||
|
||||
assert params.filter == {}
|
||||
|
||||
|
||||
def test_event_type_is_required() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
EventTriggerParams(filter={"x": 1})
|
||||
|
||||
|
||||
def test_event_type_must_not_be_blank() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
EventTriggerParams(event_type="")
|
||||
|
||||
|
||||
def test_extra_keys_are_forbidden() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
EventTriggerParams(event_type="document.indexed", typo=True)
|
||||
|
|
@ -6,7 +6,7 @@ from datetime import UTC, datetime
|
|||
|
||||
import pytest
|
||||
|
||||
from app.automations.triggers.schedule.cron import (
|
||||
from app.automations.triggers.builtin.schedule.cron import (
|
||||
InvalidCronError,
|
||||
compute_next_fire_at,
|
||||
validate_cron,
|
||||
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.automations.triggers.schedule.params import ScheduleTriggerParams
|
||||
from app.automations.triggers.builtin.schedule.params import ScheduleTriggerParams
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
0
surfsense_backend/tests/unit/event_bus/__init__.py
Normal file
0
surfsense_backend/tests/unit/event_bus/__init__.py
Normal file
25
surfsense_backend/tests/unit/event_bus/conftest.py
Normal file
25
surfsense_backend/tests/unit/event_bus/conftest.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
"""Shared fixtures for the ``app.event_bus`` unit-test tree.
|
||||
|
||||
The event-type catalog is a module-level registry populated at import. Tests
|
||||
that register their own event types (or assert on registry contents) snapshot
|
||||
and restore it so state never leaks between tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
|
||||
from app.event_bus.catalog import catalog
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_event_catalog() -> Iterator[None]:
|
||||
"""Snapshot and restore the event-type catalog around a test."""
|
||||
snapshot = dict(catalog._registry)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
catalog._registry.clear()
|
||||
catalog._registry.update(snapshot)
|
||||
181
surfsense_backend/tests/unit/event_bus/test_bus.py
Normal file
181
surfsense_backend/tests/unit/event_bus/test_bus.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
"""``EventBus`` contract: subscribe, publish (stamp + fan out), dispatch.
|
||||
|
||||
Each test uses a fresh ``EventBus`` — no shared global state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.event_bus import Event, EventBus
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _event() -> Event:
|
||||
return Event(event_type="x.happened", payload={"k": "v"}, search_space_id=1)
|
||||
|
||||
|
||||
async def _noop(_event: Event) -> None:
|
||||
return None
|
||||
|
||||
|
||||
async def _other(_event: Event) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# --- registry -------------------------------------------------------------
|
||||
|
||||
|
||||
def test_subscribe_then_subscribers_returns_the_handler() -> None:
|
||||
bus = EventBus()
|
||||
bus.subscribe(_noop)
|
||||
|
||||
assert _noop in bus.subscribers()
|
||||
|
||||
|
||||
def test_subscribe_is_idempotent_for_the_same_handler() -> None:
|
||||
"""Registering the same handler twice must not make it fire twice."""
|
||||
bus = EventBus()
|
||||
bus.subscribe(_noop)
|
||||
bus.subscribe(_noop)
|
||||
|
||||
assert bus.subscribers().count(_noop) == 1
|
||||
|
||||
|
||||
def test_distinct_handlers_both_register() -> None:
|
||||
bus = EventBus()
|
||||
bus.subscribe(_noop)
|
||||
bus.subscribe(_other)
|
||||
|
||||
registered = bus.subscribers()
|
||||
assert _noop in registered
|
||||
assert _other in registered
|
||||
|
||||
|
||||
def test_subscribers_returns_a_defensive_snapshot() -> None:
|
||||
"""Mutating the returned list must not corrupt the registry."""
|
||||
bus = EventBus()
|
||||
bus.subscribe(_noop)
|
||||
|
||||
snapshot = bus.subscribers()
|
||||
snapshot.clear()
|
||||
|
||||
assert _noop in bus.subscribers()
|
||||
|
||||
|
||||
def test_subscribe_returns_handler_so_it_can_be_used_as_a_decorator() -> None:
|
||||
bus = EventBus()
|
||||
returned = bus.subscribe(_other)
|
||||
|
||||
assert returned is _other
|
||||
|
||||
|
||||
def test_two_buses_do_not_share_subscribers() -> None:
|
||||
"""The registry is per-instance, not global."""
|
||||
a = EventBus()
|
||||
b = EventBus()
|
||||
a.subscribe(_noop)
|
||||
|
||||
assert _noop in a.subscribers()
|
||||
assert _noop not in b.subscribers()
|
||||
|
||||
|
||||
# --- dispatch -------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_dispatch_delivers_event_to_every_subscriber() -> None:
|
||||
bus = EventBus()
|
||||
seen: list[tuple[str, Event]] = []
|
||||
|
||||
async def first(event: Event) -> None:
|
||||
seen.append(("first", event))
|
||||
|
||||
async def second(event: Event) -> None:
|
||||
seen.append(("second", event))
|
||||
|
||||
bus.subscribe(first)
|
||||
bus.subscribe(second)
|
||||
|
||||
event = _event()
|
||||
await bus.dispatch(event)
|
||||
|
||||
assert ("first", event) in seen
|
||||
assert ("second", event) in seen
|
||||
|
||||
|
||||
async def test_dispatch_isolates_a_failing_subscriber() -> None:
|
||||
"""A subscriber that raises must not stop a healthy one from running."""
|
||||
bus = EventBus()
|
||||
healthy_ran = False
|
||||
|
||||
async def boom(_event: Event) -> None:
|
||||
raise RuntimeError("subscriber blew up")
|
||||
|
||||
async def healthy(_event: Event) -> None:
|
||||
nonlocal healthy_ran
|
||||
healthy_ran = True
|
||||
|
||||
bus.subscribe(boom)
|
||||
bus.subscribe(healthy)
|
||||
|
||||
await bus.dispatch(_event())
|
||||
|
||||
assert healthy_ran is True
|
||||
|
||||
|
||||
async def test_dispatch_never_propagates_subscriber_errors() -> None:
|
||||
"""``dispatch`` itself must not raise even if every subscriber fails."""
|
||||
bus = EventBus()
|
||||
|
||||
async def boom(_event: Event) -> None:
|
||||
raise ValueError("nope")
|
||||
|
||||
bus.subscribe(boom)
|
||||
|
||||
await bus.dispatch(_event()) # must not raise
|
||||
|
||||
|
||||
async def test_dispatch_with_no_subscribers_is_a_noop() -> None:
|
||||
bus = EventBus()
|
||||
await bus.dispatch(_event()) # must not raise
|
||||
|
||||
|
||||
# --- publish --------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_publish_builds_a_stamped_event_and_fans_it_out() -> None:
|
||||
bus = EventBus()
|
||||
received: list[Event] = []
|
||||
|
||||
async def handler(event: Event) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(handler)
|
||||
await bus.publish("document.indexed", {"document_id": 42}, search_space_id=7)
|
||||
|
||||
assert len(received) == 1
|
||||
event = received[0]
|
||||
assert event.event_type == "document.indexed"
|
||||
assert event.payload == {"document_id": 42}
|
||||
assert event.search_space_id == 7
|
||||
# Engine-stamped identity/time on the way through.
|
||||
assert event.event_id
|
||||
assert event.occurred_at
|
||||
|
||||
|
||||
async def test_publish_defaults_payload_to_empty_dict() -> None:
|
||||
bus = EventBus()
|
||||
received: list[Event] = []
|
||||
|
||||
async def handler(event: Event) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(handler)
|
||||
await bus.publish("x.happened", search_space_id=1)
|
||||
|
||||
assert received[0].payload == {}
|
||||
|
||||
|
||||
async def test_publish_with_no_subscribers_is_a_noop() -> None:
|
||||
await EventBus().publish("x.happened", search_space_id=1) # must not raise
|
||||
73
surfsense_backend/tests/unit/event_bus/test_catalog.py
Normal file
73
surfsense_backend/tests/unit/event_bus/test_catalog.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""EventCatalog contract: register, look up, snapshot, derive schema."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.event_bus.catalog import EventCatalog, EventType
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _SamplePayload(BaseModel):
|
||||
document_id: int
|
||||
|
||||
|
||||
def _event_type(type_: str = "test.thing") -> EventType:
|
||||
return EventType(
|
||||
type=type_,
|
||||
description="A thing happened.",
|
||||
payload_model=_SamplePayload,
|
||||
)
|
||||
|
||||
|
||||
def test_register_then_get_returns_the_event_type(isolated_event_catalog: None) -> None:
|
||||
from app.event_bus.catalog import catalog
|
||||
catalog.register(_event_type())
|
||||
|
||||
assert catalog.get("test.thing") is not None
|
||||
assert catalog.get("test.thing").type == "test.thing"
|
||||
|
||||
|
||||
def test_get_unknown_type_returns_none(isolated_event_catalog: None) -> None:
|
||||
from app.event_bus.catalog import catalog
|
||||
assert catalog.get("does.not.exist") is None
|
||||
|
||||
|
||||
def test_register_duplicate_type_raises(isolated_event_catalog: None) -> None:
|
||||
"""A type is a contract; registering it twice is a bug, not an override."""
|
||||
from app.event_bus.catalog import catalog
|
||||
catalog.register(_event_type())
|
||||
|
||||
with pytest.raises(ValueError, match="already registered"):
|
||||
catalog.register(_event_type())
|
||||
|
||||
|
||||
def test_all_is_a_defensive_snapshot(isolated_event_catalog: None) -> None:
|
||||
"""Mutating the returned dict must not corrupt the registry."""
|
||||
from app.event_bus.catalog import catalog
|
||||
catalog.register(_event_type())
|
||||
|
||||
snapshot = catalog.all()
|
||||
snapshot.clear()
|
||||
|
||||
assert catalog.get("test.thing") is not None
|
||||
|
||||
|
||||
def test_payload_schema_is_derived_from_the_payload_model() -> None:
|
||||
"""The JSON Schema a UI/validator consumes comes from the payload model."""
|
||||
event_type = _event_type()
|
||||
|
||||
assert event_type.payload_schema == _SamplePayload.model_json_schema()
|
||||
|
||||
|
||||
def test_each_catalog_instance_has_its_own_registry() -> None:
|
||||
"""Two EventCatalog instances are fully independent."""
|
||||
a = EventCatalog()
|
||||
b = EventCatalog()
|
||||
|
||||
a.register(_event_type())
|
||||
|
||||
assert a.get("test.thing") is not None
|
||||
assert b.get("test.thing") is None
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
"""``document.entered_folder`` payload contract + catalog registration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.event_bus.catalog import catalog
|
||||
from app.event_bus.events.document_entered_folder import (
|
||||
EVENT_TYPE,
|
||||
DocumentEnteredFolderPayload,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _payload(**overrides: object) -> DocumentEnteredFolderPayload:
|
||||
base: dict[str, object] = {
|
||||
"document_id": 42,
|
||||
"folder_id": 7,
|
||||
"document_type": "FILE",
|
||||
"title": "Q3 report.pdf",
|
||||
}
|
||||
base.update(overrides)
|
||||
return DocumentEnteredFolderPayload(**base)
|
||||
|
||||
|
||||
def test_payload_carries_the_filterable_fields() -> None:
|
||||
payload = _payload(connector_id=12, created_by_id="abc")
|
||||
|
||||
assert payload.document_id == 42
|
||||
assert payload.folder_id == 7
|
||||
assert payload.document_type == "FILE"
|
||||
assert payload.connector_id == 12
|
||||
|
||||
|
||||
def test_first_placement_is_not_a_move() -> None:
|
||||
"""No previous folder (created or AI-sorted into place) → not a move."""
|
||||
assert _payload(previous_folder_id=None).is_move is False
|
||||
|
||||
|
||||
def test_change_between_folders_is_a_move() -> None:
|
||||
assert _payload(previous_folder_id=3).is_move is True
|
||||
|
||||
|
||||
def test_is_move_is_serialized_for_filtering() -> None:
|
||||
"""Filters match against the dumped payload, so ``is_move`` must appear there."""
|
||||
dumped = _payload(previous_folder_id=3).model_dump()
|
||||
|
||||
assert dumped["is_move"] is True
|
||||
|
||||
|
||||
def test_event_type_is_registered_in_the_catalog() -> None:
|
||||
registered = catalog.get(EVENT_TYPE)
|
||||
|
||||
assert registered is not None
|
||||
assert registered.payload_model is DocumentEnteredFolderPayload
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
"""payload_if_entered_folder: decides whether a document commit warrants an event."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.event_bus.events.document_entered_folder import payload_if_entered_folder
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _call(**overrides: Any) -> dict[str, Any] | None:
|
||||
defaults: dict[str, Any] = {
|
||||
"document_id": 1,
|
||||
"search_space_id": 10,
|
||||
"new_folder_id": 7,
|
||||
"previous_folder_id": None,
|
||||
"folder_id_changed": True,
|
||||
"status_state": "ready",
|
||||
"document_type": "FILE",
|
||||
"title": "report.pdf",
|
||||
"connector_id": None,
|
||||
"created_by_id": None,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return payload_if_entered_folder(**defaults)
|
||||
|
||||
|
||||
def test_folder_set_ready_fires() -> None:
|
||||
result = _call()
|
||||
|
||||
assert result is not None
|
||||
assert result["event_type"] == "document.entered_folder"
|
||||
assert result["search_space_id"] == 10
|
||||
assert result["payload"]["folder_id"] == 7
|
||||
assert result["payload"]["previous_folder_id"] is None
|
||||
|
||||
|
||||
def test_no_folder_is_silent() -> None:
|
||||
assert _call(new_folder_id=None) is None
|
||||
|
||||
|
||||
def test_not_ready_is_silent() -> None:
|
||||
assert _call(status_state="processing") is None
|
||||
|
||||
|
||||
def test_folder_unchanged_is_silent() -> None:
|
||||
assert _call(folder_id_changed=False) is None
|
||||
|
||||
|
||||
def test_move_carries_previous_folder_id() -> None:
|
||||
result = _call(previous_folder_id=3, new_folder_id=7)
|
||||
|
||||
assert result is not None
|
||||
assert result["payload"]["previous_folder_id"] == 3
|
||||
assert result["payload"]["folder_id"] == 7
|
||||
53
surfsense_backend/tests/unit/event_bus/test_event.py
Normal file
53
surfsense_backend/tests/unit/event_bus/test_event.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""``Event`` contract: carry caller facts + engine-stamped id/time, round-trip JSON."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.event_bus.event import Event
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_event_carries_caller_supplied_facts() -> None:
|
||||
"""The three caller inputs are stored verbatim."""
|
||||
event = Event(
|
||||
event_type="document.indexed",
|
||||
payload={"document_id": 42, "content_type": "pdf"},
|
||||
search_space_id=7,
|
||||
)
|
||||
|
||||
assert event.event_type == "document.indexed"
|
||||
assert event.payload == {"document_id": 42, "content_type": "pdf"}
|
||||
assert event.search_space_id == 7
|
||||
|
||||
|
||||
def test_event_stamps_identity_and_time_when_not_supplied() -> None:
|
||||
"""Engine stamps id + time so subscribers can dedup/order."""
|
||||
event = Event(event_type="x.happened", payload={}, search_space_id=1)
|
||||
|
||||
assert event.event_id
|
||||
assert isinstance(event.occurred_at, datetime)
|
||||
|
||||
|
||||
def test_event_ids_are_unique_per_instance() -> None:
|
||||
"""Two events published with identical content are still distinct facts."""
|
||||
first = Event(event_type="x.happened", payload={}, search_space_id=1)
|
||||
second = Event(event_type="x.happened", payload={}, search_space_id=1)
|
||||
|
||||
assert first.event_id != second.event_id
|
||||
|
||||
|
||||
def test_event_survives_json_round_trip() -> None:
|
||||
"""Serialize → deserialize reproduces the event (subscribers queue it as JSON)."""
|
||||
original = Event(
|
||||
event_type="podcast.generated",
|
||||
payload={"podcast_id": 9, "duration_s": 123.5},
|
||||
search_space_id=3,
|
||||
)
|
||||
|
||||
restored = Event.model_validate_json(original.model_dump_json())
|
||||
|
||||
assert restored == original
|
||||
Loading…
Add table
Add a link
Reference in a new issue