cloud: added openrouter integration with global configs

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-15 23:46:29 -07:00
parent ff4e0f9b62
commit 4a51ccdc2c
26 changed files with 911 additions and 178 deletions

View file

@ -49,6 +49,49 @@ def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool:
return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc)))
_UNIVERSAL_CONTENT_TYPES = {
"text",
"image_url",
"input_audio",
"refusal",
"audio",
"file",
}
def _sanitize_content(content: Any) -> Any:
"""Normalise a LangChain message ``content`` field so it is safe for any
downstream provider (Azure, OpenAI, OpenRouter, etc.).
* Strips provider-specific block types (e.g. ``thinking`` from reasoning models).
* Removes text blocks with blank text (Bedrock rejects ``{"type":"text","text":""}``)
* Converts bare strings inside a list to ``{"type": "text", "text": ...}`` objects
(Azure rejects raw strings in a content array).
* Collapses a single-text-block list to a plain string for maximum compatibility.
"""
if not isinstance(content, list):
return content
filtered: list[dict] = []
for block in content:
if isinstance(block, str):
if block:
filtered.append({"type": "text", "text": block})
elif isinstance(block, dict):
block_type = block.get("type", "text")
if block_type not in _UNIVERSAL_CONTENT_TYPES:
continue
if block_type == "text" and not block.get("text"):
continue
filtered.append(block)
if not filtered:
return ""
if len(filtered) == 1 and filtered[0].get("type") == "text":
return filtered[0].get("text", "")
return filtered
# Special ID for Auto mode - uses router for load balancing
AUTO_MODE_ID = 0
@ -103,6 +146,7 @@ class LLMRouterService:
_model_list: list[dict] = []
_router_settings: dict = {}
_initialized: bool = False
_premium_model_strings: set[str] = set()
def __new__(cls):
if cls._instance is None:
@ -135,22 +179,28 @@ class LLMRouterService:
logger.debug("LLM Router already initialized, skipping")
return
auto_configs = [
c for c in global_configs if c.get("billing_tier", "free") != "premium"
]
model_list = []
for config in auto_configs:
premium_models: set[str] = set()
for config in global_configs:
deployment = cls._config_to_deployment(config)
if deployment:
model_list.append(deployment)
if config.get("billing_tier") == "premium":
model_string = deployment["litellm_params"]["model"]
premium_models.add(model_string)
if not model_list:
logger.warning("No valid LLM configs found for router initialization")
return
instance._model_list = model_list
instance._premium_model_strings = premium_models
instance._router_settings = router_settings or {}
logger.info(
"Router pool: %d deployments (%d premium)",
len(model_list),
len(premium_models),
)
# Default router settings optimized for rate limit handling
default_settings = {
@ -197,6 +247,21 @@ class LLMRouterService:
logger.error(f"Failed to initialize LLM Router: {e}")
instance._router = None
@classmethod
def is_premium_model(cls, model_string: str) -> bool:
"""Return True if *model_string* (as reported by LiteLLM) belongs to a
premium-tier deployment in the router pool."""
instance = cls.get_instance()
return model_string in instance._premium_model_strings
@classmethod
def compute_premium_tokens(cls, calls: list) -> int:
"""Sum ``total_tokens`` for calls whose model is premium."""
instance = cls.get_instance()
return sum(
c.total_tokens for c in calls if c.model in instance._premium_model_strings
)
@classmethod
def _build_context_fallback_groups(
cls, model_list: list[dict]
@ -1044,10 +1109,12 @@ class ChatLiteLLMRouter(BaseChatModel):
result.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMsg):
ai_msg: dict[str, Any] = {"role": "assistant"}
if msg.content:
ai_msg["content"] = msg.content
# Handle tool calls
if hasattr(msg, "tool_calls") and msg.tool_calls:
has_tool_calls = hasattr(msg, "tool_calls") and msg.tool_calls
sanitized = _sanitize_content(msg.content) if msg.content else ""
ai_msg["content"] = sanitized if sanitized else ""
if has_tool_calls:
ai_msg["tool_calls"] = [
{
"id": tc.get("id", ""),

View file

@ -6,6 +6,7 @@ from langchain_litellm import ChatLiteLLM
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
from app.config import config
from app.db import NewLLMConfig, SearchSpace
from app.services.llm_router_service import (
@ -150,7 +151,7 @@ async def validate_llm_config(
if litellm_params:
litellm_kwargs.update(litellm_params)
llm = ChatLiteLLM(**litellm_kwargs)
llm = SanitizedChatLiteLLM(**litellm_kwargs)
# Make a simple test call
test_message = HumanMessage(content="Hello")
@ -302,7 +303,7 @@ async def get_search_space_llm_instance(
if disable_streaming:
litellm_kwargs["disable_streaming"] = True
return ChatLiteLLM(**litellm_kwargs)
return SanitizedChatLiteLLM(**litellm_kwargs)
# Get the LLM configuration from database (NewLLMConfig)
result = await session.execute(
@ -379,7 +380,7 @@ async def get_search_space_llm_instance(
if disable_streaming:
litellm_kwargs["disable_streaming"] = True
return ChatLiteLLM(**litellm_kwargs)
return SanitizedChatLiteLLM(**litellm_kwargs)
except Exception as e:
logger.error(
@ -480,7 +481,7 @@ async def get_vision_llm(
if global_cfg.get("litellm_params"):
litellm_kwargs.update(global_cfg["litellm_params"])
return ChatLiteLLM(**litellm_kwargs)
return SanitizedChatLiteLLM(**litellm_kwargs)
result = await session.execute(
select(VisionLLMConfig).where(
@ -513,7 +514,7 @@ async def get_vision_llm(
if vision_cfg.litellm_params:
litellm_kwargs.update(vision_cfg.litellm_params)
return ChatLiteLLM(**litellm_kwargs)
return SanitizedChatLiteLLM(**litellm_kwargs)
except Exception as e:
logger.error(

View file

@ -86,12 +86,34 @@ def _is_text_output_model(model: dict) -> bool:
return output_mods == ["text"]
def _supports_tool_calling(model: dict) -> bool:
"""Return True if the model supports function/tool calling."""
supported = model.get("supported_parameters") or []
return "tools" in supported
MIN_CONTEXT_LENGTH = 100_000
def _has_sufficient_context(model: dict) -> bool:
"""Return True if the model's context window is at least MIN_CONTEXT_LENGTH."""
ctx = model.get("context_length") or 0
return ctx >= MIN_CONTEXT_LENGTH
def _is_allowed_model(model: dict) -> bool:
"""Reuse the exclusion list from the OpenRouter integration service."""
from app.services.openrouter_integration_service import _is_allowed_model as _check
return _check(model)
def _process_models(raw_models: list[dict]) -> list[dict]:
"""
Transform raw OpenRouter model entries into a flat list of
{value, label, provider, context_window} dicts.
Only text-output models are included (audio/image generators are skipped).
Only text-output models with tool-calling support are included.
Each OpenRouter model is emitted once for OPENROUTER (full id) and,
when the slug maps to a native provider, once more with just the
@ -110,6 +132,15 @@ def _process_models(raw_models: list[dict]) -> list[dict]:
if not _is_text_output_model(model):
continue
if not _supports_tool_calling(model):
continue
if not _has_sufficient_context(model):
continue
if not _is_allowed_model(model):
continue
provider_slug, model_name = model_id.split("/", 1)
context_window = _format_context_length(context_length)

View file

@ -0,0 +1,291 @@
"""
OpenRouter Integration Service
Dynamically fetches all available models from the OpenRouter public API
and generates virtual global LLM config entries. These entries are injected
into config.GLOBAL_LLM_CONFIGS so they appear alongside static YAML configs
in the model selector.
All actual LLM calls go through LiteLLM with the ``openrouter/`` prefix --
this service only manages the catalogue, not the inference path.
"""
import asyncio
import logging
import threading
from typing import Any
import httpx
logger = logging.getLogger(__name__)
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
# Sentinel value stored on each generated config so we can distinguish
# dynamic OpenRouter entries from hand-written YAML entries during refresh.
_OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__"
def _is_text_output_model(model: dict) -> bool:
"""Return True if the model produces text output only (skip image/audio generators)."""
output_mods = model.get("architecture", {}).get("output_modalities", [])
return output_mods == ["text"]
def _supports_tool_calling(model: dict) -> bool:
"""Return True if the model supports function/tool calling."""
supported = model.get("supported_parameters") or []
return "tools" in supported
MIN_CONTEXT_LENGTH = 100_000
# Provider slugs whose backend is fundamentally incompatible with our agent's
# tool-call message flow (e.g. Amazon Bedrock requires toolConfig alongside
# tool history which OpenRouter doesn't relay).
_EXCLUDED_PROVIDER_SLUGS = {"amazon"}
_EXCLUDED_MODEL_IDS: set[str] = {
# Deprecated / removed upstream
"openai/gpt-4-1106-preview",
"openai/gpt-4-turbo-preview",
# Permanently no-capacity variant
"openai/gpt-4o:extended",
# Non-serverless model that requires a dedicated endpoint
"arcee-ai/virtuoso-large",
# Deep-research models reject standard params (temperature, etc.)
"openai/o3-deep-research",
"openai/o4-mini-deep-research",
}
_EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
def _has_sufficient_context(model: dict) -> bool:
"""Return True if the model's context window is at least MIN_CONTEXT_LENGTH."""
ctx = model.get("context_length") or 0
return ctx >= MIN_CONTEXT_LENGTH
def _is_compatible_provider(model: dict) -> bool:
"""Return False for models from providers known to be incompatible."""
model_id = model.get("id", "")
slug = model_id.split("/", 1)[0] if "/" in model_id else ""
return slug not in _EXCLUDED_PROVIDER_SLUGS
def _is_allowed_model(model: dict) -> bool:
"""Return False for specific model IDs known to be broken or incompatible."""
model_id = model.get("id", "")
if model_id in _EXCLUDED_MODEL_IDS:
return False
base_id = model_id.split(":")[0]
return not base_id.endswith(_EXCLUDED_MODEL_SUFFIXES)
def _fetch_models_sync() -> list[dict] | None:
"""Synchronous fetch for use during startup (before the event loop is running)."""
try:
with httpx.Client(timeout=20) as client:
response = client.get(OPENROUTER_API_URL)
response.raise_for_status()
data = response.json()
return data.get("data", [])
except Exception as e:
logger.warning("Failed to fetch OpenRouter models (sync): %s", e)
return None
async def _fetch_models_async() -> list[dict] | None:
"""Async fetch for background refresh."""
try:
async with httpx.AsyncClient(timeout=20) as client:
response = await client.get(OPENROUTER_API_URL)
response.raise_for_status()
data = response.json()
return data.get("data", [])
except Exception as e:
logger.warning("Failed to fetch OpenRouter models (async): %s", e)
return None
def _generate_configs(
raw_models: list[dict],
settings: dict[str, Any],
) -> list[dict]:
"""
Convert raw OpenRouter model entries into global LLM config dicts.
Models are sorted by ID for deterministic, stable ID assignment across
restarts and refreshes.
"""
id_offset: int = settings.get("id_offset", -10000)
api_key: str = settings.get("api_key", "")
billing_tier: str = settings.get("billing_tier", "premium")
anonymous_enabled: bool = settings.get("anonymous_enabled", False)
seo_enabled: bool = settings.get("seo_enabled", False)
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
rpm: int = settings.get("rpm", 200)
tpm: int = settings.get("tpm", 1000000)
litellm_params: dict = settings.get("litellm_params") or {}
system_instructions: str = settings.get("system_instructions", "")
use_default: bool = settings.get("use_default_system_instructions", True)
citations_enabled: bool = settings.get("citations_enabled", True)
text_models = [
m
for m in raw_models
if _is_text_output_model(m)
and _supports_tool_calling(m)
and _has_sufficient_context(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
text_models.sort(key=lambda m: m["id"])
configs: list[dict] = []
for idx, model in enumerate(text_models):
model_id: str = model["id"]
name: str = model.get("name", model_id)
cfg: dict[str, Any] = {
"id": id_offset - idx,
"name": name,
"description": f"{name} via OpenRouter",
"billing_tier": billing_tier,
"anonymous_enabled": anonymous_enabled,
"seo_enabled": seo_enabled,
"seo_slug": None,
"quota_reserve_tokens": quota_reserve_tokens,
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
"api_base": "",
"rpm": rpm,
"tpm": tpm,
"litellm_params": dict(litellm_params),
"system_instructions": system_instructions,
"use_default_system_instructions": use_default,
"citations_enabled": citations_enabled,
_OPENROUTER_DYNAMIC_MARKER: True,
}
configs.append(cfg)
return configs
class OpenRouterIntegrationService:
"""Singleton that manages the dynamic OpenRouter model catalogue."""
_instance: "OpenRouterIntegrationService | None" = None
_lock = threading.Lock()
def __init__(self) -> None:
self._settings: dict[str, Any] = {}
self._configs: list[dict] = []
self._configs_by_id: dict[int, dict] = {}
self._initialized = False
self._refresh_task: asyncio.Task | None = None
@classmethod
def get_instance(cls) -> "OpenRouterIntegrationService":
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def is_initialized(cls) -> bool:
return cls._instance is not None and cls._instance._initialized
# ------------------------------------------------------------------
# Initialisation (called at startup, before event loop for Celery)
# ------------------------------------------------------------------
def initialize(self, settings: dict[str, Any]) -> list[dict]:
"""
Fetch models synchronously and generate configs.
Returns the generated configs list.
"""
self._settings = settings
raw_models = _fetch_models_sync()
if raw_models is None:
logger.warning("OpenRouter integration: could not fetch models at startup")
self._initialized = True
return []
self._configs = _generate_configs(raw_models, settings)
self._configs_by_id = {c["id"]: c for c in self._configs}
self._initialized = True
logger.info(
"OpenRouter integration: loaded %d models (IDs %d to %d)",
len(self._configs),
self._configs[0]["id"] if self._configs else 0,
self._configs[-1]["id"] if self._configs else 0,
)
return self._configs
# ------------------------------------------------------------------
# Background refresh
# ------------------------------------------------------------------
async def refresh(self) -> None:
"""Re-fetch from OpenRouter and atomically swap configs in GLOBAL_LLM_CONFIGS."""
raw_models = await _fetch_models_async()
if raw_models is None:
logger.warning("OpenRouter refresh: fetch failed, keeping stale list")
return
new_configs = _generate_configs(raw_models, self._settings)
new_by_id = {c["id"]: c for c in new_configs}
from app.config import config as app_config
static_configs = [
c
for c in app_config.GLOBAL_LLM_CONFIGS
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
]
app_config.GLOBAL_LLM_CONFIGS = static_configs + new_configs
self._configs = new_configs
self._configs_by_id = new_by_id
logger.info("OpenRouter refresh: updated to %d models", len(new_configs))
async def _refresh_loop(self, interval_hours: float) -> None:
interval_sec = interval_hours * 3600
while True:
await asyncio.sleep(interval_sec)
try:
await self.refresh()
except Exception:
logger.exception("OpenRouter background refresh failed")
def start_background_refresh(self, interval_hours: float) -> None:
if interval_hours <= 0:
return
loop = asyncio.get_event_loop()
self._refresh_task = loop.create_task(self._refresh_loop(interval_hours))
logger.info(
"OpenRouter background refresh started (every %.1fh)", interval_hours
)
def stop_background_refresh(self) -> None:
if self._refresh_task is not None and not self._refresh_task.done():
self._refresh_task.cancel()
self._refresh_task = None
logger.info("OpenRouter background refresh stopped")
# ------------------------------------------------------------------
# Accessors
# ------------------------------------------------------------------
def get_configs(self) -> list[dict]:
return self._configs
def get_config_by_id(self, config_id: int) -> dict | None:
return self._configs_by_id.get(config_id)