cloud: added openrouter integration with global configs

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-15 23:46:29 -07:00
parent ff4e0f9b62
commit 4a51ccdc2c
26 changed files with 911 additions and 178 deletions

View file

@ -22,7 +22,11 @@ from .chat_deepagent import create_surfsense_deep_agent
from .context import SurfSenseContextSchema
# LLM config
from .llm_config import create_chat_litellm_from_config, load_llm_config_from_yaml
from .llm_config import (
create_chat_litellm_from_config,
load_global_llm_config_by_id,
load_llm_config_from_yaml,
)
# Middleware
from .middleware import (
@ -81,6 +85,7 @@ __all__ = [
"get_all_tool_names",
"get_default_enabled_tools",
"get_tool_by_name",
"load_global_llm_config_by_id",
"load_llm_config_from_yaml",
"search_knowledge_base_async",
]

View file

@ -10,10 +10,18 @@ It also provides utilities for creating ChatLiteLLM instances and
managing prompt configurations.
"""
from collections.abc import AsyncIterator
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import yaml
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_litellm import ChatLiteLLM
from litellm import get_model_info
from sqlalchemy import select
@ -23,10 +31,64 @@ from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
LLMRouterService,
_sanitize_content,
get_auto_mode_llm,
is_auto_mode,
)
def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
"""Sanitize content on every message so it is safe for any provider.
Handles three cross-provider incompatibilities:
- List content with provider-specific blocks (e.g. ``thinking``)
- List content with bare strings or empty text blocks
- AI messages with empty content + tool calls: some providers (Bedrock)
convert ``""`` to ``[{"type":"text","text":""}]`` server-side then
reject the blank text. The OpenAI spec says ``content`` should be
``null`` when an assistant message only carries tool calls.
"""
for msg in messages:
if isinstance(msg.content, list):
msg.content = _sanitize_content(msg.content)
if (
isinstance(msg, AIMessage)
and (not msg.content or msg.content == "")
and getattr(msg, "tool_calls", None)
):
msg.content = None # type: ignore[assignment]
return messages
class SanitizedChatLiteLLM(ChatLiteLLM):
"""ChatLiteLLM subclass that strips provider-specific content blocks
(e.g. ``thinking`` from reasoning models) and normalises bare strings
in content arrays before forwarding to the underlying provider."""
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return super()._generate(
_sanitize_messages(messages), stop, run_manager, **kwargs
)
async def _astream(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
async for chunk in super()._astream(
_sanitize_messages(messages), stop, run_manager, **kwargs
):
yield chunk
# Provider mapping for LiteLLM model string construction
PROVIDER_MAP = {
"OPENAI": "openai",
@ -252,6 +314,28 @@ def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
return None
def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
"""
Load a global LLM config by ID, checking in-memory configs first.
This handles both static YAML configs and dynamically injected configs
(e.g. OpenRouter integration models that only exist in memory).
Args:
llm_config_id: The negative ID of the global config to load
Returns:
LLM config dict or None if not found
"""
from app.config import config as app_config
for cfg in app_config.GLOBAL_LLM_CONFIGS:
if cfg.get("id") == llm_config_id:
return cfg
# Fallback to YAML file read (covers edge cases like hot-reload)
return load_llm_config_from_yaml(llm_config_id)
async def load_new_llm_config_from_db(
session: AsyncSession,
config_id: int,
@ -359,7 +443,13 @@ async def load_agent_config(
return AgentConfig.from_auto_mode()
if config_id < 0:
# Load from YAML (global configs have negative IDs)
# Check in-memory configs first (includes static YAML + dynamic OpenRouter)
from app.config import config as app_config
for cfg in app_config.GLOBAL_LLM_CONFIGS:
if cfg.get("id") == config_id:
return AgentConfig.from_yaml_config(cfg)
# Fallback to YAML file read for safety
yaml_config = load_llm_config_from_yaml(config_id)
if yaml_config:
return AgentConfig.from_yaml_config(yaml_config)
@ -402,7 +492,7 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
if llm_config.get("litellm_params"):
litellm_kwargs.update(llm_config["litellm_params"])
llm = ChatLiteLLM(**litellm_kwargs)
llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string)
return llm
@ -457,6 +547,6 @@ def create_chat_litellm_from_agent_config(
if agent_config.litellm_params:
litellm_kwargs.update(agent_config.litellm_params)
llm = ChatLiteLLM(**litellm_kwargs)
llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string)
return llm

View file

@ -13,6 +13,7 @@ from litellm import aspeech
from app.config import config as app_config
from app.services.kokoro_tts_service import get_kokoro_tts_service
from app.services.llm_service import get_agent_llm
from app.utils.content_utils import extract_text_content, strip_markdown_fences
from .configuration import Configuration
from .prompts import get_podcast_generation_prompt
@ -53,43 +54,32 @@ async def create_podcast_transcript(
# Generate the podcast transcript
llm_response = await llm.ainvoke(messages)
# First try the direct approach
# Reasoning models (e.g. Kimi K2.5) may return content as a list of
# blocks including 'reasoning' entries. Normalise to a plain string.
content = strip_markdown_fences(extract_text_content(llm_response.content))
try:
podcast_transcript = PodcastTranscripts.model_validate(
json.loads(llm_response.content)
)
except (json.JSONDecodeError, ValueError) as e:
podcast_transcript = PodcastTranscripts.model_validate(json.loads(content))
except (json.JSONDecodeError, TypeError, ValueError) as e:
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
# Fallback: Parse the JSON response manually
try:
# Extract JSON content from the response
content = llm_response.content
# Find the JSON in the content (handle case where LLM might add additional text)
json_start = content.find("{")
json_end = content.rfind("}") + 1
if json_start >= 0 and json_end > json_start:
json_str = content[json_start:json_end]
# Parse the JSON string
parsed_data = json.loads(json_str)
# Convert to Pydantic model
podcast_transcript = PodcastTranscripts.model_validate(parsed_data)
print("Successfully parsed podcast transcript using fallback approach")
else:
# If JSON structure not found, raise a clear error
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
print(error_message)
raise ValueError(error_message)
except (json.JSONDecodeError, ValueError) as e2:
# Log the error and re-raise it
except (json.JSONDecodeError, TypeError, ValueError) as e2:
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
print(f"Error parsing LLM response: {e2!s}")
print(f"Raw response: {llm_response.content}")
print(f"Raw response: {content}")
raise
return {"podcast_transcript": podcast_transcript.podcast_transcripts}

View file

@ -16,6 +16,7 @@ from litellm import aspeech
from app.config import config as app_config
from app.services.kokoro_tts_service import get_kokoro_tts_service
from app.services.llm_service import get_agent_llm
from app.utils.content_utils import extract_text_content, strip_markdown_fences
from .configuration import Configuration
from .prompts import (
@ -67,16 +68,14 @@ async def create_presentation_slides(
]
llm_response = await llm.ainvoke(messages)
content = strip_markdown_fences(extract_text_content(llm_response.content))
try:
presentation = PresentationSlides.model_validate(
json.loads(llm_response.content)
)
except (json.JSONDecodeError, ValueError) as e:
presentation = PresentationSlides.model_validate(json.loads(content))
except (json.JSONDecodeError, TypeError, ValueError) as e:
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
try:
content = llm_response.content
json_start = content.find("{")
json_end = content.rfind("}") + 1
if json_start >= 0 and json_end > json_start:
@ -89,10 +88,10 @@ async def create_presentation_slides(
print(error_message)
raise ValueError(error_message)
except (json.JSONDecodeError, ValueError) as e2:
except (json.JSONDecodeError, TypeError, ValueError) as e2:
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
print(f"Error parsing LLM response: {e2!s}")
print(f"Raw response: {llm_response.content}")
print(f"Raw response: {content}")
raise
return {"slides": presentation.slides}
@ -308,12 +307,7 @@ async def _assign_themes_with_llm(
]
)
text = response.content.strip()
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(
line for line in lines if not line.strip().startswith("```")
).strip()
text = strip_markdown_fences(extract_text_content(response.content))
assignments = json.loads(text)
valid_themes = set(THEME_PRESETS)
@ -424,7 +418,9 @@ async def generate_slide_scene_codes(
)
llm_response = await llm.ainvoke(messages)
code, scene_title = _extract_code_and_title(llm_response.content)
code, scene_title = _extract_code_and_title(
extract_text_content(llm_response.content)
)
code = await _refine_if_needed(llm, code, slide.slide_number)
@ -452,7 +448,7 @@ def _extract_code_and_title(content: str) -> tuple[str, str | None]:
Returns (code, title) where title may be None.
"""
text = content.strip()
text = strip_markdown_fences(content)
if text.startswith("{"):
try:
@ -472,18 +468,7 @@ def _extract_code_and_title(content: str) -> tuple[str, str | None]:
except (json.JSONDecodeError, ValueError):
pass
code = text
if code.startswith("```"):
lines = code.split("\n")
start = 1
end = len(lines)
for i in range(len(lines) - 1, 0, -1):
if lines[i].strip().startswith("```"):
end = i
break
code = "\n".join(lines[start:end]).strip()
return code, None
return text, None
async def _refine_if_needed(llm, code: str, slide_number: int) -> str:
@ -512,7 +497,7 @@ async def _refine_if_needed(llm, code: str, slide_number: int) -> str:
]
response = await llm.ainvoke(messages)
code, _ = _extract_code_and_title(response.content)
code, _ = _extract_code_and_title(extract_text_content(response.content))
error = _basic_syntax_check(code)
if error is None: