mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-01 11:56:25 +02:00
cloud: added openrouter integration with global configs
This commit is contained in:
parent
ff4e0f9b62
commit
4a51ccdc2c
26 changed files with 911 additions and 178 deletions
|
|
@ -49,6 +49,49 @@ def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool:
|
|||
return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc)))
|
||||
|
||||
|
||||
_UNIVERSAL_CONTENT_TYPES = {
|
||||
"text",
|
||||
"image_url",
|
||||
"input_audio",
|
||||
"refusal",
|
||||
"audio",
|
||||
"file",
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_content(content: Any) -> Any:
|
||||
"""Normalise a LangChain message ``content`` field so it is safe for any
|
||||
downstream provider (Azure, OpenAI, OpenRouter, etc.).
|
||||
|
||||
* Strips provider-specific block types (e.g. ``thinking`` from reasoning models).
|
||||
* Removes text blocks with blank text (Bedrock rejects ``{"type":"text","text":""}``)
|
||||
* Converts bare strings inside a list to ``{"type": "text", "text": ...}`` objects
|
||||
(Azure rejects raw strings in a content array).
|
||||
* Collapses a single-text-block list to a plain string for maximum compatibility.
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return content
|
||||
|
||||
filtered: list[dict] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
if block:
|
||||
filtered.append({"type": "text", "text": block})
|
||||
elif isinstance(block, dict):
|
||||
block_type = block.get("type", "text")
|
||||
if block_type not in _UNIVERSAL_CONTENT_TYPES:
|
||||
continue
|
||||
if block_type == "text" and not block.get("text"):
|
||||
continue
|
||||
filtered.append(block)
|
||||
|
||||
if not filtered:
|
||||
return ""
|
||||
if len(filtered) == 1 and filtered[0].get("type") == "text":
|
||||
return filtered[0].get("text", "")
|
||||
return filtered
|
||||
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
AUTO_MODE_ID = 0
|
||||
|
||||
|
|
@ -103,6 +146,7 @@ class LLMRouterService:
|
|||
_model_list: list[dict] = []
|
||||
_router_settings: dict = {}
|
||||
_initialized: bool = False
|
||||
_premium_model_strings: set[str] = set()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
|
|
@ -135,22 +179,28 @@ class LLMRouterService:
|
|||
logger.debug("LLM Router already initialized, skipping")
|
||||
return
|
||||
|
||||
auto_configs = [
|
||||
c for c in global_configs if c.get("billing_tier", "free") != "premium"
|
||||
]
|
||||
|
||||
model_list = []
|
||||
for config in auto_configs:
|
||||
premium_models: set[str] = set()
|
||||
for config in global_configs:
|
||||
deployment = cls._config_to_deployment(config)
|
||||
if deployment:
|
||||
model_list.append(deployment)
|
||||
if config.get("billing_tier") == "premium":
|
||||
model_string = deployment["litellm_params"]["model"]
|
||||
premium_models.add(model_string)
|
||||
|
||||
if not model_list:
|
||||
logger.warning("No valid LLM configs found for router initialization")
|
||||
return
|
||||
|
||||
instance._model_list = model_list
|
||||
instance._premium_model_strings = premium_models
|
||||
instance._router_settings = router_settings or {}
|
||||
logger.info(
|
||||
"Router pool: %d deployments (%d premium)",
|
||||
len(model_list),
|
||||
len(premium_models),
|
||||
)
|
||||
|
||||
# Default router settings optimized for rate limit handling
|
||||
default_settings = {
|
||||
|
|
@ -197,6 +247,21 @@ class LLMRouterService:
|
|||
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def is_premium_model(cls, model_string: str) -> bool:
|
||||
"""Return True if *model_string* (as reported by LiteLLM) belongs to a
|
||||
premium-tier deployment in the router pool."""
|
||||
instance = cls.get_instance()
|
||||
return model_string in instance._premium_model_strings
|
||||
|
||||
@classmethod
|
||||
def compute_premium_tokens(cls, calls: list) -> int:
|
||||
"""Sum ``total_tokens`` for calls whose model is premium."""
|
||||
instance = cls.get_instance()
|
||||
return sum(
|
||||
c.total_tokens for c in calls if c.model in instance._premium_model_strings
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_context_fallback_groups(
|
||||
cls, model_list: list[dict]
|
||||
|
|
@ -1044,10 +1109,12 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
result.append({"role": "user", "content": msg.content})
|
||||
elif isinstance(msg, AIMsg):
|
||||
ai_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if msg.content:
|
||||
ai_msg["content"] = msg.content
|
||||
# Handle tool calls
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
has_tool_calls = hasattr(msg, "tool_calls") and msg.tool_calls
|
||||
|
||||
sanitized = _sanitize_content(msg.content) if msg.content else ""
|
||||
ai_msg["content"] = sanitized if sanitized else ""
|
||||
|
||||
if has_tool_calls:
|
||||
ai_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc.get("id", ""),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue