refactor(provider-configuration): standardize provider parameter naming across various modules and improve quota error handling in tests

This commit is contained in:
Anish Sarkar 2026-06-13 14:23:32 +05:30
parent 4a6a282a46
commit e104193ddf
11 changed files with 160 additions and 58 deletions

View file

@ -153,13 +153,13 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
)
cap = derive_supports_image_input(
litellm_provider=cfg.get("litellm_provider"),
provider=cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
block = is_known_text_only_chat_model(
litellm_provider=cfg.get("litellm_provider"),
provider=cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),

View file

@ -76,8 +76,7 @@ async def test_quota_denial_fails_the_podcast_without_a_transcript(
async def _deny(**_kwargs):
raise QuotaInsufficientError(
usage_type="podcast_generation",
used_micros=5_000_000,
limit_micros=5_000_000,
balance_micros=0,
remaining_micros=0,
)
yield # pragma: no cover - unreachable, satisfies the CM protocol

View file

@ -45,8 +45,9 @@ class _FakeQuotaResult:
class _FakeExecResult:
def __init__(self, thread):
def __init__(self, *, thread=None, scalars=None):
self._thread = thread
self._scalars = scalars or []
def unique(self):
return self
@ -54,14 +55,21 @@ class _FakeExecResult:
def scalar_one_or_none(self):
return self._thread
def scalars(self):
return SimpleNamespace(all=lambda: self._scalars)
class _FakeSession:
def __init__(self, thread):
self.thread = thread
self.commit_count = 0
self.execute_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self.thread)
self.execute_count += 1
if self.execute_count == 1:
return _FakeExecResult(thread=self.thread)
return _FakeExecResult(scalars=[])
async def commit(self):
self.commit_count += 1
@ -71,6 +79,60 @@ def _thread(*, pinned: int | None = None):
return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned)
def _set_global_llm_configs(monkeypatch, config, configs: list[dict]):
from app.services.provider_capabilities import derive_supports_image_input
connections = []
models = []
for cfg in configs:
config_id = int(cfg["id"])
connection_id = config_id - 100_000
provider = cfg.get("provider") or cfg.get("litellm_provider")
model_name = cfg["model_name"]
if "supports_image_input" not in cfg:
litellm_params = cfg.get("litellm_params") or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
cfg["supports_image_input"] = derive_supports_image_input(
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
connections.append(
{
"id": connection_id,
"provider": provider,
"scope": "GLOBAL",
"enabled": True,
}
)
model = {
"id": config_id,
"connection_id": connection_id,
"model_id": model_name,
"display_name": cfg.get("name") or model_name,
"supports_chat": cfg.get("supports_chat", True),
"supports_tools": cfg.get("supports_tools", True),
"supports_image_generation": cfg.get("supports_image_generation", False),
"capabilities_override": cfg.get("capabilities_override") or {},
"billing_tier": cfg.get("billing_tier", "free"),
"catalog": {
"auto_pin_tier": cfg.get("auto_pin_tier"),
"quality_score": cfg.get("quality_score"),
},
"supports_image_input": cfg["supports_image_input"],
}
models.append(model)
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs)
monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections)
monkeypatch.setattr(config, "GLOBAL_MODELS", models)
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
return {
"id": id_,
@ -108,11 +170,7 @@ async def test_image_turn_filters_out_text_only_candidates(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2)],
)
_set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1), _vision_cfg(-2)])
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
_premium_allowed,
@ -140,11 +198,7 @@ async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2)],
)
_set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1), _vision_cfg(-2)])
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
_premium_allowed,
@ -172,9 +226,9 @@ async def test_image_turn_reuses_existing_vision_pin(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned=-2))
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)],
)
monkeypatch.setattr(
@ -203,10 +257,8 @@ async def test_image_turn_with_no_vision_candidates_raises(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _text_only_cfg(-2)],
_set_global_llm_configs(
monkeypatch, config, [_text_only_cfg(-1), _text_only_cfg(-2)]
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
@ -231,11 +283,7 @@ async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1)],
)
_set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1)])
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
_premium_allowed,
@ -269,7 +317,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
"quality_score": 80,
# NOTE: no supports_image_input key
}
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision])
_set_global_llm_configs(monkeypatch, config, [cfg_unannotated_vision])
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
_premium_allowed,

View file

@ -217,10 +217,64 @@ def test_auto_model_pin_candidates_include_dynamic_openrouter():
model_name="meta-llama/llama-3.3-70b:free",
billing_tier="free",
)
original = config.GLOBAL_LLM_CONFIGS
global_connections = [
{
"id": -110_001,
"provider": "openrouter",
"scope": "GLOBAL",
"enabled": True,
},
{
"id": -110_002,
"provider": "openrouter",
"scope": "GLOBAL",
"enabled": True,
},
]
global_models = [
{
"id": or_premium["id"],
"connection_id": -110_001,
"model_id": or_premium["model_name"],
"display_name": or_premium["name"],
"supports_chat": True,
"supports_image_input": True,
"supports_tools": True,
"supports_image_generation": False,
"capabilities_override": {},
"billing_tier": or_premium["billing_tier"],
"catalog": {
"auto_pin_tier": "A",
"quality_score": 50,
},
},
{
"id": or_free["id"],
"connection_id": -110_002,
"model_id": or_free["model_name"],
"display_name": or_free["name"],
"supports_chat": True,
"supports_image_input": True,
"supports_tools": True,
"supports_image_generation": False,
"capabilities_override": {},
"billing_tier": or_free["billing_tier"],
"catalog": {
"auto_pin_tier": "A",
"quality_score": 50,
},
},
]
original_configs = config.GLOBAL_LLM_CONFIGS
original_connections = config.GLOBAL_CONNECTIONS
original_models = config.GLOBAL_MODELS
try:
config.GLOBAL_LLM_CONFIGS = [or_premium, or_free]
config.GLOBAL_CONNECTIONS = global_connections
config.GLOBAL_MODELS = global_models
candidate_ids = {c["id"] for c in _global_candidates()}
assert candidate_ids == {-10_001, -10_002}
finally:
config.GLOBAL_LLM_CONFIGS = original
config.GLOBAL_LLM_CONFIGS = original_configs
config.GLOBAL_CONNECTIONS = original_connections
config.GLOBAL_MODELS = original_models

View file

@ -6,7 +6,7 @@ def test_openai_compatible_resolver_uses_explicit_api_base() -> None:
model, kwargs = to_litellm(
{
"protocol": "OPENAI_COMPATIBLE",
"litellm_provider": "openai",
"provider": "openai",
"base_url": "http://host.docker.internal:1234/v1",
"api_key": "local-key",
"extra": {},
@ -24,7 +24,7 @@ def test_ollama_resolver_uses_native_api_base() -> None:
model, kwargs = to_litellm(
{
"protocol": "OLLAMA",
"litellm_provider": "ollama_chat",
"provider": "ollama_chat",
"base_url": "http://host.docker.internal:11434",
"api_key": None,
"extra": {},
@ -62,7 +62,6 @@ def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> No
"billing_tier": "premium",
},
],
vision_configs=[],
image_configs=[],
)

View file

@ -25,7 +25,7 @@ def _or_cfg(
) -> dict:
return {
"id": cid,
"litellm_provider": "openrouter",
"provider": "openrouter",
"model_name": model_name,
"billing_tier": tier,
"auto_pin_tier": "B" if tier == "premium" else "C",

View file

@ -32,7 +32,7 @@ pytestmark = pytest.mark.unit
def test_or_modalities_with_image_returns_true():
assert (
derive_supports_image_input(
litellm_provider="openrouter",
provider="openrouter",
model_name="openai/gpt-4o",
openrouter_input_modalities=["text", "image"],
)
@ -43,7 +43,7 @@ def test_or_modalities_with_image_returns_true():
def test_or_modalities_text_only_returns_false():
assert (
derive_supports_image_input(
litellm_provider="openrouter",
provider="openrouter",
model_name="deepseek/deepseek-v3.2-exp",
openrouter_input_modalities=["text"],
)
@ -57,7 +57,7 @@ def test_or_modalities_empty_list_returns_false():
to LiteLLM."""
assert (
derive_supports_image_input(
litellm_provider="openrouter",
provider="openrouter",
model_name="weird/empty-modalities",
openrouter_input_modalities=[],
)
@ -70,7 +70,7 @@ def test_or_modalities_none_falls_through_to_litellm():
to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map."""
assert (
derive_supports_image_input(
litellm_provider="openai",
provider="openai",
model_name="gpt-4o",
openrouter_input_modalities=None,
)
@ -86,7 +86,7 @@ def test_or_modalities_none_falls_through_to_litellm():
def test_litellm_known_vision_model_returns_true():
assert (
derive_supports_image_input(
litellm_provider="openai",
provider="openai",
model_name="gpt-4o",
)
is True
@ -100,7 +100,7 @@ def test_litellm_base_model_wins_over_model_name():
doesn't know) would shadow the real capability."""
assert (
derive_supports_image_input(
litellm_provider="azure",
provider="azure",
model_name="my-azure-deployment-id",
base_model="gpt-4o",
)
@ -112,7 +112,7 @@ def test_litellm_unknown_model_default_allows():
"""Default-allow on unknown — the safety net is the actual block."""
assert (
derive_supports_image_input(
litellm_provider="custom",
provider="custom",
model_name="brand-new-model-x9-unmapped",
custom_provider="brand_new_proxy",
)
@ -128,7 +128,7 @@ def test_litellm_known_text_only_returns_false():
# Sanity: confirm the helper's negative path. We use a small model
# known not to support vision per the map.
result = derive_supports_image_input(
litellm_provider="openai",
provider="openai",
model_name="deepseek-chat",
)
# We accept either False (LiteLLM said explicit no) or True
@ -147,7 +147,7 @@ def test_litellm_known_text_only_returns_false():
def test_is_known_text_only_returns_false_for_vision_model():
assert (
is_known_text_only_chat_model(
litellm_provider="openai",
provider="openai",
model_name="gpt-4o",
)
is False
@ -160,7 +160,7 @@ def test_is_known_text_only_returns_false_for_unknown_model():
fixing."""
assert (
is_known_text_only_chat_model(
litellm_provider="custom",
provider="custom",
model_name="brand-new-model-x9-unmapped",
custom_provider="brand_new_proxy",
)
@ -181,7 +181,7 @@ def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch):
assert (
is_known_text_only_chat_model(
litellm_provider="openai",
provider="openai",
model_name="gpt-4o",
)
is False
@ -201,7 +201,7 @@ def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch):
assert (
is_known_text_only_chat_model(
litellm_provider="openai",
provider="openai",
model_name="any-model",
)
is True
@ -218,7 +218,7 @@ def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch):
assert (
is_known_text_only_chat_model(
litellm_provider="openai",
provider="openai",
model_name="any-model",
)
is False
@ -237,7 +237,7 @@ def test_is_known_text_only_returns_false_on_missing_key(monkeypatch):
assert (
is_known_text_only_chat_model(
litellm_provider="openai",
provider="openai",
model_name="any-model",
)
is False

View file

@ -105,8 +105,7 @@ async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
async def _denying_billable_call(**_kwargs):
raise QuotaInsufficientError(
usage_type="vision_extraction",
used_micros=5_000_000,
limit_micros=5_000_000,
balance_micros=0,
remaining_micros=0,
)
yield # unreachable but required for asynccontextmanager type

View file

@ -131,6 +131,10 @@ def test_serialized_calls_includes_cost_micros():
assert serialized == [
{
"model": "m",
"model_ref": None,
"model_id": None,
"display_name": None,
"provider": None,
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,

View file

@ -35,7 +35,7 @@ def test_safety_net_does_not_fire_for_azure_gpt_4o():
it text-only."""
assert (
is_known_text_only_chat_model(
litellm_provider="azure",
provider="azure",
model_name="my-azure-deployment",
base_model="gpt-4o",
)
@ -49,7 +49,7 @@ def test_safety_net_does_not_fire_for_unknown_model():
LiteLLM doesn't know about must flow through to the provider."""
assert (
is_known_text_only_chat_model(
litellm_provider="custom",
provider="custom",
custom_provider="brand_new_proxy",
model_name="brand-new-model-x9",
)
@ -69,7 +69,7 @@ def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
assert (
is_known_text_only_chat_model(
litellm_provider="openai",
provider="openai",
model_name="gpt-4o",
)
is False
@ -88,7 +88,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
assert (
is_known_text_only_chat_model(
litellm_provider="openai",
provider="openai",
model_name="text-only-stub",
)
is True
@ -100,7 +100,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
assert (
is_known_text_only_chat_model(
litellm_provider="openai",
provider="openai",
model_name="vision-stub",
)
is False
@ -112,7 +112,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
assert (
is_known_text_only_chat_model(
litellm_provider="openai",
provider="openai",
model_name="missing-key-stub",
)
is False

View file

@ -98,8 +98,7 @@ async def _denying_billable_call(**kwargs):
_CALL_LOG.append(kwargs)
raise QuotaInsufficientError(
usage_type=kwargs.get("usage_type", "?"),
used_micros=5_000_000,
limit_micros=5_000_000,
balance_micros=0,
remaining_micros=0,
)
yield SimpleNamespace() # pragma: no cover