mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-27 01:36:30 +02:00
cloud: added openrouter integration with global configs
This commit is contained in:
parent
ff4e0f9b62
commit
4a51ccdc2c
26 changed files with 911 additions and 178 deletions
|
|
@ -10,10 +10,18 @@ It also provides utilities for creating ChatLiteLLM instances and
|
|||
managing prompt configurations.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from litellm import get_model_info
|
||||
from sqlalchemy import select
|
||||
|
|
@ -23,10 +31,64 @@ from app.services.llm_router_service import (
|
|||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
_sanitize_content,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Sanitize content on every message so it is safe for any provider.
|
||||
|
||||
Handles three cross-provider incompatibilities:
|
||||
- List content with provider-specific blocks (e.g. ``thinking``)
|
||||
- List content with bare strings or empty text blocks
|
||||
- AI messages with empty content + tool calls: some providers (Bedrock)
|
||||
convert ``""`` to ``[{"type":"text","text":""}]`` server-side then
|
||||
reject the blank text. The OpenAI spec says ``content`` should be
|
||||
``null`` when an assistant message only carries tool calls.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, list):
|
||||
msg.content = _sanitize_content(msg.content)
|
||||
if (
|
||||
isinstance(msg, AIMessage)
|
||||
and (not msg.content or msg.content == "")
|
||||
and getattr(msg, "tool_calls", None)
|
||||
):
|
||||
msg.content = None # type: ignore[assignment]
|
||||
return messages
|
||||
|
||||
|
||||
class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||
"""ChatLiteLLM subclass that strips provider-specific content blocks
|
||||
(e.g. ``thinking`` from reasoning models) and normalises bare strings
|
||||
in content arrays before forwarding to the underlying provider."""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return super()._generate(
|
||||
_sanitize_messages(messages), stop, run_manager, **kwargs
|
||||
)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
async for chunk in super()._astream(
|
||||
_sanitize_messages(messages), stop, run_manager, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
# Provider mapping for LiteLLM model string construction
|
||||
PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
|
|
@ -252,6 +314,28 @@ def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
|
|||
return None
|
||||
|
||||
|
||||
def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
|
||||
"""
|
||||
Load a global LLM config by ID, checking in-memory configs first.
|
||||
|
||||
This handles both static YAML configs and dynamically injected configs
|
||||
(e.g. OpenRouter integration models that only exist in memory).
|
||||
|
||||
Args:
|
||||
llm_config_id: The negative ID of the global config to load
|
||||
|
||||
Returns:
|
||||
LLM config dict or None if not found
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == llm_config_id:
|
||||
return cfg
|
||||
# Fallback to YAML file read (covers edge cases like hot-reload)
|
||||
return load_llm_config_from_yaml(llm_config_id)
|
||||
|
||||
|
||||
async def load_new_llm_config_from_db(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
|
|
@ -359,7 +443,13 @@ async def load_agent_config(
|
|||
return AgentConfig.from_auto_mode()
|
||||
|
||||
if config_id < 0:
|
||||
# Load from YAML (global configs have negative IDs)
|
||||
# Check in-memory configs first (includes static YAML + dynamic OpenRouter)
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return AgentConfig.from_yaml_config(cfg)
|
||||
# Fallback to YAML file read for safety
|
||||
yaml_config = load_llm_config_from_yaml(config_id)
|
||||
if yaml_config:
|
||||
return AgentConfig.from_yaml_config(yaml_config)
|
||||
|
|
@ -402,7 +492,7 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
|||
if llm_config.get("litellm_params"):
|
||||
litellm_kwargs.update(llm_config["litellm_params"])
|
||||
|
||||
llm = ChatLiteLLM(**litellm_kwargs)
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
_attach_model_profile(llm, model_string)
|
||||
return llm
|
||||
|
||||
|
|
@ -457,6 +547,6 @@ def create_chat_litellm_from_agent_config(
|
|||
if agent_config.litellm_params:
|
||||
litellm_kwargs.update(agent_config.litellm_params)
|
||||
|
||||
llm = ChatLiteLLM(**litellm_kwargs)
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
_attach_model_profile(llm, model_string)
|
||||
return llm
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue