Merge remote-tracking branch 'upstream/dev' into feat/whatsapp-gateway-integration

This commit is contained in:
Anish Sarkar 2026-06-02 00:29:32 +05:30
commit e3de7c4667
465 changed files with 29171 additions and 6994 deletions

View file

@ -60,7 +60,6 @@ class TestReadOnlyToolsAllowed:
"glob",
"web_search",
"scrape_webpage",
"search_surfsense_docs",
"get_connected_accounts",
"write_todos",
"task",

View file

@ -22,12 +22,6 @@ from app.agents.new_chat.subagents.config import (
# ---------------------------------------------------------------------------
@tool
def search_surfsense_docs(query: str) -> str:
"""Search the user's KB."""
return ""
@tool
def web_search(query: str) -> str:
"""Search the public web."""
@ -95,7 +89,6 @@ def generate_report(topic: str) -> str:
ALL_TOOLS = [
search_surfsense_docs,
web_search,
scrape_webpage,
read_file,
@ -161,7 +154,7 @@ class TestReportWriterSubagent:
names = {t.name for t in spec["tools"]} # type: ignore[index]
assert names == REPORT_WRITER_TOOLS & {t.name for t in ALL_TOOLS}
assert "generate_report" in names
assert "search_surfsense_docs" in names
assert "read_file" in names
def test_deny_rules_block_writes_but_allow_generate_report(self) -> None:
spec = build_report_writer_subagent(tools=ALL_TOOLS)
@ -272,9 +265,9 @@ class TestFilterToolsWarningSuppression:
# Allowed set asks for two registry tools (one present, one
# not) plus a bunch of middleware-provided names.
_filter_tools(
[search_surfsense_docs],
[web_search],
allowed_names={
"search_surfsense_docs",
"web_search",
"scrape_webpage", # legitimately missing → should warn
"read_file", # mw-provided → suppressed
"ls",
@ -322,7 +315,6 @@ class TestDenyPatternsCoverage:
def test_deny_patterns_do_not_match_safe_read_tools(self) -> None:
canonical_reads = [
"search_surfsense_docs",
"read_file",
"ls_tree",
"grep",

View file

@ -0,0 +1,73 @@
"""Lock ``build_auto_decisions`` — the HITL auto-approve/reject wire mapper.
``build_auto_decisions`` walks ``state.interrupts`` (duck-typed) and produces
two parallel resume maps: one keyed by LangGraph ``Interrupt.id`` and one
keyed by ``tool_call_id`` for the subagent middleware bridge. Both carry
the same decision payload.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
import pytest
from app.automations.actions.builtin.agent_task.auto_decide import build_auto_decisions
pytestmark = pytest.mark.unit
def _state(interrupts: list[Any]) -> SimpleNamespace:
"""Build a duck-typed LangGraph state stub carrying ``interrupts``."""
return SimpleNamespace(interrupts=interrupts)
def _interrupt(*, id_: str, value: Any) -> SimpleNamespace:
"""Build a duck-typed interrupt with the canonical ``(id, value)`` shape."""
return SimpleNamespace(id=id_, value=value)
def test_build_auto_decisions_produces_one_decision_per_action_request() -> None:
"""An interrupt carrying N ``action_requests`` produces N decisions of
the requested type in both maps. This is the canonical batched-HITL
wire shape losing a decision would leave a pending action stuck."""
interrupt = _interrupt(
id_="lg-1",
value={
"tool_call_id": "tc-1",
"action_requests": [{"id": "a"}, {"id": "b"}],
},
)
lg_map, routed = build_auto_decisions(_state([interrupt]), "approve")
assert lg_map == {"lg-1": {"decisions": [{"type": "approve"}, {"type": "approve"}]}}
assert routed == {"tc-1": {"decisions": [{"type": "approve"}, {"type": "approve"}]}}
def test_build_auto_decisions_defaults_to_one_decision_for_scalar_interrupt() -> None:
"""When an interrupt's value has no ``action_requests`` list, the
function defaults to a single decision. Locks compatibility with
older single-action interrupt shapes still emitted by some tools."""
interrupt = _interrupt(id_="lg-2", value={"tool_call_id": "tc-2"})
lg_map, routed = build_auto_decisions(_state([interrupt]), "reject")
assert lg_map == {"lg-2": {"decisions": [{"type": "reject"}]}}
assert routed == {"tc-2": {"decisions": [{"type": "reject"}]}}
def test_build_auto_decisions_skips_interrupts_with_invalid_shape() -> None:
"""Interrupts missing the canonical ``(str id, dict value)`` shape are
skipped silently rather than crashing the resume loop. Locks the
resilience contract a malformed interrupt from a misbehaving tool
shouldn't take down the whole agent_task step."""
good = _interrupt(id_="lg-good", value={"tool_call_id": "tc-good"})
bad_value = _interrupt(id_="lg-bad-value", value="not a dict")
bad_id = _interrupt(id_=None, value={"tool_call_id": "tc-bad-id"}) # type: ignore[arg-type]
lg_map, routed = build_auto_decisions(_state([good, bad_value, bad_id]), "approve")
assert lg_map == {"lg-good": {"decisions": [{"type": "approve"}]}}
assert routed == {"tc-good": {"decisions": [{"type": "approve"}]}}

View file

@ -0,0 +1,174 @@
"""Lock the runtime model-policy backstop in ``build_dependencies``.
Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so
runs are insulated from later chat/search-space model changes), and the model
policy is re-checked at run time so a captured model that is no longer billable
fails the run clearly. When no snapshot is present, resolution falls back to the
live search space.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
import pytest
import app.automations.actions.agent_task.dependencies as deps_mod
from app.automations.actions.agent_task.dependencies import (
DependencyError,
build_dependencies,
)
from app.automations.services.model_policy import AutomationModelPolicyError
pytestmark = pytest.mark.unit
class _FakeSession:
"""Minimal async session whose ``get`` returns a preset search space."""
def __init__(self, search_space: Any) -> None:
self._search_space = search_space
async def get(self, _model: Any, _pk: int) -> Any:
return self._search_space
@pytest.fixture
def patched_side_effects(monkeypatch: pytest.MonkeyPatch):
"""Stub the connector setup + checkpointer so only policy/LLM logic runs."""
async def _fake_setup(_session, *, search_space_id):
return (SimpleNamespace(name="connector"), "fc-key")
monkeypatch.setattr(deps_mod, "setup_connector_and_firecrawl", _fake_setup)
return None
async def test_build_dependencies_resolves_captured_agent_llm_id(
monkeypatch: pytest.MonkeyPatch, patched_side_effects
) -> None:
"""The bundle loads with the *captured* ``agent_llm_id``, not the live search space."""
captured: dict[str, Any] = {}
async def _fake_load(_session, *, config_id, search_space_id):
captured["config_id"] = config_id
captured["search_space_id"] = search_space_id
return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None)
monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load)
# Captured path validates the explicit ids; passes for this test.
monkeypatch.setattr(deps_mod, "assert_models_billable", lambda **_kw: None)
# A different value on the live search space proves we ignore it when a
# snapshot is supplied.
monkeypatch.setattr(
deps_mod,
"assert_automation_models_billable",
lambda _ss: pytest.fail("search-space policy should not run on captured path"),
)
search_space = SimpleNamespace(agent_llm_id=-99)
result = await build_dependencies(
session=_FakeSession(search_space),
search_space_id=42,
agent_llm_id=-7,
image_generation_config_id=5,
vision_llm_config_id=-1,
)
assert captured == {"config_id": -7, "search_space_id": 42}
assert result.llm.name == "llm"
assert result.firecrawl_api_key == "fc-key"
async def test_build_dependencies_validates_captured_ids(
monkeypatch: pytest.MonkeyPatch, patched_side_effects
) -> None:
"""The captured ids (not the search space) are what gets policy-checked."""
seen: dict[str, Any] = {}
def _capture(**kwargs):
seen.update(kwargs)
monkeypatch.setattr(deps_mod, "assert_models_billable", _capture)
async def _fake_load(_session, *, config_id, search_space_id):
return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None)
monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load)
await build_dependencies(
session=_FakeSession(SimpleNamespace(agent_llm_id=0)),
search_space_id=42,
agent_llm_id=-7,
image_generation_config_id=5,
vision_llm_config_id=-1,
)
assert seen == {
"agent_llm_id": -7,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
}
async def test_build_dependencies_raises_on_captured_policy_violation(
monkeypatch: pytest.MonkeyPatch, patched_side_effects
) -> None:
"""A blocked captured model raises ``DependencyError`` so the step fails clearly."""
def _raise(**_kw):
raise AutomationModelPolicyError(
[{"kind": "image", "config_id": -2, "reason": "free model"}]
)
monkeypatch.setattr(deps_mod, "assert_models_billable", _raise)
monkeypatch.setattr(
deps_mod,
"load_llm_bundle",
lambda *a, **k: pytest.fail("load_llm_bundle should not be called"),
)
with pytest.raises(DependencyError):
await build_dependencies(
session=_FakeSession(SimpleNamespace(agent_llm_id=-7)),
search_space_id=42,
agent_llm_id=-7,
image_generation_config_id=-2,
vision_llm_config_id=-1,
)
async def test_build_dependencies_falls_back_to_search_space(
monkeypatch: pytest.MonkeyPatch, patched_side_effects
) -> None:
"""With no captured snapshot, resolve + validate the live search space."""
captured: dict[str, Any] = {}
async def _fake_load(_session, *, config_id, search_space_id):
captured["config_id"] = config_id
return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None)
monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load)
monkeypatch.setattr(deps_mod, "assert_automation_models_billable", lambda _ss: None)
monkeypatch.setattr(
deps_mod,
"assert_models_billable",
lambda **_kw: pytest.fail("captured policy should not run on fallback path"),
)
search_space = SimpleNamespace(agent_llm_id=-7)
result = await build_dependencies(
session=_FakeSession(search_space), search_space_id=42
)
assert captured == {"config_id": -7}
assert result.llm.name == "llm"
async def test_build_dependencies_raises_when_search_space_missing(
patched_side_effects,
) -> None:
"""A missing search space (fallback path) surfaces as a ``DependencyError``."""
with pytest.raises(DependencyError):
await build_dependencies(session=_FakeSession(None), search_space_id=999)

View file

@ -0,0 +1,86 @@
"""Lock ``extract_final_assistant_message`` — what surfaces in run output.
Each scenario is one shape the agent runtime is observed to produce.
Locking these means we can refactor the extractor without losing
backwards compatibility with already-stored ``run.output`` payloads.
"""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from app.automations.actions.builtin.agent_task.finalize import (
extract_final_assistant_message,
)
pytestmark = pytest.mark.unit
def test_extract_returns_last_ai_message_string_content() -> None:
"""The canonical shape: the agent's final ``AIMessage`` carries a
plain string. That string is returned verbatim, trimmed."""
result = {
"messages": [
HumanMessage(content="ask"),
AIMessage(content="the answer"),
]
}
assert extract_final_assistant_message(result) == "the answer"
def test_extract_concatenates_text_parts_and_skips_non_text_parts() -> None:
"""Multi-part AIMessage content (Anthropic / OpenAI list shape) joins
its ``text`` parts in order; non-text parts (tool_use, images, ...)
are skipped. Locks the wire shape used when the model emits tool
calls alongside narrative text in the same turn."""
result = {
"messages": [
AIMessage(
content=[
{"type": "text", "text": "Hello "},
{"type": "tool_use", "name": "search", "input": {}},
{"type": "text", "text": "world"},
]
)
]
}
assert extract_final_assistant_message(result) == "Hello world"
def test_extract_returns_last_ai_message_skipping_tool_messages() -> None:
"""When the transcript ends with tool calls and tool results, the
extractor still walks back to the **last** ``AIMessage`` (the agent's
final narrative answer). Locks resilience against trailing
``ToolMessage`` payloads in the transcript."""
result = {
"messages": [
HumanMessage(content="ask"),
AIMessage(content="thinking..."),
ToolMessage(content="tool output", tool_call_id="tc-1"),
AIMessage(content="final answer"),
ToolMessage(content="trailing tool noise", tool_call_id="tc-2"),
]
}
assert extract_final_assistant_message(result) == "final answer"
def test_extract_returns_none_when_no_assistant_text_is_present() -> None:
"""No ``AIMessage`` with extractable text → ``None`` rather than the
empty string. Lets callers branch on "did the agent actually say
anything?" rather than guess whether ``""`` means silence or empty
output. Empty-string contents are normalized to ``None`` too."""
no_ai = {"messages": [HumanMessage(content="just a question")]}
only_tools = {
"messages": [
AIMessage(content=[{"type": "tool_use", "name": "x", "input": {}}])
]
}
empty_string = {"messages": [AIMessage(content=" ")]}
assert extract_final_assistant_message(no_ai) is None
assert extract_final_assistant_message(only_tools) is None
assert extract_final_assistant_message(empty_string) is None

View file

@ -0,0 +1,39 @@
"""Shared fixtures for the ``app.automations`` unit-test tree.
Provides registry isolation: the built-in ``schedule`` trigger and
``agent_task`` action self-register at import time. Tests that register
additional triggers/actions (or assert on the registry contents) must
not leak that state to other tests. These fixtures snapshot and restore
the module-level registry dicts.
"""
from __future__ import annotations
from collections.abc import Iterator
import pytest
from app.automations.actions import store as action_store
from app.automations.triggers import store as trigger_store
@pytest.fixture
def isolated_action_registry() -> Iterator[None]:
"""Snapshot and restore the action registry around a test."""
snapshot = dict(action_store._REGISTRY)
try:
yield
finally:
action_store._REGISTRY.clear()
action_store._REGISTRY.update(snapshot)
@pytest.fixture
def isolated_trigger_registry() -> Iterator[None]:
"""Snapshot and restore the trigger registry around a test."""
snapshot = dict(trigger_store._REGISTRY)
try:
yield
finally:
trigger_store._REGISTRY.clear()
trigger_store._REGISTRY.update(snapshot)

View file

@ -0,0 +1,28 @@
"""Lock the ``DispatchError`` exception contract.
``DispatchError`` is the uniform exception type the dispatch layer raises
for any "cannot turn this fire request into a run" condition. Other
modules (templates of error envelopes, run records) compare on
``isinstance(exc, DispatchError)``, so the inheritance is the contract.
"""
from __future__ import annotations
import pytest
from app.automations.dispatch.errors import DispatchError
pytestmark = pytest.mark.unit
def test_dispatch_error_is_exception_subclass_and_carries_message() -> None:
"""Lifting a string into ``DispatchError`` preserves the message and
behaves as a regular ``Exception`` for ``isinstance`` / ``raise`` /
``except`` consumers."""
error = DispatchError("missing trigger")
assert isinstance(error, Exception)
assert str(error) == "missing trigger"
with pytest.raises(DispatchError):
raise error

View file

@ -0,0 +1,74 @@
"""Lock the input-validation contract enforced before a run is enqueued.
``validate_inputs`` is the pure schema check that ``enqueue_run`` runs against
merged inputs. ``enqueue_run`` itself needs a real DB session, so tests target
this pure function directly; the contract not the symbol is what's locked.
"""
from __future__ import annotations
import pytest
from app.automations.dispatch.errors import DispatchError
from app.automations.dispatch.inputs import validate_inputs
from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.schemas.definition.inputs import Inputs
from app.automations.schemas.definition.plan_step import PlanStep
pytestmark = pytest.mark.unit
def _minimal_definition(*, inputs: Inputs | None = None) -> AutomationDefinition:
"""One-step definition with an optional declared input schema."""
return AutomationDefinition(
name="test",
inputs=inputs,
plan=[PlanStep(step_id="s1", action="agent_task")],
)
def test_validate_inputs_passes_through_when_no_schema_is_declared() -> None:
"""When the definition declares no input schema, runtime inputs reach
the template context **unchanged**. Regression site: previously this
branch returned ``{}``, which stripped runtime keys like ``fired_at``
and ``last_fired_at`` and made Jinja blow up on ``{{ inputs.* }}``.
"""
definition = _minimal_definition(inputs=None)
runtime_inputs = {
"fired_at": "2026-01-01T00:00:00+00:00",
"last_fired_at": None,
"static_key": "value",
}
assert validate_inputs(definition, runtime_inputs) == runtime_inputs
def test_validate_inputs_returns_inputs_when_they_match_declared_schema() -> None:
"""With a declared JSON schema, inputs that satisfy it pass through
unchanged (validation succeeds; the function does not coerce or
strip extra fields not mentioned in the schema)."""
schema = {
"type": "object",
"properties": {"topic": {"type": "string"}},
"required": ["topic"],
}
definition = _minimal_definition(inputs=Inputs(schema=schema))
inputs = {"topic": "weekly report"}
assert validate_inputs(definition, inputs) == inputs
def test_validate_inputs_raises_dispatch_error_when_inputs_violate_schema() -> None:
"""Inputs that don't match the declared schema must surface as
``DispatchError`` (not the raw ``jsonschema.ValidationError``), so every
caller can handle one dispatch-domain exception type uniformly."""
schema = {
"type": "object",
"properties": {"topic": {"type": "string"}},
"required": ["topic"],
}
definition = _minimal_definition(inputs=Inputs(schema=schema))
with pytest.raises(DispatchError):
validate_inputs(definition, {"topic": 42}) # type violates string

View file

@ -0,0 +1,272 @@
"""Lock the ``execute_step`` orchestration contract.
Covers the pure step-execution logic: predicate gate, params rendering,
action lookup, retry budget, error shaping. The ``ActionContext.session``
is never touched by ``execute_step`` itself (it's only forwarded to the
handler), so unit tests pass ``None`` cast to the type.
"""
from __future__ import annotations
from typing import Any, cast
import pytest
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.actions.store import register_action
from app.automations.actions.types import ActionContext, ActionDefinition
from app.automations.runtime.step import execute_step
from app.automations.schemas.definition.plan_step import PlanStep
pytestmark = pytest.mark.unit
class _AnyParams(BaseModel):
"""Open params model used by test actions — they never validate."""
model_config = {"extra": "allow"}
def _action_context() -> ActionContext:
"""Minimal context: session is unused by ``execute_step``, only forwarded."""
return ActionContext(
session=cast(AsyncSession, None),
run_id=1,
step_id="s1",
search_space_id=1,
creator_user_id=None,
)
async def test_execute_step_runs_registered_action_handler_and_wraps_result(
isolated_action_registry: None,
) -> None:
"""A step pointing at a registered action runs its handler with the
step's params and returns a ``succeeded`` entry carrying the handler's
output plus ``attempts=1`` (one try, no retries triggered)."""
invocations: list[dict[str, Any]] = []
async def echo(params: dict[str, Any]) -> dict[str, Any]:
invocations.append(params)
return {"echoed": params["value"]}
register_action(
ActionDefinition(
type="test_echo",
name="Echo",
description="Test action.",
params_model=_AnyParams,
build_handler=lambda _ctx: echo,
)
)
step = PlanStep(step_id="s1", action="test_echo", params={"value": "hello"})
result = await execute_step(
step=step,
template_context={},
action_context=_action_context(),
default_max_retries=0,
default_retry_backoff="none",
default_timeout_seconds=30,
)
assert result["status"] == "succeeded"
assert result["step_id"] == "s1"
assert result["action"] == "test_echo"
assert result["attempts"] == 1
assert result["result"] == {"echoed": "hello"}
assert invocations == [{"value": "hello"}]
async def test_execute_step_skips_step_when_predicate_is_falsy(
isolated_action_registry: None,
) -> None:
"""If ``step.when`` evaluates to falsy in the template context, the
handler is **not** invoked, the result entry has ``status=skipped``
and ``attempts=0``, and no ``result`` key is present."""
invoked = False
async def must_not_run(_params: dict[str, Any]) -> dict[str, Any]:
nonlocal invoked
invoked = True
return {}
register_action(
ActionDefinition(
type="test_guarded",
name="Guarded",
description="Test action that should not run.",
params_model=_AnyParams,
build_handler=lambda _ctx: must_not_run,
)
)
step = PlanStep(
step_id="s1",
action="test_guarded",
when="inputs.enabled",
params={},
)
result = await execute_step(
step=step,
template_context={"inputs": {"enabled": False}},
action_context=_action_context(),
default_max_retries=0,
default_retry_backoff="none",
default_timeout_seconds=30,
)
assert result["status"] == "skipped"
assert result["attempts"] == 0
assert "result" not in result
assert invoked is False
async def test_execute_step_fails_when_step_references_an_unknown_action(
isolated_action_registry: None,
) -> None:
"""A step pointing at an action that isn't in the registry must fail
with ``ActionNotFound`` rather than crashing. Catches typos in the
plan and removed actions without the run going off the rails."""
step = PlanStep(step_id="s1", action="no_such_action", params={})
result = await execute_step(
step=step,
template_context={},
action_context=_action_context(),
default_max_retries=0,
default_retry_backoff="none",
default_timeout_seconds=30,
)
assert result["status"] == "failed"
assert result["attempts"] == 0
assert result["error"]["type"] == "ActionNotFound"
assert "no_such_action" in result["error"]["message"]
async def test_execute_step_retries_failing_handler_up_to_default_budget(
isolated_action_registry: None,
) -> None:
"""A handler that raises on every attempt consumes the retry budget
(1 initial try + ``default_max_retries`` retries) and the step ends
``failed`` with the exception's type and message surfaced through
the error envelope."""
calls = 0
async def always_fails(_params: dict[str, Any]) -> dict[str, Any]:
nonlocal calls
calls += 1
raise RuntimeError("boom")
register_action(
ActionDefinition(
type="test_fails",
name="Fails",
description="Always raises.",
params_model=_AnyParams,
build_handler=lambda _ctx: always_fails,
)
)
step = PlanStep(step_id="s1", action="test_fails", params={})
result = await execute_step(
step=step,
template_context={},
action_context=_action_context(),
default_max_retries=2,
default_retry_backoff="none",
default_timeout_seconds=30,
)
assert result["status"] == "failed"
assert result["attempts"] == 3
assert calls == 3
assert result["error"]["type"] == "RuntimeError"
assert "boom" in result["error"]["message"]
async def test_execute_step_succeeds_when_handler_recovers_within_retry_budget(
isolated_action_registry: None,
) -> None:
"""A handler that fails the first N times and then succeeds yields a
``succeeded`` entry with ``attempts == N + 1``. Locks that retries
can actually recover (not just exhaust)."""
calls = 0
async def flaky(_params: dict[str, Any]) -> dict[str, Any]:
nonlocal calls
calls += 1
if calls < 3:
raise RuntimeError("transient")
return {"ok": True}
register_action(
ActionDefinition(
type="test_flaky",
name="Flaky",
description="Fails twice, succeeds third time.",
params_model=_AnyParams,
build_handler=lambda _ctx: flaky,
)
)
step = PlanStep(step_id="s1", action="test_flaky", params={})
result = await execute_step(
step=step,
template_context={},
action_context=_action_context(),
default_max_retries=2,
default_retry_backoff="none",
default_timeout_seconds=30,
)
assert result["status"] == "succeeded"
assert result["attempts"] == 3
assert result["result"] == {"ok": True}
assert calls == 3
async def test_execute_step_renders_step_params_through_template_engine(
isolated_action_registry: None,
) -> None:
"""Step params are rendered against the template context before the
handler is invoked. String values containing Jinja expressions get
substituted from ``inputs`` and ``steps`` in the run context."""
received: list[dict[str, Any]] = []
async def capture(params: dict[str, Any]) -> dict[str, Any]:
received.append(params)
return {}
register_action(
ActionDefinition(
type="test_capture",
name="Capture",
description="Captures the params passed in.",
params_model=_AnyParams,
build_handler=lambda _ctx: capture,
)
)
step = PlanStep(
step_id="s1",
action="test_capture",
params={"message": "Hello {{ inputs.name }}"},
)
await execute_step(
step=step,
template_context={"inputs": {"name": "World"}, "steps": {}},
action_context=_action_context(),
default_max_retries=0,
default_retry_backoff="none",
default_timeout_seconds=30,
)
assert received == [{"message": "Hello World"}]

View file

@ -0,0 +1,59 @@
"""Lock that the executor propagates the captured model snapshot into the
``ActionContext``, so runs resolve their own model (insulated from chat /
search-space changes) and not the live search space.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import cast
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.runtime.executor import _build_action_ctx
from app.automations.schemas.definition.envelope import AutomationModels
from app.automations.schemas.definition.plan_step import PlanStep
pytestmark = pytest.mark.unit
def _run() -> SimpleNamespace:
return SimpleNamespace(
id=1,
automation=SimpleNamespace(search_space_id=42, created_by_user_id="u-1"),
)
def test_build_action_ctx_propagates_captured_models() -> None:
"""``definition.models`` flows onto the ActionContext model fields."""
models = AutomationModels(
agent_llm_id=-1,
image_generation_config_id=5,
vision_llm_config_id=-1,
)
ctx = _build_action_ctx(
cast(AsyncSession, None),
_run(),
PlanStep(step_id="s1", action="agent_task"),
models,
)
assert ctx.search_space_id == 42
assert ctx.agent_llm_id == -1
assert ctx.image_generation_config_id == 5
assert ctx.vision_llm_config_id == -1
def test_build_action_ctx_none_models_leaves_fields_none() -> None:
"""No captured snapshot → model fields are None (defensive fallback path)."""
ctx = _build_action_ctx(
cast(AsyncSession, None),
_run(),
PlanStep(step_id="s1", action="agent_task"),
None,
)
assert ctx.agent_llm_id is None
assert ctx.image_generation_config_id is None
assert ctx.vision_llm_config_id is None

View file

@ -0,0 +1,74 @@
"""Lock the ``with_retries`` policy: budget, recovery, exhaustion, timeout, backoff.
Tests with ``backoff="none"`` to keep wall-clock time zero. Backoff sleep
values themselves are observed by monkeypatching ``asyncio.sleep`` so we
don't introduce flakiness via real timing.
"""
from __future__ import annotations
import pytest
from app.automations.runtime.retries import with_retries
pytestmark = pytest.mark.unit
async def test_with_retries_returns_result_and_attempts_one_on_first_success() -> None:
"""A coroutine that succeeds on the first call returns its result
paired with ``attempts=1`` no retry consumed."""
calls = 0
async def succeed() -> str:
nonlocal calls
calls += 1
return "ok"
result, attempts = await with_retries(
succeed, max_retries=2, backoff="none", timeout=None
)
assert result == "ok"
assert attempts == 1
assert calls == 1
async def test_with_retries_returns_attempt_count_when_succeeding_after_failures() -> (
None
):
"""A coroutine that fails twice then succeeds returns ``attempts=3``
(the actual attempt that produced the result). Locks the contract
that the caller can distinguish first-try success from a recovery."""
calls = 0
async def flaky() -> str:
nonlocal calls
calls += 1
if calls < 3:
raise RuntimeError("transient")
return "ok"
result, attempts = await with_retries(
flaky, max_retries=5, backoff="none", timeout=None
)
assert result == "ok"
assert attempts == 3
assert calls == 3
async def test_with_retries_reraises_after_exhausting_the_budget() -> None:
"""When the coroutine raises on every attempt within
``1 + max_retries`` tries, the last exception propagates and the
handler is called exactly ``1 + max_retries`` times."""
calls = 0
async def always_fails() -> str:
nonlocal calls
calls += 1
raise RuntimeError(f"boom-{calls}")
with pytest.raises(RuntimeError, match="boom-3"):
await with_retries(always_fails, max_retries=2, backoff="none", timeout=None)
assert calls == 3 # 1 initial + 2 retries

View file

@ -0,0 +1,82 @@
"""Lock the request-side automation API schemas — the public validation gate."""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from app.automations.schemas.api.automation import AutomationCreate, AutomationUpdate
pytestmark = pytest.mark.unit
_VALID_DEFINITION = {
"name": "Test",
"plan": [{"step_id": "s1", "action": "agent_task"}],
}
def test_automation_create_accepts_valid_minimal_payload() -> None:
"""Happy path: just search_space_id, name, and a valid definition.
Triggers default to ``[]`` so users can attach them later."""
payload = AutomationCreate.model_validate(
{
"search_space_id": 1,
"name": "Daily digest",
"definition": _VALID_DEFINITION,
}
)
assert payload.name == "Daily digest"
assert payload.description is None
assert payload.triggers == []
def test_automation_create_cascades_validation_into_nested_definition() -> None:
"""A bad ``definition`` (e.g. empty plan) fails at the API boundary,
not at the DB layer. Locks the cascade so corrupt definitions can't
sneak through a misshapen wire payload."""
with pytest.raises(ValidationError):
AutomationCreate.model_validate(
{
"search_space_id": 1,
"name": "Bad",
"definition": {"name": "X", "plan": []}, # empty plan
}
)
def test_automation_create_rejects_unknown_top_level_field() -> None:
"""``extra='forbid'`` catches typos in API payloads at the boundary."""
with pytest.raises(ValidationError):
AutomationCreate.model_validate(
{
"search_space_id": 1,
"name": "X",
"definition": _VALID_DEFINITION,
"owner": "tg", # not allowed
}
)
def test_automation_create_rejects_empty_name() -> None:
"""Name is required and constrained to 1..200 chars."""
with pytest.raises(ValidationError):
AutomationCreate.model_validate(
{
"search_space_id": 1,
"name": "",
"definition": _VALID_DEFINITION,
}
)
def test_automation_update_accepts_partial_payload_with_no_fields() -> None:
"""All fields on ``AutomationUpdate`` are optional. An empty body is
a valid no-op update (the service layer decides what to do with it)."""
update = AutomationUpdate.model_validate({})
assert update.name is None
assert update.description is None
assert update.status is None
assert update.definition is None

View file

@ -0,0 +1,47 @@
"""Lock the request-side trigger API schemas."""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.schemas.api.trigger import TriggerCreate, TriggerUpdate
pytestmark = pytest.mark.unit
def test_trigger_create_uses_safe_defaults_for_optional_fields() -> None:
"""Defaults: empty ``params`` and ``static_inputs``, ``enabled=True``.
These let callers create a trigger with just ``type`` + the params
the trigger requires."""
trigger = TriggerCreate(type=TriggerType.SCHEDULE) # type: ignore[arg-type]
assert trigger.type is TriggerType.SCHEDULE
assert trigger.params == {}
assert trigger.static_inputs == {}
assert trigger.enabled is True
def test_trigger_create_rejects_unknown_trigger_type_string() -> None:
"""``type`` is a ``TriggerType`` enum, so any string outside the
enum's known values fails validation at the boundary."""
with pytest.raises(ValidationError):
TriggerCreate.model_validate({"type": "webhook"}) # not in TriggerType
def test_trigger_create_rejects_unknown_field() -> None:
"""``extra='forbid'`` catches typos in trigger payloads."""
with pytest.raises(ValidationError):
TriggerCreate.model_validate(
{"type": "schedule", "param": {}} # typo: param vs params
)
def test_trigger_update_accepts_partial_payload_with_no_fields() -> None:
"""``TriggerUpdate`` is fully optional — empty body is valid (no-op)."""
update = TriggerUpdate()
assert update.enabled is None
assert update.params is None
assert update.static_inputs is None

View file

@ -0,0 +1,90 @@
"""Lock the ``AutomationDefinition`` envelope contract."""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from app.automations.schemas.definition.envelope import (
AutomationDefinition,
AutomationModels,
)
from app.automations.schemas.definition.plan_step import PlanStep
pytestmark = pytest.mark.unit
def test_automation_definition_accepts_minimal_valid_input_with_sensible_defaults() -> (
None
):
"""A definition with just ``name`` + a one-step ``plan`` is valid and
fills in the rest with safe defaults so users don't have to write
out every section to get started."""
definition = AutomationDefinition(
name="Daily digest",
plan=[PlanStep(step_id="s1", action="agent_task")],
)
assert definition.name == "Daily digest"
assert definition.schema_version == "1.0"
assert definition.goal is None
assert definition.inputs is None
assert definition.triggers == []
# ``models`` is optional (populated server-side at create()).
assert definition.models is None
def test_automation_definition_models_round_trip() -> None:
"""The captured ``models`` snapshot survives a model_dump/validate round-trip."""
definition = AutomationDefinition(
name="Daily digest",
plan=[PlanStep(step_id="s1", action="agent_task")],
models=AutomationModels(
agent_llm_id=-1,
image_generation_config_id=5,
vision_llm_config_id=-1,
),
)
dumped = definition.model_dump(mode="json", by_alias=True)
assert dumped["models"] == {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
}
restored = AutomationDefinition.model_validate(dumped)
assert restored.models is not None
assert restored.models.agent_llm_id == -1
assert restored.models.image_generation_config_id == 5
assert restored.models.vision_llm_config_id == -1
def test_automation_definition_rejects_unknown_top_level_field() -> None:
"""``extra='forbid'`` catches typos at validation time (e.g. ``pln``
instead of ``plan``) before the bad definition reaches storage."""
with pytest.raises(ValidationError):
AutomationDefinition.model_validate(
{
"name": "X",
"plan": [{"step_id": "s1", "action": "agent_task"}],
"extra_field": "unexpected",
}
)
def test_automation_definition_rejects_empty_plan() -> None:
"""An automation with no plan steps has nothing to execute and must
be rejected at validation time."""
with pytest.raises(ValidationError):
AutomationDefinition(name="X", plan=[])
def test_automation_definition_rejects_empty_name() -> None:
"""Name is required and must be non-empty so list views and audit
logs have something meaningful to display."""
with pytest.raises(ValidationError):
AutomationDefinition(
name="",
plan=[PlanStep(step_id="s1", action="agent_task")],
)

View file

@ -0,0 +1,49 @@
"""Lock the ``Execution`` defaults + literal-constraint contract.
These defaults control production behavior of every automation that
doesn't override them; the defaults *are* the contract.
"""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from app.automations.schemas.definition.execution import Execution
pytestmark = pytest.mark.unit
def test_execution_uses_production_defaults_when_no_overrides_provided() -> None:
"""The defaults shipped to prod: 10-minute wall clock, 2 retries
per step, exponential backoff, drop overlapping runs. Changing any
of these is a behavioral release-note change."""
execution = Execution()
assert execution.timeout_seconds == 600
assert execution.max_retries == 2
assert execution.retry_backoff == "exponential"
assert execution.concurrency == "drop_if_running"
assert execution.on_failure == []
def test_execution_rejects_unknown_retry_backoff_strategy() -> None:
"""``retry_backoff`` is constrained to a closed set — typos like
``"expontential"`` must fail validation, not silently coerce."""
with pytest.raises(ValidationError):
Execution(retry_backoff="expontential") # type: ignore[arg-type]
def test_execution_rejects_unknown_concurrency_strategy() -> None:
"""Same closed-set constraint on ``concurrency``."""
with pytest.raises(ValidationError):
Execution(concurrency="parallel") # type: ignore[arg-type]
def test_execution_rejects_invalid_numeric_bounds() -> None:
"""``timeout_seconds > 0`` and ``max_retries >= 0``. Zero or negative
values would produce nonsensical run behavior."""
with pytest.raises(ValidationError):
Execution(timeout_seconds=0)
with pytest.raises(ValidationError):
Execution(max_retries=-1)

View file

@ -0,0 +1,39 @@
"""Lock the ``Inputs`` JSON ``schema``-alias roundtrip.
The field is ``schema_`` in Python (``schema`` shadows a Pydantic builtin)
but is wire-named ``schema``. Locking the roundtrip means JSON definitions
authored anywhere (UI raw editor, NL drafter, CLI export) speak the same
wire shape.
"""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from app.automations.schemas.definition.inputs import Inputs
pytestmark = pytest.mark.unit
def test_inputs_parses_wire_field_named_schema_into_schema_attribute() -> None:
"""JSON payloads use ``schema`` (the convention). The model stores it
on the Python attribute ``schema_`` without shadowing the builtin."""
parsed = Inputs.model_validate({"schema": {"type": "object"}})
assert parsed.schema_ == {"type": "object"}
def test_inputs_serializes_schema_attribute_back_to_wire_field_named_schema() -> None:
"""Round-trip: serializing emits ``schema`` (alias), not ``schema_``.
Locks the consumer-visible JSON shape regardless of the Python name."""
inputs = Inputs(schema={"type": "object"}) # type: ignore[call-arg]
assert inputs.model_dump() == {"schema": {"type": "object"}}
def test_inputs_rejects_unknown_field() -> None:
"""``extra='forbid'`` catches typos like ``shema`` so bad definitions
don't silently lose their input declaration."""
with pytest.raises(ValidationError):
Inputs.model_validate({"schema": {}, "extra": "x"})

View file

@ -0,0 +1,37 @@
"""Lock the ``Metadata`` ``extra='allow'`` contract — the only schema
that does. Free-form annotations on definitions (e.g. ``owner``,
``project``, ``created_by_ai``) need to round-trip through the envelope
without being rejected.
"""
from __future__ import annotations
import pytest
from app.automations.schemas.definition.metadata import Metadata
pytestmark = pytest.mark.unit
def test_metadata_preserves_unknown_keys() -> None:
"""Unlike every other definition sub-schema, ``Metadata`` allows
extra keys and round-trips them that's its purpose."""
metadata = Metadata.model_validate(
{
"tags": ["weekly", "report"],
"owner": "tg",
"created_by_ai": True,
}
)
dumped = metadata.model_dump()
assert dumped["tags"] == ["weekly", "report"]
assert dumped["owner"] == "tg"
assert dumped["created_by_ai"] is True
def test_metadata_defaults_tags_to_empty_list() -> None:
"""No tags is the common case; the default is the empty list so
callers can append without a None check."""
assert Metadata().tags == []

View file

@ -0,0 +1,52 @@
"""Lock the ``PlanStep`` validation contract."""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from app.automations.schemas.definition.plan_step import PlanStep
pytestmark = pytest.mark.unit
def test_plan_step_accepts_minimal_input_with_safe_defaults() -> None:
"""A step with just ``step_id`` + ``action`` is valid. Defaults
(no when, empty params, no output_as override, no retry/timeout
override) let the run inherit automation-wide defaults."""
step = PlanStep(step_id="s1", action="agent_task")
assert step.step_id == "s1"
assert step.action == "agent_task"
assert step.when is None
assert step.params == {}
assert step.output_as is None
assert step.max_retries is None
assert step.timeout_seconds is None
def test_plan_step_rejects_empty_step_id_and_action() -> None:
"""``step_id`` and ``action`` are addressing primitives — empty
strings would silently break runtime lookups."""
with pytest.raises(ValidationError):
PlanStep(step_id="", action="agent_task")
with pytest.raises(ValidationError):
PlanStep(step_id="s1", action="")
def test_plan_step_rejects_negative_max_retries_and_non_positive_timeout() -> None:
"""Numeric constraints: ``max_retries >= 0`` and ``timeout_seconds > 0``.
Negative budgets or zero timeouts produce nonsensical run behavior."""
with pytest.raises(ValidationError):
PlanStep(step_id="s1", action="agent_task", max_retries=-1)
with pytest.raises(ValidationError):
PlanStep(step_id="s1", action="agent_task", timeout_seconds=0)
def test_plan_step_rejects_unknown_field() -> None:
"""``extra='forbid'`` catches typos like ``actoin`` (instead of
``action``) before the bad step reaches storage."""
with pytest.raises(ValidationError):
PlanStep.model_validate(
{"step_id": "s1", "action": "agent_task", "actoin": "agent_task"}
)

View file

@ -0,0 +1,33 @@
"""Lock the ``TriggerSpec`` validation contract — the entry shape used
inside an automation's ``triggers[]`` array.
"""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from app.automations.schemas.definition.trigger_spec import TriggerSpec
pytestmark = pytest.mark.unit
def test_trigger_spec_accepts_type_with_default_empty_params() -> None:
"""``type`` is required; ``params`` defaults to ``{}`` so triggers
that take no params don't need an explicit body."""
spec = TriggerSpec(type="schedule")
assert spec.type == "schedule"
assert spec.params == {}
def test_trigger_spec_rejects_empty_type() -> None:
"""``type`` is the registry lookup key — empty would silently miss."""
with pytest.raises(ValidationError):
TriggerSpec(type="")
def test_trigger_spec_rejects_unknown_field() -> None:
"""``extra='forbid'`` catches typos at definition-validation time."""
with pytest.raises(ValidationError):
TriggerSpec.model_validate({"type": "schedule", "paramz": {}})

View file

@ -0,0 +1,493 @@
"""Lock creation-time model-policy enforcement in ``AutomationService``.
Creation (REST + manual builder) rejects search spaces whose models aren't
billable for automations with HTTP 422, mirroring the runtime backstop. These
tests isolate the new ``_assert_models_billable`` / ``model_eligibility`` paths
without touching the DB commit.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
import pytest
from fastapi import HTTPException
import app.automations.services.automation as automation_mod
from app.automations.schemas.api import AutomationCreate, AutomationUpdate
from app.automations.schemas.definition.envelope import (
AutomationDefinition,
AutomationModels,
)
from app.automations.schemas.definition.plan_step import PlanStep
from app.automations.services.automation import AutomationService
from app.automations.services.model_policy import AutomationModelPolicyError
pytestmark = pytest.mark.unit
class _FakeSession:
def __init__(self, search_space: Any) -> None:
self._search_space = search_space
self.added: list[Any] = []
self.commits = 0
async def get(self, _model: Any, _pk: int) -> Any:
return self._search_space
def add(self, obj: Any) -> None:
self.added.append(obj)
async def commit(self) -> None:
self.commits += 1
def _service(search_space: Any) -> AutomationService:
return AutomationService(
session=_FakeSession(search_space), user=SimpleNamespace(id="u-1")
)
def _definition(**kwargs: Any) -> AutomationDefinition:
return AutomationDefinition(
name="A",
plan=[PlanStep(step_id="s1", action="agent_task")],
**kwargs,
)
async def test_assert_models_billable_raises_422_on_violation(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A blocked model maps the policy error to HTTP 422."""
def _raise(_ss):
raise AutomationModelPolicyError(
[{"kind": "llm", "config_id": 0, "reason": "Auto mode"}]
)
monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise)
service = _service(SimpleNamespace(agent_llm_id=0))
with pytest.raises(HTTPException) as exc_info:
await service._assert_models_billable(1)
assert exc_info.value.status_code == 422
async def test_assert_models_billable_raises_404_when_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A missing search space is a 404, not a policy error."""
monkeypatch.setattr(
automation_mod, "assert_automation_models_billable", lambda _ss: None
)
service = _service(None)
with pytest.raises(HTTPException) as exc_info:
await service._assert_models_billable(999)
assert exc_info.value.status_code == 404
async def test_assert_models_billable_returns_search_space_when_ok(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the policy accepts, the loaded search space is returned for reuse."""
monkeypatch.setattr(
automation_mod, "assert_automation_models_billable", lambda _ss: None
)
search_space = SimpleNamespace(agent_llm_id=-1)
service = _service(search_space)
assert await service._assert_models_billable(1) is search_space
async def test_create_injects_captured_models_from_search_space(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""create() snapshots the search space's model prefs onto the definition."""
monkeypatch.setattr(
automation_mod, "assert_automation_models_billable", lambda _ss: None
)
async def _noop_authorize(self, *_a, **_k):
return None
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
async def _return_added(self, _aid):
return self.session.added[-1]
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
search_space = SimpleNamespace(
agent_llm_id=-1,
image_generation_config_id=5,
vision_llm_config_id=-1,
)
service = _service(search_space)
payload = AutomationCreate(
search_space_id=1,
name="A",
definition=_definition(),
)
automation = await service.create(payload)
assert automation.definition["models"] == {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
}
async def test_create_treats_unset_prefs_as_auto_zero(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""``None`` search-space prefs are captured as ``0`` (Auto) ids."""
monkeypatch.setattr(
automation_mod, "assert_automation_models_billable", lambda _ss: None
)
async def _noop_authorize(self, *_a, **_k):
return None
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
async def _return_added(self, _aid):
return self.session.added[-1]
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
search_space = SimpleNamespace(
agent_llm_id=None,
image_generation_config_id=None,
vision_llm_config_id=None,
)
service = _service(search_space)
payload = AutomationCreate(search_space_id=1, name="A", definition=_definition())
automation = await service.create(payload)
assert automation.definition["models"] == {
"agent_llm_id": 0,
"image_generation_config_id": 0,
"vision_llm_config_id": 0,
}
async def test_create_honors_selected_models_when_provided(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the payload carries ``definition.models`` they are validated + kept.
The search-space snapshot path is bypassed entirely (no
``assert_automation_models_billable`` call).
"""
def _fail_snapshot(_ss):
raise AssertionError("snapshot path should not run when models are provided")
monkeypatch.setattr(
automation_mod, "assert_automation_models_billable", _fail_snapshot
)
validated: dict[str, Any] = {}
def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
validated["ids"] = (
agent_llm_id,
image_generation_config_id,
vision_llm_config_id,
)
monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok)
async def _noop_authorize(self, *_a, **_k):
return None
async def _return_added(self, _aid):
return self.session.added[-1]
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
service = _service(SimpleNamespace(agent_llm_id=-99))
payload = AutomationCreate(
search_space_id=1,
name="A",
definition=_definition(
models=AutomationModels(
agent_llm_id=-1,
image_generation_config_id=7,
vision_llm_config_id=-2,
)
),
)
automation = await service.create(payload)
assert validated["ids"] == (-1, 7, -2)
assert automation.definition["models"] == {
"agent_llm_id": -1,
"image_generation_config_id": 7,
"vision_llm_config_id": -2,
}
async def test_create_rejects_unbillable_selected_models(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A non-billable explicit selection maps the policy error to HTTP 422."""
def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
raise AutomationModelPolicyError(
[{"kind": "llm", "config_id": -3, "reason": "free model"}]
)
monkeypatch.setattr(automation_mod, "assert_models_billable", _raise)
async def _noop_authorize(self, *_a, **_k):
return None
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
service = _service(SimpleNamespace(agent_llm_id=-3))
payload = AutomationCreate(
search_space_id=1,
name="A",
definition=_definition(
models=AutomationModels(
agent_llm_id=-3,
image_generation_config_id=7,
vision_llm_config_id=-2,
)
),
)
with pytest.raises(HTTPException) as exc_info:
await service.create(payload)
assert exc_info.value.status_code == 422
async def test_update_preserves_captured_models(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A definition edit carries over the previously captured ``models``."""
captured = {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
}
existing = SimpleNamespace(
search_space_id=1,
definition={"name": "A", "plan": [], "models": captured},
version=3,
)
async def _noop_authorize(self, *_a, **_k):
return None
async def _return_existing(self, _aid):
return existing
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
monkeypatch.setattr(
AutomationService, "_get_with_triggers_or_raise", _return_existing
)
service = _service(SimpleNamespace())
# The incoming patch definition has no ``models`` (frontend strips it).
patch = AutomationUpdate(definition=_definition())
result = await service.update(7, patch)
assert result.definition["models"] == captured
assert result.version == 4
async def test_update_honors_changed_models_when_valid(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A definition edit with a *changed* models block validates + keeps it."""
existing = SimpleNamespace(
search_space_id=1,
definition={
"name": "A",
"plan": [],
"models": {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
},
},
version=3,
)
validated: dict[str, Any] = {}
def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
validated["ids"] = (
agent_llm_id,
image_generation_config_id,
vision_llm_config_id,
)
monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok)
async def _noop_authorize(self, *_a, **_k):
return None
async def _return_existing(self, _aid):
return existing
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
monkeypatch.setattr(
AutomationService, "_get_with_triggers_or_raise", _return_existing
)
service = _service(SimpleNamespace())
patch = AutomationUpdate(
definition=_definition(
models=AutomationModels(
agent_llm_id=-2,
image_generation_config_id=9,
vision_llm_config_id=-2,
)
)
)
result = await service.update(7, patch)
assert validated["ids"] == (-2, 9, -2)
assert result.definition["models"] == {
"agent_llm_id": -2,
"image_generation_config_id": 9,
"vision_llm_config_id": -2,
}
assert result.version == 4
async def test_update_rejects_changed_unbillable_models(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A *changed* non-billable models block is rejected with HTTP 422."""
existing = SimpleNamespace(
search_space_id=1,
definition={
"name": "A",
"plan": [],
"models": {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
},
},
version=3,
)
def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
raise AutomationModelPolicyError(
[{"kind": "llm", "config_id": -7, "reason": "free model"}]
)
monkeypatch.setattr(automation_mod, "assert_models_billable", _raise)
async def _noop_authorize(self, *_a, **_k):
return None
async def _return_existing(self, _aid):
return existing
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
monkeypatch.setattr(
AutomationService, "_get_with_triggers_or_raise", _return_existing
)
service = _service(SimpleNamespace())
patch = AutomationUpdate(
definition=_definition(
models=AutomationModels(
agent_llm_id=-7,
image_generation_config_id=5,
vision_llm_config_id=-1,
)
)
)
with pytest.raises(HTTPException) as exc_info:
await service.update(7, patch)
assert exc_info.value.status_code == 422
async def test_update_keeps_unchanged_models_without_revalidation(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""An unchanged models block is kept as-is and is NOT re-validated.
Lets users edit an automation whose captured model later drifted out of
premium without an unrelated edit tripping the policy check.
"""
captured = {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
}
existing = SimpleNamespace(
search_space_id=1,
definition={"name": "A", "plan": [], "models": captured},
version=3,
)
def _fail(*_a, **_k):
raise AssertionError("unchanged models must not be re-validated")
monkeypatch.setattr(automation_mod, "assert_models_billable", _fail)
async def _noop_authorize(self, *_a, **_k):
return None
async def _return_existing(self, _aid):
return existing
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
monkeypatch.setattr(
AutomationService, "_get_with_triggers_or_raise", _return_existing
)
service = _service(SimpleNamespace())
patch = AutomationUpdate(
definition=_definition(models=AutomationModels(**captured))
)
result = await service.update(7, patch)
assert result.definition["models"] == captured
assert result.version == 4
async def test_model_eligibility_authorizes_and_returns_payload(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""``model_eligibility`` checks read access then returns the eligibility dict."""
authorized: dict[str, Any] = {}
async def _fake_check_permission(_session, _user, ss_id, permission, _msg):
authorized["search_space_id"] = ss_id
authorized["permission"] = permission
monkeypatch.setattr(automation_mod, "check_permission", _fake_check_permission)
monkeypatch.setattr(
automation_mod,
"get_automation_model_eligibility",
lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]},
)
service = _service(SimpleNamespace(agent_llm_id=-2))
result = await service.model_eligibility(search_space_id=5)
assert result == {"allowed": False, "violations": [{"kind": "image"}]}
assert authorized["search_space_id"] == 5
assert authorized["permission"] == "automations:read"

View file

@ -0,0 +1,196 @@
"""Lock the automation model-billing policy.
Automations may only run on billable models: premium global configs
(``billing_tier == "premium"``) or user BYOK configs (positive id). Free
globals and Auto mode (id == 0 / None) are blocked. These tests pin that rule
across all three model slots (chat LLM, image, vision).
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
import app.automations.services.model_policy as model_policy
from app.automations.services.model_policy import (
AutomationModelPolicyError,
assert_automation_models_billable,
assert_models_billable,
get_automation_model_eligibility,
get_model_eligibility,
)
pytestmark = pytest.mark.unit
def _search_space(*, llm: int | None, image: int | None, vision: int | None):
"""Minimal stand-in for the ``SearchSpace`` ORM row the policy reads."""
return SimpleNamespace(
agent_llm_id=llm,
image_generation_config_id=image,
vision_llm_config_id=vision,
)
@pytest.fixture
def patched_globals(monkeypatch: pytest.MonkeyPatch):
"""Stub the global config sources the policy consults for negative ids.
Negative ids: -1 is premium, -2 is free, for each of llm/image/vision.
"""
llm_configs = {
-1: {"id": -1, "billing_tier": "premium"},
-2: {"id": -2, "billing_tier": "free"},
}
monkeypatch.setattr(
"app.agents.new_chat.llm_config.load_global_llm_config_by_id",
lambda cid: llm_configs.get(cid),
)
from app.config import config as app_config
monkeypatch.setattr(
app_config,
"GLOBAL_IMAGE_GEN_CONFIGS",
[
{"id": -1, "billing_tier": "premium"},
{"id": -2, "billing_tier": "free"},
],
raising=False,
)
monkeypatch.setattr(
app_config,
"GLOBAL_VISION_LLM_CONFIGS",
[
{"id": -1, "billing_tier": "premium"},
{"id": -2, "billing_tier": "free"},
],
raising=False,
)
return None
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None:
"""A positive config id is a user-owned BYOK model — always billable."""
allowed, reason = model_policy._classify(kind, 7)
assert allowed is True
assert reason == ""
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
@pytest.mark.parametrize("config_id", [0, None])
def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None:
"""Auto mode (id 0) and an unset slot (None) are blocked."""
allowed, reason = model_policy._classify(kind, config_id)
assert allowed is False
assert "Auto mode" in reason
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
def test_premium_global_is_allowed(kind: str, patched_globals) -> None:
"""A negative (global) id with premium billing tier is allowed."""
allowed, reason = model_policy._classify(kind, -1)
assert allowed is True
assert reason == ""
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
def test_free_global_is_blocked(kind: str, patched_globals) -> None:
"""A negative (global) id with a free billing tier is blocked."""
allowed, reason = model_policy._classify(kind, -2)
assert allowed is False
assert "free model" in reason
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None:
"""A negative id that resolves to no config is treated as not premium."""
allowed, _ = model_policy._classify(kind, -999)
assert allowed is False
def test_eligibility_all_billable(patched_globals) -> None:
"""Premium LLM + BYOK image + premium vision → allowed, no violations."""
search_space = _search_space(llm=-1, image=5, vision=-1)
result = get_automation_model_eligibility(search_space)
assert result == {"allowed": True, "violations": []}
def test_eligibility_reports_each_violation(patched_globals) -> None:
"""A free LLM, Auto image, and free vision each produce a violation."""
search_space = _search_space(llm=-2, image=0, vision=-2)
result = get_automation_model_eligibility(search_space)
assert result["allowed"] is False
kinds = {v["kind"] for v in result["violations"]}
assert kinds == {"llm", "image", "vision"}
# config_id is echoed back for the UI / settings deep-link.
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
def test_assert_raises_with_violations(patched_globals) -> None:
"""``assert_automation_models_billable`` raises when any slot is blocked."""
search_space = _search_space(llm=0, image=5, vision=-1)
with pytest.raises(AutomationModelPolicyError) as exc_info:
assert_automation_models_billable(search_space)
assert len(exc_info.value.violations) == 1
assert exc_info.value.violations[0]["kind"] == "llm"
def test_assert_passes_when_all_billable(patched_globals) -> None:
"""No exception when every slot is premium or BYOK."""
search_space = _search_space(llm=3, image=-1, vision=4)
assert assert_automation_models_billable(search_space) is None
# --- ID-based core (used by the runtime backstop against captured snapshots) ---
def test_get_model_eligibility_all_billable(patched_globals) -> None:
"""Premium LLM + BYOK image + premium vision (explicit ids) → allowed."""
result = get_model_eligibility(
agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1
)
assert result == {"allowed": True, "violations": []}
def test_get_model_eligibility_reports_each_violation(patched_globals) -> None:
"""Free LLM, Auto image, free vision (explicit ids) each produce a violation."""
result = get_model_eligibility(
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
)
assert result["allowed"] is False
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
def test_assert_models_billable_raises(patched_globals) -> None:
"""``assert_models_billable`` raises when any explicit id is blocked."""
with pytest.raises(AutomationModelPolicyError) as exc_info:
assert_models_billable(
agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1
)
assert len(exc_info.value.violations) == 1
assert exc_info.value.violations[0]["kind"] == "llm"
def test_assert_models_billable_passes(patched_globals) -> None:
"""No exception when every explicit id is premium or BYOK."""
assert (
assert_models_billable(
agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4
)
is None
)
def test_search_space_wrapper_delegates_to_core(patched_globals) -> None:
"""The search-space wrapper produces the same result as the ID core."""
search_space = _search_space(llm=-2, image=0, vision=-2)
assert get_automation_model_eligibility(search_space) == get_model_eligibility(
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
)

View file

@ -0,0 +1,53 @@
"""Lock the ``{run, inputs, steps}`` namespace exposed to every template."""
from __future__ import annotations
from datetime import UTC, datetime
from uuid import UUID
import pytest
from app.automations.templating.context import build_run_context
pytestmark = pytest.mark.unit
def test_build_run_context_exposes_run_inputs_and_steps_namespaces() -> None:
"""The namespace handed to templates groups run metadata under ``run``,
runtime + static inputs under ``inputs``, and step outputs (keyed by
``output_as`` / ``step_id``) under ``steps``. Locks the contract that
every plan template body relies on."""
creator = UUID("00000000-0000-0000-0000-000000000001")
started = datetime(2026, 5, 28, 14, 30, tzinfo=UTC)
ctx = build_run_context(
run_id=42,
automation_id=7,
automation_name="Weekly digest",
automation_version=3,
search_space_id=1,
creator_id=creator,
trigger_id=11,
trigger_type="schedule",
started_at=started,
attempt=2,
inputs={"topic": "weekly"},
step_outputs={"summarize": {"text": "ok"}},
)
assert ctx == {
"run": {
"id": 42,
"automation_id": 7,
"automation_name": "Weekly digest",
"automation_version": 3,
"search_space_id": 1,
"creator_id": creator,
"trigger_id": 11,
"trigger_type": "schedule",
"started_at": started,
"attempt": 2,
},
"inputs": {"topic": "weekly"},
"steps": {"summarize": {"text": "ok"}},
}

View file

@ -0,0 +1,53 @@
"""Lock the sandbox boundary: disallowed filters/tests reject, finalize coerces non-strings.
These behaviors live in ``environment.py`` but are observed through the
public ``render_template`` surface the same surface every step uses.
"""
from __future__ import annotations
from datetime import UTC, datetime
import pytest
from jinja2.exceptions import TemplateError
from app.automations.templating.render import render_template
pytestmark = pytest.mark.unit
def test_environment_rejects_filters_not_in_the_allowlist() -> None:
"""A template that pipes through a Jinja built-in **not** in the
allowlist (e.g. ``pprint``) must fail rather than rendering. Locks
the sandbox surface against accidental re-introduction of removed
filters."""
with pytest.raises(TemplateError):
render_template("{{ value | pprint }}", {"value": {"k": 1}})
def test_environment_finalizes_datetime_output_to_iso_string() -> None:
"""A datetime that lands directly at an output site is stringified
via ``isoformat()`` rather than producing ``str(datetime)`` (which
has a space separator). Locks the wire shape templates produce
when emitting ``inputs.fired_at`` and other datetime values."""
dt = datetime(2026, 5, 28, 14, 30, tzinfo=UTC)
assert (
render_template("{{ moment }}", {"moment": dt}) == "2026-05-28T14:30:00+00:00"
)
def test_environment_finalizes_none_output_to_empty_string() -> None:
"""A ``None`` at an output site becomes the empty string. Lets
templates write ``{{ inputs.last_fired_at }}`` unconditionally on
the first run without exploding on the null."""
assert render_template("{{ missing }}", {"missing": None}) == ""
def test_environment_finalizes_dict_output_to_json() -> None:
"""A dict at an output site is JSON-serialized. Same for lists.
Locks the wire shape so users embedding structured values into
prompts get deterministic, parseable output."""
rendered = render_template("{{ payload }}", {"payload": {"a": 1, "b": [2, 3]}})
assert rendered == '{"a": 1, "b": [2, 3]}'

View file

@ -0,0 +1,42 @@
"""Lock the custom Jinja filters: ``date`` and ``slugify``."""
from __future__ import annotations
from datetime import UTC, datetime
import pytest
from app.automations.templating.filters import filter_date, filter_slugify
pytestmark = pytest.mark.unit
def test_filter_slugify_produces_url_safe_slug_from_typical_title() -> None:
"""``filter_slugify`` lowercases, replaces non-alphanumerics with
hyphens, collapses repeats, and trims edge hyphens the standard
URL-slug contract users expect when piping titles into paths."""
assert filter_slugify("Hello, World! 2026") == "hello-world-2026"
def test_filter_date_formats_datetime_with_strftime_format() -> None:
"""``filter_date`` calls ``strftime`` on datetime-like values with the
provided format. Default format yields ISO date (YYYY-MM-DD)."""
dt = datetime(2026, 5, 28, 14, 30, tzinfo=UTC)
assert filter_date(dt) == "2026-05-28"
assert filter_date(dt, "%Y/%m/%d %H:%M") == "2026/05/28 14:30"
def test_filter_date_returns_empty_string_for_none() -> None:
"""``None`` (e.g., a never-fired ``last_fired_at``) renders as the
empty string rather than the literal ``"None"`` or raising. This is
what lets templates write ``{{ inputs.last_fired_at | date }}``
unconditionally on the first run."""
assert filter_date(None) == ""
def test_filter_date_passes_strings_through_unchanged() -> None:
"""Already-formatted ISO strings (the JSON-serialized shape of
runtime inputs like ``fired_at``) pass through unchanged so callers
don't have to special-case the type."""
assert filter_date("2026-05-28T14:30:00+00:00") == "2026-05-28T14:30:00+00:00"

View file

@ -0,0 +1,59 @@
"""Lock the public template-rendering surface: render, predicate, recursive."""
from __future__ import annotations
import pytest
from jinja2 import UndefinedError
from app.automations.templating.render import (
evaluate_predicate,
render_template,
render_value,
)
pytestmark = pytest.mark.unit
def test_render_template_substitutes_context_variables() -> None:
"""A template referencing a context variable produces the substituted
string. Most basic contract of the template engine."""
result = render_template("Hello {{ name }}!", {"name": "World"})
assert result == "Hello World!"
def test_render_template_raises_on_undefined_variable() -> None:
"""Referencing a variable that isn't in the context raises rather than
rendering the empty string. Locks the StrictUndefined safety net so
template typos surface as run failures instead of silent corruption."""
with pytest.raises(UndefinedError):
render_template("Hello {{ missing }}!", {})
def test_evaluate_predicate_returns_truthy_outcome_of_expression() -> None:
"""``evaluate_predicate`` compiles a Jinja **expression** (not template
body) and coerces the value to ``bool``. Drives ``step.when`` gating."""
assert evaluate_predicate("inputs.count > 0", {"inputs": {"count": 3}}) is True
assert evaluate_predicate("inputs.count > 0", {"inputs": {"count": 0}}) is False
def test_render_value_renders_strings_recursively_through_dicts_and_lists() -> None:
"""``render_value`` walks dicts and lists, renders string leaves through
the template engine, and leaves non-strings untouched. This is the
primitive ``execute_step`` uses to render step params at run time."""
context = {"inputs": {"name": "World"}, "topic": "weekly"}
rendered = render_value(
{
"greeting": "Hello {{ inputs.name }}",
"tags": ["{{ topic }}", "static"],
"config": {"retries": 3, "label": "{{ topic }}-{{ inputs.name }}"},
},
context,
)
assert rendered == {
"greeting": "Hello World",
"tags": ["weekly", "static"],
"config": {"retries": 3, "label": "weekly-World"},
}

View file

@ -0,0 +1,56 @@
"""Lock the ``params_schema`` derivation on action + trigger definitions.
Both definition dataclasses expose ``params_schema`` as the JSON Schema
of their ``params_model``. This is what the registry endpoints surface
to the UI as the "what shape do these params take?" contract.
"""
from __future__ import annotations
import pytest
from pydantic import BaseModel
from app.automations.actions.types import ActionDefinition
from app.automations.triggers.types import TriggerDefinition
pytestmark = pytest.mark.unit
class _Topic(BaseModel):
"""Model with one required string field — minimal schema fingerprint."""
topic: str
def test_action_definition_params_schema_reflects_params_model() -> None:
"""``ActionDefinition.params_schema`` returns a JSON Schema derived
from the Pydantic ``params_model`` required fields and types are
visible to clients consuming the registry endpoint."""
definition = ActionDefinition(
type="t",
name="N",
description="D",
params_model=_Topic,
build_handler=lambda _ctx: lambda _p: {}, # type: ignore[arg-type,return-value]
)
schema = definition.params_schema
assert schema["type"] == "object"
assert schema["properties"]["topic"]["type"] == "string"
assert "topic" in schema["required"]
def test_trigger_definition_params_schema_reflects_params_model() -> None:
"""Same JSON-Schema derivation contract on the trigger side."""
definition = TriggerDefinition(
type="t",
description="D",
params_model=_Topic,
)
schema = definition.params_schema
assert schema["type"] == "object"
assert schema["properties"]["topic"]["type"] == "string"
assert "topic" in schema["required"]

View file

@ -0,0 +1,37 @@
"""Lock the bundled import side-effects.
Importing ``app.automations`` (the package) registers the v1 bundled
action (``agent_task``) and the v1 bundled trigger (``schedule``). If the
import chain breaks (e.g. someone removes ``from . import definition``
in a sub-package ``__init__``), the system would silently launch with an
empty registry. These tests are the canary.
"""
from __future__ import annotations
import pytest
import app.automations # noqa: F401 (force the package import + its side-effects)
from app.automations.actions.store import get_action
from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.triggers.store import get_trigger
pytestmark = pytest.mark.unit
def test_bundled_agent_task_action_is_registered_after_package_import() -> None:
"""``agent_task`` — the v1 default action — must be discoverable in
the registry after the package is imported."""
definition = get_action("agent_task")
assert definition is not None
assert definition.type == "agent_task"
def test_bundled_schedule_trigger_is_registered_after_package_import() -> None:
"""``schedule`` — the only v1 trigger — must be discoverable in the
registry after the package is imported."""
definition = get_trigger(TriggerType.SCHEDULE.value)
assert definition is not None
assert definition.type == TriggerType.SCHEDULE.value

View file

@ -0,0 +1,45 @@
"""Lock the persistence enum string values + members.
These enums are mirrored by Postgres enum types, embedded in stored DB
rows, and surfaced in the JSON API. Renaming a value (or removing a
member) silently breaks production data and previously-issued API
responses, so the strings + the set of members are the contract.
"""
from __future__ import annotations
import pytest
from app.automations.persistence.enums.automation_status import AutomationStatus
from app.automations.persistence.enums.run_status import RunStatus
from app.automations.persistence.enums.trigger_type import TriggerType
pytestmark = pytest.mark.unit
def test_automation_status_string_values_are_stable() -> None:
"""The exact strings persisted to Postgres and served in API JSON."""
assert {member.value for member in AutomationStatus} == {
"active",
"paused",
"archived",
}
def test_run_status_string_values_are_stable() -> None:
"""Run lifecycle states embedded in the ``automation_runs`` table."""
assert {member.value for member in RunStatus} == {
"pending",
"running",
"succeeded",
"failed",
"cancelled",
"timed_out",
}
def test_trigger_type_keeps_manual_member_even_though_unregistered() -> None:
"""``schedule`` and ``event`` are registered; ``MANUAL`` is reserved
(mirrors the Postgres enum) but the trigger store does not register it.
The enum must keep every member so DB rows and migrations stay valid."""
assert {member.value for member in TriggerType} == {"schedule", "event", "manual"}

View file

@ -0,0 +1,117 @@
"""Lock the trigger + action registry contracts.
Both stores share the same API shape (register/get/all + duplicate-raise),
so they're tested together to keep the contract visible side-by-side.
"""
from __future__ import annotations
import pytest
from pydantic import BaseModel
from app.automations.actions.store import (
get_action,
register_action,
)
from app.automations.actions.types import ActionDefinition
from app.automations.triggers.store import (
all_triggers,
get_trigger,
register_trigger,
)
from app.automations.triggers.types import TriggerDefinition
pytestmark = pytest.mark.unit
class _Params(BaseModel):
"""Empty params model used by test-only registrations."""
def _trigger(type_: str = "test_trigger") -> TriggerDefinition:
return TriggerDefinition(
type=type_, description="Test trigger.", params_model=_Params
)
def _action(type_: str = "test_action") -> ActionDefinition:
return ActionDefinition(
type=type_,
name="Test",
description="Test action.",
params_model=_Params,
build_handler=lambda _ctx: lambda _p: {}, # type: ignore[arg-type,return-value]
)
def test_register_trigger_then_get_trigger_returns_the_same_definition(
isolated_trigger_registry: None,
) -> None:
"""The canonical round-trip: register, look up by type, get the same
definition back. Locks the basic registry contract."""
definition = _trigger()
register_trigger(definition)
assert get_trigger("test_trigger") is definition
def test_register_action_then_get_action_returns_the_same_definition(
isolated_action_registry: None,
) -> None:
"""Same round-trip contract for the action registry."""
definition = _action()
register_action(definition)
assert get_action("test_action") is definition
def test_get_trigger_returns_none_for_unknown_type(
isolated_trigger_registry: None,
) -> None:
"""An unknown type returns ``None`` (not raises). Lets callers like
the dispatcher branch on "is this trigger still registered?" without
try/except."""
assert get_trigger("never_registered") is None
def test_get_action_returns_none_for_unknown_type(
isolated_action_registry: None,
) -> None:
"""Same ``None``-not-raise contract on the action side."""
assert get_action("never_registered") is None
def test_register_trigger_rejects_duplicate_type(
isolated_trigger_registry: None,
) -> None:
"""Re-registering the same ``type`` raises rather than silently
overwriting. Locks the safety net against accidental double-import
(e.g., circular imports re-running the registration block)."""
register_trigger(_trigger())
with pytest.raises(ValueError, match="test_trigger"):
register_trigger(_trigger())
def test_register_action_rejects_duplicate_type(
isolated_action_registry: None,
) -> None:
"""Same duplicate-rejection contract on the action side."""
register_action(_action())
with pytest.raises(ValueError, match="test_action"):
register_action(_action())
def test_all_triggers_returns_defensive_snapshot(
isolated_trigger_registry: None,
) -> None:
"""``all_triggers()`` returns a copy: mutating the returned dict does
not corrupt the internal registry. Locks the snapshot contract that
UI/listing endpoints rely on."""
register_trigger(_trigger("snapshot_test"))
snapshot = all_triggers()
snapshot.pop("snapshot_test")
assert get_trigger("snapshot_test") is not None

View file

@ -0,0 +1,18 @@
"""The ``event`` trigger self-registers on the triggers store at import."""
from __future__ import annotations
import pytest
from app.automations.triggers import get_trigger
from app.automations.triggers.builtin.event.params import EventTriggerParams
pytestmark = pytest.mark.unit
def test_event_trigger_is_registered() -> None:
definition = get_trigger("event")
assert definition is not None
assert definition.type == "event"
assert definition.params_model is EventTriggerParams

View file

@ -0,0 +1,115 @@
"""Behavior tests for the ``matches`` filter grammar."""
from __future__ import annotations
import pytest
from app.automations.triggers.builtin.event.filter import FilterError, matches
pytestmark = pytest.mark.unit
def test_empty_filter_matches_any_payload() -> None:
assert matches({}, {"document_id": 42, "document_type": "FILE"}) is True
assert matches({}, {}) is True
def test_scalar_value_is_implicit_equality() -> None:
flt = {"document_type": "FILE"}
assert matches(flt, {"document_type": "FILE"}) is True
assert matches(flt, {"document_type": "WEBPAGE"}) is False
def test_multiple_fields_are_anded() -> None:
flt = {"document_type": "FILE", "search_space_id": 7}
assert matches(flt, {"document_type": "FILE", "search_space_id": 7}) is True
assert matches(flt, {"document_type": "FILE", "search_space_id": 9}) is False
def test_gt_operator_compares_greater_than() -> None:
flt = {"page_count": {"$gt": 10}}
assert matches(flt, {"page_count": 20}) is True
assert matches(flt, {"page_count": 10}) is False
assert matches(flt, {"page_count": 5}) is False
def test_remaining_comparison_operators() -> None:
assert matches({"n": {"$gte": 10}}, {"n": 10}) is True
assert matches({"n": {"$gte": 10}}, {"n": 9}) is False
assert matches({"n": {"$lt": 10}}, {"n": 9}) is True
assert matches({"n": {"$lt": 10}}, {"n": 10}) is False
assert matches({"n": {"$lte": 10}}, {"n": 10}) is True
assert matches({"n": {"$lte": 10}}, {"n": 11}) is False
assert matches({"s": {"$eq": "FILE"}}, {"s": "FILE"}) is True
assert matches({"s": {"$eq": "FILE"}}, {"s": "WEB"}) is False
assert matches({"s": {"$ne": "FILE"}}, {"s": "WEB"}) is True
assert matches({"s": {"$ne": "FILE"}}, {"s": "FILE"}) is False
def test_multiple_operators_on_one_field_are_anded() -> None:
flt = {"n": {"$gte": 10, "$lt": 20}}
assert matches(flt, {"n": 15}) is True
assert matches(flt, {"n": 10}) is True
assert matches(flt, {"n": 20}) is False
assert matches(flt, {"n": 5}) is False
def test_in_and_nin_membership_operators() -> None:
flt_in = {"document_type": {"$in": ["FILE", "WEBPAGE"]}}
assert matches(flt_in, {"document_type": "FILE"}) is True
assert matches(flt_in, {"document_type": "SLACK"}) is False
flt_nin = {"document_type": {"$nin": ["FILE", "WEBPAGE"]}}
assert matches(flt_nin, {"document_type": "SLACK"}) is True
assert matches(flt_nin, {"document_type": "FILE"}) is False
def test_or_matches_when_any_branch_holds() -> None:
flt = {"$or": [{"document_type": "FILE"}, {"document_type": "WEBPAGE"}]}
assert matches(flt, {"document_type": "WEBPAGE"}) is True
assert matches(flt, {"document_type": "SLACK"}) is False
def test_and_matches_when_every_branch_holds() -> None:
flt = {"$and": [{"n": {"$gt": 5}}, {"n": {"$lt": 10}}]}
assert matches(flt, {"n": 7}) is True
assert matches(flt, {"n": 12}) is False
def test_not_inverts_its_subexpression() -> None:
flt = {"$not": {"document_type": "FILE"}}
assert matches(flt, {"document_type": "WEBPAGE"}) is True
assert matches(flt, {"document_type": "FILE"}) is False
def test_missing_field_never_matches_and_never_raises() -> None:
# Conservative: an absent field fails the constraint, and comparisons must
# not raise on the missing value — including $ne (absence isn't "not equal").
assert matches({"document_type": "FILE"}, {}) is False
assert matches({"page_count": {"$gt": 5}}, {}) is False
assert matches({"document_type": {"$in": ["FILE"]}}, {}) is False
assert matches({"document_type": {"$ne": "FILE"}}, {}) is False
def test_logical_operators_compose_with_fields() -> None:
flt = {
"search_space_id": 7,
"$or": [{"document_type": "FILE"}, {"document_type": "WEBPAGE"}],
}
assert matches(flt, {"search_space_id": 7, "document_type": "FILE"}) is True
assert matches(flt, {"search_space_id": 9, "document_type": "FILE"}) is False
assert matches(flt, {"search_space_id": 7, "document_type": "SLACK"}) is False
def test_unknown_field_operator_raises_filter_error() -> None:
with pytest.raises(FilterError):
matches({"n": {"$regex": "x"}}, {"n": "xyz"})
def test_unknown_logical_operator_raises_filter_error() -> None:
with pytest.raises(FilterError):
matches({"$nor": [{"document_type": "FILE"}]}, {"document_type": "FILE"})

View file

@ -0,0 +1,26 @@
"""An event hands its payload + metadata to the run as inputs."""
from __future__ import annotations
import pytest
from app.automations.triggers.builtin.event.inputs import event_runtime_inputs
from app.event_bus import Event
pytestmark = pytest.mark.unit
def test_runtime_inputs_flatten_payload_with_event_metadata() -> None:
event = Event(
event_type="document.indexed",
payload={"document_id": 42, "document_type": "FILE"},
search_space_id=7,
)
inputs = event_runtime_inputs(event)
assert inputs["document_id"] == 42
assert inputs["document_type"] == "FILE"
assert inputs["event_type"] == "document.indexed"
assert inputs["event_id"] == event.event_id
assert inputs["occurred_at"] == event.occurred_at.isoformat()

View file

@ -0,0 +1,39 @@
"""Which triggers an event fires: event_type equality + filter match."""
from __future__ import annotations
import pytest
from app.automations.triggers.builtin.event.match import trigger_matches_event
from app.event_bus import Event
pytestmark = pytest.mark.unit
def _event(event_type: str = "document.indexed", **payload) -> Event:
return Event(event_type=event_type, payload=payload, search_space_id=7)
def test_matches_when_event_type_equal_and_filter_passes() -> None:
params = {"event_type": "document.indexed", "filter": {"document_type": "FILE"}}
assert trigger_matches_event(params, _event(document_type="FILE")) is True
def test_no_match_when_event_type_differs() -> None:
params = {"event_type": "document.indexed", "filter": {}}
assert trigger_matches_event(params, _event("podcast.generated")) is False
def test_no_match_when_filter_rejects_payload() -> None:
params = {"event_type": "document.indexed", "filter": {"document_type": "FILE"}}
assert trigger_matches_event(params, _event(document_type="WEBPAGE")) is False
def test_empty_filter_matches_any_payload_of_that_type() -> None:
params = {"event_type": "document.indexed", "filter": {}}
assert trigger_matches_event(params, _event(document_type="ANYTHING")) is True
def test_missing_filter_key_is_treated_as_empty() -> None:
params = {"event_type": "document.indexed"}
assert trigger_matches_event(params, _event(document_type="X")) is True

View file

@ -0,0 +1,40 @@
"""``EventTriggerParams`` contract: an event_type to listen for + an optional filter."""
from __future__ import annotations
import pytest
from app.automations.triggers.builtin.event.params import EventTriggerParams
pytestmark = pytest.mark.unit
def test_accepts_event_type_and_filter() -> None:
params = EventTriggerParams(
event_type="document.indexed",
filter={"document_type": "FILE"},
)
assert params.event_type == "document.indexed"
assert params.filter == {"document_type": "FILE"}
def test_filter_defaults_to_empty() -> None:
params = EventTriggerParams(event_type="document.indexed")
assert params.filter == {}
def test_event_type_is_required() -> None:
with pytest.raises(ValueError):
EventTriggerParams(filter={"x": 1})
def test_event_type_must_not_be_blank() -> None:
with pytest.raises(ValueError):
EventTriggerParams(event_type="")
def test_extra_keys_are_forbidden() -> None:
with pytest.raises(ValueError):
EventTriggerParams(event_type="document.indexed", typo=True)

View file

@ -0,0 +1,86 @@
"""Lock the cron + timezone + UTC normalization contract."""
from __future__ import annotations
from datetime import UTC, datetime
import pytest
from app.automations.triggers.builtin.schedule.cron import (
InvalidCronError,
compute_next_fire_at,
validate_cron,
)
pytestmark = pytest.mark.unit
def test_compute_next_fire_at_returns_next_match_normalized_to_utc() -> None:
"""``compute_next_fire_at`` evaluates the cron in the given IANA timezone
and returns the next strictly-later match expressed in UTC.
Setup: ``0 9 * * 1-5`` (09:00 Monday-Friday) in ``Africa/Kigali``
(UTC+2, no DST). With ``after`` = Tuesday 05:00 UTC (= 07:00 local),
the next fire is the same Tuesday at 09:00 local = 07:00 UTC.
"""
after = datetime(2026, 5, 26, 5, 0, tzinfo=UTC) # Tue 07:00 Kigali
next_fire = compute_next_fire_at("0 9 * * 1-5", "Africa/Kigali", after=after)
assert next_fire == datetime(2026, 5, 26, 7, 0, tzinfo=UTC)
def test_compute_next_fire_at_respects_dst_offset_change() -> None:
"""A daily cron in a DST-observing tz fires at the same local hour
across the DST boundary, which produces a different UTC offset on
either side of the transition.
Setup: ``0 9 * * *`` (09:00 every day) in ``America/New_York``.
NY is UTC-5 in winter (EST), UTC-4 in summer (EDT). Evaluating from
each side of the spring-forward in 2026 (Sun Mar 8 at 02:00 03:00):
- winter: ``after`` = 2026-02-15 (EST, UTC-5) next 09:00 EST = 14:00 UTC
- summer: ``after`` = 2026-04-15 (EDT, UTC-4) next 09:00 EDT = 13:00 UTC
"""
winter_after = datetime(2026, 2, 15, 0, 0, tzinfo=UTC)
summer_after = datetime(2026, 4, 15, 0, 0, tzinfo=UTC)
winter_fire = compute_next_fire_at(
"0 9 * * *", "America/New_York", after=winter_after
)
summer_fire = compute_next_fire_at(
"0 9 * * *", "America/New_York", after=summer_after
)
assert winter_fire == datetime(2026, 2, 15, 14, 0, tzinfo=UTC)
assert summer_fire == datetime(2026, 4, 15, 13, 0, tzinfo=UTC)
def test_compute_next_fire_at_is_strictly_after_when_after_equals_a_match() -> None:
"""When ``after`` lands exactly on a cron match, the result jumps to the
next match never the same instant. Required so the schedule-tick
can pass ``next_fire_at`` itself as ``after`` to advance to the
following slot without double-firing.
Setup: weekday 09:00 Kigali. ``after`` = Mon 09:00 Kigali = 07:00 UTC
(an exact match) next fire must be Tue 09:00 Kigali = next day 07:00 UTC.
"""
after = datetime(2026, 5, 25, 7, 0, tzinfo=UTC) # Mon 09:00 Kigali — exact match
next_fire = compute_next_fire_at("0 9 * * 1-5", "Africa/Kigali", after=after)
assert next_fire == datetime(2026, 5, 26, 7, 0, tzinfo=UTC) # Tue 09:00 Kigali
def test_validate_cron_rejects_malformed_cron_expression() -> None:
"""A syntactically invalid cron must be rejected at validation time so
bad triggers can't reach storage and explode at fire time."""
with pytest.raises(InvalidCronError):
validate_cron("this is not cron", "UTC")
def test_validate_cron_rejects_unknown_timezone() -> None:
"""A non-IANA timezone string must be rejected at validation time —
the same protective gate as the cron expression itself."""
with pytest.raises(InvalidCronError):
validate_cron("0 9 * * *", "Mars/Olympus_Mons")

View file

@ -0,0 +1,34 @@
"""Lock the ``ScheduleTriggerParams`` validation contract."""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from app.automations.triggers.builtin.schedule.params import ScheduleTriggerParams
pytestmark = pytest.mark.unit
def test_schedule_params_accept_valid_cron_and_iana_timezone() -> None:
"""A well-formed cron + IANA timezone yields a populated model.
Locks the round-trip path users go through when creating a trigger."""
params = ScheduleTriggerParams(cron="0 9 * * 1-5", timezone="Africa/Kigali")
assert params.cron == "0 9 * * 1-5"
assert params.timezone == "Africa/Kigali"
def test_schedule_params_reject_malformed_cron_with_validation_error() -> None:
"""``InvalidCronError`` from ``validate_cron`` must surface as
Pydantic ``ValidationError`` so the FastAPI layer returns 422 instead
of letting the bad value reach storage."""
with pytest.raises(ValidationError):
ScheduleTriggerParams(cron="not cron", timezone="UTC")
def test_schedule_params_reject_unknown_timezone_with_validation_error() -> None:
"""An unknown timezone is rejected at the API boundary — same gate
as the cron expression itself."""
with pytest.raises(ValidationError):
ScheduleTriggerParams(cron="0 9 * * *", timezone="Mars/Olympus_Mons")

View file

@ -0,0 +1,25 @@
"""Shared fixtures for the ``app.event_bus`` unit-test tree.
The event-type catalog is a module-level registry populated at import. Tests
that register their own event types (or assert on registry contents) snapshot
and restore it so state never leaks between tests.
"""
from __future__ import annotations
from collections.abc import Iterator
import pytest
from app.event_bus.catalog import catalog
@pytest.fixture
def isolated_event_catalog() -> Iterator[None]:
"""Snapshot and restore the event-type catalog around a test."""
snapshot = dict(catalog._registry)
try:
yield
finally:
catalog._registry.clear()
catalog._registry.update(snapshot)

View file

@ -0,0 +1,181 @@
"""``EventBus`` contract: subscribe, publish (stamp + fan out), dispatch.
Each test uses a fresh ``EventBus`` no shared global state.
"""
from __future__ import annotations
import pytest
from app.event_bus import Event, EventBus
pytestmark = pytest.mark.unit
def _event() -> Event:
return Event(event_type="x.happened", payload={"k": "v"}, search_space_id=1)
async def _noop(_event: Event) -> None:
return None
async def _other(_event: Event) -> None:
return None
# --- registry -------------------------------------------------------------
def test_subscribe_then_subscribers_returns_the_handler() -> None:
bus = EventBus()
bus.subscribe(_noop)
assert _noop in bus.subscribers()
def test_subscribe_is_idempotent_for_the_same_handler() -> None:
"""Registering the same handler twice must not make it fire twice."""
bus = EventBus()
bus.subscribe(_noop)
bus.subscribe(_noop)
assert bus.subscribers().count(_noop) == 1
def test_distinct_handlers_both_register() -> None:
bus = EventBus()
bus.subscribe(_noop)
bus.subscribe(_other)
registered = bus.subscribers()
assert _noop in registered
assert _other in registered
def test_subscribers_returns_a_defensive_snapshot() -> None:
"""Mutating the returned list must not corrupt the registry."""
bus = EventBus()
bus.subscribe(_noop)
snapshot = bus.subscribers()
snapshot.clear()
assert _noop in bus.subscribers()
def test_subscribe_returns_handler_so_it_can_be_used_as_a_decorator() -> None:
bus = EventBus()
returned = bus.subscribe(_other)
assert returned is _other
def test_two_buses_do_not_share_subscribers() -> None:
"""The registry is per-instance, not global."""
a = EventBus()
b = EventBus()
a.subscribe(_noop)
assert _noop in a.subscribers()
assert _noop not in b.subscribers()
# --- dispatch -------------------------------------------------------------
async def test_dispatch_delivers_event_to_every_subscriber() -> None:
bus = EventBus()
seen: list[tuple[str, Event]] = []
async def first(event: Event) -> None:
seen.append(("first", event))
async def second(event: Event) -> None:
seen.append(("second", event))
bus.subscribe(first)
bus.subscribe(second)
event = _event()
await bus.dispatch(event)
assert ("first", event) in seen
assert ("second", event) in seen
async def test_dispatch_isolates_a_failing_subscriber() -> None:
"""A subscriber that raises must not stop a healthy one from running."""
bus = EventBus()
healthy_ran = False
async def boom(_event: Event) -> None:
raise RuntimeError("subscriber blew up")
async def healthy(_event: Event) -> None:
nonlocal healthy_ran
healthy_ran = True
bus.subscribe(boom)
bus.subscribe(healthy)
await bus.dispatch(_event())
assert healthy_ran is True
async def test_dispatch_never_propagates_subscriber_errors() -> None:
"""``dispatch`` itself must not raise even if every subscriber fails."""
bus = EventBus()
async def boom(_event: Event) -> None:
raise ValueError("nope")
bus.subscribe(boom)
await bus.dispatch(_event()) # must not raise
async def test_dispatch_with_no_subscribers_is_a_noop() -> None:
bus = EventBus()
await bus.dispatch(_event()) # must not raise
# --- publish --------------------------------------------------------------
async def test_publish_builds_a_stamped_event_and_fans_it_out() -> None:
bus = EventBus()
received: list[Event] = []
async def handler(event: Event) -> None:
received.append(event)
bus.subscribe(handler)
await bus.publish("document.indexed", {"document_id": 42}, search_space_id=7)
assert len(received) == 1
event = received[0]
assert event.event_type == "document.indexed"
assert event.payload == {"document_id": 42}
assert event.search_space_id == 7
# Engine-stamped identity/time on the way through.
assert event.event_id
assert event.occurred_at
async def test_publish_defaults_payload_to_empty_dict() -> None:
bus = EventBus()
received: list[Event] = []
async def handler(event: Event) -> None:
received.append(event)
bus.subscribe(handler)
await bus.publish("x.happened", search_space_id=1)
assert received[0].payload == {}
async def test_publish_with_no_subscribers_is_a_noop() -> None:
await EventBus().publish("x.happened", search_space_id=1) # must not raise

View file

@ -0,0 +1,77 @@
"""EventCatalog contract: register, look up, snapshot, derive schema."""
from __future__ import annotations
import pytest
from pydantic import BaseModel
from app.event_bus.catalog import EventCatalog, EventType
pytestmark = pytest.mark.unit
class _SamplePayload(BaseModel):
document_id: int
def _event_type(type_: str = "test.thing") -> EventType:
return EventType(
type=type_,
description="A thing happened.",
payload_model=_SamplePayload,
)
def test_register_then_get_returns_the_event_type(isolated_event_catalog: None) -> None:
from app.event_bus.catalog import catalog
catalog.register(_event_type())
assert catalog.get("test.thing") is not None
assert catalog.get("test.thing").type == "test.thing"
def test_get_unknown_type_returns_none(isolated_event_catalog: None) -> None:
from app.event_bus.catalog import catalog
assert catalog.get("does.not.exist") is None
def test_register_duplicate_type_raises(isolated_event_catalog: None) -> None:
"""A type is a contract; registering it twice is a bug, not an override."""
from app.event_bus.catalog import catalog
catalog.register(_event_type())
with pytest.raises(ValueError, match="already registered"):
catalog.register(_event_type())
def test_all_is_a_defensive_snapshot(isolated_event_catalog: None) -> None:
"""Mutating the returned dict must not corrupt the registry."""
from app.event_bus.catalog import catalog
catalog.register(_event_type())
snapshot = catalog.all()
snapshot.clear()
assert catalog.get("test.thing") is not None
def test_payload_schema_is_derived_from_the_payload_model() -> None:
"""The JSON Schema a UI/validator consumes comes from the payload model."""
event_type = _event_type()
assert event_type.payload_schema == _SamplePayload.model_json_schema()
def test_each_catalog_instance_has_its_own_registry() -> None:
"""Two EventCatalog instances are fully independent."""
a = EventCatalog()
b = EventCatalog()
a.register(_event_type())
assert a.get("test.thing") is not None
assert b.get("test.thing") is None

View file

@ -0,0 +1,56 @@
"""``document.entered_folder`` payload contract + catalog registration."""
from __future__ import annotations
import pytest
from app.event_bus.catalog import catalog
from app.event_bus.events.document_entered_folder import (
EVENT_TYPE,
DocumentEnteredFolderPayload,
)
pytestmark = pytest.mark.unit
def _payload(**overrides: object) -> DocumentEnteredFolderPayload:
base: dict[str, object] = {
"document_id": 42,
"folder_id": 7,
"document_type": "FILE",
"title": "Q3 report.pdf",
}
base.update(overrides)
return DocumentEnteredFolderPayload(**base)
def test_payload_carries_the_filterable_fields() -> None:
payload = _payload(connector_id=12, created_by_id="abc")
assert payload.document_id == 42
assert payload.folder_id == 7
assert payload.document_type == "FILE"
assert payload.connector_id == 12
def test_first_placement_is_not_a_move() -> None:
"""No previous folder (created or AI-sorted into place) → not a move."""
assert _payload(previous_folder_id=None).is_move is False
def test_change_between_folders_is_a_move() -> None:
assert _payload(previous_folder_id=3).is_move is True
def test_is_move_is_serialized_for_filtering() -> None:
"""Filters match against the dumped payload, so ``is_move`` must appear there."""
dumped = _payload(previous_folder_id=3).model_dump()
assert dumped["is_move"] is True
def test_event_type_is_registered_in_the_catalog() -> None:
registered = catalog.get(EVENT_TYPE)
assert registered is not None
assert registered.payload_model is DocumentEnteredFolderPayload

View file

@ -0,0 +1,58 @@
"""payload_if_entered_folder: decides whether a document commit warrants an event."""
from __future__ import annotations
from typing import Any
import pytest
from app.event_bus.events.document_entered_folder import payload_if_entered_folder
pytestmark = pytest.mark.unit
def _call(**overrides: Any) -> dict[str, Any] | None:
defaults: dict[str, Any] = {
"document_id": 1,
"search_space_id": 10,
"new_folder_id": 7,
"previous_folder_id": None,
"folder_id_changed": True,
"status_state": "ready",
"document_type": "FILE",
"title": "report.pdf",
"connector_id": None,
"created_by_id": None,
}
defaults.update(overrides)
return payload_if_entered_folder(**defaults)
def test_folder_set_ready_fires() -> None:
result = _call()
assert result is not None
assert result["event_type"] == "document.entered_folder"
assert result["search_space_id"] == 10
assert result["payload"]["folder_id"] == 7
assert result["payload"]["previous_folder_id"] is None
def test_no_folder_is_silent() -> None:
assert _call(new_folder_id=None) is None
def test_not_ready_is_silent() -> None:
assert _call(status_state="processing") is None
def test_folder_unchanged_is_silent() -> None:
assert _call(folder_id_changed=False) is None
def test_move_carries_previous_folder_id() -> None:
result = _call(previous_folder_id=3, new_folder_id=7)
assert result is not None
assert result["payload"]["previous_folder_id"] == 3
assert result["payload"]["folder_id"] == 7

View file

@ -0,0 +1,53 @@
"""``Event`` contract: carry caller facts + engine-stamped id/time, round-trip JSON."""
from __future__ import annotations
from datetime import datetime
import pytest
from app.event_bus.event import Event
pytestmark = pytest.mark.unit
def test_event_carries_caller_supplied_facts() -> None:
"""The three caller inputs are stored verbatim."""
event = Event(
event_type="document.indexed",
payload={"document_id": 42, "content_type": "pdf"},
search_space_id=7,
)
assert event.event_type == "document.indexed"
assert event.payload == {"document_id": 42, "content_type": "pdf"}
assert event.search_space_id == 7
def test_event_stamps_identity_and_time_when_not_supplied() -> None:
"""Engine stamps id + time so subscribers can dedup/order."""
event = Event(event_type="x.happened", payload={}, search_space_id=1)
assert event.event_id
assert isinstance(event.occurred_at, datetime)
def test_event_ids_are_unique_per_instance() -> None:
"""Two events published with identical content are still distinct facts."""
first = Event(event_type="x.happened", payload={}, search_space_id=1)
second = Event(event_type="x.happened", payload={}, search_space_id=1)
assert first.event_id != second.event_id
def test_event_survives_json_round_trip() -> None:
"""Serialize → deserialize reproduces the event (subscribers queue it as JSON)."""
original = Event(
event_type="podcast.generated",
payload={"podcast_id": 9, "duration_s": 123.5},
search_space_id=3,
)
restored = Event.model_validate_json(original.model_dump_json())
assert restored == original

View file

@ -0,0 +1,557 @@
"""Parity gate for the parallel refactor of ``stream_new_chat.py``.
The new tree under ``app.tasks.chat.streaming.flows`` is built side-by-side with
the legacy monolithic ``app.tasks.chat.stream_new_chat`` so we can cut over
atomically. This file pins externally-observable behaviour at module
boundaries so a divergence between the two trees fails loudly *before* the
cutover.
What we verify:
1. **Signature parity** ``stream_new_chat`` / ``stream_resume_chat`` from
the new tree have the same call signature as the originals.
2. **Helper extraction parity** the SRP modules in ``flows/`` produce the
same outputs as the inline code in the legacy file for representative
inputs (initial thinking step, image-capability gate, runtime context,
SSE frame sequences, token-usage frame shape, persistence guards).
3. **Wrapper delegation** wrappers like ``load_llm_bundle`` /
``can_recover_provider_rate_limit`` exist and are addressable.
Delete this file along with ``stream_new_chat.py`` once the cutover is done
(see the parent refactor plan).
"""
from __future__ import annotations
import asyncio
import inspect
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from app.agents.new_chat.context import SurfSenseContextSchema
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_new_chat import (
stream_new_chat as old_stream_new_chat,
stream_resume_chat as old_stream_resume_chat,
)
from app.tasks.chat.streaming.flows import (
stream_new_chat as new_stream_new_chat,
stream_resume_chat as new_stream_resume_chat,
)
from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import (
build_initial_thinking_step,
)
from app.tasks.chat.streaming.flows.new_chat.llm_capability import (
check_image_input_capability,
)
from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import (
await_persist_task,
spawn_persist_assistant_shell_task,
spawn_persist_user_task,
spawn_set_ai_responding_bg,
)
from app.tasks.chat.streaming.flows.new_chat.runtime_context import (
build_new_chat_runtime_context,
)
from app.tasks.chat.streaming.flows.resume_chat.runtime_context import (
build_resume_chat_runtime_context,
)
from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame
from app.tasks.chat.streaming.flows.shared.first_frames import (
iter_final_frames,
iter_initial_frames,
)
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
from app.tasks.chat.streaming.flows.shared.premium_quota import (
PremiumReservation,
needs_premium_quota,
)
from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import (
can_recover_provider_rate_limit,
)
pytestmark = pytest.mark.unit
# --------------------------------------------------------------------- signature
def _normalize_annotation(ann: Any) -> str:
"""Compare-friendly form for an annotation.
The legacy ``stream_new_chat.py`` does NOT use ``from __future__ import
annotations``, so its annotations are evaluated at import time and come
back as type objects / typing generics. The new tree DOES use it, so its
annotations are PEP-563 strings.
Both reprs describe the same types strip the module prefixes / typing
namespace + the ``<class 'X'>`` wrapper so we compare the canonical
declared form.
"""
if ann is inspect.Signature.empty:
return ""
raw = ann if isinstance(ann, str) else repr(ann)
cleaned = (
raw.replace("typing.", "")
.replace("collections.abc.", "")
.replace("app.db.", "")
.replace("app.agents.new_chat.filesystem_selection.", "")
.replace("app.agents.new_chat.context.", "")
)
# Unwrap ``<class 'int'>`` → ``int`` (legacy-side type objects).
if cleaned.startswith("<class '") and cleaned.endswith("'>"):
cleaned = cleaned[len("<class '") : -len("'>")]
return cleaned
def _normalize_sig(sig: inspect.Signature) -> list[tuple[str, Any, str]]:
return [
(p.name, p.default, _normalize_annotation(p.annotation))
for p in sig.parameters.values()
]
def test_stream_new_chat_signature_matches_legacy() -> None:
old = inspect.signature(old_stream_new_chat)
new = inspect.signature(new_stream_new_chat)
assert _normalize_sig(new) == _normalize_sig(old)
assert _normalize_annotation(new.return_annotation) == _normalize_annotation(
old.return_annotation
)
def test_stream_resume_chat_signature_matches_legacy() -> None:
old = inspect.signature(old_stream_resume_chat)
new = inspect.signature(new_stream_resume_chat)
assert _normalize_sig(new) == _normalize_sig(old)
assert _normalize_annotation(new.return_annotation) == _normalize_annotation(
old.return_annotation
)
def test_orchestrators_are_async_generator_functions() -> None:
assert inspect.isasyncgenfunction(new_stream_new_chat)
assert inspect.isasyncgenfunction(new_stream_resume_chat)
# ------------------------------------------------------------ initial thinking
@pytest.mark.parametrize(
"user_query, image_urls, expected_title, expected_action",
[
("hello world", None, "Understanding your request", "Processing"),
(
"",
["data:image/png;base64,AAA"],
"Understanding your request",
"Processing",
),
("", None, "Understanding your request", "Processing"),
],
)
def test_initial_thinking_step_branches(
user_query: str,
image_urls: list[str] | None,
expected_title: str,
expected_action: str,
) -> None:
step = build_initial_thinking_step(
user_query=user_query,
user_image_data_urls=image_urls,
)
assert step.step_id == "thinking-1"
assert step.title == expected_title
assert len(step.items) == 1
assert step.items[0].startswith(f"{expected_action}: ")
def test_initial_thinking_step_truncates_long_query() -> None:
long_query = "x" * 200
step = build_initial_thinking_step(
user_query=long_query,
user_image_data_urls=None,
)
# 80-char truncation + ellipsis, sandwiched after "Processing: ".
assert "..." in step.items[0]
item = step.items[0]
payload = item[len("Processing: ") :]
assert payload.startswith("x" * 80) and payload.endswith("...")
# ------------------------------------------------------------ capability gate
def test_image_capability_passes_without_images() -> None:
assert (
check_image_input_capability(user_image_data_urls=None, agent_config=None)
is None
)
def test_image_capability_passes_when_capability_unknown() -> None:
"""Unknown / unmapped models are not blocked — only models LiteLLM has
*explicitly* marked text-only trip the gate."""
class _AgentConfig:
provider = "openrouter"
model_name = "unknown-mystery-model"
custom_provider = None
config_name = "Unknown"
litellm_params: dict[str, Any] = {}
with patch(
"app.services.provider_capabilities.is_known_text_only_chat_model",
return_value=False,
):
assert (
check_image_input_capability(
user_image_data_urls=["data:image/png;base64,AAA"],
agent_config=_AgentConfig(), # type: ignore[arg-type]
)
is None
)
def test_image_capability_blocks_known_text_only_models() -> None:
class _AgentConfig:
provider = "openai"
model_name = "gpt-3.5-turbo"
custom_provider = None
config_name = "GPT-3.5"
litellm_params: dict[str, Any] = {"base_model": "gpt-3.5-turbo"}
with patch(
"app.services.provider_capabilities.is_known_text_only_chat_model",
return_value=True,
):
result = check_image_input_capability(
user_image_data_urls=["data:image/png;base64,AAA"],
agent_config=_AgentConfig(), # type: ignore[arg-type]
)
assert result is not None
message, error_code = result
assert error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
assert "GPT-3.5" in message
# ---------------------------------------------------------------- runtime ctx
def test_new_chat_runtime_context_prefers_accepted_folder_ids() -> None:
ctx = build_new_chat_runtime_context(
search_space_id=7,
mentioned_document_ids=[1, 2],
accepted_folder_ids=[10],
mentioned_folder_ids=[20, 30],
request_id="req",
turn_id="t1",
)
assert isinstance(ctx, SurfSenseContextSchema)
assert ctx.search_space_id == 7
assert list(ctx.mentioned_document_ids) == [1, 2]
assert list(ctx.mentioned_folder_ids) == [10]
assert ctx.request_id == "req"
assert ctx.turn_id == "t1"
def test_new_chat_runtime_context_falls_back_to_mentioned_folder_ids() -> None:
ctx = build_new_chat_runtime_context(
search_space_id=7,
mentioned_document_ids=None,
accepted_folder_ids=[],
mentioned_folder_ids=[20, 30],
request_id=None,
turn_id="t2",
)
assert list(ctx.mentioned_folder_ids) == [20, 30]
def test_resume_chat_runtime_context_empty_mention_lists() -> None:
ctx = build_resume_chat_runtime_context(
search_space_id=42, request_id="req-r", turn_id="t-r"
)
assert ctx.search_space_id == 42
assert ctx.request_id == "req-r"
assert ctx.turn_id == "t-r"
# ---------------------------------------------------------------- SSE frames
def test_iter_initial_frames_emits_canonical_sequence() -> None:
svc = VercelStreamingService()
frames = list(iter_initial_frames(svc, turn_id="42:1700000000000"))
# Exactly 4 frames: message_start, start_step, turn-info (turn_id), turn-status (busy).
assert len(frames) == 4
assert "42:1700000000000" in frames[2]
assert '"status":"busy"' in frames[3] or '"status": "busy"' in frames[3]
def test_iter_final_frames_emits_idle_then_finish_done() -> None:
svc = VercelStreamingService()
frames = list(iter_final_frames(svc))
assert len(frames) == 4
assert '"status":"idle"' in frames[0] or '"status": "idle"' in frames[0]
# ----------------------------------------------------------- token usage frame
class _FakeAccumulator:
"""Minimal stand-in covering only the fields ``iter_token_usage_frame`` reads."""
def __init__(self, summary: Any = None) -> None:
self._summary = summary
self.calls = [1, 2, 3]
self.grand_total = 100
self.total_cost_micros = 50_000
self.total_prompt_tokens = 60
self.total_completion_tokens = 40
def per_message_summary(self) -> Any:
return self._summary
def serialized_calls(self) -> list[Any]:
return list(self.calls)
def test_token_usage_frame_skipped_when_no_summary() -> None:
svc = VercelStreamingService()
frames = list(
iter_token_usage_frame(
svc,
accumulator=_FakeAccumulator(summary=None), # type: ignore[arg-type]
log_label="parity-empty",
)
)
assert frames == []
def test_token_usage_frame_emitted_when_summary_present() -> None:
svc = VercelStreamingService()
frames = list(
iter_token_usage_frame(
svc,
accumulator=_FakeAccumulator(summary=[{"m": "x", "t": 100}]), # type: ignore[arg-type]
log_label="parity-populated",
)
)
assert len(frames) == 1
# Field shape on the wire is fixed by the FE; assert each surfaces.
payload = frames[0]
for key in (
'"prompt_tokens":60',
'"completion_tokens":40',
'"total_tokens":100',
'"cost_micros":50000',
):
assert key in payload.replace(" ", "")
# ------------------------------------------------------------------ llm_bundle
def test_load_llm_bundle_routes_negative_id_to_yaml_loader() -> None:
async def _run() -> tuple[Any, Any, str | None]:
with (
patch(
"app.tasks.chat.streaming.flows.shared.llm_bundle.load_global_llm_config_by_id",
return_value=None,
),
):
return await load_llm_bundle(
session=AsyncMock(), # type: ignore[arg-type]
config_id=-1,
search_space_id=7,
)
llm, agent_config, error = asyncio.run(_run())
assert llm is None
assert agent_config is None
assert error is not None and "id -1" in error
def test_load_llm_bundle_routes_nonnegative_id_to_db_loader() -> None:
async def _run() -> tuple[Any, Any, str | None]:
with (
patch(
"app.tasks.chat.streaming.flows.shared.llm_bundle.load_agent_config",
new=AsyncMock(return_value=None),
),
):
return await load_llm_bundle(
session=AsyncMock(), # type: ignore[arg-type]
config_id=12,
search_space_id=7,
)
llm, agent_config, error = asyncio.run(_run())
assert llm is None
assert agent_config is None
assert error is not None and "id 12" in error
# ----------------------------------------------------------------- premium quota
def test_needs_premium_quota_requires_user_and_premium_flag() -> None:
class _AgentConfig:
is_premium = True
class _NonPremium:
is_premium = False
assert needs_premium_quota(_AgentConfig(), "user-1") is True # type: ignore[arg-type]
assert needs_premium_quota(_AgentConfig(), None) is False # type: ignore[arg-type]
assert needs_premium_quota(_NonPremium(), "user-1") is False # type: ignore[arg-type]
assert needs_premium_quota(None, "user-1") is False
def test_premium_reservation_dataclass_shape() -> None:
# Sanity: the dataclass exists and carries the fields the orchestrator uses.
r = PremiumReservation(request_id="abc", reserved_micros=100, allowed=True)
assert r.request_id == "abc"
assert r.reserved_micros == 100
assert r.allowed is True
# ----------------------------------------------------------- rate-limit guard
@pytest.mark.parametrize(
"first_event_seen, recovered, requested_id, current_id, expected",
[
(False, False, 0, -1, True),
# Already recovered: no second pass.
(False, True, 0, -1, False),
# User explicitly picked a config: don't silently switch.
(False, False, 5, -1, False),
# Already on a database-backed (positive) id.
(False, False, 0, 7, False),
# User has already seen output: silent rebuild not possible.
(True, False, 0, -1, False),
],
)
def test_can_recover_provider_rate_limit_truth_table(
first_event_seen: bool,
recovered: bool,
requested_id: int,
current_id: int,
expected: bool,
) -> None:
# Use a known rate-limit-shaped exception so the helper's last condition
# is satisfied; the guard only short-circuits to False when one of the
# *other* preconditions fails.
exc = Exception('{"error":{"type":"rate_limit_error","message":"slow"}}')
assert (
can_recover_provider_rate_limit(
exc,
first_event_seen=first_event_seen,
runtime_rate_limit_recovered=recovered,
requested_llm_config_id=requested_id,
current_llm_config_id=current_id,
)
is expected
)
def test_can_recover_provider_rate_limit_rejects_non_rate_limit_exception() -> None:
assert (
can_recover_provider_rate_limit(
ValueError("not a rate limit"),
first_event_seen=False,
runtime_rate_limit_recovered=False,
requested_llm_config_id=0,
current_llm_config_id=-1,
)
is False
)
# --------------------------------------------------------- persistence spawn
def test_spawn_set_ai_responding_bg_noop_without_user_id() -> None:
async def _run() -> set[asyncio.Task]:
background: set[asyncio.Task] = set()
spawn_set_ai_responding_bg(chat_id=1, user_id=None, background_tasks=background)
return background
bg = asyncio.run(_run())
assert bg == set()
def test_spawn_persist_user_task_registers_and_self_unregisters() -> None:
async def _run() -> tuple[int, int]:
background: set[asyncio.Task] = set()
with patch(
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_user_turn",
new=AsyncMock(return_value=99),
):
task = spawn_persist_user_task(
chat_id=1,
user_id="u",
turn_id="t",
user_query="hi",
user_image_data_urls=None,
mentioned_documents=None,
background_tasks=background,
)
size_before_await = len(background)
result = await asyncio.shield(task)
# Give the done-callback one event-loop tick to run.
await asyncio.sleep(0)
return size_before_await, result # type: ignore[return-value]
size_before, result = asyncio.run(_run())
assert size_before == 1
assert result == 99
def test_spawn_persist_assistant_shell_task_registers() -> None:
async def _run() -> int | None:
background: set[asyncio.Task] = set()
with patch(
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_assistant_shell",
new=AsyncMock(return_value=42),
):
task = spawn_persist_assistant_shell_task(
chat_id=1,
user_id="u",
turn_id="t",
background_tasks=background,
)
return await asyncio.shield(task)
assert asyncio.run(_run()) == 42
def test_await_persist_task_returns_none_on_failure() -> None:
async def _run() -> int | None:
async def _boom() -> int:
raise RuntimeError("DB down")
task = asyncio.create_task(_boom())
return await await_persist_task(
task,
chat_id=1,
turn_id="t",
log_label="parity-failure",
)
assert asyncio.run(_run()) is None
def test_await_persist_task_returns_none_for_none_input() -> None:
async def _run() -> int | None:
return await await_persist_task(
None,
chat_id=1,
turn_id="t",
log_label="parity-none",
)
assert asyncio.run(_run()) is None