mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-14 20:55:15 +02:00
refactor(provider-configuration): standardize provider parameter naming across various modules and improve quota error handling in tests
This commit is contained in:
parent
4a6a282a46
commit
e104193ddf
11 changed files with 160 additions and 58 deletions
|
|
@ -153,13 +153,13 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
|
|||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
)
|
||||
cap = derive_supports_image_input(
|
||||
litellm_provider=cfg.get("litellm_provider"),
|
||||
provider=cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
block = is_known_text_only_chat_model(
|
||||
litellm_provider=cfg.get("litellm_provider"),
|
||||
provider=cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
|
|
|
|||
|
|
@ -76,8 +76,7 @@ async def test_quota_denial_fails_the_podcast_without_a_transcript(
|
|||
async def _deny(**_kwargs):
|
||||
raise QuotaInsufficientError(
|
||||
usage_type="podcast_generation",
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
balance_micros=0,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield # pragma: no cover - unreachable, satisfies the CM protocol
|
||||
|
|
|
|||
|
|
@ -45,8 +45,9 @@ class _FakeQuotaResult:
|
|||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, thread):
|
||||
def __init__(self, *, thread=None, scalars=None):
|
||||
self._thread = thread
|
||||
self._scalars = scalars or []
|
||||
|
||||
def unique(self):
|
||||
return self
|
||||
|
|
@ -54,14 +55,21 @@ class _FakeExecResult:
|
|||
def scalar_one_or_none(self):
|
||||
return self._thread
|
||||
|
||||
def scalars(self):
|
||||
return SimpleNamespace(all=lambda: self._scalars)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, thread):
|
||||
self.thread = thread
|
||||
self.commit_count = 0
|
||||
self.execute_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self.thread)
|
||||
self.execute_count += 1
|
||||
if self.execute_count == 1:
|
||||
return _FakeExecResult(thread=self.thread)
|
||||
return _FakeExecResult(scalars=[])
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
|
@ -71,6 +79,60 @@ def _thread(*, pinned: int | None = None):
|
|||
return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned)
|
||||
|
||||
|
||||
def _set_global_llm_configs(monkeypatch, config, configs: list[dict]):
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
connections = []
|
||||
models = []
|
||||
for cfg in configs:
|
||||
config_id = int(cfg["id"])
|
||||
connection_id = config_id - 100_000
|
||||
provider = cfg.get("provider") or cfg.get("litellm_provider")
|
||||
model_name = cfg["model_name"]
|
||||
if "supports_image_input" not in cfg:
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
cfg["supports_image_input"] = derive_supports_image_input(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
connections.append(
|
||||
{
|
||||
"id": connection_id,
|
||||
"provider": provider,
|
||||
"scope": "GLOBAL",
|
||||
"enabled": True,
|
||||
}
|
||||
)
|
||||
model = {
|
||||
"id": config_id,
|
||||
"connection_id": connection_id,
|
||||
"model_id": model_name,
|
||||
"display_name": cfg.get("name") or model_name,
|
||||
"supports_chat": cfg.get("supports_chat", True),
|
||||
"supports_tools": cfg.get("supports_tools", True),
|
||||
"supports_image_generation": cfg.get("supports_image_generation", False),
|
||||
"capabilities_override": cfg.get("capabilities_override") or {},
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"catalog": {
|
||||
"auto_pin_tier": cfg.get("auto_pin_tier"),
|
||||
"quality_score": cfg.get("quality_score"),
|
||||
},
|
||||
"supports_image_input": cfg["supports_image_input"],
|
||||
}
|
||||
models.append(model)
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs)
|
||||
monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections)
|
||||
monkeypatch.setattr(config, "GLOBAL_MODELS", models)
|
||||
|
||||
|
||||
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
|
||||
return {
|
||||
"id": id_,
|
||||
|
|
@ -108,11 +170,7 @@ async def test_image_turn_filters_out_text_only_candidates(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2)],
|
||||
)
|
||||
_set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1), _vision_cfg(-2)])
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
|
||||
_premium_allowed,
|
||||
|
|
@ -140,11 +198,7 @@ async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2)],
|
||||
)
|
||||
_set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1), _vision_cfg(-2)])
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
|
||||
_premium_allowed,
|
||||
|
|
@ -172,9 +226,9 @@ async def test_image_turn_reuses_existing_vision_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned=-2))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
|
|
@ -203,10 +257,8 @@ async def test_image_turn_with_no_vision_candidates_raises(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _text_only_cfg(-2)],
|
||||
_set_global_llm_configs(
|
||||
monkeypatch, config, [_text_only_cfg(-1), _text_only_cfg(-2)]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
|
||||
|
|
@ -231,11 +283,7 @@ async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1)],
|
||||
)
|
||||
_set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1)])
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
|
||||
_premium_allowed,
|
||||
|
|
@ -269,7 +317,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
|
|||
"quality_score": 80,
|
||||
# NOTE: no supports_image_input key
|
||||
}
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision])
|
||||
_set_global_llm_configs(monkeypatch, config, [cfg_unannotated_vision])
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
|
||||
_premium_allowed,
|
||||
|
|
|
|||
|
|
@ -217,10 +217,64 @@ def test_auto_model_pin_candidates_include_dynamic_openrouter():
|
|||
model_name="meta-llama/llama-3.3-70b:free",
|
||||
billing_tier="free",
|
||||
)
|
||||
original = config.GLOBAL_LLM_CONFIGS
|
||||
global_connections = [
|
||||
{
|
||||
"id": -110_001,
|
||||
"provider": "openrouter",
|
||||
"scope": "GLOBAL",
|
||||
"enabled": True,
|
||||
},
|
||||
{
|
||||
"id": -110_002,
|
||||
"provider": "openrouter",
|
||||
"scope": "GLOBAL",
|
||||
"enabled": True,
|
||||
},
|
||||
]
|
||||
global_models = [
|
||||
{
|
||||
"id": or_premium["id"],
|
||||
"connection_id": -110_001,
|
||||
"model_id": or_premium["model_name"],
|
||||
"display_name": or_premium["name"],
|
||||
"supports_chat": True,
|
||||
"supports_image_input": True,
|
||||
"supports_tools": True,
|
||||
"supports_image_generation": False,
|
||||
"capabilities_override": {},
|
||||
"billing_tier": or_premium["billing_tier"],
|
||||
"catalog": {
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 50,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": or_free["id"],
|
||||
"connection_id": -110_002,
|
||||
"model_id": or_free["model_name"],
|
||||
"display_name": or_free["name"],
|
||||
"supports_chat": True,
|
||||
"supports_image_input": True,
|
||||
"supports_tools": True,
|
||||
"supports_image_generation": False,
|
||||
"capabilities_override": {},
|
||||
"billing_tier": or_free["billing_tier"],
|
||||
"catalog": {
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 50,
|
||||
},
|
||||
},
|
||||
]
|
||||
original_configs = config.GLOBAL_LLM_CONFIGS
|
||||
original_connections = config.GLOBAL_CONNECTIONS
|
||||
original_models = config.GLOBAL_MODELS
|
||||
try:
|
||||
config.GLOBAL_LLM_CONFIGS = [or_premium, or_free]
|
||||
config.GLOBAL_CONNECTIONS = global_connections
|
||||
config.GLOBAL_MODELS = global_models
|
||||
candidate_ids = {c["id"] for c in _global_candidates()}
|
||||
assert candidate_ids == {-10_001, -10_002}
|
||||
finally:
|
||||
config.GLOBAL_LLM_CONFIGS = original
|
||||
config.GLOBAL_LLM_CONFIGS = original_configs
|
||||
config.GLOBAL_CONNECTIONS = original_connections
|
||||
config.GLOBAL_MODELS = original_models
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ def test_openai_compatible_resolver_uses_explicit_api_base() -> None:
|
|||
model, kwargs = to_litellm(
|
||||
{
|
||||
"protocol": "OPENAI_COMPATIBLE",
|
||||
"litellm_provider": "openai",
|
||||
"provider": "openai",
|
||||
"base_url": "http://host.docker.internal:1234/v1",
|
||||
"api_key": "local-key",
|
||||
"extra": {},
|
||||
|
|
@ -24,7 +24,7 @@ def test_ollama_resolver_uses_native_api_base() -> None:
|
|||
model, kwargs = to_litellm(
|
||||
{
|
||||
"protocol": "OLLAMA",
|
||||
"litellm_provider": "ollama_chat",
|
||||
"provider": "ollama_chat",
|
||||
"base_url": "http://host.docker.internal:11434",
|
||||
"api_key": None,
|
||||
"extra": {},
|
||||
|
|
@ -62,7 +62,6 @@ def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> No
|
|||
"billing_tier": "premium",
|
||||
},
|
||||
],
|
||||
vision_configs=[],
|
||||
image_configs=[],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def _or_cfg(
|
|||
) -> dict:
|
||||
return {
|
||||
"id": cid,
|
||||
"litellm_provider": "openrouter",
|
||||
"provider": "openrouter",
|
||||
"model_name": model_name,
|
||||
"billing_tier": tier,
|
||||
"auto_pin_tier": "B" if tier == "premium" else "C",
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ pytestmark = pytest.mark.unit
|
|||
def test_or_modalities_with_image_returns_true():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
litellm_provider="openrouter",
|
||||
provider="openrouter",
|
||||
model_name="openai/gpt-4o",
|
||||
openrouter_input_modalities=["text", "image"],
|
||||
)
|
||||
|
|
@ -43,7 +43,7 @@ def test_or_modalities_with_image_returns_true():
|
|||
def test_or_modalities_text_only_returns_false():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
litellm_provider="openrouter",
|
||||
provider="openrouter",
|
||||
model_name="deepseek/deepseek-v3.2-exp",
|
||||
openrouter_input_modalities=["text"],
|
||||
)
|
||||
|
|
@ -57,7 +57,7 @@ def test_or_modalities_empty_list_returns_false():
|
|||
to LiteLLM."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
litellm_provider="openrouter",
|
||||
provider="openrouter",
|
||||
model_name="weird/empty-modalities",
|
||||
openrouter_input_modalities=[],
|
||||
)
|
||||
|
|
@ -70,7 +70,7 @@ def test_or_modalities_none_falls_through_to_litellm():
|
|||
to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
litellm_provider="openai",
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
openrouter_input_modalities=None,
|
||||
)
|
||||
|
|
@ -86,7 +86,7 @@ def test_or_modalities_none_falls_through_to_litellm():
|
|||
def test_litellm_known_vision_model_returns_true():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
litellm_provider="openai",
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is True
|
||||
|
|
@ -100,7 +100,7 @@ def test_litellm_base_model_wins_over_model_name():
|
|||
doesn't know) would shadow the real capability."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
litellm_provider="azure",
|
||||
provider="azure",
|
||||
model_name="my-azure-deployment-id",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
|
|
@ -112,7 +112,7 @@ def test_litellm_unknown_model_default_allows():
|
|||
"""Default-allow on unknown — the safety net is the actual block."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
litellm_provider="custom",
|
||||
provider="custom",
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
|
|
@ -128,7 +128,7 @@ def test_litellm_known_text_only_returns_false():
|
|||
# Sanity: confirm the helper's negative path. We use a small model
|
||||
# known not to support vision per the map.
|
||||
result = derive_supports_image_input(
|
||||
litellm_provider="openai",
|
||||
provider="openai",
|
||||
model_name="deepseek-chat",
|
||||
)
|
||||
# We accept either False (LiteLLM said explicit no) or True
|
||||
|
|
@ -147,7 +147,7 @@ def test_litellm_known_text_only_returns_false():
|
|||
def test_is_known_text_only_returns_false_for_vision_model():
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
litellm_provider="openai",
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
|
|
@ -160,7 +160,7 @@ def test_is_known_text_only_returns_false_for_unknown_model():
|
|||
fixing."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
litellm_provider="custom",
|
||||
provider="custom",
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
|
|
@ -181,7 +181,7 @@ def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch):
|
|||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
litellm_provider="openai",
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
|
|
@ -201,7 +201,7 @@ def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch):
|
|||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
litellm_provider="openai",
|
||||
provider="openai",
|
||||
model_name="any-model",
|
||||
)
|
||||
is True
|
||||
|
|
@ -218,7 +218,7 @@ def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch):
|
|||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
litellm_provider="openai",
|
||||
provider="openai",
|
||||
model_name="any-model",
|
||||
)
|
||||
is False
|
||||
|
|
@ -237,7 +237,7 @@ def test_is_known_text_only_returns_false_on_missing_key(monkeypatch):
|
|||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
litellm_provider="openai",
|
||||
provider="openai",
|
||||
model_name="any-model",
|
||||
)
|
||||
is False
|
||||
|
|
|
|||
|
|
@ -105,8 +105,7 @@ async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
|
|||
async def _denying_billable_call(**_kwargs):
|
||||
raise QuotaInsufficientError(
|
||||
usage_type="vision_extraction",
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
balance_micros=0,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield # unreachable but required for asynccontextmanager type
|
||||
|
|
|
|||
|
|
@ -131,6 +131,10 @@ def test_serialized_calls_includes_cost_micros():
|
|||
assert serialized == [
|
||||
{
|
||||
"model": "m",
|
||||
"model_ref": None,
|
||||
"model_id": None,
|
||||
"display_name": None,
|
||||
"provider": None,
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ def test_safety_net_does_not_fire_for_azure_gpt_4o():
|
|||
it text-only."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
litellm_provider="azure",
|
||||
provider="azure",
|
||||
model_name="my-azure-deployment",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
|
|
@ -49,7 +49,7 @@ def test_safety_net_does_not_fire_for_unknown_model():
|
|||
LiteLLM doesn't know about must flow through to the provider."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
litellm_provider="custom",
|
||||
provider="custom",
|
||||
custom_provider="brand_new_proxy",
|
||||
model_name="brand-new-model-x9",
|
||||
)
|
||||
|
|
@ -69,7 +69,7 @@ def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
|
|||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
litellm_provider="openai",
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
|
|
@ -88,7 +88,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
|
|||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
litellm_provider="openai",
|
||||
provider="openai",
|
||||
model_name="text-only-stub",
|
||||
)
|
||||
is True
|
||||
|
|
@ -100,7 +100,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
|
|||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
litellm_provider="openai",
|
||||
provider="openai",
|
||||
model_name="vision-stub",
|
||||
)
|
||||
is False
|
||||
|
|
@ -112,7 +112,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
|
|||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
litellm_provider="openai",
|
||||
provider="openai",
|
||||
model_name="missing-key-stub",
|
||||
)
|
||||
is False
|
||||
|
|
|
|||
|
|
@ -98,8 +98,7 @@ async def _denying_billable_call(**kwargs):
|
|||
_CALL_LOG.append(kwargs)
|
||||
raise QuotaInsufficientError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
balance_micros=0,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield SimpleNamespace() # pragma: no cover
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue