feat(automations): implement model eligibility checks for automation creation

- Added model eligibility checks to ensure automations can only use billable models (premium or BYOK).
- Introduced new API endpoint to report model eligibility status for search spaces.
- Updated frontend components to display eligibility alerts and disable creation options when models are not billable.
- Enhanced automation creation forms to reflect model eligibility, preventing users from submitting invalid configurations.
- Implemented server-side logic to capture and preserve model preferences across automation edits, ensuring consistent behavior during execution.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-29 03:13:46 -07:00
parent 5d90fbe99f
commit 409fec94c3
32 changed files with 1451 additions and 67 deletions

View file

@ -0,0 +1,174 @@
"""Lock the runtime model-policy backstop in ``build_dependencies``.
Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so
runs are insulated from later chat/search-space model changes), and the model
policy is re-checked at run time so a captured model that is no longer billable
fails the run clearly. When no snapshot is present, resolution falls back to the
live search space.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
import pytest
import app.automations.actions.agent_task.dependencies as deps_mod
from app.automations.actions.agent_task.dependencies import (
DependencyError,
build_dependencies,
)
from app.automations.services.model_policy import AutomationModelPolicyError
pytestmark = pytest.mark.unit
class _FakeSession:
"""Minimal async session whose ``get`` returns a preset search space."""
def __init__(self, search_space: Any) -> None:
self._search_space = search_space
async def get(self, _model: Any, _pk: int) -> Any:
return self._search_space
@pytest.fixture
def patched_side_effects(monkeypatch: pytest.MonkeyPatch):
"""Stub the connector setup + checkpointer so only policy/LLM logic runs."""
async def _fake_setup(_session, *, search_space_id):
return (SimpleNamespace(name="connector"), "fc-key")
monkeypatch.setattr(deps_mod, "setup_connector_and_firecrawl", _fake_setup)
return None
async def test_build_dependencies_resolves_captured_agent_llm_id(
monkeypatch: pytest.MonkeyPatch, patched_side_effects
) -> None:
"""The bundle loads with the *captured* ``agent_llm_id``, not the live search space."""
captured: dict[str, Any] = {}
async def _fake_load(_session, *, config_id, search_space_id):
captured["config_id"] = config_id
captured["search_space_id"] = search_space_id
return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None)
monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load)
# Captured path validates the explicit ids; passes for this test.
monkeypatch.setattr(deps_mod, "assert_models_billable", lambda **_kw: None)
# A different value on the live search space proves we ignore it when a
# snapshot is supplied.
monkeypatch.setattr(
deps_mod,
"assert_automation_models_billable",
lambda _ss: pytest.fail("search-space policy should not run on captured path"),
)
search_space = SimpleNamespace(agent_llm_id=-99)
result = await build_dependencies(
session=_FakeSession(search_space),
search_space_id=42,
agent_llm_id=-7,
image_generation_config_id=5,
vision_llm_config_id=-1,
)
assert captured == {"config_id": -7, "search_space_id": 42}
assert result.llm.name == "llm"
assert result.firecrawl_api_key == "fc-key"
async def test_build_dependencies_validates_captured_ids(
monkeypatch: pytest.MonkeyPatch, patched_side_effects
) -> None:
"""The captured ids (not the search space) are what gets policy-checked."""
seen: dict[str, Any] = {}
def _capture(**kwargs):
seen.update(kwargs)
monkeypatch.setattr(deps_mod, "assert_models_billable", _capture)
async def _fake_load(_session, *, config_id, search_space_id):
return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None)
monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load)
await build_dependencies(
session=_FakeSession(SimpleNamespace(agent_llm_id=0)),
search_space_id=42,
agent_llm_id=-7,
image_generation_config_id=5,
vision_llm_config_id=-1,
)
assert seen == {
"agent_llm_id": -7,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
}
async def test_build_dependencies_raises_on_captured_policy_violation(
monkeypatch: pytest.MonkeyPatch, patched_side_effects
) -> None:
"""A blocked captured model raises ``DependencyError`` so the step fails clearly."""
def _raise(**_kw):
raise AutomationModelPolicyError(
[{"kind": "image", "config_id": -2, "reason": "free model"}]
)
monkeypatch.setattr(deps_mod, "assert_models_billable", _raise)
monkeypatch.setattr(
deps_mod,
"load_llm_bundle",
lambda *a, **k: pytest.fail("load_llm_bundle should not be called"),
)
with pytest.raises(DependencyError):
await build_dependencies(
session=_FakeSession(SimpleNamespace(agent_llm_id=-7)),
search_space_id=42,
agent_llm_id=-7,
image_generation_config_id=-2,
vision_llm_config_id=-1,
)
async def test_build_dependencies_falls_back_to_search_space(
monkeypatch: pytest.MonkeyPatch, patched_side_effects
) -> None:
"""With no captured snapshot, resolve + validate the live search space."""
captured: dict[str, Any] = {}
async def _fake_load(_session, *, config_id, search_space_id):
captured["config_id"] = config_id
return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None)
monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load)
monkeypatch.setattr(deps_mod, "assert_automation_models_billable", lambda _ss: None)
monkeypatch.setattr(
deps_mod,
"assert_models_billable",
lambda **_kw: pytest.fail("captured policy should not run on fallback path"),
)
search_space = SimpleNamespace(agent_llm_id=-7)
result = await build_dependencies(
session=_FakeSession(search_space), search_space_id=42
)
assert captured == {"config_id": -7}
assert result.llm.name == "llm"
async def test_build_dependencies_raises_when_search_space_missing(
patched_side_effects,
) -> None:
"""A missing search space (fallback path) surfaces as a ``DependencyError``."""
with pytest.raises(DependencyError):
await build_dependencies(session=_FakeSession(None), search_space_id=999)

View file

@ -0,0 +1,59 @@
"""Lock that the executor propagates the captured model snapshot into the
``ActionContext``, so runs resolve their own model (insulated from chat /
search-space changes) and not the live search space.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import cast
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.runtime.executor import _build_action_ctx
from app.automations.schemas.definition.envelope import AutomationModels
from app.automations.schemas.definition.plan_step import PlanStep
pytestmark = pytest.mark.unit
def _run() -> SimpleNamespace:
return SimpleNamespace(
id=1,
automation=SimpleNamespace(search_space_id=42, created_by_user_id="u-1"),
)
def test_build_action_ctx_propagates_captured_models() -> None:
"""``definition.models`` flows onto the ActionContext model fields."""
models = AutomationModels(
agent_llm_id=-1,
image_generation_config_id=5,
vision_llm_config_id=-1,
)
ctx = _build_action_ctx(
cast(AsyncSession, None),
_run(),
PlanStep(step_id="s1", action="agent_task"),
models,
)
assert ctx.search_space_id == 42
assert ctx.agent_llm_id == -1
assert ctx.image_generation_config_id == 5
assert ctx.vision_llm_config_id == -1
def test_build_action_ctx_none_models_leaves_fields_none() -> None:
"""No captured snapshot → model fields are None (defensive fallback path)."""
ctx = _build_action_ctx(
cast(AsyncSession, None),
_run(),
PlanStep(step_id="s1", action="agent_task"),
None,
)
assert ctx.agent_llm_id is None
assert ctx.image_generation_config_id is None
assert ctx.vision_llm_config_id is None

View file

@ -5,7 +5,10 @@ from __future__ import annotations
import pytest
from pydantic import ValidationError
from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.schemas.definition.envelope import (
AutomationDefinition,
AutomationModels,
)
from app.automations.schemas.definition.plan_step import PlanStep
pytestmark = pytest.mark.unit
@ -27,6 +30,34 @@ def test_automation_definition_accepts_minimal_valid_input_with_sensible_default
assert definition.goal is None
assert definition.inputs is None
assert definition.triggers == []
# ``models`` is optional (populated server-side at create()).
assert definition.models is None
def test_automation_definition_models_round_trip() -> None:
"""The captured ``models`` snapshot survives a model_dump/validate round-trip."""
definition = AutomationDefinition(
name="Daily digest",
plan=[PlanStep(step_id="s1", action="agent_task")],
models=AutomationModels(
agent_llm_id=-1,
image_generation_config_id=5,
vision_llm_config_id=-1,
),
)
dumped = definition.model_dump(mode="json", by_alias=True)
assert dumped["models"] == {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
}
restored = AutomationDefinition.model_validate(dumped)
assert restored.models is not None
assert restored.models.agent_llm_id == -1
assert restored.models.image_generation_config_id == 5
assert restored.models.vision_llm_config_id == -1
def test_automation_definition_rejects_unknown_top_level_field() -> None:

View file

@ -0,0 +1,236 @@
"""Lock creation-time model-policy enforcement in ``AutomationService``.
Creation (REST + manual builder) rejects search spaces whose models aren't
billable for automations with HTTP 422, mirroring the runtime backstop. These
tests isolate the new ``_assert_models_billable`` / ``model_eligibility`` paths
without touching the DB commit.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
import pytest
from fastapi import HTTPException
import app.automations.services.automation as automation_mod
from app.automations.schemas.api import AutomationCreate, AutomationUpdate
from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.schemas.definition.plan_step import PlanStep
from app.automations.services.automation import AutomationService
from app.automations.services.model_policy import AutomationModelPolicyError
pytestmark = pytest.mark.unit
class _FakeSession:
def __init__(self, search_space: Any) -> None:
self._search_space = search_space
self.added: list[Any] = []
self.commits = 0
async def get(self, _model: Any, _pk: int) -> Any:
return self._search_space
def add(self, obj: Any) -> None:
self.added.append(obj)
async def commit(self) -> None:
self.commits += 1
def _service(search_space: Any) -> AutomationService:
return AutomationService(
session=_FakeSession(search_space), user=SimpleNamespace(id="u-1")
)
def _definition(**kwargs: Any) -> AutomationDefinition:
return AutomationDefinition(
name="A",
plan=[PlanStep(step_id="s1", action="agent_task")],
**kwargs,
)
async def test_assert_models_billable_raises_422_on_violation(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A blocked model maps the policy error to HTTP 422."""
def _raise(_ss):
raise AutomationModelPolicyError(
[{"kind": "llm", "config_id": 0, "reason": "Auto mode"}]
)
monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise)
service = _service(SimpleNamespace(agent_llm_id=0))
with pytest.raises(HTTPException) as exc_info:
await service._assert_models_billable(1)
assert exc_info.value.status_code == 422
async def test_assert_models_billable_raises_404_when_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A missing search space is a 404, not a policy error."""
monkeypatch.setattr(
automation_mod, "assert_automation_models_billable", lambda _ss: None
)
service = _service(None)
with pytest.raises(HTTPException) as exc_info:
await service._assert_models_billable(999)
assert exc_info.value.status_code == 404
async def test_assert_models_billable_returns_search_space_when_ok(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the policy accepts, the loaded search space is returned for reuse."""
monkeypatch.setattr(
automation_mod, "assert_automation_models_billable", lambda _ss: None
)
search_space = SimpleNamespace(agent_llm_id=-1)
service = _service(search_space)
assert await service._assert_models_billable(1) is search_space
async def test_create_injects_captured_models_from_search_space(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""create() snapshots the search space's model prefs onto the definition."""
monkeypatch.setattr(
automation_mod, "assert_automation_models_billable", lambda _ss: None
)
async def _noop_authorize(self, *_a, **_k):
return None
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
async def _return_added(self, _aid):
return self.session.added[-1]
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
search_space = SimpleNamespace(
agent_llm_id=-1,
image_generation_config_id=5,
vision_llm_config_id=-1,
)
service = _service(search_space)
payload = AutomationCreate(
search_space_id=1,
name="A",
definition=_definition(),
)
automation = await service.create(payload)
assert automation.definition["models"] == {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
}
async def test_create_treats_unset_prefs_as_auto_zero(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""``None`` search-space prefs are captured as ``0`` (Auto) ids."""
monkeypatch.setattr(
automation_mod, "assert_automation_models_billable", lambda _ss: None
)
async def _noop_authorize(self, *_a, **_k):
return None
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
async def _return_added(self, _aid):
return self.session.added[-1]
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
search_space = SimpleNamespace(
agent_llm_id=None,
image_generation_config_id=None,
vision_llm_config_id=None,
)
service = _service(search_space)
payload = AutomationCreate(search_space_id=1, name="A", definition=_definition())
automation = await service.create(payload)
assert automation.definition["models"] == {
"agent_llm_id": 0,
"image_generation_config_id": 0,
"vision_llm_config_id": 0,
}
async def test_update_preserves_captured_models(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A definition edit carries over the previously captured ``models``."""
captured = {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
}
existing = SimpleNamespace(
search_space_id=1,
definition={"name": "A", "plan": [], "models": captured},
version=3,
)
async def _noop_authorize(self, *_a, **_k):
return None
async def _return_existing(self, _aid):
return existing
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
monkeypatch.setattr(
AutomationService, "_get_with_triggers_or_raise", _return_existing
)
service = _service(SimpleNamespace())
# The incoming patch definition has no ``models`` (frontend strips it).
patch = AutomationUpdate(definition=_definition())
result = await service.update(7, patch)
assert result.definition["models"] == captured
assert result.version == 4
async def test_model_eligibility_authorizes_and_returns_payload(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""``model_eligibility`` checks read access then returns the eligibility dict."""
authorized: dict[str, Any] = {}
async def _fake_check_permission(_session, _user, ss_id, permission, _msg):
authorized["search_space_id"] = ss_id
authorized["permission"] = permission
monkeypatch.setattr(automation_mod, "check_permission", _fake_check_permission)
monkeypatch.setattr(
automation_mod,
"get_automation_model_eligibility",
lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]},
)
service = _service(SimpleNamespace(agent_llm_id=-2))
result = await service.model_eligibility(search_space_id=5)
assert result == {"allowed": False, "violations": [{"kind": "image"}]}
assert authorized["search_space_id"] == 5
assert authorized["permission"] == "automations:read"

View file

@ -0,0 +1,196 @@
"""Lock the automation model-billing policy.
Automations may only run on billable models: premium global configs
(``billing_tier == "premium"``) or user BYOK configs (positive id). Free
globals and Auto mode (id == 0 / None) are blocked. These tests pin that rule
across all three model slots (chat LLM, image, vision).
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
import app.automations.services.model_policy as model_policy
from app.automations.services.model_policy import (
AutomationModelPolicyError,
assert_automation_models_billable,
assert_models_billable,
get_automation_model_eligibility,
get_model_eligibility,
)
pytestmark = pytest.mark.unit
def _search_space(*, llm: int | None, image: int | None, vision: int | None):
"""Minimal stand-in for the ``SearchSpace`` ORM row the policy reads."""
return SimpleNamespace(
agent_llm_id=llm,
image_generation_config_id=image,
vision_llm_config_id=vision,
)
@pytest.fixture
def patched_globals(monkeypatch: pytest.MonkeyPatch):
"""Stub the global config sources the policy consults for negative ids.
Negative ids: -1 is premium, -2 is free, for each of llm/image/vision.
"""
llm_configs = {
-1: {"id": -1, "billing_tier": "premium"},
-2: {"id": -2, "billing_tier": "free"},
}
monkeypatch.setattr(
"app.agents.new_chat.llm_config.load_global_llm_config_by_id",
lambda cid: llm_configs.get(cid),
)
from app.config import config as app_config
monkeypatch.setattr(
app_config,
"GLOBAL_IMAGE_GEN_CONFIGS",
[
{"id": -1, "billing_tier": "premium"},
{"id": -2, "billing_tier": "free"},
],
raising=False,
)
monkeypatch.setattr(
app_config,
"GLOBAL_VISION_LLM_CONFIGS",
[
{"id": -1, "billing_tier": "premium"},
{"id": -2, "billing_tier": "free"},
],
raising=False,
)
return None
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None:
"""A positive config id is a user-owned BYOK model — always billable."""
allowed, reason = model_policy._classify(kind, 7)
assert allowed is True
assert reason == ""
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
@pytest.mark.parametrize("config_id", [0, None])
def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None:
"""Auto mode (id 0) and an unset slot (None) are blocked."""
allowed, reason = model_policy._classify(kind, config_id)
assert allowed is False
assert "Auto mode" in reason
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
def test_premium_global_is_allowed(kind: str, patched_globals) -> None:
"""A negative (global) id with premium billing tier is allowed."""
allowed, reason = model_policy._classify(kind, -1)
assert allowed is True
assert reason == ""
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
def test_free_global_is_blocked(kind: str, patched_globals) -> None:
"""A negative (global) id with a free billing tier is blocked."""
allowed, reason = model_policy._classify(kind, -2)
assert allowed is False
assert "free model" in reason
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None:
"""A negative id that resolves to no config is treated as not premium."""
allowed, _ = model_policy._classify(kind, -999)
assert allowed is False
def test_eligibility_all_billable(patched_globals) -> None:
"""Premium LLM + BYOK image + premium vision → allowed, no violations."""
search_space = _search_space(llm=-1, image=5, vision=-1)
result = get_automation_model_eligibility(search_space)
assert result == {"allowed": True, "violations": []}
def test_eligibility_reports_each_violation(patched_globals) -> None:
"""A free LLM, Auto image, and free vision each produce a violation."""
search_space = _search_space(llm=-2, image=0, vision=-2)
result = get_automation_model_eligibility(search_space)
assert result["allowed"] is False
kinds = {v["kind"] for v in result["violations"]}
assert kinds == {"llm", "image", "vision"}
# config_id is echoed back for the UI / settings deep-link.
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
def test_assert_raises_with_violations(patched_globals) -> None:
"""``assert_automation_models_billable`` raises when any slot is blocked."""
search_space = _search_space(llm=0, image=5, vision=-1)
with pytest.raises(AutomationModelPolicyError) as exc_info:
assert_automation_models_billable(search_space)
assert len(exc_info.value.violations) == 1
assert exc_info.value.violations[0]["kind"] == "llm"
def test_assert_passes_when_all_billable(patched_globals) -> None:
"""No exception when every slot is premium or BYOK."""
search_space = _search_space(llm=3, image=-1, vision=4)
assert assert_automation_models_billable(search_space) is None
# --- ID-based core (used by the runtime backstop against captured snapshots) ---
def test_get_model_eligibility_all_billable(patched_globals) -> None:
"""Premium LLM + BYOK image + premium vision (explicit ids) → allowed."""
result = get_model_eligibility(
agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1
)
assert result == {"allowed": True, "violations": []}
def test_get_model_eligibility_reports_each_violation(patched_globals) -> None:
"""Free LLM, Auto image, free vision (explicit ids) each produce a violation."""
result = get_model_eligibility(
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
)
assert result["allowed"] is False
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
def test_assert_models_billable_raises(patched_globals) -> None:
"""``assert_models_billable`` raises when any explicit id is blocked."""
with pytest.raises(AutomationModelPolicyError) as exc_info:
assert_models_billable(
agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1
)
assert len(exc_info.value.violations) == 1
assert exc_info.value.violations[0]["kind"] == "llm"
def test_assert_models_billable_passes(patched_globals) -> None:
"""No exception when every explicit id is premium or BYOK."""
assert (
assert_models_billable(
agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4
)
is None
)
def test_search_space_wrapper_delegates_to_core(patched_globals) -> None:
"""The search-space wrapper produces the same result as the ID core."""
search_space = _search_space(llm=-2, image=0, vision=-2)
assert get_automation_model_eligibility(search_space) == get_model_eligibility(
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
)