mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-14 20:55:15 +02:00
chore: ran linting
This commit is contained in:
parent
ceace003aa
commit
c7409c8995
48 changed files with 342 additions and 187 deletions
|
|
@ -148,7 +148,9 @@ def upgrade() -> None:
|
|||
nullable=False,
|
||||
),
|
||||
sa.Column("scope", connection_scope, nullable=False),
|
||||
sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
|
||||
),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=True),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("last_verified_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
|
|
@ -166,16 +168,16 @@ def upgrade() -> None:
|
|||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
if _index_exists("connections", "ix_connections_native_provider") and not _index_exists(
|
||||
"connections", "ix_connections_provider"
|
||||
):
|
||||
if _index_exists(
|
||||
"connections", "ix_connections_native_provider"
|
||||
) and not _index_exists("connections", "ix_connections_provider"):
|
||||
op.execute(
|
||||
"ALTER INDEX ix_connections_native_provider "
|
||||
"RENAME TO ix_connections_provider"
|
||||
)
|
||||
if _index_exists("connections", "ix_connections_litellm_provider") and not _index_exists(
|
||||
"connections", "ix_connections_provider"
|
||||
):
|
||||
if _index_exists(
|
||||
"connections", "ix_connections_litellm_provider"
|
||||
) and not _index_exists("connections", "ix_connections_provider"):
|
||||
op.execute(
|
||||
"ALTER INDEX ix_connections_litellm_provider "
|
||||
"RENAME TO ix_connections_provider"
|
||||
|
|
@ -209,7 +211,9 @@ def upgrade() -> None:
|
|||
nullable=False,
|
||||
),
|
||||
sa.Column("embedding_dimension", sa.Integer(), nullable=True),
|
||||
sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
|
||||
),
|
||||
sa.Column("billing_tier", sa.String(length=50), nullable=True),
|
||||
sa.Column(
|
||||
"catalog",
|
||||
|
|
@ -217,7 +221,9 @@ def upgrade() -> None:
|
|||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["connection_id"], ["connections.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connection_id"], ["connections.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"connection_id", "model_id", name="uq_models_connection_model_id"
|
||||
|
|
@ -225,18 +231,25 @@ def upgrade() -> None:
|
|||
)
|
||||
else:
|
||||
if not _column_exists("models", "supports_chat"):
|
||||
op.add_column("models", sa.Column("supports_chat", sa.Boolean(), nullable=True))
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_chat", sa.Boolean(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "max_input_tokens"):
|
||||
op.add_column("models", sa.Column("max_input_tokens", sa.Integer(), nullable=True))
|
||||
op.add_column(
|
||||
"models", sa.Column("max_input_tokens", sa.Integer(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "supports_image_input"):
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_image_input", sa.Boolean(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "supports_tools"):
|
||||
op.add_column("models", sa.Column("supports_tools", sa.Boolean(), nullable=True))
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_tools", sa.Boolean(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "supports_image_generation"):
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_image_generation", sa.Boolean(), nullable=True)
|
||||
"models",
|
||||
sa.Column("supports_image_generation", sa.Boolean(), nullable=True),
|
||||
)
|
||||
_drop_column_if_exists("models", "capabilities")
|
||||
_drop_column_if_exists("models", "capabilities_declared")
|
||||
|
|
@ -246,7 +259,9 @@ def upgrade() -> None:
|
|||
_create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"])
|
||||
|
||||
_add_searchspace_column_if_missing("chat_model_id", server_default=sa.text("0"))
|
||||
_add_searchspace_column_if_missing("image_gen_model_id", server_default=sa.text("0"))
|
||||
_add_searchspace_column_if_missing(
|
||||
"image_gen_model_id", server_default=sa.text("0")
|
||||
)
|
||||
_add_searchspace_column_if_missing("vision_model_id", server_default=sa.text("0"))
|
||||
for column_name in ("chat_model_id", "image_gen_model_id", "vision_model_id"):
|
||||
op.alter_column(
|
||||
|
|
|
|||
|
|
@ -21,20 +21,21 @@ from app.db import (
|
|||
SearchSpace,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.auto_model_pin_service import (
|
||||
auto_model_candidates,
|
||||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_resolver import to_litellm
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_global_model(model_id: int) -> dict | None:
|
||||
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
|
||||
|
||||
|
|
@ -113,13 +114,10 @@ def create_generate_image_tool(
|
|||
if image_gen_model_id_override is not None:
|
||||
# Automation run: use the captured image model, insulated from
|
||||
# later search-space changes. No search-space read needed.
|
||||
config_id = (
|
||||
image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
config_id = image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID
|
||||
else:
|
||||
config_id = (
|
||||
search_space.image_gen_model_id
|
||||
or IMAGE_GEN_AUTO_MODE_ID
|
||||
search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
|
||||
# size/quality/style are intentionally omitted: valid values
|
||||
|
|
@ -147,7 +145,9 @@ def create_generate_image_tool(
|
|||
|
||||
if config_id < 0:
|
||||
global_model = _get_global_model(config_id)
|
||||
if not global_model or not has_capability(global_model, "image_gen"):
|
||||
if not global_model or not has_capability(
|
||||
global_model, "image_gen"
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
global_connection = _get_global_connection(
|
||||
|
|
@ -174,7 +174,11 @@ def create_generate_image_tool(
|
|||
.filter(Model.id == config_id, Model.enabled.is_(True))
|
||||
)
|
||||
db_model = cfg_result.scalars().first()
|
||||
if not db_model or not db_model.connection or not db_model.connection.enabled:
|
||||
if (
|
||||
not db_model
|
||||
or not db_model.connection
|
||||
or not db_model.connection.enabled
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
conn = db_model.connection
|
||||
|
|
|
|||
|
|
@ -200,7 +200,9 @@ class AgentConfig:
|
|||
|
||||
system_instructions = yaml_config.get("system_instructions", "")
|
||||
|
||||
provider = yaml_config.get("provider") or yaml_config.get("litellm_provider", "")
|
||||
provider = yaml_config.get("provider") or yaml_config.get(
|
||||
"litellm_provider", ""
|
||||
)
|
||||
model_name = yaml_config.get("model_name", "")
|
||||
custom_provider = yaml_config.get("custom_provider")
|
||||
litellm_params = yaml_config.get("litellm_params") or {}
|
||||
|
|
@ -291,7 +293,9 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
|||
if llm_config.get("custom_provider"):
|
||||
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
|
||||
else:
|
||||
provider = llm_config.get("provider") or llm_config.get("litellm_provider", "openai")
|
||||
provider = llm_config.get("provider") or llm_config.get(
|
||||
"litellm_provider", "openai"
|
||||
)
|
||||
model_string = f"{provider}/{llm_config['model_name']}"
|
||||
|
||||
litellm_kwargs = {
|
||||
|
|
@ -317,7 +321,9 @@ def create_chat_litellm_from_agent_config(
|
|||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Create a ChatLiteLLM from an already resolved concrete model config."""
|
||||
if agent_config.is_auto_mode:
|
||||
print("Error: Auto mode must be resolved to a concrete model before LLM creation")
|
||||
print(
|
||||
"Error: Auto mode must be resolved to a concrete model before LLM creation"
|
||||
)
|
||||
return None
|
||||
|
||||
if agent_config.custom_provider:
|
||||
|
|
|
|||
|
|
@ -1582,7 +1582,10 @@ class Model(BaseModel, TimestampMixin):
|
|||
__tablename__ = "models"
|
||||
|
||||
connection_id = Column(
|
||||
Integer, ForeignKey("connections.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
Integer,
|
||||
ForeignKey("connections.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
model_id = Column(String(255), nullable=False)
|
||||
display_name = Column(String(255), nullable=True)
|
||||
|
|
@ -1597,7 +1600,9 @@ class Model(BaseModel, TimestampMixin):
|
|||
supports_image_input = Column(Boolean, nullable=True)
|
||||
supports_tools = Column(Boolean, nullable=True)
|
||||
supports_image_generation = Column(Boolean, nullable=True)
|
||||
capabilities_override = Column(JSONB, nullable=False, default=dict, server_default="{}")
|
||||
capabilities_override = Column(
|
||||
JSONB, nullable=False, default=dict, server_default="{}"
|
||||
)
|
||||
embedding_dimension = Column(Integer, nullable=True)
|
||||
enabled = Column(Boolean, nullable=False, default=True, server_default="true")
|
||||
billing_tier = Column(String(50), nullable=True, index=True)
|
||||
|
|
@ -1606,7 +1611,9 @@ class Model(BaseModel, TimestampMixin):
|
|||
connection = relationship("Connection", back_populates="models")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("connection_id", "model_id", name="uq_models_connection_model_id"),
|
||||
UniqueConstraint(
|
||||
"connection_id", "model_id", name="uq_models_connection_model_id"
|
||||
),
|
||||
Index("ix_models_model_id", "model_id"),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -43,8 +43,8 @@ from .linear_add_connector_route import router as linear_add_connector_router
|
|||
from .logs_routes import router as logs_router
|
||||
from .luma_add_connector_route import router as luma_add_connector_router
|
||||
from .mcp_oauth_route import router as mcp_oauth_router
|
||||
from .model_connections_routes import router as model_connections_router
|
||||
from .memory_routes import router as memory_router
|
||||
from .model_connections_routes import router as model_connections_router
|
||||
from .model_list_routes import router as model_list_router
|
||||
from .new_chat_routes import router as new_chat_router
|
||||
from .notes_routes import router as notes_router
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ from app.utils.signed_image_urls import verify_image_token
|
|||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_global_model(model_id: int) -> dict | None:
|
||||
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,9 @@ def _preview_model_read(item: dict) -> ModelPreviewRead:
|
|||
)
|
||||
|
||||
|
||||
def _connection_read(conn: Connection | dict, models: list[Model | dict] | None = None) -> ConnectionRead:
|
||||
def _connection_read(
|
||||
conn: Connection | dict, models: list[Model | dict] | None = None
|
||||
) -> ConnectionRead:
|
||||
if isinstance(conn, dict):
|
||||
payload = {
|
||||
**conn,
|
||||
|
|
@ -207,7 +209,9 @@ async def _resolve_role_model_id(
|
|||
return 0
|
||||
|
||||
|
||||
async def _clear_invalid_roles(session: AsyncSession, search_space_id: int) -> SearchSpace:
|
||||
async def _clear_invalid_roles(
|
||||
session: AsyncSession, search_space_id: int
|
||||
) -> SearchSpace:
|
||||
search_space = await _get_search_space(session, search_space_id)
|
||||
search_space.chat_model_id = await _resolve_role_model_id(
|
||||
session,
|
||||
|
|
@ -243,10 +247,14 @@ async def _default_unset_roles(
|
|||
if search_space.vision_model_id is None:
|
||||
vision_default = None
|
||||
if search_space.chat_model_id:
|
||||
chat_model = next((m for m in models if m.id == search_space.chat_model_id), None)
|
||||
chat_model = next(
|
||||
(m for m in models if m.id == search_space.chat_model_id), None
|
||||
)
|
||||
if chat_model and has_capability(chat_model, "vision"):
|
||||
vision_default = chat_model.id
|
||||
search_space.vision_model_id = vision_default or _default_model_for(models, "vision")
|
||||
search_space.vision_model_id = vision_default or _default_model_for(
|
||||
models, "vision"
|
||||
)
|
||||
if search_space.image_gen_model_id is None:
|
||||
search_space.image_gen_model_id = _default_model_for(models, "image_gen")
|
||||
|
||||
|
|
@ -270,7 +278,9 @@ async def list_model_providers(user: User = Depends(current_active_user)):
|
|||
|
||||
|
||||
async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace:
|
||||
result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id))
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
|
@ -305,7 +315,9 @@ async def _assert_connection_access(
|
|||
)
|
||||
return
|
||||
if conn.user_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="Connection does not belong to user")
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Connection does not belong to user"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/global-model-connections", response_model=list[ConnectionRead])
|
||||
|
|
@ -340,8 +352,7 @@ async def list_connections(
|
|||
stmt = stmt.where(Connection.user_id == user.id)
|
||||
result = await session.execute(stmt.order_by(Connection.id))
|
||||
return [
|
||||
_connection_read(conn, list(conn.models))
|
||||
for conn in result.scalars().all()
|
||||
_connection_read(conn, list(conn.models)) for conn in result.scalars().all()
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -367,7 +378,9 @@ async def create_connection(
|
|||
|
||||
conn = Connection(
|
||||
**payload,
|
||||
search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None,
|
||||
search_space_id=data.search_space_id
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE
|
||||
else None,
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
|
|
@ -389,7 +402,9 @@ async def create_connection(
|
|||
return _connection_read(conn, list(conn.models))
|
||||
|
||||
|
||||
@router.post("/model-connections/discover-preview", response_model=list[ModelPreviewRead])
|
||||
@router.post(
|
||||
"/model-connections/discover-preview", response_model=list[ModelPreviewRead]
|
||||
)
|
||||
async def preview_connection_models(
|
||||
data: ConnectionCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
|
|
@ -411,7 +426,9 @@ async def preview_connection_models(
|
|||
extra=data.extra or {},
|
||||
scope=data.scope,
|
||||
enabled=data.enabled,
|
||||
search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None,
|
||||
search_space_id=data.search_space_id
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE
|
||||
else None,
|
||||
user_id=user.id,
|
||||
)
|
||||
try:
|
||||
|
|
@ -447,7 +464,9 @@ async def test_preview_connection_model(
|
|||
extra=data.extra or {},
|
||||
scope=data.scope,
|
||||
enabled=data.enabled,
|
||||
search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None,
|
||||
search_space_id=data.search_space_id
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE
|
||||
else None,
|
||||
user_id=user.id,
|
||||
)
|
||||
model = Model(
|
||||
|
|
@ -459,7 +478,9 @@ async def test_preview_connection_model(
|
|||
catalog={},
|
||||
)
|
||||
result = await test_model(draft, model)
|
||||
return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message)
|
||||
return VerifyConnectionResponse(
|
||||
status=result.status, ok=result.ok, message=result.message
|
||||
)
|
||||
|
||||
|
||||
@router.put("/model-connections/{connection_id}", response_model=ConnectionRead)
|
||||
|
|
@ -470,7 +491,9 @@ async def update_connection(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
search_space_id = conn.search_space_id
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(conn, key, value)
|
||||
|
|
@ -489,7 +512,9 @@ async def delete_connection(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_DELETE.value)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_DELETE.value
|
||||
)
|
||||
search_space_id = conn.search_space_id
|
||||
await session.delete(conn)
|
||||
await session.commit()
|
||||
|
|
@ -499,27 +524,37 @@ async def delete_connection(
|
|||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.post("/model-connections/{connection_id}/verify", response_model=VerifyConnectionResponse)
|
||||
@router.post(
|
||||
"/model-connections/{connection_id}/verify", response_model=VerifyConnectionResponse
|
||||
)
|
||||
async def verify_model_connection(
|
||||
connection_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_CREATE.value)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_CREATE.value
|
||||
)
|
||||
result = await persist_verification(conn)
|
||||
await session.commit()
|
||||
return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message)
|
||||
return VerifyConnectionResponse(
|
||||
status=result.status, ok=result.ok, message=result.message
|
||||
)
|
||||
|
||||
|
||||
@router.post("/model-connections/{connection_id}/discover", response_model=list[ModelRead])
|
||||
@router.post(
|
||||
"/model-connections/{connection_id}/discover", response_model=list[ModelRead]
|
||||
)
|
||||
async def discover_connection_models(
|
||||
connection_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_CREATE.value)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_CREATE.value
|
||||
)
|
||||
try:
|
||||
discovered = await discover_models(conn)
|
||||
except ModelDiscoveryError as exc:
|
||||
|
|
@ -561,13 +596,17 @@ async def add_manual_model(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
|
||||
model_id = data.model_id.strip()
|
||||
if not model_id:
|
||||
raise HTTPException(status_code=400, detail="model_id is required")
|
||||
if any(existing.model_id == model_id for existing in conn.models):
|
||||
raise HTTPException(status_code=400, detail="Model already exists on this connection")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Model already exists on this connection"
|
||||
)
|
||||
|
||||
capabilities = derive_capabilities(conn, model_id)
|
||||
model = Model(
|
||||
|
|
@ -592,7 +631,9 @@ async def add_manual_model(
|
|||
return _model_read(model)
|
||||
|
||||
|
||||
@router.patch("/model-connections/{connection_id}/models", response_model=list[ModelRead])
|
||||
@router.patch(
|
||||
"/model-connections/{connection_id}/models", response_model=list[ModelRead]
|
||||
)
|
||||
async def bulk_update_models(
|
||||
connection_id: int,
|
||||
data: ModelsBulkUpdate,
|
||||
|
|
@ -600,7 +641,9 @@ async def bulk_update_models(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
search_space_id = conn.search_space_id
|
||||
|
||||
model_ids = set(data.model_ids)
|
||||
|
|
@ -632,12 +675,16 @@ async def update_model(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Model).options(selectinload(Model.connection)).where(Model.id == model_id)
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.where(Model.id == model_id)
|
||||
)
|
||||
model = result.scalars().first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value)
|
||||
await _assert_connection_access(
|
||||
session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
search_space_id = model.connection.search_space_id
|
||||
update = data.model_dump(exclude_unset=True)
|
||||
for key, value in update.items():
|
||||
|
|
@ -658,18 +705,26 @@ async def test_connection_model(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Model).options(selectinload(Model.connection)).where(Model.id == model_id)
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.where(Model.id == model_id)
|
||||
)
|
||||
model = result.scalars().first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value)
|
||||
await _assert_connection_access(
|
||||
session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
result = await test_model(model.connection, model)
|
||||
await session.commit()
|
||||
return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message)
|
||||
return VerifyConnectionResponse(
|
||||
status=result.status, ok=result.ok, message=result.message
|
||||
)
|
||||
|
||||
|
||||
@router.get("/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead)
|
||||
@router.get(
|
||||
"/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead
|
||||
)
|
||||
async def get_model_roles(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
|
|
@ -692,7 +747,9 @@ async def get_model_roles(
|
|||
)
|
||||
|
||||
|
||||
@router.put("/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead)
|
||||
@router.put(
|
||||
"/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead
|
||||
)
|
||||
async def update_model_roles(
|
||||
search_space_id: int,
|
||||
data: ModelRolesUpdate,
|
||||
|
|
|
|||
|
|
@ -127,7 +127,8 @@ from .video_presentations import (
|
|||
VideoPresentationRead,
|
||||
VideoPresentationUpdate,
|
||||
)
|
||||
__all__ = [
|
||||
|
||||
__all__ = [
|
||||
# Folder schemas
|
||||
"BulkDocumentMove",
|
||||
# Chat schemas (assistant-ui integration)
|
||||
|
|
|
|||
|
|
@ -104,4 +104,3 @@ class ImageGenerationListRead(BaseModel):
|
|||
is_success=obj.response_data is not None,
|
||||
image_count=image_count,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -145,7 +145,9 @@ def _shared_runtime_cooled_down_ids(config_ids: list[int]) -> set[int]:
|
|||
exc_info=True,
|
||||
)
|
||||
return set()
|
||||
return {cid for cid, value in zip(unique_ids, values, strict=False) if value is not None}
|
||||
return {
|
||||
cid for cid, value in zip(unique_ids, values, strict=False) if value is not None
|
||||
}
|
||||
|
||||
|
||||
def _clear_shared_runtime_cooldown(config_id: int | None = None) -> None:
|
||||
|
|
@ -388,7 +390,11 @@ async def _db_candidates(
|
|||
continue
|
||||
if conn.search_space_id is not None and conn.search_space_id != search_space_id:
|
||||
continue
|
||||
if conn.user_id is not None and parsed_user_id is not None and conn.user_id != parsed_user_id:
|
||||
if (
|
||||
conn.user_id is not None
|
||||
and parsed_user_id is not None
|
||||
and conn.user_id != parsed_user_id
|
||||
):
|
||||
continue
|
||||
if conn.user_id is not None and parsed_user_id is None:
|
||||
continue
|
||||
|
|
@ -574,9 +580,7 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
# Distinguish the "no vision-capable cfg" case from generic
|
||||
# "no usable cfg" so the streaming task can map this to the
|
||||
# MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
|
||||
raise ValueError(
|
||||
"No vision-capable LLM models are available for Auto mode"
|
||||
)
|
||||
raise ValueError("No vision-capable LLM models are available for Auto mode")
|
||||
raise ValueError("No usable LLM models are available for Auto mode")
|
||||
candidate_by_id = {int(c["id"]): c for c in candidates}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
|
|||
# Special ID for Auto mode - uses router for load balancing
|
||||
IMAGE_GEN_AUTO_MODE_ID = 0
|
||||
|
||||
|
||||
class ImageGenRouterService:
|
||||
"""
|
||||
Singleton service for managing LiteLLM Router for image generation.
|
||||
|
|
|
|||
|
|
@ -180,7 +180,11 @@ def _category_from_provider_payload(
|
|||
normalized_type = (provider_error_type or "").lower()
|
||||
if normalized_type == "rate_limit_error":
|
||||
return LLMErrorCategory.RATE_LIMITED
|
||||
if normalized_type in {"authentication_error", "invalid_api_key", "invalid_api_key_error"}:
|
||||
if normalized_type in {
|
||||
"authentication_error",
|
||||
"invalid_api_key",
|
||||
"invalid_api_key_error",
|
||||
}:
|
||||
return LLMErrorCategory.AUTH_FAILED
|
||||
if normalized_type in {"permission_denied", "forbidden"}:
|
||||
return LLMErrorCategory.PERMISSION_DENIED
|
||||
|
|
@ -193,7 +197,10 @@ def _category_from_provider_payload(
|
|||
|
||||
def _category_from_message(raw: str) -> LLMErrorCategory | None:
|
||||
lowered = raw.lower()
|
||||
if any(hint in lowered for hint in ("rate limit", "rate-limited", "temporarily rate-limited")):
|
||||
if any(
|
||||
hint in lowered
|
||||
for hint in ("rate limit", "rate-limited", "temporarily rate-limited")
|
||||
):
|
||||
return LLMErrorCategory.RATE_LIMITED
|
||||
if any(
|
||||
hint in lowered
|
||||
|
|
@ -248,4 +255,3 @@ def adapt_llm_exception(exc: BaseException) -> LLMErrorAdaptation:
|
|||
|
||||
def llm_error_message(exc: BaseException) -> str:
|
||||
return adapt_llm_exception(exc).user_message
|
||||
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ def _sanitize_content(content: Any) -> Any:
|
|||
# Special ID for Auto mode - uses router for load balancing
|
||||
AUTO_MODE_ID = 0
|
||||
|
||||
|
||||
class LLMRouterService:
|
||||
"""
|
||||
Singleton service for managing LiteLLM Router.
|
||||
|
|
|
|||
|
|
@ -410,6 +410,7 @@ async def get_vision_llm(
|
|||
unwrapped — they don't consume premium credit (issue M).
|
||||
"""
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
|
|
@ -427,7 +428,9 @@ async def get_vision_llm(
|
|||
if chat_model_id < 0:
|
||||
chat_model = get_global_model(chat_model_id)
|
||||
if chat_model and _has_capability(chat_model, "vision"):
|
||||
global_connection = get_global_connection(chat_model["connection_id"])
|
||||
global_connection = get_global_connection(
|
||||
chat_model["connection_id"]
|
||||
)
|
||||
if global_connection:
|
||||
model_string, litellm_kwargs = _chat_litellm_from_resolved(
|
||||
conn=global_connection,
|
||||
|
|
@ -466,7 +469,9 @@ async def get_vision_llm(
|
|||
if not candidates:
|
||||
logger.error("No vision-capable models available for Auto mode")
|
||||
return None
|
||||
config_id = int(choose_auto_model_candidate(candidates, search_space_id)["id"])
|
||||
config_id = int(
|
||||
choose_auto_model_candidate(candidates, search_space_id)["id"]
|
||||
)
|
||||
|
||||
if config_id < 0:
|
||||
global_model = get_global_model(config_id)
|
||||
|
|
|
|||
|
|
@ -68,7 +68,9 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str:
|
|||
"backend container. Use host.docker.internal and make sure the model "
|
||||
"server listens on 0.0.0.0."
|
||||
)
|
||||
if "host.docker.internal" in url and ("refused" in raw.lower() or "connect" in raw.lower()):
|
||||
if "host.docker.internal" in url and (
|
||||
"refused" in raw.lower() or "connect" in raw.lower()
|
||||
):
|
||||
return (
|
||||
f"{raw}. The host is reachable only if your local model server is "
|
||||
"listening on 0.0.0.0. On Linux Docker, add "
|
||||
|
|
@ -152,11 +154,17 @@ async def verify_connection(conn: Connection) -> VerifyResult:
|
|||
elif spec.discovery == "anthropic_models" and base_url:
|
||||
url = f"{base_url.rstrip('/')}/models"
|
||||
else:
|
||||
return VerifyResult("OK", True, "Connection uses provider-native authentication.")
|
||||
return VerifyResult(
|
||||
"OK", True, "Connection uses provider-native authentication."
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client:
|
||||
headers = _anthropic_headers(conn) if spec.auth_style == "x-api-key" else _auth_headers(conn)
|
||||
headers = (
|
||||
_anthropic_headers(conn)
|
||||
if spec.auth_style == "x-api-key"
|
||||
else _auth_headers(conn)
|
||||
)
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code in (401, 403):
|
||||
return VerifyResult("AUTH_FAILED", False, "Authentication failed.")
|
||||
|
|
@ -213,7 +221,9 @@ def _litellm_info(model_string: str, model_id: str) -> dict[str, Any]:
|
|||
info = litellm.get_model_info(model=model_string)
|
||||
if isinstance(info, dict):
|
||||
return info
|
||||
return litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {}
|
||||
return (
|
||||
litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {}
|
||||
)
|
||||
|
||||
|
||||
def _classify_from_litellm(model_string: str, model_id: str) -> dict[str, Any]:
|
||||
|
|
@ -230,11 +240,14 @@ def _classify_from_litellm(model_string: str, model_id: str) -> dict[str, Any]:
|
|||
"max_input_tokens": info.get("max_input_tokens") or info.get("max_tokens"),
|
||||
"supports_image_input": supports_image_input,
|
||||
"supports_tools": supports_tools,
|
||||
"supports_image_generation": mode in {"image_generation", "image_generation_model"},
|
||||
"supports_image_generation": mode
|
||||
in {"image_generation", "image_generation_model"},
|
||||
}
|
||||
|
||||
|
||||
def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, Any]:
|
||||
def derive_capabilities(
|
||||
conn: Connection, model_id: str, metadata: dict | None = None
|
||||
) -> dict[str, Any]:
|
||||
metadata = metadata or {}
|
||||
spec = spec_for(conn.provider)
|
||||
model_string, _ = to_litellm(conn, model_id)
|
||||
|
|
@ -245,7 +258,8 @@ def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None =
|
|||
facts.update(
|
||||
{
|
||||
"supports_chat": "embedding" not in caps,
|
||||
"supports_image_input": "vision" in caps or facts["supports_image_input"],
|
||||
"supports_image_input": "vision" in caps
|
||||
or facts["supports_image_input"],
|
||||
"supports_tools": "tools" in caps or facts["supports_tools"],
|
||||
"supports_image_generation": False,
|
||||
"max_input_tokens": metadata.get("context_length")
|
||||
|
|
@ -351,7 +365,9 @@ async def _ollama_tags_then_show(conn: Connection) -> list[dict[str, Any]]:
|
|||
async def _openrouter_models(conn: Connection) -> list[dict[str, Any]]:
|
||||
base_url = _base_url_or_default(conn) or "https://openrouter.ai/api/v1"
|
||||
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
||||
response = await client.get(f"{ensure_v1(base_url)}/models", headers=_auth_headers(conn))
|
||||
response = await client.get(
|
||||
f"{ensure_v1(base_url)}/models", headers=_auth_headers(conn)
|
||||
)
|
||||
response.raise_for_status()
|
||||
return normalize_openrouter_models(response.json().get("data", []))
|
||||
|
||||
|
|
@ -361,7 +377,9 @@ def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]:
|
|||
prefix = spec_for(provider).litellm_prefix or provider
|
||||
results: list[dict[str, Any]] = []
|
||||
for model_string, metadata in litellm.model_cost.items():
|
||||
if not isinstance(model_string, str) or not model_string.startswith(f"{prefix}/"):
|
||||
if not isinstance(model_string, str) or not model_string.startswith(
|
||||
f"{prefix}/"
|
||||
):
|
||||
continue
|
||||
model_id = model_string.split("/", 1)[1]
|
||||
results.append(
|
||||
|
|
@ -414,7 +432,8 @@ async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]:
|
|||
"model_id": model_id,
|
||||
"display_name": item.get("modelName") or model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
"supports_chat": "TEXT" in input_modalities and "TEXT" in output_modalities,
|
||||
"supports_chat": "TEXT" in input_modalities
|
||||
and "TEXT" in output_modalities,
|
||||
"supports_image_input": "IMAGE" in input_modalities,
|
||||
"supports_tools": None,
|
||||
"supports_image_generation": "IMAGE" in output_modalities,
|
||||
|
|
|
|||
|
|
@ -48,14 +48,20 @@ def to_litellm(
|
|||
prefix = spec.litellm_prefix or str(provider)
|
||||
model_string = f"{prefix}/{model_id}" if prefix else model_id
|
||||
if base_url:
|
||||
api_base = ensure_v1(base_url) if spec.transport == Transport.OPENAI_COMPATIBLE else base_url.rstrip("/")
|
||||
api_base = (
|
||||
ensure_v1(base_url)
|
||||
if spec.transport == Transport.OPENAI_COMPATIBLE
|
||||
else base_url.rstrip("/")
|
||||
)
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
if api_version := extra.get("api_version"):
|
||||
kwargs["api_version"] = api_version
|
||||
kwargs.update(extra.get("litellm_params", {}))
|
||||
kwargs.update(extra.get("kwargs", {}))
|
||||
if provider == "bedrock" and (bearer_token := kwargs.pop("aws_bearer_token_bedrock", None)):
|
||||
if provider == "bedrock" and (
|
||||
bearer_token := kwargs.pop("aws_bearer_token_bedrock", None)
|
||||
):
|
||||
kwargs["api_key"] = bearer_token
|
||||
return model_string, kwargs
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,10 @@ from typing import Any
|
|||
|
||||
import httpx
|
||||
|
||||
from app.services.openrouter_model_normalizer import (
|
||||
is_openrouter_image_model,
|
||||
normalize_openrouter_models,
|
||||
)
|
||||
from app.services.quality_score import (
|
||||
_HEALTH_BLEND_WEIGHT,
|
||||
_HEALTH_ENRICH_CONCURRENCY,
|
||||
|
|
@ -29,13 +33,6 @@ from app.services.quality_score import (
|
|||
aggregate_health,
|
||||
static_score_or,
|
||||
)
|
||||
from app.services.openrouter_model_normalizer import (
|
||||
is_allowed_model as _shared_is_allowed_model,
|
||||
is_compatible_provider as _shared_is_compatible_provider,
|
||||
is_openrouter_image_model,
|
||||
normalize_openrouter_models,
|
||||
supports_image_input,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -395,11 +392,7 @@ def _generate_image_gen_configs(
|
|||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
|
||||
image_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if is_openrouter_image_model(m)
|
||||
]
|
||||
image_models = [m for m in raw_models if is_openrouter_image_model(m)]
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
|
|
|
|||
|
|
@ -84,7 +84,9 @@ def is_openrouter_image_model(model: dict[str, Any]) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def normalize_openrouter_models(raw_models: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
def normalize_openrouter_models(
|
||||
raw_models: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for model in raw_models:
|
||||
if not is_openrouter_chat_model(model):
|
||||
|
|
|
|||
|
|
@ -46,7 +46,13 @@ REGISTRY: dict[str, ProviderSpec] = {
|
|||
Transport.NATIVE, "openai", "openai_models", None, False, "bearer", "OpenAI"
|
||||
),
|
||||
"anthropic": ProviderSpec(
|
||||
Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key", "Anthropic"
|
||||
Transport.NATIVE,
|
||||
"anthropic",
|
||||
"anthropic_models",
|
||||
None,
|
||||
False,
|
||||
"x-api-key",
|
||||
"Anthropic",
|
||||
),
|
||||
"azure": ProviderSpec(
|
||||
Transport.NATIVE, "azure", "static", None, True, "native", "Azure OpenAI"
|
||||
|
|
@ -55,7 +61,13 @@ REGISTRY: dict[str, ProviderSpec] = {
|
|||
Transport.NATIVE, "vertex_ai", "static", None, False, "native", "Vertex AI"
|
||||
),
|
||||
"bedrock": ProviderSpec(
|
||||
Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native", "Amazon Bedrock"
|
||||
Transport.NATIVE,
|
||||
"bedrock",
|
||||
"bedrock_models",
|
||||
None,
|
||||
False,
|
||||
"native",
|
||||
"Amazon Bedrock",
|
||||
),
|
||||
"openrouter": ProviderSpec(
|
||||
Transport.OPENAI_COMPATIBLE,
|
||||
|
|
|
|||
|
|
@ -86,4 +86,3 @@ def user_content_to_llm_content(
|
|||
if allow_images and any(part.get("type") == "image_url" for part in parts):
|
||||
return parts
|
||||
return "\n".join(text_chunks)
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,9 @@ def _text_from_content(content: Any) -> str:
|
|||
return "".join(text_parts)
|
||||
|
||||
|
||||
def normalize_ai_message_to_parts(message: AIMessage | Any | None) -> list[dict[str, Any]]:
|
||||
def normalize_ai_message_to_parts(
|
||||
message: AIMessage | Any | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return user-visible assistant-ui parts for a final AI message.
|
||||
|
||||
We intentionally do not backfill provider ``thinking`` /
|
||||
|
|
@ -60,7 +62,9 @@ def last_ai_message(messages: Iterable[Any] | None) -> AIMessage | Any | None:
|
|||
return None
|
||||
|
||||
|
||||
def final_assistant_parts_from_messages(messages: Iterable[Any] | None) -> list[dict[str, Any]]:
|
||||
def final_assistant_parts_from_messages(
|
||||
messages: Iterable[Any] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
return normalize_ai_message_to_parts(last_ai_message(messages))
|
||||
|
||||
|
||||
|
|
@ -83,4 +87,3 @@ def merge_streamed_and_final_parts(
|
|||
if not has_non_empty_text_part(final_parts):
|
||||
return streamed_parts
|
||||
return [*streamed_parts, *final_parts]
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ from app.agents.chat.multi_agent_chat.main_agent.middleware.kb_persistence impor
|
|||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.message_parts_normalizer import (
|
||||
final_assistant_parts_from_messages,
|
||||
)
|
||||
from app.tasks.chat.streaming.contract.file_contract import (
|
||||
contract_enforcement_active,
|
||||
evaluate_file_contract_outcome,
|
||||
|
|
@ -25,9 +28,6 @@ from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
|
|||
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
|
||||
all_interrupt_values,
|
||||
)
|
||||
from app.tasks.chat.message_parts_normalizer import (
|
||||
final_assistant_parts_from_messages,
|
||||
)
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
from app.tasks.chat.streaming.shared.utils import safe_float
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
|
|
|||
|
|
@ -146,9 +146,10 @@ def _provider_error_extra(adapted: Any) -> dict[str, Any] | None:
|
|||
|
||||
def _classify_provider_exception(
|
||||
exc: Exception,
|
||||
) -> tuple[
|
||||
str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None
|
||||
] | None:
|
||||
) -> (
|
||||
tuple[str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None]
|
||||
| None
|
||||
):
|
||||
adapted = adapt_llm_exception(exc)
|
||||
|
||||
if adapted.category is LLMErrorCategory.RATE_LIMITED:
|
||||
|
|
|
|||
|
|
@ -55,8 +55,12 @@ def _agent_config_from_resolved(
|
|||
)
|
||||
|
||||
|
||||
async def _load_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace | None:
|
||||
result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id))
|
||||
async def _load_search_space(
|
||||
session: AsyncSession, search_space_id: int
|
||||
) -> SearchSpace | None:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
|
|
@ -131,7 +135,9 @@ async def load_llm_bundle(
|
|||
None,
|
||||
)
|
||||
|
||||
global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None)
|
||||
global_model = next(
|
||||
(m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None
|
||||
)
|
||||
if not global_model or not has_capability(global_model, "chat"):
|
||||
return None, None, f"Failed to load global chat model with id {config_id}"
|
||||
global_connection = next(
|
||||
|
|
@ -144,7 +150,9 @@ async def load_llm_bundle(
|
|||
)
|
||||
if not global_connection:
|
||||
return None, None, f"Failed to load global connection for model {config_id}"
|
||||
model_string, litellm_kwargs = to_litellm(global_connection, global_model["model_id"])
|
||||
model_string, litellm_kwargs = to_litellm(
|
||||
global_connection, global_model["model_id"]
|
||||
)
|
||||
display_name = global_model.get("display_name") or global_model.get("model_id")
|
||||
provider = global_connection.get("provider") or ""
|
||||
register_model_usage_metadata(
|
||||
|
|
|
|||
|
|
@ -37,4 +37,3 @@ def test_sanitize_messages_sets_tool_only_ai_content_to_none() -> None:
|
|||
|
||||
assert sanitized[0].content is None
|
||||
assert message.content == ""
|
||||
|
||||
|
|
|
|||
|
|
@ -77,7 +77,13 @@ async def test_resolve_billing_for_premium_global_config(monkeypatch):
|
|||
config,
|
||||
"GLOBAL_CONNECTIONS",
|
||||
[
|
||||
{"id": -101, "provider": "openai", "api_key": "sk-test", "base_url": None, "extra": {}},
|
||||
{
|
||||
"id": -101,
|
||||
"provider": "openai",
|
||||
"api_key": "sk-test",
|
||||
"base_url": None,
|
||||
"extra": {},
|
||||
},
|
||||
{
|
||||
"id": -102,
|
||||
"provider": "openrouter",
|
||||
|
|
@ -154,7 +160,15 @@ async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
|
|||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_CONNECTIONS",
|
||||
[{"id": -101, "provider": "openai", "api_key": "sk-test", "base_url": None, "extra": {}}],
|
||||
[
|
||||
{
|
||||
"id": -101,
|
||||
"provider": "openai",
|
||||
"api_key": "sk-test",
|
||||
"base_url": None,
|
||||
"extra": {},
|
||||
}
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -120,7 +120,9 @@ def _set_global_llm_configs(monkeypatch, config, configs: list[dict]):
|
|||
"supports_chat": cfg.get("supports_chat", True),
|
||||
"supports_image_input": cfg.get("supports_image_input", True),
|
||||
"supports_tools": cfg.get("supports_tools", True),
|
||||
"supports_image_generation": cfg.get("supports_image_generation", False),
|
||||
"supports_image_generation": cfg.get(
|
||||
"supports_image_generation", False
|
||||
),
|
||||
"capabilities_override": cfg.get("capabilities_override") or {},
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"catalog": {
|
||||
|
|
@ -157,7 +159,12 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
|
|||
monkeypatch,
|
||||
config,
|
||||
[
|
||||
{"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
|
||||
{
|
||||
"id": -2,
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
},
|
||||
{
|
||||
"id": -1,
|
||||
"litellm_provider": "openai",
|
||||
|
|
@ -548,7 +555,12 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch):
|
|||
monkeypatch,
|
||||
config,
|
||||
[
|
||||
{"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
|
||||
{
|
||||
"id": -2,
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -573,7 +585,12 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
|
|||
monkeypatch,
|
||||
config,
|
||||
[
|
||||
{"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
|
||||
{
|
||||
"id": -2,
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -135,7 +135,9 @@ async def test_generate_image_tool_global_sets_explicit_api_base():
|
|||
with (
|
||||
patch.object(gi_module, "shielded_async_session", return_value=session_cm),
|
||||
patch.object(gi_module, "_get_global_model", return_value=global_model),
|
||||
patch.object(gi_module, "_get_global_connection", return_value=global_connection),
|
||||
patch.object(
|
||||
gi_module, "_get_global_connection", return_value=global_connection
|
||||
),
|
||||
patch.object(
|
||||
gi_module, "aimage_generation", side_effect=fake_aimage_generation
|
||||
),
|
||||
|
|
|
|||
|
|
@ -288,4 +288,3 @@ def test_generate_image_gen_configs_assigns_image_id_offset():
|
|||
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
|
||||
assert all(c["id"] < -20_000 + 1 for c in cfgs)
|
||||
assert all(c["id"] > -29_000_000 for c in cfgs)
|
||||
|
||||
|
|
|
|||
|
|
@ -368,5 +368,3 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog):
|
|||
|
||||
# The good config still registered.
|
||||
assert any("custom-deployment" in payload for payload in successful_calls)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -77,4 +77,3 @@ def test_stream_classifier_keeps_unknown_errors_generic() -> None:
|
|||
assert expected is False
|
||||
assert message == "Error during chat: database exploded"
|
||||
assert extra is None
|
||||
|
||||
|
|
|
|||
|
|
@ -59,4 +59,3 @@ def test_user_images_can_be_dropped_for_text_only_history() -> None:
|
|||
]
|
||||
|
||||
assert user_content_to_llm_content(content, allow_images=False) == "look"
|
||||
|
||||
|
|
|
|||
|
|
@ -65,4 +65,3 @@ def test_merge_does_not_duplicate_when_stream_already_has_text() -> None:
|
|||
final = [{"type": "text", "text": "final answer"}]
|
||||
|
||||
assert merge_streamed_and_final_parts(streamed, final) == streamed
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"use client";
|
||||
import { CalendarDays, AlarmClock, Info } from "lucide-react";
|
||||
import { AlarmClock, CalendarDays, Info } from "lucide-react";
|
||||
import { Table, TableBody, TableHead, TableHeader, TableRow } from "@/components/ui/table";
|
||||
import type { AutomationSummary } from "@/contracts/types/automation.types";
|
||||
import { AutomationRow } from "./automation-row";
|
||||
|
|
|
|||
|
|
@ -37,11 +37,7 @@ export function DashboardClientLayout({
|
|||
const setActiveSearchSpaceIdState = useSetAtom(activeSearchSpaceIdAtom);
|
||||
const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom);
|
||||
|
||||
const {
|
||||
data: modelRoles = {},
|
||||
isLoading: loading,
|
||||
error,
|
||||
} = useAtomValue(modelRolesAtom);
|
||||
const { data: modelRoles = {}, isLoading: loading, error } = useAtomValue(modelRolesAtom);
|
||||
const { data: globalConnections = [], isLoading: globalConfigsLoading } = useAtomValue(
|
||||
globalModelConnectionsAtom
|
||||
);
|
||||
|
|
@ -158,13 +154,13 @@ export function DashboardClientLayout({
|
|||
|
||||
// Determine if we should show loading
|
||||
const shouldShowLoading =
|
||||
(!hasCheckedOnboarding &&
|
||||
(!isSearchSpaceReady ||
|
||||
loading ||
|
||||
accessLoading ||
|
||||
globalConfigsLoading ||
|
||||
modelConnectionsLoading) &&
|
||||
!isOnboardingPage);
|
||||
!hasCheckedOnboarding &&
|
||||
(!isSearchSpaceReady ||
|
||||
loading ||
|
||||
accessLoading ||
|
||||
globalConfigsLoading ||
|
||||
modelConnectionsLoading) &&
|
||||
!isOnboardingPage;
|
||||
|
||||
// Use global loading screen - spinner animation won't reset
|
||||
useGlobalLoadingEffect(shouldShowLoading);
|
||||
|
|
|
|||
|
|
@ -1,12 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import {
|
||||
BookText,
|
||||
Cpu,
|
||||
Earth,
|
||||
Settings,
|
||||
UserKey,
|
||||
} from "lucide-react";
|
||||
import { BookText, Cpu, Earth, Settings, UserKey } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useSelectedLayoutSegment } from "next/navigation";
|
||||
import { useTranslations } from "next-intl";
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import type {
|
|||
ModelPreviewRead,
|
||||
ModelRead,
|
||||
ModelRoles,
|
||||
ModelTestPreviewRequest,
|
||||
ModelsBulkUpdateRequest,
|
||||
ModelTestPreviewRequest,
|
||||
ModelUpdateRequest,
|
||||
VerifyConnectionResponse,
|
||||
} from "@/contracts/types/model-connections.types";
|
||||
|
|
|
|||
|
|
@ -957,7 +957,10 @@ interface ComposerActionProps {
|
|||
searchSpaceId: number;
|
||||
}
|
||||
|
||||
const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false, searchSpaceId }) => {
|
||||
const ComposerAction: FC<ComposerActionProps> = ({
|
||||
isBlockedByOtherUser = false,
|
||||
searchSpaceId,
|
||||
}) => {
|
||||
const mentionedDocuments = useAtomValue(mentionedDocumentsAtom);
|
||||
const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom);
|
||||
const [toolsPopoverOpen, setToolsPopoverOpen] = useState(false);
|
||||
|
|
|
|||
|
|
@ -15,9 +15,9 @@ export { default as GeminiIcon } from "./gemini.svg";
|
|||
export { default as GitHubModelsIcon } from "./github.svg";
|
||||
export { default as GroqIcon } from "./groq.svg";
|
||||
export { default as HuggingFaceIcon } from "./huggingface.svg";
|
||||
export { default as LmStudioIcon } from "./lm-studio.svg";
|
||||
export { default as MiniMaxIcon } from "./minimax.svg";
|
||||
export { default as MistralIcon } from "./mistral.svg";
|
||||
export { default as LmStudioIcon } from "./lm-studio.svg";
|
||||
export { default as MoonshotIcon } from "./moonshot.svg";
|
||||
export { default as NscaleIcon } from "./nscale.svg";
|
||||
export { default as OllamaIcon } from "./ollama.svg";
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { useAtom, useAtomValue, useSetAtom } from "jotai";
|
||||
import { AlertTriangle, AlarmClock, Inbox, LibraryBig } from "lucide-react";
|
||||
import { AlarmClock, AlertTriangle, Inbox, LibraryBig } from "lucide-react";
|
||||
import { useParams, usePathname, useRouter } from "next/navigation";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { useTheme } from "next-themes";
|
||||
|
|
|
|||
|
|
@ -66,9 +66,7 @@ export function Header({ mobileMenuTrigger }: HeaderProps) {
|
|||
return (
|
||||
<header className="sticky top-0 z-10 flex h-14 shrink-0 items-center gap-2 bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60 px-4">
|
||||
{/* Left side - Mobile menu trigger */}
|
||||
<div className="flex flex-1 items-center gap-2 min-w-0">
|
||||
{mobileMenuTrigger}
|
||||
</div>
|
||||
<div className="flex flex-1 items-center gap-2 min-w-0">{mobileMenuTrigger}</div>
|
||||
|
||||
{/* Right side - Actions */}
|
||||
<div className="ml-auto flex items-center gap-2">
|
||||
|
|
|
|||
|
|
@ -82,10 +82,7 @@ function groupedModels(models: ChatModel[]) {
|
|||
}, {});
|
||||
}
|
||||
|
||||
export function ModelSelector({
|
||||
searchSpaceId,
|
||||
className,
|
||||
}: ModelSelectorProps) {
|
||||
export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps) {
|
||||
const router = useRouter();
|
||||
const isMobile = useIsMobile();
|
||||
const [open, setOpen] = useState(false);
|
||||
|
|
@ -250,11 +247,9 @@ export function ModelSelector({
|
|||
className
|
||||
)}
|
||||
>
|
||||
{selected ? (
|
||||
getProviderIcon(selected.provider, { className: "size-4 shrink-0" })
|
||||
) : (
|
||||
getProviderIcon(AUTO_PROVIDER_ICON_KEY, { className: "size-4 shrink-0" })
|
||||
)}
|
||||
{selected
|
||||
? getProviderIcon(selected.provider, { className: "size-4 shrink-0" })
|
||||
: getProviderIcon(AUTO_PROVIDER_ICON_KEY, { className: "size-4 shrink-0" })}
|
||||
<span className="min-w-0 flex-1 truncate text-sm">
|
||||
{selected ? modelName(selected) : "Auto"}
|
||||
</span>
|
||||
|
|
|
|||
|
|
@ -16,12 +16,7 @@ interface ApiBaseUrlFieldProps {
|
|||
}
|
||||
|
||||
/** Shared API Base URL input. The prefilled default is passed in via `value`. */
|
||||
export function ApiBaseUrlField({
|
||||
value,
|
||||
onChange,
|
||||
placeholder,
|
||||
hint,
|
||||
}: ApiBaseUrlFieldProps) {
|
||||
export function ApiBaseUrlField({ value, onChange, placeholder, hint }: ApiBaseUrlFieldProps) {
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Label>API Base URL</Label>
|
||||
|
|
|
|||
|
|
@ -95,7 +95,8 @@ export function ProviderConnectDialog({
|
|||
})();
|
||||
|
||||
const canRefreshModels = !isAzure && !isVertex && (!isBedrock || canSubmit);
|
||||
const hasEnabledModel = previewModels.some((model) => model.enabled) || Boolean(currentDraft.seedModelId);
|
||||
const hasEnabledModel =
|
||||
previewModels.some((model) => model.enabled) || Boolean(currentDraft.seedModelId);
|
||||
const canConnect = canSubmit && hasEnabledModel;
|
||||
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { AlertCircle, AlarmClock, CornerDownLeftIcon, ExternalLink, Pencil } from "lucide-react";
|
||||
import { AlarmClock, AlertCircle, CornerDownLeftIcon, ExternalLink, Pencil } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import {
|
||||
AlarmClock,
|
||||
Brain,
|
||||
Calendar,
|
||||
AlarmClock,
|
||||
FileEdit,
|
||||
FilePlus,
|
||||
FileText,
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ import {
|
|||
type ModelProviderRead,
|
||||
type ModelRead,
|
||||
type ModelRoles,
|
||||
type ModelTestPreviewRequest,
|
||||
type ModelsBulkUpdateRequest,
|
||||
type ModelTestPreviewRequest,
|
||||
type ModelUpdateRequest,
|
||||
modelCreateRequest,
|
||||
modelListResponse,
|
||||
|
|
@ -20,8 +20,8 @@ import {
|
|||
modelProviderListResponse,
|
||||
modelRead,
|
||||
modelRoles,
|
||||
modelTestPreviewRequest,
|
||||
modelsBulkUpdateRequest,
|
||||
modelTestPreviewRequest,
|
||||
modelUpdateRequest,
|
||||
type VerifyConnectionResponse,
|
||||
verifyConnectionResponse,
|
||||
|
|
@ -94,18 +94,16 @@ class ModelConnectionsApiService {
|
|||
);
|
||||
};
|
||||
|
||||
testPreviewModel = async (request: ModelTestPreviewRequest): Promise<VerifyConnectionResponse> => {
|
||||
testPreviewModel = async (
|
||||
request: ModelTestPreviewRequest
|
||||
): Promise<VerifyConnectionResponse> => {
|
||||
const parsed = modelTestPreviewRequest.safeParse(request);
|
||||
if (!parsed.success) {
|
||||
throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", "));
|
||||
}
|
||||
return baseApiService.post(
|
||||
`/api/v1/model-connections/test-preview`,
|
||||
verifyConnectionResponse,
|
||||
{
|
||||
body: parsed.data,
|
||||
}
|
||||
);
|
||||
return baseApiService.post(`/api/v1/model-connections/test-preview`, verifyConnectionResponse, {
|
||||
body: parsed.data,
|
||||
});
|
||||
};
|
||||
|
||||
addManualModel = async (
|
||||
|
|
|
|||
|
|
@ -22,8 +22,7 @@ export function isLlmOnboardingComplete(
|
|||
|
||||
return connections.some((connection) =>
|
||||
connection.models.some(
|
||||
(model) =>
|
||||
model.id === resolvedChatModelId && model.enabled && Boolean(model.supports_chat)
|
||||
(model) => model.id === resolvedChatModelId && model.enabled && Boolean(model.supports_chat)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue