mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/whatsapp-gateway-integration
This commit is contained in:
commit
e3de7c4667
465 changed files with 29171 additions and 6994 deletions
|
|
@ -60,7 +60,6 @@ class TestReadOnlyToolsAllowed:
|
|||
"glob",
|
||||
"web_search",
|
||||
"scrape_webpage",
|
||||
"search_surfsense_docs",
|
||||
"get_connected_accounts",
|
||||
"write_todos",
|
||||
"task",
|
||||
|
|
|
|||
|
|
@ -22,12 +22,6 @@ from app.agents.new_chat.subagents.config import (
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@tool
|
||||
def search_surfsense_docs(query: str) -> str:
|
||||
"""Search the user's KB."""
|
||||
return ""
|
||||
|
||||
|
||||
@tool
|
||||
def web_search(query: str) -> str:
|
||||
"""Search the public web."""
|
||||
|
|
@ -95,7 +89,6 @@ def generate_report(topic: str) -> str:
|
|||
|
||||
|
||||
ALL_TOOLS = [
|
||||
search_surfsense_docs,
|
||||
web_search,
|
||||
scrape_webpage,
|
||||
read_file,
|
||||
|
|
@ -161,7 +154,7 @@ class TestReportWriterSubagent:
|
|||
names = {t.name for t in spec["tools"]} # type: ignore[index]
|
||||
assert names == REPORT_WRITER_TOOLS & {t.name for t in ALL_TOOLS}
|
||||
assert "generate_report" in names
|
||||
assert "search_surfsense_docs" in names
|
||||
assert "read_file" in names
|
||||
|
||||
def test_deny_rules_block_writes_but_allow_generate_report(self) -> None:
|
||||
spec = build_report_writer_subagent(tools=ALL_TOOLS)
|
||||
|
|
@ -272,9 +265,9 @@ class TestFilterToolsWarningSuppression:
|
|||
# Allowed set asks for two registry tools (one present, one
|
||||
# not) plus a bunch of middleware-provided names.
|
||||
_filter_tools(
|
||||
[search_surfsense_docs],
|
||||
[web_search],
|
||||
allowed_names={
|
||||
"search_surfsense_docs",
|
||||
"web_search",
|
||||
"scrape_webpage", # legitimately missing → should warn
|
||||
"read_file", # mw-provided → suppressed
|
||||
"ls",
|
||||
|
|
@ -322,7 +315,6 @@ class TestDenyPatternsCoverage:
|
|||
|
||||
def test_deny_patterns_do_not_match_safe_read_tools(self) -> None:
|
||||
canonical_reads = [
|
||||
"search_surfsense_docs",
|
||||
"read_file",
|
||||
"ls_tree",
|
||||
"grep",
|
||||
|
|
|
|||
0
surfsense_backend/tests/unit/automations/__init__.py
Normal file
0
surfsense_backend/tests/unit/automations/__init__.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""Lock ``build_auto_decisions`` — the HITL auto-approve/reject wire mapper.
|
||||
|
||||
``build_auto_decisions`` walks ``state.interrupts`` (duck-typed) and produces
|
||||
two parallel resume maps: one keyed by LangGraph ``Interrupt.id`` and one
|
||||
keyed by ``tool_call_id`` for the subagent middleware bridge. Both carry
|
||||
the same decision payload.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.actions.builtin.agent_task.auto_decide import build_auto_decisions
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _state(interrupts: list[Any]) -> SimpleNamespace:
|
||||
"""Build a duck-typed LangGraph state stub carrying ``interrupts``."""
|
||||
return SimpleNamespace(interrupts=interrupts)
|
||||
|
||||
|
||||
def _interrupt(*, id_: str, value: Any) -> SimpleNamespace:
|
||||
"""Build a duck-typed interrupt with the canonical ``(id, value)`` shape."""
|
||||
return SimpleNamespace(id=id_, value=value)
|
||||
|
||||
|
||||
def test_build_auto_decisions_produces_one_decision_per_action_request() -> None:
|
||||
"""An interrupt carrying N ``action_requests`` produces N decisions of
|
||||
the requested type in both maps. This is the canonical batched-HITL
|
||||
wire shape — losing a decision would leave a pending action stuck."""
|
||||
interrupt = _interrupt(
|
||||
id_="lg-1",
|
||||
value={
|
||||
"tool_call_id": "tc-1",
|
||||
"action_requests": [{"id": "a"}, {"id": "b"}],
|
||||
},
|
||||
)
|
||||
|
||||
lg_map, routed = build_auto_decisions(_state([interrupt]), "approve")
|
||||
|
||||
assert lg_map == {"lg-1": {"decisions": [{"type": "approve"}, {"type": "approve"}]}}
|
||||
assert routed == {"tc-1": {"decisions": [{"type": "approve"}, {"type": "approve"}]}}
|
||||
|
||||
|
||||
def test_build_auto_decisions_defaults_to_one_decision_for_scalar_interrupt() -> None:
|
||||
"""When an interrupt's value has no ``action_requests`` list, the
|
||||
function defaults to a single decision. Locks compatibility with
|
||||
older single-action interrupt shapes still emitted by some tools."""
|
||||
interrupt = _interrupt(id_="lg-2", value={"tool_call_id": "tc-2"})
|
||||
|
||||
lg_map, routed = build_auto_decisions(_state([interrupt]), "reject")
|
||||
|
||||
assert lg_map == {"lg-2": {"decisions": [{"type": "reject"}]}}
|
||||
assert routed == {"tc-2": {"decisions": [{"type": "reject"}]}}
|
||||
|
||||
|
||||
def test_build_auto_decisions_skips_interrupts_with_invalid_shape() -> None:
|
||||
"""Interrupts missing the canonical ``(str id, dict value)`` shape are
|
||||
skipped silently rather than crashing the resume loop. Locks the
|
||||
resilience contract — a malformed interrupt from a misbehaving tool
|
||||
shouldn't take down the whole agent_task step."""
|
||||
good = _interrupt(id_="lg-good", value={"tool_call_id": "tc-good"})
|
||||
bad_value = _interrupt(id_="lg-bad-value", value="not a dict")
|
||||
bad_id = _interrupt(id_=None, value={"tool_call_id": "tc-bad-id"}) # type: ignore[arg-type]
|
||||
|
||||
lg_map, routed = build_auto_decisions(_state([good, bad_value, bad_id]), "approve")
|
||||
|
||||
assert lg_map == {"lg-good": {"decisions": [{"type": "approve"}]}}
|
||||
assert routed == {"tc-good": {"decisions": [{"type": "approve"}]}}
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
"""Lock the runtime model-policy backstop in ``build_dependencies``.
|
||||
|
||||
Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so
|
||||
runs are insulated from later chat/search-space model changes), and the model
|
||||
policy is re-checked at run time so a captured model that is no longer billable
|
||||
fails the run clearly. When no snapshot is present, resolution falls back to the
|
||||
live search space.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
import app.automations.actions.agent_task.dependencies as deps_mod
|
||||
from app.automations.actions.agent_task.dependencies import (
|
||||
DependencyError,
|
||||
build_dependencies,
|
||||
)
|
||||
from app.automations.services.model_policy import AutomationModelPolicyError
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Minimal async session whose ``get`` returns a preset search space."""
|
||||
|
||||
def __init__(self, search_space: Any) -> None:
|
||||
self._search_space = search_space
|
||||
|
||||
async def get(self, _model: Any, _pk: int) -> Any:
|
||||
return self._search_space
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_side_effects(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Stub the connector setup + checkpointer so only policy/LLM logic runs."""
|
||||
|
||||
async def _fake_setup(_session, *, search_space_id):
|
||||
return (SimpleNamespace(name="connector"), "fc-key")
|
||||
|
||||
monkeypatch.setattr(deps_mod, "setup_connector_and_firecrawl", _fake_setup)
|
||||
return None
|
||||
|
||||
|
||||
async def test_build_dependencies_resolves_captured_agent_llm_id(
|
||||
monkeypatch: pytest.MonkeyPatch, patched_side_effects
|
||||
) -> None:
|
||||
"""The bundle loads with the *captured* ``agent_llm_id``, not the live search space."""
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def _fake_load(_session, *, config_id, search_space_id):
|
||||
captured["config_id"] = config_id
|
||||
captured["search_space_id"] = search_space_id
|
||||
return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None)
|
||||
|
||||
monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load)
|
||||
# Captured path validates the explicit ids; passes for this test.
|
||||
monkeypatch.setattr(deps_mod, "assert_models_billable", lambda **_kw: None)
|
||||
# A different value on the live search space proves we ignore it when a
|
||||
# snapshot is supplied.
|
||||
monkeypatch.setattr(
|
||||
deps_mod,
|
||||
"assert_automation_models_billable",
|
||||
lambda _ss: pytest.fail("search-space policy should not run on captured path"),
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(agent_llm_id=-99)
|
||||
result = await build_dependencies(
|
||||
session=_FakeSession(search_space),
|
||||
search_space_id=42,
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
)
|
||||
|
||||
assert captured == {"config_id": -7, "search_space_id": 42}
|
||||
assert result.llm.name == "llm"
|
||||
assert result.firecrawl_api_key == "fc-key"
|
||||
|
||||
|
||||
async def test_build_dependencies_validates_captured_ids(
|
||||
monkeypatch: pytest.MonkeyPatch, patched_side_effects
|
||||
) -> None:
|
||||
"""The captured ids (not the search space) are what gets policy-checked."""
|
||||
seen: dict[str, Any] = {}
|
||||
|
||||
def _capture(**kwargs):
|
||||
seen.update(kwargs)
|
||||
|
||||
monkeypatch.setattr(deps_mod, "assert_models_billable", _capture)
|
||||
|
||||
async def _fake_load(_session, *, config_id, search_space_id):
|
||||
return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None)
|
||||
|
||||
monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load)
|
||||
|
||||
await build_dependencies(
|
||||
session=_FakeSession(SimpleNamespace(agent_llm_id=0)),
|
||||
search_space_id=42,
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
)
|
||||
|
||||
assert seen == {
|
||||
"agent_llm_id": -7,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
}
|
||||
|
||||
|
||||
async def test_build_dependencies_raises_on_captured_policy_violation(
|
||||
monkeypatch: pytest.MonkeyPatch, patched_side_effects
|
||||
) -> None:
|
||||
"""A blocked captured model raises ``DependencyError`` so the step fails clearly."""
|
||||
|
||||
def _raise(**_kw):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "image", "config_id": -2, "reason": "free model"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(deps_mod, "assert_models_billable", _raise)
|
||||
monkeypatch.setattr(
|
||||
deps_mod,
|
||||
"load_llm_bundle",
|
||||
lambda *a, **k: pytest.fail("load_llm_bundle should not be called"),
|
||||
)
|
||||
|
||||
with pytest.raises(DependencyError):
|
||||
await build_dependencies(
|
||||
session=_FakeSession(SimpleNamespace(agent_llm_id=-7)),
|
||||
search_space_id=42,
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=-2,
|
||||
vision_llm_config_id=-1,
|
||||
)
|
||||
|
||||
|
||||
async def test_build_dependencies_falls_back_to_search_space(
|
||||
monkeypatch: pytest.MonkeyPatch, patched_side_effects
|
||||
) -> None:
|
||||
"""With no captured snapshot, resolve + validate the live search space."""
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def _fake_load(_session, *, config_id, search_space_id):
|
||||
captured["config_id"] = config_id
|
||||
return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None)
|
||||
|
||||
monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load)
|
||||
monkeypatch.setattr(deps_mod, "assert_automation_models_billable", lambda _ss: None)
|
||||
monkeypatch.setattr(
|
||||
deps_mod,
|
||||
"assert_models_billable",
|
||||
lambda **_kw: pytest.fail("captured policy should not run on fallback path"),
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(agent_llm_id=-7)
|
||||
result = await build_dependencies(
|
||||
session=_FakeSession(search_space), search_space_id=42
|
||||
)
|
||||
|
||||
assert captured == {"config_id": -7}
|
||||
assert result.llm.name == "llm"
|
||||
|
||||
|
||||
async def test_build_dependencies_raises_when_search_space_missing(
|
||||
patched_side_effects,
|
||||
) -> None:
|
||||
"""A missing search space (fallback path) surfaces as a ``DependencyError``."""
|
||||
with pytest.raises(DependencyError):
|
||||
await build_dependencies(session=_FakeSession(None), search_space_id=999)
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
"""Lock ``extract_final_assistant_message`` — what surfaces in run output.
|
||||
|
||||
Each scenario is one shape the agent runtime is observed to produce.
|
||||
Locking these means we can refactor the extractor without losing
|
||||
backwards compatibility with already-stored ``run.output`` payloads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from app.automations.actions.builtin.agent_task.finalize import (
|
||||
extract_final_assistant_message,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_extract_returns_last_ai_message_string_content() -> None:
|
||||
"""The canonical shape: the agent's final ``AIMessage`` carries a
|
||||
plain string. That string is returned verbatim, trimmed."""
|
||||
result = {
|
||||
"messages": [
|
||||
HumanMessage(content="ask"),
|
||||
AIMessage(content="the answer"),
|
||||
]
|
||||
}
|
||||
|
||||
assert extract_final_assistant_message(result) == "the answer"
|
||||
|
||||
|
||||
def test_extract_concatenates_text_parts_and_skips_non_text_parts() -> None:
|
||||
"""Multi-part AIMessage content (Anthropic / OpenAI list shape) joins
|
||||
its ``text`` parts in order; non-text parts (tool_use, images, ...)
|
||||
are skipped. Locks the wire shape used when the model emits tool
|
||||
calls alongside narrative text in the same turn."""
|
||||
result = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Hello "},
|
||||
{"type": "tool_use", "name": "search", "input": {}},
|
||||
{"type": "text", "text": "world"},
|
||||
]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
assert extract_final_assistant_message(result) == "Hello world"
|
||||
|
||||
|
||||
def test_extract_returns_last_ai_message_skipping_tool_messages() -> None:
|
||||
"""When the transcript ends with tool calls and tool results, the
|
||||
extractor still walks back to the **last** ``AIMessage`` (the agent's
|
||||
final narrative answer). Locks resilience against trailing
|
||||
``ToolMessage`` payloads in the transcript."""
|
||||
result = {
|
||||
"messages": [
|
||||
HumanMessage(content="ask"),
|
||||
AIMessage(content="thinking..."),
|
||||
ToolMessage(content="tool output", tool_call_id="tc-1"),
|
||||
AIMessage(content="final answer"),
|
||||
ToolMessage(content="trailing tool noise", tool_call_id="tc-2"),
|
||||
]
|
||||
}
|
||||
|
||||
assert extract_final_assistant_message(result) == "final answer"
|
||||
|
||||
|
||||
def test_extract_returns_none_when_no_assistant_text_is_present() -> None:
|
||||
"""No ``AIMessage`` with extractable text → ``None`` rather than the
|
||||
empty string. Lets callers branch on "did the agent actually say
|
||||
anything?" rather than guess whether ``""`` means silence or empty
|
||||
output. Empty-string contents are normalized to ``None`` too."""
|
||||
no_ai = {"messages": [HumanMessage(content="just a question")]}
|
||||
only_tools = {
|
||||
"messages": [
|
||||
AIMessage(content=[{"type": "tool_use", "name": "x", "input": {}}])
|
||||
]
|
||||
}
|
||||
empty_string = {"messages": [AIMessage(content=" ")]}
|
||||
|
||||
assert extract_final_assistant_message(no_ai) is None
|
||||
assert extract_final_assistant_message(only_tools) is None
|
||||
assert extract_final_assistant_message(empty_string) is None
|
||||
39
surfsense_backend/tests/unit/automations/conftest.py
Normal file
39
surfsense_backend/tests/unit/automations/conftest.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""Shared fixtures for the ``app.automations`` unit-test tree.
|
||||
|
||||
Provides registry isolation: the built-in ``schedule`` trigger and
|
||||
``agent_task`` action self-register at import time. Tests that register
|
||||
additional triggers/actions (or assert on the registry contents) must
|
||||
not leak that state to other tests. These fixtures snapshot and restore
|
||||
the module-level registry dicts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.actions import store as action_store
|
||||
from app.automations.triggers import store as trigger_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_action_registry() -> Iterator[None]:
|
||||
"""Snapshot and restore the action registry around a test."""
|
||||
snapshot = dict(action_store._REGISTRY)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
action_store._REGISTRY.clear()
|
||||
action_store._REGISTRY.update(snapshot)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_trigger_registry() -> Iterator[None]:
|
||||
"""Snapshot and restore the trigger registry around a test."""
|
||||
snapshot = dict(trigger_store._REGISTRY)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
trigger_store._REGISTRY.clear()
|
||||
trigger_store._REGISTRY.update(snapshot)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
"""Lock the ``DispatchError`` exception contract.
|
||||
|
||||
``DispatchError`` is the uniform exception type the dispatch layer raises
|
||||
for any "cannot turn this fire request into a run" condition. Other
|
||||
modules (templates of error envelopes, run records) compare on
|
||||
``isinstance(exc, DispatchError)``, so the inheritance is the contract.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.dispatch.errors import DispatchError
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_dispatch_error_is_exception_subclass_and_carries_message() -> None:
|
||||
"""Lifting a string into ``DispatchError`` preserves the message and
|
||||
behaves as a regular ``Exception`` for ``isinstance`` / ``raise`` /
|
||||
``except`` consumers."""
|
||||
error = DispatchError("missing trigger")
|
||||
|
||||
assert isinstance(error, Exception)
|
||||
assert str(error) == "missing trigger"
|
||||
|
||||
with pytest.raises(DispatchError):
|
||||
raise error
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
"""Lock the input-validation contract enforced before a run is enqueued.
|
||||
|
||||
``validate_inputs`` is the pure schema check that ``enqueue_run`` runs against
|
||||
merged inputs. ``enqueue_run`` itself needs a real DB session, so tests target
|
||||
this pure function directly; the contract — not the symbol — is what's locked.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.dispatch.errors import DispatchError
|
||||
from app.automations.dispatch.inputs import validate_inputs
|
||||
from app.automations.schemas.definition.envelope import AutomationDefinition
|
||||
from app.automations.schemas.definition.inputs import Inputs
|
||||
from app.automations.schemas.definition.plan_step import PlanStep
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _minimal_definition(*, inputs: Inputs | None = None) -> AutomationDefinition:
|
||||
"""One-step definition with an optional declared input schema."""
|
||||
return AutomationDefinition(
|
||||
name="test",
|
||||
inputs=inputs,
|
||||
plan=[PlanStep(step_id="s1", action="agent_task")],
|
||||
)
|
||||
|
||||
|
||||
def test_validate_inputs_passes_through_when_no_schema_is_declared() -> None:
|
||||
"""When the definition declares no input schema, runtime inputs reach
|
||||
the template context **unchanged**. Regression site: previously this
|
||||
branch returned ``{}``, which stripped runtime keys like ``fired_at``
|
||||
and ``last_fired_at`` and made Jinja blow up on ``{{ inputs.* }}``.
|
||||
"""
|
||||
definition = _minimal_definition(inputs=None)
|
||||
runtime_inputs = {
|
||||
"fired_at": "2026-01-01T00:00:00+00:00",
|
||||
"last_fired_at": None,
|
||||
"static_key": "value",
|
||||
}
|
||||
|
||||
assert validate_inputs(definition, runtime_inputs) == runtime_inputs
|
||||
|
||||
|
||||
def test_validate_inputs_returns_inputs_when_they_match_declared_schema() -> None:
|
||||
"""With a declared JSON schema, inputs that satisfy it pass through
|
||||
unchanged (validation succeeds; the function does not coerce or
|
||||
strip extra fields not mentioned in the schema)."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"topic": {"type": "string"}},
|
||||
"required": ["topic"],
|
||||
}
|
||||
definition = _minimal_definition(inputs=Inputs(schema=schema))
|
||||
|
||||
inputs = {"topic": "weekly report"}
|
||||
|
||||
assert validate_inputs(definition, inputs) == inputs
|
||||
|
||||
|
||||
def test_validate_inputs_raises_dispatch_error_when_inputs_violate_schema() -> None:
|
||||
"""Inputs that don't match the declared schema must surface as
|
||||
``DispatchError`` (not the raw ``jsonschema.ValidationError``), so every
|
||||
caller can handle one dispatch-domain exception type uniformly."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"topic": {"type": "string"}},
|
||||
"required": ["topic"],
|
||||
}
|
||||
definition = _minimal_definition(inputs=Inputs(schema=schema))
|
||||
|
||||
with pytest.raises(DispatchError):
|
||||
validate_inputs(definition, {"topic": 42}) # type violates string
|
||||
|
|
@ -0,0 +1,272 @@
|
|||
"""Lock the ``execute_step`` orchestration contract.
|
||||
|
||||
Covers the pure step-execution logic: predicate gate, params rendering,
|
||||
action lookup, retry budget, error shaping. The ``ActionContext.session``
|
||||
is never touched by ``execute_step`` itself (it's only forwarded to the
|
||||
handler), so unit tests pass ``None`` cast to the type.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.actions.store import register_action
|
||||
from app.automations.actions.types import ActionContext, ActionDefinition
|
||||
from app.automations.runtime.step import execute_step
|
||||
from app.automations.schemas.definition.plan_step import PlanStep
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _AnyParams(BaseModel):
|
||||
"""Open params model used by test actions — they never validate."""
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
def _action_context() -> ActionContext:
|
||||
"""Minimal context: session is unused by ``execute_step``, only forwarded."""
|
||||
return ActionContext(
|
||||
session=cast(AsyncSession, None),
|
||||
run_id=1,
|
||||
step_id="s1",
|
||||
search_space_id=1,
|
||||
creator_user_id=None,
|
||||
)
|
||||
|
||||
|
||||
async def test_execute_step_runs_registered_action_handler_and_wraps_result(
|
||||
isolated_action_registry: None,
|
||||
) -> None:
|
||||
"""A step pointing at a registered action runs its handler with the
|
||||
step's params and returns a ``succeeded`` entry carrying the handler's
|
||||
output plus ``attempts=1`` (one try, no retries triggered)."""
|
||||
invocations: list[dict[str, Any]] = []
|
||||
|
||||
async def echo(params: dict[str, Any]) -> dict[str, Any]:
|
||||
invocations.append(params)
|
||||
return {"echoed": params["value"]}
|
||||
|
||||
register_action(
|
||||
ActionDefinition(
|
||||
type="test_echo",
|
||||
name="Echo",
|
||||
description="Test action.",
|
||||
params_model=_AnyParams,
|
||||
build_handler=lambda _ctx: echo,
|
||||
)
|
||||
)
|
||||
|
||||
step = PlanStep(step_id="s1", action="test_echo", params={"value": "hello"})
|
||||
|
||||
result = await execute_step(
|
||||
step=step,
|
||||
template_context={},
|
||||
action_context=_action_context(),
|
||||
default_max_retries=0,
|
||||
default_retry_backoff="none",
|
||||
default_timeout_seconds=30,
|
||||
)
|
||||
|
||||
assert result["status"] == "succeeded"
|
||||
assert result["step_id"] == "s1"
|
||||
assert result["action"] == "test_echo"
|
||||
assert result["attempts"] == 1
|
||||
assert result["result"] == {"echoed": "hello"}
|
||||
assert invocations == [{"value": "hello"}]
|
||||
|
||||
|
||||
async def test_execute_step_skips_step_when_predicate_is_falsy(
|
||||
isolated_action_registry: None,
|
||||
) -> None:
|
||||
"""If ``step.when`` evaluates to falsy in the template context, the
|
||||
handler is **not** invoked, the result entry has ``status=skipped``
|
||||
and ``attempts=0``, and no ``result`` key is present."""
|
||||
invoked = False
|
||||
|
||||
async def must_not_run(_params: dict[str, Any]) -> dict[str, Any]:
|
||||
nonlocal invoked
|
||||
invoked = True
|
||||
return {}
|
||||
|
||||
register_action(
|
||||
ActionDefinition(
|
||||
type="test_guarded",
|
||||
name="Guarded",
|
||||
description="Test action that should not run.",
|
||||
params_model=_AnyParams,
|
||||
build_handler=lambda _ctx: must_not_run,
|
||||
)
|
||||
)
|
||||
|
||||
step = PlanStep(
|
||||
step_id="s1",
|
||||
action="test_guarded",
|
||||
when="inputs.enabled",
|
||||
params={},
|
||||
)
|
||||
|
||||
result = await execute_step(
|
||||
step=step,
|
||||
template_context={"inputs": {"enabled": False}},
|
||||
action_context=_action_context(),
|
||||
default_max_retries=0,
|
||||
default_retry_backoff="none",
|
||||
default_timeout_seconds=30,
|
||||
)
|
||||
|
||||
assert result["status"] == "skipped"
|
||||
assert result["attempts"] == 0
|
||||
assert "result" not in result
|
||||
assert invoked is False
|
||||
|
||||
|
||||
async def test_execute_step_fails_when_step_references_an_unknown_action(
|
||||
isolated_action_registry: None,
|
||||
) -> None:
|
||||
"""A step pointing at an action that isn't in the registry must fail
|
||||
with ``ActionNotFound`` rather than crashing. Catches typos in the
|
||||
plan and removed actions without the run going off the rails."""
|
||||
step = PlanStep(step_id="s1", action="no_such_action", params={})
|
||||
|
||||
result = await execute_step(
|
||||
step=step,
|
||||
template_context={},
|
||||
action_context=_action_context(),
|
||||
default_max_retries=0,
|
||||
default_retry_backoff="none",
|
||||
default_timeout_seconds=30,
|
||||
)
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert result["attempts"] == 0
|
||||
assert result["error"]["type"] == "ActionNotFound"
|
||||
assert "no_such_action" in result["error"]["message"]
|
||||
|
||||
|
||||
async def test_execute_step_retries_failing_handler_up_to_default_budget(
|
||||
isolated_action_registry: None,
|
||||
) -> None:
|
||||
"""A handler that raises on every attempt consumes the retry budget
|
||||
(1 initial try + ``default_max_retries`` retries) and the step ends
|
||||
``failed`` with the exception's type and message surfaced through
|
||||
the error envelope."""
|
||||
calls = 0
|
||||
|
||||
async def always_fails(_params: dict[str, Any]) -> dict[str, Any]:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
raise RuntimeError("boom")
|
||||
|
||||
register_action(
|
||||
ActionDefinition(
|
||||
type="test_fails",
|
||||
name="Fails",
|
||||
description="Always raises.",
|
||||
params_model=_AnyParams,
|
||||
build_handler=lambda _ctx: always_fails,
|
||||
)
|
||||
)
|
||||
|
||||
step = PlanStep(step_id="s1", action="test_fails", params={})
|
||||
|
||||
result = await execute_step(
|
||||
step=step,
|
||||
template_context={},
|
||||
action_context=_action_context(),
|
||||
default_max_retries=2,
|
||||
default_retry_backoff="none",
|
||||
default_timeout_seconds=30,
|
||||
)
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert result["attempts"] == 3
|
||||
assert calls == 3
|
||||
assert result["error"]["type"] == "RuntimeError"
|
||||
assert "boom" in result["error"]["message"]
|
||||
|
||||
|
||||
async def test_execute_step_succeeds_when_handler_recovers_within_retry_budget(
|
||||
isolated_action_registry: None,
|
||||
) -> None:
|
||||
"""A handler that fails the first N times and then succeeds yields a
|
||||
``succeeded`` entry with ``attempts == N + 1``. Locks that retries
|
||||
can actually recover (not just exhaust)."""
|
||||
calls = 0
|
||||
|
||||
async def flaky(_params: dict[str, Any]) -> dict[str, Any]:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
if calls < 3:
|
||||
raise RuntimeError("transient")
|
||||
return {"ok": True}
|
||||
|
||||
register_action(
|
||||
ActionDefinition(
|
||||
type="test_flaky",
|
||||
name="Flaky",
|
||||
description="Fails twice, succeeds third time.",
|
||||
params_model=_AnyParams,
|
||||
build_handler=lambda _ctx: flaky,
|
||||
)
|
||||
)
|
||||
|
||||
step = PlanStep(step_id="s1", action="test_flaky", params={})
|
||||
|
||||
result = await execute_step(
|
||||
step=step,
|
||||
template_context={},
|
||||
action_context=_action_context(),
|
||||
default_max_retries=2,
|
||||
default_retry_backoff="none",
|
||||
default_timeout_seconds=30,
|
||||
)
|
||||
|
||||
assert result["status"] == "succeeded"
|
||||
assert result["attempts"] == 3
|
||||
assert result["result"] == {"ok": True}
|
||||
assert calls == 3
|
||||
|
||||
|
||||
async def test_execute_step_renders_step_params_through_template_engine(
|
||||
isolated_action_registry: None,
|
||||
) -> None:
|
||||
"""Step params are rendered against the template context before the
|
||||
handler is invoked. String values containing Jinja expressions get
|
||||
substituted from ``inputs`` and ``steps`` in the run context."""
|
||||
received: list[dict[str, Any]] = []
|
||||
|
||||
async def capture(params: dict[str, Any]) -> dict[str, Any]:
|
||||
received.append(params)
|
||||
return {}
|
||||
|
||||
register_action(
|
||||
ActionDefinition(
|
||||
type="test_capture",
|
||||
name="Capture",
|
||||
description="Captures the params passed in.",
|
||||
params_model=_AnyParams,
|
||||
build_handler=lambda _ctx: capture,
|
||||
)
|
||||
)
|
||||
|
||||
step = PlanStep(
|
||||
step_id="s1",
|
||||
action="test_capture",
|
||||
params={"message": "Hello {{ inputs.name }}"},
|
||||
)
|
||||
|
||||
await execute_step(
|
||||
step=step,
|
||||
template_context={"inputs": {"name": "World"}, "steps": {}},
|
||||
action_context=_action_context(),
|
||||
default_max_retries=0,
|
||||
default_retry_backoff="none",
|
||||
default_timeout_seconds=30,
|
||||
)
|
||||
|
||||
assert received == [{"message": "Hello World"}]
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
"""Lock that the executor propagates the captured model snapshot into the
|
||||
``ActionContext``, so runs resolve their own model (insulated from chat /
|
||||
search-space changes) and not the live search space.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.runtime.executor import _build_action_ctx
|
||||
from app.automations.schemas.definition.envelope import AutomationModels
|
||||
from app.automations.schemas.definition.plan_step import PlanStep
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _run() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=1,
|
||||
automation=SimpleNamespace(search_space_id=42, created_by_user_id="u-1"),
|
||||
)
|
||||
|
||||
|
||||
def test_build_action_ctx_propagates_captured_models() -> None:
|
||||
"""``definition.models`` flows onto the ActionContext model fields."""
|
||||
models = AutomationModels(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
)
|
||||
ctx = _build_action_ctx(
|
||||
cast(AsyncSession, None),
|
||||
_run(),
|
||||
PlanStep(step_id="s1", action="agent_task"),
|
||||
models,
|
||||
)
|
||||
|
||||
assert ctx.search_space_id == 42
|
||||
assert ctx.agent_llm_id == -1
|
||||
assert ctx.image_generation_config_id == 5
|
||||
assert ctx.vision_llm_config_id == -1
|
||||
|
||||
|
||||
def test_build_action_ctx_none_models_leaves_fields_none() -> None:
|
||||
"""No captured snapshot → model fields are None (defensive fallback path)."""
|
||||
ctx = _build_action_ctx(
|
||||
cast(AsyncSession, None),
|
||||
_run(),
|
||||
PlanStep(step_id="s1", action="agent_task"),
|
||||
None,
|
||||
)
|
||||
|
||||
assert ctx.agent_llm_id is None
|
||||
assert ctx.image_generation_config_id is None
|
||||
assert ctx.vision_llm_config_id is None
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
"""Lock the ``with_retries`` policy: budget, recovery, exhaustion, timeout, backoff.
|
||||
|
||||
Tests with ``backoff="none"`` to keep wall-clock time zero. Backoff sleep
|
||||
values themselves are observed by monkeypatching ``asyncio.sleep`` so we
|
||||
don't introduce flakiness via real timing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.runtime.retries import with_retries
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
async def test_with_retries_returns_result_and_attempts_one_on_first_success() -> None:
|
||||
"""A coroutine that succeeds on the first call returns its result
|
||||
paired with ``attempts=1`` — no retry consumed."""
|
||||
calls = 0
|
||||
|
||||
async def succeed() -> str:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return "ok"
|
||||
|
||||
result, attempts = await with_retries(
|
||||
succeed, max_retries=2, backoff="none", timeout=None
|
||||
)
|
||||
|
||||
assert result == "ok"
|
||||
assert attempts == 1
|
||||
assert calls == 1
|
||||
|
||||
|
||||
async def test_with_retries_returns_attempt_count_when_succeeding_after_failures() -> (
|
||||
None
|
||||
):
|
||||
"""A coroutine that fails twice then succeeds returns ``attempts=3``
|
||||
(the actual attempt that produced the result). Locks the contract
|
||||
that the caller can distinguish first-try success from a recovery."""
|
||||
calls = 0
|
||||
|
||||
async def flaky() -> str:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
if calls < 3:
|
||||
raise RuntimeError("transient")
|
||||
return "ok"
|
||||
|
||||
result, attempts = await with_retries(
|
||||
flaky, max_retries=5, backoff="none", timeout=None
|
||||
)
|
||||
|
||||
assert result == "ok"
|
||||
assert attempts == 3
|
||||
assert calls == 3
|
||||
|
||||
|
||||
async def test_with_retries_reraises_after_exhausting_the_budget() -> None:
|
||||
"""When the coroutine raises on every attempt within
|
||||
``1 + max_retries`` tries, the last exception propagates and the
|
||||
handler is called exactly ``1 + max_retries`` times."""
|
||||
calls = 0
|
||||
|
||||
async def always_fails() -> str:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
raise RuntimeError(f"boom-{calls}")
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom-3"):
|
||||
await with_retries(always_fails, max_retries=2, backoff="none", timeout=None)
|
||||
|
||||
assert calls == 3 # 1 initial + 2 retries
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
"""Lock the request-side automation API schemas — the public validation gate."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.automations.schemas.api.automation import AutomationCreate, AutomationUpdate
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_VALID_DEFINITION = {
|
||||
"name": "Test",
|
||||
"plan": [{"step_id": "s1", "action": "agent_task"}],
|
||||
}
|
||||
|
||||
|
||||
def test_automation_create_accepts_valid_minimal_payload() -> None:
|
||||
"""Happy path: just search_space_id, name, and a valid definition.
|
||||
Triggers default to ``[]`` so users can attach them later."""
|
||||
payload = AutomationCreate.model_validate(
|
||||
{
|
||||
"search_space_id": 1,
|
||||
"name": "Daily digest",
|
||||
"definition": _VALID_DEFINITION,
|
||||
}
|
||||
)
|
||||
|
||||
assert payload.name == "Daily digest"
|
||||
assert payload.description is None
|
||||
assert payload.triggers == []
|
||||
|
||||
|
||||
def test_automation_create_cascades_validation_into_nested_definition() -> None:
|
||||
"""A bad ``definition`` (e.g. empty plan) fails at the API boundary,
|
||||
not at the DB layer. Locks the cascade so corrupt definitions can't
|
||||
sneak through a misshapen wire payload."""
|
||||
with pytest.raises(ValidationError):
|
||||
AutomationCreate.model_validate(
|
||||
{
|
||||
"search_space_id": 1,
|
||||
"name": "Bad",
|
||||
"definition": {"name": "X", "plan": []}, # empty plan
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_automation_create_rejects_unknown_top_level_field() -> None:
|
||||
"""``extra='forbid'`` catches typos in API payloads at the boundary."""
|
||||
with pytest.raises(ValidationError):
|
||||
AutomationCreate.model_validate(
|
||||
{
|
||||
"search_space_id": 1,
|
||||
"name": "X",
|
||||
"definition": _VALID_DEFINITION,
|
||||
"owner": "tg", # not allowed
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_automation_create_rejects_empty_name() -> None:
|
||||
"""Name is required and constrained to 1..200 chars."""
|
||||
with pytest.raises(ValidationError):
|
||||
AutomationCreate.model_validate(
|
||||
{
|
||||
"search_space_id": 1,
|
||||
"name": "",
|
||||
"definition": _VALID_DEFINITION,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_automation_update_accepts_partial_payload_with_no_fields() -> None:
|
||||
"""All fields on ``AutomationUpdate`` are optional. An empty body is
|
||||
a valid no-op update (the service layer decides what to do with it)."""
|
||||
update = AutomationUpdate.model_validate({})
|
||||
|
||||
assert update.name is None
|
||||
assert update.description is None
|
||||
assert update.status is None
|
||||
assert update.definition is None
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
"""Lock the request-side trigger API schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.schemas.api.trigger import TriggerCreate, TriggerUpdate
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_trigger_create_uses_safe_defaults_for_optional_fields() -> None:
|
||||
"""Defaults: empty ``params`` and ``static_inputs``, ``enabled=True``.
|
||||
These let callers create a trigger with just ``type`` + the params
|
||||
the trigger requires."""
|
||||
trigger = TriggerCreate(type=TriggerType.SCHEDULE) # type: ignore[arg-type]
|
||||
|
||||
assert trigger.type is TriggerType.SCHEDULE
|
||||
assert trigger.params == {}
|
||||
assert trigger.static_inputs == {}
|
||||
assert trigger.enabled is True
|
||||
|
||||
|
||||
def test_trigger_create_rejects_unknown_trigger_type_string() -> None:
|
||||
"""``type`` is a ``TriggerType`` enum, so any string outside the
|
||||
enum's known values fails validation at the boundary."""
|
||||
with pytest.raises(ValidationError):
|
||||
TriggerCreate.model_validate({"type": "webhook"}) # not in TriggerType
|
||||
|
||||
|
||||
def test_trigger_create_rejects_unknown_field() -> None:
|
||||
"""``extra='forbid'`` catches typos in trigger payloads."""
|
||||
with pytest.raises(ValidationError):
|
||||
TriggerCreate.model_validate(
|
||||
{"type": "schedule", "param": {}} # typo: param vs params
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_update_accepts_partial_payload_with_no_fields() -> None:
|
||||
"""``TriggerUpdate`` is fully optional — empty body is valid (no-op)."""
|
||||
update = TriggerUpdate()
|
||||
|
||||
assert update.enabled is None
|
||||
assert update.params is None
|
||||
assert update.static_inputs is None
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
"""Lock the ``AutomationDefinition`` envelope contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.automations.schemas.definition.envelope import (
|
||||
AutomationDefinition,
|
||||
AutomationModels,
|
||||
)
|
||||
from app.automations.schemas.definition.plan_step import PlanStep
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_automation_definition_accepts_minimal_valid_input_with_sensible_defaults() -> (
|
||||
None
|
||||
):
|
||||
"""A definition with just ``name`` + a one-step ``plan`` is valid and
|
||||
fills in the rest with safe defaults so users don't have to write
|
||||
out every section to get started."""
|
||||
definition = AutomationDefinition(
|
||||
name="Daily digest",
|
||||
plan=[PlanStep(step_id="s1", action="agent_task")],
|
||||
)
|
||||
|
||||
assert definition.name == "Daily digest"
|
||||
assert definition.schema_version == "1.0"
|
||||
assert definition.goal is None
|
||||
assert definition.inputs is None
|
||||
assert definition.triggers == []
|
||||
# ``models`` is optional (populated server-side at create()).
|
||||
assert definition.models is None
|
||||
|
||||
|
||||
def test_automation_definition_models_round_trip() -> None:
|
||||
"""The captured ``models`` snapshot survives a model_dump/validate round-trip."""
|
||||
definition = AutomationDefinition(
|
||||
name="Daily digest",
|
||||
plan=[PlanStep(step_id="s1", action="agent_task")],
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
),
|
||||
)
|
||||
|
||||
dumped = definition.model_dump(mode="json", by_alias=True)
|
||||
assert dumped["models"] == {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
}
|
||||
|
||||
restored = AutomationDefinition.model_validate(dumped)
|
||||
assert restored.models is not None
|
||||
assert restored.models.agent_llm_id == -1
|
||||
assert restored.models.image_generation_config_id == 5
|
||||
assert restored.models.vision_llm_config_id == -1
|
||||
|
||||
|
||||
def test_automation_definition_rejects_unknown_top_level_field() -> None:
|
||||
"""``extra='forbid'`` catches typos at validation time (e.g. ``pln``
|
||||
instead of ``plan``) before the bad definition reaches storage."""
|
||||
with pytest.raises(ValidationError):
|
||||
AutomationDefinition.model_validate(
|
||||
{
|
||||
"name": "X",
|
||||
"plan": [{"step_id": "s1", "action": "agent_task"}],
|
||||
"extra_field": "unexpected",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_automation_definition_rejects_empty_plan() -> None:
|
||||
"""An automation with no plan steps has nothing to execute and must
|
||||
be rejected at validation time."""
|
||||
with pytest.raises(ValidationError):
|
||||
AutomationDefinition(name="X", plan=[])
|
||||
|
||||
|
||||
def test_automation_definition_rejects_empty_name() -> None:
|
||||
"""Name is required and must be non-empty so list views and audit
|
||||
logs have something meaningful to display."""
|
||||
with pytest.raises(ValidationError):
|
||||
AutomationDefinition(
|
||||
name="",
|
||||
plan=[PlanStep(step_id="s1", action="agent_task")],
|
||||
)
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
"""Lock the ``Execution`` defaults + literal-constraint contract.
|
||||
|
||||
These defaults control production behavior of every automation that
|
||||
doesn't override them; the defaults *are* the contract.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.automations.schemas.definition.execution import Execution
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_execution_uses_production_defaults_when_no_overrides_provided() -> None:
|
||||
"""The defaults shipped to prod: 10-minute wall clock, 2 retries
|
||||
per step, exponential backoff, drop overlapping runs. Changing any
|
||||
of these is a behavioral release-note change."""
|
||||
execution = Execution()
|
||||
|
||||
assert execution.timeout_seconds == 600
|
||||
assert execution.max_retries == 2
|
||||
assert execution.retry_backoff == "exponential"
|
||||
assert execution.concurrency == "drop_if_running"
|
||||
assert execution.on_failure == []
|
||||
|
||||
|
||||
def test_execution_rejects_unknown_retry_backoff_strategy() -> None:
|
||||
"""``retry_backoff`` is constrained to a closed set — typos like
|
||||
``"expontential"`` must fail validation, not silently coerce."""
|
||||
with pytest.raises(ValidationError):
|
||||
Execution(retry_backoff="expontential") # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_execution_rejects_unknown_concurrency_strategy() -> None:
|
||||
"""Same closed-set constraint on ``concurrency``."""
|
||||
with pytest.raises(ValidationError):
|
||||
Execution(concurrency="parallel") # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_execution_rejects_invalid_numeric_bounds() -> None:
|
||||
"""``timeout_seconds > 0`` and ``max_retries >= 0``. Zero or negative
|
||||
values would produce nonsensical run behavior."""
|
||||
with pytest.raises(ValidationError):
|
||||
Execution(timeout_seconds=0)
|
||||
with pytest.raises(ValidationError):
|
||||
Execution(max_retries=-1)
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""Lock the ``Inputs`` JSON ``schema``-alias roundtrip.
|
||||
|
||||
The field is ``schema_`` in Python (``schema`` shadows a Pydantic builtin)
|
||||
but is wire-named ``schema``. Locking the roundtrip means JSON definitions
|
||||
authored anywhere (UI raw editor, NL drafter, CLI export) speak the same
|
||||
wire shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.automations.schemas.definition.inputs import Inputs
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_inputs_parses_wire_field_named_schema_into_schema_attribute() -> None:
|
||||
"""JSON payloads use ``schema`` (the convention). The model stores it
|
||||
on the Python attribute ``schema_`` without shadowing the builtin."""
|
||||
parsed = Inputs.model_validate({"schema": {"type": "object"}})
|
||||
|
||||
assert parsed.schema_ == {"type": "object"}
|
||||
|
||||
|
||||
def test_inputs_serializes_schema_attribute_back_to_wire_field_named_schema() -> None:
|
||||
"""Round-trip: serializing emits ``schema`` (alias), not ``schema_``.
|
||||
Locks the consumer-visible JSON shape regardless of the Python name."""
|
||||
inputs = Inputs(schema={"type": "object"}) # type: ignore[call-arg]
|
||||
|
||||
assert inputs.model_dump() == {"schema": {"type": "object"}}
|
||||
|
||||
|
||||
def test_inputs_rejects_unknown_field() -> None:
|
||||
"""``extra='forbid'`` catches typos like ``shema`` so bad definitions
|
||||
don't silently lose their input declaration."""
|
||||
with pytest.raises(ValidationError):
|
||||
Inputs.model_validate({"schema": {}, "extra": "x"})
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
"""Lock the ``Metadata`` ``extra='allow'`` contract — the only schema
|
||||
that does. Free-form annotations on definitions (e.g. ``owner``,
|
||||
``project``, ``created_by_ai``) need to round-trip through the envelope
|
||||
without being rejected.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.schemas.definition.metadata import Metadata
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_metadata_preserves_unknown_keys() -> None:
|
||||
"""Unlike every other definition sub-schema, ``Metadata`` allows
|
||||
extra keys and round-trips them — that's its purpose."""
|
||||
metadata = Metadata.model_validate(
|
||||
{
|
||||
"tags": ["weekly", "report"],
|
||||
"owner": "tg",
|
||||
"created_by_ai": True,
|
||||
}
|
||||
)
|
||||
|
||||
dumped = metadata.model_dump()
|
||||
|
||||
assert dumped["tags"] == ["weekly", "report"]
|
||||
assert dumped["owner"] == "tg"
|
||||
assert dumped["created_by_ai"] is True
|
||||
|
||||
|
||||
def test_metadata_defaults_tags_to_empty_list() -> None:
|
||||
"""No tags is the common case; the default is the empty list so
|
||||
callers can append without a None check."""
|
||||
assert Metadata().tags == []
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
"""Lock the ``PlanStep`` validation contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.automations.schemas.definition.plan_step import PlanStep
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_plan_step_accepts_minimal_input_with_safe_defaults() -> None:
|
||||
"""A step with just ``step_id`` + ``action`` is valid. Defaults
|
||||
(no when, empty params, no output_as override, no retry/timeout
|
||||
override) let the run inherit automation-wide defaults."""
|
||||
step = PlanStep(step_id="s1", action="agent_task")
|
||||
|
||||
assert step.step_id == "s1"
|
||||
assert step.action == "agent_task"
|
||||
assert step.when is None
|
||||
assert step.params == {}
|
||||
assert step.output_as is None
|
||||
assert step.max_retries is None
|
||||
assert step.timeout_seconds is None
|
||||
|
||||
|
||||
def test_plan_step_rejects_empty_step_id_and_action() -> None:
|
||||
"""``step_id`` and ``action`` are addressing primitives — empty
|
||||
strings would silently break runtime lookups."""
|
||||
with pytest.raises(ValidationError):
|
||||
PlanStep(step_id="", action="agent_task")
|
||||
with pytest.raises(ValidationError):
|
||||
PlanStep(step_id="s1", action="")
|
||||
|
||||
|
||||
def test_plan_step_rejects_negative_max_retries_and_non_positive_timeout() -> None:
|
||||
"""Numeric constraints: ``max_retries >= 0`` and ``timeout_seconds > 0``.
|
||||
Negative budgets or zero timeouts produce nonsensical run behavior."""
|
||||
with pytest.raises(ValidationError):
|
||||
PlanStep(step_id="s1", action="agent_task", max_retries=-1)
|
||||
with pytest.raises(ValidationError):
|
||||
PlanStep(step_id="s1", action="agent_task", timeout_seconds=0)
|
||||
|
||||
|
||||
def test_plan_step_rejects_unknown_field() -> None:
|
||||
"""``extra='forbid'`` catches typos like ``actoin`` (instead of
|
||||
``action``) before the bad step reaches storage."""
|
||||
with pytest.raises(ValidationError):
|
||||
PlanStep.model_validate(
|
||||
{"step_id": "s1", "action": "agent_task", "actoin": "agent_task"}
|
||||
)
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
"""Lock the ``TriggerSpec`` validation contract — the entry shape used
|
||||
inside an automation's ``triggers[]`` array.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.automations.schemas.definition.trigger_spec import TriggerSpec
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_trigger_spec_accepts_type_with_default_empty_params() -> None:
|
||||
"""``type`` is required; ``params`` defaults to ``{}`` so triggers
|
||||
that take no params don't need an explicit body."""
|
||||
spec = TriggerSpec(type="schedule")
|
||||
|
||||
assert spec.type == "schedule"
|
||||
assert spec.params == {}
|
||||
|
||||
|
||||
def test_trigger_spec_rejects_empty_type() -> None:
|
||||
"""``type`` is the registry lookup key — empty would silently miss."""
|
||||
with pytest.raises(ValidationError):
|
||||
TriggerSpec(type="")
|
||||
|
||||
|
||||
def test_trigger_spec_rejects_unknown_field() -> None:
|
||||
"""``extra='forbid'`` catches typos at definition-validation time."""
|
||||
with pytest.raises(ValidationError):
|
||||
TriggerSpec.model_validate({"type": "schedule", "paramz": {}})
|
||||
|
|
@ -0,0 +1,493 @@
|
|||
"""Lock creation-time model-policy enforcement in ``AutomationService``.
|
||||
|
||||
Creation (REST + manual builder) rejects search spaces whose models aren't
|
||||
billable for automations with HTTP 422, mirroring the runtime backstop. These
|
||||
tests isolate the new ``_assert_models_billable`` / ``model_eligibility`` paths
|
||||
without touching the DB commit.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import app.automations.services.automation as automation_mod
|
||||
from app.automations.schemas.api import AutomationCreate, AutomationUpdate
|
||||
from app.automations.schemas.definition.envelope import (
|
||||
AutomationDefinition,
|
||||
AutomationModels,
|
||||
)
|
||||
from app.automations.schemas.definition.plan_step import PlanStep
|
||||
from app.automations.services.automation import AutomationService
|
||||
from app.automations.services.model_policy import AutomationModelPolicyError
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, search_space: Any) -> None:
|
||||
self._search_space = search_space
|
||||
self.added: list[Any] = []
|
||||
self.commits = 0
|
||||
|
||||
async def get(self, _model: Any, _pk: int) -> Any:
|
||||
return self._search_space
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commits += 1
|
||||
|
||||
|
||||
def _service(search_space: Any) -> AutomationService:
|
||||
return AutomationService(
|
||||
session=_FakeSession(search_space), user=SimpleNamespace(id="u-1")
|
||||
)
|
||||
|
||||
|
||||
def _definition(**kwargs: Any) -> AutomationDefinition:
|
||||
return AutomationDefinition(
|
||||
name="A",
|
||||
plan=[PlanStep(step_id="s1", action="agent_task")],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def test_assert_models_billable_raises_422_on_violation(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A blocked model maps the policy error to HTTP 422."""
|
||||
|
||||
def _raise(_ss):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "llm", "config_id": 0, "reason": "Auto mode"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=0))
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service._assert_models_billable(1)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
|
||||
|
||||
async def test_assert_models_billable_raises_404_when_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A missing search space is a 404, not a policy error."""
|
||||
monkeypatch.setattr(
|
||||
automation_mod, "assert_automation_models_billable", lambda _ss: None
|
||||
)
|
||||
|
||||
service = _service(None)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service._assert_models_billable(999)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
async def test_assert_models_billable_returns_search_space_when_ok(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When the policy accepts, the loaded search space is returned for reuse."""
|
||||
monkeypatch.setattr(
|
||||
automation_mod, "assert_automation_models_billable", lambda _ss: None
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(agent_llm_id=-1)
|
||||
service = _service(search_space)
|
||||
assert await service._assert_models_billable(1) is search_space
|
||||
|
||||
|
||||
async def test_create_injects_captured_models_from_search_space(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""create() snapshots the search space's model prefs onto the definition."""
|
||||
monkeypatch.setattr(
|
||||
automation_mod, "assert_automation_models_billable", lambda _ss: None
|
||||
)
|
||||
|
||||
async def _noop_authorize(self, *_a, **_k):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
|
||||
async def _return_added(self, _aid):
|
||||
return self.session.added[-1]
|
||||
|
||||
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
|
||||
|
||||
search_space = SimpleNamespace(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
)
|
||||
service = _service(search_space)
|
||||
payload = AutomationCreate(
|
||||
search_space_id=1,
|
||||
name="A",
|
||||
definition=_definition(),
|
||||
)
|
||||
|
||||
automation = await service.create(payload)
|
||||
|
||||
assert automation.definition["models"] == {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
}
|
||||
|
||||
|
||||
async def test_create_treats_unset_prefs_as_auto_zero(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""``None`` search-space prefs are captured as ``0`` (Auto) ids."""
|
||||
monkeypatch.setattr(
|
||||
automation_mod, "assert_automation_models_billable", lambda _ss: None
|
||||
)
|
||||
|
||||
async def _noop_authorize(self, *_a, **_k):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
|
||||
async def _return_added(self, _aid):
|
||||
return self.session.added[-1]
|
||||
|
||||
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
|
||||
|
||||
search_space = SimpleNamespace(
|
||||
agent_llm_id=None,
|
||||
image_generation_config_id=None,
|
||||
vision_llm_config_id=None,
|
||||
)
|
||||
service = _service(search_space)
|
||||
payload = AutomationCreate(search_space_id=1, name="A", definition=_definition())
|
||||
|
||||
automation = await service.create(payload)
|
||||
|
||||
assert automation.definition["models"] == {
|
||||
"agent_llm_id": 0,
|
||||
"image_generation_config_id": 0,
|
||||
"vision_llm_config_id": 0,
|
||||
}
|
||||
|
||||
|
||||
async def test_create_honors_selected_models_when_provided(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When the payload carries ``definition.models`` they are validated + kept.
|
||||
|
||||
The search-space snapshot path is bypassed entirely (no
|
||||
``assert_automation_models_billable`` call).
|
||||
"""
|
||||
|
||||
def _fail_snapshot(_ss):
|
||||
raise AssertionError("snapshot path should not run when models are provided")
|
||||
|
||||
monkeypatch.setattr(
|
||||
automation_mod, "assert_automation_models_billable", _fail_snapshot
|
||||
)
|
||||
validated: dict[str, Any] = {}
|
||||
|
||||
def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
validated["ids"] = (
|
||||
agent_llm_id,
|
||||
image_generation_config_id,
|
||||
vision_llm_config_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok)
|
||||
|
||||
async def _noop_authorize(self, *_a, **_k):
|
||||
return None
|
||||
|
||||
async def _return_added(self, _aid):
|
||||
return self.session.added[-1]
|
||||
|
||||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=-99))
|
||||
payload = AutomationCreate(
|
||||
search_space_id=1,
|
||||
name="A",
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=7,
|
||||
vision_llm_config_id=-2,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
automation = await service.create(payload)
|
||||
|
||||
assert validated["ids"] == (-1, 7, -2)
|
||||
assert automation.definition["models"] == {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 7,
|
||||
"vision_llm_config_id": -2,
|
||||
}
|
||||
|
||||
|
||||
async def test_create_rejects_unbillable_selected_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A non-billable explicit selection maps the policy error to HTTP 422."""
|
||||
|
||||
def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "llm", "config_id": -3, "reason": "free model"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _raise)
|
||||
|
||||
async def _noop_authorize(self, *_a, **_k):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=-3))
|
||||
payload = AutomationCreate(
|
||||
search_space_id=1,
|
||||
name="A",
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-3,
|
||||
image_generation_config_id=7,
|
||||
vision_llm_config_id=-2,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.create(payload)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
|
||||
|
||||
async def test_update_preserves_captured_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A definition edit carries over the previously captured ``models``."""
|
||||
captured = {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
}
|
||||
existing = SimpleNamespace(
|
||||
search_space_id=1,
|
||||
definition={"name": "A", "plan": [], "models": captured},
|
||||
version=3,
|
||||
)
|
||||
|
||||
async def _noop_authorize(self, *_a, **_k):
|
||||
return None
|
||||
|
||||
async def _return_existing(self, _aid):
|
||||
return existing
|
||||
|
||||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
monkeypatch.setattr(
|
||||
AutomationService, "_get_with_triggers_or_raise", _return_existing
|
||||
)
|
||||
|
||||
service = _service(SimpleNamespace())
|
||||
# The incoming patch definition has no ``models`` (frontend strips it).
|
||||
patch = AutomationUpdate(definition=_definition())
|
||||
|
||||
result = await service.update(7, patch)
|
||||
|
||||
assert result.definition["models"] == captured
|
||||
assert result.version == 4
|
||||
|
||||
|
||||
async def test_update_honors_changed_models_when_valid(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A definition edit with a *changed* models block validates + keeps it."""
|
||||
existing = SimpleNamespace(
|
||||
search_space_id=1,
|
||||
definition={
|
||||
"name": "A",
|
||||
"plan": [],
|
||||
"models": {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
},
|
||||
},
|
||||
version=3,
|
||||
)
|
||||
validated: dict[str, Any] = {}
|
||||
|
||||
def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
validated["ids"] = (
|
||||
agent_llm_id,
|
||||
image_generation_config_id,
|
||||
vision_llm_config_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok)
|
||||
|
||||
async def _noop_authorize(self, *_a, **_k):
|
||||
return None
|
||||
|
||||
async def _return_existing(self, _aid):
|
||||
return existing
|
||||
|
||||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
monkeypatch.setattr(
|
||||
AutomationService, "_get_with_triggers_or_raise", _return_existing
|
||||
)
|
||||
|
||||
service = _service(SimpleNamespace())
|
||||
patch = AutomationUpdate(
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-2,
|
||||
image_generation_config_id=9,
|
||||
vision_llm_config_id=-2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
result = await service.update(7, patch)
|
||||
|
||||
assert validated["ids"] == (-2, 9, -2)
|
||||
assert result.definition["models"] == {
|
||||
"agent_llm_id": -2,
|
||||
"image_generation_config_id": 9,
|
||||
"vision_llm_config_id": -2,
|
||||
}
|
||||
assert result.version == 4
|
||||
|
||||
|
||||
async def test_update_rejects_changed_unbillable_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A *changed* non-billable models block is rejected with HTTP 422."""
|
||||
existing = SimpleNamespace(
|
||||
search_space_id=1,
|
||||
definition={
|
||||
"name": "A",
|
||||
"plan": [],
|
||||
"models": {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
},
|
||||
},
|
||||
version=3,
|
||||
)
|
||||
|
||||
def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "llm", "config_id": -7, "reason": "free model"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _raise)
|
||||
|
||||
async def _noop_authorize(self, *_a, **_k):
|
||||
return None
|
||||
|
||||
async def _return_existing(self, _aid):
|
||||
return existing
|
||||
|
||||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
monkeypatch.setattr(
|
||||
AutomationService, "_get_with_triggers_or_raise", _return_existing
|
||||
)
|
||||
|
||||
service = _service(SimpleNamespace())
|
||||
patch = AutomationUpdate(
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.update(7, patch)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
|
||||
|
||||
async def test_update_keeps_unchanged_models_without_revalidation(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""An unchanged models block is kept as-is and is NOT re-validated.
|
||||
|
||||
Lets users edit an automation whose captured model later drifted out of
|
||||
premium without an unrelated edit tripping the policy check.
|
||||
"""
|
||||
captured = {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
}
|
||||
existing = SimpleNamespace(
|
||||
search_space_id=1,
|
||||
definition={"name": "A", "plan": [], "models": captured},
|
||||
version=3,
|
||||
)
|
||||
|
||||
def _fail(*_a, **_k):
|
||||
raise AssertionError("unchanged models must not be re-validated")
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _fail)
|
||||
|
||||
async def _noop_authorize(self, *_a, **_k):
|
||||
return None
|
||||
|
||||
async def _return_existing(self, _aid):
|
||||
return existing
|
||||
|
||||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
monkeypatch.setattr(
|
||||
AutomationService, "_get_with_triggers_or_raise", _return_existing
|
||||
)
|
||||
|
||||
service = _service(SimpleNamespace())
|
||||
patch = AutomationUpdate(
|
||||
definition=_definition(models=AutomationModels(**captured))
|
||||
)
|
||||
|
||||
result = await service.update(7, patch)
|
||||
|
||||
assert result.definition["models"] == captured
|
||||
assert result.version == 4
|
||||
|
||||
|
||||
async def test_model_eligibility_authorizes_and_returns_payload(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""``model_eligibility`` checks read access then returns the eligibility dict."""
|
||||
authorized: dict[str, Any] = {}
|
||||
|
||||
async def _fake_check_permission(_session, _user, ss_id, permission, _msg):
|
||||
authorized["search_space_id"] = ss_id
|
||||
authorized["permission"] = permission
|
||||
|
||||
monkeypatch.setattr(automation_mod, "check_permission", _fake_check_permission)
|
||||
monkeypatch.setattr(
|
||||
automation_mod,
|
||||
"get_automation_model_eligibility",
|
||||
lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]},
|
||||
)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=-2))
|
||||
result = await service.model_eligibility(search_space_id=5)
|
||||
|
||||
assert result == {"allowed": False, "violations": [{"kind": "image"}]}
|
||||
assert authorized["search_space_id"] == 5
|
||||
assert authorized["permission"] == "automations:read"
|
||||
|
|
@ -0,0 +1,196 @@
|
|||
"""Lock the automation model-billing policy.
|
||||
|
||||
Automations may only run on billable models: premium global configs
|
||||
(``billing_tier == "premium"``) or user BYOK configs (positive id). Free
|
||||
globals and Auto mode (id == 0 / None) are blocked. These tests pin that rule
|
||||
across all three model slots (chat LLM, image, vision).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
import app.automations.services.model_policy as model_policy
|
||||
from app.automations.services.model_policy import (
|
||||
AutomationModelPolicyError,
|
||||
assert_automation_models_billable,
|
||||
assert_models_billable,
|
||||
get_automation_model_eligibility,
|
||||
get_model_eligibility,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _search_space(*, llm: int | None, image: int | None, vision: int | None):
|
||||
"""Minimal stand-in for the ``SearchSpace`` ORM row the policy reads."""
|
||||
return SimpleNamespace(
|
||||
agent_llm_id=llm,
|
||||
image_generation_config_id=image,
|
||||
vision_llm_config_id=vision,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_globals(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Stub the global config sources the policy consults for negative ids.
|
||||
|
||||
Negative ids: -1 is premium, -2 is free, for each of llm/image/vision.
|
||||
"""
|
||||
llm_configs = {
|
||||
-1: {"id": -1, "billing_tier": "premium"},
|
||||
-2: {"id": -2, "billing_tier": "free"},
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.llm_config.load_global_llm_config_by_id",
|
||||
lambda cid: llm_configs.get(cid),
|
||||
)
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
monkeypatch.setattr(
|
||||
app_config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
[
|
||||
{"id": -1, "billing_tier": "premium"},
|
||||
{"id": -2, "billing_tier": "free"},
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -1, "billing_tier": "premium"},
|
||||
{"id": -2, "billing_tier": "free"},
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None:
|
||||
"""A positive config id is a user-owned BYOK model — always billable."""
|
||||
allowed, reason = model_policy._classify(kind, 7)
|
||||
assert allowed is True
|
||||
assert reason == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("config_id", [0, None])
|
||||
def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None:
|
||||
"""Auto mode (id 0) and an unset slot (None) are blocked."""
|
||||
allowed, reason = model_policy._classify(kind, config_id)
|
||||
assert allowed is False
|
||||
assert "Auto mode" in reason
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
def test_premium_global_is_allowed(kind: str, patched_globals) -> None:
|
||||
"""A negative (global) id with premium billing tier is allowed."""
|
||||
allowed, reason = model_policy._classify(kind, -1)
|
||||
assert allowed is True
|
||||
assert reason == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
def test_free_global_is_blocked(kind: str, patched_globals) -> None:
|
||||
"""A negative (global) id with a free billing tier is blocked."""
|
||||
allowed, reason = model_policy._classify(kind, -2)
|
||||
assert allowed is False
|
||||
assert "free model" in reason
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None:
|
||||
"""A negative id that resolves to no config is treated as not premium."""
|
||||
allowed, _ = model_policy._classify(kind, -999)
|
||||
assert allowed is False
|
||||
|
||||
|
||||
def test_eligibility_all_billable(patched_globals) -> None:
|
||||
"""Premium LLM + BYOK image + premium vision → allowed, no violations."""
|
||||
search_space = _search_space(llm=-1, image=5, vision=-1)
|
||||
result = get_automation_model_eligibility(search_space)
|
||||
assert result == {"allowed": True, "violations": []}
|
||||
|
||||
|
||||
def test_eligibility_reports_each_violation(patched_globals) -> None:
|
||||
"""A free LLM, Auto image, and free vision each produce a violation."""
|
||||
search_space = _search_space(llm=-2, image=0, vision=-2)
|
||||
result = get_automation_model_eligibility(search_space)
|
||||
|
||||
assert result["allowed"] is False
|
||||
kinds = {v["kind"] for v in result["violations"]}
|
||||
assert kinds == {"llm", "image", "vision"}
|
||||
# config_id is echoed back for the UI / settings deep-link.
|
||||
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
|
||||
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
|
||||
|
||||
|
||||
def test_assert_raises_with_violations(patched_globals) -> None:
|
||||
"""``assert_automation_models_billable`` raises when any slot is blocked."""
|
||||
search_space = _search_space(llm=0, image=5, vision=-1)
|
||||
with pytest.raises(AutomationModelPolicyError) as exc_info:
|
||||
assert_automation_models_billable(search_space)
|
||||
|
||||
assert len(exc_info.value.violations) == 1
|
||||
assert exc_info.value.violations[0]["kind"] == "llm"
|
||||
|
||||
|
||||
def test_assert_passes_when_all_billable(patched_globals) -> None:
|
||||
"""No exception when every slot is premium or BYOK."""
|
||||
search_space = _search_space(llm=3, image=-1, vision=4)
|
||||
assert assert_automation_models_billable(search_space) is None
|
||||
|
||||
|
||||
# --- ID-based core (used by the runtime backstop against captured snapshots) ---
|
||||
|
||||
|
||||
def test_get_model_eligibility_all_billable(patched_globals) -> None:
|
||||
"""Premium LLM + BYOK image + premium vision (explicit ids) → allowed."""
|
||||
result = get_model_eligibility(
|
||||
agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1
|
||||
)
|
||||
assert result == {"allowed": True, "violations": []}
|
||||
|
||||
|
||||
def test_get_model_eligibility_reports_each_violation(patched_globals) -> None:
|
||||
"""Free LLM, Auto image, free vision (explicit ids) each produce a violation."""
|
||||
result = get_model_eligibility(
|
||||
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
|
||||
)
|
||||
assert result["allowed"] is False
|
||||
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
|
||||
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
|
||||
|
||||
|
||||
def test_assert_models_billable_raises(patched_globals) -> None:
|
||||
"""``assert_models_billable`` raises when any explicit id is blocked."""
|
||||
with pytest.raises(AutomationModelPolicyError) as exc_info:
|
||||
assert_models_billable(
|
||||
agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1
|
||||
)
|
||||
assert len(exc_info.value.violations) == 1
|
||||
assert exc_info.value.violations[0]["kind"] == "llm"
|
||||
|
||||
|
||||
def test_assert_models_billable_passes(patched_globals) -> None:
|
||||
"""No exception when every explicit id is premium or BYOK."""
|
||||
assert (
|
||||
assert_models_billable(
|
||||
agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_search_space_wrapper_delegates_to_core(patched_globals) -> None:
|
||||
"""The search-space wrapper produces the same result as the ID core."""
|
||||
search_space = _search_space(llm=-2, image=0, vision=-2)
|
||||
assert get_automation_model_eligibility(search_space) == get_model_eligibility(
|
||||
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
|
||||
)
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
"""Lock the ``{run, inputs, steps}`` namespace exposed to every template."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.templating.context import build_run_context
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_build_run_context_exposes_run_inputs_and_steps_namespaces() -> None:
|
||||
"""The namespace handed to templates groups run metadata under ``run``,
|
||||
runtime + static inputs under ``inputs``, and step outputs (keyed by
|
||||
``output_as`` / ``step_id``) under ``steps``. Locks the contract that
|
||||
every plan template body relies on."""
|
||||
creator = UUID("00000000-0000-0000-0000-000000000001")
|
||||
started = datetime(2026, 5, 28, 14, 30, tzinfo=UTC)
|
||||
|
||||
ctx = build_run_context(
|
||||
run_id=42,
|
||||
automation_id=7,
|
||||
automation_name="Weekly digest",
|
||||
automation_version=3,
|
||||
search_space_id=1,
|
||||
creator_id=creator,
|
||||
trigger_id=11,
|
||||
trigger_type="schedule",
|
||||
started_at=started,
|
||||
attempt=2,
|
||||
inputs={"topic": "weekly"},
|
||||
step_outputs={"summarize": {"text": "ok"}},
|
||||
)
|
||||
|
||||
assert ctx == {
|
||||
"run": {
|
||||
"id": 42,
|
||||
"automation_id": 7,
|
||||
"automation_name": "Weekly digest",
|
||||
"automation_version": 3,
|
||||
"search_space_id": 1,
|
||||
"creator_id": creator,
|
||||
"trigger_id": 11,
|
||||
"trigger_type": "schedule",
|
||||
"started_at": started,
|
||||
"attempt": 2,
|
||||
},
|
||||
"inputs": {"topic": "weekly"},
|
||||
"steps": {"summarize": {"text": "ok"}},
|
||||
}
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
"""Lock the sandbox boundary: disallowed filters/tests reject, finalize coerces non-strings.
|
||||
|
||||
These behaviors live in ``environment.py`` but are observed through the
|
||||
public ``render_template`` surface — the same surface every step uses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
from jinja2.exceptions import TemplateError
|
||||
|
||||
from app.automations.templating.render import render_template
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_environment_rejects_filters_not_in_the_allowlist() -> None:
|
||||
"""A template that pipes through a Jinja built-in **not** in the
|
||||
allowlist (e.g. ``pprint``) must fail rather than rendering. Locks
|
||||
the sandbox surface against accidental re-introduction of removed
|
||||
filters."""
|
||||
with pytest.raises(TemplateError):
|
||||
render_template("{{ value | pprint }}", {"value": {"k": 1}})
|
||||
|
||||
|
||||
def test_environment_finalizes_datetime_output_to_iso_string() -> None:
|
||||
"""A datetime that lands directly at an output site is stringified
|
||||
via ``isoformat()`` rather than producing ``str(datetime)`` (which
|
||||
has a space separator). Locks the wire shape templates produce
|
||||
when emitting ``inputs.fired_at`` and other datetime values."""
|
||||
dt = datetime(2026, 5, 28, 14, 30, tzinfo=UTC)
|
||||
|
||||
assert (
|
||||
render_template("{{ moment }}", {"moment": dt}) == "2026-05-28T14:30:00+00:00"
|
||||
)
|
||||
|
||||
|
||||
def test_environment_finalizes_none_output_to_empty_string() -> None:
|
||||
"""A ``None`` at an output site becomes the empty string. Lets
|
||||
templates write ``{{ inputs.last_fired_at }}`` unconditionally on
|
||||
the first run without exploding on the null."""
|
||||
assert render_template("{{ missing }}", {"missing": None}) == ""
|
||||
|
||||
|
||||
def test_environment_finalizes_dict_output_to_json() -> None:
|
||||
"""A dict at an output site is JSON-serialized. Same for lists.
|
||||
Locks the wire shape so users embedding structured values into
|
||||
prompts get deterministic, parseable output."""
|
||||
rendered = render_template("{{ payload }}", {"payload": {"a": 1, "b": [2, 3]}})
|
||||
|
||||
assert rendered == '{"a": 1, "b": [2, 3]}'
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
"""Lock the custom Jinja filters: ``date`` and ``slugify``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.templating.filters import filter_date, filter_slugify
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_filter_slugify_produces_url_safe_slug_from_typical_title() -> None:
|
||||
"""``filter_slugify`` lowercases, replaces non-alphanumerics with
|
||||
hyphens, collapses repeats, and trims edge hyphens — the standard
|
||||
URL-slug contract users expect when piping titles into paths."""
|
||||
assert filter_slugify("Hello, World! 2026") == "hello-world-2026"
|
||||
|
||||
|
||||
def test_filter_date_formats_datetime_with_strftime_format() -> None:
|
||||
"""``filter_date`` calls ``strftime`` on datetime-like values with the
|
||||
provided format. Default format yields ISO date (YYYY-MM-DD)."""
|
||||
dt = datetime(2026, 5, 28, 14, 30, tzinfo=UTC)
|
||||
|
||||
assert filter_date(dt) == "2026-05-28"
|
||||
assert filter_date(dt, "%Y/%m/%d %H:%M") == "2026/05/28 14:30"
|
||||
|
||||
|
||||
def test_filter_date_returns_empty_string_for_none() -> None:
|
||||
"""``None`` (e.g., a never-fired ``last_fired_at``) renders as the
|
||||
empty string rather than the literal ``"None"`` or raising. This is
|
||||
what lets templates write ``{{ inputs.last_fired_at | date }}``
|
||||
unconditionally on the first run."""
|
||||
assert filter_date(None) == ""
|
||||
|
||||
|
||||
def test_filter_date_passes_strings_through_unchanged() -> None:
|
||||
"""Already-formatted ISO strings (the JSON-serialized shape of
|
||||
runtime inputs like ``fired_at``) pass through unchanged so callers
|
||||
don't have to special-case the type."""
|
||||
assert filter_date("2026-05-28T14:30:00+00:00") == "2026-05-28T14:30:00+00:00"
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
"""Lock the public template-rendering surface: render, predicate, recursive."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from jinja2 import UndefinedError
|
||||
|
||||
from app.automations.templating.render import (
|
||||
evaluate_predicate,
|
||||
render_template,
|
||||
render_value,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_render_template_substitutes_context_variables() -> None:
|
||||
"""A template referencing a context variable produces the substituted
|
||||
string. Most basic contract of the template engine."""
|
||||
result = render_template("Hello {{ name }}!", {"name": "World"})
|
||||
|
||||
assert result == "Hello World!"
|
||||
|
||||
|
||||
def test_render_template_raises_on_undefined_variable() -> None:
|
||||
"""Referencing a variable that isn't in the context raises rather than
|
||||
rendering the empty string. Locks the StrictUndefined safety net so
|
||||
template typos surface as run failures instead of silent corruption."""
|
||||
with pytest.raises(UndefinedError):
|
||||
render_template("Hello {{ missing }}!", {})
|
||||
|
||||
|
||||
def test_evaluate_predicate_returns_truthy_outcome_of_expression() -> None:
|
||||
"""``evaluate_predicate`` compiles a Jinja **expression** (not template
|
||||
body) and coerces the value to ``bool``. Drives ``step.when`` gating."""
|
||||
assert evaluate_predicate("inputs.count > 0", {"inputs": {"count": 3}}) is True
|
||||
assert evaluate_predicate("inputs.count > 0", {"inputs": {"count": 0}}) is False
|
||||
|
||||
|
||||
def test_render_value_renders_strings_recursively_through_dicts_and_lists() -> None:
|
||||
"""``render_value`` walks dicts and lists, renders string leaves through
|
||||
the template engine, and leaves non-strings untouched. This is the
|
||||
primitive ``execute_step`` uses to render step params at run time."""
|
||||
context = {"inputs": {"name": "World"}, "topic": "weekly"}
|
||||
|
||||
rendered = render_value(
|
||||
{
|
||||
"greeting": "Hello {{ inputs.name }}",
|
||||
"tags": ["{{ topic }}", "static"],
|
||||
"config": {"retries": 3, "label": "{{ topic }}-{{ inputs.name }}"},
|
||||
},
|
||||
context,
|
||||
)
|
||||
|
||||
assert rendered == {
|
||||
"greeting": "Hello World",
|
||||
"tags": ["weekly", "static"],
|
||||
"config": {"retries": 3, "label": "weekly-World"},
|
||||
}
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
"""Lock the ``params_schema`` derivation on action + trigger definitions.
|
||||
|
||||
Both definition dataclasses expose ``params_schema`` as the JSON Schema
|
||||
of their ``params_model``. This is what the registry endpoints surface
|
||||
to the UI as the "what shape do these params take?" contract.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.automations.actions.types import ActionDefinition
|
||||
from app.automations.triggers.types import TriggerDefinition
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _Topic(BaseModel):
|
||||
"""Model with one required string field — minimal schema fingerprint."""
|
||||
|
||||
topic: str
|
||||
|
||||
|
||||
def test_action_definition_params_schema_reflects_params_model() -> None:
|
||||
"""``ActionDefinition.params_schema`` returns a JSON Schema derived
|
||||
from the Pydantic ``params_model`` — required fields and types are
|
||||
visible to clients consuming the registry endpoint."""
|
||||
definition = ActionDefinition(
|
||||
type="t",
|
||||
name="N",
|
||||
description="D",
|
||||
params_model=_Topic,
|
||||
build_handler=lambda _ctx: lambda _p: {}, # type: ignore[arg-type,return-value]
|
||||
)
|
||||
|
||||
schema = definition.params_schema
|
||||
|
||||
assert schema["type"] == "object"
|
||||
assert schema["properties"]["topic"]["type"] == "string"
|
||||
assert "topic" in schema["required"]
|
||||
|
||||
|
||||
def test_trigger_definition_params_schema_reflects_params_model() -> None:
|
||||
"""Same JSON-Schema derivation contract on the trigger side."""
|
||||
definition = TriggerDefinition(
|
||||
type="t",
|
||||
description="D",
|
||||
params_model=_Topic,
|
||||
)
|
||||
|
||||
schema = definition.params_schema
|
||||
|
||||
assert schema["type"] == "object"
|
||||
assert schema["properties"]["topic"]["type"] == "string"
|
||||
assert "topic" in schema["required"]
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
"""Lock the bundled import side-effects.
|
||||
|
||||
Importing ``app.automations`` (the package) registers the v1 bundled
|
||||
action (``agent_task``) and the v1 bundled trigger (``schedule``). If the
|
||||
import chain breaks (e.g. someone removes ``from . import definition``
|
||||
in a sub-package ``__init__``), the system would silently launch with an
|
||||
empty registry. These tests are the canary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import app.automations # noqa: F401 (force the package import + its side-effects)
|
||||
from app.automations.actions.store import get_action
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.triggers.store import get_trigger
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_bundled_agent_task_action_is_registered_after_package_import() -> None:
|
||||
"""``agent_task`` — the v1 default action — must be discoverable in
|
||||
the registry after the package is imported."""
|
||||
definition = get_action("agent_task")
|
||||
|
||||
assert definition is not None
|
||||
assert definition.type == "agent_task"
|
||||
|
||||
|
||||
def test_bundled_schedule_trigger_is_registered_after_package_import() -> None:
|
||||
"""``schedule`` — the only v1 trigger — must be discoverable in the
|
||||
registry after the package is imported."""
|
||||
definition = get_trigger(TriggerType.SCHEDULE.value)
|
||||
|
||||
assert definition is not None
|
||||
assert definition.type == TriggerType.SCHEDULE.value
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
"""Lock the persistence enum string values + members.
|
||||
|
||||
These enums are mirrored by Postgres enum types, embedded in stored DB
|
||||
rows, and surfaced in the JSON API. Renaming a value (or removing a
|
||||
member) silently breaks production data and previously-issued API
|
||||
responses, so the strings + the set of members are the contract.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.persistence.enums.automation_status import AutomationStatus
|
||||
from app.automations.persistence.enums.run_status import RunStatus
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_automation_status_string_values_are_stable() -> None:
|
||||
"""The exact strings persisted to Postgres and served in API JSON."""
|
||||
assert {member.value for member in AutomationStatus} == {
|
||||
"active",
|
||||
"paused",
|
||||
"archived",
|
||||
}
|
||||
|
||||
|
||||
def test_run_status_string_values_are_stable() -> None:
|
||||
"""Run lifecycle states embedded in the ``automation_runs`` table."""
|
||||
assert {member.value for member in RunStatus} == {
|
||||
"pending",
|
||||
"running",
|
||||
"succeeded",
|
||||
"failed",
|
||||
"cancelled",
|
||||
"timed_out",
|
||||
}
|
||||
|
||||
|
||||
def test_trigger_type_keeps_manual_member_even_though_unregistered() -> None:
|
||||
"""``schedule`` and ``event`` are registered; ``MANUAL`` is reserved
|
||||
(mirrors the Postgres enum) but the trigger store does not register it.
|
||||
The enum must keep every member so DB rows and migrations stay valid."""
|
||||
assert {member.value for member in TriggerType} == {"schedule", "event", "manual"}
|
||||
117
surfsense_backend/tests/unit/automations/test_stores.py
Normal file
117
surfsense_backend/tests/unit/automations/test_stores.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
"""Lock the trigger + action registry contracts.
|
||||
|
||||
Both stores share the same API shape (register/get/all + duplicate-raise),
|
||||
so they're tested together to keep the contract visible side-by-side.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.automations.actions.store import (
|
||||
get_action,
|
||||
register_action,
|
||||
)
|
||||
from app.automations.actions.types import ActionDefinition
|
||||
from app.automations.triggers.store import (
|
||||
all_triggers,
|
||||
get_trigger,
|
||||
register_trigger,
|
||||
)
|
||||
from app.automations.triggers.types import TriggerDefinition
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _Params(BaseModel):
|
||||
"""Empty params model used by test-only registrations."""
|
||||
|
||||
|
||||
def _trigger(type_: str = "test_trigger") -> TriggerDefinition:
|
||||
return TriggerDefinition(
|
||||
type=type_, description="Test trigger.", params_model=_Params
|
||||
)
|
||||
|
||||
|
||||
def _action(type_: str = "test_action") -> ActionDefinition:
|
||||
return ActionDefinition(
|
||||
type=type_,
|
||||
name="Test",
|
||||
description="Test action.",
|
||||
params_model=_Params,
|
||||
build_handler=lambda _ctx: lambda _p: {}, # type: ignore[arg-type,return-value]
|
||||
)
|
||||
|
||||
|
||||
def test_register_trigger_then_get_trigger_returns_the_same_definition(
|
||||
isolated_trigger_registry: None,
|
||||
) -> None:
|
||||
"""The canonical round-trip: register, look up by type, get the same
|
||||
definition back. Locks the basic registry contract."""
|
||||
definition = _trigger()
|
||||
register_trigger(definition)
|
||||
|
||||
assert get_trigger("test_trigger") is definition
|
||||
|
||||
|
||||
def test_register_action_then_get_action_returns_the_same_definition(
|
||||
isolated_action_registry: None,
|
||||
) -> None:
|
||||
"""Same round-trip contract for the action registry."""
|
||||
definition = _action()
|
||||
register_action(definition)
|
||||
|
||||
assert get_action("test_action") is definition
|
||||
|
||||
|
||||
def test_get_trigger_returns_none_for_unknown_type(
|
||||
isolated_trigger_registry: None,
|
||||
) -> None:
|
||||
"""An unknown type returns ``None`` (not raises). Lets callers like
|
||||
the dispatcher branch on "is this trigger still registered?" without
|
||||
try/except."""
|
||||
assert get_trigger("never_registered") is None
|
||||
|
||||
|
||||
def test_get_action_returns_none_for_unknown_type(
|
||||
isolated_action_registry: None,
|
||||
) -> None:
|
||||
"""Same ``None``-not-raise contract on the action side."""
|
||||
assert get_action("never_registered") is None
|
||||
|
||||
|
||||
def test_register_trigger_rejects_duplicate_type(
|
||||
isolated_trigger_registry: None,
|
||||
) -> None:
|
||||
"""Re-registering the same ``type`` raises rather than silently
|
||||
overwriting. Locks the safety net against accidental double-import
|
||||
(e.g., circular imports re-running the registration block)."""
|
||||
register_trigger(_trigger())
|
||||
|
||||
with pytest.raises(ValueError, match="test_trigger"):
|
||||
register_trigger(_trigger())
|
||||
|
||||
|
||||
def test_register_action_rejects_duplicate_type(
|
||||
isolated_action_registry: None,
|
||||
) -> None:
|
||||
"""Same duplicate-rejection contract on the action side."""
|
||||
register_action(_action())
|
||||
|
||||
with pytest.raises(ValueError, match="test_action"):
|
||||
register_action(_action())
|
||||
|
||||
|
||||
def test_all_triggers_returns_defensive_snapshot(
|
||||
isolated_trigger_registry: None,
|
||||
) -> None:
|
||||
"""``all_triggers()`` returns a copy: mutating the returned dict does
|
||||
not corrupt the internal registry. Locks the snapshot contract that
|
||||
UI/listing endpoints rely on."""
|
||||
register_trigger(_trigger("snapshot_test"))
|
||||
|
||||
snapshot = all_triggers()
|
||||
snapshot.pop("snapshot_test")
|
||||
|
||||
assert get_trigger("snapshot_test") is not None
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
"""The ``event`` trigger self-registers on the triggers store at import."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.builtin.event.params import EventTriggerParams
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_event_trigger_is_registered() -> None:
|
||||
definition = get_trigger("event")
|
||||
|
||||
assert definition is not None
|
||||
assert definition.type == "event"
|
||||
assert definition.params_model is EventTriggerParams
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
"""Behavior tests for the ``matches`` filter grammar."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.triggers.builtin.event.filter import FilterError, matches
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_empty_filter_matches_any_payload() -> None:
|
||||
assert matches({}, {"document_id": 42, "document_type": "FILE"}) is True
|
||||
assert matches({}, {}) is True
|
||||
|
||||
|
||||
def test_scalar_value_is_implicit_equality() -> None:
|
||||
flt = {"document_type": "FILE"}
|
||||
assert matches(flt, {"document_type": "FILE"}) is True
|
||||
assert matches(flt, {"document_type": "WEBPAGE"}) is False
|
||||
|
||||
|
||||
def test_multiple_fields_are_anded() -> None:
|
||||
flt = {"document_type": "FILE", "search_space_id": 7}
|
||||
assert matches(flt, {"document_type": "FILE", "search_space_id": 7}) is True
|
||||
assert matches(flt, {"document_type": "FILE", "search_space_id": 9}) is False
|
||||
|
||||
|
||||
def test_gt_operator_compares_greater_than() -> None:
|
||||
flt = {"page_count": {"$gt": 10}}
|
||||
assert matches(flt, {"page_count": 20}) is True
|
||||
assert matches(flt, {"page_count": 10}) is False
|
||||
assert matches(flt, {"page_count": 5}) is False
|
||||
|
||||
|
||||
def test_remaining_comparison_operators() -> None:
|
||||
assert matches({"n": {"$gte": 10}}, {"n": 10}) is True
|
||||
assert matches({"n": {"$gte": 10}}, {"n": 9}) is False
|
||||
|
||||
assert matches({"n": {"$lt": 10}}, {"n": 9}) is True
|
||||
assert matches({"n": {"$lt": 10}}, {"n": 10}) is False
|
||||
|
||||
assert matches({"n": {"$lte": 10}}, {"n": 10}) is True
|
||||
assert matches({"n": {"$lte": 10}}, {"n": 11}) is False
|
||||
|
||||
assert matches({"s": {"$eq": "FILE"}}, {"s": "FILE"}) is True
|
||||
assert matches({"s": {"$eq": "FILE"}}, {"s": "WEB"}) is False
|
||||
|
||||
assert matches({"s": {"$ne": "FILE"}}, {"s": "WEB"}) is True
|
||||
assert matches({"s": {"$ne": "FILE"}}, {"s": "FILE"}) is False
|
||||
|
||||
|
||||
def test_multiple_operators_on_one_field_are_anded() -> None:
|
||||
flt = {"n": {"$gte": 10, "$lt": 20}}
|
||||
assert matches(flt, {"n": 15}) is True
|
||||
assert matches(flt, {"n": 10}) is True
|
||||
assert matches(flt, {"n": 20}) is False
|
||||
assert matches(flt, {"n": 5}) is False
|
||||
|
||||
|
||||
def test_in_and_nin_membership_operators() -> None:
|
||||
flt_in = {"document_type": {"$in": ["FILE", "WEBPAGE"]}}
|
||||
assert matches(flt_in, {"document_type": "FILE"}) is True
|
||||
assert matches(flt_in, {"document_type": "SLACK"}) is False
|
||||
|
||||
flt_nin = {"document_type": {"$nin": ["FILE", "WEBPAGE"]}}
|
||||
assert matches(flt_nin, {"document_type": "SLACK"}) is True
|
||||
assert matches(flt_nin, {"document_type": "FILE"}) is False
|
||||
|
||||
|
||||
def test_or_matches_when_any_branch_holds() -> None:
|
||||
flt = {"$or": [{"document_type": "FILE"}, {"document_type": "WEBPAGE"}]}
|
||||
assert matches(flt, {"document_type": "WEBPAGE"}) is True
|
||||
assert matches(flt, {"document_type": "SLACK"}) is False
|
||||
|
||||
|
||||
def test_and_matches_when_every_branch_holds() -> None:
|
||||
flt = {"$and": [{"n": {"$gt": 5}}, {"n": {"$lt": 10}}]}
|
||||
assert matches(flt, {"n": 7}) is True
|
||||
assert matches(flt, {"n": 12}) is False
|
||||
|
||||
|
||||
def test_not_inverts_its_subexpression() -> None:
|
||||
flt = {"$not": {"document_type": "FILE"}}
|
||||
assert matches(flt, {"document_type": "WEBPAGE"}) is True
|
||||
assert matches(flt, {"document_type": "FILE"}) is False
|
||||
|
||||
|
||||
def test_missing_field_never_matches_and_never_raises() -> None:
|
||||
# Conservative: an absent field fails the constraint, and comparisons must
|
||||
# not raise on the missing value — including $ne (absence isn't "not equal").
|
||||
assert matches({"document_type": "FILE"}, {}) is False
|
||||
assert matches({"page_count": {"$gt": 5}}, {}) is False
|
||||
assert matches({"document_type": {"$in": ["FILE"]}}, {}) is False
|
||||
assert matches({"document_type": {"$ne": "FILE"}}, {}) is False
|
||||
|
||||
|
||||
def test_logical_operators_compose_with_fields() -> None:
|
||||
flt = {
|
||||
"search_space_id": 7,
|
||||
"$or": [{"document_type": "FILE"}, {"document_type": "WEBPAGE"}],
|
||||
}
|
||||
assert matches(flt, {"search_space_id": 7, "document_type": "FILE"}) is True
|
||||
assert matches(flt, {"search_space_id": 9, "document_type": "FILE"}) is False
|
||||
assert matches(flt, {"search_space_id": 7, "document_type": "SLACK"}) is False
|
||||
|
||||
|
||||
def test_unknown_field_operator_raises_filter_error() -> None:
|
||||
with pytest.raises(FilterError):
|
||||
matches({"n": {"$regex": "x"}}, {"n": "xyz"})
|
||||
|
||||
|
||||
def test_unknown_logical_operator_raises_filter_error() -> None:
|
||||
with pytest.raises(FilterError):
|
||||
matches({"$nor": [{"document_type": "FILE"}]}, {"document_type": "FILE"})
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
"""An event hands its payload + metadata to the run as inputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.triggers.builtin.event.inputs import event_runtime_inputs
|
||||
from app.event_bus import Event
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_runtime_inputs_flatten_payload_with_event_metadata() -> None:
|
||||
event = Event(
|
||||
event_type="document.indexed",
|
||||
payload={"document_id": 42, "document_type": "FILE"},
|
||||
search_space_id=7,
|
||||
)
|
||||
|
||||
inputs = event_runtime_inputs(event)
|
||||
|
||||
assert inputs["document_id"] == 42
|
||||
assert inputs["document_type"] == "FILE"
|
||||
assert inputs["event_type"] == "document.indexed"
|
||||
assert inputs["event_id"] == event.event_id
|
||||
assert inputs["occurred_at"] == event.occurred_at.isoformat()
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""Which triggers an event fires: event_type equality + filter match."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.triggers.builtin.event.match import trigger_matches_event
|
||||
from app.event_bus import Event
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _event(event_type: str = "document.indexed", **payload) -> Event:
|
||||
return Event(event_type=event_type, payload=payload, search_space_id=7)
|
||||
|
||||
|
||||
def test_matches_when_event_type_equal_and_filter_passes() -> None:
|
||||
params = {"event_type": "document.indexed", "filter": {"document_type": "FILE"}}
|
||||
assert trigger_matches_event(params, _event(document_type="FILE")) is True
|
||||
|
||||
|
||||
def test_no_match_when_event_type_differs() -> None:
|
||||
params = {"event_type": "document.indexed", "filter": {}}
|
||||
assert trigger_matches_event(params, _event("podcast.generated")) is False
|
||||
|
||||
|
||||
def test_no_match_when_filter_rejects_payload() -> None:
|
||||
params = {"event_type": "document.indexed", "filter": {"document_type": "FILE"}}
|
||||
assert trigger_matches_event(params, _event(document_type="WEBPAGE")) is False
|
||||
|
||||
|
||||
def test_empty_filter_matches_any_payload_of_that_type() -> None:
|
||||
params = {"event_type": "document.indexed", "filter": {}}
|
||||
assert trigger_matches_event(params, _event(document_type="ANYTHING")) is True
|
||||
|
||||
|
||||
def test_missing_filter_key_is_treated_as_empty() -> None:
|
||||
params = {"event_type": "document.indexed"}
|
||||
assert trigger_matches_event(params, _event(document_type="X")) is True
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
"""``EventTriggerParams`` contract: an event_type to listen for + an optional filter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.triggers.builtin.event.params import EventTriggerParams
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_accepts_event_type_and_filter() -> None:
|
||||
params = EventTriggerParams(
|
||||
event_type="document.indexed",
|
||||
filter={"document_type": "FILE"},
|
||||
)
|
||||
|
||||
assert params.event_type == "document.indexed"
|
||||
assert params.filter == {"document_type": "FILE"}
|
||||
|
||||
|
||||
def test_filter_defaults_to_empty() -> None:
|
||||
params = EventTriggerParams(event_type="document.indexed")
|
||||
|
||||
assert params.filter == {}
|
||||
|
||||
|
||||
def test_event_type_is_required() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
EventTriggerParams(filter={"x": 1})
|
||||
|
||||
|
||||
def test_event_type_must_not_be_blank() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
EventTriggerParams(event_type="")
|
||||
|
||||
|
||||
def test_extra_keys_are_forbidden() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
EventTriggerParams(event_type="document.indexed", typo=True)
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
"""Lock the cron + timezone + UTC normalization contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.automations.triggers.builtin.schedule.cron import (
|
||||
InvalidCronError,
|
||||
compute_next_fire_at,
|
||||
validate_cron,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_compute_next_fire_at_returns_next_match_normalized_to_utc() -> None:
|
||||
"""``compute_next_fire_at`` evaluates the cron in the given IANA timezone
|
||||
and returns the next strictly-later match expressed in UTC.
|
||||
|
||||
Setup: ``0 9 * * 1-5`` (09:00 Monday-Friday) in ``Africa/Kigali``
|
||||
(UTC+2, no DST). With ``after`` = Tuesday 05:00 UTC (= 07:00 local),
|
||||
the next fire is the same Tuesday at 09:00 local = 07:00 UTC.
|
||||
"""
|
||||
after = datetime(2026, 5, 26, 5, 0, tzinfo=UTC) # Tue 07:00 Kigali
|
||||
|
||||
next_fire = compute_next_fire_at("0 9 * * 1-5", "Africa/Kigali", after=after)
|
||||
|
||||
assert next_fire == datetime(2026, 5, 26, 7, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
def test_compute_next_fire_at_respects_dst_offset_change() -> None:
|
||||
"""A daily cron in a DST-observing tz fires at the same local hour
|
||||
across the DST boundary, which produces a different UTC offset on
|
||||
either side of the transition.
|
||||
|
||||
Setup: ``0 9 * * *`` (09:00 every day) in ``America/New_York``.
|
||||
NY is UTC-5 in winter (EST), UTC-4 in summer (EDT). Evaluating from
|
||||
each side of the spring-forward in 2026 (Sun Mar 8 at 02:00 → 03:00):
|
||||
|
||||
- winter: ``after`` = 2026-02-15 (EST, UTC-5) → next 09:00 EST = 14:00 UTC
|
||||
- summer: ``after`` = 2026-04-15 (EDT, UTC-4) → next 09:00 EDT = 13:00 UTC
|
||||
"""
|
||||
winter_after = datetime(2026, 2, 15, 0, 0, tzinfo=UTC)
|
||||
summer_after = datetime(2026, 4, 15, 0, 0, tzinfo=UTC)
|
||||
|
||||
winter_fire = compute_next_fire_at(
|
||||
"0 9 * * *", "America/New_York", after=winter_after
|
||||
)
|
||||
summer_fire = compute_next_fire_at(
|
||||
"0 9 * * *", "America/New_York", after=summer_after
|
||||
)
|
||||
|
||||
assert winter_fire == datetime(2026, 2, 15, 14, 0, tzinfo=UTC)
|
||||
assert summer_fire == datetime(2026, 4, 15, 13, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
def test_compute_next_fire_at_is_strictly_after_when_after_equals_a_match() -> None:
|
||||
"""When ``after`` lands exactly on a cron match, the result jumps to the
|
||||
next match — never the same instant. Required so the schedule-tick
|
||||
can pass ``next_fire_at`` itself as ``after`` to advance to the
|
||||
following slot without double-firing.
|
||||
|
||||
Setup: weekday 09:00 Kigali. ``after`` = Mon 09:00 Kigali = 07:00 UTC
|
||||
(an exact match) → next fire must be Tue 09:00 Kigali = next day 07:00 UTC.
|
||||
"""
|
||||
after = datetime(2026, 5, 25, 7, 0, tzinfo=UTC) # Mon 09:00 Kigali — exact match
|
||||
|
||||
next_fire = compute_next_fire_at("0 9 * * 1-5", "Africa/Kigali", after=after)
|
||||
|
||||
assert next_fire == datetime(2026, 5, 26, 7, 0, tzinfo=UTC) # Tue 09:00 Kigali
|
||||
|
||||
|
||||
def test_validate_cron_rejects_malformed_cron_expression() -> None:
|
||||
"""A syntactically invalid cron must be rejected at validation time so
|
||||
bad triggers can't reach storage and explode at fire time."""
|
||||
with pytest.raises(InvalidCronError):
|
||||
validate_cron("this is not cron", "UTC")
|
||||
|
||||
|
||||
def test_validate_cron_rejects_unknown_timezone() -> None:
|
||||
"""A non-IANA timezone string must be rejected at validation time —
|
||||
the same protective gate as the cron expression itself."""
|
||||
with pytest.raises(InvalidCronError):
|
||||
validate_cron("0 9 * * *", "Mars/Olympus_Mons")
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
"""Lock the ``ScheduleTriggerParams`` validation contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.automations.triggers.builtin.schedule.params import ScheduleTriggerParams
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_schedule_params_accept_valid_cron_and_iana_timezone() -> None:
|
||||
"""A well-formed cron + IANA timezone yields a populated model.
|
||||
Locks the round-trip path users go through when creating a trigger."""
|
||||
params = ScheduleTriggerParams(cron="0 9 * * 1-5", timezone="Africa/Kigali")
|
||||
|
||||
assert params.cron == "0 9 * * 1-5"
|
||||
assert params.timezone == "Africa/Kigali"
|
||||
|
||||
|
||||
def test_schedule_params_reject_malformed_cron_with_validation_error() -> None:
|
||||
"""``InvalidCronError`` from ``validate_cron`` must surface as
|
||||
Pydantic ``ValidationError`` so the FastAPI layer returns 422 instead
|
||||
of letting the bad value reach storage."""
|
||||
with pytest.raises(ValidationError):
|
||||
ScheduleTriggerParams(cron="not cron", timezone="UTC")
|
||||
|
||||
|
||||
def test_schedule_params_reject_unknown_timezone_with_validation_error() -> None:
|
||||
"""An unknown timezone is rejected at the API boundary — same gate
|
||||
as the cron expression itself."""
|
||||
with pytest.raises(ValidationError):
|
||||
ScheduleTriggerParams(cron="0 9 * * *", timezone="Mars/Olympus_Mons")
|
||||
0
surfsense_backend/tests/unit/event_bus/__init__.py
Normal file
0
surfsense_backend/tests/unit/event_bus/__init__.py
Normal file
25
surfsense_backend/tests/unit/event_bus/conftest.py
Normal file
25
surfsense_backend/tests/unit/event_bus/conftest.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
"""Shared fixtures for the ``app.event_bus`` unit-test tree.
|
||||
|
||||
The event-type catalog is a module-level registry populated at import. Tests
|
||||
that register their own event types (or assert on registry contents) snapshot
|
||||
and restore it so state never leaks between tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
|
||||
from app.event_bus.catalog import catalog
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_event_catalog() -> Iterator[None]:
|
||||
"""Snapshot and restore the event-type catalog around a test."""
|
||||
snapshot = dict(catalog._registry)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
catalog._registry.clear()
|
||||
catalog._registry.update(snapshot)
|
||||
181
surfsense_backend/tests/unit/event_bus/test_bus.py
Normal file
181
surfsense_backend/tests/unit/event_bus/test_bus.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
"""``EventBus`` contract: subscribe, publish (stamp + fan out), dispatch.
|
||||
|
||||
Each test uses a fresh ``EventBus`` — no shared global state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.event_bus import Event, EventBus
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _event() -> Event:
|
||||
return Event(event_type="x.happened", payload={"k": "v"}, search_space_id=1)
|
||||
|
||||
|
||||
async def _noop(_event: Event) -> None:
|
||||
return None
|
||||
|
||||
|
||||
async def _other(_event: Event) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# --- registry -------------------------------------------------------------
|
||||
|
||||
|
||||
def test_subscribe_then_subscribers_returns_the_handler() -> None:
|
||||
bus = EventBus()
|
||||
bus.subscribe(_noop)
|
||||
|
||||
assert _noop in bus.subscribers()
|
||||
|
||||
|
||||
def test_subscribe_is_idempotent_for_the_same_handler() -> None:
|
||||
"""Registering the same handler twice must not make it fire twice."""
|
||||
bus = EventBus()
|
||||
bus.subscribe(_noop)
|
||||
bus.subscribe(_noop)
|
||||
|
||||
assert bus.subscribers().count(_noop) == 1
|
||||
|
||||
|
||||
def test_distinct_handlers_both_register() -> None:
|
||||
bus = EventBus()
|
||||
bus.subscribe(_noop)
|
||||
bus.subscribe(_other)
|
||||
|
||||
registered = bus.subscribers()
|
||||
assert _noop in registered
|
||||
assert _other in registered
|
||||
|
||||
|
||||
def test_subscribers_returns_a_defensive_snapshot() -> None:
|
||||
"""Mutating the returned list must not corrupt the registry."""
|
||||
bus = EventBus()
|
||||
bus.subscribe(_noop)
|
||||
|
||||
snapshot = bus.subscribers()
|
||||
snapshot.clear()
|
||||
|
||||
assert _noop in bus.subscribers()
|
||||
|
||||
|
||||
def test_subscribe_returns_handler_so_it_can_be_used_as_a_decorator() -> None:
|
||||
bus = EventBus()
|
||||
returned = bus.subscribe(_other)
|
||||
|
||||
assert returned is _other
|
||||
|
||||
|
||||
def test_two_buses_do_not_share_subscribers() -> None:
|
||||
"""The registry is per-instance, not global."""
|
||||
a = EventBus()
|
||||
b = EventBus()
|
||||
a.subscribe(_noop)
|
||||
|
||||
assert _noop in a.subscribers()
|
||||
assert _noop not in b.subscribers()
|
||||
|
||||
|
||||
# --- dispatch -------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_dispatch_delivers_event_to_every_subscriber() -> None:
|
||||
bus = EventBus()
|
||||
seen: list[tuple[str, Event]] = []
|
||||
|
||||
async def first(event: Event) -> None:
|
||||
seen.append(("first", event))
|
||||
|
||||
async def second(event: Event) -> None:
|
||||
seen.append(("second", event))
|
||||
|
||||
bus.subscribe(first)
|
||||
bus.subscribe(second)
|
||||
|
||||
event = _event()
|
||||
await bus.dispatch(event)
|
||||
|
||||
assert ("first", event) in seen
|
||||
assert ("second", event) in seen
|
||||
|
||||
|
||||
async def test_dispatch_isolates_a_failing_subscriber() -> None:
|
||||
"""A subscriber that raises must not stop a healthy one from running."""
|
||||
bus = EventBus()
|
||||
healthy_ran = False
|
||||
|
||||
async def boom(_event: Event) -> None:
|
||||
raise RuntimeError("subscriber blew up")
|
||||
|
||||
async def healthy(_event: Event) -> None:
|
||||
nonlocal healthy_ran
|
||||
healthy_ran = True
|
||||
|
||||
bus.subscribe(boom)
|
||||
bus.subscribe(healthy)
|
||||
|
||||
await bus.dispatch(_event())
|
||||
|
||||
assert healthy_ran is True
|
||||
|
||||
|
||||
async def test_dispatch_never_propagates_subscriber_errors() -> None:
|
||||
"""``dispatch`` itself must not raise even if every subscriber fails."""
|
||||
bus = EventBus()
|
||||
|
||||
async def boom(_event: Event) -> None:
|
||||
raise ValueError("nope")
|
||||
|
||||
bus.subscribe(boom)
|
||||
|
||||
await bus.dispatch(_event()) # must not raise
|
||||
|
||||
|
||||
async def test_dispatch_with_no_subscribers_is_a_noop() -> None:
|
||||
bus = EventBus()
|
||||
await bus.dispatch(_event()) # must not raise
|
||||
|
||||
|
||||
# --- publish --------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_publish_builds_a_stamped_event_and_fans_it_out() -> None:
|
||||
bus = EventBus()
|
||||
received: list[Event] = []
|
||||
|
||||
async def handler(event: Event) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(handler)
|
||||
await bus.publish("document.indexed", {"document_id": 42}, search_space_id=7)
|
||||
|
||||
assert len(received) == 1
|
||||
event = received[0]
|
||||
assert event.event_type == "document.indexed"
|
||||
assert event.payload == {"document_id": 42}
|
||||
assert event.search_space_id == 7
|
||||
# Engine-stamped identity/time on the way through.
|
||||
assert event.event_id
|
||||
assert event.occurred_at
|
||||
|
||||
|
||||
async def test_publish_defaults_payload_to_empty_dict() -> None:
|
||||
bus = EventBus()
|
||||
received: list[Event] = []
|
||||
|
||||
async def handler(event: Event) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(handler)
|
||||
await bus.publish("x.happened", search_space_id=1)
|
||||
|
||||
assert received[0].payload == {}
|
||||
|
||||
|
||||
async def test_publish_with_no_subscribers_is_a_noop() -> None:
|
||||
await EventBus().publish("x.happened", search_space_id=1) # must not raise
|
||||
77
surfsense_backend/tests/unit/event_bus/test_catalog.py
Normal file
77
surfsense_backend/tests/unit/event_bus/test_catalog.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""EventCatalog contract: register, look up, snapshot, derive schema."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.event_bus.catalog import EventCatalog, EventType
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _SamplePayload(BaseModel):
|
||||
document_id: int
|
||||
|
||||
|
||||
def _event_type(type_: str = "test.thing") -> EventType:
|
||||
return EventType(
|
||||
type=type_,
|
||||
description="A thing happened.",
|
||||
payload_model=_SamplePayload,
|
||||
)
|
||||
|
||||
|
||||
def test_register_then_get_returns_the_event_type(isolated_event_catalog: None) -> None:
|
||||
from app.event_bus.catalog import catalog
|
||||
|
||||
catalog.register(_event_type())
|
||||
|
||||
assert catalog.get("test.thing") is not None
|
||||
assert catalog.get("test.thing").type == "test.thing"
|
||||
|
||||
|
||||
def test_get_unknown_type_returns_none(isolated_event_catalog: None) -> None:
|
||||
from app.event_bus.catalog import catalog
|
||||
|
||||
assert catalog.get("does.not.exist") is None
|
||||
|
||||
|
||||
def test_register_duplicate_type_raises(isolated_event_catalog: None) -> None:
|
||||
"""A type is a contract; registering it twice is a bug, not an override."""
|
||||
from app.event_bus.catalog import catalog
|
||||
|
||||
catalog.register(_event_type())
|
||||
|
||||
with pytest.raises(ValueError, match="already registered"):
|
||||
catalog.register(_event_type())
|
||||
|
||||
|
||||
def test_all_is_a_defensive_snapshot(isolated_event_catalog: None) -> None:
|
||||
"""Mutating the returned dict must not corrupt the registry."""
|
||||
from app.event_bus.catalog import catalog
|
||||
|
||||
catalog.register(_event_type())
|
||||
|
||||
snapshot = catalog.all()
|
||||
snapshot.clear()
|
||||
|
||||
assert catalog.get("test.thing") is not None
|
||||
|
||||
|
||||
def test_payload_schema_is_derived_from_the_payload_model() -> None:
|
||||
"""The JSON Schema a UI/validator consumes comes from the payload model."""
|
||||
event_type = _event_type()
|
||||
|
||||
assert event_type.payload_schema == _SamplePayload.model_json_schema()
|
||||
|
||||
|
||||
def test_each_catalog_instance_has_its_own_registry() -> None:
|
||||
"""Two EventCatalog instances are fully independent."""
|
||||
a = EventCatalog()
|
||||
b = EventCatalog()
|
||||
|
||||
a.register(_event_type())
|
||||
|
||||
assert a.get("test.thing") is not None
|
||||
assert b.get("test.thing") is None
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
"""``document.entered_folder`` payload contract + catalog registration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.event_bus.catalog import catalog
|
||||
from app.event_bus.events.document_entered_folder import (
|
||||
EVENT_TYPE,
|
||||
DocumentEnteredFolderPayload,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _payload(**overrides: object) -> DocumentEnteredFolderPayload:
|
||||
base: dict[str, object] = {
|
||||
"document_id": 42,
|
||||
"folder_id": 7,
|
||||
"document_type": "FILE",
|
||||
"title": "Q3 report.pdf",
|
||||
}
|
||||
base.update(overrides)
|
||||
return DocumentEnteredFolderPayload(**base)
|
||||
|
||||
|
||||
def test_payload_carries_the_filterable_fields() -> None:
|
||||
payload = _payload(connector_id=12, created_by_id="abc")
|
||||
|
||||
assert payload.document_id == 42
|
||||
assert payload.folder_id == 7
|
||||
assert payload.document_type == "FILE"
|
||||
assert payload.connector_id == 12
|
||||
|
||||
|
||||
def test_first_placement_is_not_a_move() -> None:
|
||||
"""No previous folder (created or AI-sorted into place) → not a move."""
|
||||
assert _payload(previous_folder_id=None).is_move is False
|
||||
|
||||
|
||||
def test_change_between_folders_is_a_move() -> None:
|
||||
assert _payload(previous_folder_id=3).is_move is True
|
||||
|
||||
|
||||
def test_is_move_is_serialized_for_filtering() -> None:
|
||||
"""Filters match against the dumped payload, so ``is_move`` must appear there."""
|
||||
dumped = _payload(previous_folder_id=3).model_dump()
|
||||
|
||||
assert dumped["is_move"] is True
|
||||
|
||||
|
||||
def test_event_type_is_registered_in_the_catalog() -> None:
|
||||
registered = catalog.get(EVENT_TYPE)
|
||||
|
||||
assert registered is not None
|
||||
assert registered.payload_model is DocumentEnteredFolderPayload
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
"""payload_if_entered_folder: decides whether a document commit warrants an event."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.event_bus.events.document_entered_folder import payload_if_entered_folder
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _call(**overrides: Any) -> dict[str, Any] | None:
|
||||
defaults: dict[str, Any] = {
|
||||
"document_id": 1,
|
||||
"search_space_id": 10,
|
||||
"new_folder_id": 7,
|
||||
"previous_folder_id": None,
|
||||
"folder_id_changed": True,
|
||||
"status_state": "ready",
|
||||
"document_type": "FILE",
|
||||
"title": "report.pdf",
|
||||
"connector_id": None,
|
||||
"created_by_id": None,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return payload_if_entered_folder(**defaults)
|
||||
|
||||
|
||||
def test_folder_set_ready_fires() -> None:
|
||||
result = _call()
|
||||
|
||||
assert result is not None
|
||||
assert result["event_type"] == "document.entered_folder"
|
||||
assert result["search_space_id"] == 10
|
||||
assert result["payload"]["folder_id"] == 7
|
||||
assert result["payload"]["previous_folder_id"] is None
|
||||
|
||||
|
||||
def test_no_folder_is_silent() -> None:
|
||||
assert _call(new_folder_id=None) is None
|
||||
|
||||
|
||||
def test_not_ready_is_silent() -> None:
|
||||
assert _call(status_state="processing") is None
|
||||
|
||||
|
||||
def test_folder_unchanged_is_silent() -> None:
|
||||
assert _call(folder_id_changed=False) is None
|
||||
|
||||
|
||||
def test_move_carries_previous_folder_id() -> None:
|
||||
result = _call(previous_folder_id=3, new_folder_id=7)
|
||||
|
||||
assert result is not None
|
||||
assert result["payload"]["previous_folder_id"] == 3
|
||||
assert result["payload"]["folder_id"] == 7
|
||||
53
surfsense_backend/tests/unit/event_bus/test_event.py
Normal file
53
surfsense_backend/tests/unit/event_bus/test_event.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""``Event`` contract: carry caller facts + engine-stamped id/time, round-trip JSON."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.event_bus.event import Event
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_event_carries_caller_supplied_facts() -> None:
|
||||
"""The three caller inputs are stored verbatim."""
|
||||
event = Event(
|
||||
event_type="document.indexed",
|
||||
payload={"document_id": 42, "content_type": "pdf"},
|
||||
search_space_id=7,
|
||||
)
|
||||
|
||||
assert event.event_type == "document.indexed"
|
||||
assert event.payload == {"document_id": 42, "content_type": "pdf"}
|
||||
assert event.search_space_id == 7
|
||||
|
||||
|
||||
def test_event_stamps_identity_and_time_when_not_supplied() -> None:
|
||||
"""Engine stamps id + time so subscribers can dedup/order."""
|
||||
event = Event(event_type="x.happened", payload={}, search_space_id=1)
|
||||
|
||||
assert event.event_id
|
||||
assert isinstance(event.occurred_at, datetime)
|
||||
|
||||
|
||||
def test_event_ids_are_unique_per_instance() -> None:
|
||||
"""Two events published with identical content are still distinct facts."""
|
||||
first = Event(event_type="x.happened", payload={}, search_space_id=1)
|
||||
second = Event(event_type="x.happened", payload={}, search_space_id=1)
|
||||
|
||||
assert first.event_id != second.event_id
|
||||
|
||||
|
||||
def test_event_survives_json_round_trip() -> None:
|
||||
"""Serialize → deserialize reproduces the event (subscribers queue it as JSON)."""
|
||||
original = Event(
|
||||
event_type="podcast.generated",
|
||||
payload={"podcast_id": 9, "duration_s": 123.5},
|
||||
search_space_id=3,
|
||||
)
|
||||
|
||||
restored = Event.model_validate_json(original.model_dump_json())
|
||||
|
||||
assert restored == original
|
||||
|
|
@ -0,0 +1,557 @@
|
|||
"""Parity gate for the parallel refactor of ``stream_new_chat.py``.
|
||||
|
||||
The new tree under ``app.tasks.chat.streaming.flows`` is built side-by-side with
|
||||
the legacy monolithic ``app.tasks.chat.stream_new_chat`` so we can cut over
|
||||
atomically. This file pins externally-observable behaviour at module
|
||||
boundaries so a divergence between the two trees fails loudly *before* the
|
||||
cutover.
|
||||
|
||||
What we verify:
|
||||
|
||||
1. **Signature parity** — ``stream_new_chat`` / ``stream_resume_chat`` from
|
||||
the new tree have the same call signature as the originals.
|
||||
2. **Helper extraction parity** — the SRP modules in ``flows/`` produce the
|
||||
same outputs as the inline code in the legacy file for representative
|
||||
inputs (initial thinking step, image-capability gate, runtime context,
|
||||
SSE frame sequences, token-usage frame shape, persistence guards).
|
||||
3. **Wrapper delegation** — wrappers like ``load_llm_bundle`` /
|
||||
``can_recover_provider_rate_limit`` exist and are addressable.
|
||||
|
||||
Delete this file along with ``stream_new_chat.py`` once the cutover is done
|
||||
(see the parent refactor plan).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
stream_new_chat as old_stream_new_chat,
|
||||
stream_resume_chat as old_stream_resume_chat,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows import (
|
||||
stream_new_chat as new_stream_new_chat,
|
||||
stream_resume_chat as new_stream_resume_chat,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import (
|
||||
build_initial_thinking_step,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.llm_capability import (
|
||||
check_image_input_capability,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import (
|
||||
await_persist_task,
|
||||
spawn_persist_assistant_shell_task,
|
||||
spawn_persist_user_task,
|
||||
spawn_set_ai_responding_bg,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.runtime_context import (
|
||||
build_new_chat_runtime_context,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.resume_chat.runtime_context import (
|
||||
build_resume_chat_runtime_context,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame
|
||||
from app.tasks.chat.streaming.flows.shared.first_frames import (
|
||||
iter_final_frames,
|
||||
iter_initial_frames,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
|
||||
from app.tasks.chat.streaming.flows.shared.premium_quota import (
|
||||
PremiumReservation,
|
||||
needs_premium_quota,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import (
|
||||
can_recover_provider_rate_limit,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- signature
|
||||
|
||||
|
||||
def _normalize_annotation(ann: Any) -> str:
|
||||
"""Compare-friendly form for an annotation.
|
||||
|
||||
The legacy ``stream_new_chat.py`` does NOT use ``from __future__ import
|
||||
annotations``, so its annotations are evaluated at import time and come
|
||||
back as type objects / typing generics. The new tree DOES use it, so its
|
||||
annotations are PEP-563 strings.
|
||||
|
||||
Both reprs describe the same types — strip the module prefixes / typing
|
||||
namespace + the ``<class 'X'>`` wrapper so we compare the canonical
|
||||
declared form.
|
||||
"""
|
||||
if ann is inspect.Signature.empty:
|
||||
return ""
|
||||
raw = ann if isinstance(ann, str) else repr(ann)
|
||||
cleaned = (
|
||||
raw.replace("typing.", "")
|
||||
.replace("collections.abc.", "")
|
||||
.replace("app.db.", "")
|
||||
.replace("app.agents.new_chat.filesystem_selection.", "")
|
||||
.replace("app.agents.new_chat.context.", "")
|
||||
)
|
||||
# Unwrap ``<class 'int'>`` → ``int`` (legacy-side type objects).
|
||||
if cleaned.startswith("<class '") and cleaned.endswith("'>"):
|
||||
cleaned = cleaned[len("<class '") : -len("'>")]
|
||||
return cleaned
|
||||
|
||||
|
||||
def _normalize_sig(sig: inspect.Signature) -> list[tuple[str, Any, str]]:
|
||||
return [
|
||||
(p.name, p.default, _normalize_annotation(p.annotation))
|
||||
for p in sig.parameters.values()
|
||||
]
|
||||
|
||||
|
||||
def test_stream_new_chat_signature_matches_legacy() -> None:
|
||||
old = inspect.signature(old_stream_new_chat)
|
||||
new = inspect.signature(new_stream_new_chat)
|
||||
assert _normalize_sig(new) == _normalize_sig(old)
|
||||
assert _normalize_annotation(new.return_annotation) == _normalize_annotation(
|
||||
old.return_annotation
|
||||
)
|
||||
|
||||
|
||||
def test_stream_resume_chat_signature_matches_legacy() -> None:
|
||||
old = inspect.signature(old_stream_resume_chat)
|
||||
new = inspect.signature(new_stream_resume_chat)
|
||||
assert _normalize_sig(new) == _normalize_sig(old)
|
||||
assert _normalize_annotation(new.return_annotation) == _normalize_annotation(
|
||||
old.return_annotation
|
||||
)
|
||||
|
||||
|
||||
def test_orchestrators_are_async_generator_functions() -> None:
|
||||
assert inspect.isasyncgenfunction(new_stream_new_chat)
|
||||
assert inspect.isasyncgenfunction(new_stream_resume_chat)
|
||||
|
||||
|
||||
# ------------------------------------------------------------ initial thinking
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_query, image_urls, expected_title, expected_action",
|
||||
[
|
||||
("hello world", None, "Understanding your request", "Processing"),
|
||||
(
|
||||
"",
|
||||
["data:image/png;base64,AAA"],
|
||||
"Understanding your request",
|
||||
"Processing",
|
||||
),
|
||||
("", None, "Understanding your request", "Processing"),
|
||||
],
|
||||
)
|
||||
def test_initial_thinking_step_branches(
|
||||
user_query: str,
|
||||
image_urls: list[str] | None,
|
||||
expected_title: str,
|
||||
expected_action: str,
|
||||
) -> None:
|
||||
step = build_initial_thinking_step(
|
||||
user_query=user_query,
|
||||
user_image_data_urls=image_urls,
|
||||
)
|
||||
assert step.step_id == "thinking-1"
|
||||
assert step.title == expected_title
|
||||
assert len(step.items) == 1
|
||||
assert step.items[0].startswith(f"{expected_action}: ")
|
||||
|
||||
|
||||
def test_initial_thinking_step_truncates_long_query() -> None:
|
||||
long_query = "x" * 200
|
||||
step = build_initial_thinking_step(
|
||||
user_query=long_query,
|
||||
user_image_data_urls=None,
|
||||
)
|
||||
# 80-char truncation + ellipsis, sandwiched after "Processing: ".
|
||||
assert "..." in step.items[0]
|
||||
item = step.items[0]
|
||||
payload = item[len("Processing: ") :]
|
||||
assert payload.startswith("x" * 80) and payload.endswith("...")
|
||||
|
||||
|
||||
# ------------------------------------------------------------ capability gate
|
||||
|
||||
|
||||
def test_image_capability_passes_without_images() -> None:
|
||||
assert (
|
||||
check_image_input_capability(user_image_data_urls=None, agent_config=None)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_image_capability_passes_when_capability_unknown() -> None:
|
||||
"""Unknown / unmapped models are not blocked — only models LiteLLM has
|
||||
*explicitly* marked text-only trip the gate."""
|
||||
|
||||
class _AgentConfig:
|
||||
provider = "openrouter"
|
||||
model_name = "unknown-mystery-model"
|
||||
custom_provider = None
|
||||
config_name = "Unknown"
|
||||
litellm_params: dict[str, Any] = {}
|
||||
|
||||
with patch(
|
||||
"app.services.provider_capabilities.is_known_text_only_chat_model",
|
||||
return_value=False,
|
||||
):
|
||||
assert (
|
||||
check_image_input_capability(
|
||||
user_image_data_urls=["data:image/png;base64,AAA"],
|
||||
agent_config=_AgentConfig(), # type: ignore[arg-type]
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_image_capability_blocks_known_text_only_models() -> None:
|
||||
class _AgentConfig:
|
||||
provider = "openai"
|
||||
model_name = "gpt-3.5-turbo"
|
||||
custom_provider = None
|
||||
config_name = "GPT-3.5"
|
||||
litellm_params: dict[str, Any] = {"base_model": "gpt-3.5-turbo"}
|
||||
|
||||
with patch(
|
||||
"app.services.provider_capabilities.is_known_text_only_chat_model",
|
||||
return_value=True,
|
||||
):
|
||||
result = check_image_input_capability(
|
||||
user_image_data_urls=["data:image/png;base64,AAA"],
|
||||
agent_config=_AgentConfig(), # type: ignore[arg-type]
|
||||
)
|
||||
assert result is not None
|
||||
message, error_code = result
|
||||
assert error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
|
||||
assert "GPT-3.5" in message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- runtime ctx
|
||||
|
||||
|
||||
def test_new_chat_runtime_context_prefers_accepted_folder_ids() -> None:
|
||||
ctx = build_new_chat_runtime_context(
|
||||
search_space_id=7,
|
||||
mentioned_document_ids=[1, 2],
|
||||
accepted_folder_ids=[10],
|
||||
mentioned_folder_ids=[20, 30],
|
||||
request_id="req",
|
||||
turn_id="t1",
|
||||
)
|
||||
assert isinstance(ctx, SurfSenseContextSchema)
|
||||
assert ctx.search_space_id == 7
|
||||
assert list(ctx.mentioned_document_ids) == [1, 2]
|
||||
assert list(ctx.mentioned_folder_ids) == [10]
|
||||
assert ctx.request_id == "req"
|
||||
assert ctx.turn_id == "t1"
|
||||
|
||||
|
||||
def test_new_chat_runtime_context_falls_back_to_mentioned_folder_ids() -> None:
|
||||
ctx = build_new_chat_runtime_context(
|
||||
search_space_id=7,
|
||||
mentioned_document_ids=None,
|
||||
accepted_folder_ids=[],
|
||||
mentioned_folder_ids=[20, 30],
|
||||
request_id=None,
|
||||
turn_id="t2",
|
||||
)
|
||||
assert list(ctx.mentioned_folder_ids) == [20, 30]
|
||||
|
||||
|
||||
def test_resume_chat_runtime_context_empty_mention_lists() -> None:
|
||||
ctx = build_resume_chat_runtime_context(
|
||||
search_space_id=42, request_id="req-r", turn_id="t-r"
|
||||
)
|
||||
assert ctx.search_space_id == 42
|
||||
assert ctx.request_id == "req-r"
|
||||
assert ctx.turn_id == "t-r"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- SSE frames
|
||||
|
||||
|
||||
def test_iter_initial_frames_emits_canonical_sequence() -> None:
|
||||
svc = VercelStreamingService()
|
||||
frames = list(iter_initial_frames(svc, turn_id="42:1700000000000"))
|
||||
# Exactly 4 frames: message_start, start_step, turn-info (turn_id), turn-status (busy).
|
||||
assert len(frames) == 4
|
||||
assert "42:1700000000000" in frames[2]
|
||||
assert '"status":"busy"' in frames[3] or '"status": "busy"' in frames[3]
|
||||
|
||||
|
||||
def test_iter_final_frames_emits_idle_then_finish_done() -> None:
|
||||
svc = VercelStreamingService()
|
||||
frames = list(iter_final_frames(svc))
|
||||
assert len(frames) == 4
|
||||
assert '"status":"idle"' in frames[0] or '"status": "idle"' in frames[0]
|
||||
|
||||
|
||||
# ----------------------------------------------------------- token usage frame
|
||||
|
||||
|
||||
class _FakeAccumulator:
|
||||
"""Minimal stand-in covering only the fields ``iter_token_usage_frame`` reads."""
|
||||
|
||||
def __init__(self, summary: Any = None) -> None:
|
||||
self._summary = summary
|
||||
self.calls = [1, 2, 3]
|
||||
self.grand_total = 100
|
||||
self.total_cost_micros = 50_000
|
||||
self.total_prompt_tokens = 60
|
||||
self.total_completion_tokens = 40
|
||||
|
||||
def per_message_summary(self) -> Any:
|
||||
return self._summary
|
||||
|
||||
def serialized_calls(self) -> list[Any]:
|
||||
return list(self.calls)
|
||||
|
||||
|
||||
def test_token_usage_frame_skipped_when_no_summary() -> None:
|
||||
svc = VercelStreamingService()
|
||||
frames = list(
|
||||
iter_token_usage_frame(
|
||||
svc,
|
||||
accumulator=_FakeAccumulator(summary=None), # type: ignore[arg-type]
|
||||
log_label="parity-empty",
|
||||
)
|
||||
)
|
||||
assert frames == []
|
||||
|
||||
|
||||
def test_token_usage_frame_emitted_when_summary_present() -> None:
|
||||
svc = VercelStreamingService()
|
||||
frames = list(
|
||||
iter_token_usage_frame(
|
||||
svc,
|
||||
accumulator=_FakeAccumulator(summary=[{"m": "x", "t": 100}]), # type: ignore[arg-type]
|
||||
log_label="parity-populated",
|
||||
)
|
||||
)
|
||||
assert len(frames) == 1
|
||||
# Field shape on the wire is fixed by the FE; assert each surfaces.
|
||||
payload = frames[0]
|
||||
for key in (
|
||||
'"prompt_tokens":60',
|
||||
'"completion_tokens":40',
|
||||
'"total_tokens":100',
|
||||
'"cost_micros":50000',
|
||||
):
|
||||
assert key in payload.replace(" ", "")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ llm_bundle
|
||||
|
||||
|
||||
def test_load_llm_bundle_routes_negative_id_to_yaml_loader() -> None:
|
||||
async def _run() -> tuple[Any, Any, str | None]:
|
||||
with (
|
||||
patch(
|
||||
"app.tasks.chat.streaming.flows.shared.llm_bundle.load_global_llm_config_by_id",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
return await load_llm_bundle(
|
||||
session=AsyncMock(), # type: ignore[arg-type]
|
||||
config_id=-1,
|
||||
search_space_id=7,
|
||||
)
|
||||
|
||||
llm, agent_config, error = asyncio.run(_run())
|
||||
assert llm is None
|
||||
assert agent_config is None
|
||||
assert error is not None and "id -1" in error
|
||||
|
||||
|
||||
def test_load_llm_bundle_routes_nonnegative_id_to_db_loader() -> None:
|
||||
async def _run() -> tuple[Any, Any, str | None]:
|
||||
with (
|
||||
patch(
|
||||
"app.tasks.chat.streaming.flows.shared.llm_bundle.load_agent_config",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
):
|
||||
return await load_llm_bundle(
|
||||
session=AsyncMock(), # type: ignore[arg-type]
|
||||
config_id=12,
|
||||
search_space_id=7,
|
||||
)
|
||||
|
||||
llm, agent_config, error = asyncio.run(_run())
|
||||
assert llm is None
|
||||
assert agent_config is None
|
||||
assert error is not None and "id 12" in error
|
||||
|
||||
|
||||
# ----------------------------------------------------------------- premium quota
|
||||
|
||||
|
||||
def test_needs_premium_quota_requires_user_and_premium_flag() -> None:
|
||||
class _AgentConfig:
|
||||
is_premium = True
|
||||
|
||||
class _NonPremium:
|
||||
is_premium = False
|
||||
|
||||
assert needs_premium_quota(_AgentConfig(), "user-1") is True # type: ignore[arg-type]
|
||||
assert needs_premium_quota(_AgentConfig(), None) is False # type: ignore[arg-type]
|
||||
assert needs_premium_quota(_NonPremium(), "user-1") is False # type: ignore[arg-type]
|
||||
assert needs_premium_quota(None, "user-1") is False
|
||||
|
||||
|
||||
def test_premium_reservation_dataclass_shape() -> None:
|
||||
# Sanity: the dataclass exists and carries the fields the orchestrator uses.
|
||||
r = PremiumReservation(request_id="abc", reserved_micros=100, allowed=True)
|
||||
assert r.request_id == "abc"
|
||||
assert r.reserved_micros == 100
|
||||
assert r.allowed is True
|
||||
|
||||
|
||||
# ----------------------------------------------------------- rate-limit guard
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"first_event_seen, recovered, requested_id, current_id, expected",
|
||||
[
|
||||
(False, False, 0, -1, True),
|
||||
# Already recovered: no second pass.
|
||||
(False, True, 0, -1, False),
|
||||
# User explicitly picked a config: don't silently switch.
|
||||
(False, False, 5, -1, False),
|
||||
# Already on a database-backed (positive) id.
|
||||
(False, False, 0, 7, False),
|
||||
# User has already seen output: silent rebuild not possible.
|
||||
(True, False, 0, -1, False),
|
||||
],
|
||||
)
|
||||
def test_can_recover_provider_rate_limit_truth_table(
|
||||
first_event_seen: bool,
|
||||
recovered: bool,
|
||||
requested_id: int,
|
||||
current_id: int,
|
||||
expected: bool,
|
||||
) -> None:
|
||||
# Use a known rate-limit-shaped exception so the helper's last condition
|
||||
# is satisfied; the guard only short-circuits to False when one of the
|
||||
# *other* preconditions fails.
|
||||
exc = Exception('{"error":{"type":"rate_limit_error","message":"slow"}}')
|
||||
assert (
|
||||
can_recover_provider_rate_limit(
|
||||
exc,
|
||||
first_event_seen=first_event_seen,
|
||||
runtime_rate_limit_recovered=recovered,
|
||||
requested_llm_config_id=requested_id,
|
||||
current_llm_config_id=current_id,
|
||||
)
|
||||
is expected
|
||||
)
|
||||
|
||||
|
||||
def test_can_recover_provider_rate_limit_rejects_non_rate_limit_exception() -> None:
|
||||
assert (
|
||||
can_recover_provider_rate_limit(
|
||||
ValueError("not a rate limit"),
|
||||
first_event_seen=False,
|
||||
runtime_rate_limit_recovered=False,
|
||||
requested_llm_config_id=0,
|
||||
current_llm_config_id=-1,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------- persistence spawn
|
||||
|
||||
|
||||
def test_spawn_set_ai_responding_bg_noop_without_user_id() -> None:
|
||||
async def _run() -> set[asyncio.Task]:
|
||||
background: set[asyncio.Task] = set()
|
||||
spawn_set_ai_responding_bg(chat_id=1, user_id=None, background_tasks=background)
|
||||
return background
|
||||
|
||||
bg = asyncio.run(_run())
|
||||
assert bg == set()
|
||||
|
||||
|
||||
def test_spawn_persist_user_task_registers_and_self_unregisters() -> None:
|
||||
async def _run() -> tuple[int, int]:
|
||||
background: set[asyncio.Task] = set()
|
||||
with patch(
|
||||
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_user_turn",
|
||||
new=AsyncMock(return_value=99),
|
||||
):
|
||||
task = spawn_persist_user_task(
|
||||
chat_id=1,
|
||||
user_id="u",
|
||||
turn_id="t",
|
||||
user_query="hi",
|
||||
user_image_data_urls=None,
|
||||
mentioned_documents=None,
|
||||
background_tasks=background,
|
||||
)
|
||||
size_before_await = len(background)
|
||||
result = await asyncio.shield(task)
|
||||
# Give the done-callback one event-loop tick to run.
|
||||
await asyncio.sleep(0)
|
||||
return size_before_await, result # type: ignore[return-value]
|
||||
|
||||
size_before, result = asyncio.run(_run())
|
||||
assert size_before == 1
|
||||
assert result == 99
|
||||
|
||||
|
||||
def test_spawn_persist_assistant_shell_task_registers() -> None:
|
||||
async def _run() -> int | None:
|
||||
background: set[asyncio.Task] = set()
|
||||
with patch(
|
||||
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_assistant_shell",
|
||||
new=AsyncMock(return_value=42),
|
||||
):
|
||||
task = spawn_persist_assistant_shell_task(
|
||||
chat_id=1,
|
||||
user_id="u",
|
||||
turn_id="t",
|
||||
background_tasks=background,
|
||||
)
|
||||
return await asyncio.shield(task)
|
||||
|
||||
assert asyncio.run(_run()) == 42
|
||||
|
||||
|
||||
def test_await_persist_task_returns_none_on_failure() -> None:
|
||||
async def _run() -> int | None:
|
||||
async def _boom() -> int:
|
||||
raise RuntimeError("DB down")
|
||||
|
||||
task = asyncio.create_task(_boom())
|
||||
return await await_persist_task(
|
||||
task,
|
||||
chat_id=1,
|
||||
turn_id="t",
|
||||
log_label="parity-failure",
|
||||
)
|
||||
|
||||
assert asyncio.run(_run()) is None
|
||||
|
||||
|
||||
def test_await_persist_task_returns_none_for_none_input() -> None:
|
||||
async def _run() -> int | None:
|
||||
return await await_persist_task(
|
||||
None,
|
||||
chat_id=1,
|
||||
turn_id="t",
|
||||
log_label="parity-none",
|
||||
)
|
||||
|
||||
assert asyncio.run(_run()) is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue