mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
cloud: added openrouter integration with global configs
This commit is contained in:
parent
ff4e0f9b62
commit
4a51ccdc2c
26 changed files with 911 additions and 178 deletions
|
|
@ -22,7 +22,11 @@ from .chat_deepagent import create_surfsense_deep_agent
|
|||
from .context import SurfSenseContextSchema
|
||||
|
||||
# LLM config
|
||||
from .llm_config import create_chat_litellm_from_config, load_llm_config_from_yaml
|
||||
from .llm_config import (
|
||||
create_chat_litellm_from_config,
|
||||
load_global_llm_config_by_id,
|
||||
load_llm_config_from_yaml,
|
||||
)
|
||||
|
||||
# Middleware
|
||||
from .middleware import (
|
||||
|
|
@ -81,6 +85,7 @@ __all__ = [
|
|||
"get_all_tool_names",
|
||||
"get_default_enabled_tools",
|
||||
"get_tool_by_name",
|
||||
"load_global_llm_config_by_id",
|
||||
"load_llm_config_from_yaml",
|
||||
"search_knowledge_base_async",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -10,10 +10,18 @@ It also provides utilities for creating ChatLiteLLM instances and
|
|||
managing prompt configurations.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from litellm import get_model_info
|
||||
from sqlalchemy import select
|
||||
|
|
@ -23,10 +31,64 @@ from app.services.llm_router_service import (
|
|||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
_sanitize_content,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Sanitize content on every message so it is safe for any provider.
|
||||
|
||||
Handles three cross-provider incompatibilities:
|
||||
- List content with provider-specific blocks (e.g. ``thinking``)
|
||||
- List content with bare strings or empty text blocks
|
||||
- AI messages with empty content + tool calls: some providers (Bedrock)
|
||||
convert ``""`` to ``[{"type":"text","text":""}]`` server-side then
|
||||
reject the blank text. The OpenAI spec says ``content`` should be
|
||||
``null`` when an assistant message only carries tool calls.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, list):
|
||||
msg.content = _sanitize_content(msg.content)
|
||||
if (
|
||||
isinstance(msg, AIMessage)
|
||||
and (not msg.content or msg.content == "")
|
||||
and getattr(msg, "tool_calls", None)
|
||||
):
|
||||
msg.content = None # type: ignore[assignment]
|
||||
return messages
|
||||
|
||||
|
||||
class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||
"""ChatLiteLLM subclass that strips provider-specific content blocks
|
||||
(e.g. ``thinking`` from reasoning models) and normalises bare strings
|
||||
in content arrays before forwarding to the underlying provider."""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return super()._generate(
|
||||
_sanitize_messages(messages), stop, run_manager, **kwargs
|
||||
)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
async for chunk in super()._astream(
|
||||
_sanitize_messages(messages), stop, run_manager, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
# Provider mapping for LiteLLM model string construction
|
||||
PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
|
|
@ -252,6 +314,28 @@ def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
|
|||
return None
|
||||
|
||||
|
||||
def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
|
||||
"""
|
||||
Load a global LLM config by ID, checking in-memory configs first.
|
||||
|
||||
This handles both static YAML configs and dynamically injected configs
|
||||
(e.g. OpenRouter integration models that only exist in memory).
|
||||
|
||||
Args:
|
||||
llm_config_id: The negative ID of the global config to load
|
||||
|
||||
Returns:
|
||||
LLM config dict or None if not found
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == llm_config_id:
|
||||
return cfg
|
||||
# Fallback to YAML file read (covers edge cases like hot-reload)
|
||||
return load_llm_config_from_yaml(llm_config_id)
|
||||
|
||||
|
||||
async def load_new_llm_config_from_db(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
|
|
@ -359,7 +443,13 @@ async def load_agent_config(
|
|||
return AgentConfig.from_auto_mode()
|
||||
|
||||
if config_id < 0:
|
||||
# Load from YAML (global configs have negative IDs)
|
||||
# Check in-memory configs first (includes static YAML + dynamic OpenRouter)
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return AgentConfig.from_yaml_config(cfg)
|
||||
# Fallback to YAML file read for safety
|
||||
yaml_config = load_llm_config_from_yaml(config_id)
|
||||
if yaml_config:
|
||||
return AgentConfig.from_yaml_config(yaml_config)
|
||||
|
|
@ -402,7 +492,7 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
|||
if llm_config.get("litellm_params"):
|
||||
litellm_kwargs.update(llm_config["litellm_params"])
|
||||
|
||||
llm = ChatLiteLLM(**litellm_kwargs)
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
_attach_model_profile(llm, model_string)
|
||||
return llm
|
||||
|
||||
|
|
@ -457,6 +547,6 @@ def create_chat_litellm_from_agent_config(
|
|||
if agent_config.litellm_params:
|
||||
litellm_kwargs.update(agent_config.litellm_params)
|
||||
|
||||
llm = ChatLiteLLM(**litellm_kwargs)
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
_attach_model_profile(llm, model_string)
|
||||
return llm
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from litellm import aspeech
|
|||
from app.config import config as app_config
|
||||
from app.services.kokoro_tts_service import get_kokoro_tts_service
|
||||
from app.services.llm_service import get_agent_llm
|
||||
from app.utils.content_utils import extract_text_content, strip_markdown_fences
|
||||
|
||||
from .configuration import Configuration
|
||||
from .prompts import get_podcast_generation_prompt
|
||||
|
|
@ -53,43 +54,32 @@ async def create_podcast_transcript(
|
|||
# Generate the podcast transcript
|
||||
llm_response = await llm.ainvoke(messages)
|
||||
|
||||
# First try the direct approach
|
||||
# Reasoning models (e.g. Kimi K2.5) may return content as a list of
|
||||
# blocks including 'reasoning' entries. Normalise to a plain string.
|
||||
content = strip_markdown_fences(extract_text_content(llm_response.content))
|
||||
|
||||
try:
|
||||
podcast_transcript = PodcastTranscripts.model_validate(
|
||||
json.loads(llm_response.content)
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
podcast_transcript = PodcastTranscripts.model_validate(json.loads(content))
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
|
||||
|
||||
# Fallback: Parse the JSON response manually
|
||||
try:
|
||||
# Extract JSON content from the response
|
||||
content = llm_response.content
|
||||
|
||||
# Find the JSON in the content (handle case where LLM might add additional text)
|
||||
json_start = content.find("{")
|
||||
json_end = content.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
json_str = content[json_start:json_end]
|
||||
|
||||
# Parse the JSON string
|
||||
parsed_data = json.loads(json_str)
|
||||
|
||||
# Convert to Pydantic model
|
||||
podcast_transcript = PodcastTranscripts.model_validate(parsed_data)
|
||||
|
||||
print("Successfully parsed podcast transcript using fallback approach")
|
||||
else:
|
||||
# If JSON structure not found, raise a clear error
|
||||
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
|
||||
print(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e2:
|
||||
# Log the error and re-raise it
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e2:
|
||||
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
|
||||
print(f"Error parsing LLM response: {e2!s}")
|
||||
print(f"Raw response: {llm_response.content}")
|
||||
print(f"Raw response: {content}")
|
||||
raise
|
||||
|
||||
return {"podcast_transcript": podcast_transcript.podcast_transcripts}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from litellm import aspeech
|
|||
from app.config import config as app_config
|
||||
from app.services.kokoro_tts_service import get_kokoro_tts_service
|
||||
from app.services.llm_service import get_agent_llm
|
||||
from app.utils.content_utils import extract_text_content, strip_markdown_fences
|
||||
|
||||
from .configuration import Configuration
|
||||
from .prompts import (
|
||||
|
|
@ -67,16 +68,14 @@ async def create_presentation_slides(
|
|||
]
|
||||
|
||||
llm_response = await llm.ainvoke(messages)
|
||||
content = strip_markdown_fences(extract_text_content(llm_response.content))
|
||||
|
||||
try:
|
||||
presentation = PresentationSlides.model_validate(
|
||||
json.loads(llm_response.content)
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
presentation = PresentationSlides.model_validate(json.loads(content))
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
|
||||
|
||||
try:
|
||||
content = llm_response.content
|
||||
json_start = content.find("{")
|
||||
json_end = content.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
|
|
@ -89,10 +88,10 @@ async def create_presentation_slides(
|
|||
print(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e2:
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e2:
|
||||
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
|
||||
print(f"Error parsing LLM response: {e2!s}")
|
||||
print(f"Raw response: {llm_response.content}")
|
||||
print(f"Raw response: {content}")
|
||||
raise
|
||||
|
||||
return {"slides": presentation.slides}
|
||||
|
|
@ -308,12 +307,7 @@ async def _assign_themes_with_llm(
|
|||
]
|
||||
)
|
||||
|
||||
text = response.content.strip()
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
text = "\n".join(
|
||||
line for line in lines if not line.strip().startswith("```")
|
||||
).strip()
|
||||
text = strip_markdown_fences(extract_text_content(response.content))
|
||||
|
||||
assignments = json.loads(text)
|
||||
valid_themes = set(THEME_PRESETS)
|
||||
|
|
@ -424,7 +418,9 @@ async def generate_slide_scene_codes(
|
|||
)
|
||||
|
||||
llm_response = await llm.ainvoke(messages)
|
||||
code, scene_title = _extract_code_and_title(llm_response.content)
|
||||
code, scene_title = _extract_code_and_title(
|
||||
extract_text_content(llm_response.content)
|
||||
)
|
||||
|
||||
code = await _refine_if_needed(llm, code, slide.slide_number)
|
||||
|
||||
|
|
@ -452,7 +448,7 @@ def _extract_code_and_title(content: str) -> tuple[str, str | None]:
|
|||
|
||||
Returns (code, title) where title may be None.
|
||||
"""
|
||||
text = content.strip()
|
||||
text = strip_markdown_fences(content)
|
||||
|
||||
if text.startswith("{"):
|
||||
try:
|
||||
|
|
@ -472,18 +468,7 @@ def _extract_code_and_title(content: str) -> tuple[str, str | None]:
|
|||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
code = text
|
||||
if code.startswith("```"):
|
||||
lines = code.split("\n")
|
||||
start = 1
|
||||
end = len(lines)
|
||||
for i in range(len(lines) - 1, 0, -1):
|
||||
if lines[i].strip().startswith("```"):
|
||||
end = i
|
||||
break
|
||||
code = "\n".join(lines[start:end]).strip()
|
||||
|
||||
return code, None
|
||||
return text, None
|
||||
|
||||
|
||||
async def _refine_if_needed(llm, code: str, slide_number: int) -> str:
|
||||
|
|
@ -512,7 +497,7 @@ async def _refine_if_needed(llm, code: str, slide_number: int) -> str:
|
|||
]
|
||||
|
||||
response = await llm.ainvoke(messages)
|
||||
code, _ = _extract_code_and_title(response.content)
|
||||
code, _ = _extract_code_and_title(extract_text_content(response.content))
|
||||
|
||||
error = _basic_syntax_check(code)
|
||||
if error is None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue