feat(models): add model connection persistence

This commit is contained in:
Anish Sarkar 2026-06-10 21:47:23 +05:30
parent b4c6061353
commit adb857925b
8 changed files with 1033 additions and 1 deletions

View file

@ -0,0 +1,178 @@
"""add model connections
Revision ID: 156
Revises: 155
"""
from collections.abc import Sequence
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
revision: str = "156"
down_revision: str | None = "155"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
connection_protocol = postgresql.ENUM(
"OLLAMA",
"OPENAI_COMPATIBLE",
"NATIVE",
name="connectionprotocol",
create_type=False,
)
connection_scope = postgresql.ENUM(
"GLOBAL",
"SEARCH_SPACE",
"USER",
name="connectionscope",
create_type=False,
)
model_source = postgresql.ENUM(
"DISCOVERED",
"MANUAL",
name="modelsource",
create_type=False,
)
def upgrade() -> None:
bind = op.get_bind()
connection_protocol.create(bind, checkfirst=True)
connection_scope.create(bind, checkfirst=True)
model_source.create(bind, checkfirst=True)
op.create_table(
"connections",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("protocol", connection_protocol, nullable=False),
sa.Column("native_provider", sa.String(length=100), nullable=True),
sa.Column("base_url", sa.String(length=500), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column(
"extra",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column("scope", connection_scope, nullable=False),
sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False),
sa.Column("search_space_id", sa.Integer(), nullable=True),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("last_verified_at", sa.TIMESTAMP(timezone=True), nullable=True),
sa.Column("last_status", sa.String(length=50), nullable=True),
sa.Column("last_error", sa.Text(), nullable=True),
sa.CheckConstraint(
"(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR "
"(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR "
"(scope = 'USER' AND user_id IS NOT NULL)",
name="ck_connections_scope_owner",
),
sa.ForeignKeyConstraint(
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_connections_protocol"), "connections", ["protocol"], unique=False)
op.create_index(
op.f("ix_connections_native_provider"),
"connections",
["native_provider"],
unique=False,
)
op.create_index(op.f("ix_connections_scope"), "connections", ["scope"], unique=False)
op.create_table(
"models",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("connection_id", sa.Integer(), nullable=False),
sa.Column("model_id", sa.String(length=255), nullable=False),
sa.Column("display_name", sa.String(length=255), nullable=True),
sa.Column(
"source",
model_source,
server_default="DISCOVERED",
nullable=False,
),
sa.Column(
"capabilities",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column(
"capabilities_declared",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column(
"capabilities_verified",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column(
"capabilities_override",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column("embedding_dimension", sa.Integer(), nullable=True),
sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False),
sa.Column("billing_tier", sa.String(length=50), nullable=True),
sa.Column(
"catalog",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.ForeignKeyConstraint(["connection_id"], ["connections.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"connection_id", "model_id", name="uq_models_connection_model_id"
),
)
op.create_index(op.f("ix_models_connection_id"), "models", ["connection_id"], unique=False)
op.create_index("ix_models_model_id", "models", ["model_id"], unique=False)
op.create_index(op.f("ix_models_billing_tier"), "models", ["billing_tier"], unique=False)
op.add_column(
"searchspaces",
sa.Column("chat_model_id", sa.Integer(), nullable=True),
)
op.add_column(
"searchspaces",
sa.Column("image_gen_model_id", sa.Integer(), nullable=True),
)
op.add_column(
"searchspaces",
sa.Column("vision_model_id", sa.Integer(), nullable=True),
)
def downgrade() -> None:
op.drop_column("searchspaces", "vision_model_id")
op.drop_column("searchspaces", "image_gen_model_id")
op.drop_column("searchspaces", "chat_model_id")
op.drop_index(op.f("ix_models_billing_tier"), table_name="models")
op.drop_index("ix_models_model_id", table_name="models")
op.drop_index(op.f("ix_models_connection_id"), table_name="models")
op.drop_table("models")
op.drop_index(op.f("ix_connections_scope"), table_name="connections")
op.drop_index(op.f("ix_connections_native_provider"), table_name="connections")
op.drop_index(op.f("ix_connections_protocol"), table_name="connections")
op.drop_table("connections")
bind = op.get_bind()
model_source.drop(bind, checkfirst=True)
connection_scope.drop(bind, checkfirst=True)
connection_protocol.drop(bind, checkfirst=True)

View file

@ -280,6 +280,23 @@ class VisionProvider(StrEnum):
CUSTOM = "CUSTOM"
class ConnectionProtocol(StrEnum):
OLLAMA = "OLLAMA"
OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
NATIVE = "NATIVE"
class ConnectionScope(StrEnum):
GLOBAL = "GLOBAL"
SEARCH_SPACE = "SEARCH_SPACE"
USER = "USER"
class ModelSource(StrEnum):
DISCOVERED = "DISCOVERED"
MANUAL = "MANUAL"
class LogLevel(StrEnum):
DEBUG = "DEBUG"
INFO = "INFO"
@ -1642,6 +1659,79 @@ class Report(BaseModel, TimestampMixin):
thread = relationship("NewChatThread")
class Connection(BaseModel, TimestampMixin):
__tablename__ = "connections"
protocol = Column(SQLAlchemyEnum(ConnectionProtocol), nullable=False, index=True)
native_provider = Column(String(100), nullable=True, index=True)
base_url = Column(String(500), nullable=True)
api_key = Column(String, nullable=True)
extra = Column(JSONB, nullable=False, default=dict, server_default="{}")
scope = Column(SQLAlchemyEnum(ConnectionScope), nullable=False, index=True)
enabled = Column(Boolean, nullable=False, default=True, server_default="true")
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True
)
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
last_verified_at = Column(TIMESTAMP(timezone=True), nullable=True)
last_status = Column(String(50), nullable=True)
last_error = Column(Text, nullable=True)
search_space = relationship("SearchSpace", back_populates="connections")
user = relationship("User", back_populates="connections")
models = relationship(
"Model",
back_populates="connection",
order_by="Model.id",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
CheckConstraint(
"(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR "
"(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR "
"(scope = 'USER' AND user_id IS NOT NULL)",
name="ck_connections_scope_owner",
),
)
class Model(BaseModel, TimestampMixin):
__tablename__ = "models"
connection_id = Column(
Integer, ForeignKey("connections.id", ondelete="CASCADE"), nullable=False, index=True
)
model_id = Column(String(255), nullable=False)
display_name = Column(String(255), nullable=True)
source = Column(
SQLAlchemyEnum(ModelSource),
nullable=False,
default=ModelSource.DISCOVERED,
server_default=ModelSource.DISCOVERED.value,
)
capabilities = Column(JSONB, nullable=False, default=dict, server_default="{}")
capabilities_declared = Column(JSONB, nullable=False, default=dict, server_default="{}")
capabilities_verified = Column(JSONB, nullable=False, default=dict, server_default="{}")
capabilities_override = Column(JSONB, nullable=False, default=dict, server_default="{}")
embedding_dimension = Column(Integer, nullable=True)
enabled = Column(Boolean, nullable=False, default=True, server_default="true")
billing_tier = Column(String(50), nullable=True, index=True)
catalog = Column(JSONB, nullable=False, default=dict, server_default="{}")
connection = relationship("Connection", back_populates="models")
__table_args__ = (
UniqueConstraint("connection_id", "model_id", name="uq_models_connection_model_id"),
Index("ix_models_model_id", "model_id"),
)
class ImageGenerationConfig(BaseModel, TimestampMixin):
"""
Dedicated configuration table for image generation models.
@ -1794,7 +1884,7 @@ class SearchSpace(BaseModel, TimestampMixin):
# - Positive IDs: Custom configs from DB (NewLLMConfig table)
agent_llm_id = Column(
Integer, nullable=True, default=0
) # For agent/chat operations, defaults to Auto mode
) # For chat operations, defaults to Auto mode
image_generation_config_id = Column(
Integer, nullable=True, default=0
) # For image generation, defaults to Auto mode
@ -1802,6 +1892,22 @@ class SearchSpace(BaseModel, TimestampMixin):
Integer, nullable=True, default=0
) # For vision/screenshot analysis, defaults to Auto mode
# New connection/model role bindings. These supersede the legacy config
# columns above without removing them in this PR.
# Note: ID values preserve the existing convention:
# - 0: Auto mode
# - Negative IDs: Global virtual models from global_llm_config.yaml
# - Positive IDs: User/search-space models from the models table
chat_model_id = Column(
Integer, nullable=True, default=0
) # For agent/chat operations, defaults to Auto mode
image_gen_model_id = Column(
Integer, nullable=True, default=0
) # For image generation, defaults to Auto mode when eligible
vision_model_id = Column(
Integer, nullable=True, default=0
) # For vision/screenshot analysis, defaults to Auto mode
ai_file_sort_enabled = Column(
Boolean, nullable=False, default=False, server_default="false"
)
@ -1889,6 +1995,13 @@ class SearchSpace(BaseModel, TimestampMixin):
order_by="VisionLLMConfig.id",
cascade="all, delete-orphan",
)
connections = relationship(
"Connection",
back_populates="search_space",
order_by="Connection.id",
cascade="all, delete-orphan",
passive_deletes=True,
)
automations = relationship(
"Automation",
@ -2429,6 +2542,11 @@ if config.AUTH_TYPE == "GOOGLE":
back_populates="user",
passive_deletes=True,
)
connections = relationship(
"Connection",
back_populates="user",
passive_deletes=True,
)
# Automations created by this user
automations = relationship(
@ -2568,6 +2686,11 @@ else:
back_populates="user",
passive_deletes=True,
)
connections = relationship(
"Connection",
back_populates="user",
passive_deletes=True,
)
# Automations created by this user
automations = relationship(

View file

@ -42,6 +42,7 @@ from .linear_add_connector_route import router as linear_add_connector_router
from .logs_routes import router as logs_router
from .luma_add_connector_route import router as luma_add_connector_router
from .mcp_oauth_route import router as mcp_oauth_router
from .model_connections_routes import router as model_connections_router
from .memory_routes import router as memory_router
from .model_list_routes import router as model_list_router
from .new_chat_routes import router as new_chat_router
@ -117,6 +118,7 @@ router.include_router(confluence_add_connector_router)
router.include_router(clickup_add_connector_router)
router.include_router(dropbox_add_connector_router)
router.include_router(new_llm_config_router) # LLM configs with prompt configuration
router.include_router(model_connections_router) # Connection-centric model catalog
router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter
router.include_router(logs_router)
router.include_router(circleback_webhook_router) # Circleback meeting webhooks

View file

@ -0,0 +1,346 @@
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.config import config
from app.db import (
Connection,
ConnectionScope,
Model,
Permission,
SearchSpace,
User,
get_async_session,
)
from app.schemas import (
ConnectionCreate,
ConnectionRead,
ConnectionUpdate,
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
ModelUpdate,
VerifyConnectionResponse,
)
from app.services.model_connection_service import (
discover_models,
persist_verification,
test_model,
)
from app.users import current_active_user
from app.utils.rbac import check_permission
router = APIRouter()
logger = logging.getLogger(__name__)
def _model_read(model: Model | dict) -> ModelRead:
return ModelRead.model_validate(model)
def _connection_read(conn: Connection | dict, models: list[Model | dict] | None = None) -> ConnectionRead:
if isinstance(conn, dict):
payload = {
**conn,
"has_api_key": bool(conn.get("api_key")),
"api_key": None,
"models": [_model_read(model) for model in (models or [])],
}
payload.pop("api_key", None)
return ConnectionRead.model_validate(payload)
return ConnectionRead(
id=conn.id,
protocol=conn.protocol,
native_provider=conn.native_provider,
base_url=conn.base_url,
extra=conn.extra or {},
scope=conn.scope,
search_space_id=conn.search_space_id,
user_id=conn.user_id,
enabled=conn.enabled,
has_api_key=bool(conn.api_key),
last_verified_at=conn.last_verified_at,
last_status=conn.last_status,
last_error=conn.last_error,
models=[_model_read(model) for model in (models or conn.models or [])],
created_at=conn.created_at,
)
async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace:
result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id))
search_space = result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
return search_space
async def _load_connection(session: AsyncSession, connection_id: int) -> Connection:
result = await session.execute(
select(Connection)
.options(selectinload(Connection.models))
.where(Connection.id == connection_id)
)
conn = result.scalars().first()
if not conn:
raise HTTPException(status_code=404, detail="Connection not found")
return conn
async def _assert_connection_access(
session: AsyncSession,
user: User,
conn: Connection,
permission: str = Permission.LLM_CONFIGS_CREATE.value,
) -> None:
if conn.search_space_id:
await check_permission(
session,
user,
conn.search_space_id,
permission,
"You don't have permission to manage model connections in this search space",
)
return
if conn.user_id != user.id:
raise HTTPException(status_code=403, detail="Connection does not belong to user")
@router.get("/global-model-connections", response_model=list[ConnectionRead])
async def list_global_connections(user: User = Depends(current_active_user)):
del user
models_by_connection: dict[int, list[dict]] = {}
for model in config.GLOBAL_MODELS:
models_by_connection.setdefault(model["connection_id"], []).append(model)
return [
_connection_read(conn, models_by_connection.get(conn["id"], []))
for conn in config.GLOBAL_CONNECTIONS
]
@router.get("/model-connections", response_model=list[ConnectionRead])
async def list_connections(
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
stmt = select(Connection).options(selectinload(Connection.models))
if search_space_id is not None:
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to view model connections in this search space",
)
stmt = stmt.where(Connection.search_space_id == search_space_id)
else:
stmt = stmt.where(Connection.user_id == user.id)
result = await session.execute(stmt.order_by(Connection.id))
return [_connection_read(conn) for conn in result.scalars().all()]
@router.post("/model-connections", response_model=ConnectionRead)
async def create_connection(
data: ConnectionCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
if data.scope == ConnectionScope.GLOBAL:
raise HTTPException(status_code=400, detail="GLOBAL connections are YAML-only")
if data.scope == ConnectionScope.SEARCH_SPACE:
if data.search_space_id is None:
raise HTTPException(status_code=400, detail="search_space_id is required")
await check_permission(
session,
user,
data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
conn = Connection(
**data.model_dump(exclude={"search_space_id"}),
search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None,
user_id=user.id,
)
session.add(conn)
await session.commit()
await session.refresh(conn)
return _connection_read(conn)
@router.put("/model-connections/{connection_id}", response_model=ConnectionRead)
async def update_connection(
connection_id: int,
data: ConnectionUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
for key, value in data.model_dump(exclude_unset=True).items():
setattr(conn, key, value)
await session.commit()
await session.refresh(conn)
return _connection_read(conn)
@router.delete("/model-connections/{connection_id}")
async def delete_connection(
connection_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_DELETE.value)
await session.delete(conn)
await session.commit()
return {"status": "deleted"}
@router.post("/model-connections/{connection_id}/verify", response_model=VerifyConnectionResponse)
async def verify_model_connection(
connection_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_CREATE.value)
result = await persist_verification(conn)
await session.commit()
return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message)
@router.post("/model-connections/{connection_id}/discover", response_model=list[ModelRead])
async def discover_connection_models(
connection_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_CREATE.value)
discovered = await discover_models(conn)
by_model_id = {model.model_id: model for model in conn.models}
for item in discovered:
db_model = by_model_id.get(item["model_id"])
if db_model is None:
db_model = Model(
connection_id=conn.id,
model_id=item["model_id"],
display_name=item.get("display_name"),
source=item["source"],
capabilities=item["capabilities"],
capabilities_declared=item["capabilities"],
capabilities_verified={},
capabilities_override={},
enabled=False,
catalog=item.get("metadata") or {},
)
session.add(db_model)
else:
db_model.display_name = item.get("display_name") or db_model.display_name
db_model.capabilities_declared = item["capabilities"]
db_model.capabilities = {
**item["capabilities"],
**(db_model.capabilities_override or {}),
}
db_model.catalog = item.get("metadata") or db_model.catalog
await session.commit()
await session.refresh(conn)
return [_model_read(model) for model in conn.models]
@router.put("/models/{model_id}", response_model=ModelRead)
async def update_model(
model_id: int,
data: ModelUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(Model).options(selectinload(Model.connection)).where(Model.id == model_id)
)
model = result.scalars().first()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value)
update = data.model_dump(exclude_unset=True)
for key, value in update.items():
setattr(model, key, value)
if "capabilities_override" in update:
model.capabilities = {
**(model.capabilities_declared or {}),
**(model.capabilities_override or {}),
}
await session.commit()
await session.refresh(model)
return _model_read(model)
@router.post("/models/{model_id}/test", response_model=VerifyConnectionResponse)
async def test_connection_model(
model_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(Model).options(selectinload(Model.connection)).where(Model.id == model_id)
)
model = result.scalars().first()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value)
result = await test_model(model.connection, model)
await session.commit()
return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message)
@router.get("/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead)
async def get_model_roles(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to view model roles in this search space",
)
search_space = await _get_search_space(session, search_space_id)
return ModelRolesRead(
chat_model_id=search_space.chat_model_id,
vision_model_id=search_space.vision_model_id,
image_gen_model_id=search_space.image_gen_model_id,
)
@router.put("/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead)
async def update_model_roles(
search_space_id: int,
data: ModelRolesUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_UPDATE.value,
"You don't have permission to update model roles in this search space",
)
search_space = await _get_search_space(session, search_space_id)
for key, value in data.model_dump(exclude_unset=True).items():
setattr(search_space, key, value)
await session.commit()
await session.refresh(search_space)
return ModelRolesRead(
chat_model_id=search_space.chat_model_id,
vision_model_id=search_space.vision_model_id,
image_gen_model_id=search_space.image_gen_model_id,
)

View file

@ -44,6 +44,16 @@ from .image_generation import (
ImageGenerationRead,
)
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
from .model_connections import (
ConnectionCreate,
ConnectionRead,
ConnectionUpdate,
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
ModelUpdate,
VerifyConnectionResponse,
)
from .new_chat import (
ChatMessage,
NewChatMessageAppend,

View file

@ -0,0 +1,89 @@
import uuid
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.db import ConnectionProtocol, ConnectionScope, ModelSource
class ModelRead(BaseModel):
id: int
connection_id: int
model_id: str
display_name: str | None = None
source: ModelSource | str
capabilities: dict[str, Any]
capabilities_declared: dict[str, Any] = Field(default_factory=dict)
capabilities_verified: dict[str, Any] = Field(default_factory=dict)
capabilities_override: dict[str, Any] = Field(default_factory=dict)
embedding_dimension: int | None = None
enabled: bool
billing_tier: str | None = None
catalog: dict[str, Any] = Field(default_factory=dict)
created_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
class ConnectionRead(BaseModel):
id: int
protocol: ConnectionProtocol | str
native_provider: str | None = None
base_url: str | None = None
extra: dict[str, Any] = Field(default_factory=dict)
scope: ConnectionScope | str
search_space_id: int | None = None
user_id: uuid.UUID | None = None
enabled: bool
has_api_key: bool
last_verified_at: datetime | None = None
last_status: str | None = None
last_error: str | None = None
models: list[ModelRead] = Field(default_factory=list)
created_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
class ConnectionCreate(BaseModel):
protocol: ConnectionProtocol
native_provider: str | None = None
base_url: str | None = Field(None, max_length=500)
api_key: str | None = None
extra: dict[str, Any] = Field(default_factory=dict)
scope: ConnectionScope = ConnectionScope.SEARCH_SPACE
search_space_id: int | None = None
enabled: bool = True
class ConnectionUpdate(BaseModel):
native_provider: str | None = None
base_url: str | None = Field(None, max_length=500)
api_key: str | None = None
extra: dict[str, Any] | None = None
enabled: bool | None = None
class ModelUpdate(BaseModel):
display_name: str | None = Field(None, max_length=255)
enabled: bool | None = None
capabilities_override: dict[str, Any] | None = None
class VerifyConnectionResponse(BaseModel):
status: str
ok: bool
message: str = ""
class ModelRolesRead(BaseModel):
chat_model_id: int | None = 0
vision_model_id: int | None = 0
image_gen_model_id: int | None = 0
class ModelRolesUpdate(BaseModel):
chat_model_id: int | None = None
vision_model_id: int | None = None
image_gen_model_id: int | None = None

View file

@ -0,0 +1,209 @@
"""Connection verification, model discovery, and capability probing."""
from __future__ import annotations
import contextlib
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
import httpx
import litellm
from app.db import Connection, ConnectionProtocol, Model, ModelSource
from app.services.model_resolver import ensure_v1, to_litellm
logger = logging.getLogger(__name__)
VERIFY_TIMEOUT_SECONDS = 8.0
DISCOVERY_TIMEOUT_SECONDS = 15.0
TEST_TIMEOUT_SECONDS = 30.0
@dataclass(frozen=True)
class VerifyResult:
status: str
ok: bool
message: str = ""
def _auth_headers(conn: Connection) -> dict[str, str]:
if not conn.api_key:
return {}
return {"Authorization": f"Bearer {conn.api_key}"}
def _docker_hint(url: str | None, exc_or_status: Any) -> str:
raw = str(exc_or_status)
if not url:
return raw
if "localhost" in url or "127.0.0.1" in url:
return (
f"{raw}. The backend is running inside Docker; localhost means the "
"backend container. Use host.docker.internal and make sure the model "
"server listens on 0.0.0.0."
)
if "host.docker.internal" in url and ("refused" in raw.lower() or "connect" in raw.lower()):
return (
f"{raw}. The host is reachable only if your local model server is "
"listening on 0.0.0.0. On Linux Docker, add "
"`host.docker.internal:host-gateway` to extra_hosts."
)
return raw
async def verify_connection(conn: Connection) -> VerifyResult:
if not conn.base_url and conn.protocol in (
ConnectionProtocol.OLLAMA,
ConnectionProtocol.OPENAI_COMPATIBLE,
):
return VerifyResult("UNREACHABLE", False, "Base URL is required.")
if conn.protocol == ConnectionProtocol.OLLAMA:
url = f"{conn.base_url.rstrip('/')}/api/version"
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
url = f"{ensure_v1(conn.base_url)}/models"
else:
# Native providers do not share one cheap health endpoint. The model
# probe exercises the real path and is the authoritative check.
return VerifyResult("OK", True, "Native provider configuration accepted.")
try:
async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client:
response = await client.get(url, headers=_auth_headers(conn))
if response.status_code in (401, 403):
return VerifyResult("AUTH_FAILED", False, "Authentication failed.")
if response.status_code == 404:
if conn.protocol == ConnectionProtocol.OLLAMA and url.endswith("/v1/models"):
message = "Ollama native API should not use /v1."
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
message = "OpenAI-compatible servers should expose /v1/models."
else:
message = "Endpoint returned 404."
return VerifyResult("NOT_FOUND", False, message)
response.raise_for_status()
return VerifyResult("OK", True, "Connection verified.")
except httpx.ConnectError as exc:
return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc))
except httpx.TimeoutException as exc:
return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}")
except httpx.HTTPError as exc:
return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc))
async def persist_verification(conn: Connection) -> VerifyResult:
result = await verify_connection(conn)
conn.last_verified_at = datetime.now(UTC)
conn.last_status = result.status
conn.last_error = "" if result.ok else result.message
return result
def _litellm_capabilities(model_string: str, model_id: str) -> dict[str, bool]:
capabilities = {
"chat": True,
"vision": False,
"tools": False,
"image_gen": False,
"embedding": False,
}
with contextlib.suppress(Exception):
capabilities["vision"] = bool(litellm.supports_vision(model=model_string))
with contextlib.suppress(Exception):
capabilities["tools"] = bool(litellm.supports_function_calling(model=model_string))
try:
info = litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {}
mode = str(info.get("mode") or "")
capabilities["embedding"] = mode == "embedding"
capabilities["image_gen"] = mode in {"image_generation", "image_generation_model"}
except Exception:
pass
return capabilities
def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]:
metadata = metadata or {}
if conn.protocol == ConnectionProtocol.OLLAMA:
caps = metadata.get("capabilities") or []
capabilities = {
"chat": True,
"vision": "vision" in caps,
"tools": False,
"image_gen": False,
"embedding": "embedding" in caps,
}
return capabilities
model_string, _ = to_litellm(conn, model_id)
return _litellm_capabilities(model_string, model_id)
async def discover_models(conn: Connection) -> list[dict[str, Any]]:
if conn.protocol == ConnectionProtocol.OLLAMA:
url = f"{conn.base_url.rstrip('/')}/api/tags"
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(url, headers=_auth_headers(conn))
response.raise_for_status()
models = response.json().get("models", [])
return [
{
"model_id": item.get("model") or item.get("name"),
"display_name": item.get("name") or item.get("model"),
"source": ModelSource.DISCOVERED,
"capabilities": derive_capabilities(conn, item.get("model") or item.get("name"), item),
"metadata": item,
}
for item in models
if item.get("model") or item.get("name")
]
if conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
url = f"{ensure_v1(conn.base_url)}/models"
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(url, headers=_auth_headers(conn))
response.raise_for_status()
models = response.json().get("data", [])
return [
{
"model_id": item.get("id"),
"display_name": item.get("name") or item.get("id"),
"source": ModelSource.DISCOVERED,
"capabilities": derive_capabilities(conn, item.get("id"), item),
"metadata": item,
}
for item in models
if item.get("id")
]
# Native providers rely on curated/global catalog entries or manual rows.
return []
async def test_model(conn: Connection, model: Model) -> VerifyResult:
model_string, kwargs = to_litellm(conn, model.model_id)
try:
await litellm.acompletion(
model=model_string,
messages=[{"role": "user", "content": "Hello"}],
timeout=TEST_TIMEOUT_SECONDS,
**kwargs,
)
except Exception as exc:
return VerifyResult("UNREACHABLE", False, str(exc))
model.capabilities_verified = {
**(model.capabilities_verified or {}),
"chat": True,
}
return VerifyResult("OK", True, "Model test succeeded.")
__all__ = [
"VerifyResult",
"derive_capabilities",
"discover_models",
"persist_verification",
"test_model",
"verify_connection",
]

View file

@ -0,0 +1,75 @@
from app.services.global_model_catalog import materialize_global_model_catalog
from app.services.model_resolver import ensure_v1, to_litellm
def test_openai_compatible_resolver_normalizes_v1() -> None:
model, kwargs = to_litellm(
{
"protocol": "OPENAI_COMPATIBLE",
"base_url": "http://host.docker.internal:1234",
"api_key": "local-key",
"extra": {},
},
"qwen/qwen3",
)
assert model == "openai/qwen/qwen3"
assert kwargs["api_base"] == "http://host.docker.internal:1234/v1"
assert kwargs["api_key"] == "local-key"
assert ensure_v1("http://example.com/v1") == "http://example.com/v1"
def test_ollama_resolver_uses_native_api_base() -> None:
model, kwargs = to_litellm(
{
"protocol": "OLLAMA",
"base_url": "http://host.docker.internal:11434",
"api_key": None,
"extra": {},
},
"llama3.2",
)
assert model == "ollama_chat/llama3.2"
assert kwargs["api_base"] == "http://host.docker.internal:11434"
def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> None:
connections, models = materialize_global_model_catalog(
chat_configs=[
{
"id": -101,
"name": "OpenRouter Free",
"provider": "OPENROUTER",
"model_name": "meta-llama/llama-3.1-8b-instruct:free",
"api_key": "sk-global-secret",
"billing_tier": "free",
"anonymous_enabled": True,
"seo_enabled": True,
"rpm": 10,
"tpm": 1000,
},
{
"id": -102,
"name": "OpenRouter Premium",
"provider": "OPENROUTER",
"model_name": "anthropic/claude-sonnet-4",
"api_key": "sk-global-secret",
"billing_tier": "premium",
},
],
vision_configs=[],
image_configs=[],
)
assert len(connections) == 1
assert connections[0]["api_key"] == "sk-global-secret"
assert {model["billing_tier"] for model in models} == {"free", "premium"}
assert models[0]["catalog"]["anonymous_enabled"] is True
assert models[0]["catalog"]["rpm"] == 10
public_connections = [
{key: value for key, value in connection.items() if key != "api_key"}
for connection in connections
]
assert "sk-" not in repr(public_connections)