mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
cloud: added openrouter integration with global configs
This commit is contained in:
parent
ff4e0f9b62
commit
4a51ccdc2c
26 changed files with 911 additions and 178 deletions
|
|
@ -49,6 +49,49 @@ def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool:
|
|||
return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc)))
|
||||
|
||||
|
||||
_UNIVERSAL_CONTENT_TYPES = {
|
||||
"text",
|
||||
"image_url",
|
||||
"input_audio",
|
||||
"refusal",
|
||||
"audio",
|
||||
"file",
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_content(content: Any) -> Any:
|
||||
"""Normalise a LangChain message ``content`` field so it is safe for any
|
||||
downstream provider (Azure, OpenAI, OpenRouter, etc.).
|
||||
|
||||
* Strips provider-specific block types (e.g. ``thinking`` from reasoning models).
|
||||
* Removes text blocks with blank text (Bedrock rejects ``{"type":"text","text":""}``)
|
||||
* Converts bare strings inside a list to ``{"type": "text", "text": ...}`` objects
|
||||
(Azure rejects raw strings in a content array).
|
||||
* Collapses a single-text-block list to a plain string for maximum compatibility.
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return content
|
||||
|
||||
filtered: list[dict] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
if block:
|
||||
filtered.append({"type": "text", "text": block})
|
||||
elif isinstance(block, dict):
|
||||
block_type = block.get("type", "text")
|
||||
if block_type not in _UNIVERSAL_CONTENT_TYPES:
|
||||
continue
|
||||
if block_type == "text" and not block.get("text"):
|
||||
continue
|
||||
filtered.append(block)
|
||||
|
||||
if not filtered:
|
||||
return ""
|
||||
if len(filtered) == 1 and filtered[0].get("type") == "text":
|
||||
return filtered[0].get("text", "")
|
||||
return filtered
|
||||
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
AUTO_MODE_ID = 0
|
||||
|
||||
|
|
@ -103,6 +146,7 @@ class LLMRouterService:
|
|||
_model_list: list[dict] = []
|
||||
_router_settings: dict = {}
|
||||
_initialized: bool = False
|
||||
_premium_model_strings: set[str] = set()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
|
|
@ -135,22 +179,28 @@ class LLMRouterService:
|
|||
logger.debug("LLM Router already initialized, skipping")
|
||||
return
|
||||
|
||||
auto_configs = [
|
||||
c for c in global_configs if c.get("billing_tier", "free") != "premium"
|
||||
]
|
||||
|
||||
model_list = []
|
||||
for config in auto_configs:
|
||||
premium_models: set[str] = set()
|
||||
for config in global_configs:
|
||||
deployment = cls._config_to_deployment(config)
|
||||
if deployment:
|
||||
model_list.append(deployment)
|
||||
if config.get("billing_tier") == "premium":
|
||||
model_string = deployment["litellm_params"]["model"]
|
||||
premium_models.add(model_string)
|
||||
|
||||
if not model_list:
|
||||
logger.warning("No valid LLM configs found for router initialization")
|
||||
return
|
||||
|
||||
instance._model_list = model_list
|
||||
instance._premium_model_strings = premium_models
|
||||
instance._router_settings = router_settings or {}
|
||||
logger.info(
|
||||
"Router pool: %d deployments (%d premium)",
|
||||
len(model_list),
|
||||
len(premium_models),
|
||||
)
|
||||
|
||||
# Default router settings optimized for rate limit handling
|
||||
default_settings = {
|
||||
|
|
@ -197,6 +247,21 @@ class LLMRouterService:
|
|||
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def is_premium_model(cls, model_string: str) -> bool:
|
||||
"""Return True if *model_string* (as reported by LiteLLM) belongs to a
|
||||
premium-tier deployment in the router pool."""
|
||||
instance = cls.get_instance()
|
||||
return model_string in instance._premium_model_strings
|
||||
|
||||
@classmethod
|
||||
def compute_premium_tokens(cls, calls: list) -> int:
|
||||
"""Sum ``total_tokens`` for calls whose model is premium."""
|
||||
instance = cls.get_instance()
|
||||
return sum(
|
||||
c.total_tokens for c in calls if c.model in instance._premium_model_strings
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_context_fallback_groups(
|
||||
cls, model_list: list[dict]
|
||||
|
|
@ -1044,10 +1109,12 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
result.append({"role": "user", "content": msg.content})
|
||||
elif isinstance(msg, AIMsg):
|
||||
ai_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if msg.content:
|
||||
ai_msg["content"] = msg.content
|
||||
# Handle tool calls
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
has_tool_calls = hasattr(msg, "tool_calls") and msg.tool_calls
|
||||
|
||||
sanitized = _sanitize_content(msg.content) if msg.content else ""
|
||||
ai_msg["content"] = sanitized if sanitized else ""
|
||||
|
||||
if has_tool_calls:
|
||||
ai_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc.get("id", ""),
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from langchain_litellm import ChatLiteLLM
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
from app.config import config
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
from app.services.llm_router_service import (
|
||||
|
|
@ -150,7 +151,7 @@ async def validate_llm_config(
|
|||
if litellm_params:
|
||||
litellm_kwargs.update(litellm_params)
|
||||
|
||||
llm = ChatLiteLLM(**litellm_kwargs)
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
# Make a simple test call
|
||||
test_message = HumanMessage(content="Hello")
|
||||
|
|
@ -302,7 +303,7 @@ async def get_search_space_llm_instance(
|
|||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
# Get the LLM configuration from database (NewLLMConfig)
|
||||
result = await session.execute(
|
||||
|
|
@ -379,7 +380,7 @@ async def get_search_space_llm_instance(
|
|||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
@ -480,7 +481,7 @@ async def get_vision_llm(
|
|||
if global_cfg.get("litellm_params"):
|
||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).where(
|
||||
|
|
@ -513,7 +514,7 @@ async def get_vision_llm(
|
|||
if vision_cfg.litellm_params:
|
||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
|
|||
|
|
@ -86,12 +86,34 @@ def _is_text_output_model(model: dict) -> bool:
|
|||
return output_mods == ["text"]
|
||||
|
||||
|
||||
def _supports_tool_calling(model: dict) -> bool:
|
||||
"""Return True if the model supports function/tool calling."""
|
||||
supported = model.get("supported_parameters") or []
|
||||
return "tools" in supported
|
||||
|
||||
|
||||
MIN_CONTEXT_LENGTH = 100_000
|
||||
|
||||
|
||||
def _has_sufficient_context(model: dict) -> bool:
|
||||
"""Return True if the model's context window is at least MIN_CONTEXT_LENGTH."""
|
||||
ctx = model.get("context_length") or 0
|
||||
return ctx >= MIN_CONTEXT_LENGTH
|
||||
|
||||
|
||||
def _is_allowed_model(model: dict) -> bool:
|
||||
"""Reuse the exclusion list from the OpenRouter integration service."""
|
||||
from app.services.openrouter_integration_service import _is_allowed_model as _check
|
||||
|
||||
return _check(model)
|
||||
|
||||
|
||||
def _process_models(raw_models: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Transform raw OpenRouter model entries into a flat list of
|
||||
{value, label, provider, context_window} dicts.
|
||||
|
||||
Only text-output models are included (audio/image generators are skipped).
|
||||
Only text-output models with tool-calling support are included.
|
||||
|
||||
Each OpenRouter model is emitted once for OPENROUTER (full id) and,
|
||||
when the slug maps to a native provider, once more with just the
|
||||
|
|
@ -110,6 +132,15 @@ def _process_models(raw_models: list[dict]) -> list[dict]:
|
|||
if not _is_text_output_model(model):
|
||||
continue
|
||||
|
||||
if not _supports_tool_calling(model):
|
||||
continue
|
||||
|
||||
if not _has_sufficient_context(model):
|
||||
continue
|
||||
|
||||
if not _is_allowed_model(model):
|
||||
continue
|
||||
|
||||
provider_slug, model_name = model_id.split("/", 1)
|
||||
context_window = _format_context_length(context_length)
|
||||
|
||||
|
|
|
|||
291
surfsense_backend/app/services/openrouter_integration_service.py
Normal file
291
surfsense_backend/app/services/openrouter_integration_service.py
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
"""
|
||||
OpenRouter Integration Service
|
||||
|
||||
Dynamically fetches all available models from the OpenRouter public API
|
||||
and generates virtual global LLM config entries. These entries are injected
|
||||
into config.GLOBAL_LLM_CONFIGS so they appear alongside static YAML configs
|
||||
in the model selector.
|
||||
|
||||
All actual LLM calls go through LiteLLM with the ``openrouter/`` prefix --
|
||||
this service only manages the catalogue, not the inference path.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
|
||||
|
||||
# Sentinel value stored on each generated config so we can distinguish
|
||||
# dynamic OpenRouter entries from hand-written YAML entries during refresh.
|
||||
_OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__"
|
||||
|
||||
|
||||
def _is_text_output_model(model: dict) -> bool:
|
||||
"""Return True if the model produces text output only (skip image/audio generators)."""
|
||||
output_mods = model.get("architecture", {}).get("output_modalities", [])
|
||||
return output_mods == ["text"]
|
||||
|
||||
|
||||
def _supports_tool_calling(model: dict) -> bool:
|
||||
"""Return True if the model supports function/tool calling."""
|
||||
supported = model.get("supported_parameters") or []
|
||||
return "tools" in supported
|
||||
|
||||
|
||||
MIN_CONTEXT_LENGTH = 100_000
|
||||
|
||||
# Provider slugs whose backend is fundamentally incompatible with our agent's
|
||||
# tool-call message flow (e.g. Amazon Bedrock requires toolConfig alongside
|
||||
# tool history which OpenRouter doesn't relay).
|
||||
_EXCLUDED_PROVIDER_SLUGS = {"amazon"}
|
||||
|
||||
_EXCLUDED_MODEL_IDS: set[str] = {
|
||||
# Deprecated / removed upstream
|
||||
"openai/gpt-4-1106-preview",
|
||||
"openai/gpt-4-turbo-preview",
|
||||
# Permanently no-capacity variant
|
||||
"openai/gpt-4o:extended",
|
||||
# Non-serverless model that requires a dedicated endpoint
|
||||
"arcee-ai/virtuoso-large",
|
||||
# Deep-research models reject standard params (temperature, etc.)
|
||||
"openai/o3-deep-research",
|
||||
"openai/o4-mini-deep-research",
|
||||
}
|
||||
|
||||
_EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
|
||||
|
||||
|
||||
def _has_sufficient_context(model: dict) -> bool:
|
||||
"""Return True if the model's context window is at least MIN_CONTEXT_LENGTH."""
|
||||
ctx = model.get("context_length") or 0
|
||||
return ctx >= MIN_CONTEXT_LENGTH
|
||||
|
||||
|
||||
def _is_compatible_provider(model: dict) -> bool:
|
||||
"""Return False for models from providers known to be incompatible."""
|
||||
model_id = model.get("id", "")
|
||||
slug = model_id.split("/", 1)[0] if "/" in model_id else ""
|
||||
return slug not in _EXCLUDED_PROVIDER_SLUGS
|
||||
|
||||
|
||||
def _is_allowed_model(model: dict) -> bool:
|
||||
"""Return False for specific model IDs known to be broken or incompatible."""
|
||||
model_id = model.get("id", "")
|
||||
if model_id in _EXCLUDED_MODEL_IDS:
|
||||
return False
|
||||
base_id = model_id.split(":")[0]
|
||||
return not base_id.endswith(_EXCLUDED_MODEL_SUFFIXES)
|
||||
|
||||
|
||||
def _fetch_models_sync() -> list[dict] | None:
|
||||
"""Synchronous fetch for use during startup (before the event loop is running)."""
|
||||
try:
|
||||
with httpx.Client(timeout=20) as client:
|
||||
response = client.get(OPENROUTER_API_URL)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("data", [])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch OpenRouter models (sync): %s", e)
|
||||
return None
|
||||
|
||||
|
||||
async def _fetch_models_async() -> list[dict] | None:
|
||||
"""Async fetch for background refresh."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=20) as client:
|
||||
response = await client.get(OPENROUTER_API_URL)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("data", [])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch OpenRouter models (async): %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _generate_configs(
|
||||
raw_models: list[dict],
|
||||
settings: dict[str, Any],
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Convert raw OpenRouter model entries into global LLM config dicts.
|
||||
|
||||
Models are sorted by ID for deterministic, stable ID assignment across
|
||||
restarts and refreshes.
|
||||
"""
|
||||
id_offset: int = settings.get("id_offset", -10000)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
billing_tier: str = settings.get("billing_tier", "premium")
|
||||
anonymous_enabled: bool = settings.get("anonymous_enabled", False)
|
||||
seo_enabled: bool = settings.get("seo_enabled", False)
|
||||
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
tpm: int = settings.get("tpm", 1000000)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
system_instructions: str = settings.get("system_instructions", "")
|
||||
use_default: bool = settings.get("use_default_system_instructions", True)
|
||||
citations_enabled: bool = settings.get("citations_enabled", True)
|
||||
|
||||
text_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_text_output_model(m)
|
||||
and _supports_tool_calling(m)
|
||||
and _has_sufficient_context(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
text_models.sort(key=lambda m: m["id"])
|
||||
|
||||
configs: list[dict] = []
|
||||
for idx, model in enumerate(text_models):
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": id_offset - idx,
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter",
|
||||
"billing_tier": billing_tier,
|
||||
"anonymous_enabled": anonymous_enabled,
|
||||
"seo_enabled": seo_enabled,
|
||||
"seo_slug": None,
|
||||
"quota_reserve_tokens": quota_reserve_tokens,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "",
|
||||
"rpm": rpm,
|
||||
"tpm": tpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"system_instructions": system_instructions,
|
||||
"use_default_system_instructions": use_default,
|
||||
"citations_enabled": citations_enabled,
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
class OpenRouterIntegrationService:
|
||||
"""Singleton that manages the dynamic OpenRouter model catalogue."""
|
||||
|
||||
_instance: "OpenRouterIntegrationService | None" = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._settings: dict[str, Any] = {}
|
||||
self._configs: list[dict] = []
|
||||
self._configs_by_id: dict[int, dict] = {}
|
||||
self._initialized = False
|
||||
self._refresh_task: asyncio.Task | None = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "OpenRouterIntegrationService":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def is_initialized(cls) -> bool:
|
||||
return cls._instance is not None and cls._instance._initialized
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Initialisation (called at startup, before event loop for Celery)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def initialize(self, settings: dict[str, Any]) -> list[dict]:
|
||||
"""
|
||||
Fetch models synchronously and generate configs.
|
||||
Returns the generated configs list.
|
||||
"""
|
||||
self._settings = settings
|
||||
raw_models = _fetch_models_sync()
|
||||
if raw_models is None:
|
||||
logger.warning("OpenRouter integration: could not fetch models at startup")
|
||||
self._initialized = True
|
||||
return []
|
||||
|
||||
self._configs = _generate_configs(raw_models, settings)
|
||||
self._configs_by_id = {c["id"]: c for c in self._configs}
|
||||
self._initialized = True
|
||||
|
||||
logger.info(
|
||||
"OpenRouter integration: loaded %d models (IDs %d to %d)",
|
||||
len(self._configs),
|
||||
self._configs[0]["id"] if self._configs else 0,
|
||||
self._configs[-1]["id"] if self._configs else 0,
|
||||
)
|
||||
return self._configs
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background refresh
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def refresh(self) -> None:
|
||||
"""Re-fetch from OpenRouter and atomically swap configs in GLOBAL_LLM_CONFIGS."""
|
||||
raw_models = await _fetch_models_async()
|
||||
if raw_models is None:
|
||||
logger.warning("OpenRouter refresh: fetch failed, keeping stale list")
|
||||
return
|
||||
|
||||
new_configs = _generate_configs(raw_models, self._settings)
|
||||
new_by_id = {c["id"]: c for c in new_configs}
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
static_configs = [
|
||||
c
|
||||
for c in app_config.GLOBAL_LLM_CONFIGS
|
||||
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||
]
|
||||
app_config.GLOBAL_LLM_CONFIGS = static_configs + new_configs
|
||||
|
||||
self._configs = new_configs
|
||||
self._configs_by_id = new_by_id
|
||||
|
||||
logger.info("OpenRouter refresh: updated to %d models", len(new_configs))
|
||||
|
||||
async def _refresh_loop(self, interval_hours: float) -> None:
|
||||
interval_sec = interval_hours * 3600
|
||||
while True:
|
||||
await asyncio.sleep(interval_sec)
|
||||
try:
|
||||
await self.refresh()
|
||||
except Exception:
|
||||
logger.exception("OpenRouter background refresh failed")
|
||||
|
||||
def start_background_refresh(self, interval_hours: float) -> None:
|
||||
if interval_hours <= 0:
|
||||
return
|
||||
loop = asyncio.get_event_loop()
|
||||
self._refresh_task = loop.create_task(self._refresh_loop(interval_hours))
|
||||
logger.info(
|
||||
"OpenRouter background refresh started (every %.1fh)", interval_hours
|
||||
)
|
||||
|
||||
def stop_background_refresh(self) -> None:
|
||||
if self._refresh_task is not None and not self._refresh_task.done():
|
||||
self._refresh_task.cancel()
|
||||
self._refresh_task = None
|
||||
logger.info("OpenRouter background refresh stopped")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Accessors
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_configs(self) -> list[dict]:
|
||||
return self._configs
|
||||
|
||||
def get_config_by_id(self, config_id: int) -> dict | None:
|
||||
return self._configs_by_id.get(config_id)
|
||||
Loading…
Add table
Add a link
Reference in a new issue