feat(automations): implement model eligibility checks for automation creation

- Added model eligibility checks to ensure automations can only use billable models (premium or BYOK).
- Introduced new API endpoint to report model eligibility status for search spaces.
- Updated frontend components to display eligibility alerts and disable creation options when models are not billable.
- Enhanced automation creation forms to reflect model eligibility, preventing users from submitting invalid configurations.
- Implemented server-side logic to capture and preserve model preferences across automation edits, ensuring consistent behavior during execution.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-29 03:13:46 -07:00
parent 5d90fbe99f
commit 409fec94c3
32 changed files with 1451 additions and 67 deletions

View file

@ -0,0 +1,236 @@
"""Lock creation-time model-policy enforcement in ``AutomationService``.
Creation (REST + manual builder) rejects search spaces whose models aren't
billable for automations with HTTP 422, mirroring the runtime backstop. These
tests isolate the new ``_assert_models_billable`` / ``model_eligibility`` paths
without touching the DB commit.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
import pytest
from fastapi import HTTPException
import app.automations.services.automation as automation_mod
from app.automations.schemas.api import AutomationCreate, AutomationUpdate
from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.schemas.definition.plan_step import PlanStep
from app.automations.services.automation import AutomationService
from app.automations.services.model_policy import AutomationModelPolicyError
pytestmark = pytest.mark.unit
class _FakeSession:
def __init__(self, search_space: Any) -> None:
self._search_space = search_space
self.added: list[Any] = []
self.commits = 0
async def get(self, _model: Any, _pk: int) -> Any:
return self._search_space
def add(self, obj: Any) -> None:
self.added.append(obj)
async def commit(self) -> None:
self.commits += 1
def _service(search_space: Any) -> AutomationService:
return AutomationService(
session=_FakeSession(search_space), user=SimpleNamespace(id="u-1")
)
def _definition(**kwargs: Any) -> AutomationDefinition:
return AutomationDefinition(
name="A",
plan=[PlanStep(step_id="s1", action="agent_task")],
**kwargs,
)
async def test_assert_models_billable_raises_422_on_violation(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A blocked model maps the policy error to HTTP 422."""
def _raise(_ss):
raise AutomationModelPolicyError(
[{"kind": "llm", "config_id": 0, "reason": "Auto mode"}]
)
monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise)
service = _service(SimpleNamespace(agent_llm_id=0))
with pytest.raises(HTTPException) as exc_info:
await service._assert_models_billable(1)
assert exc_info.value.status_code == 422
async def test_assert_models_billable_raises_404_when_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A missing search space is a 404, not a policy error."""
monkeypatch.setattr(
automation_mod, "assert_automation_models_billable", lambda _ss: None
)
service = _service(None)
with pytest.raises(HTTPException) as exc_info:
await service._assert_models_billable(999)
assert exc_info.value.status_code == 404
async def test_assert_models_billable_returns_search_space_when_ok(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the policy accepts, the loaded search space is returned for reuse."""
monkeypatch.setattr(
automation_mod, "assert_automation_models_billable", lambda _ss: None
)
search_space = SimpleNamespace(agent_llm_id=-1)
service = _service(search_space)
assert await service._assert_models_billable(1) is search_space
async def test_create_injects_captured_models_from_search_space(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""create() snapshots the search space's model prefs onto the definition."""
monkeypatch.setattr(
automation_mod, "assert_automation_models_billable", lambda _ss: None
)
async def _noop_authorize(self, *_a, **_k):
return None
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
async def _return_added(self, _aid):
return self.session.added[-1]
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
search_space = SimpleNamespace(
agent_llm_id=-1,
image_generation_config_id=5,
vision_llm_config_id=-1,
)
service = _service(search_space)
payload = AutomationCreate(
search_space_id=1,
name="A",
definition=_definition(),
)
automation = await service.create(payload)
assert automation.definition["models"] == {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
}
async def test_create_treats_unset_prefs_as_auto_zero(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""``None`` search-space prefs are captured as ``0`` (Auto) ids."""
monkeypatch.setattr(
automation_mod, "assert_automation_models_billable", lambda _ss: None
)
async def _noop_authorize(self, *_a, **_k):
return None
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
async def _return_added(self, _aid):
return self.session.added[-1]
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
search_space = SimpleNamespace(
agent_llm_id=None,
image_generation_config_id=None,
vision_llm_config_id=None,
)
service = _service(search_space)
payload = AutomationCreate(search_space_id=1, name="A", definition=_definition())
automation = await service.create(payload)
assert automation.definition["models"] == {
"agent_llm_id": 0,
"image_generation_config_id": 0,
"vision_llm_config_id": 0,
}
async def test_update_preserves_captured_models(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A definition edit carries over the previously captured ``models``."""
captured = {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
}
existing = SimpleNamespace(
search_space_id=1,
definition={"name": "A", "plan": [], "models": captured},
version=3,
)
async def _noop_authorize(self, *_a, **_k):
return None
async def _return_existing(self, _aid):
return existing
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
monkeypatch.setattr(
AutomationService, "_get_with_triggers_or_raise", _return_existing
)
service = _service(SimpleNamespace())
# The incoming patch definition has no ``models`` (frontend strips it).
patch = AutomationUpdate(definition=_definition())
result = await service.update(7, patch)
assert result.definition["models"] == captured
assert result.version == 4
async def test_model_eligibility_authorizes_and_returns_payload(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""``model_eligibility`` checks read access then returns the eligibility dict."""
authorized: dict[str, Any] = {}
async def _fake_check_permission(_session, _user, ss_id, permission, _msg):
authorized["search_space_id"] = ss_id
authorized["permission"] = permission
monkeypatch.setattr(automation_mod, "check_permission", _fake_check_permission)
monkeypatch.setattr(
automation_mod,
"get_automation_model_eligibility",
lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]},
)
service = _service(SimpleNamespace(agent_llm_id=-2))
result = await service.model_eligibility(search_space_id=5)
assert result == {"allowed": False, "violations": [{"kind": "image"}]}
assert authorized["search_space_id"] == 5
assert authorized["permission"] == "automations:read"

View file

@ -0,0 +1,196 @@
"""Lock the automation model-billing policy.
Automations may only run on billable models: premium global configs
(``billing_tier == "premium"``) or user BYOK configs (positive id). Free
globals and Auto mode (id == 0 / None) are blocked. These tests pin that rule
across all three model slots (chat LLM, image, vision).
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
import app.automations.services.model_policy as model_policy
from app.automations.services.model_policy import (
AutomationModelPolicyError,
assert_automation_models_billable,
assert_models_billable,
get_automation_model_eligibility,
get_model_eligibility,
)
pytestmark = pytest.mark.unit
def _search_space(*, llm: int | None, image: int | None, vision: int | None):
"""Minimal stand-in for the ``SearchSpace`` ORM row the policy reads."""
return SimpleNamespace(
agent_llm_id=llm,
image_generation_config_id=image,
vision_llm_config_id=vision,
)
@pytest.fixture
def patched_globals(monkeypatch: pytest.MonkeyPatch):
"""Stub the global config sources the policy consults for negative ids.
Negative ids: -1 is premium, -2 is free, for each of llm/image/vision.
"""
llm_configs = {
-1: {"id": -1, "billing_tier": "premium"},
-2: {"id": -2, "billing_tier": "free"},
}
monkeypatch.setattr(
"app.agents.new_chat.llm_config.load_global_llm_config_by_id",
lambda cid: llm_configs.get(cid),
)
from app.config import config as app_config
monkeypatch.setattr(
app_config,
"GLOBAL_IMAGE_GEN_CONFIGS",
[
{"id": -1, "billing_tier": "premium"},
{"id": -2, "billing_tier": "free"},
],
raising=False,
)
monkeypatch.setattr(
app_config,
"GLOBAL_VISION_LLM_CONFIGS",
[
{"id": -1, "billing_tier": "premium"},
{"id": -2, "billing_tier": "free"},
],
raising=False,
)
return None
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None:
"""A positive config id is a user-owned BYOK model — always billable."""
allowed, reason = model_policy._classify(kind, 7)
assert allowed is True
assert reason == ""
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
@pytest.mark.parametrize("config_id", [0, None])
def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None:
"""Auto mode (id 0) and an unset slot (None) are blocked."""
allowed, reason = model_policy._classify(kind, config_id)
assert allowed is False
assert "Auto mode" in reason
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
def test_premium_global_is_allowed(kind: str, patched_globals) -> None:
"""A negative (global) id with premium billing tier is allowed."""
allowed, reason = model_policy._classify(kind, -1)
assert allowed is True
assert reason == ""
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
def test_free_global_is_blocked(kind: str, patched_globals) -> None:
"""A negative (global) id with a free billing tier is blocked."""
allowed, reason = model_policy._classify(kind, -2)
assert allowed is False
assert "free model" in reason
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None:
"""A negative id that resolves to no config is treated as not premium."""
allowed, _ = model_policy._classify(kind, -999)
assert allowed is False
def test_eligibility_all_billable(patched_globals) -> None:
"""Premium LLM + BYOK image + premium vision → allowed, no violations."""
search_space = _search_space(llm=-1, image=5, vision=-1)
result = get_automation_model_eligibility(search_space)
assert result == {"allowed": True, "violations": []}
def test_eligibility_reports_each_violation(patched_globals) -> None:
"""A free LLM, Auto image, and free vision each produce a violation."""
search_space = _search_space(llm=-2, image=0, vision=-2)
result = get_automation_model_eligibility(search_space)
assert result["allowed"] is False
kinds = {v["kind"] for v in result["violations"]}
assert kinds == {"llm", "image", "vision"}
# config_id is echoed back for the UI / settings deep-link.
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
def test_assert_raises_with_violations(patched_globals) -> None:
"""``assert_automation_models_billable`` raises when any slot is blocked."""
search_space = _search_space(llm=0, image=5, vision=-1)
with pytest.raises(AutomationModelPolicyError) as exc_info:
assert_automation_models_billable(search_space)
assert len(exc_info.value.violations) == 1
assert exc_info.value.violations[0]["kind"] == "llm"
def test_assert_passes_when_all_billable(patched_globals) -> None:
"""No exception when every slot is premium or BYOK."""
search_space = _search_space(llm=3, image=-1, vision=4)
assert assert_automation_models_billable(search_space) is None
# --- ID-based core (used by the runtime backstop against captured snapshots) ---
def test_get_model_eligibility_all_billable(patched_globals) -> None:
"""Premium LLM + BYOK image + premium vision (explicit ids) → allowed."""
result = get_model_eligibility(
agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1
)
assert result == {"allowed": True, "violations": []}
def test_get_model_eligibility_reports_each_violation(patched_globals) -> None:
"""Free LLM, Auto image, free vision (explicit ids) each produce a violation."""
result = get_model_eligibility(
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
)
assert result["allowed"] is False
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
def test_assert_models_billable_raises(patched_globals) -> None:
"""``assert_models_billable`` raises when any explicit id is blocked."""
with pytest.raises(AutomationModelPolicyError) as exc_info:
assert_models_billable(
agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1
)
assert len(exc_info.value.violations) == 1
assert exc_info.value.violations[0]["kind"] == "llm"
def test_assert_models_billable_passes(patched_globals) -> None:
"""No exception when every explicit id is premium or BYOK."""
assert (
assert_models_billable(
agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4
)
is None
)
def test_search_space_wrapper_delegates_to_core(patched_globals) -> None:
"""The search-space wrapper produces the same result as the ID core."""
search_space = _search_space(llm=-2, image=0, vision=-2)
assert get_automation_model_eligibility(search_space) == get_model_eligibility(
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
)