mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-20 21:18:13 +02:00
263 lines
9.3 KiB
Python
263 lines
9.3 KiB
Python
"""Connection verification, model discovery, and capability probing."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
import httpx
|
|
import litellm
|
|
|
|
from app.db import Connection, ConnectionProtocol, Model, ModelSource
|
|
from app.services.model_resolver import ensure_v1, to_litellm
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
VERIFY_TIMEOUT_SECONDS = 8.0
|
|
DISCOVERY_TIMEOUT_SECONDS = 15.0
|
|
TEST_TIMEOUT_SECONDS = 30.0
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class VerifyResult:
|
|
status: str
|
|
ok: bool
|
|
message: str = ""
|
|
|
|
|
|
def _auth_headers(conn: Connection) -> dict[str, str]:
|
|
if not conn.api_key:
|
|
return {}
|
|
return {"Authorization": f"Bearer {conn.api_key}"}
|
|
|
|
|
|
def _anthropic_headers(conn: Connection) -> dict[str, str]:
|
|
headers = {"anthropic-version": "2023-06-01"}
|
|
if conn.api_key:
|
|
headers["x-api-key"] = conn.api_key
|
|
return headers
|
|
|
|
|
|
def _docker_hint(url: str | None, exc_or_status: Any) -> str:
|
|
raw = str(exc_or_status)
|
|
if not url:
|
|
return raw
|
|
if "localhost" in url or "127.0.0.1" in url:
|
|
return (
|
|
f"{raw}. The backend is running inside Docker; localhost means the "
|
|
"backend container. Use host.docker.internal and make sure the model "
|
|
"server listens on 0.0.0.0."
|
|
)
|
|
if "host.docker.internal" in url and ("refused" in raw.lower() or "connect" in raw.lower()):
|
|
return (
|
|
f"{raw}. The host is reachable only if your local model server is "
|
|
"listening on 0.0.0.0. On Linux Docker, add "
|
|
"`host.docker.internal:host-gateway` to extra_hosts."
|
|
)
|
|
return raw
|
|
|
|
|
|
async def verify_connection(conn: Connection) -> VerifyResult:
|
|
if not conn.base_url:
|
|
return VerifyResult("UNREACHABLE", False, "Base URL is required.")
|
|
|
|
if conn.protocol == ConnectionProtocol.OLLAMA:
|
|
url = f"{conn.base_url.rstrip('/')}/api/version"
|
|
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
|
|
url = f"{ensure_v1(conn.base_url)}/models"
|
|
elif conn.protocol == ConnectionProtocol.ANTHROPIC:
|
|
url = f"{conn.base_url.rstrip('/')}/models"
|
|
else:
|
|
return VerifyResult("UNREACHABLE", False, "Unsupported connection protocol.")
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client:
|
|
headers = (
|
|
_anthropic_headers(conn)
|
|
if conn.protocol == ConnectionProtocol.ANTHROPIC
|
|
else _auth_headers(conn)
|
|
)
|
|
response = await client.get(url, headers=headers)
|
|
if response.status_code in (401, 403):
|
|
return VerifyResult("AUTH_FAILED", False, "Authentication failed.")
|
|
if response.status_code == 404:
|
|
if conn.protocol == ConnectionProtocol.OLLAMA and url.endswith("/v1/models"):
|
|
message = "Ollama native API should not use /v1."
|
|
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
|
|
message = "OpenAI-compatible servers should expose /v1/models."
|
|
else:
|
|
message = "Endpoint returned 404."
|
|
return VerifyResult("NOT_FOUND", False, message)
|
|
response.raise_for_status()
|
|
return VerifyResult("OK", True, "Connection verified.")
|
|
except httpx.ConnectError as exc:
|
|
return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc))
|
|
except httpx.TimeoutException as exc:
|
|
return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}")
|
|
except httpx.HTTPError as exc:
|
|
return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc))
|
|
|
|
|
|
async def persist_verification(conn: Connection) -> VerifyResult:
|
|
result = await verify_connection(conn)
|
|
conn.last_verified_at = datetime.now(UTC)
|
|
conn.last_status = result.status
|
|
conn.last_error = "" if result.ok else result.message
|
|
return result
|
|
|
|
|
|
def _litellm_capabilities(model_string: str, model_id: str) -> dict[str, bool]:
|
|
capabilities = {
|
|
"chat": True,
|
|
"vision": False,
|
|
"tools": False,
|
|
"image_gen": False,
|
|
"embedding": False,
|
|
}
|
|
with contextlib.suppress(Exception):
|
|
capabilities["vision"] = bool(litellm.supports_vision(model=model_string))
|
|
with contextlib.suppress(Exception):
|
|
capabilities["tools"] = bool(litellm.supports_function_calling(model=model_string))
|
|
try:
|
|
info = litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {}
|
|
mode = str(info.get("mode") or "")
|
|
capabilities["embedding"] = mode == "embedding"
|
|
capabilities["image_gen"] = mode in {"image_generation", "image_generation_model"}
|
|
except Exception:
|
|
pass
|
|
return capabilities
|
|
|
|
|
|
def _allowlist(conn: Connection) -> set[str]:
|
|
"""Per-connection model-id allowlist stored in ``extra.model_ids``.
|
|
|
|
Empty/absent means "no restriction" (discover everything), mirroring
|
|
OpenWebUI's behaviour. A non-empty list restricts discovery to those ids —
|
|
essential for providers like OpenRouter that expose hundreds of models.
|
|
"""
|
|
raw = (conn.extra or {}).get("model_ids") or []
|
|
return {str(item).strip() for item in raw if str(item).strip()}
|
|
|
|
|
|
async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) -> list[dict[str, Any]]:
|
|
if not base_url:
|
|
return []
|
|
|
|
url = f"{ensure_v1(base_url)}/models"
|
|
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
|
response = await client.get(url, headers=_auth_headers(conn))
|
|
response.raise_for_status()
|
|
return [
|
|
{
|
|
"model_id": item.get("id"),
|
|
"display_name": item.get("name") or item.get("id"),
|
|
"source": ModelSource.DISCOVERED,
|
|
"capabilities": derive_capabilities(conn, item.get("id"), item),
|
|
"metadata": item,
|
|
}
|
|
for item in response.json().get("data", [])
|
|
if item.get("id")
|
|
]
|
|
|
|
|
|
async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]:
|
|
if not conn.base_url:
|
|
return []
|
|
|
|
url = f"{conn.base_url.rstrip('/')}/models"
|
|
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
|
response = await client.get(url, headers=_anthropic_headers(conn))
|
|
response.raise_for_status()
|
|
models = response.json().get("data", [])
|
|
return [
|
|
{
|
|
"model_id": item.get("id"),
|
|
"display_name": item.get("display_name") or item.get("id"),
|
|
"source": ModelSource.DISCOVERED,
|
|
"capabilities": derive_capabilities(conn, item.get("id"), item),
|
|
"metadata": item,
|
|
}
|
|
for item in models
|
|
if item.get("id")
|
|
]
|
|
|
|
|
|
def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]:
|
|
metadata = metadata or {}
|
|
if conn.protocol == ConnectionProtocol.OLLAMA:
|
|
caps = metadata.get("capabilities") or []
|
|
capabilities = {
|
|
"chat": True,
|
|
"vision": "vision" in caps,
|
|
"tools": False,
|
|
"image_gen": False,
|
|
"embedding": "embedding" in caps,
|
|
}
|
|
return capabilities
|
|
|
|
model_string, _ = to_litellm(conn, model_id)
|
|
return _litellm_capabilities(model_string, model_id)
|
|
|
|
|
|
async def discover_models(conn: Connection) -> list[dict[str, Any]]:
|
|
allowlist = _allowlist(conn)
|
|
|
|
if conn.protocol == ConnectionProtocol.OLLAMA:
|
|
url = f"{conn.base_url.rstrip('/')}/api/tags"
|
|
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
|
response = await client.get(url, headers=_auth_headers(conn))
|
|
response.raise_for_status()
|
|
models = response.json().get("models", [])
|
|
results = [
|
|
{
|
|
"model_id": item.get("model") or item.get("name"),
|
|
"display_name": item.get("name") or item.get("model"),
|
|
"source": ModelSource.DISCOVERED,
|
|
"capabilities": derive_capabilities(conn, item.get("model") or item.get("name"), item),
|
|
"metadata": item,
|
|
}
|
|
for item in models
|
|
if item.get("model") or item.get("name")
|
|
]
|
|
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
|
|
results = await _discover_openai_shaped_models(conn, conn.base_url)
|
|
elif conn.protocol == ConnectionProtocol.ANTHROPIC:
|
|
results = await _discover_anthropic_models(conn)
|
|
else:
|
|
results = []
|
|
|
|
if allowlist:
|
|
results = [item for item in results if item["model_id"] in allowlist]
|
|
return results
|
|
|
|
|
|
async def test_model(conn: Connection, model: Model) -> VerifyResult:
|
|
model_string, kwargs = to_litellm(conn, model.model_id)
|
|
try:
|
|
await litellm.acompletion(
|
|
model=model_string,
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
timeout=TEST_TIMEOUT_SECONDS,
|
|
**kwargs,
|
|
)
|
|
except Exception as exc:
|
|
return VerifyResult("UNREACHABLE", False, str(exc))
|
|
|
|
model.capabilities_verified = {
|
|
**(model.capabilities_verified or {}),
|
|
"chat": True,
|
|
}
|
|
return VerifyResult("OK", True, "Model test succeeded.")
|
|
|
|
|
|
__all__ = [
|
|
"VerifyResult",
|
|
"derive_capabilities",
|
|
"discover_models",
|
|
"persist_verification",
|
|
"test_model",
|
|
"verify_connection",
|
|
]
|