mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-08 20:25:19 +02:00
cloud: added openrouter integration with global configs
This commit is contained in:
parent
ff4e0f9b62
commit
4a51ccdc2c
26 changed files with 911 additions and 178 deletions
|
|
@ -184,17 +184,17 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
|
|||
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
|
||||
PAGES_LIMIT=500
|
||||
|
||||
# Premium token quota per registered user (default: 5,000,000)
|
||||
# Premium token quota per registered user (default: 3,000,000)
|
||||
# Applies only to models with billing_tier=premium in global_llm_config.yaml
|
||||
PREMIUM_TOKEN_LIMIT=5000000
|
||||
PREMIUM_TOKEN_LIMIT=3000000
|
||||
|
||||
# No-login (anonymous) mode — allows public users to chat without an account
|
||||
# Set TRUE to enable /free pages and anonymous chat API
|
||||
NOLOGIN_MODE_ENABLED=FALSE
|
||||
# Total tokens allowed per anonymous session before requiring account creation
|
||||
ANON_TOKEN_LIMIT=1000000
|
||||
ANON_TOKEN_LIMIT=500000
|
||||
# Token count at which the UI shows a soft warning
|
||||
ANON_TOKEN_WARNING_THRESHOLD=800000
|
||||
ANON_TOKEN_WARNING_THRESHOLD=400000
|
||||
# Days before anonymous quota tracking expires in Redis
|
||||
ANON_TOKEN_QUOTA_TTL_DAYS=30
|
||||
# Max document upload size for anonymous users (MB)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,11 @@ from .chat_deepagent import create_surfsense_deep_agent
|
|||
from .context import SurfSenseContextSchema
|
||||
|
||||
# LLM config
|
||||
from .llm_config import create_chat_litellm_from_config, load_llm_config_from_yaml
|
||||
from .llm_config import (
|
||||
create_chat_litellm_from_config,
|
||||
load_global_llm_config_by_id,
|
||||
load_llm_config_from_yaml,
|
||||
)
|
||||
|
||||
# Middleware
|
||||
from .middleware import (
|
||||
|
|
@ -81,6 +85,7 @@ __all__ = [
|
|||
"get_all_tool_names",
|
||||
"get_default_enabled_tools",
|
||||
"get_tool_by_name",
|
||||
"load_global_llm_config_by_id",
|
||||
"load_llm_config_from_yaml",
|
||||
"search_knowledge_base_async",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -10,10 +10,18 @@ It also provides utilities for creating ChatLiteLLM instances and
|
|||
managing prompt configurations.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from litellm import get_model_info
|
||||
from sqlalchemy import select
|
||||
|
|
@ -23,10 +31,64 @@ from app.services.llm_router_service import (
|
|||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
_sanitize_content,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Sanitize content on every message so it is safe for any provider.
|
||||
|
||||
Handles three cross-provider incompatibilities:
|
||||
- List content with provider-specific blocks (e.g. ``thinking``)
|
||||
- List content with bare strings or empty text blocks
|
||||
- AI messages with empty content + tool calls: some providers (Bedrock)
|
||||
convert ``""`` to ``[{"type":"text","text":""}]`` server-side then
|
||||
reject the blank text. The OpenAI spec says ``content`` should be
|
||||
``null`` when an assistant message only carries tool calls.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, list):
|
||||
msg.content = _sanitize_content(msg.content)
|
||||
if (
|
||||
isinstance(msg, AIMessage)
|
||||
and (not msg.content or msg.content == "")
|
||||
and getattr(msg, "tool_calls", None)
|
||||
):
|
||||
msg.content = None # type: ignore[assignment]
|
||||
return messages
|
||||
|
||||
|
||||
class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||
"""ChatLiteLLM subclass that strips provider-specific content blocks
|
||||
(e.g. ``thinking`` from reasoning models) and normalises bare strings
|
||||
in content arrays before forwarding to the underlying provider."""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return super()._generate(
|
||||
_sanitize_messages(messages), stop, run_manager, **kwargs
|
||||
)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
async for chunk in super()._astream(
|
||||
_sanitize_messages(messages), stop, run_manager, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
# Provider mapping for LiteLLM model string construction
|
||||
PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
|
|
@ -252,6 +314,28 @@ def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
|
|||
return None
|
||||
|
||||
|
||||
def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
|
||||
"""
|
||||
Load a global LLM config by ID, checking in-memory configs first.
|
||||
|
||||
This handles both static YAML configs and dynamically injected configs
|
||||
(e.g. OpenRouter integration models that only exist in memory).
|
||||
|
||||
Args:
|
||||
llm_config_id: The negative ID of the global config to load
|
||||
|
||||
Returns:
|
||||
LLM config dict or None if not found
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == llm_config_id:
|
||||
return cfg
|
||||
# Fallback to YAML file read (covers edge cases like hot-reload)
|
||||
return load_llm_config_from_yaml(llm_config_id)
|
||||
|
||||
|
||||
async def load_new_llm_config_from_db(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
|
|
@ -359,7 +443,13 @@ async def load_agent_config(
|
|||
return AgentConfig.from_auto_mode()
|
||||
|
||||
if config_id < 0:
|
||||
# Load from YAML (global configs have negative IDs)
|
||||
# Check in-memory configs first (includes static YAML + dynamic OpenRouter)
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return AgentConfig.from_yaml_config(cfg)
|
||||
# Fallback to YAML file read for safety
|
||||
yaml_config = load_llm_config_from_yaml(config_id)
|
||||
if yaml_config:
|
||||
return AgentConfig.from_yaml_config(yaml_config)
|
||||
|
|
@ -402,7 +492,7 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
|||
if llm_config.get("litellm_params"):
|
||||
litellm_kwargs.update(llm_config["litellm_params"])
|
||||
|
||||
llm = ChatLiteLLM(**litellm_kwargs)
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
_attach_model_profile(llm, model_string)
|
||||
return llm
|
||||
|
||||
|
|
@ -457,6 +547,6 @@ def create_chat_litellm_from_agent_config(
|
|||
if agent_config.litellm_params:
|
||||
litellm_kwargs.update(agent_config.litellm_params)
|
||||
|
||||
llm = ChatLiteLLM(**litellm_kwargs)
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
_attach_model_profile(llm, model_string)
|
||||
return llm
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from litellm import aspeech
|
|||
from app.config import config as app_config
|
||||
from app.services.kokoro_tts_service import get_kokoro_tts_service
|
||||
from app.services.llm_service import get_agent_llm
|
||||
from app.utils.content_utils import extract_text_content, strip_markdown_fences
|
||||
|
||||
from .configuration import Configuration
|
||||
from .prompts import get_podcast_generation_prompt
|
||||
|
|
@ -53,43 +54,32 @@ async def create_podcast_transcript(
|
|||
# Generate the podcast transcript
|
||||
llm_response = await llm.ainvoke(messages)
|
||||
|
||||
# First try the direct approach
|
||||
# Reasoning models (e.g. Kimi K2.5) may return content as a list of
|
||||
# blocks including 'reasoning' entries. Normalise to a plain string.
|
||||
content = strip_markdown_fences(extract_text_content(llm_response.content))
|
||||
|
||||
try:
|
||||
podcast_transcript = PodcastTranscripts.model_validate(
|
||||
json.loads(llm_response.content)
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
podcast_transcript = PodcastTranscripts.model_validate(json.loads(content))
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
|
||||
|
||||
# Fallback: Parse the JSON response manually
|
||||
try:
|
||||
# Extract JSON content from the response
|
||||
content = llm_response.content
|
||||
|
||||
# Find the JSON in the content (handle case where LLM might add additional text)
|
||||
json_start = content.find("{")
|
||||
json_end = content.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
json_str = content[json_start:json_end]
|
||||
|
||||
# Parse the JSON string
|
||||
parsed_data = json.loads(json_str)
|
||||
|
||||
# Convert to Pydantic model
|
||||
podcast_transcript = PodcastTranscripts.model_validate(parsed_data)
|
||||
|
||||
print("Successfully parsed podcast transcript using fallback approach")
|
||||
else:
|
||||
# If JSON structure not found, raise a clear error
|
||||
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
|
||||
print(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e2:
|
||||
# Log the error and re-raise it
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e2:
|
||||
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
|
||||
print(f"Error parsing LLM response: {e2!s}")
|
||||
print(f"Raw response: {llm_response.content}")
|
||||
print(f"Raw response: {content}")
|
||||
raise
|
||||
|
||||
return {"podcast_transcript": podcast_transcript.podcast_transcripts}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from litellm import aspeech
|
|||
from app.config import config as app_config
|
||||
from app.services.kokoro_tts_service import get_kokoro_tts_service
|
||||
from app.services.llm_service import get_agent_llm
|
||||
from app.utils.content_utils import extract_text_content, strip_markdown_fences
|
||||
|
||||
from .configuration import Configuration
|
||||
from .prompts import (
|
||||
|
|
@ -67,16 +68,14 @@ async def create_presentation_slides(
|
|||
]
|
||||
|
||||
llm_response = await llm.ainvoke(messages)
|
||||
content = strip_markdown_fences(extract_text_content(llm_response.content))
|
||||
|
||||
try:
|
||||
presentation = PresentationSlides.model_validate(
|
||||
json.loads(llm_response.content)
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
presentation = PresentationSlides.model_validate(json.loads(content))
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
|
||||
|
||||
try:
|
||||
content = llm_response.content
|
||||
json_start = content.find("{")
|
||||
json_end = content.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
|
|
@ -89,10 +88,10 @@ async def create_presentation_slides(
|
|||
print(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e2:
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e2:
|
||||
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
|
||||
print(f"Error parsing LLM response: {e2!s}")
|
||||
print(f"Raw response: {llm_response.content}")
|
||||
print(f"Raw response: {content}")
|
||||
raise
|
||||
|
||||
return {"slides": presentation.slides}
|
||||
|
|
@ -308,12 +307,7 @@ async def _assign_themes_with_llm(
|
|||
]
|
||||
)
|
||||
|
||||
text = response.content.strip()
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
text = "\n".join(
|
||||
line for line in lines if not line.strip().startswith("```")
|
||||
).strip()
|
||||
text = strip_markdown_fences(extract_text_content(response.content))
|
||||
|
||||
assignments = json.loads(text)
|
||||
valid_themes = set(THEME_PRESETS)
|
||||
|
|
@ -424,7 +418,9 @@ async def generate_slide_scene_codes(
|
|||
)
|
||||
|
||||
llm_response = await llm.ainvoke(messages)
|
||||
code, scene_title = _extract_code_and_title(llm_response.content)
|
||||
code, scene_title = _extract_code_and_title(
|
||||
extract_text_content(llm_response.content)
|
||||
)
|
||||
|
||||
code = await _refine_if_needed(llm, code, slide.slide_number)
|
||||
|
||||
|
|
@ -452,7 +448,7 @@ def _extract_code_and_title(content: str) -> tuple[str, str | None]:
|
|||
|
||||
Returns (code, title) where title may be None.
|
||||
"""
|
||||
text = content.strip()
|
||||
text = strip_markdown_fences(content)
|
||||
|
||||
if text.startswith("{"):
|
||||
try:
|
||||
|
|
@ -472,18 +468,7 @@ def _extract_code_and_title(content: str) -> tuple[str, str | None]:
|
|||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
code = text
|
||||
if code.startswith("```"):
|
||||
lines = code.split("\n")
|
||||
start = 1
|
||||
end = len(lines)
|
||||
for i in range(len(lines) - 1, 0, -1):
|
||||
if lines[i].strip().startswith("```"):
|
||||
end = i
|
||||
break
|
||||
code = "\n".join(lines[start:end]).strip()
|
||||
|
||||
return code, None
|
||||
return text, None
|
||||
|
||||
|
||||
async def _refine_if_needed(llm, code: str, slide_number: int) -> str:
|
||||
|
|
@ -512,7 +497,7 @@ async def _refine_if_needed(llm, code: str, slide_number: int) -> str:
|
|||
]
|
||||
|
||||
response = await llm.ainvoke(messages)
|
||||
code, _ = _extract_code_and_title(response.content)
|
||||
code, _ = _extract_code_and_title(extract_text_content(response.content))
|
||||
|
||||
error = _basic_syntax_check(code)
|
||||
if error is None:
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from app.config import (
|
|||
config,
|
||||
initialize_image_gen_router,
|
||||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
|
|
@ -368,6 +369,26 @@ def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None:
|
|||
)
|
||||
|
||||
|
||||
def _start_openrouter_background_refresh() -> None:
|
||||
"""Start periodic OpenRouter model refresh if integration is enabled."""
|
||||
from app.services.openrouter_integration_service import OpenRouterIntegrationService
|
||||
|
||||
if not OpenRouterIntegrationService.is_initialized():
|
||||
return
|
||||
settings = config.OPENROUTER_INTEGRATION_SETTINGS
|
||||
if settings:
|
||||
interval = settings.get("refresh_interval_hours", 24)
|
||||
OpenRouterIntegrationService.get_instance().start_background_refresh(interval)
|
||||
|
||||
|
||||
def _stop_openrouter_background_refresh() -> None:
|
||||
"""Cancel the periodic OpenRouter refresh task on shutdown."""
|
||||
from app.services.openrouter_integration_service import OpenRouterIntegrationService
|
||||
|
||||
if OpenRouterIntegrationService.is_initialized():
|
||||
OpenRouterIntegrationService.get_instance().stop_background_refresh()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
|
||||
|
|
@ -378,6 +399,8 @@ async def lifespan(app: FastAPI):
|
|||
_enable_slow_callback_logging(threshold_sec=0.5)
|
||||
await create_db_and_tables()
|
||||
await setup_checkpointer_tables()
|
||||
initialize_openrouter_integration()
|
||||
_start_openrouter_background_refresh()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
|
@ -393,6 +416,7 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
yield
|
||||
|
||||
_stop_openrouter_background_refresh()
|
||||
await close_checkpointer()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -21,9 +21,11 @@ def init_worker(**kwargs):
|
|||
from app.config import (
|
||||
initialize_image_gen_router,
|
||||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
|
||||
initialize_openrouter_integration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
|
|
|||
|
|
@ -187,24 +187,82 @@ def load_image_gen_router_settings():
|
|||
return default_settings
|
||||
|
||||
|
||||
def load_openrouter_integration_settings() -> dict | None:
|
||||
"""
|
||||
Load OpenRouter integration settings from the YAML config.
|
||||
|
||||
Returns:
|
||||
dict with settings if present and enabled, None otherwise
|
||||
"""
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
settings = data.get("openrouter_integration")
|
||||
if settings and settings.get("enabled"):
|
||||
return settings
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def initialize_openrouter_integration():
|
||||
"""
|
||||
If enabled, fetch all OpenRouter models and append them to
|
||||
config.GLOBAL_LLM_CONFIGS as dynamic premium entries.
|
||||
Should be called BEFORE initialize_llm_router() so the router
|
||||
correctly excludes premium models from Auto mode.
|
||||
"""
|
||||
settings = load_openrouter_integration_settings()
|
||||
if not settings:
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.openrouter_integration_service import (
|
||||
OpenRouterIntegrationService,
|
||||
)
|
||||
|
||||
service = OpenRouterIntegrationService.get_instance()
|
||||
new_configs = service.initialize(settings)
|
||||
|
||||
if new_configs:
|
||||
config.GLOBAL_LLM_CONFIGS.extend(new_configs)
|
||||
print(
|
||||
f"Info: OpenRouter integration added {len(new_configs)} models "
|
||||
f"(billing_tier={settings.get('billing_tier', 'premium')})"
|
||||
)
|
||||
else:
|
||||
print("Info: OpenRouter integration enabled but no models fetched")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
|
||||
|
||||
|
||||
def initialize_llm_router():
|
||||
"""
|
||||
Initialize the LLM Router service for Auto mode.
|
||||
This should be called during application startup.
|
||||
This should be called during application startup, AFTER
|
||||
initialize_openrouter_integration() so dynamic models are included.
|
||||
Uses config.GLOBAL_LLM_CONFIGS (in-memory) which includes both
|
||||
static YAML configs and dynamic OpenRouter models.
|
||||
"""
|
||||
global_configs = load_global_llm_configs()
|
||||
all_configs = config.GLOBAL_LLM_CONFIGS
|
||||
router_settings = load_router_settings()
|
||||
|
||||
if not global_configs:
|
||||
if not all_configs:
|
||||
print("Info: No global LLM configs found, Auto mode will not be available")
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
|
||||
LLMRouterService.initialize(global_configs, router_settings)
|
||||
LLMRouterService.initialize(all_configs, router_settings)
|
||||
print(
|
||||
f"Info: LLM Router initialized with {len(global_configs)} models "
|
||||
f"Info: LLM Router initialized with {len(all_configs)} models "
|
||||
f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})"
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -326,7 +384,7 @@ class Config:
|
|||
)
|
||||
|
||||
# Premium token quota settings
|
||||
PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "5000000"))
|
||||
PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000"))
|
||||
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
|
||||
STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000"))
|
||||
STRIPE_TOKEN_BUYING_ENABLED = (
|
||||
|
|
@ -335,9 +393,9 @@ class Config:
|
|||
|
||||
# Anonymous / no-login mode settings
|
||||
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
|
||||
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "1000000"))
|
||||
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000"))
|
||||
ANON_TOKEN_WARNING_THRESHOLD = int(
|
||||
os.getenv("ANON_TOKEN_WARNING_THRESHOLD", "800000")
|
||||
os.getenv("ANON_TOKEN_WARNING_THRESHOLD", "400000")
|
||||
)
|
||||
ANON_TOKEN_QUOTA_TTL_DAYS = int(os.getenv("ANON_TOKEN_QUOTA_TTL_DAYS", "30"))
|
||||
ANON_MAX_UPLOAD_SIZE_MB = int(os.getenv("ANON_MAX_UPLOAD_SIZE_MB", "5"))
|
||||
|
|
@ -450,6 +508,9 @@ class Config:
|
|||
# Router settings for Vision LLM Auto mode
|
||||
VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings()
|
||||
|
||||
# OpenRouter Integration settings (optional)
|
||||
OPENROUTER_INTEGRATION_SETTINGS = load_openrouter_integration_settings()
|
||||
|
||||
# Chonkie Configuration | Edit this to your needs
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||
# Azure OpenAI credentials from environment variables
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# Global LLM Configuration
|
||||
#
|
||||
#
|
||||
# SETUP INSTRUCTIONS:
|
||||
# 1. For production: Copy this file to global_llm_config.yaml and add your real API keys
|
||||
# 2. For testing: The system will use this example file automatically if global_llm_config.yaml doesn't exist
|
||||
|
|
@ -29,16 +29,16 @@ router_settings:
|
|||
# - "least-busy": Routes to least busy deployment
|
||||
# - "latency-based-routing": Routes based on response latency
|
||||
routing_strategy: "usage-based-routing"
|
||||
|
||||
|
||||
# Number of retries before failing
|
||||
num_retries: 3
|
||||
|
||||
|
||||
# Number of failures allowed before cooling down a deployment
|
||||
allowed_fails: 3
|
||||
|
||||
|
||||
# Cooldown time in seconds after allowed_fails is exceeded
|
||||
cooldown_time: 60
|
||||
|
||||
|
||||
# Fallback models (optional) - when primary fails, try these
|
||||
# Format: [{"primary_model": ["fallback1", "fallback2"]}]
|
||||
# fallbacks: []
|
||||
|
|
@ -58,13 +58,13 @@ global_llm_configs:
|
|||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
# Rate limits for load balancing (requests/tokens per minute)
|
||||
rpm: 500 # Requests per minute
|
||||
tpm: 100000 # Tokens per minute
|
||||
rpm: 500 # Requests per minute
|
||||
tpm: 100000 # Tokens per minute
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
# Prompt Configuration
|
||||
system_instructions: "" # Empty = use default SURFSENSE_SYSTEM_INSTRUCTIONS
|
||||
system_instructions: "" # Empty = use default SURFSENSE_SYSTEM_INSTRUCTIONS
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
|
|
@ -103,14 +103,14 @@ global_llm_configs:
|
|||
model_name: "gpt-3.5-turbo"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 3500 # GPT-3.5 has higher rate limits
|
||||
rpm: 3500 # GPT-3.5 has higher rate limits
|
||||
tpm: 200000
|
||||
litellm_params:
|
||||
temperature: 0.5
|
||||
max_tokens: 2000
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: false # Disabled for faster responses
|
||||
citations_enabled: false # Disabled for faster responses
|
||||
|
||||
# Example: Chinese LLM - DeepSeek with custom instructions
|
||||
- id: -4
|
||||
|
|
@ -134,9 +134,9 @@ global_llm_configs:
|
|||
system_instructions: |
|
||||
<system_instruction>
|
||||
You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base.
|
||||
|
||||
|
||||
Today's date (UTC): {resolved_today}
|
||||
|
||||
|
||||
IMPORTANT: Please respond in Chinese (简体中文) unless the user specifically requests another language.
|
||||
</system_instruction>
|
||||
use_default_system_instructions: false
|
||||
|
|
@ -158,7 +158,7 @@ global_llm_configs:
|
|||
model_name: "azure/gpt-4o-deployment"
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
api_version: "2024-02-15-preview" # Azure API version
|
||||
api_version: "2024-02-15-preview" # Azure API version
|
||||
rpm: 1000
|
||||
tpm: 150000
|
||||
litellm_params:
|
||||
|
|
@ -191,7 +191,7 @@ global_llm_configs:
|
|||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
base_model: "gpt-4-turbo" # Maps to gpt-4-turbo-preview
|
||||
base_model: "gpt-4-turbo" # Maps to gpt-4-turbo-preview
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
|
@ -209,7 +209,7 @@ global_llm_configs:
|
|||
model_name: "llama3-70b-8192"
|
||||
api_key: "your-groq-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 30 # Groq has lower rate limits on free tier
|
||||
rpm: 30 # Groq has lower rate limits on free tier
|
||||
tpm: 14400
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
|
|
@ -234,12 +234,48 @@ global_llm_configs:
|
|||
rpm: 60
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0], cannot be 0
|
||||
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0], cannot be 0
|
||||
max_tokens: 4000
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# =============================================================================
|
||||
# OpenRouter Integration
|
||||
# =============================================================================
|
||||
# When enabled, dynamically fetches ALL available models from the OpenRouter API
|
||||
# and injects them as global configs. This gives premium users access to any model
|
||||
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota.
|
||||
# Models are fetched at startup and refreshed periodically in the background.
|
||||
# All calls go through LiteLLM with the openrouter/ prefix.
|
||||
openrouter_integration:
|
||||
enabled: false
|
||||
api_key: "sk-or-your-openrouter-api-key"
|
||||
# billing_tier: "premium" or "free". Controls whether users need premium tokens.
|
||||
billing_tier: "premium"
|
||||
# anonymous_enabled: set true to also show OpenRouter models to no-login users
|
||||
anonymous_enabled: false
|
||||
seo_enabled: false
|
||||
# quota_reserve_tokens: tokens reserved per call for quota enforcement
|
||||
quota_reserve_tokens: 4000
|
||||
# id_offset: starting negative ID for dynamically generated configs.
|
||||
# Must not overlap with your static global_llm_configs IDs above.
|
||||
id_offset: -10000
|
||||
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
|
||||
refresh_interval_hours: 24
|
||||
# rpm/tpm: Applied uniformly to all OpenRouter models for LiteLLM Router load balancing.
|
||||
# OpenRouter doesn't expose per-model rate limits via API; actual throttling is handled
|
||||
# upstream by OpenRouter itself (your account limits are at https://openrouter.ai/settings/limits).
|
||||
# These values only matter if you set billing_tier to "free" (adding them to Auto mode).
|
||||
# For premium-only models they are cosmetic. Set conservatively or match your account tier.
|
||||
rpm: 200
|
||||
tpm: 1000000
|
||||
litellm_params:
|
||||
max_tokens: 16384
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# =============================================================================
|
||||
# Image Generation Configuration
|
||||
# =============================================================================
|
||||
|
|
@ -265,7 +301,7 @@ global_image_generation_configs:
|
|||
model_name: "dall-e-3"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens)
|
||||
rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens)
|
||||
litellm_params: {}
|
||||
|
||||
# Example: OpenAI GPT Image 1
|
||||
|
|
@ -394,7 +430,7 @@ global_vision_llm_configs:
|
|||
#
|
||||
# IMAGE GENERATION NOTES:
|
||||
# - Image generation configs use the same ID scheme as LLM configs (negative for global)
|
||||
# - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure),
|
||||
# - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure),
|
||||
# bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter)
|
||||
# - The router uses litellm.aimage_generation() for async image generation
|
||||
# - Only RPM (requests per minute) is relevant for image generation rate limiting.
|
||||
|
|
|
|||
|
|
@ -49,6 +49,49 @@ def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool:
|
|||
return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc)))
|
||||
|
||||
|
||||
_UNIVERSAL_CONTENT_TYPES = {
|
||||
"text",
|
||||
"image_url",
|
||||
"input_audio",
|
||||
"refusal",
|
||||
"audio",
|
||||
"file",
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_content(content: Any) -> Any:
|
||||
"""Normalise a LangChain message ``content`` field so it is safe for any
|
||||
downstream provider (Azure, OpenAI, OpenRouter, etc.).
|
||||
|
||||
* Strips provider-specific block types (e.g. ``thinking`` from reasoning models).
|
||||
* Removes text blocks with blank text (Bedrock rejects ``{"type":"text","text":""}``)
|
||||
* Converts bare strings inside a list to ``{"type": "text", "text": ...}`` objects
|
||||
(Azure rejects raw strings in a content array).
|
||||
* Collapses a single-text-block list to a plain string for maximum compatibility.
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return content
|
||||
|
||||
filtered: list[dict] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
if block:
|
||||
filtered.append({"type": "text", "text": block})
|
||||
elif isinstance(block, dict):
|
||||
block_type = block.get("type", "text")
|
||||
if block_type not in _UNIVERSAL_CONTENT_TYPES:
|
||||
continue
|
||||
if block_type == "text" and not block.get("text"):
|
||||
continue
|
||||
filtered.append(block)
|
||||
|
||||
if not filtered:
|
||||
return ""
|
||||
if len(filtered) == 1 and filtered[0].get("type") == "text":
|
||||
return filtered[0].get("text", "")
|
||||
return filtered
|
||||
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
AUTO_MODE_ID = 0
|
||||
|
||||
|
|
@ -103,6 +146,7 @@ class LLMRouterService:
|
|||
_model_list: list[dict] = []
|
||||
_router_settings: dict = {}
|
||||
_initialized: bool = False
|
||||
_premium_model_strings: set[str] = set()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
|
|
@ -135,22 +179,28 @@ class LLMRouterService:
|
|||
logger.debug("LLM Router already initialized, skipping")
|
||||
return
|
||||
|
||||
auto_configs = [
|
||||
c for c in global_configs if c.get("billing_tier", "free") != "premium"
|
||||
]
|
||||
|
||||
model_list = []
|
||||
for config in auto_configs:
|
||||
premium_models: set[str] = set()
|
||||
for config in global_configs:
|
||||
deployment = cls._config_to_deployment(config)
|
||||
if deployment:
|
||||
model_list.append(deployment)
|
||||
if config.get("billing_tier") == "premium":
|
||||
model_string = deployment["litellm_params"]["model"]
|
||||
premium_models.add(model_string)
|
||||
|
||||
if not model_list:
|
||||
logger.warning("No valid LLM configs found for router initialization")
|
||||
return
|
||||
|
||||
instance._model_list = model_list
|
||||
instance._premium_model_strings = premium_models
|
||||
instance._router_settings = router_settings or {}
|
||||
logger.info(
|
||||
"Router pool: %d deployments (%d premium)",
|
||||
len(model_list),
|
||||
len(premium_models),
|
||||
)
|
||||
|
||||
# Default router settings optimized for rate limit handling
|
||||
default_settings = {
|
||||
|
|
@ -197,6 +247,21 @@ class LLMRouterService:
|
|||
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def is_premium_model(cls, model_string: str) -> bool:
|
||||
"""Return True if *model_string* (as reported by LiteLLM) belongs to a
|
||||
premium-tier deployment in the router pool."""
|
||||
instance = cls.get_instance()
|
||||
return model_string in instance._premium_model_strings
|
||||
|
||||
@classmethod
|
||||
def compute_premium_tokens(cls, calls: list) -> int:
|
||||
"""Sum ``total_tokens`` for calls whose model is premium."""
|
||||
instance = cls.get_instance()
|
||||
return sum(
|
||||
c.total_tokens for c in calls if c.model in instance._premium_model_strings
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_context_fallback_groups(
|
||||
cls, model_list: list[dict]
|
||||
|
|
@ -1044,10 +1109,12 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
result.append({"role": "user", "content": msg.content})
|
||||
elif isinstance(msg, AIMsg):
|
||||
ai_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if msg.content:
|
||||
ai_msg["content"] = msg.content
|
||||
# Handle tool calls
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
has_tool_calls = hasattr(msg, "tool_calls") and msg.tool_calls
|
||||
|
||||
sanitized = _sanitize_content(msg.content) if msg.content else ""
|
||||
ai_msg["content"] = sanitized if sanitized else ""
|
||||
|
||||
if has_tool_calls:
|
||||
ai_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc.get("id", ""),
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from langchain_litellm import ChatLiteLLM
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
from app.config import config
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
from app.services.llm_router_service import (
|
||||
|
|
@ -150,7 +151,7 @@ async def validate_llm_config(
|
|||
if litellm_params:
|
||||
litellm_kwargs.update(litellm_params)
|
||||
|
||||
llm = ChatLiteLLM(**litellm_kwargs)
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
# Make a simple test call
|
||||
test_message = HumanMessage(content="Hello")
|
||||
|
|
@ -302,7 +303,7 @@ async def get_search_space_llm_instance(
|
|||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
# Get the LLM configuration from database (NewLLMConfig)
|
||||
result = await session.execute(
|
||||
|
|
@ -379,7 +380,7 @@ async def get_search_space_llm_instance(
|
|||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
@ -480,7 +481,7 @@ async def get_vision_llm(
|
|||
if global_cfg.get("litellm_params"):
|
||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).where(
|
||||
|
|
@ -513,7 +514,7 @@ async def get_vision_llm(
|
|||
if vision_cfg.litellm_params:
|
||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
|
|||
|
|
@ -86,12 +86,34 @@ def _is_text_output_model(model: dict) -> bool:
|
|||
return output_mods == ["text"]
|
||||
|
||||
|
||||
def _supports_tool_calling(model: dict) -> bool:
|
||||
"""Return True if the model supports function/tool calling."""
|
||||
supported = model.get("supported_parameters") or []
|
||||
return "tools" in supported
|
||||
|
||||
|
||||
MIN_CONTEXT_LENGTH = 100_000
|
||||
|
||||
|
||||
def _has_sufficient_context(model: dict) -> bool:
|
||||
"""Return True if the model's context window is at least MIN_CONTEXT_LENGTH."""
|
||||
ctx = model.get("context_length") or 0
|
||||
return ctx >= MIN_CONTEXT_LENGTH
|
||||
|
||||
|
||||
def _is_allowed_model(model: dict) -> bool:
|
||||
"""Reuse the exclusion list from the OpenRouter integration service."""
|
||||
from app.services.openrouter_integration_service import _is_allowed_model as _check
|
||||
|
||||
return _check(model)
|
||||
|
||||
|
||||
def _process_models(raw_models: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Transform raw OpenRouter model entries into a flat list of
|
||||
{value, label, provider, context_window} dicts.
|
||||
|
||||
Only text-output models are included (audio/image generators are skipped).
|
||||
Only text-output models with tool-calling support are included.
|
||||
|
||||
Each OpenRouter model is emitted once for OPENROUTER (full id) and,
|
||||
when the slug maps to a native provider, once more with just the
|
||||
|
|
@ -110,6 +132,15 @@ def _process_models(raw_models: list[dict]) -> list[dict]:
|
|||
if not _is_text_output_model(model):
|
||||
continue
|
||||
|
||||
if not _supports_tool_calling(model):
|
||||
continue
|
||||
|
||||
if not _has_sufficient_context(model):
|
||||
continue
|
||||
|
||||
if not _is_allowed_model(model):
|
||||
continue
|
||||
|
||||
provider_slug, model_name = model_id.split("/", 1)
|
||||
context_window = _format_context_length(context_length)
|
||||
|
||||
|
|
|
|||
291
surfsense_backend/app/services/openrouter_integration_service.py
Normal file
291
surfsense_backend/app/services/openrouter_integration_service.py
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
"""
|
||||
OpenRouter Integration Service
|
||||
|
||||
Dynamically fetches all available models from the OpenRouter public API
|
||||
and generates virtual global LLM config entries. These entries are injected
|
||||
into config.GLOBAL_LLM_CONFIGS so they appear alongside static YAML configs
|
||||
in the model selector.
|
||||
|
||||
All actual LLM calls go through LiteLLM with the ``openrouter/`` prefix --
|
||||
this service only manages the catalogue, not the inference path.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
|
||||
|
||||
# Sentinel value stored on each generated config so we can distinguish
|
||||
# dynamic OpenRouter entries from hand-written YAML entries during refresh.
|
||||
_OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__"
|
||||
|
||||
|
||||
def _is_text_output_model(model: dict) -> bool:
|
||||
"""Return True if the model produces text output only (skip image/audio generators)."""
|
||||
output_mods = model.get("architecture", {}).get("output_modalities", [])
|
||||
return output_mods == ["text"]
|
||||
|
||||
|
||||
def _supports_tool_calling(model: dict) -> bool:
|
||||
"""Return True if the model supports function/tool calling."""
|
||||
supported = model.get("supported_parameters") or []
|
||||
return "tools" in supported
|
||||
|
||||
|
||||
MIN_CONTEXT_LENGTH = 100_000
|
||||
|
||||
# Provider slugs whose backend is fundamentally incompatible with our agent's
|
||||
# tool-call message flow (e.g. Amazon Bedrock requires toolConfig alongside
|
||||
# tool history which OpenRouter doesn't relay).
|
||||
_EXCLUDED_PROVIDER_SLUGS = {"amazon"}
|
||||
|
||||
_EXCLUDED_MODEL_IDS: set[str] = {
|
||||
# Deprecated / removed upstream
|
||||
"openai/gpt-4-1106-preview",
|
||||
"openai/gpt-4-turbo-preview",
|
||||
# Permanently no-capacity variant
|
||||
"openai/gpt-4o:extended",
|
||||
# Non-serverless model that requires a dedicated endpoint
|
||||
"arcee-ai/virtuoso-large",
|
||||
# Deep-research models reject standard params (temperature, etc.)
|
||||
"openai/o3-deep-research",
|
||||
"openai/o4-mini-deep-research",
|
||||
}
|
||||
|
||||
_EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
|
||||
|
||||
|
||||
def _has_sufficient_context(model: dict) -> bool:
|
||||
"""Return True if the model's context window is at least MIN_CONTEXT_LENGTH."""
|
||||
ctx = model.get("context_length") or 0
|
||||
return ctx >= MIN_CONTEXT_LENGTH
|
||||
|
||||
|
||||
def _is_compatible_provider(model: dict) -> bool:
|
||||
"""Return False for models from providers known to be incompatible."""
|
||||
model_id = model.get("id", "")
|
||||
slug = model_id.split("/", 1)[0] if "/" in model_id else ""
|
||||
return slug not in _EXCLUDED_PROVIDER_SLUGS
|
||||
|
||||
|
||||
def _is_allowed_model(model: dict) -> bool:
|
||||
"""Return False for specific model IDs known to be broken or incompatible."""
|
||||
model_id = model.get("id", "")
|
||||
if model_id in _EXCLUDED_MODEL_IDS:
|
||||
return False
|
||||
base_id = model_id.split(":")[0]
|
||||
return not base_id.endswith(_EXCLUDED_MODEL_SUFFIXES)
|
||||
|
||||
|
||||
def _fetch_models_sync() -> list[dict] | None:
|
||||
"""Synchronous fetch for use during startup (before the event loop is running)."""
|
||||
try:
|
||||
with httpx.Client(timeout=20) as client:
|
||||
response = client.get(OPENROUTER_API_URL)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("data", [])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch OpenRouter models (sync): %s", e)
|
||||
return None
|
||||
|
||||
|
||||
async def _fetch_models_async() -> list[dict] | None:
|
||||
"""Async fetch for background refresh."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=20) as client:
|
||||
response = await client.get(OPENROUTER_API_URL)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("data", [])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch OpenRouter models (async): %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _generate_configs(
|
||||
raw_models: list[dict],
|
||||
settings: dict[str, Any],
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Convert raw OpenRouter model entries into global LLM config dicts.
|
||||
|
||||
Models are sorted by ID for deterministic, stable ID assignment across
|
||||
restarts and refreshes.
|
||||
"""
|
||||
id_offset: int = settings.get("id_offset", -10000)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
billing_tier: str = settings.get("billing_tier", "premium")
|
||||
anonymous_enabled: bool = settings.get("anonymous_enabled", False)
|
||||
seo_enabled: bool = settings.get("seo_enabled", False)
|
||||
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
tpm: int = settings.get("tpm", 1000000)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
system_instructions: str = settings.get("system_instructions", "")
|
||||
use_default: bool = settings.get("use_default_system_instructions", True)
|
||||
citations_enabled: bool = settings.get("citations_enabled", True)
|
||||
|
||||
text_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_text_output_model(m)
|
||||
and _supports_tool_calling(m)
|
||||
and _has_sufficient_context(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
text_models.sort(key=lambda m: m["id"])
|
||||
|
||||
configs: list[dict] = []
|
||||
for idx, model in enumerate(text_models):
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": id_offset - idx,
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter",
|
||||
"billing_tier": billing_tier,
|
||||
"anonymous_enabled": anonymous_enabled,
|
||||
"seo_enabled": seo_enabled,
|
||||
"seo_slug": None,
|
||||
"quota_reserve_tokens": quota_reserve_tokens,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "",
|
||||
"rpm": rpm,
|
||||
"tpm": tpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"system_instructions": system_instructions,
|
||||
"use_default_system_instructions": use_default,
|
||||
"citations_enabled": citations_enabled,
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
class OpenRouterIntegrationService:
|
||||
"""Singleton that manages the dynamic OpenRouter model catalogue."""
|
||||
|
||||
_instance: "OpenRouterIntegrationService | None" = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._settings: dict[str, Any] = {}
|
||||
self._configs: list[dict] = []
|
||||
self._configs_by_id: dict[int, dict] = {}
|
||||
self._initialized = False
|
||||
self._refresh_task: asyncio.Task | None = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "OpenRouterIntegrationService":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def is_initialized(cls) -> bool:
|
||||
return cls._instance is not None and cls._instance._initialized
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Initialisation (called at startup, before event loop for Celery)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def initialize(self, settings: dict[str, Any]) -> list[dict]:
|
||||
"""
|
||||
Fetch models synchronously and generate configs.
|
||||
Returns the generated configs list.
|
||||
"""
|
||||
self._settings = settings
|
||||
raw_models = _fetch_models_sync()
|
||||
if raw_models is None:
|
||||
logger.warning("OpenRouter integration: could not fetch models at startup")
|
||||
self._initialized = True
|
||||
return []
|
||||
|
||||
self._configs = _generate_configs(raw_models, settings)
|
||||
self._configs_by_id = {c["id"]: c for c in self._configs}
|
||||
self._initialized = True
|
||||
|
||||
logger.info(
|
||||
"OpenRouter integration: loaded %d models (IDs %d to %d)",
|
||||
len(self._configs),
|
||||
self._configs[0]["id"] if self._configs else 0,
|
||||
self._configs[-1]["id"] if self._configs else 0,
|
||||
)
|
||||
return self._configs
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background refresh
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def refresh(self) -> None:
|
||||
"""Re-fetch from OpenRouter and atomically swap configs in GLOBAL_LLM_CONFIGS."""
|
||||
raw_models = await _fetch_models_async()
|
||||
if raw_models is None:
|
||||
logger.warning("OpenRouter refresh: fetch failed, keeping stale list")
|
||||
return
|
||||
|
||||
new_configs = _generate_configs(raw_models, self._settings)
|
||||
new_by_id = {c["id"]: c for c in new_configs}
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
static_configs = [
|
||||
c
|
||||
for c in app_config.GLOBAL_LLM_CONFIGS
|
||||
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||
]
|
||||
app_config.GLOBAL_LLM_CONFIGS = static_configs + new_configs
|
||||
|
||||
self._configs = new_configs
|
||||
self._configs_by_id = new_by_id
|
||||
|
||||
logger.info("OpenRouter refresh: updated to %d models", len(new_configs))
|
||||
|
||||
async def _refresh_loop(self, interval_hours: float) -> None:
|
||||
interval_sec = interval_hours * 3600
|
||||
while True:
|
||||
await asyncio.sleep(interval_sec)
|
||||
try:
|
||||
await self.refresh()
|
||||
except Exception:
|
||||
logger.exception("OpenRouter background refresh failed")
|
||||
|
||||
def start_background_refresh(self, interval_hours: float) -> None:
|
||||
if interval_hours <= 0:
|
||||
return
|
||||
loop = asyncio.get_event_loop()
|
||||
self._refresh_task = loop.create_task(self._refresh_loop(interval_hours))
|
||||
logger.info(
|
||||
"OpenRouter background refresh started (every %.1fh)", interval_hours
|
||||
)
|
||||
|
||||
def stop_background_refresh(self) -> None:
|
||||
if self._refresh_task is not None and not self._refresh_task.done():
|
||||
self._refresh_task.cancel()
|
||||
self._refresh_task = None
|
||||
logger.info("OpenRouter background refresh stopped")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Accessors
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_configs(self) -> list[dict]:
|
||||
return self._configs
|
||||
|
||||
def get_config_by_id(self, config_id: int) -> dict | None:
|
||||
return self._configs_by_id.get(config_id)
|
||||
|
|
@ -35,7 +35,7 @@ from app.agents.new_chat.llm_config import (
|
|||
create_chat_litellm_from_agent_config,
|
||||
create_chat_litellm_from_config,
|
||||
load_agent_config,
|
||||
load_llm_config_from_yaml,
|
||||
load_global_llm_config_by_id,
|
||||
)
|
||||
from app.agents.new_chat.memory_extraction import (
|
||||
extract_and_save_memory,
|
||||
|
|
@ -1205,8 +1205,8 @@ async def stream_new_chat(
|
|||
# Create ChatLiteLLM from AgentConfig
|
||||
llm = create_chat_litellm_from_agent_config(agent_config)
|
||||
else:
|
||||
# Negative ID: Load from YAML (global configs)
|
||||
llm_config = load_llm_config_from_yaml(llm_config_id=llm_config_id)
|
||||
# Negative ID: Load from in-memory global configs (includes dynamic OpenRouter models)
|
||||
llm_config = load_global_llm_config_by_id(llm_config_id)
|
||||
if not llm_config:
|
||||
yield streaming_service.format_error(
|
||||
f"Failed to load LLM config with id {llm_config_id}"
|
||||
|
|
@ -1214,9 +1214,8 @@ async def stream_new_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Create ChatLiteLLM from YAML config dict
|
||||
# Create ChatLiteLLM from global config dict
|
||||
llm = create_chat_litellm_from_config(llm_config)
|
||||
# Create AgentConfig from YAML for consistency (uses defaults for prompt settings)
|
||||
agent_config = AgentConfig.from_yaml_config(llm_config)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
|
||||
|
|
@ -1224,8 +1223,14 @@ async def stream_new_chat(
|
|||
llm_config_id,
|
||||
)
|
||||
|
||||
# Premium quota reservation
|
||||
if agent_config and agent_config.is_premium and user_id:
|
||||
# Premium quota reservation — applies to explicitly premium configs
|
||||
# AND Auto mode (which may route to premium models).
|
||||
_needs_premium_quota = (
|
||||
agent_config is not None
|
||||
and user_id
|
||||
and (agent_config.is_premium or agent_config.is_auto_mode)
|
||||
)
|
||||
if _needs_premium_quota:
|
||||
import uuid as _uuid
|
||||
|
||||
from app.config import config as _app_config
|
||||
|
|
@ -1246,11 +1251,16 @@ async def stream_new_chat(
|
|||
)
|
||||
_premium_reserved = reserve_amount
|
||||
if not quota_result.allowed:
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
if agent_config.is_premium:
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
# Auto mode: quota exhausted but we can still proceed
|
||||
# (the router may pick a free model). Reset reservation.
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
|
|
@ -1658,17 +1668,27 @@ async def stream_new_chat(
|
|||
chat_id, generated_title
|
||||
)
|
||||
|
||||
# Finalize premium quota with actual tokens
|
||||
# Finalize premium quota with actual tokens.
|
||||
# For Auto mode, only count tokens from calls that used premium models.
|
||||
if _premium_request_id and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
if agent_config and agent_config.is_auto_mode:
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
|
||||
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
|
||||
accumulator.calls
|
||||
)
|
||||
else:
|
||||
actual_premium_tokens = accumulator.grand_total
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_premium_request_id,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
actual_tokens=actual_premium_tokens,
|
||||
reserved_tokens=_premium_reserved,
|
||||
)
|
||||
except Exception:
|
||||
|
|
@ -1856,7 +1876,7 @@ async def stream_resume_chat(
|
|||
return
|
||||
llm = create_chat_litellm_from_agent_config(agent_config)
|
||||
else:
|
||||
llm_config = load_llm_config_from_yaml(llm_config_id=llm_config_id)
|
||||
llm_config = load_global_llm_config_by_id(llm_config_id)
|
||||
if not llm_config:
|
||||
yield streaming_service.format_error(
|
||||
f"Failed to load LLM config with id {llm_config_id}"
|
||||
|
|
@ -1869,6 +1889,44 @@ async def stream_resume_chat(
|
|||
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
# Premium quota reservation (same logic as stream_new_chat)
|
||||
_resume_premium_reserved = 0
|
||||
_resume_premium_request_id: str | None = None
|
||||
_resume_needs_premium = (
|
||||
agent_config is not None
|
||||
and user_id
|
||||
and (agent_config.is_premium or agent_config.is_auto_mode)
|
||||
)
|
||||
if _resume_needs_premium:
|
||||
import uuid as _uuid
|
||||
|
||||
from app.config import config as _app_config
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
_resume_premium_request_id = _uuid.uuid4().hex[:16]
|
||||
reserve_amount = min(
|
||||
agent_config.quota_reserve_tokens
|
||||
or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
)
|
||||
async with shielded_async_session() as quota_session:
|
||||
quota_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_resume_premium_request_id,
|
||||
reserve_tokens=reserve_amount,
|
||||
)
|
||||
_resume_premium_reserved = reserve_amount
|
||||
if not quota_result.allowed:
|
||||
if agent_config.is_premium:
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
yield streaming_service.format_done()
|
||||
|
|
@ -1982,6 +2040,35 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Finalize premium quota for resume path
|
||||
if _resume_premium_request_id and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
if agent_config and agent_config.is_auto_mode:
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
|
||||
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
|
||||
accumulator.calls
|
||||
)
|
||||
else:
|
||||
actual_premium_tokens = accumulator.grand_total
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_resume_premium_request_id,
|
||||
actual_tokens=actual_premium_tokens,
|
||||
reserved_tokens=_resume_premium_reserved,
|
||||
)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s (resume)",
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
|
||||
|
|
@ -2018,6 +2105,23 @@ async def stream_resume_chat(
|
|||
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
# Release premium reservation if not finalized
|
||||
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
reserved_tokens=_resume_premium_reserved,
|
||||
)
|
||||
_resume_premium_reserved = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to release premium quota for user %s (resume)", user_id
|
||||
)
|
||||
|
||||
try:
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,20 @@ if TYPE_CHECKING:
|
|||
from app.db import ChatVisibility
|
||||
|
||||
|
||||
import re
|
||||
|
||||
_FENCE_RE = re.compile(
|
||||
r"^```(?:\w+)?\s*\n(.*?)```\s*$",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def strip_markdown_fences(text: str) -> str:
|
||||
"""Remove a single markdown code fence (```json ... ```) wrapper if present."""
|
||||
m = _FENCE_RE.match(text.strip())
|
||||
return m.group(1).strip() if m else text
|
||||
|
||||
|
||||
def extract_text_content(content: str | dict | list) -> str:
|
||||
"""Extract plain text content from various message formats."""
|
||||
if isinstance(content, str):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue