mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-30 21:59:46 +02:00
feat(database-migrations): add migration to remove legacy model config tables and remove stale model connection code
This commit is contained in:
parent
50668775f8
commit
bd4a04f2e7
93 changed files with 956 additions and 11442 deletions
|
|
@ -1,6 +1,6 @@
|
|||
"""Lock the runtime model-policy backstop in ``build_dependencies``.
|
||||
|
||||
Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so
|
||||
Automations resolve their LLM from the *captured* ``chat_model_id`` snapshot (so
|
||||
runs are insulated from later chat/search-space model changes), and the model
|
||||
policy is re-checked at run time so a captured model that is no longer billable
|
||||
fails the run clearly. When no snapshot is present, resolution falls back to the
|
||||
|
|
@ -45,10 +45,10 @@ def patched_side_effects(monkeypatch: pytest.MonkeyPatch):
|
|||
return None
|
||||
|
||||
|
||||
async def test_build_dependencies_resolves_captured_agent_llm_id(
|
||||
async def test_build_dependencies_resolves_captured_chat_model_id(
|
||||
monkeypatch: pytest.MonkeyPatch, patched_side_effects
|
||||
) -> None:
|
||||
"""The bundle loads with the *captured* ``agent_llm_id``, not the live search space."""
|
||||
"""The bundle loads with the *captured* ``chat_model_id``, not the live search space."""
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def _fake_load(_session, *, config_id, search_space_id):
|
||||
|
|
@ -67,13 +67,13 @@ async def test_build_dependencies_resolves_captured_agent_llm_id(
|
|||
lambda _ss: pytest.fail("search-space policy should not run on captured path"),
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(agent_llm_id=-99)
|
||||
search_space = SimpleNamespace(chat_model_id=-99)
|
||||
result = await build_dependencies(
|
||||
session=_FakeSession(search_space),
|
||||
search_space_id=42,
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-7,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
|
||||
assert captured == {"config_id": -7, "search_space_id": 42}
|
||||
|
|
@ -98,17 +98,17 @@ async def test_build_dependencies_validates_captured_ids(
|
|||
monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load)
|
||||
|
||||
await build_dependencies(
|
||||
session=_FakeSession(SimpleNamespace(agent_llm_id=0)),
|
||||
session=_FakeSession(SimpleNamespace(chat_model_id=0)),
|
||||
search_space_id=42,
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-7,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
|
||||
assert seen == {
|
||||
"agent_llm_id": -7,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -7,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -119,7 +119,7 @@ async def test_build_dependencies_raises_on_captured_policy_violation(
|
|||
|
||||
def _raise(**_kw):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "image", "config_id": -2, "reason": "free model"}]
|
||||
[{"kind": "image", "model_id": -2, "reason": "free model"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(deps_mod, "assert_models_billable", _raise)
|
||||
|
|
@ -131,11 +131,11 @@ async def test_build_dependencies_raises_on_captured_policy_violation(
|
|||
|
||||
with pytest.raises(DependencyError):
|
||||
await build_dependencies(
|
||||
session=_FakeSession(SimpleNamespace(agent_llm_id=-7)),
|
||||
session=_FakeSession(SimpleNamespace(chat_model_id=-7)),
|
||||
search_space_id=42,
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=-2,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-7,
|
||||
image_gen_model_id=-2,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -157,7 +157,7 @@ async def test_build_dependencies_falls_back_to_search_space(
|
|||
lambda **_kw: pytest.fail("captured policy should not run on fallback path"),
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(agent_llm_id=-7)
|
||||
search_space = SimpleNamespace(chat_model_id=-7)
|
||||
result = await build_dependencies(
|
||||
session=_FakeSession(search_space), search_space_id=42
|
||||
)
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ def _run() -> SimpleNamespace:
|
|||
def test_build_action_ctx_propagates_captured_models() -> None:
|
||||
"""``definition.models`` flows onto the ActionContext model fields."""
|
||||
models = AutomationModels(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-1,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
ctx = _build_action_ctx(
|
||||
cast(AsyncSession, None),
|
||||
|
|
@ -40,9 +40,9 @@ def test_build_action_ctx_propagates_captured_models() -> None:
|
|||
)
|
||||
|
||||
assert ctx.search_space_id == 42
|
||||
assert ctx.agent_llm_id == -1
|
||||
assert ctx.image_generation_config_id == 5
|
||||
assert ctx.vision_llm_config_id == -1
|
||||
assert ctx.chat_model_id == -1
|
||||
assert ctx.image_gen_model_id == 5
|
||||
assert ctx.vision_model_id == -1
|
||||
|
||||
|
||||
def test_build_action_ctx_none_models_leaves_fields_none() -> None:
|
||||
|
|
@ -54,6 +54,6 @@ def test_build_action_ctx_none_models_leaves_fields_none() -> None:
|
|||
None,
|
||||
)
|
||||
|
||||
assert ctx.agent_llm_id is None
|
||||
assert ctx.image_generation_config_id is None
|
||||
assert ctx.vision_llm_config_id is None
|
||||
assert ctx.chat_model_id is None
|
||||
assert ctx.image_gen_model_id is None
|
||||
assert ctx.vision_model_id is None
|
||||
|
|
|
|||
|
|
@ -40,24 +40,24 @@ def test_automation_definition_models_round_trip() -> None:
|
|||
name="Daily digest",
|
||||
plan=[PlanStep(step_id="s1", action="agent_task")],
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-1,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
),
|
||||
)
|
||||
|
||||
dumped = definition.model_dump(mode="json", by_alias=True)
|
||||
assert dumped["models"] == {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
|
||||
restored = AutomationDefinition.model_validate(dumped)
|
||||
assert restored.models is not None
|
||||
assert restored.models.agent_llm_id == -1
|
||||
assert restored.models.image_generation_config_id == 5
|
||||
assert restored.models.vision_llm_config_id == -1
|
||||
assert restored.models.chat_model_id == -1
|
||||
assert restored.models.image_gen_model_id == 5
|
||||
assert restored.models.vision_model_id == -1
|
||||
|
||||
|
||||
def test_automation_definition_rejects_unknown_top_level_field() -> None:
|
||||
|
|
|
|||
|
|
@ -64,12 +64,12 @@ async def test_assert_models_billable_raises_422_on_violation(
|
|||
|
||||
def _raise(_ss):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "llm", "config_id": 0, "reason": "Auto mode"}]
|
||||
[{"kind": "llm", "model_id": 0, "reason": "Auto mode"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=0))
|
||||
service = _service(SimpleNamespace(chat_model_id=0))
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service._assert_models_billable(1)
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ async def test_assert_models_billable_returns_search_space_when_ok(
|
|||
automation_mod, "assert_automation_models_billable", lambda _ss: None
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(agent_llm_id=-1)
|
||||
search_space = SimpleNamespace(chat_model_id=-1)
|
||||
service = _service(search_space)
|
||||
assert await service._assert_models_billable(1) is search_space
|
||||
|
||||
|
|
@ -123,9 +123,9 @@ async def test_create_injects_captured_models_from_search_space(
|
|||
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
|
||||
|
||||
search_space = SimpleNamespace(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-1,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
service = _service(search_space)
|
||||
payload = AutomationCreate(
|
||||
|
|
@ -137,9 +137,9 @@ async def test_create_injects_captured_models_from_search_space(
|
|||
automation = await service.create(payload)
|
||||
|
||||
assert automation.definition["models"] == {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -162,9 +162,9 @@ async def test_create_treats_unset_prefs_as_auto_zero(
|
|||
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
|
||||
|
||||
search_space = SimpleNamespace(
|
||||
agent_llm_id=None,
|
||||
image_generation_config_id=None,
|
||||
vision_llm_config_id=None,
|
||||
chat_model_id=None,
|
||||
image_gen_model_id=None,
|
||||
vision_model_id=None,
|
||||
)
|
||||
service = _service(search_space)
|
||||
payload = AutomationCreate(search_space_id=1, name="A", definition=_definition())
|
||||
|
|
@ -172,9 +172,9 @@ async def test_create_treats_unset_prefs_as_auto_zero(
|
|||
automation = await service.create(payload)
|
||||
|
||||
assert automation.definition["models"] == {
|
||||
"agent_llm_id": 0,
|
||||
"image_generation_config_id": 0,
|
||||
"vision_llm_config_id": 0,
|
||||
"chat_model_id": 0,
|
||||
"image_gen_model_id": 0,
|
||||
"vision_model_id": 0,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -195,11 +195,11 @@ async def test_create_honors_selected_models_when_provided(
|
|||
)
|
||||
validated: dict[str, Any] = {}
|
||||
|
||||
def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id):
|
||||
validated["ids"] = (
|
||||
agent_llm_id,
|
||||
image_generation_config_id,
|
||||
vision_llm_config_id,
|
||||
chat_model_id,
|
||||
image_gen_model_id,
|
||||
vision_model_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok)
|
||||
|
|
@ -213,15 +213,15 @@ async def test_create_honors_selected_models_when_provided(
|
|||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=-99))
|
||||
service = _service(SimpleNamespace(chat_model_id=-99))
|
||||
payload = AutomationCreate(
|
||||
search_space_id=1,
|
||||
name="A",
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=7,
|
||||
vision_llm_config_id=-2,
|
||||
chat_model_id=-1,
|
||||
image_gen_model_id=7,
|
||||
vision_model_id=-2,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
|
@ -230,9 +230,9 @@ async def test_create_honors_selected_models_when_provided(
|
|||
|
||||
assert validated["ids"] == (-1, 7, -2)
|
||||
assert automation.definition["models"] == {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 7,
|
||||
"vision_llm_config_id": -2,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 7,
|
||||
"vision_model_id": -2,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -241,9 +241,9 @@ async def test_create_rejects_unbillable_selected_models(
|
|||
) -> None:
|
||||
"""A non-billable explicit selection maps the policy error to HTTP 422."""
|
||||
|
||||
def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
def _raise(*, chat_model_id, image_gen_model_id, vision_model_id):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "llm", "config_id": -3, "reason": "free model"}]
|
||||
[{"kind": "llm", "model_id": -3, "reason": "free model"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _raise)
|
||||
|
|
@ -253,15 +253,15 @@ async def test_create_rejects_unbillable_selected_models(
|
|||
|
||||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=-3))
|
||||
service = _service(SimpleNamespace(chat_model_id=-3))
|
||||
payload = AutomationCreate(
|
||||
search_space_id=1,
|
||||
name="A",
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-3,
|
||||
image_generation_config_id=7,
|
||||
vision_llm_config_id=-2,
|
||||
chat_model_id=-3,
|
||||
image_gen_model_id=7,
|
||||
vision_model_id=-2,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
|
@ -277,9 +277,9 @@ async def test_update_preserves_captured_models(
|
|||
) -> None:
|
||||
"""A definition edit carries over the previously captured ``models``."""
|
||||
captured = {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
existing = SimpleNamespace(
|
||||
search_space_id=1,
|
||||
|
|
@ -318,20 +318,20 @@ async def test_update_honors_changed_models_when_valid(
|
|||
"name": "A",
|
||||
"plan": [],
|
||||
"models": {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
},
|
||||
},
|
||||
version=3,
|
||||
)
|
||||
validated: dict[str, Any] = {}
|
||||
|
||||
def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id):
|
||||
validated["ids"] = (
|
||||
agent_llm_id,
|
||||
image_generation_config_id,
|
||||
vision_llm_config_id,
|
||||
chat_model_id,
|
||||
image_gen_model_id,
|
||||
vision_model_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok)
|
||||
|
|
@ -351,9 +351,9 @@ async def test_update_honors_changed_models_when_valid(
|
|||
patch = AutomationUpdate(
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-2,
|
||||
image_generation_config_id=9,
|
||||
vision_llm_config_id=-2,
|
||||
chat_model_id=-2,
|
||||
image_gen_model_id=9,
|
||||
vision_model_id=-2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -362,9 +362,9 @@ async def test_update_honors_changed_models_when_valid(
|
|||
|
||||
assert validated["ids"] == (-2, 9, -2)
|
||||
assert result.definition["models"] == {
|
||||
"agent_llm_id": -2,
|
||||
"image_generation_config_id": 9,
|
||||
"vision_llm_config_id": -2,
|
||||
"chat_model_id": -2,
|
||||
"image_gen_model_id": 9,
|
||||
"vision_model_id": -2,
|
||||
}
|
||||
assert result.version == 4
|
||||
|
||||
|
|
@ -379,17 +379,17 @@ async def test_update_rejects_changed_unbillable_models(
|
|||
"name": "A",
|
||||
"plan": [],
|
||||
"models": {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
},
|
||||
},
|
||||
version=3,
|
||||
)
|
||||
|
||||
def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
def _raise(*, chat_model_id, image_gen_model_id, vision_model_id):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "llm", "config_id": -7, "reason": "free model"}]
|
||||
[{"kind": "llm", "model_id": -7, "reason": "free model"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _raise)
|
||||
|
|
@ -409,9 +409,9 @@ async def test_update_rejects_changed_unbillable_models(
|
|||
patch = AutomationUpdate(
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-7,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -431,9 +431,9 @@ async def test_update_keeps_unchanged_models_without_revalidation(
|
|||
premium without an unrelated edit tripping the policy check.
|
||||
"""
|
||||
captured = {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
existing = SimpleNamespace(
|
||||
search_space_id=1,
|
||||
|
|
@ -485,7 +485,7 @@ async def test_model_eligibility_authorizes_and_returns_payload(
|
|||
lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]},
|
||||
)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=-2))
|
||||
service = _service(SimpleNamespace(chat_model_id=-2))
|
||||
result = await service.model_eligibility(search_space_id=5)
|
||||
|
||||
assert result == {"allowed": False, "violations": [{"kind": "image"}]}
|
||||
|
|
|
|||
|
|
@ -27,9 +27,9 @@ pytestmark = pytest.mark.unit
|
|||
def _search_space(*, llm: int | None, image: int | None, vision: int | None):
|
||||
"""Minimal stand-in for the ``SearchSpace`` ORM row the policy reads."""
|
||||
return SimpleNamespace(
|
||||
agent_llm_id=llm,
|
||||
image_generation_config_id=image,
|
||||
vision_llm_config_id=vision,
|
||||
chat_model_id=llm,
|
||||
image_gen_model_id=image,
|
||||
vision_model_id=vision,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -39,29 +39,11 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch):
|
|||
|
||||
Negative ids: -1 is premium, -2 is free, for each of llm/image/vision.
|
||||
"""
|
||||
llm_configs = {
|
||||
-1: {"id": -1, "billing_tier": "premium"},
|
||||
-2: {"id": -2, "billing_tier": "free"},
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
"app.agents.chat.runtime.llm_config.load_global_llm_config_by_id",
|
||||
lambda cid: llm_configs.get(cid),
|
||||
)
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
monkeypatch.setattr(
|
||||
app_config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
[
|
||||
{"id": -1, "billing_tier": "premium"},
|
||||
{"id": -2, "billing_tier": "free"},
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
"GLOBAL_MODELS",
|
||||
[
|
||||
{"id": -1, "billing_tier": "premium"},
|
||||
{"id": -2, "billing_tier": "free"},
|
||||
|
|
@ -71,7 +53,7 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch):
|
|||
return None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None:
|
||||
"""A positive config id is a user-owned BYOK model — always billable."""
|
||||
allowed, reason = model_policy._classify(kind, 7)
|
||||
|
|
@ -79,7 +61,7 @@ def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None:
|
|||
assert reason == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
@pytest.mark.parametrize("config_id", [0, None])
|
||||
def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None:
|
||||
"""Auto mode (id 0) and an unset slot (None) are blocked."""
|
||||
|
|
@ -88,7 +70,7 @@ def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None:
|
|||
assert "Auto mode" in reason
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
def test_premium_global_is_allowed(kind: str, patched_globals) -> None:
|
||||
"""A negative (global) id with premium billing tier is allowed."""
|
||||
allowed, reason = model_policy._classify(kind, -1)
|
||||
|
|
@ -96,7 +78,7 @@ def test_premium_global_is_allowed(kind: str, patched_globals) -> None:
|
|||
assert reason == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
def test_free_global_is_blocked(kind: str, patched_globals) -> None:
|
||||
"""A negative (global) id with a free billing tier is blocked."""
|
||||
allowed, reason = model_policy._classify(kind, -2)
|
||||
|
|
@ -104,7 +86,7 @@ def test_free_global_is_blocked(kind: str, patched_globals) -> None:
|
|||
assert "free model" in reason
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None:
|
||||
"""A negative id that resolves to no config is treated as not premium."""
|
||||
allowed, _ = model_policy._classify(kind, -999)
|
||||
|
|
@ -125,10 +107,10 @@ def test_eligibility_reports_each_violation(patched_globals) -> None:
|
|||
|
||||
assert result["allowed"] is False
|
||||
kinds = {v["kind"] for v in result["violations"]}
|
||||
assert kinds == {"llm", "image", "vision"}
|
||||
# config_id is echoed back for the UI / settings deep-link.
|
||||
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
|
||||
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
|
||||
assert kinds == {"chat", "image", "vision"}
|
||||
# model_id is echoed back for the UI / settings deep-link.
|
||||
by_kind = {v["kind"]: v["model_id"] for v in result["violations"]}
|
||||
assert by_kind == {"chat": -2, "image": 0, "vision": -2}
|
||||
|
||||
|
||||
def test_assert_raises_with_violations(patched_globals) -> None:
|
||||
|
|
@ -138,7 +120,7 @@ def test_assert_raises_with_violations(patched_globals) -> None:
|
|||
assert_automation_models_billable(search_space)
|
||||
|
||||
assert len(exc_info.value.violations) == 1
|
||||
assert exc_info.value.violations[0]["kind"] == "llm"
|
||||
assert exc_info.value.violations[0]["kind"] == "chat"
|
||||
|
||||
|
||||
def test_assert_passes_when_all_billable(patched_globals) -> None:
|
||||
|
|
@ -153,7 +135,7 @@ def test_assert_passes_when_all_billable(patched_globals) -> None:
|
|||
def test_get_model_eligibility_all_billable(patched_globals) -> None:
|
||||
"""Premium LLM + BYOK image + premium vision (explicit ids) → allowed."""
|
||||
result = get_model_eligibility(
|
||||
agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1
|
||||
chat_model_id=-1, image_gen_model_id=5, vision_model_id=-1
|
||||
)
|
||||
assert result == {"allowed": True, "violations": []}
|
||||
|
||||
|
|
@ -161,28 +143,28 @@ def test_get_model_eligibility_all_billable(patched_globals) -> None:
|
|||
def test_get_model_eligibility_reports_each_violation(patched_globals) -> None:
|
||||
"""Free LLM, Auto image, free vision (explicit ids) each produce a violation."""
|
||||
result = get_model_eligibility(
|
||||
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
|
||||
chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2
|
||||
)
|
||||
assert result["allowed"] is False
|
||||
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
|
||||
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
|
||||
by_kind = {v["kind"]: v["model_id"] for v in result["violations"]}
|
||||
assert by_kind == {"chat": -2, "image": 0, "vision": -2}
|
||||
|
||||
|
||||
def test_assert_models_billable_raises(patched_globals) -> None:
|
||||
"""``assert_models_billable`` raises when any explicit id is blocked."""
|
||||
with pytest.raises(AutomationModelPolicyError) as exc_info:
|
||||
assert_models_billable(
|
||||
agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1
|
||||
chat_model_id=0, image_gen_model_id=5, vision_model_id=-1
|
||||
)
|
||||
assert len(exc_info.value.violations) == 1
|
||||
assert exc_info.value.violations[0]["kind"] == "llm"
|
||||
assert exc_info.value.violations[0]["kind"] == "chat"
|
||||
|
||||
|
||||
def test_assert_models_billable_passes(patched_globals) -> None:
|
||||
"""No exception when every explicit id is premium or BYOK."""
|
||||
assert (
|
||||
assert_models_billable(
|
||||
agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4
|
||||
chat_model_id=3, image_gen_model_id=-1, vision_model_id=4
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
|
@ -192,5 +174,5 @@ def test_search_space_wrapper_delegates_to_core(patched_globals) -> None:
|
|||
"""The search-space wrapper produces the same result as the ID core."""
|
||||
search_space = _search_space(llm=-2, image=0, vision=-2)
|
||||
assert get_automation_model_eligibility(search_space) == get_model_eligibility(
|
||||
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
|
||||
chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue