mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-16 21:05:20 +02:00
feat(database-migrations): add migration to remove legacy model config tables and remove stale model connection code
This commit is contained in:
parent
50668775f8
commit
bd4a04f2e7
93 changed files with 956 additions and 11442 deletions
|
|
@ -4,7 +4,7 @@ Revision ID: 138
|
|||
Revises: 137
|
||||
Create Date: 2026-04-30
|
||||
|
||||
Add a single thread-level column to persist the Auto (Fastest) model pin:
|
||||
Add a single thread-level column to persist the Auto model pin:
|
||||
- pinned_llm_config_id: concrete resolved global LLM config id used for this
|
||||
thread. NULL means "no pin; Auto will resolve on next turn".
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,270 @@
|
|||
"""remove legacy model config tables
|
||||
|
||||
Revision ID: 161
|
||||
Revises: 160
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "161"
|
||||
down_revision: str | None = "160"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
litellm_provider = postgresql.ENUM(
|
||||
"OPENAI",
|
||||
"ANTHROPIC",
|
||||
"GOOGLE",
|
||||
"AZURE_OPENAI",
|
||||
"BEDROCK",
|
||||
"VERTEX_AI",
|
||||
"GROQ",
|
||||
"COHERE",
|
||||
"MISTRAL",
|
||||
"DEEPSEEK",
|
||||
"XAI",
|
||||
"OPENROUTER",
|
||||
"TOGETHER_AI",
|
||||
"FIREWORKS_AI",
|
||||
"REPLICATE",
|
||||
"PERPLEXITY",
|
||||
"OLLAMA",
|
||||
"ALIBABA_QWEN",
|
||||
"MOONSHOT",
|
||||
"ZHIPU",
|
||||
"ANYSCALE",
|
||||
"DEEPINFRA",
|
||||
"CEREBRAS",
|
||||
"SAMBANOVA",
|
||||
"AI21",
|
||||
"CLOUDFLARE",
|
||||
"DATABRICKS",
|
||||
"COMETAPI",
|
||||
"HUGGINGFACE",
|
||||
"GITHUB_MODELS",
|
||||
"MINIMAX",
|
||||
"CUSTOM",
|
||||
name="litellmprovider",
|
||||
create_type=False,
|
||||
)
|
||||
image_gen_provider = postgresql.ENUM(
|
||||
"OPENAI",
|
||||
"AZURE_OPENAI",
|
||||
"GOOGLE",
|
||||
"VERTEX_AI",
|
||||
"BEDROCK",
|
||||
"RECRAFT",
|
||||
"OPENROUTER",
|
||||
"XINFERENCE",
|
||||
"NSCALE",
|
||||
name="imagegenprovider",
|
||||
create_type=False,
|
||||
)
|
||||
vision_provider = postgresql.ENUM(
|
||||
"OPENAI",
|
||||
"ANTHROPIC",
|
||||
"GOOGLE",
|
||||
"AZURE_OPENAI",
|
||||
"VERTEX_AI",
|
||||
"BEDROCK",
|
||||
"XAI",
|
||||
"OPENROUTER",
|
||||
"OLLAMA",
|
||||
"GROQ",
|
||||
"TOGETHER_AI",
|
||||
"FIREWORKS_AI",
|
||||
"DEEPSEEK",
|
||||
"MISTRAL",
|
||||
"CUSTOM",
|
||||
name="visionprovider",
|
||||
create_type=False,
|
||||
)
|
||||
|
||||
|
||||
def _table_exists(table_name: str) -> bool:
|
||||
return table_name in sa.inspect(op.get_bind()).get_table_names()
|
||||
|
||||
|
||||
def _column_exists(table_name: str, column_name: str) -> bool:
|
||||
if not _table_exists(table_name):
|
||||
return False
|
||||
return column_name in {
|
||||
column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name)
|
||||
}
|
||||
|
||||
|
||||
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
|
||||
if _column_exists(table_name, column_name):
|
||||
op.drop_column(table_name, column_name)
|
||||
|
||||
|
||||
def _rename_column_if_exists(
|
||||
table_name: str,
|
||||
old_column_name: str,
|
||||
new_column_name: str,
|
||||
*,
|
||||
existing_type: TypeEngine,
|
||||
existing_nullable: bool = True,
|
||||
) -> None:
|
||||
if _column_exists(table_name, old_column_name) and not _column_exists(
|
||||
table_name, new_column_name
|
||||
):
|
||||
op.alter_column(
|
||||
table_name,
|
||||
old_column_name,
|
||||
new_column_name=new_column_name,
|
||||
existing_type=existing_type,
|
||||
existing_nullable=existing_nullable,
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
for table_name in (
|
||||
"new_llm_configs",
|
||||
"vision_llm_configs",
|
||||
"image_generation_configs",
|
||||
):
|
||||
if _table_exists(table_name):
|
||||
op.drop_table(table_name)
|
||||
|
||||
_drop_column_if_exists("searchspaces", "agent_llm_id")
|
||||
_drop_column_if_exists("searchspaces", "image_generation_config_id")
|
||||
_drop_column_if_exists("searchspaces", "vision_llm_config_id")
|
||||
|
||||
_rename_column_if_exists(
|
||||
"image_generations",
|
||||
"image_generation_config_id",
|
||||
"image_gen_model_id",
|
||||
existing_type=sa.Integer(),
|
||||
)
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS litellmprovider")
|
||||
op.execute("DROP TYPE IF EXISTS imagegenprovider")
|
||||
op.execute("DROP TYPE IF EXISTS visionprovider")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
litellm_provider.create(bind, checkfirst=True)
|
||||
image_gen_provider.create(bind, checkfirst=True)
|
||||
vision_provider.create(bind, checkfirst=True)
|
||||
|
||||
_rename_column_if_exists(
|
||||
"image_generations",
|
||||
"image_gen_model_id",
|
||||
"image_generation_config_id",
|
||||
existing_type=sa.Integer(),
|
||||
)
|
||||
|
||||
if _table_exists("searchspaces"):
|
||||
if not _column_exists("searchspaces", "agent_llm_id"):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("agent_llm_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
if not _column_exists("searchspaces", "image_generation_config_id"):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("image_generation_config_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
if not _column_exists("searchspaces", "vision_llm_config_id"):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("vision_llm_config_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
if not _table_exists("image_generation_configs"):
|
||||
op.create_table(
|
||||
"image_generation_configs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("description", sa.String(length=500), nullable=True),
|
||||
sa.Column("provider", image_gen_provider, nullable=False),
|
||||
sa.Column("custom_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("model_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(length=500), nullable=True),
|
||||
sa.Column("api_version", sa.String(length=50), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_image_generation_configs_name"),
|
||||
"image_generation_configs",
|
||||
["name"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
if not _table_exists("vision_llm_configs"):
|
||||
op.create_table(
|
||||
"vision_llm_configs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("description", sa.String(length=500), nullable=True),
|
||||
sa.Column("provider", vision_provider, nullable=False),
|
||||
sa.Column("custom_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("model_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(length=500), nullable=True),
|
||||
sa.Column("api_version", sa.String(length=50), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_vision_llm_configs_name"),
|
||||
"vision_llm_configs",
|
||||
["name"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
if not _table_exists("new_llm_configs"):
|
||||
op.create_table(
|
||||
"new_llm_configs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("description", sa.String(length=500), nullable=True),
|
||||
sa.Column("provider", litellm_provider, nullable=False),
|
||||
sa.Column("custom_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("model_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(length=500), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("system_instructions", sa.Text(), nullable=False),
|
||||
sa.Column("use_default_system_instructions", sa.Boolean(), nullable=False),
|
||||
sa.Column("citations_enabled", sa.Boolean(), nullable=False),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_new_llm_configs_name"),
|
||||
"new_llm_configs",
|
||||
["name"],
|
||||
unique=False,
|
||||
)
|
||||
|
|
@ -215,7 +215,7 @@ def create_generate_image_tool(
|
|||
prompt=prompt,
|
||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||
n=n,
|
||||
image_generation_config_id=config_id,
|
||||
image_gen_model_id=config_id,
|
||||
response_data=response_dict,
|
||||
search_space_id=search_space_id,
|
||||
access_token=access_token,
|
||||
|
|
|
|||
|
|
@ -24,8 +24,6 @@ from langchain_core.messages import AIMessage, BaseMessage
|
|||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from litellm import get_model_info
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.chat.runtime.prompt_caching import (
|
||||
apply_litellm_prompt_caching,
|
||||
|
|
@ -34,7 +32,6 @@ from app.services.llm_router_service import (
|
|||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
_sanitize_content,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -130,7 +127,7 @@ class AgentConfig:
|
|||
"""
|
||||
Complete configuration for the SurfSense agent.
|
||||
|
||||
This combines LLM settings with prompt configuration from NewLLMConfig.
|
||||
This combines resolved model settings with prompt configuration.
|
||||
Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to
|
||||
a concrete global or BYOK model before constructing ChatLiteLLM.
|
||||
"""
|
||||
|
|
@ -180,7 +177,7 @@ class AgentConfig:
|
|||
use_default_system_instructions=True,
|
||||
citations_enabled=True,
|
||||
config_id=AUTO_MODE_ID,
|
||||
config_name="Auto (Fastest)",
|
||||
config_name="Auto",
|
||||
is_auto_mode=True,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
|
|
@ -191,57 +188,12 @@ class AgentConfig:
|
|||
supports_image_input=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_new_llm_config(cls, config) -> "AgentConfig":
|
||||
"""Build an AgentConfig from a NewLLMConfig database model."""
|
||||
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
return cls(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=config.api_base,
|
||||
custom_provider=config.custom_provider,
|
||||
litellm_params=config.litellm_params,
|
||||
system_instructions=config.system_instructions,
|
||||
use_default_system_instructions=config.use_default_system_instructions,
|
||||
citations_enabled=config.citations_enabled,
|
||||
config_id=config.id,
|
||||
config_name=config.name,
|
||||
is_auto_mode=False,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
# BYOK rows have no curated flag; ask LiteLLM (default-allow on
|
||||
# unknown). The streaming safety net still blocks explicit text-only.
|
||||
supports_image_input=derive_supports_image_input(
|
||||
provider=provider_value.lower(),
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
|
||||
"""Build an AgentConfig from a YAML configuration dictionary.
|
||||
|
||||
Supports the same prompt fields as NewLLMConfig (system_instructions,
|
||||
use_default_system_instructions, citations_enabled).
|
||||
Supports prompt fields such as system_instructions,
|
||||
use_default_system_instructions, and citations_enabled.
|
||||
"""
|
||||
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
|
@ -334,82 +286,6 @@ def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
|
|||
return load_llm_config_from_yaml(llm_config_id)
|
||||
|
||||
|
||||
async def load_new_llm_config_from_db(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
) -> "AgentConfig | None":
|
||||
"""Load a NewLLMConfig from the database by ID."""
|
||||
from app.db import NewLLMConfig
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
print(f"Error: NewLLMConfig with id {config_id} not found")
|
||||
return None
|
||||
|
||||
return AgentConfig.from_new_llm_config(config)
|
||||
except Exception as e:
|
||||
print(f"Error loading NewLLMConfig from database: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def load_agent_llm_config_for_search_space(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> "AgentConfig | None":
|
||||
"""Load the chat model config for a search space via its agent_llm_id.
|
||||
|
||||
Positive id -> DB; negative -> YAML; None -> first global config (-1).
|
||||
"""
|
||||
from app.db import SearchSpace
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
|
||||
if not search_space:
|
||||
print(f"Error: SearchSpace with id {search_space_id} not found")
|
||||
return None
|
||||
|
||||
config_id = (
|
||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
)
|
||||
return await load_agent_config(session, config_id, search_space_id)
|
||||
except Exception as e:
|
||||
print(f"Error loading chat model config for search space {search_space_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def load_agent_config(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
search_space_id: int | None = None,
|
||||
) -> "AgentConfig | None":
|
||||
"""Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB."""
|
||||
if is_auto_mode(config_id):
|
||||
return AgentConfig.from_auto_mode()
|
||||
|
||||
if config_id < 0:
|
||||
# In-memory covers static YAML + dynamic OpenRouter configs.
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return AgentConfig.from_yaml_config(cfg)
|
||||
yaml_config = load_llm_config_from_yaml(config_id)
|
||||
if yaml_config:
|
||||
return AgentConfig.from_yaml_config(yaml_config)
|
||||
return None
|
||||
else:
|
||||
return await load_new_llm_config_from_db(session, config_id)
|
||||
|
||||
|
||||
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||
"""Create a ChatLiteLLM instance from a global LLM config dictionary."""
|
||||
if llm_config.get("custom_provider"):
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@
|
|||
|
||||
Automations run unattended, so every run must be **billable**: it may only use
|
||||
either a premium global model (``billing_tier == "premium"``) or a user-provided
|
||||
BYOK model (a positive config id pointing at a per-user/per-space DB row). Free
|
||||
BYOK model (a positive model id pointing at a per-user/per-space DB row). Free
|
||||
global models and Auto mode are blocked, because Auto can dispatch to a free
|
||||
deployment and free models aren't metered in premium credits.
|
||||
|
||||
Config id conventions (shared across chat / image / vision):
|
||||
Model id conventions (shared across chat / image / vision):
|
||||
- ``id == 0`` → Auto mode (``AUTO_MODE_ID`` / ``IMAGE_GEN_AUTO_MODE_ID`` /
|
||||
``VISION_AUTO_MODE_ID``). Blocked.
|
||||
- ``id < 0`` → global YAML/OpenRouter config. Allowed only if premium.
|
||||
|
|
@ -82,7 +82,7 @@ def get_model_eligibility(
|
|||
|
||||
The ID-based core shared by both the search-space path (creation/eligibility)
|
||||
and the captured-snapshot path (runtime backstop). Each violation is
|
||||
``{"kind", "config_id", "reason"}``.
|
||||
``{"kind", "model_id", "reason"}``.
|
||||
"""
|
||||
checks: list[tuple[ModelKind, int | None]] = [
|
||||
("chat", chat_model_id),
|
||||
|
|
@ -91,10 +91,10 @@ def get_model_eligibility(
|
|||
]
|
||||
|
||||
violations: list[dict] = []
|
||||
for kind, config_id in checks:
|
||||
allowed, reason = _classify(kind, config_id)
|
||||
for kind, model_id in checks:
|
||||
allowed, reason = _classify(kind, model_id)
|
||||
if not allowed:
|
||||
violations.append({"kind": kind, "model_id": config_id, "reason": reason})
|
||||
violations.append({"kind": kind, "model_id": model_id, "reason": reason})
|
||||
|
||||
return {"allowed": not violations, "violations": violations}
|
||||
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ def load_global_llm_configs():
|
|||
else:
|
||||
seen_slugs[slug] = cfg.get("id", 0)
|
||||
|
||||
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
|
||||
# Stamp Auto ranking metadata. YAML configs are always
|
||||
# Tier A — operator-curated, locked first when premium-eligible.
|
||||
# The OpenRouter refresh tick later re-stamps health for any cfg
|
||||
# whose provider == "openrouter" via _enrich_health.
|
||||
|
|
@ -210,42 +210,6 @@ def load_global_image_gen_configs():
|
|||
return []
|
||||
|
||||
|
||||
def load_global_vision_llm_configs():
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return []
|
||||
|
||||
try:
|
||||
configs = copy.deepcopy(data.get("global_vision_llm_configs", []) or [])
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global vision LLM configs: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def load_vision_llm_router_settings():
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing",
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60,
|
||||
}
|
||||
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return default_settings
|
||||
|
||||
try:
|
||||
settings = data.get("vision_llm_router_settings", {})
|
||||
return {**default_settings, **settings}
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load vision LLM router settings: {e}")
|
||||
return default_settings
|
||||
|
||||
|
||||
def load_image_gen_router_settings():
|
||||
"""
|
||||
Load router settings for image generation Auto mode from YAML file.
|
||||
|
|
@ -482,12 +446,6 @@ def initialize_image_gen_router():
|
|||
print(f"Warning: Failed to initialize Image Generation Router: {e}")
|
||||
|
||||
|
||||
def initialize_vision_llm_router():
|
||||
# Retired: vision Auto now uses shared capability-filtered model selection
|
||||
# over GLOBAL/BYOK chat models with supports_image_input=true.
|
||||
return
|
||||
|
||||
|
||||
class Config:
|
||||
# Check if ffmpeg is installed
|
||||
if not is_ffmpeg_installed():
|
||||
|
|
@ -869,12 +827,6 @@ class Config:
|
|||
# Router settings for Image Generation Auto mode
|
||||
IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings()
|
||||
|
||||
# Global Vision LLM Configurations (optional)
|
||||
GLOBAL_VISION_LLM_CONFIGS = load_global_vision_llm_configs()
|
||||
|
||||
# Router settings for Vision LLM Auto mode
|
||||
VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings()
|
||||
|
||||
# Virtual GLOBAL connection/model catalog. This is server-only metadata
|
||||
# derived from global_llm_config.yaml; GLOBAL keys are not stored in DB.
|
||||
from app.services.global_model_catalog import (
|
||||
|
|
|
|||
|
|
@ -433,87 +433,11 @@ global_image_generation_configs:
|
|||
# rpm: 30
|
||||
# litellm_params: {}
|
||||
|
||||
# =============================================================================
|
||||
# Vision LLM Configuration
|
||||
# =============================================================================
|
||||
# These configurations power the vision autocomplete feature (screenshot analysis).
|
||||
# Only vision-capable models should be used here (e.g. GPT-4o, Gemini Pro, Claude 3).
|
||||
# Supported providers: OpenAI, Anthropic, Google, Azure OpenAI, Vertex AI, Bedrock,
|
||||
# xAI, OpenRouter, Ollama, Groq, Together AI, Fireworks AI, DeepSeek, Mistral, Custom
|
||||
#
|
||||
# Auto mode (ID 0) uses LiteLLM Router for load balancing across all vision configs.
|
||||
|
||||
# Router Settings for Vision LLM Auto Mode
|
||||
vision_llm_router_settings:
|
||||
routing_strategy: "usage-based-routing"
|
||||
num_retries: 3
|
||||
allowed_fails: 3
|
||||
cooldown_time: 60
|
||||
|
||||
global_vision_llm_configs:
|
||||
# Example: OpenAI GPT-4o (recommended for vision)
|
||||
- id: -1001
|
||||
name: "Global GPT-4o Vision"
|
||||
description: "OpenAI's GPT-4o with strong vision capabilities"
|
||||
litellm_provider: "openai"
|
||||
model_name: "gpt-4o"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: "https://api.openai.com/v1"
|
||||
rpm: 500
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Google Gemini 2.0 Flash
|
||||
- id: -1002
|
||||
name: "Global Gemini 2.0 Flash"
|
||||
description: "Google's fast vision model with large context"
|
||||
litellm_provider: "gemini"
|
||||
model_name: "gemini-2.0-flash"
|
||||
api_key: "your-google-ai-api-key-here"
|
||||
api_base: "https://generativelanguage.googleapis.com/v1beta"
|
||||
rpm: 1000
|
||||
tpm: 200000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Anthropic Claude 3.5 Sonnet
|
||||
- id: -1003
|
||||
name: "Global Claude 3.5 Sonnet Vision"
|
||||
description: "Anthropic's Claude 3.5 Sonnet with vision support"
|
||||
litellm_provider: "anthropic"
|
||||
model_name: "claude-3-5-sonnet-20241022"
|
||||
api_key: "sk-ant-your-anthropic-api-key-here"
|
||||
api_base: "https://api.anthropic.com/v1"
|
||||
rpm: 1000
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Azure OpenAI GPT-4o
|
||||
# - id: -1004
|
||||
# name: "Global Azure GPT-4o Vision"
|
||||
# description: "Azure-hosted GPT-4o for vision analysis"
|
||||
# litellm_provider: "azure"
|
||||
# model_name: "azure/gpt-4o-deployment"
|
||||
# api_key: "your-azure-api-key-here"
|
||||
# api_base: "https://your-resource.openai.azure.com"
|
||||
# api_version: "2024-02-15-preview"
|
||||
# rpm: 500
|
||||
# tpm: 100000
|
||||
# litellm_params:
|
||||
# temperature: 0.3
|
||||
# max_tokens: 1000
|
||||
# base_model: "gpt-4o"
|
||||
|
||||
# Notes:
|
||||
# - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing
|
||||
# - Use negative IDs to distinguish global models from BYOK/local DB models
|
||||
# - IDs must be unique across chat, vision, and image generation configs
|
||||
# - Suggested static ranges: chat -1..-999, vision -1001..-1999, image -2001..-2999
|
||||
# - IDs must be unique across chat and image generation configs
|
||||
# - Suggested static ranges: chat -1..-999, image -2001..-2999
|
||||
# - The 'api_key' field will not be exposed to users via API
|
||||
# - system_instructions: Custom prompt or empty string to use defaults
|
||||
# - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty
|
||||
|
|
|
|||
|
|
@ -198,81 +198,6 @@ class DocumentStatus:
|
|||
return None
|
||||
|
||||
|
||||
class LiteLLMProvider(StrEnum):
|
||||
"""
|
||||
Enum for LLM providers supported by LiteLLM.
|
||||
"""
|
||||
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
GOOGLE = "GOOGLE"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
BEDROCK = "BEDROCK"
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
GROQ = "GROQ"
|
||||
COHERE = "COHERE"
|
||||
MISTRAL = "MISTRAL"
|
||||
DEEPSEEK = "DEEPSEEK"
|
||||
XAI = "XAI"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
TOGETHER_AI = "TOGETHER_AI"
|
||||
FIREWORKS_AI = "FIREWORKS_AI"
|
||||
REPLICATE = "REPLICATE"
|
||||
PERPLEXITY = "PERPLEXITY"
|
||||
OLLAMA = "OLLAMA"
|
||||
ALIBABA_QWEN = "ALIBABA_QWEN"
|
||||
MOONSHOT = "MOONSHOT"
|
||||
ZHIPU = "ZHIPU"
|
||||
ANYSCALE = "ANYSCALE"
|
||||
DEEPINFRA = "DEEPINFRA"
|
||||
CEREBRAS = "CEREBRAS"
|
||||
SAMBANOVA = "SAMBANOVA"
|
||||
AI21 = "AI21"
|
||||
CLOUDFLARE = "CLOUDFLARE"
|
||||
DATABRICKS = "DATABRICKS"
|
||||
COMETAPI = "COMETAPI"
|
||||
HUGGINGFACE = "HUGGINGFACE"
|
||||
GITHUB_MODELS = "GITHUB_MODELS"
|
||||
MINIMAX = "MINIMAX"
|
||||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
class ImageGenProvider(StrEnum):
|
||||
"""
|
||||
Enum for image generation providers supported by LiteLLM.
|
||||
This is a subset of LLM providers — only those that support image generation.
|
||||
See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
"""
|
||||
|
||||
OPENAI = "OPENAI"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
GOOGLE = "GOOGLE" # Google AI Studio
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
BEDROCK = "BEDROCK" # AWS Bedrock
|
||||
RECRAFT = "RECRAFT"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
XINFERENCE = "XINFERENCE"
|
||||
NSCALE = "NSCALE"
|
||||
|
||||
|
||||
class VisionProvider(StrEnum):
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
GOOGLE = "GOOGLE"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
BEDROCK = "BEDROCK"
|
||||
XAI = "XAI"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
OLLAMA = "OLLAMA"
|
||||
GROQ = "GROQ"
|
||||
TOGETHER_AI = "TOGETHER_AI"
|
||||
FIREWORKS_AI = "FIREWORKS_AI"
|
||||
DEEPSEEK = "DEEPSEEK"
|
||||
MISTRAL = "MISTRAL"
|
||||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
class ConnectionScope(StrEnum):
|
||||
GLOBAL = "GLOBAL"
|
||||
SEARCH_SPACE = "SEARCH_SPACE"
|
||||
|
|
@ -710,11 +635,11 @@ class NewChatThread(BaseModel, TimestampMixin):
|
|||
default=False,
|
||||
server_default="false",
|
||||
)
|
||||
# Auto (Fastest) model pin for this thread: concrete resolved global LLM
|
||||
# Auto model pin for this thread: concrete resolved global LLM
|
||||
# config id. NULL means no pin; Auto will resolve on the next turn.
|
||||
# Single-writer invariant: only app.services.auto_model_pin_service sets
|
||||
# or clears this column (plus bulk clears when a search space's
|
||||
# agent_llm_id changes). Unindexed: all reads are by primary key.
|
||||
# chat_model_id changes). Unindexed: all reads are by primary key.
|
||||
pinned_llm_config_id = Column(Integer, nullable=True)
|
||||
|
||||
# Surface metadata for first-party SurfSense and external chat threads.
|
||||
|
|
@ -1686,75 +1611,6 @@ class Model(BaseModel, TimestampMixin):
|
|||
)
|
||||
|
||||
|
||||
class ImageGenerationConfig(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Dedicated configuration table for image generation models.
|
||||
|
||||
Separate from NewLLMConfig because image generation models don't need
|
||||
system_instructions, citations_enabled, or use_default_system_instructions.
|
||||
They only need provider credentials and model parameters.
|
||||
"""
|
||||
|
||||
__tablename__ = "image_generation_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
# Provider & model (uses ImageGenProvider, NOT LiteLLMProvider)
|
||||
provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False)
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
|
||||
# Credentials
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
api_version = Column(String(50), nullable=True) # Azure-specific
|
||||
|
||||
# Additional litellm parameters
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
# Relationships
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship(
|
||||
"SearchSpace", back_populates="image_generation_configs"
|
||||
)
|
||||
|
||||
# User who created this config
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="image_generation_configs")
|
||||
|
||||
|
||||
class VisionLLMConfig(BaseModel, TimestampMixin):
|
||||
__tablename__ = "vision_llm_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
provider = Column(SQLAlchemyEnum(VisionProvider), nullable=False)
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
api_version = Column(String(50), nullable=True)
|
||||
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="vision_llm_configs")
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="vision_llm_configs")
|
||||
|
||||
|
||||
class ImageGeneration(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Stores image generation requests and results using litellm.aimage_generation().
|
||||
|
|
@ -1786,10 +1642,9 @@ class ImageGeneration(BaseModel, TimestampMixin):
|
|||
style = Column(String(50), nullable=True) # Model-specific style parameter
|
||||
response_format = Column(String(50), nullable=True) # "url" or "b64_json"
|
||||
|
||||
# Image generation config reference
|
||||
# 0 = Auto mode (router), negative IDs = global configs from YAML,
|
||||
# positive IDs = ImageGenerationConfig records in DB
|
||||
image_generation_config_id = Column(Integer, nullable=True)
|
||||
# Image generation model provenance.
|
||||
# 0 = Auto mode, negative IDs = GLOBAL models, positive IDs = Model records.
|
||||
image_gen_model_id = Column(Integer, nullable=True)
|
||||
|
||||
# Response data (full litellm response as JSONB) — present on success
|
||||
response_data = Column(JSONB, nullable=True)
|
||||
|
|
@ -1831,23 +1686,7 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
|
||||
shared_memory_md = Column(Text, nullable=True, server_default="")
|
||||
|
||||
# Search space-level LLM preferences (shared by all members)
|
||||
# Note: ID values:
|
||||
# - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces
|
||||
# - Negative IDs: Global configs from YAML
|
||||
# - Positive IDs: Custom configs from DB (NewLLMConfig table)
|
||||
agent_llm_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For chat operations, defaults to Auto mode
|
||||
image_generation_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For image generation, defaults to Auto mode
|
||||
vision_llm_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For vision/screenshot analysis, defaults to Auto mode
|
||||
|
||||
# New connection/model role bindings. These supersede the legacy config
|
||||
# columns above without removing them in this PR.
|
||||
# Connection/model role bindings.
|
||||
# Note: ID values preserve the existing convention:
|
||||
# - 0: Auto mode
|
||||
# - Negative IDs: Global virtual models from global_llm_config.yaml
|
||||
|
|
@ -1931,24 +1770,6 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="SearchSourceConnector.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
new_llm_configs = relationship(
|
||||
"NewLLMConfig",
|
||||
back_populates="search_space",
|
||||
order_by="NewLLMConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="search_space",
|
||||
order_by="ImageGenerationConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
back_populates="search_space",
|
||||
order_by="VisionLLMConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
connections = relationship(
|
||||
"Connection",
|
||||
back_populates="search_space",
|
||||
|
|
@ -2057,64 +1878,6 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
|
|||
documents = relationship("Document", back_populates="connector")
|
||||
|
||||
|
||||
class NewLLMConfig(BaseModel, TimestampMixin):
|
||||
"""
|
||||
New LLM configuration table that combines model settings with prompt configuration.
|
||||
|
||||
This table provides:
|
||||
- LLM model configuration (provider, model_name, api_key, etc.)
|
||||
- Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
|
||||
- Citation toggle (enable/disable citation instructions)
|
||||
|
||||
Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory).
|
||||
"""
|
||||
|
||||
__tablename__ = "new_llm_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
# === LLM Model Configuration (from original LLMConfig, excluding 'language') ===
|
||||
# Provider from the enum
|
||||
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
|
||||
# Custom provider name when provider is CUSTOM
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
# Just the model name without provider prefix
|
||||
model_name = Column(String(100), nullable=False)
|
||||
# API Key should be encrypted before storing
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
# For any other parameters that litellm supports
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
# === Prompt Configuration ===
|
||||
# Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
|
||||
# Users can customize this from the UI
|
||||
system_instructions = Column(
|
||||
Text,
|
||||
nullable=False,
|
||||
default="", # Empty string means use default SURFSENSE_SYSTEM_INSTRUCTIONS
|
||||
)
|
||||
# Whether to use the default system instructions when system_instructions is empty
|
||||
use_default_system_instructions = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
# Citation toggle - when enabled, SURFSENSE_CITATION_INSTRUCTIONS is injected
|
||||
# When disabled, an anti-citation prompt is injected instead
|
||||
citations_enabled = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
# === Relationships ===
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="new_llm_configs")
|
||||
|
||||
# User who created this config
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="new_llm_configs")
|
||||
|
||||
|
||||
class Log(BaseModel, TimestampMixin):
|
||||
__tablename__ = "logs"
|
||||
|
||||
|
|
@ -2481,25 +2244,6 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# LLM configs created by this user
|
||||
new_llm_configs = relationship(
|
||||
"NewLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Image generation configs created by this user
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
connections = relationship(
|
||||
"Connection",
|
||||
back_populates="user",
|
||||
|
|
@ -2632,25 +2376,6 @@ else:
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# LLM configs created by this user
|
||||
new_llm_configs = relationship(
|
||||
"NewLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Image generation configs created by this user
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
connections = relationship(
|
||||
"Connection",
|
||||
back_populates="user",
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ def build_configurable_system_prompt(
|
|||
*,
|
||||
model_name: str | None = None,
|
||||
) -> str:
|
||||
"""Build a configurable SurfSense system prompt (NewLLMConfig path).
|
||||
"""Build a configurable SurfSense system prompt.
|
||||
|
||||
See :func:`app.prompts.system_prompt_composer.composer.compose_system_prompt`
|
||||
for full parameter docs.
|
||||
|
|
@ -104,7 +104,7 @@ def build_configurable_system_prompt(
|
|||
def get_default_system_instructions() -> str:
|
||||
"""Return the default ``<system_instruction>`` block (no tools / citations).
|
||||
|
||||
Useful for populating the UI when seeding ``NewLLMConfig.system_instructions``.
|
||||
Useful for populating the UI when editing custom system instructions.
|
||||
The output reflects the current fragment tree, not a baked-in constant.
|
||||
"""
|
||||
resolved_today = datetime.now(UTC).date().isoformat()
|
||||
|
|
|
|||
|
|
@ -348,8 +348,7 @@ def compose_system_prompt(
|
|||
mcp_connector_tools: ``{server_name: [tool_names...]}`` to inject
|
||||
an explicit MCP routing block.
|
||||
custom_system_instructions: Free-form instructions that override
|
||||
the default ``<system_instruction>`` block (legacy support
|
||||
for ``NewLLMConfig.system_instructions``).
|
||||
the default ``<system_instruction>`` block.
|
||||
use_default_system_instructions: When ``custom_system_instructions``
|
||||
is empty/None, fall back to defaults (legacy semantics).
|
||||
citations_enabled: Include ``citations_on.md`` (true) or
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ from .model_connections_routes import router as model_connections_router
|
|||
from .memory_routes import router as memory_router
|
||||
from .model_list_routes import router as model_list_router
|
||||
from .new_chat_routes import router as new_chat_router
|
||||
from .new_llm_config_routes import router as new_llm_config_router
|
||||
from .notes_routes import router as notes_router
|
||||
from .notion_add_connector_route import router as notion_add_connector_router
|
||||
from .obsidian_plugin_routes import router as obsidian_plugin_router
|
||||
|
|
@ -64,7 +63,6 @@ from .stripe_routes import router as stripe_router
|
|||
from .team_memory_routes import router as team_memory_router
|
||||
from .teams_add_connector_route import router as teams_add_connector_router
|
||||
from .video_presentations_routes import router as video_presentations_router
|
||||
from .vision_llm_routes import router as vision_llm_router
|
||||
from .youtube_routes import router as youtube_router
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -99,7 +97,6 @@ router.include_router(
|
|||
) # Video presentation status and streaming
|
||||
router.include_router(reports_router) # Report CRUD and multi-format export
|
||||
router.include_router(image_generation_router) # Image generation via litellm
|
||||
router.include_router(vision_llm_router) # Vision LLM configs for screenshot analysis
|
||||
router.include_router(search_source_connectors_router)
|
||||
router.include_router(google_calendar_add_connector_router)
|
||||
router.include_router(google_gmail_add_connector_router)
|
||||
|
|
@ -117,7 +114,6 @@ router.include_router(jira_add_connector_router)
|
|||
router.include_router(confluence_add_connector_router)
|
||||
router.include_router(clickup_add_connector_router)
|
||||
router.include_router(dropbox_add_connector_router)
|
||||
router.include_router(new_llm_config_router) # LLM configs with prompt configuration
|
||||
router.include_router(model_connections_router) # Connection-centric model catalog
|
||||
router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter
|
||||
router.include_router(logs_router)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
"""
|
||||
Image Generation routes:
|
||||
- CRUD for ImageGenerationConfig (user-created image model configs)
|
||||
- Global image gen configs endpoint (from YAML)
|
||||
- Image generation execution (calls litellm.aimage_generation())
|
||||
- CRUD for ImageGeneration records (results)
|
||||
- Image serving endpoint (serves b64_json images from DB, protected by signed tokens)
|
||||
|
|
@ -21,7 +19,6 @@ from sqlalchemy.orm import selectinload
|
|||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGeneration,
|
||||
ImageGenerationConfig,
|
||||
Model,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
|
|
@ -30,14 +27,14 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
GlobalImageGenConfigRead,
|
||||
ImageGenerationConfigCreate,
|
||||
ImageGenerationConfigRead,
|
||||
ImageGenerationConfigUpdate,
|
||||
ImageGenerationCreate,
|
||||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
)
|
||||
from app.services.auto_model_pin_service import (
|
||||
auto_model_candidates,
|
||||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.billable_calls import (
|
||||
DEFAULT_IMAGE_RESERVE_MICROS,
|
||||
QuotaInsufficientError,
|
||||
|
|
@ -47,12 +44,8 @@ from app.services.image_gen_router_service import (
|
|||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.auto_model_pin_service import (
|
||||
auto_model_candidates,
|
||||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.model_resolver import to_litellm
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_resolver import to_litellm
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.signed_image_urls import verify_image_token
|
||||
|
|
@ -131,14 +124,14 @@ async def _execute_image_generation(
|
|||
Call litellm.aimage_generation() with the appropriate config.
|
||||
|
||||
Resolution order:
|
||||
1. Explicit image_generation_config_id on the request
|
||||
2. Search space's image_generation_config_id preference
|
||||
1. Explicit image_gen_model_id on the request
|
||||
2. Search space's image_gen_model_id preference
|
||||
3. Falls back to Auto mode if available
|
||||
"""
|
||||
config_id = image_gen.image_generation_config_id
|
||||
config_id = image_gen.image_gen_model_id
|
||||
if config_id is None:
|
||||
config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
image_gen.image_generation_config_id = config_id
|
||||
image_gen.image_gen_model_id = config_id
|
||||
|
||||
# Build kwargs
|
||||
gen_kwargs = {}
|
||||
|
|
@ -163,7 +156,7 @@ async def _execute_image_generation(
|
|||
if not candidates:
|
||||
raise ValueError("No image-generation models are available for Auto mode")
|
||||
config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"])
|
||||
image_gen.image_generation_config_id = config_id
|
||||
image_gen.image_gen_model_id = config_id
|
||||
|
||||
if config_id < 0:
|
||||
global_model = _get_global_model(config_id)
|
||||
|
|
@ -228,266 +221,6 @@ async def _execute_image_generation(
|
|||
image_gen.model = hidden["model"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Image Generation Configs (from YAML)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/global-image-generation-configs",
|
||||
response_model=list[GlobalImageGenConfigRead],
|
||||
)
|
||||
async def get_global_image_gen_configs(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get all global image generation configs. API keys are hidden."""
|
||||
try:
|
||||
global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS
|
||||
safe_configs = []
|
||||
|
||||
if global_configs and len(global_configs) > 0:
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes across available image generation providers.",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode currently treated as free until per-deployment
|
||||
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": billing_tier,
|
||||
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||
# selector's premium badge logic keys off the same
|
||||
# field across chat / image / vision tabs.
|
||||
"is_premium": billing_tier == "premium",
|
||||
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
|
||||
}
|
||||
)
|
||||
|
||||
return safe_configs
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch global image generation configs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ImageGenerationConfig CRUD
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead)
|
||||
async def create_image_gen_config(
|
||||
config_data: ImageGenerationConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create a new image generation config for a search space."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config_data.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to create image generation configs in this search space",
|
||||
)
|
||||
|
||||
db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=user.id)
|
||||
session.add(db_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to create ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead])
|
||||
async def list_image_gen_configs(
|
||||
search_space_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List image generation configs for a search space."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to view image generation configs in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig)
|
||||
.filter(ImageGenerationConfig.search_space_id == search_space_id)
|
||||
.order_by(ImageGenerationConfig.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list ImageGenerationConfigs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
|
||||
)
|
||||
async def get_image_gen_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get a specific image generation config by ID."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to view image generation configs in this search space",
|
||||
)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put(
|
||||
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
|
||||
)
|
||||
async def update_image_gen_config(
|
||||
config_id: int,
|
||||
update_data: ImageGenerationConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Update an existing image generation config."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to update image generation configs in this search space",
|
||||
)
|
||||
|
||||
for key, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(db_config, key, value)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to update ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/image-generation-configs/{config_id}", response_model=dict)
|
||||
async def delete_image_gen_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete an image generation config."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_DELETE.value,
|
||||
"You don't have permission to delete image generation configs in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_config)
|
||||
await session.commit()
|
||||
return {
|
||||
"message": "Image generation config deleted successfully",
|
||||
"id": config_id,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to delete ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image Generation Execution + Results CRUD
|
||||
# =============================================================================
|
||||
|
|
@ -536,7 +269,7 @@ async def create_image_generation(
|
|||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
|
||||
session, data.image_generation_config_id, search_space
|
||||
session, data.image_gen_model_id, search_space
|
||||
)
|
||||
|
||||
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
|
||||
|
|
@ -562,7 +295,7 @@ async def create_image_generation(
|
|||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
image_gen_model_id=data.image_gen_model_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from app.db import (
|
|||
ConnectionScope,
|
||||
Model,
|
||||
ModelSource,
|
||||
NewChatThread,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
User,
|
||||
|
|
@ -708,12 +709,26 @@ async def update_model_roles(
|
|||
search_space = await _get_search_space(session, search_space_id)
|
||||
updates = data.model_dump(exclude_unset=True)
|
||||
if "chat_model_id" in updates:
|
||||
search_space.chat_model_id = await _validate_role_model_id(
|
||||
previous_chat_model_id = search_space.chat_model_id
|
||||
next_chat_model_id = await _validate_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=updates["chat_model_id"],
|
||||
capability="chat",
|
||||
)
|
||||
search_space.chat_model_id = next_chat_model_id
|
||||
if next_chat_model_id != previous_chat_model_id:
|
||||
await session.execute(
|
||||
update(NewChatThread)
|
||||
.where(NewChatThread.search_space_id == search_space_id)
|
||||
.values(pinned_llm_config_id=None)
|
||||
)
|
||||
logger.info(
|
||||
"Cleared auto model pins for search_space_id=%s after chat_model_id change (%s -> %s)",
|
||||
search_space_id,
|
||||
previous_chat_model_id,
|
||||
next_chat_model_id,
|
||||
)
|
||||
if "vision_model_id" in updates:
|
||||
search_space.vision_model_id = await _validate_role_model_id(
|
||||
session,
|
||||
|
|
|
|||
|
|
@ -1,480 +0,0 @@
|
|||
"""
|
||||
API routes for NewLLMConfig CRUD operations.
|
||||
|
||||
NewLLMConfig combines model settings with prompt configuration:
|
||||
- LLM provider, model, API key, etc.
|
||||
- Configurable system instructions
|
||||
- Citation toggle
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
NewLLMConfig,
|
||||
Permission,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.prompts.default_system_instructions import get_default_system_instructions
|
||||
from app.schemas import (
|
||||
DefaultSystemInstructionsResponse,
|
||||
GlobalNewLLMConfigRead,
|
||||
NewLLMConfigCreate,
|
||||
NewLLMConfigRead,
|
||||
NewLLMConfigUpdate,
|
||||
)
|
||||
from app.services.llm_service import validate_llm_config
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
|
||||
"""Augment a BYOK chat config row with the derived ``supports_image_input``.
|
||||
|
||||
There is no DB column for ``supports_image_input`` — the value is
|
||||
resolved at the API boundary from LiteLLM's authoritative model map
|
||||
(default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
|
||||
the response shape consistent across list / detail / create / update
|
||||
endpoints without having to remember to set the field at every call
|
||||
site.
|
||||
"""
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=provider_value.lower(),
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
)
|
||||
# ``model_validate`` runs the Pydantic conversion using the ORM
|
||||
# attribute access path enabled by ``ConfigDict(from_attributes=True)``,
|
||||
# then we layer the derived field on. ``model_copy(update=...)`` keeps
|
||||
# the surface immutable from the caller's perspective.
|
||||
base_read = NewLLMConfigRead.model_validate(config)
|
||||
return base_read.model_copy(update={"supports_image_input": supports_image_input})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Configs Routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/global-new-llm-configs", response_model=list[GlobalNewLLMConfigRead])
|
||||
async def get_global_new_llm_configs(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get all available global NewLLMConfig configurations.
|
||||
These are pre-configured by the system administrator and available to all users.
|
||||
API keys are not exposed through this endpoint.
|
||||
|
||||
Includes:
|
||||
- Auto mode (ID 0): Uses LiteLLM Router for automatic load balancing
|
||||
- Global configs (negative IDs): Individual pre-configured LLM providers
|
||||
"""
|
||||
try:
|
||||
global_configs = config.GLOBAL_LLM_CONFIGS
|
||||
safe_configs = []
|
||||
|
||||
# Only include Auto mode if there are actual global configs to route to
|
||||
# Auto mode requires at least one global config with valid API key
|
||||
if global_configs and len(global_configs) > 0:
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling. Recommended for most users.",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"litellm_params": {},
|
||||
"system_instructions": "",
|
||||
"use_default_system_instructions": True,
|
||||
"citations_enabled": True,
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
"anonymous_enabled": False,
|
||||
"seo_enabled": False,
|
||||
"seo_slug": None,
|
||||
"seo_title": None,
|
||||
"seo_description": None,
|
||||
"quota_reserve_tokens": None,
|
||||
# Auto routes across the configured pool, which usually
|
||||
# includes at least one vision-capable deployment, so
|
||||
# treat Auto as image-capable. The router itself will
|
||||
# still pick a vision-capable deployment for messages
|
||||
# carrying image_url blocks (LiteLLM Router falls back
|
||||
# on ``404`` per its ``allowed_fails`` policy).
|
||||
"supports_image_input": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Add individual global configs
|
||||
for cfg in global_configs:
|
||||
# Capability resolution: explicit value (YAML override or OR
|
||||
# `_supports_image_input(model)` payload baked in by the
|
||||
# OpenRouter integration service) wins. Fall back to the
|
||||
# LiteLLM-driven helper which default-allows on unknown so
|
||||
# we don't hide vision-capable models that happen to lack a
|
||||
# YAML annotation. The streaming task safety net is the
|
||||
# only place a False ever blocks.
|
||||
if "supports_image_input" in cfg:
|
||||
supports_image_input = bool(cfg.get("supports_image_input"))
|
||||
else:
|
||||
cfg_litellm_params = cfg.get("litellm_params") or {}
|
||||
cfg_base_model = (
|
||||
cfg_litellm_params.get("base_model")
|
||||
if isinstance(cfg_litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=cfg_base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
|
||||
safe_config = {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
# New prompt configuration fields
|
||||
"system_instructions": cfg.get("system_instructions", ""),
|
||||
"use_default_system_instructions": cfg.get(
|
||||
"use_default_system_instructions", True
|
||||
),
|
||||
"citations_enabled": cfg.get("citations_enabled", True),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"is_premium": cfg.get("billing_tier", "free") == "premium",
|
||||
"anonymous_enabled": cfg.get("anonymous_enabled", False),
|
||||
"seo_enabled": cfg.get("seo_enabled", False),
|
||||
"seo_slug": cfg.get("seo_slug"),
|
||||
"seo_title": cfg.get("seo_title"),
|
||||
"seo_description": cfg.get("seo_description"),
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"supports_image_input": supports_image_input,
|
||||
}
|
||||
safe_configs.append(safe_config)
|
||||
|
||||
return safe_configs
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch global NewLLMConfigs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch global configurations: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CRUD Routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/new-llm-configs", response_model=NewLLMConfigRead)
|
||||
async def create_new_llm_config(
|
||||
config_data: NewLLMConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Create a new NewLLMConfig for a search space.
|
||||
Requires LLM_CONFIGS_CREATE permission.
|
||||
"""
|
||||
try:
|
||||
# Verify user has permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config_data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create LLM configurations in this search space",
|
||||
)
|
||||
|
||||
# Validate the LLM configuration by making a test API call
|
||||
is_valid, error_message = await validate_llm_config(
|
||||
provider=config_data.provider.value,
|
||||
model_name=config_data.model_name,
|
||||
api_key=config_data.api_key,
|
||||
api_base=config_data.api_base,
|
||||
custom_provider=config_data.custom_provider,
|
||||
litellm_params=config_data.litellm_params,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid LLM configuration: {error_message}",
|
||||
)
|
||||
|
||||
# Create the config with user association
|
||||
db_config = NewLLMConfig(**config_data.model_dump(), user_id=user.id)
|
||||
session.add(db_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
|
||||
return _serialize_byok_config(db_config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to create NewLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/new-llm-configs", response_model=list[NewLLMConfigRead])
|
||||
async def list_new_llm_configs(
|
||||
search_space_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get all NewLLMConfigs for a search space.
|
||||
Requires LLM_CONFIGS_READ permission.
|
||||
"""
|
||||
try:
|
||||
# Verify user has permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_READ.value,
|
||||
"You don't have permission to view LLM configurations in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig)
|
||||
.filter(NewLLMConfig.search_space_id == search_space_id)
|
||||
.order_by(NewLLMConfig.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list NewLLMConfigs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configurations: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/new-llm-configs/default-system-instructions",
|
||||
response_model=DefaultSystemInstructionsResponse,
|
||||
)
|
||||
async def get_default_system_instructions_endpoint(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get the default SURFSENSE_SYSTEM_INSTRUCTIONS template.
|
||||
Useful for pre-populating the UI when creating a new configuration.
|
||||
"""
|
||||
return DefaultSystemInstructionsResponse(
|
||||
default_system_instructions=get_default_system_instructions()
|
||||
)
|
||||
|
||||
|
||||
@router.get("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead)
|
||||
async def get_new_llm_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get a specific NewLLMConfig by ID.
|
||||
Requires LLM_CONFIGS_READ permission.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
# Verify user has permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config.search_space_id,
|
||||
Permission.LLM_CONFIGS_READ.value,
|
||||
"You don't have permission to view LLM configurations in this search space",
|
||||
)
|
||||
|
||||
return _serialize_byok_config(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get NewLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead)
|
||||
async def update_new_llm_config(
|
||||
config_id: int,
|
||||
update_data: NewLLMConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Update an existing NewLLMConfig.
|
||||
Requires LLM_CONFIGS_UPDATE permission.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
# Verify user has permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config.search_space_id,
|
||||
Permission.LLM_CONFIGS_UPDATE.value,
|
||||
"You don't have permission to update LLM configurations in this search space",
|
||||
)
|
||||
|
||||
update_dict = update_data.model_dump(exclude_unset=True)
|
||||
|
||||
# If updating LLM settings, validate them
|
||||
if any(
|
||||
key in update_dict
|
||||
for key in [
|
||||
"provider",
|
||||
"model_name",
|
||||
"api_key",
|
||||
"api_base",
|
||||
"custom_provider",
|
||||
"litellm_params",
|
||||
]
|
||||
):
|
||||
# Build the validation config from existing + updates
|
||||
validation_config = {
|
||||
"provider": update_dict.get("provider", config.provider).value
|
||||
if hasattr(update_dict.get("provider", config.provider), "value")
|
||||
else update_dict.get("provider", config.provider.value),
|
||||
"model_name": update_dict.get("model_name", config.model_name),
|
||||
"api_key": update_dict.get("api_key", config.api_key),
|
||||
"api_base": update_dict.get("api_base", config.api_base),
|
||||
"custom_provider": update_dict.get(
|
||||
"custom_provider", config.custom_provider
|
||||
),
|
||||
"litellm_params": update_dict.get(
|
||||
"litellm_params", config.litellm_params
|
||||
),
|
||||
}
|
||||
|
||||
is_valid, error_message = await validate_llm_config(
|
||||
provider=validation_config["provider"],
|
||||
model_name=validation_config["model_name"],
|
||||
api_key=validation_config["api_key"],
|
||||
api_base=validation_config["api_base"],
|
||||
custom_provider=validation_config["custom_provider"],
|
||||
litellm_params=validation_config["litellm_params"],
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid LLM configuration: {error_message}",
|
||||
)
|
||||
|
||||
# Apply updates
|
||||
for key, value in update_dict.items():
|
||||
setattr(config, key, value)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
|
||||
return _serialize_byok_config(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to update NewLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/new-llm-configs/{config_id}", response_model=dict)
|
||||
async def delete_new_llm_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Delete a NewLLMConfig.
|
||||
Requires LLM_CONFIGS_DELETE permission.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
# Verify user has permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config.search_space_id,
|
||||
Permission.LLM_CONFIGS_DELETE.value,
|
||||
"You don't have permission to delete LLM configurations in this search space",
|
||||
)
|
||||
|
||||
await session.delete(config)
|
||||
await session.commit()
|
||||
|
||||
return {"message": "Configuration deleted successfully", "id": config_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to delete NewLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete configuration: {e!s}"
|
||||
) from e
|
||||
|
|
@ -1,27 +1,20 @@
|
|||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import func, update
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGenerationConfig,
|
||||
NewChatThread,
|
||||
NewLLMConfig,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
SearchSpaceRole,
|
||||
User,
|
||||
VisionLLMConfig,
|
||||
get_async_session,
|
||||
get_default_roles_config,
|
||||
)
|
||||
from app.schemas import (
|
||||
LLMPreferencesRead,
|
||||
LLMPreferencesUpdate,
|
||||
SearchSpaceCreate,
|
||||
SearchSpaceRead,
|
||||
SearchSpaceUpdate,
|
||||
|
|
@ -377,357 +370,6 @@ async def delete_search_space(
|
|||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LLM Preferences Routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def _get_llm_config_by_id(
|
||||
session: AsyncSession, config_id: int | None
|
||||
) -> dict | None:
|
||||
"""
|
||||
Get an LLM config by ID as a dictionary. Returns database config for positive IDs,
|
||||
global config for negative IDs, Auto mode config for ID 0, or None if ID is None.
|
||||
"""
|
||||
if config_id is None:
|
||||
return None
|
||||
|
||||
# Auto mode (ID 0) - uses LiteLLM Router for load balancing
|
||||
if config_id == 0:
|
||||
return {
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"litellm_params": {},
|
||||
"system_instructions": "",
|
||||
"use_default_system_instructions": True,
|
||||
"citations_enabled": True,
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
# Global config - find from YAML
|
||||
global_configs = config.GLOBAL_LLM_CONFIGS
|
||||
for cfg in global_configs:
|
||||
if cfg.get("id") == config_id:
|
||||
return {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base"),
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"system_instructions": cfg.get("system_instructions", ""),
|
||||
"use_default_system_instructions": cfg.get(
|
||||
"use_default_system_instructions", True
|
||||
),
|
||||
"citations_enabled": cfg.get("citations_enabled", True),
|
||||
"is_global": True,
|
||||
}
|
||||
return None
|
||||
else:
|
||||
# Database config - convert to dict
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if db_config:
|
||||
return {
|
||||
"id": db_config.id,
|
||||
"name": db_config.name,
|
||||
"description": db_config.description,
|
||||
"provider": db_config.provider.value if db_config.provider else None,
|
||||
"custom_provider": db_config.custom_provider,
|
||||
"model_name": db_config.model_name,
|
||||
"api_key": db_config.api_key,
|
||||
"api_base": db_config.api_base,
|
||||
"litellm_params": db_config.litellm_params or {},
|
||||
"system_instructions": db_config.system_instructions or "",
|
||||
"use_default_system_instructions": db_config.use_default_system_instructions,
|
||||
"citations_enabled": db_config.citations_enabled,
|
||||
"created_at": db_config.created_at.isoformat()
|
||||
if db_config.created_at
|
||||
else None,
|
||||
"search_space_id": db_config.search_space_id,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def _get_image_gen_config_by_id(
|
||||
session: AsyncSession, config_id: int | None
|
||||
) -> dict | None:
|
||||
"""
|
||||
Get an image generation config by ID as a dictionary.
|
||||
Returns Auto mode for ID 0, global config for negative IDs,
|
||||
DB ImageGenerationConfig for positive IDs, or None.
|
||||
"""
|
||||
if config_id is None:
|
||||
return None
|
||||
|
||||
if config_id == 0:
|
||||
return {
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes requests across available image generation providers",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
# Positive ID: query ImageGenerationConfig table
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if db_config:
|
||||
return {
|
||||
"id": db_config.id,
|
||||
"name": db_config.name,
|
||||
"description": db_config.description,
|
||||
"provider": db_config.provider.value if db_config.provider else None,
|
||||
"custom_provider": db_config.custom_provider,
|
||||
"model_name": db_config.model_name,
|
||||
"api_base": db_config.api_base,
|
||||
"api_version": db_config.api_version,
|
||||
"litellm_params": db_config.litellm_params or {},
|
||||
"created_at": db_config.created_at.isoformat()
|
||||
if db_config.created_at
|
||||
else None,
|
||||
"search_space_id": db_config.search_space_id,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def _get_vision_llm_config_by_id(
|
||||
session: AsyncSession, config_id: int | None
|
||||
) -> dict | None:
|
||||
if config_id is None:
|
||||
return None
|
||||
|
||||
if config_id == 0:
|
||||
return {
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes requests across available vision LLM providers",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if db_config:
|
||||
return {
|
||||
"id": db_config.id,
|
||||
"name": db_config.name,
|
||||
"description": db_config.description,
|
||||
"provider": db_config.provider.value if db_config.provider else None,
|
||||
"custom_provider": db_config.custom_provider,
|
||||
"model_name": db_config.model_name,
|
||||
"api_base": db_config.api_base,
|
||||
"api_version": db_config.api_version,
|
||||
"litellm_params": db_config.litellm_params or {},
|
||||
"created_at": db_config.created_at.isoformat()
|
||||
if db_config.created_at
|
||||
else None,
|
||||
"search_space_id": db_config.search_space_id,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search-spaces/{search_space_id}/llm-preferences",
|
||||
response_model=LLMPreferencesRead,
|
||||
)
|
||||
async def get_llm_preferences(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get LLM preferences (role assignments) for a search space.
|
||||
Requires LLM_CONFIGS_READ permission.
|
||||
"""
|
||||
try:
|
||||
# Check permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_READ.value,
|
||||
"You don't have permission to view LLM preferences",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
# Get full config objects for each role
|
||||
agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id)
|
||||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
vision_llm_config = await _get_vision_llm_config_by_id(
|
||||
session, search_space.vision_llm_config_id
|
||||
)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
vision_llm_config_id=search_space.vision_llm_config_id,
|
||||
agent_llm=agent_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
vision_llm_config=vision_llm_config,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get LLM preferences")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get LLM preferences: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put(
|
||||
"/search-spaces/{search_space_id}/llm-preferences",
|
||||
response_model=LLMPreferencesRead,
|
||||
)
|
||||
async def update_llm_preferences(
|
||||
search_space_id: int,
|
||||
preferences: LLMPreferencesUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Update LLM preferences (role assignments) for a search space.
|
||||
Requires LLM_CONFIGS_UPDATE permission.
|
||||
"""
|
||||
try:
|
||||
# Check permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_UPDATE.value,
|
||||
"You don't have permission to update LLM preferences",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
# Update preferences
|
||||
update_data = preferences.model_dump(exclude_unset=True)
|
||||
previous_agent_llm_id = search_space.agent_llm_id
|
||||
for key, value in update_data.items():
|
||||
setattr(search_space, key, value)
|
||||
|
||||
agent_llm_changed = (
|
||||
"agent_llm_id" in update_data
|
||||
and update_data["agent_llm_id"] != previous_agent_llm_id
|
||||
)
|
||||
if agent_llm_changed:
|
||||
await session.execute(
|
||||
update(NewChatThread)
|
||||
.where(NewChatThread.search_space_id == search_space_id)
|
||||
.values(pinned_llm_config_id=None)
|
||||
)
|
||||
logger.info(
|
||||
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
|
||||
search_space_id,
|
||||
previous_agent_llm_id,
|
||||
update_data["agent_llm_id"],
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(search_space)
|
||||
|
||||
# Get full config objects for response
|
||||
agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id)
|
||||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
vision_llm_config = await _get_vision_llm_config_by_id(
|
||||
session, search_space.vision_llm_config_id
|
||||
)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
vision_llm_config_id=search_space.vision_llm_config_id,
|
||||
agent_llm=agent_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
vision_llm_config=vision_llm_config,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to update LLM preferences")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update LLM preferences: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/searchspaces/{search_space_id}/snapshots")
|
||||
async def list_search_space_snapshots(
|
||||
search_space_id: int,
|
||||
|
|
|
|||
|
|
@ -1,304 +0,0 @@
|
|||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
Permission,
|
||||
User,
|
||||
VisionLLMConfig,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
GlobalVisionLLMConfigRead,
|
||||
VisionLLMConfigCreate,
|
||||
VisionLLMConfigRead,
|
||||
VisionLLMConfigUpdate,
|
||||
)
|
||||
from app.services.vision_model_list_service import get_vision_model_list
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Vision Model Catalogue (from OpenRouter, filtered for image-input models)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class VisionModelListItem(BaseModel):
|
||||
value: str
|
||||
label: str
|
||||
provider: str
|
||||
context_window: str | None = None
|
||||
|
||||
|
||||
@router.get("/vision-models", response_model=list[VisionModelListItem])
|
||||
async def list_vision_models(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Return vision-capable models sourced from OpenRouter (filtered by image input)."""
|
||||
try:
|
||||
return await get_vision_model_list()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch vision model list")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch vision model list: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Vision LLM Configs (from YAML)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/global-vision-llm-configs",
|
||||
response_model=list[GlobalVisionLLMConfigRead],
|
||||
)
|
||||
async def get_global_vision_llm_configs(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
global_configs = config.GLOBAL_VISION_LLM_CONFIGS
|
||||
safe_configs = []
|
||||
|
||||
if global_configs and len(global_configs) > 0:
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes across available vision LLM providers.",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode treated as free until per-deployment billing-tier
|
||||
# surfacing lands; see ``get_vision_llm`` for parity.
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": billing_tier,
|
||||
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||
# selector's premium badge logic keys off the same
|
||||
# field across chat / image / vision tabs.
|
||||
"is_premium": billing_tier == "premium",
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"input_cost_per_token": cfg.get("input_cost_per_token"),
|
||||
"output_cost_per_token": cfg.get("output_cost_per_token"),
|
||||
}
|
||||
)
|
||||
|
||||
return safe_configs
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch global vision LLM configs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# VisionLLMConfig CRUD
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/vision-llm-configs", response_model=VisionLLMConfigRead)
|
||||
async def create_vision_llm_config(
|
||||
config_data: VisionLLMConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config_data.search_space_id,
|
||||
Permission.VISION_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
db_config = VisionLLMConfig(**config_data.model_dump(), user_id=user.id)
|
||||
session.add(db_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to create VisionLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/vision-llm-configs", response_model=list[VisionLLMConfigRead])
|
||||
async def list_vision_llm_configs(
|
||||
search_space_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.VISION_CONFIGS_READ.value,
|
||||
"You don't have permission to view vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig)
|
||||
.filter(VisionLLMConfig.search_space_id == search_space_id)
|
||||
.order_by(VisionLLMConfig.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list VisionLLMConfigs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead)
|
||||
async def get_vision_llm_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.VISION_CONFIGS_READ.value,
|
||||
"You don't have permission to view vision LLM configs in this search space",
|
||||
)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get VisionLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead)
|
||||
async def update_vision_llm_config(
|
||||
config_id: int,
|
||||
update_data: VisionLLMConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.VISION_CONFIGS_CREATE.value,
|
||||
"You don't have permission to update vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
for key, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(db_config, key, value)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to update VisionLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/vision-llm-configs/{config_id}", response_model=dict)
|
||||
async def delete_vision_llm_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.VISION_CONFIGS_DELETE.value,
|
||||
"You don't have permission to delete vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_config)
|
||||
await session.commit()
|
||||
return {
|
||||
"message": "Vision LLM config deleted successfully",
|
||||
"id": config_id,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to delete VisionLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete config: {e!s}"
|
||||
) from e
|
||||
|
|
@ -34,11 +34,6 @@ from .folders import (
|
|||
)
|
||||
from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest
|
||||
from .image_generation import (
|
||||
GlobalImageGenConfigRead,
|
||||
ImageGenerationConfigCreate,
|
||||
ImageGenerationConfigPublic,
|
||||
ImageGenerationConfigRead,
|
||||
ImageGenerationConfigUpdate,
|
||||
ImageGenerationCreate,
|
||||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
|
|
@ -74,16 +69,6 @@ from .new_chat import (
|
|||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
)
|
||||
from .new_llm_config import (
|
||||
DefaultSystemInstructionsResponse,
|
||||
GlobalNewLLMConfigRead,
|
||||
LLMPreferencesRead,
|
||||
LLMPreferencesUpdate,
|
||||
NewLLMConfigCreate,
|
||||
NewLLMConfigPublic,
|
||||
NewLLMConfigRead,
|
||||
NewLLMConfigUpdate,
|
||||
)
|
||||
from .rbac_schemas import (
|
||||
InviteAcceptRequest,
|
||||
InviteAcceptResponse,
|
||||
|
|
@ -142,14 +127,6 @@ from .video_presentations import (
|
|||
VideoPresentationRead,
|
||||
VideoPresentationUpdate,
|
||||
)
|
||||
from .vision_llm import (
|
||||
GlobalVisionLLMConfigRead,
|
||||
VisionLLMConfigCreate,
|
||||
VisionLLMConfigPublic,
|
||||
VisionLLMConfigRead,
|
||||
VisionLLMConfigUpdate,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Folder schemas
|
||||
"BulkDocumentMove",
|
||||
|
|
@ -169,7 +146,6 @@ __all__ = [
|
|||
"CreditPurchaseHistoryResponse",
|
||||
"CreditPurchaseRead",
|
||||
"CreditStripeStatusResponse",
|
||||
"DefaultSystemInstructionsResponse",
|
||||
# Document schemas
|
||||
"DocumentBase",
|
||||
"DocumentMove",
|
||||
|
|
@ -192,19 +168,10 @@ __all__ = [
|
|||
"FolderRead",
|
||||
"FolderReorder",
|
||||
"FolderUpdate",
|
||||
"GlobalImageGenConfigRead",
|
||||
"GlobalNewLLMConfigRead",
|
||||
# Vision LLM Config schemas
|
||||
"GlobalVisionLLMConfigRead",
|
||||
"GoogleDriveIndexRequest",
|
||||
"GoogleDriveIndexingOptions",
|
||||
# Base schemas
|
||||
"IDModel",
|
||||
# Image Generation Config schemas
|
||||
"ImageGenerationConfigCreate",
|
||||
"ImageGenerationConfigPublic",
|
||||
"ImageGenerationConfigRead",
|
||||
"ImageGenerationConfigUpdate",
|
||||
# Image Generation schemas
|
||||
"ImageGenerationCreate",
|
||||
"ImageGenerationListRead",
|
||||
|
|
@ -216,9 +183,6 @@ __all__ = [
|
|||
"InviteInfoResponse",
|
||||
"InviteRead",
|
||||
"InviteUpdate",
|
||||
# LLM Preferences schemas
|
||||
"LLMPreferencesRead",
|
||||
"LLMPreferencesUpdate",
|
||||
# Log schemas
|
||||
"LogBase",
|
||||
"LogCreate",
|
||||
|
|
@ -255,11 +219,6 @@ __all__ = [
|
|||
"NewChatThreadRead",
|
||||
"NewChatThreadUpdate",
|
||||
"NewChatThreadWithMessages",
|
||||
# NewLLMConfig schemas
|
||||
"NewLLMConfigCreate",
|
||||
"NewLLMConfigPublic",
|
||||
"NewLLMConfigRead",
|
||||
"NewLLMConfigUpdate",
|
||||
"PagePurchaseHistoryResponse",
|
||||
"PagePurchaseRead",
|
||||
"PaginatedResponse",
|
||||
|
|
@ -303,8 +262,4 @@ __all__ = [
|
|||
"VideoPresentationCreate",
|
||||
"VideoPresentationRead",
|
||||
"VideoPresentationUpdate",
|
||||
"VisionLLMConfigCreate",
|
||||
"VisionLLMConfigPublic",
|
||||
"VisionLLMConfigRead",
|
||||
"VisionLLMConfigUpdate",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,109 +1,10 @@
|
|||
"""
|
||||
Pydantic schemas for Image Generation configs and generation requests.
|
||||
"""Pydantic schemas for image generation requests/results."""
|
||||
|
||||
ImageGenerationConfig: CRUD schemas for user-created image gen model configs.
|
||||
ImageGeneration: Schemas for the actual image generation requests/results.
|
||||
GlobalImageGenConfigRead: Schema for admin-configured YAML configs.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.db import ImageGenProvider
|
||||
|
||||
# =============================================================================
|
||||
# ImageGenerationConfig CRUD Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ImageGenerationConfigBase(BaseModel):
|
||||
"""Base schema with fields for ImageGenerationConfig."""
|
||||
|
||||
name: str = Field(
|
||||
..., max_length=100, description="User-friendly name for the config"
|
||||
)
|
||||
description: str | None = Field(
|
||||
None, max_length=500, description="Optional description"
|
||||
)
|
||||
provider: ImageGenProvider = Field(
|
||||
...,
|
||||
description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)",
|
||||
)
|
||||
custom_provider: str | None = Field(
|
||||
None, max_length=100, description="Custom provider name"
|
||||
)
|
||||
model_name: str = Field(
|
||||
..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)"
|
||||
)
|
||||
api_key: str = Field(..., description="API key for the provider")
|
||||
api_base: str | None = Field(
|
||||
None, max_length=500, description="Optional API base URL"
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
None,
|
||||
max_length=50,
|
||||
description="Azure-specific API version (e.g., '2024-02-15-preview')",
|
||||
)
|
||||
litellm_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional LiteLLM parameters"
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationConfigCreate(ImageGenerationConfigBase):
|
||||
"""Schema for creating a new ImageGenerationConfig."""
|
||||
|
||||
search_space_id: int = Field(
|
||||
..., description="Search space ID to associate the config with"
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationConfigUpdate(BaseModel):
|
||||
"""Schema for updating an existing ImageGenerationConfig. All fields optional."""
|
||||
|
||||
name: str | None = Field(None, max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
provider: ImageGenProvider | None = None
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str | None = Field(None, max_length=100)
|
||||
api_key: str | None = None
|
||||
api_base: str | None = Field(None, max_length=500)
|
||||
api_version: str | None = Field(None, max_length=50)
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ImageGenerationConfigRead(ImageGenerationConfigBase):
|
||||
"""Schema for reading an ImageGenerationConfig (includes id and timestamps)."""
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ImageGenerationConfigPublic(BaseModel):
|
||||
"""Public schema that hides the API key (for list views)."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: ImageGenProvider
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ImageGeneration (request/result) Schemas
|
||||
# =============================================================================
|
||||
|
|
@ -136,12 +37,12 @@ class ImageGenerationCreate(BaseModel):
|
|||
search_space_id: int = Field(
|
||||
..., description="Search space ID to associate the generation with"
|
||||
)
|
||||
image_generation_config_id: int | None = Field(
|
||||
image_gen_model_id: int | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Image generation config ID. "
|
||||
"0 = Auto mode (router), negative = global YAML config, positive = DB config. "
|
||||
"If not provided, uses the search space's image_generation_config_id preference."
|
||||
"Image generation model ID. "
|
||||
"0 = Auto mode, negative = GLOBAL model, positive = BYOK Model row. "
|
||||
"If not provided, uses the search space's image_gen_model_id preference."
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -157,7 +58,7 @@ class ImageGenerationRead(BaseModel):
|
|||
size: str | None = None
|
||||
style: str | None = None
|
||||
response_format: str | None = None
|
||||
image_generation_config_id: int | None = None
|
||||
image_gen_model_id: int | None = None
|
||||
response_data: dict[str, Any] | None = None
|
||||
error_message: str | None = None
|
||||
search_space_id: int
|
||||
|
|
@ -204,57 +105,3 @@ class ImageGenerationListRead(BaseModel):
|
|||
image_count=image_count,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Image Gen Config (from YAML)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class GlobalImageGenConfigRead(BaseModel):
|
||||
"""
|
||||
Schema for reading global image generation configs from YAML.
|
||||
Global configs have negative IDs. API key is hidden.
|
||||
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
|
||||
|
||||
The ``billing_tier`` field allows the frontend to show a Premium/Free
|
||||
badge and (more importantly) tells the backend whether to debit the
|
||||
user's premium credit pool when this config is used. ``"free"`` is
|
||||
the default for backward compatibility — admins must explicitly opt
|
||||
a global config into ``"premium"``.
|
||||
"""
|
||||
|
||||
id: int = Field(
|
||||
...,
|
||||
description="Config ID: 0 for Auto mode, negative for global configs",
|
||||
)
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: str
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
is_global: bool = True
|
||||
is_auto_mode: bool = False
|
||||
billing_tier: str = Field(
|
||||
default="free",
|
||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||
)
|
||||
is_premium: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Convenience boolean derived server-side from "
|
||||
"``billing_tier == 'premium'``. The new-chat model selector "
|
||||
"keys its Free/Premium badge off this field for parity with "
|
||||
"chat (`GlobalLLMConfigRead.is_premium`)."
|
||||
),
|
||||
)
|
||||
quota_reserve_micros: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional override for the reservation amount (in micro-USD) used when "
|
||||
"this image generation is premium. Falls back to "
|
||||
"QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted."
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,256 +0,0 @@
|
|||
"""
|
||||
Pydantic schemas for the NewLLMConfig API.
|
||||
|
||||
NewLLMConfig combines model settings with prompt configuration:
|
||||
- LLM provider, model, API key, etc.
|
||||
- Configurable system instructions
|
||||
- Citation toggle
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.db import LiteLLMProvider
|
||||
|
||||
|
||||
class NewLLMConfigBase(BaseModel):
|
||||
"""Base schema with common fields for NewLLMConfig."""
|
||||
|
||||
name: str = Field(
|
||||
..., max_length=100, description="User-friendly name for the configuration"
|
||||
)
|
||||
description: str | None = Field(
|
||||
None, max_length=500, description="Optional description"
|
||||
)
|
||||
|
||||
# Model Configuration
|
||||
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
|
||||
custom_provider: str | None = Field(
|
||||
None, max_length=100, description="Custom provider name when provider is CUSTOM"
|
||||
)
|
||||
model_name: str = Field(
|
||||
..., max_length=100, description="Model name without provider prefix"
|
||||
)
|
||||
api_key: str = Field(..., description="API key for the provider")
|
||||
api_base: str | None = Field(
|
||||
None, max_length=500, description="Optional API base URL"
|
||||
)
|
||||
litellm_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional LiteLLM parameters"
|
||||
)
|
||||
|
||||
# Prompt Configuration
|
||||
system_instructions: str = Field(
|
||||
default="",
|
||||
description="Custom system instructions. Empty string uses default SURFSENSE_SYSTEM_INSTRUCTIONS.",
|
||||
)
|
||||
use_default_system_instructions: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use default instructions when system_instructions is empty",
|
||||
)
|
||||
citations_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to include citation instructions in the system prompt",
|
||||
)
|
||||
|
||||
|
||||
class NewLLMConfigCreate(NewLLMConfigBase):
|
||||
"""Schema for creating a new NewLLMConfig."""
|
||||
|
||||
search_space_id: int = Field(
|
||||
..., description="Search space ID to associate the config with"
|
||||
)
|
||||
|
||||
|
||||
class NewLLMConfigUpdate(BaseModel):
|
||||
"""Schema for updating an existing NewLLMConfig. All fields are optional."""
|
||||
|
||||
name: str | None = Field(None, max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
|
||||
# Model Configuration
|
||||
provider: LiteLLMProvider | None = None
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str | None = Field(None, max_length=100)
|
||||
api_key: str | None = None
|
||||
api_base: str | None = Field(None, max_length=500)
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
|
||||
# Prompt Configuration
|
||||
system_instructions: str | None = None
|
||||
use_default_system_instructions: bool | None = None
|
||||
citations_enabled: bool | None = None
|
||||
|
||||
|
||||
class NewLLMConfigRead(NewLLMConfigBase):
|
||||
"""Schema for reading a NewLLMConfig (includes id and timestamps)."""
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
# Capability flag derived at the API boundary (no DB column). Default
|
||||
# True matches the conservative-allow stance — a BYOK row that the
|
||||
# route forgot to augment is not pre-judged. The streaming-task
|
||||
# safety net is the only place a False actually blocks a request.
|
||||
supports_image_input: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether the BYOK chat config can accept image inputs. Derived "
|
||||
"at the route boundary from LiteLLM's authoritative model map "
|
||||
"(``litellm.supports_vision``) — there is no DB column. "
|
||||
"Default True is the conservative-allow stance for unknown / "
|
||||
"unmapped models."
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class NewLLMConfigPublic(BaseModel):
|
||||
"""
|
||||
Public schema for NewLLMConfig that hides the API key.
|
||||
Used when returning configs in list views or to users who shouldn't see keys.
|
||||
"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
# Model Configuration (no api_key)
|
||||
provider: LiteLLMProvider
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
|
||||
# Prompt Configuration
|
||||
system_instructions: str
|
||||
use_default_system_instructions: bool
|
||||
citations_enabled: bool
|
||||
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
# Capability flag derived at the API boundary (see NewLLMConfigRead).
|
||||
supports_image_input: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether the BYOK chat config can accept image inputs. Derived "
|
||||
"at the route boundary from LiteLLM's authoritative model map. "
|
||||
"Default True is the conservative-allow stance."
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class DefaultSystemInstructionsResponse(BaseModel):
|
||||
"""Response schema for getting default system instructions."""
|
||||
|
||||
default_system_instructions: str = Field(
|
||||
..., description="The default SURFSENSE_SYSTEM_INSTRUCTIONS template"
|
||||
)
|
||||
|
||||
|
||||
class GlobalNewLLMConfigRead(BaseModel):
|
||||
"""
|
||||
Schema for reading global LLM configs from YAML.
|
||||
Global configs have negative IDs and no search_space_id.
|
||||
API key is hidden for security.
|
||||
|
||||
ID 0 is reserved for Auto mode which uses LiteLLM Router for load balancing.
|
||||
"""
|
||||
|
||||
id: int = Field(
|
||||
...,
|
||||
description="Config ID: 0 for Auto mode, negative for global configs",
|
||||
)
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
# Model Configuration (no api_key)
|
||||
provider: str # String because YAML doesn't enforce enum, "AUTO" for Auto mode
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
|
||||
# Prompt Configuration
|
||||
system_instructions: str = ""
|
||||
use_default_system_instructions: bool = True
|
||||
citations_enabled: bool = True
|
||||
|
||||
is_global: bool = True # Always true for global configs
|
||||
is_auto_mode: bool = False # True only for Auto mode (ID 0)
|
||||
|
||||
billing_tier: str = "free"
|
||||
is_premium: bool = False
|
||||
anonymous_enabled: bool = False
|
||||
seo_enabled: bool = False
|
||||
seo_slug: str | None = None
|
||||
seo_title: str | None = None
|
||||
seo_description: str | None = None
|
||||
quota_reserve_tokens: int | None = None
|
||||
supports_image_input: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether the model accepts image inputs (multimodal vision). "
|
||||
"Derived server-side: OpenRouter dynamic configs use "
|
||||
"``architecture.input_modalities``; YAML / BYOK use LiteLLM's "
|
||||
"authoritative model map (``litellm.supports_vision``). The "
|
||||
"new-chat selector hints with a 'No image' badge when this is "
|
||||
"False and there are pending image attachments. The streaming "
|
||||
"task fails fast only when LiteLLM *explicitly* marks a model "
|
||||
"as text-only — unknown / unmapped models default-allow."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LLM Preferences Schemas (for role assignments)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class LLMPreferencesRead(BaseModel):
|
||||
"""Schema for reading LLM preferences (role assignments) for a search space."""
|
||||
|
||||
agent_llm_id: int | None = Field(
|
||||
None, description="ID of the LLM config to use for agent/chat tasks"
|
||||
)
|
||||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
vision_llm_config_id: int | None = Field(
|
||||
None,
|
||||
description="ID of the vision LLM config to use for vision/screenshot analysis",
|
||||
)
|
||||
agent_llm: dict[str, Any] | None = Field(
|
||||
None, description="Full config for chat model"
|
||||
)
|
||||
image_generation_config: dict[str, Any] | None = Field(
|
||||
None, description="Full config for image generation"
|
||||
)
|
||||
vision_llm_config: dict[str, Any] | None = Field(
|
||||
None, description="Full config for vision LLM"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class LLMPreferencesUpdate(BaseModel):
|
||||
"""Schema for updating LLM preferences."""
|
||||
|
||||
agent_llm_id: int | None = Field(
|
||||
None, description="ID of the LLM config to use for agent/chat tasks"
|
||||
)
|
||||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
vision_llm_config_id: int | None = Field(
|
||||
None,
|
||||
description="ID of the vision LLM config to use for vision/screenshot analysis",
|
||||
)
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.db import VisionProvider
|
||||
|
||||
|
||||
class VisionLLMConfigBase(BaseModel):
|
||||
name: str = Field(..., max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
provider: VisionProvider = Field(...)
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str = Field(..., max_length=100)
|
||||
api_key: str = Field(...)
|
||||
api_base: str | None = Field(None, max_length=500)
|
||||
api_version: str | None = Field(None, max_length=50)
|
||||
litellm_params: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
|
||||
class VisionLLMConfigCreate(VisionLLMConfigBase):
|
||||
search_space_id: int = Field(...)
|
||||
|
||||
|
||||
class VisionLLMConfigUpdate(BaseModel):
|
||||
name: str | None = Field(None, max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
provider: VisionProvider | None = None
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str | None = Field(None, max_length=100)
|
||||
api_key: str | None = None
|
||||
api_base: str | None = Field(None, max_length=500)
|
||||
api_version: str | None = Field(None, max_length=50)
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class VisionLLMConfigRead(VisionLLMConfigBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class VisionLLMConfigPublic(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: VisionProvider
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class GlobalVisionLLMConfigRead(BaseModel):
|
||||
"""Schema for reading global vision LLM configs from YAML.
|
||||
|
||||
The ``billing_tier`` field allows the frontend to show a Premium/Free
|
||||
badge and (more importantly) tells the backend whether to debit the
|
||||
user's premium credit pool when this config is used. ``"free"`` is
|
||||
the default for backward compatibility — admins must explicitly opt
|
||||
a global config into ``"premium"``.
|
||||
"""
|
||||
|
||||
id: int = Field(...)
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: str
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
is_global: bool = True
|
||||
is_auto_mode: bool = False
|
||||
billing_tier: str = Field(
|
||||
default="free",
|
||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||
)
|
||||
is_premium: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Convenience boolean derived server-side from "
|
||||
"``billing_tier == 'premium'``. The new-chat model selector "
|
||||
"keys its Free/Premium badge off this field for parity with "
|
||||
"chat (`GlobalLLMConfigRead.is_premium`)."
|
||||
),
|
||||
)
|
||||
quota_reserve_tokens: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional override for the per-call reservation in *tokens* — "
|
||||
"converted to micro-USD via the model's input/output prices at "
|
||||
"reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS."
|
||||
),
|
||||
)
|
||||
input_cost_per_token: float | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional input price in USD/token. Used by pricing_registration to "
|
||||
"register custom Azure / OpenRouter aliases with LiteLLM at startup."
|
||||
),
|
||||
)
|
||||
output_cost_per_token: float | None = Field(
|
||||
default=None,
|
||||
description="Optional output price in USD/token. Pair with input_cost_per_token.",
|
||||
)
|
||||
|
|
@ -1,13 +1,13 @@
|
|||
"""Resolve and persist Auto (Fastest) model pins per chat thread.
|
||||
"""Resolve and persist Auto model pins per chat thread.
|
||||
|
||||
Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we
|
||||
resolve that virtual mode to one concrete global LLM config exactly once and
|
||||
Auto is represented by ``chat_model_id == 0``. For chat threads we
|
||||
resolve that virtual mode to one concrete global model exactly once and
|
||||
persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so
|
||||
subsequent turns are stable.
|
||||
|
||||
Single-writer invariant: this module is the only writer of
|
||||
``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in
|
||||
``search_spaces_routes`` when a search space's ``agent_llm_id`` changes).
|
||||
``model_connections_routes`` when a search space's ``chat_model_id`` changes).
|
||||
Therefore a non-NULL value unambiguously means "this thread has an
|
||||
Auto-resolved pin"; no separate source/policy column is needed.
|
||||
"""
|
||||
|
|
@ -33,8 +33,10 @@ from app.services.token_quota_service import TokenQuotaService
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AUTO_FASTEST_ID = 0
|
||||
AUTO_FASTEST_MODE = "auto_fastest"
|
||||
AUTO_MODE_ID = 0
|
||||
# Stable internal hash namespace for deterministic per-thread selection.
|
||||
# Do not rename: changing this rebalances Auto's model choice for new pins.
|
||||
AUTO_PIN_HASH_NAMESPACE = "auto_fastest"
|
||||
_RUNTIME_COOLDOWN_SECONDS = 600
|
||||
_HEALTHY_TTL_SECONDS = 45
|
||||
|
||||
|
|
@ -383,7 +385,7 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
|
|||
pool = tier_a if tier_a else eligible
|
||||
pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0))
|
||||
top_k = pool[:_QUALITY_TOP_K]
|
||||
digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
|
||||
digest = hashlib.sha256(f"{AUTO_PIN_HASH_NAMESPACE}:{thread_id}".encode()).digest()
|
||||
idx = int.from_bytes(digest[:8], "big") % len(top_k)
|
||||
return top_k[idx], len(top_k)
|
||||
|
||||
|
|
@ -425,7 +427,7 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
exclude_config_ids: set[int] | None = None,
|
||||
requires_image_input: bool = False,
|
||||
) -> AutoPinResolution:
|
||||
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
|
||||
"""Resolve Auto to one concrete config id and persist the pin.
|
||||
|
||||
For non-auto selections, this function clears any existing pin and returns
|
||||
the selected id as-is.
|
||||
|
|
@ -457,7 +459,7 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
)
|
||||
|
||||
# Explicit model selected: clear any stale pin.
|
||||
if selected_llm_config_id != AUTO_FASTEST_ID:
|
||||
if selected_llm_config_id != AUTO_MODE_ID:
|
||||
if thread.pinned_llm_config_id is not None:
|
||||
thread.pinned_llm_config_id = None
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -450,10 +450,10 @@ async def _resolve_agent_billing_for_search_space(
|
|||
Used by Celery tasks (podcast generation, video presentation) to bill the
|
||||
search-space owner's premium credit pool when the chat model is premium.
|
||||
|
||||
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
|
||||
Resolution rules mirror the chat model role resolver:
|
||||
|
||||
- Search space not found / no ``agent_llm_id``: raise ``ValueError``.
|
||||
- **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
|
||||
- Search space not found / no ``chat_model_id``: raise ``ValueError``.
|
||||
- **Auto mode** (``id == AUTO_MODE_ID == 0``):
|
||||
* ``thread_id`` is set: delegate to
|
||||
``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
|
||||
recurse into the resolved id. Reuses chat's existing pin if present
|
||||
|
|
@ -469,9 +469,8 @@ async def _resolve_agent_billing_for_search_space(
|
|||
(defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
|
||||
``base_model = litellm_params.get("base_model") or model_name`` —
|
||||
NOT provider-prefixed, matching chat's cost-map lookup convention.
|
||||
- **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
|
||||
``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
|
||||
``base_model`` from ``litellm_params`` or ``model_name``.
|
||||
- **Positive id** (user BYOK ``Model``): always free; ``base_model`` from
|
||||
the model catalog override or the upstream ``model_id``.
|
||||
|
||||
Note on imports: ``llm_service``, ``auto_model_pin_service``, and
|
||||
``llm_router_service`` are imported lazily inside the function body to
|
||||
|
|
@ -480,8 +479,9 @@ async def _resolve_agent_billing_for_search_space(
|
|||
``billable_calls.py``'s module load path.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
from app.db import Model, SearchSpace
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
|
|
@ -490,20 +490,20 @@ async def _resolve_agent_billing_for_search_space(
|
|||
if search_space is None:
|
||||
raise ValueError(f"Search space {search_space_id} not found")
|
||||
|
||||
agent_llm_id = search_space.agent_llm_id
|
||||
if agent_llm_id is None:
|
||||
chat_model_id = search_space.chat_model_id
|
||||
if chat_model_id is None:
|
||||
raise ValueError(
|
||||
f"Search space {search_space_id} has no agent_llm_id configured"
|
||||
f"Search space {search_space_id} has no chat_model_id configured"
|
||||
)
|
||||
|
||||
owner_user_id: UUID = search_space.user_id
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
AUTO_FASTEST_ID,
|
||||
AUTO_MODE_ID,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
if agent_llm_id == AUTO_FASTEST_ID:
|
||||
if chat_model_id == AUTO_MODE_ID:
|
||||
if thread_id is None:
|
||||
return owner_user_id, "free", "auto"
|
||||
try:
|
||||
|
|
@ -512,7 +512,7 @@ async def _resolve_agent_billing_for_search_space(
|
|||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=str(owner_user_id),
|
||||
selected_llm_config_id=AUTO_FASTEST_ID,
|
||||
selected_llm_config_id=AUTO_MODE_ID,
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
|
|
@ -523,28 +523,35 @@ async def _resolve_agent_billing_for_search_space(
|
|||
exc_info=True,
|
||||
)
|
||||
return owner_user_id, "free", "auto"
|
||||
agent_llm_id = resolution.resolved_llm_config_id
|
||||
chat_model_id = resolution.resolved_llm_config_id
|
||||
|
||||
if agent_llm_id < 0:
|
||||
if chat_model_id < 0:
|
||||
from app.services.llm_service import get_global_llm_config
|
||||
|
||||
cfg = get_global_llm_config(agent_llm_id) or {}
|
||||
cfg = get_global_llm_config(chat_model_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
|
||||
return owner_user_id, billing_tier, base_model
|
||||
|
||||
nlc_result = await session.execute(
|
||||
select(NewLLMConfig).where(
|
||||
NewLLMConfig.id == agent_llm_id,
|
||||
NewLLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
model_result = await session.execute(
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.where(Model.id == chat_model_id, Model.enabled.is_(True))
|
||||
)
|
||||
nlc = nlc_result.scalars().first()
|
||||
model = model_result.scalars().first()
|
||||
base_model = ""
|
||||
if nlc is not None:
|
||||
litellm_params = nlc.litellm_params or {}
|
||||
base_model = litellm_params.get("base_model") or nlc.model_name or ""
|
||||
if (
|
||||
model is not None
|
||||
and model.connection is not None
|
||||
and model.connection.enabled
|
||||
and (
|
||||
model.connection.search_space_id in (None, search_space_id)
|
||||
and model.connection.user_id in (None, owner_user_id)
|
||||
)
|
||||
):
|
||||
catalog = model.catalog or {}
|
||||
base_model = catalog.get("base_model") or model.model_id or ""
|
||||
return owner_user_id, "free", base_model
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,11 @@ from app.services.auto_model_pin_service import (
|
|||
auto_model_candidates,
|
||||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode
|
||||
from app.services.llm_router_service import (
|
||||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
is_auto_mode,
|
||||
)
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
from app.services.token_tracking_service import token_tracker
|
||||
|
|
@ -96,26 +100,16 @@ class LLMRole:
|
|||
def get_global_llm_config(llm_config_id: int) -> dict | None:
|
||||
"""
|
||||
Get a global LLM configuration by ID.
|
||||
Global configs have negative IDs. ID 0 is reserved for Auto mode.
|
||||
Global configs have negative IDs. Auto mode (ID 0) is resolved through the
|
||||
model-candidate pipeline, not this legacy config lookup.
|
||||
|
||||
Args:
|
||||
llm_config_id: The ID of the global config (should be negative or 0 for Auto)
|
||||
llm_config_id: The ID of the global config (must be negative)
|
||||
|
||||
Returns:
|
||||
dict: Global config dictionary or None if not found
|
||||
"""
|
||||
# Auto mode (ID 0) is handled separately via the router
|
||||
if llm_config_id == AUTO_MODE_ID:
|
||||
return {
|
||||
"id": AUTO_MODE_ID,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
|
||||
if llm_config_id > 0:
|
||||
if llm_config_id >= 0:
|
||||
return None
|
||||
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS:
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ CACHE_TTL_SECONDS = 86400 # 24 hours
|
|||
_cache: list[dict] | None = None
|
||||
_cache_timestamp: float = 0
|
||||
|
||||
# Maps OpenRouter provider slug → our LiteLLMProvider enum value.
|
||||
# Maps OpenRouter provider slug to native LiteLLM provider prefixes.
|
||||
# Only providers where the model-name part (after the slash) can be
|
||||
# used directly with the native provider's litellm prefix are listed.
|
||||
#
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ def _generate_configs(
|
|||
|
||||
OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream
|
||||
via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer
|
||||
because our own Auto (Fastest) pin + 24 h refresh + repair logic already
|
||||
because our own Auto pin + 24 h refresh + repair logic already
|
||||
cover the catalogue-churn case.
|
||||
"""
|
||||
id_offset: int = settings.get("id_offset", -10000)
|
||||
|
|
@ -346,7 +346,7 @@ def _generate_configs(
|
|||
# ``"No endpoints found that support image input"``.
|
||||
"supports_image_input": bool(normalized.get("supports_image_input")),
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
|
||||
# Auto ranking metadata. ``quality_score`` is initialised
|
||||
# to the static score and gets re-blended with health on the next
|
||||
# ``_enrich_health`` pass (synchronous on refresh, deferred on cold
|
||||
# start so startup latency is unchanged).
|
||||
|
|
@ -361,11 +361,7 @@ def _generate_configs(
|
|||
return configs
|
||||
|
||||
|
||||
# ID-offset bands used to keep dynamic OpenRouter configs in their own
|
||||
# namespace per surface. Image / vision get separate bands so a single
|
||||
# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
|
||||
_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
|
||||
_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
|
||||
|
||||
|
||||
def _generate_image_gen_configs(
|
||||
|
|
@ -431,89 +427,6 @@ def _generate_image_gen_configs(
|
|||
return configs
|
||||
|
||||
|
||||
def _generate_vision_llm_configs(
|
||||
raw_models: list[dict], settings: dict[str, Any]
|
||||
) -> list[dict]:
|
||||
"""Convert OpenRouter vision-capable LLMs into global vision-LLM config
|
||||
dicts (matches the YAML shape consumed by ``vision_llm_routes``).
|
||||
|
||||
Filter:
|
||||
- architecture.input_modalities contains "image"
|
||||
- architecture.output_modalities contains "text"
|
||||
- compatible provider (excluded slugs blocked)
|
||||
- allowed model id (excluded list blocked)
|
||||
|
||||
Vision-LLM is invoked from the indexer (image extraction during
|
||||
document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
|
||||
the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
|
||||
filters do not apply: a small-context vision model that doesn't
|
||||
advertise tool-calling is still perfectly viable for "describe this
|
||||
image" prompts.
|
||||
"""
|
||||
id_offset: int = int(
|
||||
settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
|
||||
)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
tpm: int = settings.get("tpm", 1_000_000)
|
||||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
free_tpm: int = settings.get("free_tpm", 100_000)
|
||||
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
|
||||
vision_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if supports_image_input(m)
|
||||
and _shared_is_compatible_provider(m)
|
||||
and _shared_is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
for model in vision_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
tier = _openrouter_tier(model)
|
||||
pricing = model.get("pricing") or {}
|
||||
|
||||
# Capture per-token prices so ``pricing_registration`` can
|
||||
# register them with LiteLLM at startup (and so the cost
|
||||
# estimator in ``estimate_call_reserve_micros`` can resolve
|
||||
# them at reserve time).
|
||||
try:
|
||||
input_cost = float(pricing.get("prompt", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
input_cost = 0.0
|
||||
try:
|
||||
output_cost = float(pricing.get("completion", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
output_cost = 0.0
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter (vision)",
|
||||
"provider": "openrouter",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"tpm": free_tpm if tier == "free" else tpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"billing_tier": tier,
|
||||
"quota_reserve_tokens": quota_reserve_tokens,
|
||||
"input_cost_per_token": input_cost or None,
|
||||
"output_cost_per_token": output_cost or None,
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
class OpenRouterIntegrationService:
|
||||
"""Singleton that manages the dynamic OpenRouter model catalogue."""
|
||||
|
||||
|
|
@ -724,7 +637,7 @@ class OpenRouterIntegrationService:
|
|||
return counts
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Auto (Fastest) health enrichment
|
||||
# Auto health enrichment
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _enrich_health_safely(
|
||||
|
|
|
|||
|
|
@ -154,10 +154,8 @@ def _register_chat_shape_configs(
|
|||
input_cost = _safe_float(entry.get("prompt"))
|
||||
output_cost = _safe_float(entry.get("completion"))
|
||||
else:
|
||||
# Vision configs from ``_generate_vision_llm_configs``
|
||||
# carry their pricing inline because the OpenRouter
|
||||
# raw-pricing cache is keyed by chat-catalogue model_id;
|
||||
# vision flows pick up the inline values here.
|
||||
# Some dynamically materialized configs can carry pricing
|
||||
# inline when the raw OpenRouter cache has no matching entry.
|
||||
input_cost = _safe_float(cfg.get("input_cost_per_token"))
|
||||
output_cost = _safe_float(cfg.get("output_cost_per_token"))
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Pure-function quality scoring for Auto (Fastest) model selection.
|
||||
"""Pure-function quality scoring for Auto model selection.
|
||||
|
||||
This module is import-free of any service / request-path dependencies. All
|
||||
numbers are computed once during the OpenRouter refresh tick (or YAML load)
|
||||
|
|
|
|||
|
|
@ -1,160 +0,0 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from litellm import Router
|
||||
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VISION_AUTO_MODE_ID = 0
|
||||
|
||||
class VisionLLMRouterService:
|
||||
_instance = None
|
||||
_router: Router | None = None
|
||||
_model_list: list[dict] = []
|
||||
_router_settings: dict = {}
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "VisionLLMRouterService":
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def initialize(
|
||||
cls,
|
||||
global_configs: list[dict],
|
||||
router_settings: dict | None = None,
|
||||
) -> None:
|
||||
instance = cls.get_instance()
|
||||
|
||||
if instance._initialized:
|
||||
logger.debug("Vision LLM Router already initialized, skipping")
|
||||
return
|
||||
|
||||
model_list = []
|
||||
for config in global_configs:
|
||||
deployment = cls._config_to_deployment(config)
|
||||
if deployment:
|
||||
model_list.append(deployment)
|
||||
|
||||
if not model_list:
|
||||
logger.warning(
|
||||
"No valid vision LLM configs found for router initialization"
|
||||
)
|
||||
return
|
||||
|
||||
instance._model_list = model_list
|
||||
instance._router_settings = router_settings or {}
|
||||
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing",
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60,
|
||||
"retry_after": 5,
|
||||
}
|
||||
|
||||
final_settings = {**default_settings, **instance._router_settings}
|
||||
|
||||
try:
|
||||
instance._router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy=final_settings.get(
|
||||
"routing_strategy", "usage-based-routing"
|
||||
),
|
||||
num_retries=final_settings.get("num_retries", 3),
|
||||
allowed_fails=final_settings.get("allowed_fails", 3),
|
||||
cooldown_time=final_settings.get("cooldown_time", 60),
|
||||
set_verbose=False,
|
||||
)
|
||||
instance._initialized = True
|
||||
logger.info(
|
||||
"Vision LLM Router initialized with %d deployments, strategy: %s",
|
||||
len(model_list),
|
||||
final_settings.get("routing_strategy"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Vision LLM Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def _config_to_deployment(cls, config: dict) -> dict | None:
|
||||
try:
|
||||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
native_connection_from_config(config),
|
||||
config["model_name"],
|
||||
)
|
||||
litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs}
|
||||
|
||||
deployment: dict[str, Any] = {
|
||||
"model_name": "auto",
|
||||
"litellm_params": litellm_params,
|
||||
}
|
||||
|
||||
if config.get("rpm"):
|
||||
deployment["rpm"] = config["rpm"]
|
||||
if config.get("tpm"):
|
||||
deployment["tpm"] = config["tpm"]
|
||||
|
||||
return deployment
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert vision config to deployment: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_router(cls) -> Router | None:
|
||||
instance = cls.get_instance()
|
||||
return instance._router
|
||||
|
||||
@classmethod
|
||||
def is_initialized(cls) -> bool:
|
||||
instance = cls.get_instance()
|
||||
return instance._initialized and instance._router is not None
|
||||
|
||||
@classmethod
|
||||
def get_model_count(cls) -> int:
|
||||
instance = cls.get_instance()
|
||||
return len(instance._model_list)
|
||||
|
||||
|
||||
def is_vision_auto_mode(config_id: int | None) -> bool:
|
||||
return config_id == VISION_AUTO_MODE_ID
|
||||
|
||||
|
||||
def build_vision_model_string(
|
||||
litellm_provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
return f"{litellm_provider}/{model_name}"
|
||||
|
||||
|
||||
def get_global_vision_llm_config(config_id: int) -> dict | None:
|
||||
from app.config import config
|
||||
|
||||
if config_id == VISION_AUTO_MODE_ID:
|
||||
return {
|
||||
"id": VISION_AUTO_MODE_ID,
|
||||
"name": "Auto (Fastest)",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
if config_id > 0:
|
||||
return None
|
||||
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
"""
|
||||
Service for fetching and caching the vision-capable model list.
|
||||
|
||||
Reuses the same OpenRouter public API and local fallback as the LLM model
|
||||
list service, but filters for models that accept image input.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
|
||||
FALLBACK_FILE = (
|
||||
Path(__file__).parent.parent / "config" / "vision_model_list_fallback.json"
|
||||
)
|
||||
CACHE_TTL_SECONDS = 86400 # 24 hours
|
||||
|
||||
_cache: list[dict] | None = None
|
||||
_cache_timestamp: float = 0
|
||||
|
||||
OPENROUTER_SLUG_TO_VISION_PROVIDER: dict[str, str] = {
|
||||
"openai": "OPENAI",
|
||||
"anthropic": "ANTHROPIC",
|
||||
"google": "GOOGLE",
|
||||
"mistralai": "MISTRAL",
|
||||
"x-ai": "XAI",
|
||||
}
|
||||
|
||||
|
||||
def _format_context_length(length: int | None) -> str | None:
|
||||
if not length:
|
||||
return None
|
||||
if length >= 1_000_000:
|
||||
return f"{length / 1_000_000:g}M"
|
||||
if length >= 1_000:
|
||||
return f"{length / 1_000:g}K"
|
||||
return str(length)
|
||||
|
||||
|
||||
async def _fetch_from_openrouter() -> list[dict] | None:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
response = await client.get(OPENROUTER_API_URL)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("data", [])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch from OpenRouter API for vision models: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _load_fallback() -> list[dict]:
|
||||
try:
|
||||
with open(FALLBACK_FILE, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load vision model fallback list: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
def _is_vision_model(model: dict) -> bool:
|
||||
"""Return True if the model accepts image input and outputs text."""
|
||||
arch = model.get("architecture", {})
|
||||
input_mods = arch.get("input_modalities", [])
|
||||
output_mods = arch.get("output_modalities", [])
|
||||
return "image" in input_mods and "text" in output_mods
|
||||
|
||||
|
||||
def _process_vision_models(raw_models: list[dict]) -> list[dict]:
|
||||
processed: list[dict] = []
|
||||
|
||||
for model in raw_models:
|
||||
model_id: str = model.get("id", "")
|
||||
name: str = model.get("name", "")
|
||||
context_length = model.get("context_length")
|
||||
|
||||
if "/" not in model_id:
|
||||
continue
|
||||
|
||||
if not _is_vision_model(model):
|
||||
continue
|
||||
|
||||
provider_slug, model_name = model_id.split("/", 1)
|
||||
context_window = _format_context_length(context_length)
|
||||
|
||||
processed.append(
|
||||
{
|
||||
"value": model_id,
|
||||
"label": name,
|
||||
"provider": "OPENROUTER",
|
||||
"context_window": context_window,
|
||||
}
|
||||
)
|
||||
|
||||
direct_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug)
|
||||
if direct_provider:
|
||||
if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"):
|
||||
continue
|
||||
|
||||
processed.append(
|
||||
{
|
||||
"value": model_name,
|
||||
"label": name,
|
||||
"provider": direct_provider,
|
||||
"context_window": context_window,
|
||||
}
|
||||
)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
async def get_vision_model_list() -> list[dict]:
|
||||
global _cache, _cache_timestamp
|
||||
|
||||
if _cache is not None and (time.time() - _cache_timestamp) < CACHE_TTL_SECONDS:
|
||||
return _cache
|
||||
|
||||
raw_models = await _fetch_from_openrouter()
|
||||
|
||||
if raw_models is None:
|
||||
logger.info("Using fallback vision model list")
|
||||
return _load_fallback()
|
||||
|
||||
processed = _process_vision_models(raw_models)
|
||||
|
||||
_cache = processed
|
||||
_cache_timestamp = time.time()
|
||||
|
||||
return processed
|
||||
|
|
@ -330,31 +330,6 @@ async def probe_chat_configs(report: Report, *, live: bool) -> None:
|
|||
report.add(result)
|
||||
|
||||
|
||||
async def probe_vision_configs(report: Report, *, live: bool) -> None:
|
||||
print("\n[vision configs from global_vision_llm_configs (YAML-static)]")
|
||||
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
|
||||
if _is_or_dynamic(cfg):
|
||||
continue
|
||||
result = ProbeResult(
|
||||
label=str(cfg.get("name") or cfg.get("model_name")),
|
||||
surface="vision",
|
||||
config_id=cfg.get("id"),
|
||||
)
|
||||
# For vision configs, capability is implied — they're in the
|
||||
# dedicated vision pool. Run the same resolver to flag any
|
||||
# surprise disagreement.
|
||||
cap_ok, cap_note = _probe_chat_capability(cfg)
|
||||
result.capability_ok = cap_ok
|
||||
result.capability_note = cap_note
|
||||
if live:
|
||||
t0 = time.perf_counter()
|
||||
ok, note = await _live_chat_image_call(cfg)
|
||||
result.live_ok = ok
|
||||
result.live_note = note
|
||||
result.duration_s = time.perf_counter() - t0
|
||||
report.add(result)
|
||||
|
||||
|
||||
async def probe_image_gen_configs(report: Report, *, live: bool) -> None:
|
||||
print(
|
||||
"\n[image generation configs from global_image_generation_configs (YAML-static)]"
|
||||
|
|
@ -380,7 +355,7 @@ async def probe_image_gen_configs(report: Report, *, live: bool) -> None:
|
|||
|
||||
|
||||
async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
|
||||
"""Sample one chat (vision-capable), one vision, one image-gen model
|
||||
"""Sample chat/vision-capable and image-gen models
|
||||
from the live OpenRouter catalogue. Doesn't iterate the full pool
|
||||
(would be hundreds of probes); just validates the integration end-
|
||||
to-end on a representative model from each surface."""
|
||||
|
|
@ -405,9 +380,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
|
|||
for c in config.GLOBAL_LLM_CONFIGS
|
||||
if c.get("provider") == "OPENROUTER" and c.get("supports_image_input")
|
||||
]
|
||||
or_vision = [
|
||||
c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER"
|
||||
]
|
||||
or_image_gen = [
|
||||
c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER"
|
||||
]
|
||||
|
|
@ -427,11 +399,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
|
|||
("or-chat", _pick_first(or_chat, "anthropic/claude")),
|
||||
("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")),
|
||||
]
|
||||
vision_picks = [
|
||||
("or-vision", _pick_first(or_vision, "openai/gpt-4o")),
|
||||
("or-vision", _pick_first(or_vision, "anthropic/claude")),
|
||||
("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")),
|
||||
]
|
||||
image_picks = [
|
||||
("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")),
|
||||
# OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*``
|
||||
|
|
@ -441,11 +408,11 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
|
|||
]
|
||||
|
||||
print(
|
||||
f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} "
|
||||
f" catalog: chat_vision={len(or_chat)} image_gen={len(or_image_gen)} "
|
||||
f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})"
|
||||
)
|
||||
|
||||
for surface, picked in chat_picks + vision_picks + image_picks:
|
||||
for surface, picked in chat_picks + image_picks:
|
||||
if not picked:
|
||||
report.add(
|
||||
ProbeResult(
|
||||
|
|
@ -486,7 +453,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
|
|||
async def main(args: argparse.Namespace) -> int:
|
||||
print("Loaded global configs:")
|
||||
print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries")
|
||||
print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries")
|
||||
print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries")
|
||||
print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}")
|
||||
|
||||
|
|
@ -507,8 +473,6 @@ async def main(args: argparse.Namespace) -> int:
|
|||
report = Report()
|
||||
if not args.skip_chat:
|
||||
await probe_chat_configs(report, live=args.live)
|
||||
if not args.skip_vision:
|
||||
await probe_vision_configs(report, live=args.live)
|
||||
if not args.skip_image_gen:
|
||||
await probe_image_gen_configs(report, live=args.live)
|
||||
if not args.skip_openrouter:
|
||||
|
|
@ -528,7 +492,6 @@ def _parse_args() -> argparse.Namespace:
|
|||
)
|
||||
parser.set_defaults(live=True)
|
||||
parser.add_argument("--skip-chat", action="store_true")
|
||||
parser.add_argument("--skip-vision", action="store_true")
|
||||
parser.add_argument("--skip-image-gen", action="store_true")
|
||||
parser.add_argument("--skip-openrouter", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Lock the runtime model-policy backstop in ``build_dependencies``.
|
||||
|
||||
Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so
|
||||
Automations resolve their LLM from the *captured* ``chat_model_id`` snapshot (so
|
||||
runs are insulated from later chat/search-space model changes), and the model
|
||||
policy is re-checked at run time so a captured model that is no longer billable
|
||||
fails the run clearly. When no snapshot is present, resolution falls back to the
|
||||
|
|
@ -45,10 +45,10 @@ def patched_side_effects(monkeypatch: pytest.MonkeyPatch):
|
|||
return None
|
||||
|
||||
|
||||
async def test_build_dependencies_resolves_captured_agent_llm_id(
|
||||
async def test_build_dependencies_resolves_captured_chat_model_id(
|
||||
monkeypatch: pytest.MonkeyPatch, patched_side_effects
|
||||
) -> None:
|
||||
"""The bundle loads with the *captured* ``agent_llm_id``, not the live search space."""
|
||||
"""The bundle loads with the *captured* ``chat_model_id``, not the live search space."""
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def _fake_load(_session, *, config_id, search_space_id):
|
||||
|
|
@ -67,13 +67,13 @@ async def test_build_dependencies_resolves_captured_agent_llm_id(
|
|||
lambda _ss: pytest.fail("search-space policy should not run on captured path"),
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(agent_llm_id=-99)
|
||||
search_space = SimpleNamespace(chat_model_id=-99)
|
||||
result = await build_dependencies(
|
||||
session=_FakeSession(search_space),
|
||||
search_space_id=42,
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-7,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
|
||||
assert captured == {"config_id": -7, "search_space_id": 42}
|
||||
|
|
@ -98,17 +98,17 @@ async def test_build_dependencies_validates_captured_ids(
|
|||
monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load)
|
||||
|
||||
await build_dependencies(
|
||||
session=_FakeSession(SimpleNamespace(agent_llm_id=0)),
|
||||
session=_FakeSession(SimpleNamespace(chat_model_id=0)),
|
||||
search_space_id=42,
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-7,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
|
||||
assert seen == {
|
||||
"agent_llm_id": -7,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -7,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -119,7 +119,7 @@ async def test_build_dependencies_raises_on_captured_policy_violation(
|
|||
|
||||
def _raise(**_kw):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "image", "config_id": -2, "reason": "free model"}]
|
||||
[{"kind": "image", "model_id": -2, "reason": "free model"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(deps_mod, "assert_models_billable", _raise)
|
||||
|
|
@ -131,11 +131,11 @@ async def test_build_dependencies_raises_on_captured_policy_violation(
|
|||
|
||||
with pytest.raises(DependencyError):
|
||||
await build_dependencies(
|
||||
session=_FakeSession(SimpleNamespace(agent_llm_id=-7)),
|
||||
session=_FakeSession(SimpleNamespace(chat_model_id=-7)),
|
||||
search_space_id=42,
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=-2,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-7,
|
||||
image_gen_model_id=-2,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -157,7 +157,7 @@ async def test_build_dependencies_falls_back_to_search_space(
|
|||
lambda **_kw: pytest.fail("captured policy should not run on fallback path"),
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(agent_llm_id=-7)
|
||||
search_space = SimpleNamespace(chat_model_id=-7)
|
||||
result = await build_dependencies(
|
||||
session=_FakeSession(search_space), search_space_id=42
|
||||
)
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ def _run() -> SimpleNamespace:
|
|||
def test_build_action_ctx_propagates_captured_models() -> None:
|
||||
"""``definition.models`` flows onto the ActionContext model fields."""
|
||||
models = AutomationModels(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-1,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
ctx = _build_action_ctx(
|
||||
cast(AsyncSession, None),
|
||||
|
|
@ -40,9 +40,9 @@ def test_build_action_ctx_propagates_captured_models() -> None:
|
|||
)
|
||||
|
||||
assert ctx.search_space_id == 42
|
||||
assert ctx.agent_llm_id == -1
|
||||
assert ctx.image_generation_config_id == 5
|
||||
assert ctx.vision_llm_config_id == -1
|
||||
assert ctx.chat_model_id == -1
|
||||
assert ctx.image_gen_model_id == 5
|
||||
assert ctx.vision_model_id == -1
|
||||
|
||||
|
||||
def test_build_action_ctx_none_models_leaves_fields_none() -> None:
|
||||
|
|
@ -54,6 +54,6 @@ def test_build_action_ctx_none_models_leaves_fields_none() -> None:
|
|||
None,
|
||||
)
|
||||
|
||||
assert ctx.agent_llm_id is None
|
||||
assert ctx.image_generation_config_id is None
|
||||
assert ctx.vision_llm_config_id is None
|
||||
assert ctx.chat_model_id is None
|
||||
assert ctx.image_gen_model_id is None
|
||||
assert ctx.vision_model_id is None
|
||||
|
|
|
|||
|
|
@ -40,24 +40,24 @@ def test_automation_definition_models_round_trip() -> None:
|
|||
name="Daily digest",
|
||||
plan=[PlanStep(step_id="s1", action="agent_task")],
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-1,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
),
|
||||
)
|
||||
|
||||
dumped = definition.model_dump(mode="json", by_alias=True)
|
||||
assert dumped["models"] == {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
|
||||
restored = AutomationDefinition.model_validate(dumped)
|
||||
assert restored.models is not None
|
||||
assert restored.models.agent_llm_id == -1
|
||||
assert restored.models.image_generation_config_id == 5
|
||||
assert restored.models.vision_llm_config_id == -1
|
||||
assert restored.models.chat_model_id == -1
|
||||
assert restored.models.image_gen_model_id == 5
|
||||
assert restored.models.vision_model_id == -1
|
||||
|
||||
|
||||
def test_automation_definition_rejects_unknown_top_level_field() -> None:
|
||||
|
|
|
|||
|
|
@ -64,12 +64,12 @@ async def test_assert_models_billable_raises_422_on_violation(
|
|||
|
||||
def _raise(_ss):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "llm", "config_id": 0, "reason": "Auto mode"}]
|
||||
[{"kind": "llm", "model_id": 0, "reason": "Auto mode"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=0))
|
||||
service = _service(SimpleNamespace(chat_model_id=0))
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service._assert_models_billable(1)
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ async def test_assert_models_billable_returns_search_space_when_ok(
|
|||
automation_mod, "assert_automation_models_billable", lambda _ss: None
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(agent_llm_id=-1)
|
||||
search_space = SimpleNamespace(chat_model_id=-1)
|
||||
service = _service(search_space)
|
||||
assert await service._assert_models_billable(1) is search_space
|
||||
|
||||
|
|
@ -123,9 +123,9 @@ async def test_create_injects_captured_models_from_search_space(
|
|||
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
|
||||
|
||||
search_space = SimpleNamespace(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-1,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
service = _service(search_space)
|
||||
payload = AutomationCreate(
|
||||
|
|
@ -137,9 +137,9 @@ async def test_create_injects_captured_models_from_search_space(
|
|||
automation = await service.create(payload)
|
||||
|
||||
assert automation.definition["models"] == {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -162,9 +162,9 @@ async def test_create_treats_unset_prefs_as_auto_zero(
|
|||
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
|
||||
|
||||
search_space = SimpleNamespace(
|
||||
agent_llm_id=None,
|
||||
image_generation_config_id=None,
|
||||
vision_llm_config_id=None,
|
||||
chat_model_id=None,
|
||||
image_gen_model_id=None,
|
||||
vision_model_id=None,
|
||||
)
|
||||
service = _service(search_space)
|
||||
payload = AutomationCreate(search_space_id=1, name="A", definition=_definition())
|
||||
|
|
@ -172,9 +172,9 @@ async def test_create_treats_unset_prefs_as_auto_zero(
|
|||
automation = await service.create(payload)
|
||||
|
||||
assert automation.definition["models"] == {
|
||||
"agent_llm_id": 0,
|
||||
"image_generation_config_id": 0,
|
||||
"vision_llm_config_id": 0,
|
||||
"chat_model_id": 0,
|
||||
"image_gen_model_id": 0,
|
||||
"vision_model_id": 0,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -195,11 +195,11 @@ async def test_create_honors_selected_models_when_provided(
|
|||
)
|
||||
validated: dict[str, Any] = {}
|
||||
|
||||
def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id):
|
||||
validated["ids"] = (
|
||||
agent_llm_id,
|
||||
image_generation_config_id,
|
||||
vision_llm_config_id,
|
||||
chat_model_id,
|
||||
image_gen_model_id,
|
||||
vision_model_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok)
|
||||
|
|
@ -213,15 +213,15 @@ async def test_create_honors_selected_models_when_provided(
|
|||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=-99))
|
||||
service = _service(SimpleNamespace(chat_model_id=-99))
|
||||
payload = AutomationCreate(
|
||||
search_space_id=1,
|
||||
name="A",
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=7,
|
||||
vision_llm_config_id=-2,
|
||||
chat_model_id=-1,
|
||||
image_gen_model_id=7,
|
||||
vision_model_id=-2,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
|
@ -230,9 +230,9 @@ async def test_create_honors_selected_models_when_provided(
|
|||
|
||||
assert validated["ids"] == (-1, 7, -2)
|
||||
assert automation.definition["models"] == {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 7,
|
||||
"vision_llm_config_id": -2,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 7,
|
||||
"vision_model_id": -2,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -241,9 +241,9 @@ async def test_create_rejects_unbillable_selected_models(
|
|||
) -> None:
|
||||
"""A non-billable explicit selection maps the policy error to HTTP 422."""
|
||||
|
||||
def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
def _raise(*, chat_model_id, image_gen_model_id, vision_model_id):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "llm", "config_id": -3, "reason": "free model"}]
|
||||
[{"kind": "llm", "model_id": -3, "reason": "free model"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _raise)
|
||||
|
|
@ -253,15 +253,15 @@ async def test_create_rejects_unbillable_selected_models(
|
|||
|
||||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=-3))
|
||||
service = _service(SimpleNamespace(chat_model_id=-3))
|
||||
payload = AutomationCreate(
|
||||
search_space_id=1,
|
||||
name="A",
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-3,
|
||||
image_generation_config_id=7,
|
||||
vision_llm_config_id=-2,
|
||||
chat_model_id=-3,
|
||||
image_gen_model_id=7,
|
||||
vision_model_id=-2,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
|
@ -277,9 +277,9 @@ async def test_update_preserves_captured_models(
|
|||
) -> None:
|
||||
"""A definition edit carries over the previously captured ``models``."""
|
||||
captured = {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
existing = SimpleNamespace(
|
||||
search_space_id=1,
|
||||
|
|
@ -318,20 +318,20 @@ async def test_update_honors_changed_models_when_valid(
|
|||
"name": "A",
|
||||
"plan": [],
|
||||
"models": {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
},
|
||||
},
|
||||
version=3,
|
||||
)
|
||||
validated: dict[str, Any] = {}
|
||||
|
||||
def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id):
|
||||
validated["ids"] = (
|
||||
agent_llm_id,
|
||||
image_generation_config_id,
|
||||
vision_llm_config_id,
|
||||
chat_model_id,
|
||||
image_gen_model_id,
|
||||
vision_model_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok)
|
||||
|
|
@ -351,9 +351,9 @@ async def test_update_honors_changed_models_when_valid(
|
|||
patch = AutomationUpdate(
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-2,
|
||||
image_generation_config_id=9,
|
||||
vision_llm_config_id=-2,
|
||||
chat_model_id=-2,
|
||||
image_gen_model_id=9,
|
||||
vision_model_id=-2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -362,9 +362,9 @@ async def test_update_honors_changed_models_when_valid(
|
|||
|
||||
assert validated["ids"] == (-2, 9, -2)
|
||||
assert result.definition["models"] == {
|
||||
"agent_llm_id": -2,
|
||||
"image_generation_config_id": 9,
|
||||
"vision_llm_config_id": -2,
|
||||
"chat_model_id": -2,
|
||||
"image_gen_model_id": 9,
|
||||
"vision_model_id": -2,
|
||||
}
|
||||
assert result.version == 4
|
||||
|
||||
|
|
@ -379,17 +379,17 @@ async def test_update_rejects_changed_unbillable_models(
|
|||
"name": "A",
|
||||
"plan": [],
|
||||
"models": {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
},
|
||||
},
|
||||
version=3,
|
||||
)
|
||||
|
||||
def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
def _raise(*, chat_model_id, image_gen_model_id, vision_model_id):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "llm", "config_id": -7, "reason": "free model"}]
|
||||
[{"kind": "llm", "model_id": -7, "reason": "free model"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _raise)
|
||||
|
|
@ -409,9 +409,9 @@ async def test_update_rejects_changed_unbillable_models(
|
|||
patch = AutomationUpdate(
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-7,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -431,9 +431,9 @@ async def test_update_keeps_unchanged_models_without_revalidation(
|
|||
premium without an unrelated edit tripping the policy check.
|
||||
"""
|
||||
captured = {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
existing = SimpleNamespace(
|
||||
search_space_id=1,
|
||||
|
|
@ -485,7 +485,7 @@ async def test_model_eligibility_authorizes_and_returns_payload(
|
|||
lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]},
|
||||
)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=-2))
|
||||
service = _service(SimpleNamespace(chat_model_id=-2))
|
||||
result = await service.model_eligibility(search_space_id=5)
|
||||
|
||||
assert result == {"allowed": False, "violations": [{"kind": "image"}]}
|
||||
|
|
|
|||
|
|
@ -27,9 +27,9 @@ pytestmark = pytest.mark.unit
|
|||
def _search_space(*, llm: int | None, image: int | None, vision: int | None):
|
||||
"""Minimal stand-in for the ``SearchSpace`` ORM row the policy reads."""
|
||||
return SimpleNamespace(
|
||||
agent_llm_id=llm,
|
||||
image_generation_config_id=image,
|
||||
vision_llm_config_id=vision,
|
||||
chat_model_id=llm,
|
||||
image_gen_model_id=image,
|
||||
vision_model_id=vision,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -39,29 +39,11 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch):
|
|||
|
||||
Negative ids: -1 is premium, -2 is free, for each of llm/image/vision.
|
||||
"""
|
||||
llm_configs = {
|
||||
-1: {"id": -1, "billing_tier": "premium"},
|
||||
-2: {"id": -2, "billing_tier": "free"},
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
"app.agents.chat.runtime.llm_config.load_global_llm_config_by_id",
|
||||
lambda cid: llm_configs.get(cid),
|
||||
)
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
monkeypatch.setattr(
|
||||
app_config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
[
|
||||
{"id": -1, "billing_tier": "premium"},
|
||||
{"id": -2, "billing_tier": "free"},
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
"GLOBAL_MODELS",
|
||||
[
|
||||
{"id": -1, "billing_tier": "premium"},
|
||||
{"id": -2, "billing_tier": "free"},
|
||||
|
|
@ -71,7 +53,7 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch):
|
|||
return None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None:
|
||||
"""A positive config id is a user-owned BYOK model — always billable."""
|
||||
allowed, reason = model_policy._classify(kind, 7)
|
||||
|
|
@ -79,7 +61,7 @@ def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None:
|
|||
assert reason == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
@pytest.mark.parametrize("config_id", [0, None])
|
||||
def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None:
|
||||
"""Auto mode (id 0) and an unset slot (None) are blocked."""
|
||||
|
|
@ -88,7 +70,7 @@ def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None:
|
|||
assert "Auto mode" in reason
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
def test_premium_global_is_allowed(kind: str, patched_globals) -> None:
|
||||
"""A negative (global) id with premium billing tier is allowed."""
|
||||
allowed, reason = model_policy._classify(kind, -1)
|
||||
|
|
@ -96,7 +78,7 @@ def test_premium_global_is_allowed(kind: str, patched_globals) -> None:
|
|||
assert reason == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
def test_free_global_is_blocked(kind: str, patched_globals) -> None:
|
||||
"""A negative (global) id with a free billing tier is blocked."""
|
||||
allowed, reason = model_policy._classify(kind, -2)
|
||||
|
|
@ -104,7 +86,7 @@ def test_free_global_is_blocked(kind: str, patched_globals) -> None:
|
|||
assert "free model" in reason
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None:
|
||||
"""A negative id that resolves to no config is treated as not premium."""
|
||||
allowed, _ = model_policy._classify(kind, -999)
|
||||
|
|
@ -125,10 +107,10 @@ def test_eligibility_reports_each_violation(patched_globals) -> None:
|
|||
|
||||
assert result["allowed"] is False
|
||||
kinds = {v["kind"] for v in result["violations"]}
|
||||
assert kinds == {"llm", "image", "vision"}
|
||||
# config_id is echoed back for the UI / settings deep-link.
|
||||
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
|
||||
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
|
||||
assert kinds == {"chat", "image", "vision"}
|
||||
# model_id is echoed back for the UI / settings deep-link.
|
||||
by_kind = {v["kind"]: v["model_id"] for v in result["violations"]}
|
||||
assert by_kind == {"chat": -2, "image": 0, "vision": -2}
|
||||
|
||||
|
||||
def test_assert_raises_with_violations(patched_globals) -> None:
|
||||
|
|
@ -138,7 +120,7 @@ def test_assert_raises_with_violations(patched_globals) -> None:
|
|||
assert_automation_models_billable(search_space)
|
||||
|
||||
assert len(exc_info.value.violations) == 1
|
||||
assert exc_info.value.violations[0]["kind"] == "llm"
|
||||
assert exc_info.value.violations[0]["kind"] == "chat"
|
||||
|
||||
|
||||
def test_assert_passes_when_all_billable(patched_globals) -> None:
|
||||
|
|
@ -153,7 +135,7 @@ def test_assert_passes_when_all_billable(patched_globals) -> None:
|
|||
def test_get_model_eligibility_all_billable(patched_globals) -> None:
|
||||
"""Premium LLM + BYOK image + premium vision (explicit ids) → allowed."""
|
||||
result = get_model_eligibility(
|
||||
agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1
|
||||
chat_model_id=-1, image_gen_model_id=5, vision_model_id=-1
|
||||
)
|
||||
assert result == {"allowed": True, "violations": []}
|
||||
|
||||
|
|
@ -161,28 +143,28 @@ def test_get_model_eligibility_all_billable(patched_globals) -> None:
|
|||
def test_get_model_eligibility_reports_each_violation(patched_globals) -> None:
|
||||
"""Free LLM, Auto image, free vision (explicit ids) each produce a violation."""
|
||||
result = get_model_eligibility(
|
||||
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
|
||||
chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2
|
||||
)
|
||||
assert result["allowed"] is False
|
||||
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
|
||||
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
|
||||
by_kind = {v["kind"]: v["model_id"] for v in result["violations"]}
|
||||
assert by_kind == {"chat": -2, "image": 0, "vision": -2}
|
||||
|
||||
|
||||
def test_assert_models_billable_raises(patched_globals) -> None:
|
||||
"""``assert_models_billable`` raises when any explicit id is blocked."""
|
||||
with pytest.raises(AutomationModelPolicyError) as exc_info:
|
||||
assert_models_billable(
|
||||
agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1
|
||||
chat_model_id=0, image_gen_model_id=5, vision_model_id=-1
|
||||
)
|
||||
assert len(exc_info.value.violations) == 1
|
||||
assert exc_info.value.violations[0]["kind"] == "llm"
|
||||
assert exc_info.value.violations[0]["kind"] == "chat"
|
||||
|
||||
|
||||
def test_assert_models_billable_passes(patched_globals) -> None:
|
||||
"""No exception when every explicit id is premium or BYOK."""
|
||||
assert (
|
||||
assert_models_billable(
|
||||
agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4
|
||||
chat_model_id=3, image_gen_model_id=-1, vision_model_id=4
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
|
@ -192,5 +174,5 @@ def test_search_space_wrapper_delegates_to_core(patched_globals) -> None:
|
|||
"""The search-space wrapper produces the same result as the ID core."""
|
||||
search_space = _search_space(llm=-2, image=0, vision=-2)
|
||||
assert get_automation_model_eligibility(search_space) == get_model_eligibility(
|
||||
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
|
||||
chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,110 +0,0 @@
|
|||
"""Unit tests for ``supports_image_input`` derivation on BYOK chat config
|
||||
endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``).
|
||||
|
||||
There is no DB column for ``supports_image_input`` on
|
||||
``NewLLMConfig`` — the value is resolved at the API boundary by
|
||||
``derive_supports_image_input`` so the new-chat selector / streaming
|
||||
task can read the same field shape regardless of source (BYOK vs YAML
|
||||
vs OpenRouter dynamic). Default-allow on unknown so we don't lock the
|
||||
user out of their own model choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import LiteLLMProvider
|
||||
from app.routes import new_llm_config_routes
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _byok_row(
|
||||
*,
|
||||
id_: int,
|
||||
model_name: str,
|
||||
base_model: str | None = None,
|
||||
provider: LiteLLMProvider = LiteLLMProvider.OPENAI,
|
||||
custom_provider: str | None = None,
|
||||
) -> object:
|
||||
"""Mimic the SQLAlchemy row's attribute surface; ``model_validate``
|
||||
walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough.
|
||||
|
||||
``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's
|
||||
enum validator accepts it — same as the ORM row would carry."""
|
||||
return SimpleNamespace(
|
||||
id=id_,
|
||||
name=f"BYOK-{id_}",
|
||||
description=None,
|
||||
provider=provider,
|
||||
custom_provider=custom_provider,
|
||||
model_name=model_name,
|
||||
api_key="sk-byok",
|
||||
api_base=None,
|
||||
litellm_params={"base_model": base_model} if base_model else None,
|
||||
system_instructions="",
|
||||
use_default_system_instructions=True,
|
||||
citations_enabled=True,
|
||||
created_at=datetime.now(tz=UTC),
|
||||
search_space_id=42,
|
||||
user_id=uuid4(),
|
||||
)
|
||||
|
||||
|
||||
def test_serialize_byok_known_vision_model_resolves_true():
|
||||
"""The catalog resolver consults LiteLLM's map for ``gpt-4o`` ->
|
||||
True. The serialized row carries that value through to the
|
||||
``NewLLMConfigRead`` schema."""
|
||||
row = _byok_row(id_=1, model_name="gpt-4o")
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
assert serialized.id == 1
|
||||
assert serialized.model_name == "gpt-4o"
|
||||
|
||||
|
||||
def test_serialize_byok_unknown_model_default_allows():
|
||||
"""Unknown / unmapped: default-allow. The streaming-task safety net
|
||||
is the actual block, and it requires LiteLLM to *explicitly* say
|
||||
text-only — so a brand new BYOK model should not be pre-judged."""
|
||||
row = _byok_row(
|
||||
id_=2,
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
provider=LiteLLMProvider.CUSTOM,
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
|
||||
|
||||
def test_serialize_byok_uses_base_model_when_present():
|
||||
"""Azure-style: ``model_name`` is the deployment id, ``base_model``
|
||||
inside ``litellm_params`` is the canonical sku LiteLLM knows. The
|
||||
helper must consult ``base_model`` first or unrecognised deployment
|
||||
ids would shadow the real capability."""
|
||||
row = _byok_row(
|
||||
id_=3,
|
||||
model_name="my-azure-deployment-id-no-litellm-knows-this",
|
||||
base_model="gpt-4o",
|
||||
provider=LiteLLMProvider.AZURE_OPENAI,
|
||||
)
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
|
||||
|
||||
def test_serialize_byok_returns_pydantic_read_model():
|
||||
"""The route now returns ``NewLLMConfigRead`` (not the raw ORM) so
|
||||
the schema additions are guaranteed to be present in the API
|
||||
surface. This guards against a future regression where someone
|
||||
deletes the augmentation step and falls back to ORM passthrough."""
|
||||
from app.schemas import NewLLMConfigRead
|
||||
|
||||
row = _byok_row(id_=4, model_name="gpt-4o")
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
assert isinstance(serialized, NewLLMConfigRead)
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
"""Unit tests for ``is_premium`` derivation on the global image-gen and
|
||||
vision-LLM list endpoints.
|
||||
|
||||
Chat globals (``GET /global-llm-configs``) already emit
|
||||
``is_premium = (billing_tier == "premium")``. Image and vision did not,
|
||||
which made the new-chat ``model-selector`` render the Free/Premium badge
|
||||
on the Chat tab but skip it on the Image and Vision tabs (the selector
|
||||
keys its badge logic off ``is_premium``). These tests pin parity:
|
||||
|
||||
* YAML free entry → ``is_premium=False``
|
||||
* YAML premium entry → ``is_premium=True``
|
||||
* OpenRouter dynamic premium entry → ``is_premium=True``
|
||||
* Auto stub (always emitted when at least one config is present)
|
||||
→ ``is_premium=False``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_IMAGE_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "DALL-E 3",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "dall-e-3",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "GPT-Image 1 (premium)",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-image-1",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
{
|
||||
"id": -20_001,
|
||||
"name": "google/gemini-2.5-flash-image (OpenRouter)",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-2.5-flash-image",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
_VISION_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o Vision",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "Claude 3.5 Sonnet (premium)",
|
||||
"litellm_provider": "anthropic",
|
||||
"model_name": "claude-3-5-sonnet",
|
||||
"api_key": "sk-ant-test",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
{
|
||||
"id": -30_001,
|
||||
"name": "openai/gpt-4o (OpenRouter)",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image generation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_image_gen_configs_emit_is_premium(monkeypatch):
|
||||
"""Each emitted config must carry ``is_premium`` derived server-side
|
||||
from ``billing_tier``. The Auto stub is always free.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False
|
||||
)
|
||||
|
||||
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
|
||||
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
# Auto stub is always emitted when at least one global config exists,
|
||||
# and it must always declare itself free (Auto-mode billing-tier
|
||||
# surfacing is a separate follow-up).
|
||||
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
|
||||
assert by_id[0]["is_premium"] is False
|
||||
assert by_id[0]["billing_tier"] == "free"
|
||||
|
||||
# YAML free entry — ``is_premium=False``
|
||||
assert by_id[-1]["is_premium"] is False
|
||||
assert by_id[-1]["billing_tier"] == "free"
|
||||
|
||||
# YAML premium entry — ``is_premium=True``
|
||||
assert by_id[-2]["is_premium"] is True
|
||||
assert by_id[-2]["billing_tier"] == "premium"
|
||||
|
||||
# OpenRouter dynamic premium entry — same field, same derivation
|
||||
assert by_id[-20_001]["is_premium"] is True
|
||||
assert by_id[-20_001]["billing_tier"] == "premium"
|
||||
|
||||
# Every emitted dict (including Auto) must have the field — never missing.
|
||||
for cfg in payload:
|
||||
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
|
||||
assert isinstance(cfg["is_premium"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch):
|
||||
"""When there are no global configs at all, the endpoint emits an
|
||||
empty list (no Auto stub) — Auto mode would have nothing to route to.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False)
|
||||
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
|
||||
assert payload == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Vision LLM
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_vision_llm_configs_emit_is_premium(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import vision_llm_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False
|
||||
)
|
||||
|
||||
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
|
||||
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
|
||||
assert by_id[0]["is_premium"] is False
|
||||
assert by_id[0]["billing_tier"] == "free"
|
||||
|
||||
assert by_id[-1]["is_premium"] is False
|
||||
assert by_id[-1]["billing_tier"] == "free"
|
||||
|
||||
assert by_id[-2]["is_premium"] is True
|
||||
assert by_id[-2]["billing_tier"] == "premium"
|
||||
|
||||
assert by_id[-30_001]["is_premium"] is True
|
||||
assert by_id[-30_001]["billing_tier"] == "premium"
|
||||
|
||||
for cfg in payload:
|
||||
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
|
||||
assert isinstance(cfg["is_premium"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import vision_llm_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False)
|
||||
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
|
||||
assert payload == []
|
||||
|
|
@ -1,106 +0,0 @@
|
|||
"""Unit tests for ``supports_image_input`` derivation on the chat global
|
||||
config endpoint (``GET /global-new-llm-configs``).
|
||||
|
||||
Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``):
|
||||
|
||||
1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML
|
||||
loader for operator overrides, or by the OpenRouter integration from
|
||||
``architecture.input_modalities``) — wins.
|
||||
2. ``derive_supports_image_input`` helper — default-allow on unknown
|
||||
models, only False when LiteLLM / OR modalities are definitive.
|
||||
|
||||
The flag is purely informational at the API boundary. The streaming
|
||||
task safety net (``is_known_text_only_chat_model``) is the actual block,
|
||||
and it requires LiteLLM to *explicitly* mark the model as text-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o (explicit true)",
|
||||
"description": "vision-capable, explicit YAML override",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "DeepSeek V3 (explicit false)",
|
||||
"description": "OpenRouter dynamic — modality-derived false",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "deepseek/deepseek-v3.2-exp",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "free",
|
||||
"supports_image_input": False,
|
||||
},
|
||||
{
|
||||
"id": -10_010,
|
||||
"name": "Unannotated GPT-4o",
|
||||
"description": "no flag set — resolver should derive True via LiteLLM",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
# supports_image_input intentionally absent
|
||||
},
|
||||
{
|
||||
"id": -10_011,
|
||||
"name": "Unannotated unknown model",
|
||||
"description": "unmapped — default-allow True",
|
||||
"litellm_provider": "custom",
|
||||
"custom_provider": "brand_new_proxy",
|
||||
"model_name": "brand-new-model-x9",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch):
|
||||
"""Each emitted chat config carries ``supports_image_input`` as a
|
||||
bool. Explicit values win; unannotated entries are resolved via the
|
||||
helper (default-allow True)."""
|
||||
from app.config import config
|
||||
from app.routes import new_llm_config_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False)
|
||||
|
||||
payload = await new_llm_config_routes.get_global_new_llm_configs(user=None)
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
# Auto stub: optimistic True so the user can keep Auto selected with
|
||||
# vision-capable deployments somewhere in the pool.
|
||||
assert 0 in by_id, "Auto stub should be emitted when configs exist"
|
||||
assert by_id[0]["supports_image_input"] is True
|
||||
assert by_id[0]["is_auto_mode"] is True
|
||||
|
||||
# Explicit True is preserved.
|
||||
assert by_id[-1]["supports_image_input"] is True
|
||||
|
||||
# Explicit False is preserved (the exact failure mode the safety net
|
||||
# guards against — DeepSeek V3 over OpenRouter would 404 with "No
|
||||
# endpoints found that support image input").
|
||||
assert by_id[-2]["supports_image_input"] is False
|
||||
|
||||
# Unannotated GPT-4o: resolver consults LiteLLM, which says vision.
|
||||
assert by_id[-10_010]["supports_image_input"] is True
|
||||
|
||||
# Unknown / unmapped model: default-allow rather than pre-judge.
|
||||
assert by_id[-10_011]["supports_image_input"] is True
|
||||
|
||||
for cfg in payload:
|
||||
assert "supports_image_input" in cfg, (
|
||||
f"supports_image_input missing from {cfg.get('id')}"
|
||||
)
|
||||
assert isinstance(cfg["supports_image_input"], bool)
|
||||
|
|
@ -27,9 +27,18 @@ async def test_resolve_billing_for_auto_mode(monkeypatch):
|
|||
from app.routes import image_generation_routes
|
||||
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
async def _no_auto_candidates(*_args, **_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
image_generation_routes,
|
||||
"auto_model_candidates",
|
||||
_no_auto_candidates,
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None)
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, # Not consumed on this code path.
|
||||
session=None,
|
||||
config_id=0, # IMAGE_GEN_AUTO_MODE_ID
|
||||
search_space=search_space,
|
||||
)
|
||||
|
|
@ -45,26 +54,42 @@ async def test_resolve_billing_for_premium_global_config(monkeypatch):
|
|||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
"GLOBAL_MODELS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-image-1",
|
||||
"connection_id": -101,
|
||||
"model_id": "gpt-image-1",
|
||||
"billing_tier": "premium",
|
||||
"quota_reserve_micros": 75_000,
|
||||
"catalog": {"quota_reserve_micros": 75_000},
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-2.5-flash-image",
|
||||
"connection_id": -102,
|
||||
"model_id": "google/gemini-2.5-flash-image",
|
||||
"billing_tier": "free",
|
||||
"catalog": {},
|
||||
},
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_CONNECTIONS",
|
||||
[
|
||||
{"id": -101, "provider": "openai", "api_key": "sk-test", "base_url": None, "extra": {}},
|
||||
{
|
||||
"id": -102,
|
||||
"provider": "openrouter",
|
||||
"api_key": "sk-or-test",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"extra": {},
|
||||
},
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None)
|
||||
|
||||
# Premium with override.
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
|
|
@ -94,7 +119,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free():
|
|||
from app.routes import image_generation_routes
|
||||
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None)
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=42, search_space=search_space
|
||||
)
|
||||
|
|
@ -105,7 +130,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
|
||||
"""When the request omits ``image_generation_config_id``, the helper
|
||||
"""When the request omits ``image_gen_model_id``, the helper
|
||||
must consult the search space's default — so a search space pinned
|
||||
to a premium global config still gates new requests by quota.
|
||||
"""
|
||||
|
|
@ -114,19 +139,26 @@ async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
|
|||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
"GLOBAL_MODELS",
|
||||
[
|
||||
{
|
||||
"id": -7,
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-image-1",
|
||||
"connection_id": -101,
|
||||
"model_id": "gpt-image-1",
|
||||
"billing_tier": "premium",
|
||||
"catalog": {},
|
||||
}
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_CONNECTIONS",
|
||||
[{"id": -101, "provider": "openai", "api_key": "sk-test", "base_url": None, "extra": {}}],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=-7)
|
||||
search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=-7)
|
||||
(
|
||||
tier,
|
||||
model,
|
||||
|
|
|
|||
|
|
@ -1,27 +1,4 @@
|
|||
"""Unit tests for ``_resolve_agent_billing_for_search_space``.
|
||||
|
||||
Validates the resolver used by Celery podcast/video tasks to compute
|
||||
``(owner_user_id, billing_tier, base_model)`` from a search space and its
|
||||
agent LLM config. The resolver mirrors chat's billing-resolution pattern at
|
||||
``stream_new_chat.py:2294-2351`` and is the single integration point that
|
||||
prevents Auto-mode podcast/video from leaking premium credit.
|
||||
|
||||
Coverage:
|
||||
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
|
||||
global → returns ``("premium", <base_model>)``.
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
|
||||
global → returns ``("free", <base_model>)``.
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
|
||||
→ always ``"free"``.
|
||||
* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without
|
||||
hitting the pin service.
|
||||
* Negative id (no Auto) → uses ``get_global_llm_config``'s
|
||||
``billing_tier``.
|
||||
* Positive id (user BYOK) → always ``"free"``.
|
||||
* Search space not found → raises ``ValueError``.
|
||||
* ``agent_llm_id`` is None → raises ``ValueError``.
|
||||
"""
|
||||
"""Unit tests for ``_resolve_agent_billing_for_search_space``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -34,11 +11,6 @@ import pytest
|
|||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
|
@ -51,14 +23,6 @@ class _FakeExecResult:
|
|||
|
||||
|
||||
class _FakeSession:
|
||||
"""Tiny AsyncSession stub.
|
||||
|
||||
``responses`` is a list of objects to return from successive
|
||||
``execute()`` calls (in order). The resolver makes at most two
|
||||
``execute()`` calls (search-space lookup, then optionally NewLLMConfig
|
||||
lookup), so two queued responses cover the matrix.
|
||||
"""
|
||||
|
||||
def __init__(self, responses: list):
|
||||
self._responses = list(responses)
|
||||
|
||||
|
|
@ -67,9 +31,6 @@ class _FakeSession:
|
|||
return _FakeExecResult(None)
|
||||
return _FakeExecResult(self._responses.pop(0))
|
||||
|
||||
async def commit(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakePinResolution:
|
||||
|
|
@ -78,53 +39,33 @@ class _FakePinResolution:
|
|||
from_existing_pin: bool = False
|
||||
|
||||
|
||||
def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=42,
|
||||
agent_llm_id=agent_llm_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
def _make_search_space(*, chat_model_id: int | None, user_id: UUID) -> SimpleNamespace:
|
||||
return SimpleNamespace(id=42, chat_model_id=chat_model_id, user_id=user_id)
|
||||
|
||||
|
||||
def _make_byok_config(
|
||||
*, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
|
||||
def _make_byok_model(
|
||||
*, id_: int, base_model: str | None = None, model_id: str = "gpt-byok"
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=id_,
|
||||
model_name=model_name,
|
||||
litellm_params={"base_model": base_model} if base_model else {},
|
||||
model_id=model_id,
|
||||
catalog={"base_model": base_model} if base_model else {},
|
||||
connection=SimpleNamespace(enabled=True, search_space_id=42, user_id=None),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
|
||||
"""Auto + thread → pin service resolves to negative-id premium config →
|
||||
resolver returns ``("premium", <base_model>)``."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)])
|
||||
|
||||
# Mock the pin service to return a concrete premium config id.
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
assert selected_llm_config_id == 0
|
||||
assert thread_id == 99
|
||||
async def _fake_resolve_pin(*_args, **kwargs):
|
||||
assert kwargs["selected_llm_config_id"] == 0
|
||||
assert kwargs["thread_id"] == 99
|
||||
return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
|
||||
|
||||
# Mock global config lookup to return a premium entry.
|
||||
def _fake_get_global(cfg_id):
|
||||
if cfg_id == -1:
|
||||
return {
|
||||
|
|
@ -135,8 +76,6 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
|
|||
}
|
||||
return None
|
||||
|
||||
# Lazy imports inside the resolver — patch the *target* modules so the
|
||||
# imported names resolve to our fakes.
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
|
|
@ -154,77 +93,18 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
|
|||
assert base_model == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
|
||||
"""Auto + thread → pin returns negative-id free config → resolver
|
||||
returns ``("free", <base_model>)``. Same path the pin service takes for
|
||||
out-of-credit users (graceful degradation)."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
if cfg_id == -3:
|
||||
return {
|
||||
"id": -3,
|
||||
"model_name": "openrouter/free-model",
|
||||
"billing_tier": "free",
|
||||
"litellm_params": {"base_model": "openrouter/free-model"},
|
||||
}
|
||||
return None
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "openrouter/free-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
|
||||
"""Auto + thread → pin returns positive-id BYOK config → resolver
|
||||
returns ``("free", ...)`` (BYOK is always free per
|
||||
``AgentConfig.from_new_llm_config``)."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
|
||||
byok_cfg = _make_byok_config(
|
||||
id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
|
||||
search_space = _make_search_space(chat_model_id=0, user_id=user_id)
|
||||
byok_model = _make_byok_model(
|
||||
id_=17, base_model="anthropic/claude-3-haiku", model_id="my-claude"
|
||||
)
|
||||
session = _FakeSession([search_space, byok_cfg])
|
||||
session = _FakeSession([search_space, byok_model])
|
||||
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
async def _fake_resolve_pin(*_args, **_kwargs):
|
||||
return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
|
|
@ -244,13 +124,10 @@ async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_without_thread_id_falls_back_to_free():
|
||||
"""Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
|
||||
the pin service. Forward-compat fallback for any future direct-API
|
||||
entrypoint that doesn't have a chat thread."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=None
|
||||
|
|
@ -263,13 +140,10 @@ async def test_auto_mode_without_thread_id_falls_back_to_free():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
|
||||
"""If the pin service raises ``ValueError`` (thread missing /
|
||||
mismatched search space), the resolver should log and return free
|
||||
rather than killing the whole task."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)])
|
||||
|
||||
async def _fake_resolve_pin(*args, **kwargs):
|
||||
raise ValueError("thread missing")
|
||||
|
|
@ -291,12 +165,10 @@ async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_premium_global_returns_premium(monkeypatch):
|
||||
"""Explicit negative agent_llm_id → ``get_global_llm_config`` →
|
||||
return its ``billing_tier``."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=-1, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
|
|
@ -319,50 +191,15 @@ async def test_negative_id_premium_global_returns_premium(monkeypatch):
|
|||
assert base_model == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_free_global_returns_free(monkeypatch):
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "openrouter/some-free",
|
||||
"billing_tier": "free",
|
||||
"litellm_params": {"base_model": "openrouter/some-free"},
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=None
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "openrouter/some-free"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
|
||||
"""When the global config has no ``litellm_params.base_model``, the
|
||||
resolver falls back to ``model_name`` — matching chat's behavior."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=-5, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "fallback-model",
|
||||
"billing_tier": "premium",
|
||||
# No litellm_params.
|
||||
}
|
||||
return {"id": cfg_id, "model_name": "fallback-model", "billing_tier": "premium"}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
|
|
@ -378,14 +215,12 @@ async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypat
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_id_byok_is_always_free():
|
||||
"""Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
|
||||
regardless of underlying provider tier."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
|
||||
byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
|
||||
session = _FakeSession([search_space, byok_cfg])
|
||||
search_space = _make_search_space(chat_model_id=23, user_id=user_id)
|
||||
byok_model = _make_byok_model(id_=23, base_model="anthropic/claude-3.5-sonnet")
|
||||
session = _FakeSession([search_space, byok_model])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
|
|
@ -398,13 +233,10 @@ async def test_positive_id_byok_is_always_free():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
|
||||
"""If the BYOK config row is missing/deleted but the search space still
|
||||
points at it, the resolver still returns free (no debit) with an empty
|
||||
base_model — billable_call's premium path is skipped, no harm done."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=99, user_id=user_id)])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
|
|
@ -419,18 +251,18 @@ async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
|
|||
async def test_search_space_not_found_raises_value_error():
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
session = _FakeSession([None])
|
||||
|
||||
with pytest.raises(ValueError, match="Search space"):
|
||||
await _resolve_agent_billing_for_search_space(session, search_space_id=999)
|
||||
await _resolve_agent_billing_for_search_space(
|
||||
_FakeSession([None]), search_space_id=999
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_llm_id_none_raises_value_error():
|
||||
async def test_chat_model_id_none_raises_value_error():
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=None, user_id=user_id)])
|
||||
|
||||
with pytest.raises(ValueError, match="agent_llm_id"):
|
||||
with pytest.raises(ValueError, match="chat_model_id"):
|
||||
await _resolve_agent_billing_for_search_space(session, search_space_id=42)
|
||||
|
|
|
|||
|
|
@ -32,8 +32,9 @@ class _FakeQuotaResult:
|
|||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, thread):
|
||||
def __init__(self, *, thread=None, scalars=None):
|
||||
self._thread = thread
|
||||
self._scalars = scalars or []
|
||||
|
||||
def unique(self):
|
||||
return self
|
||||
|
|
@ -41,19 +42,69 @@ class _FakeExecResult:
|
|||
def scalar_one_or_none(self):
|
||||
return self._thread
|
||||
|
||||
def scalars(self):
|
||||
return SimpleNamespace(all=lambda: self._scalars)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, thread):
|
||||
def __init__(self, thread, *, models=None):
|
||||
self.thread = thread
|
||||
self.models = models or []
|
||||
self.commit_count = 0
|
||||
self.execute_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self.thread)
|
||||
self.execute_count += 1
|
||||
if self.execute_count == 1:
|
||||
return _FakeExecResult(thread=self.thread)
|
||||
return _FakeExecResult(scalars=self.models)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
|
||||
def _set_global_llm_configs(monkeypatch, config, configs: list[dict]):
|
||||
"""Patch the new global model catalog shape from compact legacy cfg fixtures."""
|
||||
connections = []
|
||||
models = []
|
||||
for cfg in configs:
|
||||
config_id = int(cfg["id"])
|
||||
connection_id = config_id - 100_000
|
||||
provider = cfg.get("provider") or cfg.get("litellm_provider")
|
||||
model_name = cfg["model_name"]
|
||||
connections.append(
|
||||
{
|
||||
"id": connection_id,
|
||||
"provider": provider,
|
||||
"scope": "GLOBAL",
|
||||
"enabled": True,
|
||||
}
|
||||
)
|
||||
models.append(
|
||||
{
|
||||
"id": config_id,
|
||||
"connection_id": connection_id,
|
||||
"model_id": model_name,
|
||||
"display_name": cfg.get("name") or model_name,
|
||||
"supports_chat": cfg.get("supports_chat", True),
|
||||
"supports_image_input": cfg.get("supports_image_input", True),
|
||||
"supports_tools": cfg.get("supports_tools", True),
|
||||
"supports_image_generation": cfg.get("supports_image_generation", False),
|
||||
"capabilities_override": cfg.get("capabilities_override") or {},
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"catalog": {
|
||||
"auto_pin_tier": cfg.get("auto_pin_tier"),
|
||||
"quality_score": cfg.get("quality_score")
|
||||
or cfg.get("quality_score_static"),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs)
|
||||
monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections)
|
||||
monkeypatch.setattr(config, "GLOBAL_MODELS", models)
|
||||
|
||||
|
||||
def _thread(
|
||||
*,
|
||||
search_space_id: int = 10,
|
||||
|
|
@ -71,9 +122,9 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
|
||||
{
|
||||
|
|
@ -111,9 +162,9 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -2,
|
||||
|
|
@ -158,9 +209,9 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
|
|
@ -216,9 +267,9 @@ async def test_next_turn_reuses_existing_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
|
|
@ -257,9 +308,9 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
|
|
@ -295,9 +346,9 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -2,
|
||||
|
|
@ -340,9 +391,9 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -2,
|
||||
|
|
@ -385,9 +436,9 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -2,
|
||||
|
|
@ -433,9 +484,9 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-2))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
|
||||
],
|
||||
|
|
@ -458,9 +509,9 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-999))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
|
||||
],
|
||||
|
|
@ -487,7 +538,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quality-aware pin selection (Auto Fastest upgrade)
|
||||
# Quality-aware pin selection (Auto upgrade)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
|
@ -498,9 +549,9 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
|
|
@ -550,9 +601,9 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
|
|
@ -602,9 +653,9 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
|
|
@ -676,9 +727,9 @@ async def test_top_k_picks_only_high_score_models(monkeypatch):
|
|||
"quality_score": 10,
|
||||
"health_gated": False,
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[*high_score_cfgs, low_score_trap],
|
||||
)
|
||||
|
||||
|
|
@ -723,9 +774,9 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
|
|
@ -775,9 +826,9 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
|
|
@ -833,9 +884,9 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
|
|
@ -886,9 +937,9 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
|
|
@ -931,9 +982,9 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
|
|
|
|||
|
|
@ -15,15 +15,19 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base():
|
|||
"""The global-config branch forwards the explicit OpenRouter base."""
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
cfg = {
|
||||
global_model = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"connection_id": -101,
|
||||
"model_id": "openai/gpt-image-1",
|
||||
"supports_image_generation": True,
|
||||
"capabilities_override": {},
|
||||
}
|
||||
global_connection = {
|
||||
"id": -101,
|
||||
"provider": "openrouter",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"extra": {},
|
||||
}
|
||||
|
||||
captured: dict = {}
|
||||
|
|
@ -33,7 +37,7 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base():
|
|||
return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={})
|
||||
|
||||
image_gen = MagicMock()
|
||||
image_gen.image_generation_config_id = cfg["id"]
|
||||
image_gen.image_gen_model_id = global_model["id"]
|
||||
image_gen.prompt = "test"
|
||||
image_gen.n = 1
|
||||
image_gen.quality = None
|
||||
|
|
@ -43,14 +47,19 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base():
|
|||
image_gen.model = None
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.image_generation_config_id = cfg["id"]
|
||||
search_space.image_gen_model_id = global_model["id"]
|
||||
session = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
"_get_global_image_gen_config",
|
||||
return_value=cfg,
|
||||
"_get_global_model",
|
||||
return_value=global_model,
|
||||
),
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
"_get_global_connection",
|
||||
return_value=global_connection,
|
||||
),
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
|
|
@ -74,15 +83,19 @@ async def test_generate_image_tool_global_sets_explicit_api_base():
|
|||
generate_image as gi_module,
|
||||
)
|
||||
|
||||
cfg = {
|
||||
global_model = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"connection_id": -101,
|
||||
"model_id": "openai/gpt-image-1",
|
||||
"supports_image_generation": True,
|
||||
"capabilities_override": {},
|
||||
}
|
||||
global_connection = {
|
||||
"id": -101,
|
||||
"provider": "openrouter",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"extra": {},
|
||||
}
|
||||
|
||||
captured: dict = {}
|
||||
|
|
@ -98,7 +111,7 @@ async def test_generate_image_tool_global_sets_explicit_api_base():
|
|||
|
||||
search_space = MagicMock()
|
||||
search_space.id = 1
|
||||
search_space.image_generation_config_id = cfg["id"]
|
||||
search_space.image_gen_model_id = global_model["id"]
|
||||
|
||||
session_cm = AsyncMock()
|
||||
session = AsyncMock()
|
||||
|
|
@ -121,7 +134,8 @@ async def test_generate_image_tool_global_sets_explicit_api_base():
|
|||
|
||||
with (
|
||||
patch.object(gi_module, "shielded_async_session", return_value=session_cm),
|
||||
patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg),
|
||||
patch.object(gi_module, "_get_global_model", return_value=global_model),
|
||||
patch.object(gi_module, "_get_global_connection", return_value=global_connection),
|
||||
patch.object(
|
||||
gi_module, "aimage_generation", side_effect=fake_aimage_generation
|
||||
),
|
||||
|
|
|
|||
|
|
@ -217,7 +217,7 @@ def test_generate_configs_drops_non_text_and_non_tool_models():
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_image_gen_configs / _generate_vision_llm_configs
|
||||
# _generate_image_gen_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
|
@ -263,7 +263,7 @@ def test_generate_image_gen_configs_filters_by_image_output():
|
|||
# Each config must carry ``billing_tier`` for routing in image_generation_routes.
|
||||
for c in cfgs:
|
||||
assert c["billing_tier"] in {"free", "premium"}
|
||||
assert c["litellm_provider"] == "openrouter"
|
||||
assert c["provider"] == "openrouter"
|
||||
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
# Emit the OpenRouter base URL at source so every call path passes an
|
||||
# explicit api_base and cannot inherit a process-global endpoint.
|
||||
|
|
@ -271,9 +271,7 @@ def test_generate_image_gen_configs_filters_by_image_output():
|
|||
|
||||
|
||||
def test_generate_image_gen_configs_assigns_image_id_offset():
|
||||
"""Image configs use a different id_offset (-20000) so their negative
|
||||
IDs don't collide with chat configs (-10000) or vision configs (-30000).
|
||||
"""
|
||||
"""Image configs use their own id_offset (-20000)."""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_image_gen_configs,
|
||||
)
|
||||
|
|
@ -291,88 +289,3 @@ def test_generate_image_gen_configs_assigns_image_id_offset():
|
|||
assert all(c["id"] < -20_000 + 1 for c in cfgs)
|
||||
assert all(c["id"] > -29_000_000 for c in cfgs)
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_filters_by_image_input_text_output():
|
||||
"""Vision LLMs must accept image input AND emit text — pure image-gen
|
||||
(no text out) and text-only (no image in) models are excluded.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_vision_llm_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
# GPT-4o: vision LLM (image in, text out) — must emit.
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"context_length": 128_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
# Pure image generator — image *output*, no text out. Must NOT emit.
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["image"],
|
||||
},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
},
|
||||
# Pure text model (no image in). Must NOT emit.
|
||||
{
|
||||
"id": "anthropic/claude-3-haiku",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000001", "completion": "0.000005"},
|
||||
},
|
||||
]
|
||||
|
||||
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||
names = {c["model_name"] for c in cfgs}
|
||||
assert names == {"openai/gpt-4o"}
|
||||
|
||||
cfg = cfgs[0]
|
||||
assert cfg["billing_tier"] == "premium"
|
||||
# Pricing carried inline so pricing_registration can register vision
|
||||
# under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
|
||||
# is cleared.
|
||||
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
# Emit the OpenRouter base URL at source so every call path passes an
|
||||
# explicit api_base and cannot inherit a process-global endpoint.
|
||||
assert cfg["api_base"] == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_drops_chat_only_filters():
|
||||
"""A small-context vision model that doesn't advertise tool calling is
|
||||
still a valid vision LLM for "describe this image" prompts. The chat
|
||||
filters (``supports_tool_calling``, ``has_sufficient_context``) must
|
||||
NOT be applied to vision emission.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_vision_llm_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
{
|
||||
"id": "tiny/vision-mini",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": [], # no tools
|
||||
"context_length": 4_000, # well below MIN_CONTEXT_LENGTH
|
||||
"pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
|
||||
}
|
||||
]
|
||||
|
||||
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||
assert len(cfgs) == 1
|
||||
assert cfgs[0]["model_name"] == "tiny/vision-mini"
|
||||
|
|
|
|||
|
|
@ -370,77 +370,3 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog):
|
|||
assert any("custom-deployment" in payload for payload in successful_calls)
|
||||
|
||||
|
||||
def test_vision_configs_registered_with_chat_shape(monkeypatch):
|
||||
"""``register_pricing_from_global_configs`` walks
|
||||
``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
|
||||
calls (during indexing) bill correctly. Vision configs use the same
|
||||
chat-shape token prices, but image-gen pricing is intentionally NOT
|
||||
registered here (handled via ``response_cost`` in LiteLLM).
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
|
||||
)
|
||||
|
||||
# No chat configs — only vision. Proves the vision walk is a separate
|
||||
# iteration, not piggy-backed on the chat list.
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"billing_tier": "premium",
|
||||
"input_cost_per_token": 5e-6,
|
||||
"output_cost_per_token": 15e-6,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/openai/gpt-4o" in spy.all_keys
|
||||
payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
|
||||
assert payload_value["mode"] == "chat"
|
||||
assert payload_value["litellm_provider"] == "openrouter"
|
||||
assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
|
||||
|
||||
def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
|
||||
"""If the OpenRouter pricing cache misses a vision model (different
|
||||
catalogue surface), the vision walk falls back to inline
|
||||
``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-2.5-flash",
|
||||
"billing_tier": "premium",
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 4e-6,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/google/gemini-2.5-flash" in spy.all_keys
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Unit tests for the Auto (Fastest) quality scoring module."""
|
||||
"""Unit tests for the Auto quality scoring module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
|
|||
|
|
@ -1,77 +0,0 @@
|
|||
"""Vision LLM resolution must pass explicit per-config ``api_base``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vision_llm_global_openrouter_sets_api_base():
|
||||
"""Global negative-ID branch forwards the explicit OpenRouter base."""
|
||||
from app.services import llm_service
|
||||
|
||||
cfg = {
|
||||
"id": -30_001,
|
||||
"name": "GPT-4o Vision (OpenRouter)",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.id = 1
|
||||
search_space.user_id = "user-x"
|
||||
search_space.vision_llm_config_id = cfg["id"]
|
||||
|
||||
session = AsyncMock()
|
||||
scalars = MagicMock()
|
||||
scalars.first.return_value = search_space
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session.execute.return_value = result
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class FakeSanitized:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.vision_llm_router_service.get_global_vision_llm_config",
|
||||
return_value=cfg,
|
||||
),
|
||||
patch(
|
||||
"app.agents.chat.runtime.llm_config.SanitizedChatLiteLLM",
|
||||
new=FakeSanitized,
|
||||
),
|
||||
):
|
||||
await llm_service.get_vision_llm(session=session, search_space_id=1)
|
||||
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-4o"
|
||||
|
||||
|
||||
def test_vision_router_deployment_sets_api_base_when_config_empty():
|
||||
"""Auto-mode vision router carries explicit api_base into deployments."""
|
||||
from app.services.vision_llm_router_service import VisionLLMRouterService
|
||||
|
||||
deployment = VisionLLMRouterService._config_to_deployment(
|
||||
{
|
||||
"model_name": "openai/gpt-4o",
|
||||
"litellm_provider": "openrouter",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
)
|
||||
assert deployment is not None
|
||||
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
|
||||
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o"
|
||||
Loading…
Add table
Add a link
Reference in a new issue