mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-12 20:45:20 +02:00
fix(chat): harden image generation model routing
This commit is contained in:
parent
c28c4f5785
commit
831ad23c6c
7 changed files with 156 additions and 171 deletions
|
|
@ -23,20 +23,26 @@ from app.db import (
|
|||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
from app.services.auto_model_pin_service import (
|
||||
auto_model_candidates,
|
||||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.model_resolver import to_litellm
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image gen config by negative ID."""
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
def _get_global_model(model_id: int) -> dict | None:
|
||||
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
|
||||
|
||||
|
||||
def _get_global_connection(connection_id: int) -> dict | None:
|
||||
return next(
|
||||
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def create_generate_image_tool(
|
||||
|
|
@ -93,6 +99,16 @@ def create_generate_image_tool(
|
|||
# task's session is shared across every tool; without isolation,
|
||||
# autoflushes from a concurrent writer poison this tool too.
|
||||
async with shielded_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return _failed(
|
||||
{"error": "Search space not found"},
|
||||
error="Search space not found",
|
||||
)
|
||||
|
||||
if image_gen_model_id_override is not None:
|
||||
# Automation run: use the captured image model, insulated from
|
||||
# later search-space changes. No search-space read needed.
|
||||
|
|
@ -100,16 +116,6 @@ def create_generate_image_tool(
|
|||
image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
else:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return _failed(
|
||||
{"error": "Search space not found"},
|
||||
error="Search space not found",
|
||||
)
|
||||
|
||||
config_id = (
|
||||
search_space.image_gen_model_id
|
||||
or IMAGE_GEN_AUTO_MODE_ID
|
||||
|
|
@ -122,24 +128,39 @@ def create_generate_image_tool(
|
|||
gen_kwargs["n"] = n
|
||||
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=search_space.user_id,
|
||||
capability="image_gen",
|
||||
)
|
||||
if not candidates:
|
||||
err = (
|
||||
"No image generation models configured. "
|
||||
"No image generation models available. "
|
||||
"Please add an image model in Settings > Image Models."
|
||||
)
|
||||
return _failed({"error": err}, error=err)
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=prompt, model="auto", **gen_kwargs
|
||||
config_id = int(
|
||||
choose_auto_model_candidate(candidates, search_space_id)["id"]
|
||||
)
|
||||
elif config_id < 0:
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
err = f"Image generation config {config_id} not found"
|
||||
|
||||
if config_id < 0:
|
||||
global_model = _get_global_model(config_id)
|
||||
if not global_model or not (
|
||||
global_model.get("capabilities") or {}
|
||||
).get("image_gen"):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
global_connection = _get_global_connection(
|
||||
global_model["connection_id"]
|
||||
)
|
||||
if not global_connection:
|
||||
err = f"Image generation connection for model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
native_connection_from_config(cfg),
|
||||
cfg["model_name"],
|
||||
global_connection,
|
||||
global_model["model_id"],
|
||||
)
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
|
||||
|
|
@ -157,6 +178,19 @@ def create_generate_image_tool(
|
|||
if not db_model or not db_model.connection or not db_model.connection.enabled:
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
conn = db_model.connection
|
||||
if (
|
||||
conn.search_space_id is not None
|
||||
and conn.search_space_id != search_space_id
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
if (
|
||||
conn.user_id is not None
|
||||
and conn.user_id != search_space.user_id
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
if not (db_model.capabilities or {}).get("image_gen"):
|
||||
err = f"Model {config_id} is not image-generation capable"
|
||||
return _failed({"error": err}, error=err)
|
||||
|
|
|
|||
|
|
@ -45,10 +45,13 @@ from app.services.billable_calls import (
|
|||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
from app.services.auto_model_pin_service import (
|
||||
auto_model_candidates,
|
||||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.model_resolver import to_litellm
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.signed_image_urls import verify_image_token
|
||||
|
|
@ -56,22 +59,15 @@ from app.utils.signed_image_urls import verify_image_token
|
|||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image generation configuration by ID (negative IDs)."""
|
||||
if config_id == IMAGE_GEN_AUTO_MODE_ID:
|
||||
return {
|
||||
"id": IMAGE_GEN_AUTO_MODE_ID,
|
||||
"name": "Auto (Fastest)",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
if config_id > 0:
|
||||
return None
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
def _get_global_model(model_id: int) -> dict | None:
|
||||
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
|
||||
|
||||
|
||||
def _get_global_connection(connection_id: int) -> dict | None:
|
||||
return next(
|
||||
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_billing_for_image_gen(
|
||||
|
|
@ -87,30 +83,41 @@ async def _resolve_billing_for_image_gen(
|
|||
config that will actually run, and so we don't open an
|
||||
``ImageGeneration`` row for a request that's about to 402.
|
||||
|
||||
User-owned (positive ID) BYOK configs are always free — they cost
|
||||
the user nothing on our side. Auto mode currently treats as free
|
||||
because the underlying router can dispatch to either premium or
|
||||
free YAML configs and we don't surface the resolved deployment up
|
||||
here yet. Bringing Auto under premium billing would require
|
||||
threading the chosen deployment back from ``ImageGenRouterService``.
|
||||
User-owned (positive ID) BYOK models are always free — they cost
|
||||
the user nothing on our side. Auto mode resolves to one concrete
|
||||
global or BYOK model before billing is calculated.
|
||||
"""
|
||||
resolved_id = config_id
|
||||
if resolved_id is None:
|
||||
resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
|
||||
if is_image_gen_auto_mode(resolved_id):
|
||||
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space.id,
|
||||
user_id=search_space.user_id,
|
||||
capability="image_gen",
|
||||
)
|
||||
if not candidates:
|
||||
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
selected = choose_auto_model_candidate(candidates, search_space.id)
|
||||
resolved_id = int(selected["id"])
|
||||
|
||||
if resolved_id < 0:
|
||||
cfg = _get_global_image_gen_config(resolved_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
base_model, _ = to_litellm(native_connection_from_config(cfg), cfg.get("model_name", ""))
|
||||
global_model = _get_global_model(resolved_id) or {}
|
||||
global_connection = _get_global_connection(global_model.get("connection_id", 0))
|
||||
billing_tier = str(global_model.get("billing_tier", "free")).lower()
|
||||
if global_connection and global_model.get("model_id"):
|
||||
base_model, _ = to_litellm(global_connection, global_model["model_id"])
|
||||
else:
|
||||
base_model = "global_image_model"
|
||||
catalog = global_model.get("catalog") or {}
|
||||
reserve_micros = int(
|
||||
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
|
||||
catalog.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
|
||||
)
|
||||
return (billing_tier, base_model, reserve_micros)
|
||||
|
||||
# Positive ID = user-owned BYOK image-gen config — always free.
|
||||
# Positive ID = user-owned BYOK image-gen model — always free.
|
||||
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
|
||||
|
|
@ -146,23 +153,28 @@ async def _execute_image_generation(
|
|||
gen_kwargs["response_format"] = image_gen.response_format
|
||||
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
raise ValueError(
|
||||
"Auto mode requested but Image Generation Router not initialized. "
|
||||
"Ensure global_llm_config.yaml has global_image_generation_configs."
|
||||
)
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=image_gen.prompt, model="auto", **gen_kwargs
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space.id,
|
||||
user_id=search_space.user_id,
|
||||
capability="image_gen",
|
||||
)
|
||||
elif config_id < 0:
|
||||
# Global config from YAML
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
raise ValueError(f"Global image generation config {config_id} not found")
|
||||
if not candidates:
|
||||
raise ValueError("No image-generation models are available for Auto mode")
|
||||
config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"])
|
||||
image_gen.image_generation_config_id = config_id
|
||||
|
||||
if config_id < 0:
|
||||
global_model = _get_global_model(config_id)
|
||||
if not global_model or not (global_model.get("capabilities") or {}).get("image_gen"):
|
||||
raise ValueError(f"Global image generation model {config_id} not found")
|
||||
global_connection = _get_global_connection(global_model["connection_id"])
|
||||
if not global_connection:
|
||||
raise ValueError(f"Global connection for image model {config_id} not found")
|
||||
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
native_connection_from_config(cfg),
|
||||
cfg["model_name"],
|
||||
global_connection,
|
||||
global_model["model_id"],
|
||||
)
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
|
||||
|
|
@ -183,6 +195,11 @@ async def _execute_image_generation(
|
|||
db_model = result.scalars().first()
|
||||
if not db_model or not db_model.connection or not db_model.connection.enabled:
|
||||
raise ValueError(f"Image generation model {config_id} not found")
|
||||
conn = db_model.connection
|
||||
if conn.search_space_id is not None and conn.search_space_id != search_space.id:
|
||||
raise ValueError(f"Image generation model {config_id} not found")
|
||||
if conn.user_id is not None and conn.user_id != search_space.user_id:
|
||||
raise ValueError(f"Image generation model {config_id} not found")
|
||||
if not (db_model.capabilities or {}).get("image_gen"):
|
||||
raise ValueError(f"Model {config_id} is not image-generation capable")
|
||||
|
||||
|
|
@ -255,7 +272,7 @@ async def get_global_image_gen_configs(
|
|||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"provider": cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
|
|
|
|||
|
|
@ -20,23 +20,13 @@ from typing import Any
|
|||
from litellm import Router
|
||||
from litellm.utils import ImageResponse
|
||||
|
||||
from app.services.model_resolver import (
|
||||
NATIVE_PROVIDER_PREFIX,
|
||||
native_connection_from_config,
|
||||
to_litellm,
|
||||
)
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
IMAGE_GEN_AUTO_MODE_ID = 0
|
||||
|
||||
# Provider mapping for LiteLLM model string construction.
|
||||
# Only includes providers that support image generation.
|
||||
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
IMAGE_GEN_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
|
||||
|
||||
|
||||
class ImageGenRouterService:
|
||||
"""
|
||||
Singleton service for managing LiteLLM Router for image generation.
|
||||
|
|
|
|||
|
|
@ -55,7 +55,6 @@ from app.services.openrouter_integration_service import ( # noqa: E402
|
|||
_OPENROUTER_DYNAMIC_MARKER,
|
||||
OpenRouterIntegrationService,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base # noqa: E402
|
||||
from app.services.provider_capabilities import ( # noqa: E402
|
||||
derive_supports_image_input,
|
||||
is_known_text_only_chat_model,
|
||||
|
|
@ -154,13 +153,13 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
|
|||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
)
|
||||
cap = derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
litellm_provider=cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
block = is_known_text_only_chat_model(
|
||||
provider=cfg.get("provider"),
|
||||
litellm_provider=cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
|
|
@ -179,11 +178,7 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
|
|||
def _build_chat_model_string(cfg: dict) -> str:
|
||||
if cfg.get("custom_provider"):
|
||||
return f"{cfg['custom_provider']}/{cfg['model_name']}"
|
||||
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
|
||||
|
||||
prefix = _PROVIDER_PREFIX_MAP.get(
|
||||
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
|
||||
)
|
||||
prefix = cfg.get("litellm_provider") or "openai"
|
||||
return f"{prefix}/{cfg['model_name']}"
|
||||
|
||||
|
||||
|
|
@ -195,11 +190,6 @@ def _build_chat_model_string(cfg: dict) -> str:
|
|||
async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
|
||||
"""Send a 1x1 PNG + `reply with one word: ok` to the chat config."""
|
||||
model_string = _build_chat_model_string(cfg)
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=model_string.split("/", 1)[0],
|
||||
config_api_base=cfg.get("api_base") or None,
|
||||
)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": cfg.get("api_key"),
|
||||
|
|
@ -218,8 +208,8 @@ async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
|
|||
"max_tokens": 16,
|
||||
"timeout": 60,
|
||||
}
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
if cfg.get("api_base"):
|
||||
kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("litellm_params"):
|
||||
# Strip pricing keys — they're tracking-only and confuse some
|
||||
# provider validators (e.g. azure/openai reject unknown kwargs
|
||||
|
|
@ -257,20 +247,11 @@ _IMAGE_GEN_PROMPTS: tuple[str, ...] = (
|
|||
|
||||
async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
|
||||
"""Generate one tiny image to verify the deployment is reachable."""
|
||||
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
|
||||
|
||||
if cfg.get("custom_provider"):
|
||||
prefix = cfg["custom_provider"]
|
||||
else:
|
||||
prefix = _PROVIDER_PREFIX_MAP.get(
|
||||
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
|
||||
)
|
||||
prefix = cfg.get("litellm_provider") or "openai"
|
||||
model_string = f"{prefix}/{cfg['model_name']}"
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=prefix,
|
||||
config_api_base=cfg.get("api_base") or None,
|
||||
)
|
||||
base_kwargs: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": cfg.get("api_key"),
|
||||
|
|
@ -278,8 +259,8 @@ async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
|
|||
"size": "1024x1024",
|
||||
"timeout": 120,
|
||||
}
|
||||
if api_base:
|
||||
base_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_base"):
|
||||
base_kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("api_version"):
|
||||
base_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
|
|
|
|||
|
|
@ -49,14 +49,14 @@ async def test_resolve_billing_for_premium_global_config(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-image-1",
|
||||
"billing_tier": "premium",
|
||||
"quota_reserve_micros": 75_000,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-2.5-flash-image",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
|
|
@ -118,7 +118,7 @@ async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -7,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-image-1",
|
||||
"billing_tier": "premium",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,19 +1,4 @@
|
|||
"""Defense-in-depth: image-gen call sites must not let an empty
|
||||
``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``.
|
||||
|
||||
The bug repro: an OpenRouter image-gen config ships
|
||||
``api_base=""``. The pre-fix call site in
|
||||
``image_generation_routes._execute_image_generation`` did
|
||||
``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which
|
||||
silently dropped the empty string. LiteLLM then fell back to
|
||||
``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``)
|
||||
and OpenRouter's ``image_generation/transformation`` appended
|
||||
``/chat/completions`` to it → 404 ``Resource not found``.
|
||||
|
||||
This test pins the post-fix behaviour: with an empty ``api_base`` in
|
||||
the config, the call site MUST set ``api_base`` to OpenRouter's public
|
||||
URL instead of leaving it unset.
|
||||
"""
|
||||
"""Image-gen call sites must pass each config's explicit ``api_base``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -26,20 +11,17 @@ pytestmark = pytest.mark.unit
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
|
||||
"""The global-config branch (``config_id < 0``) of
|
||||
``_execute_image_generation`` must apply the resolver and pin
|
||||
``api_base`` to OpenRouter when the config ships an empty string.
|
||||
"""
|
||||
async def test_global_openrouter_image_gen_sets_explicit_api_base():
|
||||
"""The global-config branch forwards the explicit OpenRouter base."""
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
cfg = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "", # the original bug shape
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
}
|
||||
|
|
@ -80,16 +62,13 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
|
|||
session=session, image_gen=image_gen, search_space=search_space
|
||||
)
|
||||
|
||||
# The whole point of the fix: even with empty ``api_base`` in the
|
||||
# config, we forward OpenRouter's public URL so the call doesn't
|
||||
# inherit an Azure endpoint.
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
||||
"""Same defense at the agent tool entry point — both surfaces share
|
||||
async def test_generate_image_tool_global_sets_explicit_api_base():
|
||||
"""Same explicit-base behavior at the agent tool entry point — both surfaces share
|
||||
the same OpenRouter config payloads."""
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools import (
|
||||
generate_image as gi_module,
|
||||
|
|
@ -98,10 +77,10 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
|||
cfg = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
}
|
||||
|
|
@ -171,20 +150,16 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
|||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
||||
|
||||
def test_image_gen_router_deployment_sets_api_base_when_config_empty():
|
||||
"""The Auto-mode router pool must also resolve ``api_base`` when an
|
||||
OpenRouter config ships an empty string. The deployment dict is fed
|
||||
straight to ``litellm.Router``, so a missing ``api_base`` would
|
||||
leak the same way as the direct call sites.
|
||||
"""
|
||||
def test_image_gen_router_deployment_sets_explicit_api_base():
|
||||
"""The Auto-mode router pool carries explicit api_base into deployments."""
|
||||
from app.services.image_gen_router_service import ImageGenRouterService
|
||||
|
||||
deployment = ImageGenRouterService._config_to_deployment(
|
||||
{
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
)
|
||||
assert deployment is not None
|
||||
|
|
|
|||
|
|
@ -1,12 +1,4 @@
|
|||
"""Defense-in-depth: vision-LLM resolution must not leak ``api_base``
|
||||
defaults from ``litellm.api_base`` either.
|
||||
|
||||
Vision shares the same shape as image-gen — global YAML / OpenRouter
|
||||
dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm``
|
||||
call sites would silently drop the empty string and inherit
|
||||
``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on
|
||||
construction so we test the kwargs we hand to it instead.
|
||||
"""
|
||||
"""Vision LLM resolution must pass explicit per-config ``api_base``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -19,19 +11,16 @@ pytestmark = pytest.mark.unit
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vision_llm_global_openrouter_sets_api_base():
|
||||
"""Global negative-ID branch: an OpenRouter vision config with
|
||||
``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with
|
||||
``api_base="https://openrouter.ai/api/v1"`` — never an empty string,
|
||||
never silently absent."""
|
||||
"""Global negative-ID branch forwards the explicit OpenRouter base."""
|
||||
from app.services import llm_service
|
||||
|
||||
cfg = {
|
||||
"id": -30_001,
|
||||
"name": "GPT-4o Vision (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"billing_tier": "free",
|
||||
|
|
@ -72,16 +61,15 @@ async def test_get_vision_llm_global_openrouter_sets_api_base():
|
|||
|
||||
|
||||
def test_vision_router_deployment_sets_api_base_when_config_empty():
|
||||
"""Auto-mode vision router: deployments are fed to ``litellm.Router``,
|
||||
so the resolver has to apply at deployment construction time too."""
|
||||
"""Auto-mode vision router carries explicit api_base into deployments."""
|
||||
from app.services.vision_llm_router_service import VisionLLMRouterService
|
||||
|
||||
deployment = VisionLLMRouterService._config_to_deployment(
|
||||
{
|
||||
"model_name": "openai/gpt-4o",
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
)
|
||||
assert deployment is not None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue