mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-12 20:45:20 +02:00
refactor(model-connections): move backend model connections to provider capabilities
This commit is contained in:
parent
3089dd4cb6
commit
5d5d574550
31 changed files with 772 additions and 476 deletions
|
|
@ -17,13 +17,6 @@ branch_labels: str | Sequence[str] | None = None
|
|||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
connection_protocol = postgresql.ENUM(
|
||||
"OLLAMA",
|
||||
"OPENAI_COMPATIBLE",
|
||||
"ANTHROPIC",
|
||||
name="connectionprotocol",
|
||||
create_type=False,
|
||||
)
|
||||
connection_scope = postgresql.ENUM(
|
||||
"GLOBAL",
|
||||
"SEARCH_SPACE",
|
||||
|
|
@ -73,36 +66,67 @@ def _add_searchspace_column_if_missing(column_name: str) -> None:
|
|||
op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
|
||||
if _column_exists(table_name, column_name):
|
||||
op.drop_column(table_name, column_name)
|
||||
|
||||
|
||||
def _drop_index_if_exists(table_name: str, index_name: str) -> None:
|
||||
if _index_exists(table_name, index_name):
|
||||
op.drop_index(index_name, table_name=table_name)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
connection_protocol.create(bind, checkfirst=True)
|
||||
op.execute("ALTER TYPE connectionprotocol ADD VALUE IF NOT EXISTS 'ANTHROPIC'")
|
||||
connection_scope.create(bind, checkfirst=True)
|
||||
model_source.create(bind, checkfirst=True)
|
||||
|
||||
if _table_exists("connections"):
|
||||
if _column_exists("connections", "native_provider") and not _column_exists(
|
||||
"connections", "litellm_provider"
|
||||
if _column_exists("connections", "litellm_provider") and not _column_exists(
|
||||
"connections", "provider"
|
||||
):
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"litellm_provider",
|
||||
new_column_name="provider",
|
||||
existing_type=sa.String(length=100),
|
||||
existing_nullable=True,
|
||||
)
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"provider",
|
||||
existing_type=sa.String(length=100),
|
||||
nullable=False,
|
||||
)
|
||||
elif _column_exists("connections", "native_provider") and not _column_exists(
|
||||
"connections", "provider"
|
||||
):
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"native_provider",
|
||||
new_column_name="litellm_provider",
|
||||
new_column_name="provider",
|
||||
existing_type=sa.String(length=100),
|
||||
existing_nullable=True,
|
||||
)
|
||||
elif not _column_exists("connections", "litellm_provider"):
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"provider",
|
||||
existing_type=sa.String(length=100),
|
||||
nullable=False,
|
||||
)
|
||||
elif not _column_exists("connections", "provider"):
|
||||
op.add_column(
|
||||
"connections",
|
||||
sa.Column("litellm_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("provider", sa.String(length=100), nullable=False),
|
||||
)
|
||||
_drop_index_if_exists("connections", "ix_connections_protocol")
|
||||
_drop_column_if_exists("connections", "protocol")
|
||||
else:
|
||||
op.create_table(
|
||||
"connections",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("protocol", connection_protocol, nullable=False),
|
||||
sa.Column("litellm_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("provider", sa.String(length=100), nullable=False),
|
||||
sa.Column("base_url", sa.String(length=500), nullable=True),
|
||||
sa.Column("api_key", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
|
|
@ -131,18 +155,20 @@ def upgrade() -> None:
|
|||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
if _index_exists("connections", "ix_connections_native_provider") and not _index_exists(
|
||||
"connections", "ix_connections_litellm_provider"
|
||||
"connections", "ix_connections_provider"
|
||||
):
|
||||
op.execute(
|
||||
"ALTER INDEX ix_connections_native_provider "
|
||||
"RENAME TO ix_connections_litellm_provider"
|
||||
"RENAME TO ix_connections_provider"
|
||||
)
|
||||
_create_index_if_missing("ix_connections_protocol", "connections", ["protocol"])
|
||||
_create_index_if_missing(
|
||||
"ix_connections_litellm_provider",
|
||||
"connections",
|
||||
["litellm_provider"],
|
||||
)
|
||||
if _index_exists("connections", "ix_connections_litellm_provider") and not _index_exists(
|
||||
"connections", "ix_connections_provider"
|
||||
):
|
||||
op.execute(
|
||||
"ALTER INDEX ix_connections_litellm_provider "
|
||||
"RENAME TO ix_connections_provider"
|
||||
)
|
||||
_create_index_if_missing("ix_connections_provider", "connections", ["provider"])
|
||||
_create_index_if_missing("ix_connections_scope", "connections", ["scope"])
|
||||
|
||||
if not _table_exists("models"):
|
||||
|
|
@ -159,24 +185,11 @@ def upgrade() -> None:
|
|||
server_default="DISCOVERED",
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"capabilities",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"capabilities_declared",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"capabilities_verified",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("supports_chat", sa.Boolean(), nullable=True),
|
||||
sa.Column("max_input_tokens", sa.Integer(), nullable=True),
|
||||
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
|
||||
sa.Column("supports_tools", sa.Boolean(), nullable=True),
|
||||
sa.Column("supports_image_generation", sa.Boolean(), nullable=True),
|
||||
sa.Column(
|
||||
"capabilities_override",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
|
|
@ -198,6 +211,24 @@ def upgrade() -> None:
|
|||
"connection_id", "model_id", name="uq_models_connection_model_id"
|
||||
),
|
||||
)
|
||||
else:
|
||||
if not _column_exists("models", "supports_chat"):
|
||||
op.add_column("models", sa.Column("supports_chat", sa.Boolean(), nullable=True))
|
||||
if not _column_exists("models", "max_input_tokens"):
|
||||
op.add_column("models", sa.Column("max_input_tokens", sa.Integer(), nullable=True))
|
||||
if not _column_exists("models", "supports_image_input"):
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_image_input", sa.Boolean(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "supports_tools"):
|
||||
op.add_column("models", sa.Column("supports_tools", sa.Boolean(), nullable=True))
|
||||
if not _column_exists("models", "supports_image_generation"):
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_image_generation", sa.Boolean(), nullable=True)
|
||||
)
|
||||
_drop_column_if_exists("models", "capabilities")
|
||||
_drop_column_if_exists("models", "capabilities_declared")
|
||||
_drop_column_if_exists("models", "capabilities_verified")
|
||||
_create_index_if_missing("ix_models_connection_id", "models", ["connection_id"])
|
||||
_create_index_if_missing("ix_models_model_id", "models", ["model_id"])
|
||||
_create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"])
|
||||
|
|
@ -206,6 +237,8 @@ def upgrade() -> None:
|
|||
_add_searchspace_column_if_missing("image_gen_model_id")
|
||||
_add_searchspace_column_if_missing("vision_model_id")
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS connectionprotocol")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("searchspaces", "vision_model_id")
|
||||
|
|
@ -218,11 +251,9 @@ def downgrade() -> None:
|
|||
op.drop_table("models")
|
||||
|
||||
op.drop_index(op.f("ix_connections_scope"), table_name="connections")
|
||||
op.drop_index(op.f("ix_connections_litellm_provider"), table_name="connections")
|
||||
op.drop_index(op.f("ix_connections_protocol"), table_name="connections")
|
||||
op.drop_index(op.f("ix_connections_provider"), table_name="connections")
|
||||
op.drop_table("connections")
|
||||
|
||||
bind = op.get_bind()
|
||||
model_source.drop(bind, checkfirst=True)
|
||||
connection_scope.drop(bind, checkfirst=True)
|
||||
connection_protocol.drop(bind, checkfirst=True)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from app.services.auto_model_pin_service import (
|
|||
auto_model_candidates,
|
||||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_resolver import to_litellm
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
|
|
@ -146,9 +147,7 @@ def create_generate_image_tool(
|
|||
|
||||
if config_id < 0:
|
||||
global_model = _get_global_model(config_id)
|
||||
if not global_model or not (
|
||||
global_model.get("capabilities") or {}
|
||||
).get("image_gen"):
|
||||
if not global_model or not has_capability(global_model, "image_gen"):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
global_connection = _get_global_connection(
|
||||
|
|
@ -191,7 +190,7 @@ def create_generate_image_tool(
|
|||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
if not (db_model.capabilities or {}).get("image_gen"):
|
||||
if not has_capability(db_model, "image_gen"):
|
||||
err = f"Model {config_id} is not image-generation capable"
|
||||
return _failed({"error": err}, error=err)
|
||||
|
||||
|
|
|
|||
|
|
@ -49,16 +49,19 @@ def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
|||
reject the blank text. The OpenAI spec says ``content`` should be
|
||||
``null`` when an assistant message only carries tool calls.
|
||||
"""
|
||||
sanitized: list[BaseMessage] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, list):
|
||||
msg.content = _sanitize_content(msg.content)
|
||||
next_msg = msg.model_copy(deep=True)
|
||||
if isinstance(next_msg.content, list):
|
||||
next_msg.content = _sanitize_content(next_msg.content)
|
||||
if (
|
||||
isinstance(msg, AIMessage)
|
||||
and (not msg.content or msg.content == "")
|
||||
and getattr(msg, "tool_calls", None)
|
||||
isinstance(next_msg, AIMessage)
|
||||
and (not next_msg.content or next_msg.content == "")
|
||||
and getattr(next_msg, "tool_calls", None)
|
||||
):
|
||||
msg.content = None # type: ignore[assignment]
|
||||
return messages
|
||||
next_msg.content = None # type: ignore[assignment]
|
||||
sanitized.append(next_msg)
|
||||
return sanitized
|
||||
|
||||
|
||||
class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||
|
|
@ -89,6 +92,22 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
|
|||
):
|
||||
yield chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
stream: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return await super()._agenerate(
|
||||
_sanitize_messages(messages),
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||
"""Attach a ``profile`` dict to ChatLiteLLM with model context metadata."""
|
||||
|
|
@ -210,7 +229,7 @@ class AgentConfig:
|
|||
# BYOK rows have no curated flag; ask LiteLLM (default-allow on
|
||||
# unknown). The streaming safety net still blocks explicit text-only.
|
||||
supports_image_input=derive_supports_image_input(
|
||||
litellm_provider=provider_value.lower(),
|
||||
provider=provider_value.lower(),
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
|
|
@ -229,7 +248,7 @@ class AgentConfig:
|
|||
|
||||
system_instructions = yaml_config.get("system_instructions", "")
|
||||
|
||||
provider = yaml_config.get("litellm_provider", "")
|
||||
provider = yaml_config.get("provider") or yaml_config.get("litellm_provider", "")
|
||||
model_name = yaml_config.get("model_name", "")
|
||||
custom_provider = yaml_config.get("custom_provider")
|
||||
litellm_params = yaml_config.get("litellm_params") or {}
|
||||
|
|
@ -245,7 +264,7 @@ class AgentConfig:
|
|||
supports_image_input = bool(yaml_config.get("supports_image_input"))
|
||||
else:
|
||||
supports_image_input = derive_supports_image_input(
|
||||
litellm_provider=provider,
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
|
|
@ -396,8 +415,8 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
|||
if llm_config.get("custom_provider"):
|
||||
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
|
||||
else:
|
||||
litellm_provider = llm_config.get("litellm_provider", "openai")
|
||||
model_string = f"{litellm_provider}/{llm_config['model_name']}"
|
||||
provider = llm_config.get("provider") or llm_config.get("litellm_provider", "openai")
|
||||
model_string = f"{provider}/{llm_config['model_name']}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ from app.config import (
|
|||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
|
||||
|
|
@ -622,7 +621,6 @@ async def lifespan(app: FastAPI):
|
|||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
||||
# Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
|
||||
# worker readiness. ``shield`` so Uvicorn cancelling startup
|
||||
|
|
|
|||
|
|
@ -115,14 +115,12 @@ def init_worker(**kwargs):
|
|||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
|
||||
initialize_openrouter_integration()
|
||||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
||||
|
||||
# Celery configuration, sourced from the central Config singleton
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ def load_global_llm_configs():
|
|||
else None
|
||||
)
|
||||
cfg["supports_image_input"] = derive_supports_image_input(
|
||||
litellm_provider=cfg.get("litellm_provider"),
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
|
|
@ -122,7 +122,7 @@ def load_global_llm_configs():
|
|||
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
|
||||
# Tier A — operator-curated, locked first when premium-eligible.
|
||||
# The OpenRouter refresh tick later re-stamps health for any cfg
|
||||
# whose litellm_provider == "openrouter" via _enrich_health.
|
||||
# whose provider == "openrouter" via _enrich_health.
|
||||
try:
|
||||
from app.services.quality_score import static_score_yaml
|
||||
|
||||
|
|
@ -132,7 +132,7 @@ def load_global_llm_configs():
|
|||
cfg["quality_score_static"] = static_q
|
||||
cfg["quality_score"] = static_q
|
||||
cfg["quality_score_health"] = None
|
||||
# YAML cfgs whose litellm_provider is openrouter are also subject
|
||||
# YAML cfgs whose provider is openrouter are also subject
|
||||
# to health gating against their own /endpoints data — a
|
||||
# hand-picked dead OR model is still dead. _enrich_health
|
||||
# re-stamps health_gated for them on the next refresh tick.
|
||||
|
|
@ -362,8 +362,8 @@ def initialize_openrouter_integration():
|
|||
else:
|
||||
print("Info: OpenRouter integration enabled but no models fetched")
|
||||
|
||||
# Image generation + vision LLM emissions are opt-in (issue L).
|
||||
# Both reuse the catalogue already cached by ``service.initialize``
|
||||
# Image generation emissions reuse the catalogue already cached by
|
||||
# ``service.initialize``
|
||||
# so we don't make additional network calls here.
|
||||
if settings.get("image_generation_enabled"):
|
||||
try:
|
||||
|
|
@ -377,18 +377,6 @@ def initialize_openrouter_integration():
|
|||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
|
||||
|
||||
if settings.get("vision_enabled"):
|
||||
try:
|
||||
vision_configs = service.get_vision_llm_configs()
|
||||
if vision_configs:
|
||||
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
|
||||
print(
|
||||
f"Info: OpenRouter integration added {len(vision_configs)} "
|
||||
f"vision LLM models"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
|
||||
|
||||
refresh_global_model_catalog()
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
|
||||
|
|
@ -399,7 +387,6 @@ def materialize_global_configs():
|
|||
|
||||
return materialize_global_model_catalog(
|
||||
chat_configs=getattr(config, "GLOBAL_LLM_CONFIGS", []),
|
||||
vision_configs=getattr(config, "GLOBAL_VISION_LLM_CONFIGS", []),
|
||||
image_configs=getattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", []),
|
||||
)
|
||||
|
||||
|
|
@ -493,29 +480,9 @@ def initialize_image_gen_router():
|
|||
|
||||
|
||||
def initialize_vision_llm_router():
|
||||
vision_configs = load_global_vision_llm_configs()
|
||||
# Reuse the router settings already parsed at Config construction. The
|
||||
# *configs* list is intentionally re-read from YAML (it must exclude the
|
||||
# OpenRouter-injected dynamic models held in config.GLOBAL_VISION_LLM_CONFIGS).
|
||||
router_settings = config.VISION_LLM_ROUTER_SETTINGS
|
||||
|
||||
if not vision_configs:
|
||||
print(
|
||||
"Info: No global vision LLM configs found, "
|
||||
"Vision LLM Auto mode will not be available"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.vision_llm_router_service import VisionLLMRouterService
|
||||
|
||||
VisionLLMRouterService.initialize(vision_configs, router_settings)
|
||||
print(
|
||||
f"Info: Vision LLM Router initialized with {len(vision_configs)} models "
|
||||
f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize Vision LLM Router: {e}")
|
||||
# Retired: vision Auto now uses shared capability-filtered model selection
|
||||
# over GLOBAL/BYOK chat models with supports_image_input=true.
|
||||
return
|
||||
|
||||
|
||||
class Config:
|
||||
|
|
@ -874,7 +841,6 @@ class Config:
|
|||
|
||||
GLOBAL_CONNECTIONS, GLOBAL_MODELS = _materialize_global_model_catalog(
|
||||
chat_configs=GLOBAL_LLM_CONFIGS,
|
||||
vision_configs=GLOBAL_VISION_LLM_CONFIGS,
|
||||
image_configs=GLOBAL_IMAGE_GEN_CONFIGS,
|
||||
)
|
||||
del _materialize_global_model_catalog
|
||||
|
|
|
|||
|
|
@ -280,12 +280,6 @@ class VisionProvider(StrEnum):
|
|||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
class ConnectionProtocol(StrEnum):
|
||||
OLLAMA = "OLLAMA"
|
||||
OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
|
||||
|
||||
class ConnectionScope(StrEnum):
|
||||
GLOBAL = "GLOBAL"
|
||||
SEARCH_SPACE = "SEARCH_SPACE"
|
||||
|
|
@ -1662,8 +1656,7 @@ class Report(BaseModel, TimestampMixin):
|
|||
class Connection(BaseModel, TimestampMixin):
|
||||
__tablename__ = "connections"
|
||||
|
||||
protocol = Column(SQLAlchemyEnum(ConnectionProtocol), nullable=False, index=True)
|
||||
litellm_provider = Column(String(100), nullable=True, index=True)
|
||||
provider = Column(String(100), nullable=False, index=True)
|
||||
base_url = Column(String(500), nullable=True)
|
||||
api_key = Column(String, nullable=True)
|
||||
extra = Column(JSONB, nullable=False, default=dict, server_default="{}")
|
||||
|
|
@ -1715,9 +1708,11 @@ class Model(BaseModel, TimestampMixin):
|
|||
default=ModelSource.DISCOVERED,
|
||||
server_default=ModelSource.DISCOVERED.value,
|
||||
)
|
||||
capabilities = Column(JSONB, nullable=False, default=dict, server_default="{}")
|
||||
capabilities_declared = Column(JSONB, nullable=False, default=dict, server_default="{}")
|
||||
capabilities_verified = Column(JSONB, nullable=False, default=dict, server_default="{}")
|
||||
supports_chat = Column(Boolean, nullable=True)
|
||||
max_input_tokens = Column(Integer, nullable=True)
|
||||
supports_image_input = Column(Boolean, nullable=True)
|
||||
supports_tools = Column(Boolean, nullable=True)
|
||||
supports_image_generation = Column(Boolean, nullable=True)
|
||||
capabilities_override = Column(JSONB, nullable=False, default=dict, server_default="{}")
|
||||
embedding_dimension = Column(Integer, nullable=True)
|
||||
enabled = Column(Boolean, nullable=False, default=True, server_default="true")
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ async def list_anonymous_models():
|
|||
id=cfg.get("id", 0),
|
||||
name=cfg.get("name", ""),
|
||||
description=cfg.get("description"),
|
||||
provider=cfg.get("litellm_provider", ""),
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider", ""),
|
||||
model_name=cfg.get("model_name", ""),
|
||||
billing_tier=cfg.get("billing_tier", "free"),
|
||||
is_premium=cfg.get("billing_tier", "free") == "premium",
|
||||
|
|
@ -161,7 +161,7 @@ async def get_anonymous_model(slug: str):
|
|||
id=cfg.get("id", 0),
|
||||
name=cfg.get("name", ""),
|
||||
description=cfg.get("description"),
|
||||
provider=cfg.get("litellm_provider", ""),
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider", ""),
|
||||
model_name=cfg.get("model_name", ""),
|
||||
billing_tier=cfg.get("billing_tier", "free"),
|
||||
is_premium=cfg.get("billing_tier", "free") == "premium",
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ from app.services.auto_model_pin_service import (
|
|||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.model_resolver import to_litellm
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.signed_image_urls import verify_image_token
|
||||
|
|
@ -166,7 +167,7 @@ async def _execute_image_generation(
|
|||
|
||||
if config_id < 0:
|
||||
global_model = _get_global_model(config_id)
|
||||
if not global_model or not (global_model.get("capabilities") or {}).get("image_gen"):
|
||||
if not global_model or not has_capability(global_model, "image_gen"):
|
||||
raise ValueError(f"Global image generation model {config_id} not found")
|
||||
global_connection = _get_global_connection(global_model["connection_id"])
|
||||
if not global_connection:
|
||||
|
|
@ -200,7 +201,7 @@ async def _execute_image_generation(
|
|||
raise ValueError(f"Image generation model {config_id} not found")
|
||||
if conn.user_id is not None and conn.user_id != search_space.user_id:
|
||||
raise ValueError(f"Image generation model {config_id} not found")
|
||||
if not (db_model.capabilities or {}).get("image_gen"):
|
||||
if not has_capability(db_model, "image_gen"):
|
||||
raise ValueError(f"Model {config_id} is not image-generation capable")
|
||||
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
|
|
@ -272,7 +273,7 @@ async def get_global_image_gen_configs(
|
|||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("litellm_provider"),
|
||||
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from sqlalchemy.orm import selectinload
|
|||
from app.config import config
|
||||
from app.db import (
|
||||
Connection,
|
||||
ConnectionProtocol,
|
||||
ConnectionScope,
|
||||
Model,
|
||||
ModelSource,
|
||||
|
|
@ -22,6 +21,7 @@ from app.schemas import (
|
|||
ConnectionRead,
|
||||
ConnectionUpdate,
|
||||
ModelCreate,
|
||||
ModelProviderRead,
|
||||
ModelRead,
|
||||
ModelRolesRead,
|
||||
ModelRolesUpdate,
|
||||
|
|
@ -34,6 +34,8 @@ from app.services.model_connection_service import (
|
|||
persist_verification,
|
||||
test_model,
|
||||
)
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.provider_registry import REGISTRY
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
|
@ -41,16 +43,6 @@ router = APIRouter()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_litellm_provider(protocol: ConnectionProtocol | str) -> str:
|
||||
protocol_value = getattr(protocol, "value", str(protocol))
|
||||
defaults = {
|
||||
ConnectionProtocol.OLLAMA.value: "ollama_chat",
|
||||
ConnectionProtocol.ANTHROPIC.value: "anthropic",
|
||||
ConnectionProtocol.OPENAI_COMPATIBLE.value: "openai",
|
||||
}
|
||||
return defaults.get(protocol_value, "openai")
|
||||
|
||||
|
||||
def _model_read(model: Model | dict) -> ModelRead:
|
||||
return ModelRead.model_validate(model)
|
||||
|
||||
|
|
@ -68,8 +60,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None
|
|||
|
||||
return ConnectionRead(
|
||||
id=conn.id,
|
||||
protocol=conn.protocol,
|
||||
litellm_provider=conn.litellm_provider,
|
||||
provider=conn.provider,
|
||||
base_url=conn.base_url,
|
||||
extra=conn.extra or {},
|
||||
scope=conn.scope,
|
||||
|
|
@ -85,6 +76,60 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None
|
|||
)
|
||||
|
||||
|
||||
def _apply_model_facts(model: Model, facts: dict) -> None:
|
||||
model.supports_chat = facts.get("supports_chat")
|
||||
model.max_input_tokens = facts.get("max_input_tokens")
|
||||
model.supports_image_input = facts.get("supports_image_input")
|
||||
model.supports_tools = facts.get("supports_tools")
|
||||
model.supports_image_generation = facts.get("supports_image_generation")
|
||||
|
||||
|
||||
def _default_model_for(models: list[Model], capability: str) -> int | None:
|
||||
for model in models:
|
||||
if model.enabled and has_capability(model, capability):
|
||||
return model.id
|
||||
return None
|
||||
|
||||
|
||||
async def _default_unset_roles(
|
||||
session: AsyncSession,
|
||||
conn: Connection,
|
||||
models: list[Model],
|
||||
) -> None:
|
||||
if conn.scope != ConnectionScope.SEARCH_SPACE or conn.search_space_id is None:
|
||||
return
|
||||
search_space = await _get_search_space(session, conn.search_space_id)
|
||||
if search_space.chat_model_id is None:
|
||||
search_space.chat_model_id = _default_model_for(models, "chat")
|
||||
if search_space.vision_model_id is None:
|
||||
vision_default = None
|
||||
if search_space.chat_model_id:
|
||||
chat_model = next((m for m in models if m.id == search_space.chat_model_id), None)
|
||||
if chat_model and has_capability(chat_model, "vision"):
|
||||
vision_default = chat_model.id
|
||||
search_space.vision_model_id = vision_default or _default_model_for(models, "vision")
|
||||
if search_space.image_gen_model_id is None:
|
||||
search_space.image_gen_model_id = _default_model_for(models, "image_gen")
|
||||
|
||||
|
||||
@router.get("/model-providers", response_model=list[ModelProviderRead])
|
||||
async def list_model_providers(user: User = Depends(current_active_user)):
|
||||
del user
|
||||
local_only = {"ollama_chat", "lm_studio"}
|
||||
return [
|
||||
ModelProviderRead(
|
||||
provider=provider,
|
||||
transport=spec.transport.value,
|
||||
discovery=spec.discovery,
|
||||
default_base_url=spec.default_base_url,
|
||||
base_url_required=spec.base_url_required,
|
||||
auth_style=spec.auth_style,
|
||||
local_only=provider in local_only,
|
||||
)
|
||||
for provider, spec in sorted(REGISTRY.items())
|
||||
]
|
||||
|
||||
|
||||
async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace:
|
||||
result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id))
|
||||
search_space = result.scalars().first()
|
||||
|
|
@ -180,8 +225,6 @@ async def create_connection(
|
|||
"You don't have permission to create model connections in this search space",
|
||||
)
|
||||
payload = data.model_dump(exclude={"search_space_id"})
|
||||
if not payload.get("litellm_provider"):
|
||||
payload["litellm_provider"] = _default_litellm_provider(data.protocol)
|
||||
|
||||
conn = Connection(
|
||||
**payload,
|
||||
|
|
@ -254,24 +297,21 @@ async def discover_connection_models(
|
|||
model_id=item["model_id"],
|
||||
display_name=item.get("display_name"),
|
||||
source=item["source"],
|
||||
capabilities=item["capabilities"],
|
||||
capabilities_declared=item["capabilities"],
|
||||
capabilities_verified={},
|
||||
capabilities_override={},
|
||||
enabled=False,
|
||||
catalog=item.get("metadata") or {},
|
||||
)
|
||||
_apply_model_facts(db_model, item)
|
||||
session.add(db_model)
|
||||
else:
|
||||
db_model.display_name = item.get("display_name") or db_model.display_name
|
||||
db_model.capabilities_declared = item["capabilities"]
|
||||
db_model.capabilities = {
|
||||
**item["capabilities"],
|
||||
**(db_model.capabilities_override or {}),
|
||||
}
|
||||
_apply_model_facts(db_model, item)
|
||||
db_model.catalog = item.get("metadata") or db_model.catalog
|
||||
await session.commit()
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _default_unset_roles(session, conn, list(conn.models))
|
||||
await session.commit()
|
||||
conn = await _load_connection(session, connection_id)
|
||||
return [_model_read(model) for model in conn.models]
|
||||
|
||||
|
||||
|
|
@ -297,16 +337,17 @@ async def add_manual_model(
|
|||
model_id=model_id,
|
||||
display_name=data.display_name or None,
|
||||
source=ModelSource.MANUAL,
|
||||
capabilities=capabilities,
|
||||
capabilities_declared=capabilities,
|
||||
capabilities_verified={},
|
||||
capabilities_override={},
|
||||
enabled=True,
|
||||
catalog={},
|
||||
)
|
||||
_apply_model_facts(model, capabilities)
|
||||
session.add(model)
|
||||
await session.commit()
|
||||
await session.refresh(model)
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _default_unset_roles(session, conn, list(conn.models))
|
||||
await session.commit()
|
||||
return _model_read(model)
|
||||
|
||||
|
||||
|
|
@ -327,11 +368,6 @@ async def update_model(
|
|||
update = data.model_dump(exclude_unset=True)
|
||||
for key, value in update.items():
|
||||
setattr(model, key, value)
|
||||
if "capabilities_override" in update:
|
||||
model.capabilities = {
|
||||
**(model.capabilities_declared or {}),
|
||||
**(model.capabilities_override or {}),
|
||||
}
|
||||
await session.commit()
|
||||
await session.refresh(model)
|
||||
return _model_read(model)
|
||||
|
|
|
|||
|
|
@ -1741,12 +1741,11 @@ async def handle_new_chat(
|
|||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
# Use agent_llm_id from search space for chat operations
|
||||
# Positive IDs load from NewLLMConfig database table
|
||||
# Negative IDs load from YAML global configs
|
||||
# Falls back to -1 (first global config) if not configured
|
||||
# Use the converged model-connections role for chat operations.
|
||||
# Positive IDs load Model + Connection rows; negative IDs load
|
||||
# virtual GLOBAL models; 0 means Auto.
|
||||
llm_config_id = (
|
||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
search_space.chat_model_id if search_space.chat_model_id is not None else 0
|
||||
)
|
||||
|
||||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||
|
|
@ -2228,7 +2227,7 @@ async def regenerate_response(
|
|||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
llm_config_id = (
|
||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
search_space.chat_model_id if search_space.chat_model_id is not None else 0
|
||||
)
|
||||
|
||||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||
|
|
@ -2393,7 +2392,7 @@ async def resume_chat(
|
|||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
llm_config_id = (
|
||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
search_space.chat_model_id if search_space.chat_model_id is not None else 0
|
||||
)
|
||||
|
||||
decisions = [d.model_dump() for d in request.decisions]
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
|
|||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
litellm_provider=provider_value.lower(),
|
||||
provider=provider_value.lower(),
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
|
|
@ -147,7 +147,7 @@ async def get_global_new_llm_configs(
|
|||
else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
litellm_provider=cfg.get("litellm_provider"),
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=cfg_base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
|
|
@ -157,7 +157,7 @@ async def get_global_new_llm_configs(
|
|||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("litellm_provider"),
|
||||
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
|
|
|
|||
|
|
@ -419,7 +419,7 @@ async def _get_llm_config_by_id(
|
|||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("litellm_provider"),
|
||||
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base"),
|
||||
|
|
@ -490,7 +490,7 @@ async def _get_image_gen_config_by_id(
|
|||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("litellm_provider"),
|
||||
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
|
|
@ -550,7 +550,7 @@ async def _get_vision_llm_config_by_id(
|
|||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("litellm_provider"),
|
||||
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ async def get_global_vision_llm_configs(
|
|||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("litellm_provider"),
|
||||
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ from .model_connections import (
|
|||
ConnectionRead,
|
||||
ConnectionUpdate,
|
||||
ModelCreate,
|
||||
ModelProviderRead,
|
||||
ModelRead,
|
||||
ModelRolesRead,
|
||||
ModelRolesUpdate,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.db import ConnectionProtocol, ConnectionScope, ModelSource
|
||||
from app.db import ConnectionScope, ModelSource
|
||||
|
||||
|
||||
class ModelRead(BaseModel):
|
||||
|
|
@ -13,9 +13,11 @@ class ModelRead(BaseModel):
|
|||
model_id: str
|
||||
display_name: str | None = None
|
||||
source: ModelSource | str
|
||||
capabilities: dict[str, Any]
|
||||
capabilities_declared: dict[str, Any] = Field(default_factory=dict)
|
||||
capabilities_verified: dict[str, Any] = Field(default_factory=dict)
|
||||
supports_chat: bool | None = None
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool | None = None
|
||||
supports_tools: bool | None = None
|
||||
supports_image_generation: bool | None = None
|
||||
capabilities_override: dict[str, Any] = Field(default_factory=dict)
|
||||
embedding_dimension: int | None = None
|
||||
enabled: bool
|
||||
|
|
@ -28,8 +30,7 @@ class ModelRead(BaseModel):
|
|||
|
||||
class ConnectionRead(BaseModel):
|
||||
id: int
|
||||
protocol: ConnectionProtocol | str
|
||||
litellm_provider: str | None = None
|
||||
provider: str
|
||||
base_url: str | None = None
|
||||
extra: dict[str, Any] = Field(default_factory=dict)
|
||||
scope: ConnectionScope | str
|
||||
|
|
@ -47,8 +48,7 @@ class ConnectionRead(BaseModel):
|
|||
|
||||
|
||||
class ConnectionCreate(BaseModel):
|
||||
protocol: ConnectionProtocol
|
||||
litellm_provider: str | None = Field(None, max_length=100)
|
||||
provider: str = Field(..., max_length=100)
|
||||
base_url: str | None = Field(None, max_length=500)
|
||||
api_key: str | None = None
|
||||
extra: dict[str, Any] = Field(default_factory=dict)
|
||||
|
|
@ -58,7 +58,7 @@ class ConnectionCreate(BaseModel):
|
|||
|
||||
|
||||
class ConnectionUpdate(BaseModel):
|
||||
litellm_provider: str | None = Field(None, max_length=100)
|
||||
provider: str | None = Field(None, max_length=100)
|
||||
base_url: str | None = Field(None, max_length=500)
|
||||
api_key: str | None = None
|
||||
extra: dict[str, Any] | None = None
|
||||
|
|
@ -79,9 +79,24 @@ class ModelCreate(BaseModel):
|
|||
class ModelUpdate(BaseModel):
|
||||
display_name: str | None = Field(None, max_length=255)
|
||||
enabled: bool | None = None
|
||||
supports_chat: bool | None = None
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool | None = None
|
||||
supports_tools: bool | None = None
|
||||
supports_image_generation: bool | None = None
|
||||
capabilities_override: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ModelProviderRead(BaseModel):
|
||||
provider: str
|
||||
transport: str
|
||||
discovery: str
|
||||
default_base_url: str | None = None
|
||||
base_url_required: bool
|
||||
auth_style: str
|
||||
local_only: bool = False
|
||||
|
||||
|
||||
class VerifyConnectionResponse(BaseModel):
|
||||
status: str
|
||||
ok: bool
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from sqlalchemy.orm import selectinload
|
|||
|
||||
from app.config import config
|
||||
from app.db import Connection, Model, NewChatThread
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.quality_score import _QUALITY_TOP_K
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
|
|
@ -62,18 +63,13 @@ def _is_usable_global_config(cfg: dict) -> bool:
|
|||
return bool(
|
||||
cfg.get("id") is not None
|
||||
and cfg.get("model_name")
|
||||
and cfg.get("litellm_provider")
|
||||
and (cfg.get("provider") or cfg.get("litellm_provider"))
|
||||
and cfg.get("api_key")
|
||||
)
|
||||
|
||||
|
||||
def _has_capability(model: dict | Model, capability: str) -> bool:
|
||||
caps = (
|
||||
model.get("capabilities", {})
|
||||
if isinstance(model, dict)
|
||||
else model.capabilities or {}
|
||||
)
|
||||
return bool(caps.get(capability))
|
||||
return has_capability(model, capability)
|
||||
|
||||
|
||||
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
|
||||
|
|
@ -196,7 +192,7 @@ def _cfg_supports_image_input(cfg: dict) -> bool:
|
|||
else None
|
||||
)
|
||||
return derive_supports_image_input(
|
||||
litellm_provider=cfg.get("litellm_provider"),
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
|
|
@ -253,9 +249,13 @@ def _global_candidates(
|
|||
"model_id": model.get("model_id"),
|
||||
"source": "global",
|
||||
"connection": connection,
|
||||
"capabilities": model.get("capabilities") or {},
|
||||
"supports_chat": model.get("supports_chat"),
|
||||
"supports_image_input": model.get("supports_image_input"),
|
||||
"supports_tools": model.get("supports_tools"),
|
||||
"supports_image_generation": model.get("supports_image_generation"),
|
||||
"capabilities_override": model.get("capabilities_override") or {},
|
||||
"billing_tier": model.get("billing_tier", "free"),
|
||||
"litellm_provider": connection.get("litellm_provider"),
|
||||
"provider": connection.get("provider"),
|
||||
"model_name": model.get("model_id"),
|
||||
"auto_pin_tier": catalog.get("auto_pin_tier")
|
||||
or cfg.get("auto_pin_tier")
|
||||
|
|
@ -310,9 +310,13 @@ async def _db_candidates(
|
|||
"model_id": model.model_id,
|
||||
"source": "db",
|
||||
"connection": conn,
|
||||
"capabilities": model.capabilities or {},
|
||||
"supports_chat": model.supports_chat,
|
||||
"supports_image_input": model.supports_image_input,
|
||||
"supports_tools": model.supports_tools,
|
||||
"supports_image_generation": model.supports_image_generation,
|
||||
"capabilities_override": model.capabilities_override or {},
|
||||
"billing_tier": "byok",
|
||||
"litellm_provider": conn.litellm_provider,
|
||||
"provider": conn.provider,
|
||||
"model_name": model.model_id,
|
||||
"auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK",
|
||||
"quality_score": catalog.get("quality_score") or 75,
|
||||
|
|
@ -357,7 +361,7 @@ def _is_preferred_premium_auto_config(cfg: dict) -> bool:
|
|||
return (
|
||||
cfg.get("source") == "global"
|
||||
and _tier_of(cfg) == "premium"
|
||||
and str(cfg.get("litellm_provider", "")).lower() == "azure"
|
||||
and str(cfg.get("provider", "")).lower() == "azure"
|
||||
and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,8 +18,7 @@ def _connection_key(conn: dict[str, Any]) -> tuple[Any, ...]:
|
|||
# Deliberately includes api_key because two operator-owned credentials for
|
||||
# the same provider/base can have different quota/rate limits upstream.
|
||||
return (
|
||||
conn.get("protocol"),
|
||||
conn.get("litellm_provider"),
|
||||
conn.get("provider"),
|
||||
conn.get("base_url"),
|
||||
conn.get("api_key"),
|
||||
_freeze(conn.get("extra") or {}),
|
||||
|
|
@ -34,16 +33,6 @@ def _freeze(value: Any) -> Any:
|
|||
return value
|
||||
|
||||
|
||||
def _capabilities_for(role: str, config: dict[str, Any]) -> dict[str, bool]:
|
||||
return {
|
||||
"chat": role == "chat",
|
||||
"vision": role == "vision" or bool(config.get("supports_image_input")),
|
||||
"image_gen": role == "image_gen",
|
||||
"embedding": False,
|
||||
"tools": bool(config.get("supports_tools", False)),
|
||||
}
|
||||
|
||||
|
||||
def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"billing_tier": config.get("billing_tier", "free"),
|
||||
|
|
@ -72,7 +61,6 @@ def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]:
|
|||
def materialize_global_model_catalog(
|
||||
*,
|
||||
chat_configs: list[dict[str, Any]],
|
||||
vision_configs: list[dict[str, Any]],
|
||||
image_configs: list[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
connections: list[dict[str, Any]] = []
|
||||
|
|
@ -109,9 +97,13 @@ def materialize_global_model_catalog(
|
|||
"model_id": config["model_name"],
|
||||
"display_name": config.get("name") or config["model_name"],
|
||||
"source": "MANUAL",
|
||||
"capabilities": _capabilities_for(role, config),
|
||||
"capabilities_declared": _capabilities_for(role, config),
|
||||
"capabilities_verified": _capabilities_for(role, config),
|
||||
"supports_chat": role == "chat",
|
||||
"max_input_tokens": config.get("max_input_tokens"),
|
||||
"supports_image_input": (
|
||||
role == "chat" and bool(config.get("supports_image_input"))
|
||||
),
|
||||
"supports_tools": bool(config.get("supports_tools", False)),
|
||||
"supports_image_generation": role == "image_gen",
|
||||
"capabilities_override": {},
|
||||
"embedding_dimension": None,
|
||||
"enabled": True,
|
||||
|
|
@ -125,10 +117,6 @@ def materialize_global_model_catalog(
|
|||
if cfg.get("is_auto_mode"):
|
||||
continue
|
||||
add_config(cfg, "chat")
|
||||
for cfg in vision_configs:
|
||||
if cfg.get("is_auto_mode"):
|
||||
continue
|
||||
add_config(cfg, "vision")
|
||||
for cfg in image_configs:
|
||||
if cfg.get("is_auto_mode"):
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from app.services.auto_model_pin_service import (
|
|||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
from app.services.token_tracking_service import token_tracker
|
||||
|
||||
|
|
@ -76,7 +77,7 @@ def _legacy_config_connection(
|
|||
api_version: str | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
cfg = {
|
||||
"litellm_provider": provider.lower(),
|
||||
"provider": provider.lower(),
|
||||
"model_name": model_name,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
|
|
@ -136,12 +137,7 @@ def get_global_connection(connection_id: int) -> dict | None:
|
|||
|
||||
|
||||
def _has_capability(model: dict | Model, capability: str) -> bool:
|
||||
caps = (
|
||||
model.get("capabilities", {})
|
||||
if isinstance(model, dict)
|
||||
else model.capabilities or {}
|
||||
)
|
||||
return bool(caps.get(capability))
|
||||
return has_capability(model, capability)
|
||||
|
||||
|
||||
def _chat_litellm_from_resolved(
|
||||
|
|
@ -420,8 +416,6 @@ async def get_vision_llm(
|
|||
unwrapped — they don't consume premium credit (issue M).
|
||||
"""
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
from app.services.vision_llm_router_service import is_vision_auto_mode
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
|
|
@ -468,7 +462,7 @@ async def get_vision_llm(
|
|||
logger.error(f"No vision LLM configured for search space {search_space_id}")
|
||||
return None
|
||||
|
||||
if is_vision_auto_mode(config_id):
|
||||
if config_id == AUTO_MODE_ID:
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
|
|
|
|||
36
surfsense_backend/app/services/model_capabilities.py
Normal file
36
surfsense_backend/app/services/model_capabilities.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""Override-aware model capability lookup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
CAPABILITY_FIELDS = {
|
||||
"chat": "supports_chat",
|
||||
"vision": "supports_image_input",
|
||||
"image_gen": "supports_image_generation",
|
||||
"tools": "supports_tools",
|
||||
}
|
||||
|
||||
|
||||
def _get_value(model: Any, key: str) -> Any:
|
||||
if isinstance(model, Mapping):
|
||||
return model.get(key)
|
||||
return getattr(model, key, None)
|
||||
|
||||
|
||||
def has_capability(model: Any, capability: str) -> bool:
|
||||
field = CAPABILITY_FIELDS.get(capability)
|
||||
if field is None:
|
||||
return False
|
||||
|
||||
override = _get_value(model, "capabilities_override") or {}
|
||||
if isinstance(override, Mapping) and field in override:
|
||||
return bool(override[field])
|
||||
if isinstance(override, Mapping) and capability in override:
|
||||
return bool(override[capability])
|
||||
|
||||
return bool(_get_value(model, field))
|
||||
|
||||
|
||||
__all__ = ["CAPABILITY_FIELDS", "has_capability"]
|
||||
|
|
@ -11,8 +11,10 @@ from typing import Any
|
|||
import httpx
|
||||
import litellm
|
||||
|
||||
from app.db import Connection, ConnectionProtocol, Model, ModelSource
|
||||
from app.db import Connection, Model, ModelSource
|
||||
from app.services.model_resolver import ensure_v1, to_litellm
|
||||
from app.services.openrouter_model_normalizer import normalize_openrouter_models
|
||||
from app.services.provider_registry import Transport, spec_for
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -41,6 +43,16 @@ def _anthropic_headers(conn: Connection) -> dict[str, str]:
|
|||
return headers
|
||||
|
||||
|
||||
def _base_url_or_default(conn: Connection) -> str | None:
|
||||
if conn.base_url:
|
||||
return conn.base_url.rstrip("/")
|
||||
if conn.provider == "openai":
|
||||
return "https://api.openai.com/v1"
|
||||
if conn.provider == "anthropic":
|
||||
return "https://api.anthropic.com/v1"
|
||||
return spec_for(conn.provider).default_base_url
|
||||
|
||||
|
||||
def _docker_hint(url: str | None, exc_or_status: Any) -> str:
|
||||
raw = str(exc_or_status)
|
||||
if not url:
|
||||
|
|
@ -61,32 +73,30 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str:
|
|||
|
||||
|
||||
async def verify_connection(conn: Connection) -> VerifyResult:
|
||||
if not conn.base_url:
|
||||
spec = spec_for(conn.provider)
|
||||
base_url = _base_url_or_default(conn)
|
||||
if spec.base_url_required and not base_url:
|
||||
return VerifyResult("UNREACHABLE", False, "Base URL is required.")
|
||||
|
||||
if conn.protocol == ConnectionProtocol.OLLAMA:
|
||||
url = f"{conn.base_url.rstrip('/')}/api/version"
|
||||
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
|
||||
url = f"{ensure_v1(conn.base_url)}/models"
|
||||
elif conn.protocol == ConnectionProtocol.ANTHROPIC:
|
||||
url = f"{conn.base_url.rstrip('/')}/models"
|
||||
if spec.transport == Transport.OLLAMA and base_url:
|
||||
url = f"{base_url.rstrip('/')}/api/version"
|
||||
elif spec.discovery in {"openai_models", "openrouter"} and base_url:
|
||||
url = f"{ensure_v1(base_url)}/models"
|
||||
elif spec.discovery == "anthropic_models" and base_url:
|
||||
url = f"{base_url.rstrip('/')}/models"
|
||||
else:
|
||||
return VerifyResult("UNREACHABLE", False, "Unsupported connection protocol.")
|
||||
return VerifyResult("OK", True, "Connection uses provider-native authentication.")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client:
|
||||
headers = (
|
||||
_anthropic_headers(conn)
|
||||
if conn.protocol == ConnectionProtocol.ANTHROPIC
|
||||
else _auth_headers(conn)
|
||||
)
|
||||
headers = _anthropic_headers(conn) if spec.auth_style == "x-api-key" else _auth_headers(conn)
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code in (401, 403):
|
||||
return VerifyResult("AUTH_FAILED", False, "Authentication failed.")
|
||||
if response.status_code == 404:
|
||||
if conn.protocol == ConnectionProtocol.OLLAMA and url.endswith("/v1/models"):
|
||||
if spec.transport == Transport.OLLAMA and url.endswith("/v1/models"):
|
||||
message = "Ollama native API should not use /v1."
|
||||
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
|
||||
elif spec.transport == Transport.OPENAI_COMPATIBLE:
|
||||
message = "OpenAI-compatible servers should expose /v1/models."
|
||||
else:
|
||||
message = "Endpoint returned 404."
|
||||
|
|
@ -94,11 +104,11 @@ async def verify_connection(conn: Connection) -> VerifyResult:
|
|||
response.raise_for_status()
|
||||
return VerifyResult("OK", True, "Connection verified.")
|
||||
except httpx.ConnectError as exc:
|
||||
return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc))
|
||||
return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc))
|
||||
except httpx.TimeoutException as exc:
|
||||
return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}")
|
||||
except httpx.HTTPError as exc:
|
||||
return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc))
|
||||
return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc))
|
||||
|
||||
|
||||
async def persist_verification(conn: Connection) -> VerifyResult:
|
||||
|
|
@ -109,123 +119,193 @@ async def persist_verification(conn: Connection) -> VerifyResult:
|
|||
return result
|
||||
|
||||
|
||||
def _litellm_capabilities(model_string: str, model_id: str) -> dict[str, bool]:
|
||||
capabilities = {
|
||||
"chat": True,
|
||||
"vision": False,
|
||||
"tools": False,
|
||||
"image_gen": False,
|
||||
"embedding": False,
|
||||
}
|
||||
with contextlib.suppress(Exception):
|
||||
capabilities["vision"] = bool(litellm.supports_vision(model=model_string))
|
||||
with contextlib.suppress(Exception):
|
||||
capabilities["tools"] = bool(litellm.supports_function_calling(model=model_string))
|
||||
try:
|
||||
info = litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {}
|
||||
mode = str(info.get("mode") or "")
|
||||
capabilities["embedding"] = mode == "embedding"
|
||||
capabilities["image_gen"] = mode in {"image_generation", "image_generation_model"}
|
||||
except Exception:
|
||||
pass
|
||||
return capabilities
|
||||
|
||||
|
||||
def _allowlist(conn: Connection) -> set[str]:
|
||||
"""Per-connection model-id allowlist stored in ``extra.model_ids``.
|
||||
|
||||
Empty/absent means "no restriction" (discover everything), mirroring
|
||||
OpenWebUI's behaviour. A non-empty list restricts discovery to those ids —
|
||||
essential for providers like OpenRouter that expose hundreds of models.
|
||||
"""
|
||||
raw = (conn.extra or {}).get("model_ids") or []
|
||||
return {str(item).strip() for item in raw if str(item).strip()}
|
||||
|
||||
|
||||
async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) -> list[dict[str, Any]]:
|
||||
if not base_url:
|
||||
def _litellm_info(model_string: str, model_id: str) -> dict[str, Any]:
|
||||
with contextlib.suppress(Exception):
|
||||
info = litellm.get_model_info(model=model_string)
|
||||
if isinstance(info, dict):
|
||||
return info
|
||||
return litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {}
|
||||
|
||||
|
||||
def _classify_from_litellm(model_string: str, model_id: str) -> dict[str, Any]:
|
||||
info = _litellm_info(model_string, model_id)
|
||||
mode = info.get("mode")
|
||||
supports_image_input = False
|
||||
supports_tools = False
|
||||
with contextlib.suppress(Exception):
|
||||
supports_image_input = bool(litellm.supports_vision(model=model_string))
|
||||
with contextlib.suppress(Exception):
|
||||
supports_tools = bool(litellm.supports_function_calling(model=model_string))
|
||||
return {
|
||||
"supports_chat": mode in (None, "chat", "completion", "responses"),
|
||||
"max_input_tokens": info.get("max_input_tokens") or info.get("max_tokens"),
|
||||
"supports_image_input": supports_image_input,
|
||||
"supports_tools": supports_tools,
|
||||
"supports_image_generation": mode in {"image_generation", "image_generation_model"},
|
||||
}
|
||||
|
||||
|
||||
def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, Any]:
|
||||
metadata = metadata or {}
|
||||
spec = spec_for(conn.provider)
|
||||
model_string, _ = to_litellm(conn, model_id)
|
||||
facts = _classify_from_litellm(model_string, model_id)
|
||||
if spec.transport == Transport.OLLAMA:
|
||||
caps = set(metadata.get("capabilities") or [])
|
||||
details = metadata.get("details") or {}
|
||||
facts.update(
|
||||
{
|
||||
"supports_chat": "embedding" not in caps,
|
||||
"supports_image_input": "vision" in caps or facts["supports_image_input"],
|
||||
"supports_tools": "tools" in caps or facts["supports_tools"],
|
||||
"supports_image_generation": False,
|
||||
"max_input_tokens": metadata.get("context_length")
|
||||
or metadata.get("num_ctx")
|
||||
or details.get("context_length")
|
||||
or facts["max_input_tokens"],
|
||||
}
|
||||
)
|
||||
return facts
|
||||
|
||||
|
||||
async def _discover_openai_shaped_models(
|
||||
conn: Connection, base_url: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
resolved_base_url = base_url or _base_url_or_default(conn)
|
||||
if not resolved_base_url:
|
||||
return []
|
||||
|
||||
url = f"{ensure_v1(base_url)}/models"
|
||||
url = f"{ensure_v1(resolved_base_url)}/models"
|
||||
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
||||
response = await client.get(url, headers=_auth_headers(conn))
|
||||
response.raise_for_status()
|
||||
return [
|
||||
{
|
||||
"model_id": item.get("id"),
|
||||
"display_name": item.get("name") or item.get("id"),
|
||||
"source": ModelSource.DISCOVERED,
|
||||
"capabilities": derive_capabilities(conn, item.get("id"), item),
|
||||
"metadata": item,
|
||||
}
|
||||
for item in response.json().get("data", [])
|
||||
if item.get("id")
|
||||
]
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for item in response.json().get("data", []):
|
||||
model_id = item.get("id")
|
||||
if not model_id:
|
||||
continue
|
||||
results.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"display_name": item.get("name") or model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
**derive_capabilities(conn, model_id, item),
|
||||
"metadata": item,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]:
|
||||
if not conn.base_url:
|
||||
base_url = _base_url_or_default(conn)
|
||||
if not base_url:
|
||||
return []
|
||||
|
||||
url = f"{conn.base_url.rstrip('/')}/models"
|
||||
url = f"{base_url.rstrip('/')}/models"
|
||||
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
||||
response = await client.get(url, headers=_anthropic_headers(conn))
|
||||
response.raise_for_status()
|
||||
models = response.json().get("data", [])
|
||||
return [
|
||||
{
|
||||
"model_id": item.get("id"),
|
||||
"display_name": item.get("display_name") or item.get("id"),
|
||||
"source": ModelSource.DISCOVERED,
|
||||
"capabilities": derive_capabilities(conn, item.get("id"), item),
|
||||
"metadata": item,
|
||||
}
|
||||
for item in models
|
||||
if item.get("id")
|
||||
]
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for item in response.json().get("data", []):
|
||||
model_id = item.get("id")
|
||||
if not model_id:
|
||||
continue
|
||||
results.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"display_name": item.get("display_name") or model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
**derive_capabilities(conn, model_id, item),
|
||||
"metadata": item,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]:
|
||||
metadata = metadata or {}
|
||||
if conn.protocol == ConnectionProtocol.OLLAMA:
|
||||
caps = metadata.get("capabilities") or []
|
||||
capabilities = {
|
||||
"chat": True,
|
||||
"vision": "vision" in caps,
|
||||
"tools": False,
|
||||
"image_gen": False,
|
||||
"embedding": "embedding" in caps,
|
||||
}
|
||||
return capabilities
|
||||
async def _ollama_tags_then_show(conn: Connection) -> list[dict[str, Any]]:
|
||||
if not conn.base_url:
|
||||
return []
|
||||
|
||||
model_string, _ = to_litellm(conn, model_id)
|
||||
return _litellm_capabilities(model_string, model_id)
|
||||
base_url = conn.base_url.rstrip("/")
|
||||
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
||||
response = await client.get(f"{base_url}/api/tags", headers=_auth_headers(conn))
|
||||
response.raise_for_status()
|
||||
models = response.json().get("models", [])
|
||||
results: list[dict[str, Any]] = []
|
||||
for item in models:
|
||||
model_id = item.get("model") or item.get("name")
|
||||
if not model_id:
|
||||
continue
|
||||
metadata = dict(item)
|
||||
with contextlib.suppress(Exception):
|
||||
show_response = await client.post(
|
||||
f"{base_url}/api/show",
|
||||
json={"model": model_id},
|
||||
headers=_auth_headers(conn),
|
||||
)
|
||||
show_response.raise_for_status()
|
||||
metadata.update(show_response.json())
|
||||
results.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"display_name": item.get("name") or model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
**derive_capabilities(conn, model_id, metadata),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def _openrouter_models(conn: Connection) -> list[dict[str, Any]]:
|
||||
base_url = _base_url_or_default(conn) or "https://openrouter.ai/api/v1"
|
||||
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
||||
response = await client.get(f"{ensure_v1(base_url)}/models", headers=_auth_headers(conn))
|
||||
response.raise_for_status()
|
||||
return normalize_openrouter_models(response.json().get("data", []))
|
||||
|
||||
|
||||
def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]:
|
||||
provider = conn.provider
|
||||
prefix = spec_for(provider).litellm_prefix or provider
|
||||
results: list[dict[str, Any]] = []
|
||||
for model_string, metadata in litellm.model_cost.items():
|
||||
if not isinstance(model_string, str) or not model_string.startswith(f"{prefix}/"):
|
||||
continue
|
||||
model_id = model_string.split("/", 1)[1]
|
||||
results.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"display_name": metadata.get("display_name") or model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
**_classify_from_litellm(model_string, model_id),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def discover_models(conn: Connection) -> list[dict[str, Any]]:
|
||||
allowlist = _allowlist(conn)
|
||||
spec = spec_for(conn.provider)
|
||||
|
||||
if conn.protocol == ConnectionProtocol.OLLAMA:
|
||||
url = f"{conn.base_url.rstrip('/')}/api/tags"
|
||||
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
||||
response = await client.get(url, headers=_auth_headers(conn))
|
||||
response.raise_for_status()
|
||||
models = response.json().get("models", [])
|
||||
results = [
|
||||
{
|
||||
"model_id": item.get("model") or item.get("name"),
|
||||
"display_name": item.get("name") or item.get("model"),
|
||||
"source": ModelSource.DISCOVERED,
|
||||
"capabilities": derive_capabilities(conn, item.get("model") or item.get("name"), item),
|
||||
"metadata": item,
|
||||
}
|
||||
for item in models
|
||||
if item.get("model") or item.get("name")
|
||||
]
|
||||
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
|
||||
results = await _discover_openai_shaped_models(conn, conn.base_url)
|
||||
elif conn.protocol == ConnectionProtocol.ANTHROPIC:
|
||||
if spec.discovery == "ollama":
|
||||
results = await _ollama_tags_then_show(conn)
|
||||
elif spec.discovery == "openrouter":
|
||||
results = await _openrouter_models(conn)
|
||||
elif spec.discovery == "anthropic_models":
|
||||
results = await _discover_anthropic_models(conn)
|
||||
elif spec.discovery == "openai_models":
|
||||
results = await _discover_openai_shaped_models(conn, conn.base_url)
|
||||
elif spec.discovery == "static":
|
||||
results = _litellm_static_models(conn)
|
||||
else:
|
||||
results = []
|
||||
|
||||
|
|
@ -246,10 +326,7 @@ async def test_model(conn: Connection, model: Model) -> VerifyResult:
|
|||
except Exception as exc:
|
||||
return VerifyResult("UNREACHABLE", False, str(exc))
|
||||
|
||||
model.capabilities_verified = {
|
||||
**(model.capabilities_verified or {}),
|
||||
"chat": True,
|
||||
}
|
||||
model.supports_chat = True
|
||||
return VerifyResult("OK", True, "Model test succeeded.")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from pathlib import Path
|
|||
|
||||
import httpx
|
||||
|
||||
from app.services.openrouter_model_normalizer import normalize_openrouter_models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
|
||||
|
|
@ -121,26 +123,13 @@ def _process_models(raw_models: list[dict]) -> list[dict]:
|
|||
"""
|
||||
processed: list[dict] = []
|
||||
|
||||
for model in raw_models:
|
||||
model_id: str = model.get("id", "")
|
||||
name: str = model.get("name", "")
|
||||
context_length = model.get("context_length")
|
||||
|
||||
for normalized in normalize_openrouter_models(raw_models):
|
||||
model_id: str = normalized["model_id"]
|
||||
name: str = normalized.get("display_name") or model_id
|
||||
context_length = normalized.get("max_input_tokens")
|
||||
if "/" not in model_id:
|
||||
continue
|
||||
|
||||
if not _is_text_output_model(model):
|
||||
continue
|
||||
|
||||
if not _supports_tool_calling(model):
|
||||
continue
|
||||
|
||||
if not _has_sufficient_context(model):
|
||||
continue
|
||||
|
||||
if not _is_allowed_model(model):
|
||||
continue
|
||||
|
||||
provider_slug, model_name = model_id.split("/", 1)
|
||||
context_window = _format_context_length(context_length)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,7 @@ from typing import TYPE_CHECKING, Any
|
|||
if TYPE_CHECKING:
|
||||
from app.db import Connection
|
||||
|
||||
PROTOCOL_OLLAMA = "OLLAMA"
|
||||
PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
|
||||
PROTOCOL_ANTHROPIC = "ANTHROPIC"
|
||||
from app.services.provider_registry import Transport, spec_for
|
||||
|
||||
|
||||
def ensure_v1(base_url: str | None) -> str | None:
|
||||
|
|
@ -32,47 +30,25 @@ def _conn_value(conn: Connection | Mapping[str, Any], key: str) -> Any:
|
|||
return getattr(conn, key)
|
||||
|
||||
|
||||
def _protocol_value(protocol: Any) -> str:
|
||||
return getattr(protocol, "value", str(protocol))
|
||||
|
||||
|
||||
def default_litellm_provider(protocol: Any) -> str:
|
||||
protocol_value = _protocol_value(protocol)
|
||||
defaults = {
|
||||
PROTOCOL_OLLAMA: "ollama_chat",
|
||||
PROTOCOL_ANTHROPIC: "anthropic",
|
||||
PROTOCOL_OPENAI_COMPATIBLE: "openai",
|
||||
}
|
||||
return defaults.get(protocol_value, "openai")
|
||||
|
||||
|
||||
def _execution_api_base(protocol: str, base_url: str | None) -> str | None:
|
||||
del protocol
|
||||
if not base_url:
|
||||
return None
|
||||
return base_url.rstrip("/")
|
||||
|
||||
|
||||
def to_litellm(
|
||||
conn: Connection | Mapping[str, Any],
|
||||
model_id: str,
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
"""Return ``(model_string, litellm_kwargs)`` for any model role."""
|
||||
protocol = _protocol_value(_conn_value(conn, "protocol"))
|
||||
provider = _conn_value(conn, "provider")
|
||||
base_url = _conn_value(conn, "base_url")
|
||||
api_key = _conn_value(conn, "api_key")
|
||||
litellm_provider = (
|
||||
_conn_value(conn, "litellm_provider") or default_litellm_provider(protocol)
|
||||
)
|
||||
extra = _conn_value(conn, "extra") or {}
|
||||
spec = spec_for(provider)
|
||||
|
||||
kwargs: dict[str, Any] = {}
|
||||
if api_key:
|
||||
kwargs["api_key"] = api_key
|
||||
|
||||
model_string = f"{litellm_provider}/{model_id}" if litellm_provider else model_id
|
||||
api_base = _execution_api_base(protocol, base_url)
|
||||
if api_base:
|
||||
prefix = spec.litellm_prefix or str(provider)
|
||||
model_string = f"{prefix}/{model_id}" if prefix else model_id
|
||||
if base_url:
|
||||
api_base = ensure_v1(base_url) if spec.transport == Transport.OPENAI_COMPATIBLE else base_url.rstrip("/")
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
if api_version := extra.get("api_version"):
|
||||
|
|
@ -84,11 +60,11 @@ def to_litellm(
|
|||
|
||||
def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Build an in-memory connection mapping from a global config."""
|
||||
protocol = str(config.get("protocol") or PROTOCOL_OPENAI_COMPATIBLE)
|
||||
litellm_provider = str(
|
||||
config.get("litellm_provider")
|
||||
provider = str(
|
||||
config.get("provider")
|
||||
or config.get("litellm_provider")
|
||||
or config.get("custom_provider")
|
||||
or default_litellm_provider(protocol)
|
||||
or "openai"
|
||||
)
|
||||
extra: dict[str, Any] = {
|
||||
"litellm_params": config.get("litellm_params") or {},
|
||||
|
|
@ -96,8 +72,7 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
|
|||
if config.get("api_version"):
|
||||
extra["api_version"] = config.get("api_version")
|
||||
return {
|
||||
"protocol": protocol,
|
||||
"litellm_provider": litellm_provider,
|
||||
"provider": provider,
|
||||
"base_url": config.get("api_base") or None,
|
||||
"api_key": config.get("api_key") or None,
|
||||
"extra": extra,
|
||||
|
|
@ -105,7 +80,6 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
|
|||
|
||||
|
||||
__all__ = [
|
||||
"default_litellm_provider",
|
||||
"ensure_v1",
|
||||
"native_connection_from_config",
|
||||
"to_litellm",
|
||||
|
|
|
|||
|
|
@ -29,6 +29,13 @@ from app.services.quality_score import (
|
|||
aggregate_health,
|
||||
static_score_or,
|
||||
)
|
||||
from app.services.openrouter_model_normalizer import (
|
||||
is_allowed_model as _shared_is_allowed_model,
|
||||
is_compatible_provider as _shared_is_compatible_provider,
|
||||
is_openrouter_image_model,
|
||||
normalize_openrouter_models,
|
||||
supports_image_input,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -292,24 +299,16 @@ def _generate_configs(
|
|||
use_default: bool = settings.get("use_default_system_instructions", True)
|
||||
citations_enabled: bool = settings.get("citations_enabled", True)
|
||||
|
||||
text_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_text_output_model(m)
|
||||
and _supports_tool_calling(m)
|
||||
and _has_sufficient_context(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
text_models = normalize_openrouter_models(raw_models)
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
now_ts = int(time.time())
|
||||
|
||||
for model in text_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
for normalized in text_models:
|
||||
model = normalized.get("metadata") or {}
|
||||
model_id: str = normalized["model_id"]
|
||||
name: str = normalized.get("display_name") or model_id
|
||||
tier = _openrouter_tier(model)
|
||||
|
||||
static_q = static_score_or(model, now_ts=now_ts)
|
||||
|
|
@ -323,7 +322,7 @@ def _generate_configs(
|
|||
"seo_enabled": seo_enabled,
|
||||
"seo_slug": None,
|
||||
"quota_reserve_tokens": quota_reserve_tokens,
|
||||
"litellm_provider": "openrouter",
|
||||
"provider": "openrouter",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
|
|
@ -345,7 +344,7 @@ def _generate_configs(
|
|||
# ``stream_new_chat`` as a fail-fast safety net before the
|
||||
# OpenRouter request would otherwise 404 with
|
||||
# ``"No endpoints found that support image input"``.
|
||||
"supports_image_input": _supports_image_input(model),
|
||||
"supports_image_input": bool(normalized.get("supports_image_input")),
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
|
||||
# to the static score and gets re-blended with health on the next
|
||||
|
|
@ -403,10 +402,7 @@ def _generate_image_gen_configs(
|
|||
image_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_image_output_model(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
if is_openrouter_image_model(m)
|
||||
]
|
||||
|
||||
configs: list[dict] = []
|
||||
|
|
@ -420,7 +416,7 @@ def _generate_image_gen_configs(
|
|||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter (image generation)",
|
||||
"litellm_provider": "openrouter",
|
||||
"provider": "openrouter",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
|
|
@ -468,9 +464,9 @@ def _generate_vision_llm_configs(
|
|||
vision_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_vision_input_model(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
if supports_image_input(m)
|
||||
and _shared_is_compatible_provider(m)
|
||||
and _shared_is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
|
||||
|
|
@ -499,7 +495,7 @@ def _generate_vision_llm_configs(
|
|||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter (vision)",
|
||||
"litellm_provider": "openrouter",
|
||||
"provider": "openrouter",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
|
|
@ -544,11 +540,9 @@ class OpenRouterIntegrationService:
|
|||
# Cached raw catalogue from the most recent fetch. Image / vision
|
||||
# emitters reuse this to avoid a second network call per surface.
|
||||
self._raw_models: list[dict] = []
|
||||
# Image / vision config caches (only populated when the matching
|
||||
# opt-in flag is true on initialize). Refreshed in lockstep with
|
||||
# the chat catalogue.
|
||||
# Image config cache (only populated when the matching opt-in flag is
|
||||
# true on initialize). Refreshed in lockstep with the chat catalogue.
|
||||
self._image_configs: list[dict] = []
|
||||
self._vision_configs: list[dict] = []
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "OpenRouterIntegrationService":
|
||||
|
|
@ -583,7 +577,7 @@ class OpenRouterIntegrationService:
|
|||
self._configs_by_id = {c["id"]: c for c in self._configs}
|
||||
self._raw_pricing = _extract_raw_pricing(raw_models)
|
||||
|
||||
# Populate image / vision caches when their opt-in flag is set.
|
||||
# Populate image cache when its opt-in flag is set.
|
||||
# Empty otherwise so the accessors return [] without re-running
|
||||
# filters every refresh.
|
||||
if settings.get("image_generation_enabled"):
|
||||
|
|
@ -595,15 +589,6 @@ class OpenRouterIntegrationService:
|
|||
else:
|
||||
self._image_configs = []
|
||||
|
||||
if settings.get("vision_enabled"):
|
||||
self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
|
||||
logger.info(
|
||||
"OpenRouter integration: vision LLM emission ON (%d models)",
|
||||
len(self._vision_configs),
|
||||
)
|
||||
else:
|
||||
self._vision_configs = []
|
||||
|
||||
self._initialized = True
|
||||
|
||||
tier_counts = self._tier_counts(self._configs)
|
||||
|
|
@ -657,9 +642,9 @@ class OpenRouterIntegrationService:
|
|||
self._configs = new_configs
|
||||
self._configs_by_id = new_by_id
|
||||
|
||||
# Image / vision lists are atomic-swapped the same way: filter out
|
||||
# Image list is atomic-swapped the same way: filter out
|
||||
# the previous dynamic entries from the live config list and append
|
||||
# the freshly generated ones. No-ops when the opt-in flag is off.
|
||||
# the freshly generated ones. No-op when the opt-in flag is off.
|
||||
if self._settings.get("image_generation_enabled"):
|
||||
new_image = _generate_image_gen_configs(raw_models, self._settings)
|
||||
static_image = [
|
||||
|
|
@ -670,16 +655,6 @@ class OpenRouterIntegrationService:
|
|||
app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
|
||||
self._image_configs = new_image
|
||||
|
||||
if self._settings.get("vision_enabled"):
|
||||
new_vision = _generate_vision_llm_configs(raw_models, self._settings)
|
||||
static_vision = [
|
||||
c
|
||||
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
|
||||
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||
]
|
||||
app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
|
||||
self._vision_configs = new_vision
|
||||
|
||||
# Catalogue churn invalidates per-config "recently healthy" credit
|
||||
# earned by the previous turn's preflight. Drop the whole table so
|
||||
# the next turn re-probes against the freshly loaded configs.
|
||||
|
|
@ -701,7 +676,7 @@ class OpenRouterIntegrationService:
|
|||
)
|
||||
|
||||
# Re-blend health scores against the freshly fetched catalogue. Also
|
||||
# re-stamps health for any YAML-curated cfg with litellm_provider=openrouter
|
||||
# re-stamps health for any YAML-curated cfg with provider=openrouter
|
||||
# so a hand-picked dead OR model is gated like a dynamic one.
|
||||
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
|
||||
|
||||
|
|
@ -778,7 +753,7 @@ class OpenRouterIntegrationService:
|
|||
the entire previous cycle's cache for this run.
|
||||
"""
|
||||
or_cfgs = [
|
||||
c for c in configs if str(c.get("litellm_provider", "")).lower() == "openrouter"
|
||||
c for c in configs if str(c.get("provider", "")).lower() == "openrouter"
|
||||
]
|
||||
if not or_cfgs:
|
||||
return
|
||||
|
|
@ -959,17 +934,6 @@ class OpenRouterIntegrationService:
|
|||
"""
|
||||
return list(self._image_configs)
|
||||
|
||||
def get_vision_llm_configs(self) -> list[dict]:
|
||||
"""Return the dynamic OpenRouter vision-LLM configs (empty list
|
||||
when the ``vision_enabled`` flag is off).
|
||||
|
||||
Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
|
||||
so ``pricing_registration`` can teach LiteLLM the cost of these
|
||||
models the same way it does for chat — which keeps the billable
|
||||
wrapper able to debit accurate micro-USD on a vision call.
|
||||
"""
|
||||
return list(self._vision_configs)
|
||||
|
||||
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
|
||||
"""Return the cached raw OpenRouter pricing map.
|
||||
|
||||
|
|
|
|||
121
surfsense_backend/app/services/openrouter_model_normalizer.py
Normal file
121
surfsense_backend/app/services/openrouter_model_normalizer.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"""Shared OpenRouter model normalization.
|
||||
|
||||
OpenRouter metadata is richer than generic OpenAI-compatible ``/models``
|
||||
responses. Keep all OpenRouter filtering and capability extraction here so
|
||||
GLOBAL catalogue generation and BYOK discovery agree.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.db import ModelSource
|
||||
|
||||
MIN_CONTEXT_LENGTH = 100_000
|
||||
|
||||
EXCLUDED_PROVIDER_SLUGS = {"amazon"}
|
||||
EXCLUDED_MODEL_IDS: set[str] = {
|
||||
"openai/gpt-4-1106-preview",
|
||||
"openai/gpt-4-turbo-preview",
|
||||
"openai/gpt-4o:extended",
|
||||
"arcee-ai/virtuoso-large",
|
||||
"openai/o3-deep-research",
|
||||
"openai/o4-mini-deep-research",
|
||||
"openrouter/free",
|
||||
}
|
||||
EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
|
||||
|
||||
|
||||
def is_text_output_model(model: dict[str, Any]) -> bool:
|
||||
output_mods = model.get("architecture", {}).get("output_modalities", [])
|
||||
return output_mods == ["text"]
|
||||
|
||||
|
||||
def is_image_output_model(model: dict[str, Any]) -> bool:
|
||||
output_mods = model.get("architecture", {}).get("output_modalities", []) or []
|
||||
return "image" in output_mods
|
||||
|
||||
|
||||
def supports_image_input(model: dict[str, Any]) -> bool:
|
||||
input_mods = model.get("architecture", {}).get("input_modalities", []) or []
|
||||
return "image" in input_mods
|
||||
|
||||
|
||||
def supports_tool_calling(model: dict[str, Any]) -> bool:
|
||||
supported = model.get("supported_parameters") or []
|
||||
return "tools" in supported
|
||||
|
||||
|
||||
def has_sufficient_context(model: dict[str, Any]) -> bool:
|
||||
return int(model.get("context_length") or 0) >= MIN_CONTEXT_LENGTH
|
||||
|
||||
|
||||
def is_compatible_provider(model: dict[str, Any]) -> bool:
|
||||
model_id = str(model.get("id") or "")
|
||||
slug = model_id.split("/", 1)[0] if "/" in model_id else ""
|
||||
return slug not in EXCLUDED_PROVIDER_SLUGS
|
||||
|
||||
|
||||
def is_allowed_model(model: dict[str, Any]) -> bool:
|
||||
model_id = str(model.get("id") or "")
|
||||
if model_id in EXCLUDED_MODEL_IDS:
|
||||
return False
|
||||
base_id = model_id.split(":")[0]
|
||||
return not base_id.endswith(EXCLUDED_MODEL_SUFFIXES)
|
||||
|
||||
|
||||
def is_openrouter_chat_model(model: dict[str, Any]) -> bool:
|
||||
return (
|
||||
"/" in str(model.get("id") or "")
|
||||
and is_text_output_model(model)
|
||||
and supports_tool_calling(model)
|
||||
and has_sufficient_context(model)
|
||||
and is_compatible_provider(model)
|
||||
and is_allowed_model(model)
|
||||
)
|
||||
|
||||
|
||||
def is_openrouter_image_model(model: dict[str, Any]) -> bool:
|
||||
return (
|
||||
"/" in str(model.get("id") or "")
|
||||
and is_image_output_model(model)
|
||||
and is_compatible_provider(model)
|
||||
and is_allowed_model(model)
|
||||
)
|
||||
|
||||
|
||||
def normalize_openrouter_models(raw_models: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for model in raw_models:
|
||||
if not is_openrouter_chat_model(model):
|
||||
continue
|
||||
model_id = str(model.get("id") or "")
|
||||
normalized.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"display_name": model.get("name") or model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
"supports_chat": True,
|
||||
"max_input_tokens": model.get("context_length"),
|
||||
"supports_image_input": supports_image_input(model),
|
||||
"supports_tools": supports_tool_calling(model),
|
||||
"supports_image_generation": False,
|
||||
"metadata": model,
|
||||
}
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MIN_CONTEXT_LENGTH",
|
||||
"has_sufficient_context",
|
||||
"is_allowed_model",
|
||||
"is_compatible_provider",
|
||||
"is_image_output_model",
|
||||
"is_openrouter_chat_model",
|
||||
"is_openrouter_image_model",
|
||||
"is_text_output_model",
|
||||
"normalize_openrouter_models",
|
||||
"supports_image_input",
|
||||
"supports_tool_calling",
|
||||
]
|
||||
|
|
@ -143,7 +143,7 @@ def _register_chat_shape_configs(
|
|||
sample_keys: list[str] = []
|
||||
|
||||
for cfg in configs:
|
||||
provider = str(cfg.get("litellm_provider") or "").lower()
|
||||
provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower()
|
||||
model_name = str(cfg.get("model_name") or "").strip()
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = str(litellm_params.get("base_model") or model_name).strip()
|
||||
|
|
@ -216,9 +216,8 @@ def _register_chat_shape_configs(
|
|||
def register_pricing_from_global_configs() -> None:
|
||||
"""Register pricing for every known LLM deployment with LiteLLM.
|
||||
|
||||
Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
|
||||
so vision calls (during indexing) can resolve cost the same way chat
|
||||
calls do — namely:
|
||||
Walks ``config.GLOBAL_LLM_CONFIGS`` so chat and vision calls can resolve
|
||||
cost from the same chat-shaped deployment configs:
|
||||
|
||||
1. ``OPENROUTER``: pulls the cached raw pricing from
|
||||
``OpenRouterIntegrationService`` (populated during its own
|
||||
|
|
@ -245,10 +244,7 @@ def register_pricing_from_global_configs() -> None:
|
|||
from app.config import config as app_config
|
||||
|
||||
chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
|
||||
vision_configs: list[dict] = list(
|
||||
getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
|
||||
)
|
||||
if not chat_configs and not vision_configs:
|
||||
if not chat_configs:
|
||||
logger.info("[PricingRegistration] no global configs to register")
|
||||
return
|
||||
|
||||
|
|
@ -267,7 +263,3 @@ def register_pricing_from_global_configs() -> None:
|
|||
|
||||
if chat_configs:
|
||||
_register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
|
||||
if vision_configs:
|
||||
_register_chat_shape_configs(
|
||||
vision_configs, or_pricing=or_pricing, label="vision"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
def _candidate_model_strings(
|
||||
*,
|
||||
litellm_provider: str | None,
|
||||
provider: str | None,
|
||||
model_name: str | None,
|
||||
base_model: str | None,
|
||||
custom_provider: str | None,
|
||||
|
|
@ -78,7 +78,7 @@ def _candidate_model_strings(
|
|||
seen.add(key)
|
||||
candidates.append(key)
|
||||
|
||||
provider_prefix = custom_provider or litellm_provider
|
||||
provider_prefix = custom_provider or provider
|
||||
|
||||
primary_model = base_model or model_name
|
||||
bare_model = model_name
|
||||
|
|
@ -113,7 +113,7 @@ def _candidate_model_strings(
|
|||
|
||||
def derive_supports_image_input(
|
||||
*,
|
||||
litellm_provider: str | None = None,
|
||||
provider: str | None = None,
|
||||
model_name: str | None = None,
|
||||
base_model: str | None = None,
|
||||
custom_provider: str | None = None,
|
||||
|
|
@ -147,7 +147,7 @@ def derive_supports_image_input(
|
|||
return False
|
||||
|
||||
for model_string, custom_llm_provider in _candidate_model_strings(
|
||||
litellm_provider=litellm_provider,
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
|
|
@ -172,7 +172,7 @@ def derive_supports_image_input(
|
|||
|
||||
def is_known_text_only_chat_model(
|
||||
*,
|
||||
litellm_provider: str | None = None,
|
||||
provider: str | None = None,
|
||||
model_name: str | None = None,
|
||||
base_model: str | None = None,
|
||||
custom_provider: str | None = None,
|
||||
|
|
@ -193,7 +193,7 @@ def is_known_text_only_chat_model(
|
|||
leads to the regression we're fixing here.
|
||||
"""
|
||||
for model_string, custom_llm_provider in _candidate_model_strings(
|
||||
litellm_provider=litellm_provider,
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
|
|
|
|||
98
surfsense_backend/app/services/provider_registry.py
Normal file
98
surfsense_backend/app/services/provider_registry.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""Provider registry for model connections.
|
||||
|
||||
The provider string is the single public identity axis. This registry only
|
||||
describes providers whose behavior differs from LiteLLM's native default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class Transport(StrEnum):
|
||||
NATIVE = "NATIVE"
|
||||
OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
|
||||
OLLAMA = "OLLAMA"
|
||||
|
||||
|
||||
DiscoveryKind = Literal[
|
||||
"ollama",
|
||||
"openai_models",
|
||||
"anthropic_models",
|
||||
"openrouter",
|
||||
"static",
|
||||
"none",
|
||||
]
|
||||
|
||||
AuthStyle = Literal["bearer", "x-api-key", "none", "native"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderSpec:
|
||||
transport: Transport
|
||||
litellm_prefix: str | None
|
||||
discovery: DiscoveryKind
|
||||
default_base_url: str | None
|
||||
base_url_required: bool
|
||||
auth_style: AuthStyle
|
||||
|
||||
|
||||
REGISTRY: dict[str, ProviderSpec] = {
|
||||
"openai": ProviderSpec(
|
||||
Transport.NATIVE, "openai", "openai_models", None, False, "bearer"
|
||||
),
|
||||
"anthropic": ProviderSpec(
|
||||
Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key"
|
||||
),
|
||||
"azure": ProviderSpec(Transport.NATIVE, "azure", "static", None, True, "native"),
|
||||
"vertex_ai": ProviderSpec(
|
||||
Transport.NATIVE, "vertex_ai", "static", None, False, "native"
|
||||
),
|
||||
"bedrock": ProviderSpec(
|
||||
Transport.NATIVE, "bedrock", "static", None, False, "native"
|
||||
),
|
||||
"openrouter": ProviderSpec(
|
||||
Transport.OPENAI_COMPATIBLE,
|
||||
"openrouter",
|
||||
"openrouter",
|
||||
"https://openrouter.ai/api/v1",
|
||||
False,
|
||||
"bearer",
|
||||
),
|
||||
"openai_compatible": ProviderSpec(
|
||||
Transport.OPENAI_COMPATIBLE,
|
||||
"openai",
|
||||
"openai_models",
|
||||
None,
|
||||
True,
|
||||
"bearer",
|
||||
),
|
||||
"lm_studio": ProviderSpec(
|
||||
Transport.OPENAI_COMPATIBLE,
|
||||
"openai",
|
||||
"openai_models",
|
||||
"http://localhost:1234/v1",
|
||||
True,
|
||||
"bearer",
|
||||
),
|
||||
"ollama_chat": ProviderSpec(
|
||||
Transport.OLLAMA,
|
||||
"ollama_chat",
|
||||
"ollama",
|
||||
"http://localhost:11434",
|
||||
True,
|
||||
"none",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def spec_for(provider: str | None) -> ProviderSpec:
|
||||
provider_key = (provider or "").strip()
|
||||
return REGISTRY.get(provider_key) or ProviderSpec(
|
||||
Transport.NATIVE, provider_key or "openai", "static", None, False, "native"
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["REGISTRY", "ProviderSpec", "Transport", "spec_for"]
|
||||
|
|
@ -273,7 +273,7 @@ def static_score_yaml(cfg: dict) -> int:
|
|||
listed this model. Pricing / context fall through to lazy ``litellm``
|
||||
lookups; failures are silent (we just lose those sub-points).
|
||||
"""
|
||||
provider = str(cfg.get("litellm_provider", "")).lower()
|
||||
provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower()
|
||||
base = PROVIDER_PRESTIGE_YAML.get(provider, 15)
|
||||
|
||||
model_name = cfg.get("model_name") or ""
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def check_image_input_capability(
|
|||
else None
|
||||
)
|
||||
if not is_known_text_only_chat_model(
|
||||
litellm_provider=agent_config.provider,
|
||||
provider=agent_config.provider,
|
||||
model_name=agent_config.model_name,
|
||||
base_model=agent_base_model,
|
||||
custom_provider=agent_config.custom_provider,
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from app.agents.chat.runtime.llm_config import (
|
|||
)
|
||||
from app.config import config
|
||||
from app.db import Model, SearchSpace
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_resolver import to_litellm
|
||||
|
||||
|
||||
|
|
@ -96,7 +97,7 @@ async def load_llm_bundle(
|
|||
model_id=config_id,
|
||||
search_space=search_space,
|
||||
)
|
||||
if not model or not (model.capabilities or {}).get("chat"):
|
||||
if not model or not has_capability(model, "chat"):
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
|
|
@ -106,12 +107,12 @@ async def load_llm_bundle(
|
|||
agent_config = _agent_config_from_resolved(
|
||||
config_id=config_id,
|
||||
config_name=model.display_name or model.model_id,
|
||||
provider=model.connection.litellm_provider or "",
|
||||
provider=model.connection.provider or "",
|
||||
model_name=model.model_id,
|
||||
api_key=model.connection.api_key,
|
||||
api_base=model.connection.base_url,
|
||||
litellm_params=(model.connection.extra or {}).get("litellm_params"),
|
||||
supports_image_input=bool((model.capabilities or {}).get("vision")),
|
||||
supports_image_input=has_capability(model, "vision"),
|
||||
billing_tier="free",
|
||||
)
|
||||
return (
|
||||
|
|
@ -121,7 +122,7 @@ async def load_llm_bundle(
|
|||
)
|
||||
|
||||
global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None)
|
||||
if not global_model or not (global_model.get("capabilities") or {}).get("chat"):
|
||||
if not global_model or not has_capability(global_model, "chat"):
|
||||
return None, None, f"Failed to load global chat model with id {config_id}"
|
||||
global_connection = next(
|
||||
(
|
||||
|
|
@ -137,12 +138,12 @@ async def load_llm_bundle(
|
|||
agent_config = _agent_config_from_resolved(
|
||||
config_id=config_id,
|
||||
config_name=global_model.get("display_name") or global_model.get("model_id"),
|
||||
provider=global_connection.get("litellm_provider") or "",
|
||||
provider=global_connection.get("provider") or "",
|
||||
model_name=global_model["model_id"],
|
||||
api_key=global_connection.get("api_key"),
|
||||
api_base=global_connection.get("base_url"),
|
||||
litellm_params=(global_connection.get("extra") or {}).get("litellm_params"),
|
||||
supports_image_input=bool((global_model.get("capabilities") or {}).get("vision")),
|
||||
supports_image_input=has_capability(global_model, "vision"),
|
||||
billing_tier=str(global_model.get("billing_tier", "free")).lower(),
|
||||
)
|
||||
return (
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue