mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-28 21:49:40 +02:00
feat(chat): route models by provider capabilities
This commit is contained in:
parent
8f20a32571
commit
c28c4f5785
18 changed files with 429 additions and 319 deletions
|
|
@ -40,7 +40,7 @@ def check_image_input_capability(
|
|||
else None
|
||||
)
|
||||
if not is_known_text_only_chat_model(
|
||||
provider=agent_config.provider,
|
||||
litellm_provider=agent_config.provider,
|
||||
model_name=agent_config.model_name,
|
||||
base_model=agent_base_model,
|
||||
custom_provider=agent_config.custom_provider,
|
||||
|
|
|
|||
|
|
@ -80,7 +80,6 @@ async def _generate_title(
|
|||
from litellm import acompletion
|
||||
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.token_tracking_service import _turn_accumulator
|
||||
|
||||
# Excludes this turn's own assistant row (pre-written by
|
||||
|
|
@ -125,26 +124,12 @@ async def _generate_title(
|
|||
router = LLMRouterService.get_router()
|
||||
response = await router.acompletion(model="auto", messages=messages)
|
||||
else:
|
||||
# Apply the same ``api_base`` cascade chat / vision / image-gen
|
||||
# call sites use so we never inherit ``litellm.api_base``
|
||||
# (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat
|
||||
# config itself ships an empty ``api_base``. Without this the
|
||||
# title-gen on an OpenRouter chat config would 404 against the
|
||||
# inherited Azure endpoint — see ``provider_api_base`` for the
|
||||
# same bug repro on the image-gen / vision paths.
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None
|
||||
provider_value = agent_config.provider if agent_config is not None else None
|
||||
title_api_base = resolve_api_base(
|
||||
provider=provider_value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=getattr(llm, "api_base", None),
|
||||
)
|
||||
response = await acompletion(
|
||||
model=raw_model,
|
||||
messages=messages,
|
||||
api_key=getattr(llm, "api_key", None),
|
||||
api_base=title_api_base,
|
||||
api_base=getattr(llm, "api_base", None),
|
||||
)
|
||||
|
||||
usage_info = None
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""Load an LLM + AgentConfig bundle for a given config id.
|
||||
|
||||
Handles both code paths uniformly:
|
||||
- ``config_id >= 0`` → database-backed ``NewLLMConfig`` row (per-user/per-space).
|
||||
- ``config_id < 0`` → YAML-defined global LLM config (built-in defaults).
|
||||
- ``config_id > 0`` → database-backed model-connection ``Model`` row.
|
||||
- ``config_id < 0`` → virtual global model materialized from YAML/OpenRouter.
|
||||
|
||||
Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is
|
||||
``None``. The caller emits the friendly SSE error frame.
|
||||
|
|
@ -12,15 +12,72 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
AgentConfig,
|
||||
create_chat_litellm_from_agent_config,
|
||||
create_chat_litellm_from_config,
|
||||
load_agent_config,
|
||||
load_global_llm_config_by_id,
|
||||
SanitizedChatLiteLLM,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import Model, SearchSpace
|
||||
from app.services.model_resolver import to_litellm
|
||||
|
||||
|
||||
def _agent_config_from_resolved(
|
||||
*,
|
||||
config_id: int,
|
||||
config_name: str | None,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
litellm_params: dict | None,
|
||||
supports_image_input: bool,
|
||||
billing_tier: str = "free",
|
||||
) -> AgentConfig:
|
||||
return AgentConfig(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
api_key=api_key or "",
|
||||
api_base=api_base,
|
||||
custom_provider=None,
|
||||
litellm_params=litellm_params,
|
||||
config_id=config_id,
|
||||
config_name=config_name,
|
||||
is_auto_mode=False,
|
||||
billing_tier=billing_tier,
|
||||
is_premium=billing_tier == "premium",
|
||||
supports_image_input=supports_image_input,
|
||||
)
|
||||
|
||||
|
||||
async def _load_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace | None:
|
||||
result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id))
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def _load_db_model(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
model_id: int,
|
||||
search_space: SearchSpace,
|
||||
) -> Model | None:
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.where(Model.id == model_id, Model.enabled.is_(True))
|
||||
)
|
||||
model = result.scalars().first()
|
||||
if not model or not model.connection or not model.connection.enabled:
|
||||
return None
|
||||
conn = model.connection
|
||||
if conn.search_space_id is not None and conn.search_space_id != search_space.id:
|
||||
return None
|
||||
if conn.user_id is not None and conn.user_id != search_space.user_id:
|
||||
return None
|
||||
return model
|
||||
|
||||
|
||||
async def load_llm_bundle(
|
||||
|
|
@ -29,29 +86,67 @@ async def load_llm_bundle(
|
|||
config_id: int,
|
||||
search_space_id: int,
|
||||
) -> tuple[Any, AgentConfig | None, str | None]:
|
||||
if config_id >= 0:
|
||||
loaded_agent_config = await load_agent_config(
|
||||
session=session,
|
||||
config_id=config_id,
|
||||
search_space_id=search_space_id,
|
||||
search_space = await _load_search_space(session, search_space_id)
|
||||
if not search_space:
|
||||
return None, None, f"Search space {search_space_id} not found"
|
||||
|
||||
if config_id > 0:
|
||||
model = await _load_db_model(
|
||||
session,
|
||||
model_id=config_id,
|
||||
search_space=search_space,
|
||||
)
|
||||
if not loaded_agent_config:
|
||||
if not model or not (model.capabilities or {}).get("chat"):
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
f"Failed to load NewLLMConfig with id {config_id}",
|
||||
f"Failed to load chat model with id {config_id}",
|
||||
)
|
||||
model_string, litellm_kwargs = to_litellm(model.connection, model.model_id)
|
||||
agent_config = _agent_config_from_resolved(
|
||||
config_id=config_id,
|
||||
config_name=model.display_name or model.model_id,
|
||||
provider=model.connection.litellm_provider or "",
|
||||
model_name=model.model_id,
|
||||
api_key=model.connection.api_key,
|
||||
api_base=model.connection.base_url,
|
||||
litellm_params=(model.connection.extra or {}).get("litellm_params"),
|
||||
supports_image_input=bool((model.capabilities or {}).get("vision")),
|
||||
billing_tier="free",
|
||||
)
|
||||
return (
|
||||
create_chat_litellm_from_agent_config(loaded_agent_config),
|
||||
loaded_agent_config,
|
||||
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
|
||||
agent_config,
|
||||
None,
|
||||
)
|
||||
|
||||
loaded_llm_config = load_global_llm_config_by_id(config_id)
|
||||
if not loaded_llm_config:
|
||||
return None, None, f"Failed to load LLM config with id {config_id}"
|
||||
return (
|
||||
create_chat_litellm_from_config(loaded_llm_config),
|
||||
AgentConfig.from_yaml_config(loaded_llm_config),
|
||||
global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None)
|
||||
if not global_model or not (global_model.get("capabilities") or {}).get("chat"):
|
||||
return None, None, f"Failed to load global chat model with id {config_id}"
|
||||
global_connection = next(
|
||||
(
|
||||
c
|
||||
for c in config.GLOBAL_CONNECTIONS
|
||||
if c.get("id") == global_model.get("connection_id")
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not global_connection:
|
||||
return None, None, f"Failed to load global connection for model {config_id}"
|
||||
model_string, litellm_kwargs = to_litellm(global_connection, global_model["model_id"])
|
||||
agent_config = _agent_config_from_resolved(
|
||||
config_id=config_id,
|
||||
config_name=global_model.get("display_name") or global_model.get("model_id"),
|
||||
provider=global_connection.get("litellm_provider") or "",
|
||||
model_name=global_model["model_id"],
|
||||
api_key=global_connection.get("api_key"),
|
||||
api_base=global_connection.get("base_url"),
|
||||
litellm_params=(global_connection.get("extra") or {}).get("litellm_params"),
|
||||
supports_image_input=bool((global_model.get("capabilities") or {}).get("vision")),
|
||||
billing_tier=str(global_model.get("billing_tier", "free")).lower(),
|
||||
)
|
||||
return (
|
||||
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
|
||||
agent_config,
|
||||
None,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue