mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-07-02 22:01:05 +02:00
feat(chat): route models by provider capabilities
This commit is contained in:
parent
8f20a32571
commit
c28c4f5785
18 changed files with 429 additions and 319 deletions
|
|
@ -9,53 +9,12 @@ from __future__ import annotations
|
|||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.db import Connection
|
||||
|
||||
PROTOCOL_OLLAMA = "OLLAMA"
|
||||
PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
|
||||
PROTOCOL_NATIVE = "NATIVE"
|
||||
|
||||
NATIVE_PROVIDER_PREFIX: dict[str, str] = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"AZURE": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"COMETAPI": "cometapi",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"AWS_BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"HUGGINGFACE": "huggingface",
|
||||
"MINIMAX": "openai",
|
||||
"RECRAFT": "recraft",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
PROTOCOL_ANTHROPIC = "ANTHROPIC"
|
||||
|
||||
|
||||
def ensure_v1(base_url: str | None) -> str | None:
|
||||
|
|
@ -77,6 +36,23 @@ def _protocol_value(protocol: Any) -> str:
|
|||
return getattr(protocol, "value", str(protocol))
|
||||
|
||||
|
||||
def default_litellm_provider(protocol: Any) -> str:
|
||||
protocol_value = _protocol_value(protocol)
|
||||
defaults = {
|
||||
PROTOCOL_OLLAMA: "ollama_chat",
|
||||
PROTOCOL_ANTHROPIC: "anthropic",
|
||||
PROTOCOL_OPENAI_COMPATIBLE: "openai",
|
||||
}
|
||||
return defaults.get(protocol_value, "openai")
|
||||
|
||||
|
||||
def _execution_api_base(protocol: str, base_url: str | None) -> str | None:
|
||||
del protocol
|
||||
if not base_url:
|
||||
return None
|
||||
return base_url.rstrip("/")
|
||||
|
||||
|
||||
def to_litellm(
|
||||
conn: Connection | Mapping[str, Any],
|
||||
model_id: str,
|
||||
|
|
@ -85,38 +61,19 @@ def to_litellm(
|
|||
protocol = _protocol_value(_conn_value(conn, "protocol"))
|
||||
base_url = _conn_value(conn, "base_url")
|
||||
api_key = _conn_value(conn, "api_key")
|
||||
native_provider = _conn_value(conn, "native_provider")
|
||||
litellm_provider = (
|
||||
_conn_value(conn, "litellm_provider") or default_litellm_provider(protocol)
|
||||
)
|
||||
extra = _conn_value(conn, "extra") or {}
|
||||
|
||||
kwargs: dict[str, Any] = {}
|
||||
if api_key:
|
||||
kwargs["api_key"] = api_key
|
||||
|
||||
if protocol == PROTOCOL_OLLAMA:
|
||||
model_string = f"ollama_chat/{model_id}"
|
||||
if base_url:
|
||||
kwargs["api_base"] = base_url.rstrip("/")
|
||||
elif protocol == PROTOCOL_OPENAI_COMPATIBLE:
|
||||
model_string = f"openai/{model_id}"
|
||||
api_base = ensure_v1(base_url)
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
else:
|
||||
provider_key = (native_provider or "").upper()
|
||||
prefix = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower())
|
||||
if prefix == "custom":
|
||||
custom_provider = extra.get("custom_provider") or native_provider
|
||||
model_string = f"{custom_provider}/{model_id}" if custom_provider else model_id
|
||||
else:
|
||||
model_string = f"{prefix}/{model_id}"
|
||||
|
||||
api_base = resolve_api_base(
|
||||
provider=provider_key,
|
||||
provider_prefix=prefix,
|
||||
config_api_base=base_url,
|
||||
)
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
model_string = f"{litellm_provider}/{model_id}" if litellm_provider else model_id
|
||||
api_base = _execution_api_base(protocol, base_url)
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
if api_version := extra.get("api_version"):
|
||||
kwargs["api_version"] = api_version
|
||||
|
|
@ -126,18 +83,21 @@ def to_litellm(
|
|||
|
||||
|
||||
def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Build an in-memory NATIVE connection mapping from a legacy/global config."""
|
||||
provider = str(config.get("provider") or config.get("custom_provider") or "CUSTOM")
|
||||
"""Build an in-memory connection mapping from a global config."""
|
||||
protocol = str(config.get("protocol") or PROTOCOL_OPENAI_COMPATIBLE)
|
||||
litellm_provider = str(
|
||||
config.get("litellm_provider")
|
||||
or config.get("custom_provider")
|
||||
or default_litellm_provider(protocol)
|
||||
)
|
||||
extra: dict[str, Any] = {
|
||||
"litellm_params": config.get("litellm_params") or {},
|
||||
}
|
||||
if config.get("api_version"):
|
||||
extra["api_version"] = config.get("api_version")
|
||||
if config.get("custom_provider"):
|
||||
extra["custom_provider"] = config.get("custom_provider")
|
||||
return {
|
||||
"protocol": PROTOCOL_NATIVE,
|
||||
"native_provider": provider,
|
||||
"protocol": protocol,
|
||||
"litellm_provider": litellm_provider,
|
||||
"base_url": config.get("api_base") or None,
|
||||
"api_key": config.get("api_key") or None,
|
||||
"extra": extra,
|
||||
|
|
@ -145,7 +105,7 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
|
|||
|
||||
|
||||
__all__ = [
|
||||
"NATIVE_PROVIDER_PREFIX",
|
||||
"default_litellm_provider",
|
||||
"ensure_v1",
|
||||
"native_connection_from_config",
|
||||
"to_litellm",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue