mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
hotpatch(cloud): add llm load balancing
This commit is contained in:
parent
5d5f9d3bfb
commit
6fb656fd8f
21 changed files with 1324 additions and 103 deletions
|
|
@ -0,0 +1,75 @@
|
|||
"""Migrate global LLM configs to Auto mode
|
||||
|
||||
Revision ID: 84
|
||||
Revises: 83
|
||||
|
||||
This migration updates existing search spaces that use global LLM configs
|
||||
(negative IDs) to use the new Auto mode (ID 0) instead.
|
||||
|
||||
Auto mode uses LiteLLM Router to automatically load balance requests across
|
||||
all configured global LLM providers, which helps avoid rate limits.
|
||||
|
||||
Changes:
|
||||
1. Update agent_llm_id from negative values to 0 (Auto mode)
|
||||
2. Update document_summary_llm_id from negative values to 0 (Auto mode)
|
||||
3. Update NULL values to 0 (Auto mode) as the new default
|
||||
|
||||
Note: This migration preserves any custom user-created LLM configs (positive IDs).
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "84"
|
||||
down_revision: str | None = "83"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Migrate global LLM config IDs (negative) and NULL to Auto mode (0)."""
|
||||
# Update agent_llm_id: convert negative values and NULL to 0 (Auto mode)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE searchspaces
|
||||
SET agent_llm_id = 0
|
||||
WHERE agent_llm_id < 0 OR agent_llm_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Update document_summary_llm_id: convert negative values and NULL to 0 (Auto mode)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE searchspaces
|
||||
SET document_summary_llm_id = 0
|
||||
WHERE document_summary_llm_id < 0 OR document_summary_llm_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Revert Auto mode back to the first global config (ID -1).
|
||||
|
||||
Note: This is a best-effort revert. We cannot know which specific
|
||||
global config each search space was using before, so we default
|
||||
to -1 (typically the first/primary global config).
|
||||
"""
|
||||
# Revert agent_llm_id from Auto mode (0) back to first global config (-1)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE searchspaces
|
||||
SET agent_llm_id = -1
|
||||
WHERE agent_llm_id = 0
|
||||
"""
|
||||
)
|
||||
|
||||
# Revert document_summary_llm_id from Auto mode (0) back to first global config (-1)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE searchspaces
|
||||
SET document_summary_llm_id = -1
|
||||
WHERE document_summary_llm_id = 0
|
||||
"""
|
||||
)
|
||||
|
|
@ -10,8 +10,8 @@ from collections.abc import Sequence
|
|||
from typing import Any
|
||||
|
||||
from deepagents import create_deep_agent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from langgraph.types import Checkpointer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -114,7 +114,7 @@ def _map_connectors_to_searchable_types(
|
|||
|
||||
|
||||
async def create_surfsense_deep_agent(
|
||||
llm: ChatLiteLLM,
|
||||
llm: BaseChatModel,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
connector_service: ConnectorService,
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@
|
|||
LLM configuration utilities for SurfSense agents.
|
||||
|
||||
This module provides functions for loading LLM configurations from:
|
||||
1. YAML files (global configs with negative IDs)
|
||||
2. Database NewLLMConfig table (user-created configs with positive IDs)
|
||||
1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing
|
||||
2. YAML files (global configs with negative IDs)
|
||||
3. Database NewLLMConfig table (user-created configs with positive IDs)
|
||||
|
||||
It also provides utilities for creating ChatLiteLLM instances and
|
||||
managing prompt configurations.
|
||||
|
|
@ -17,6 +18,13 @@ from langchain_litellm import ChatLiteLLM
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.llm_router_service import (
|
||||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
# Provider mapping for LiteLLM model string construction
|
||||
PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
|
|
@ -58,6 +66,7 @@ class AgentConfig:
|
|||
Complete configuration for the SurfSense agent.
|
||||
|
||||
This combines LLM settings with prompt configuration from NewLLMConfig.
|
||||
Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing.
|
||||
"""
|
||||
|
||||
# LLM Model Settings
|
||||
|
|
@ -77,6 +86,32 @@ class AgentConfig:
|
|||
config_id: int | None = None
|
||||
config_name: str | None = None
|
||||
|
||||
# Auto mode flag
|
||||
is_auto_mode: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_auto_mode(cls) -> "AgentConfig":
|
||||
"""
|
||||
Create an AgentConfig for Auto mode (LiteLLM Router load balancing).
|
||||
|
||||
Returns:
|
||||
AgentConfig instance configured for Auto mode
|
||||
"""
|
||||
return cls(
|
||||
provider="AUTO",
|
||||
model_name="auto",
|
||||
api_key="", # Not needed for router
|
||||
api_base=None,
|
||||
custom_provider=None,
|
||||
litellm_params=None,
|
||||
system_instructions=None,
|
||||
use_default_system_instructions=True,
|
||||
citations_enabled=True,
|
||||
config_id=AUTO_MODE_ID,
|
||||
config_name="Auto (Load Balanced)",
|
||||
is_auto_mode=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_new_llm_config(cls, config) -> "AgentConfig":
|
||||
"""
|
||||
|
|
@ -102,6 +137,7 @@ class AgentConfig:
|
|||
citations_enabled=config.citations_enabled,
|
||||
config_id=config.id,
|
||||
config_name=config.name,
|
||||
is_auto_mode=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -138,6 +174,7 @@ class AgentConfig:
|
|||
citations_enabled=yaml_config.get("citations_enabled", True),
|
||||
config_id=yaml_config.get("id"),
|
||||
config_name=yaml_config.get("name"),
|
||||
is_auto_mode=False,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -261,20 +298,28 @@ async def load_agent_config(
|
|||
search_space_id: int | None = None,
|
||||
) -> "AgentConfig | None":
|
||||
"""
|
||||
Load an agent configuration, supporting both YAML (negative IDs) and database (positive IDs) configs.
|
||||
Load an agent configuration, supporting Auto mode, YAML, and database configs.
|
||||
|
||||
This is the main entry point for loading configurations:
|
||||
- ID 0: Auto mode (uses LiteLLM Router for load balancing)
|
||||
- Negative IDs: Load from YAML file (global configs)
|
||||
- Positive IDs: Load from NewLLMConfig database table
|
||||
|
||||
Args:
|
||||
session: AsyncSession for database access
|
||||
config_id: The config ID (negative for YAML, positive for database)
|
||||
config_id: The config ID (0 for Auto, negative for YAML, positive for database)
|
||||
search_space_id: Optional search space ID for context
|
||||
|
||||
Returns:
|
||||
AgentConfig instance or None if not found
|
||||
"""
|
||||
# Auto mode (ID 0) - use LiteLLM Router
|
||||
if is_auto_mode(config_id):
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
return AgentConfig.from_auto_mode()
|
||||
|
||||
if config_id < 0:
|
||||
# Load from YAML (global configs have negative IDs)
|
||||
yaml_config = load_llm_config_from_yaml(config_id)
|
||||
|
|
@ -324,16 +369,30 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
|||
|
||||
def create_chat_litellm_from_agent_config(
|
||||
agent_config: AgentConfig,
|
||||
) -> ChatLiteLLM | None:
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""
|
||||
Create a ChatLiteLLM instance from an AgentConfig.
|
||||
Create a ChatLiteLLM or ChatLiteLLMRouter instance from an AgentConfig.
|
||||
|
||||
For Auto mode configs, returns a ChatLiteLLMRouter that uses LiteLLM Router
|
||||
for automatic load balancing across available providers.
|
||||
|
||||
Args:
|
||||
agent_config: AgentConfig instance
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM instance or None on error
|
||||
ChatLiteLLM or ChatLiteLLMRouter instance, or None on error
|
||||
"""
|
||||
# Handle Auto mode - return ChatLiteLLMRouter
|
||||
if agent_config.is_auto_mode:
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
try:
|
||||
return ChatLiteLLMRouter()
|
||||
except Exception as e:
|
||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
||||
# Build the model string
|
||||
if agent_config.custom_provider:
|
||||
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from app.agents.new_chat.checkpointer import (
|
|||
close_checkpointer,
|
||||
setup_checkpointer_tables,
|
||||
)
|
||||
from app.config import config
|
||||
from app.config import config, initialize_llm_router
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.routes import router as crud_router
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
|
|
@ -23,6 +23,8 @@ async def lifespan(app: FastAPI):
|
|||
await create_db_and_tables()
|
||||
# Setup LangGraph checkpointer tables for conversation persistence
|
||||
await setup_checkpointer_tables()
|
||||
# Initialize LLM Router for Auto mode load balancing
|
||||
initialize_llm_router()
|
||||
# Seed Surfsense documentation
|
||||
await seed_surfsense_docs()
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -4,11 +4,25 @@ import os
|
|||
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
from celery.signals import worker_process_init
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs):
|
||||
"""Initialize the LLM Router when a Celery worker process starts.
|
||||
|
||||
This ensures the Auto mode (LiteLLM Router) is available for background tasks
|
||||
like document summarization.
|
||||
"""
|
||||
from app.config import initialize_llm_router
|
||||
|
||||
initialize_llm_router()
|
||||
|
||||
|
||||
# Get Celery configuration from environment
|
||||
CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/0")
|
||||
|
|
|
|||
|
|
@ -48,6 +48,63 @@ def load_global_llm_configs():
|
|||
return []
|
||||
|
||||
|
||||
def load_router_settings():
|
||||
"""
|
||||
Load router settings for Auto mode from YAML file.
|
||||
Falls back to default settings if not found.
|
||||
|
||||
Returns:
|
||||
dict: Router settings dictionary
|
||||
"""
|
||||
# Default router settings
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing",
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60,
|
||||
}
|
||||
|
||||
# Try main config file first
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
return default_settings
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
settings = data.get("router_settings", {})
|
||||
# Merge with defaults
|
||||
return {**default_settings, **settings}
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load router settings: {e}")
|
||||
return default_settings
|
||||
|
||||
|
||||
def initialize_llm_router():
|
||||
"""
|
||||
Initialize the LLM Router service for Auto mode.
|
||||
This should be called during application startup.
|
||||
"""
|
||||
global_configs = load_global_llm_configs()
|
||||
router_settings = load_router_settings()
|
||||
|
||||
if not global_configs:
|
||||
print("Info: No global LLM configs found, Auto mode will not be available")
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
|
||||
LLMRouterService.initialize(global_configs, router_settings)
|
||||
print(
|
||||
f"Info: LLM Router initialized with {len(global_configs)} models "
|
||||
f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize LLM Router: {e}")
|
||||
|
||||
|
||||
class Config:
|
||||
# Check if ffmpeg is installed
|
||||
if not is_ffmpeg_installed():
|
||||
|
|
@ -156,6 +213,9 @@ class Config:
|
|||
# These can be used as default options for users
|
||||
GLOBAL_LLM_CONFIGS = load_global_llm_configs()
|
||||
|
||||
# Router settings for Auto mode (LiteLLM Router load balancing)
|
||||
ROUTER_SETTINGS = load_router_settings()
|
||||
|
||||
# Chonkie Configuration | Edit this to your needs
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||
# Azure OpenAI credentials from environment variables
|
||||
|
|
|
|||
|
|
@ -10,10 +10,39 @@
|
|||
# These configurations will be available to all users as a convenient option
|
||||
# Users can choose to use these global configs or add their own
|
||||
#
|
||||
# AUTO MODE (Recommended):
|
||||
# - Auto mode (ID: 0) uses LiteLLM Router to automatically load balance across all global configs
|
||||
# - This helps avoid rate limits by distributing requests across multiple providers
|
||||
# - New users are automatically assigned Auto mode by default
|
||||
# - Configure router_settings below to customize the load balancing behavior
|
||||
#
|
||||
# Structure matches NewLLMConfig:
|
||||
# - LLM model configuration (provider, model_name, api_key, etc.)
|
||||
# - Prompt configuration (system_instructions, citations_enabled)
|
||||
|
||||
# Router Settings for Auto Mode
|
||||
# These settings control how the LiteLLM Router distributes requests across models
|
||||
router_settings:
|
||||
# Routing strategy options:
|
||||
# - "usage-based-routing": Routes to deployment with lowest current usage (recommended for rate limits)
|
||||
# - "simple-shuffle": Random distribution with optional RPM/TPM weighting
|
||||
# - "least-busy": Routes to least busy deployment
|
||||
# - "latency-based-routing": Routes based on response latency
|
||||
routing_strategy: "usage-based-routing"
|
||||
|
||||
# Number of retries before failing
|
||||
num_retries: 3
|
||||
|
||||
# Number of failures allowed before cooling down a deployment
|
||||
allowed_fails: 3
|
||||
|
||||
# Cooldown time in seconds after allowed_fails is exceeded
|
||||
cooldown_time: 60
|
||||
|
||||
# Fallback models (optional) - when primary fails, try these
|
||||
# Format: [{"primary_model": ["fallback1", "fallback2"]}]
|
||||
# fallbacks: []
|
||||
|
||||
global_llm_configs:
|
||||
# Example: OpenAI GPT-4 Turbo with citations enabled
|
||||
- id: -1
|
||||
|
|
@ -23,6 +52,9 @@ global_llm_configs:
|
|||
model_name: "gpt-4-turbo-preview"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
# Rate limits for load balancing (requests/tokens per minute)
|
||||
rpm: 500 # Requests per minute
|
||||
tpm: 100000 # Tokens per minute
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
|
|
@ -39,6 +71,8 @@ global_llm_configs:
|
|||
model_name: "claude-3-opus-20240229"
|
||||
api_key: "sk-ant-your-anthropic-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 1000
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
|
|
@ -54,6 +88,8 @@ global_llm_configs:
|
|||
model_name: "gpt-3.5-turbo"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 3500 # GPT-3.5 has higher rate limits
|
||||
tpm: 200000
|
||||
litellm_params:
|
||||
temperature: 0.5
|
||||
max_tokens: 2000
|
||||
|
|
@ -69,6 +105,8 @@ global_llm_configs:
|
|||
model_name: "deepseek-chat"
|
||||
api_key: "your-deepseek-api-key-here"
|
||||
api_base: "https://api.deepseek.com/v1"
|
||||
rpm: 60
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
|
|
@ -92,6 +130,8 @@ global_llm_configs:
|
|||
model_name: "llama3-70b-8192"
|
||||
api_key: "your-groq-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 30 # Groq has lower rate limits on free tier
|
||||
tpm: 14400
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 8000
|
||||
|
|
@ -100,6 +140,7 @@ global_llm_configs:
|
|||
citations_enabled: true
|
||||
|
||||
# Notes:
|
||||
# - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing
|
||||
# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB)
|
||||
# - IDs should be unique and sequential (e.g., -1, -2, -3, etc.)
|
||||
# - The 'api_key' field will not be exposed to users via API
|
||||
|
|
@ -107,3 +148,5 @@ global_llm_configs:
|
|||
# - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty
|
||||
# - citations_enabled: true = include citation instructions, false = include anti-citation instructions
|
||||
# - All standard LiteLLM providers are supported
|
||||
# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute)
|
||||
# These help the router distribute load evenly and avoid rate limit errors
|
||||
|
|
|
|||
|
|
@ -807,11 +807,16 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
) # User's custom instructions
|
||||
|
||||
# Search space-level LLM preferences (shared by all members)
|
||||
# Note: These can be negative IDs for global configs (from YAML) or positive IDs for custom configs (from DB)
|
||||
agent_llm_id = Column(Integer, nullable=True) # For agent/chat operations
|
||||
# Note: ID values:
|
||||
# - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces
|
||||
# - Negative IDs: Global configs from YAML
|
||||
# - Positive IDs: Custom configs from DB (NewLLMConfig table)
|
||||
agent_llm_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For agent/chat operations, defaults to Auto mode
|
||||
document_summary_llm_id = Column(
|
||||
Integer, nullable=True
|
||||
) # For document summarization
|
||||
Integer, nullable=True, default=0
|
||||
) # For document summarization, defaults to Auto mode
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
|
|
|
|||
|
|
@ -50,13 +50,33 @@ async def get_global_new_llm_configs(
|
|||
These are pre-configured by the system administrator and available to all users.
|
||||
API keys are not exposed through this endpoint.
|
||||
|
||||
Global configs have negative IDs to distinguish from user-created configs.
|
||||
Includes:
|
||||
- Auto mode (ID 0): Uses LiteLLM Router for automatic load balancing
|
||||
- Global configs (negative IDs): Individual pre-configured LLM providers
|
||||
"""
|
||||
try:
|
||||
global_configs = config.GLOBAL_LLM_CONFIGS
|
||||
|
||||
# Transform to new structure, hiding API keys
|
||||
safe_configs = []
|
||||
# Start with Auto mode as the first option (recommended default)
|
||||
safe_configs = [
|
||||
{
|
||||
"id": 0,
|
||||
"name": "Auto (Load Balanced)",
|
||||
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling. Recommended for most users.",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"litellm_params": {},
|
||||
"system_instructions": "",
|
||||
"use_default_system_instructions": True,
|
||||
"citations_enabled": True,
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
]
|
||||
|
||||
# Add individual global configs
|
||||
for cfg in global_configs:
|
||||
safe_config = {
|
||||
"id": cfg.get("id"),
|
||||
|
|
|
|||
|
|
@ -314,11 +314,29 @@ async def _get_llm_config_by_id(
|
|||
) -> dict | None:
|
||||
"""
|
||||
Get an LLM config by ID as a dictionary. Returns database config for positive IDs,
|
||||
global config for negative IDs, or None if ID is None.
|
||||
global config for negative IDs, Auto mode config for ID 0, or None if ID is None.
|
||||
"""
|
||||
if config_id is None:
|
||||
return None
|
||||
|
||||
# Auto mode (ID 0) - uses LiteLLM Router for load balancing
|
||||
if config_id == 0:
|
||||
return {
|
||||
"id": 0,
|
||||
"name": "Auto (Load Balanced)",
|
||||
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"litellm_params": {},
|
||||
"system_instructions": "",
|
||||
"use_default_system_instructions": True,
|
||||
"citations_enabled": True,
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
# Global config - find from YAML
|
||||
global_configs = config.GLOBAL_LLM_CONFIGS
|
||||
|
|
|
|||
|
|
@ -135,14 +135,19 @@ class GlobalNewLLMConfigRead(BaseModel):
|
|||
Schema for reading global LLM configs from YAML.
|
||||
Global configs have negative IDs and no search_space_id.
|
||||
API key is hidden for security.
|
||||
|
||||
ID 0 is reserved for Auto mode which uses LiteLLM Router for load balancing.
|
||||
"""
|
||||
|
||||
id: int = Field(..., description="Negative ID for global configs")
|
||||
id: int = Field(
|
||||
...,
|
||||
description="Config ID: 0 for Auto mode, negative for global configs",
|
||||
)
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
# LLM Model Configuration (no api_key)
|
||||
provider: str # String because YAML doesn't enforce enum
|
||||
provider: str # String because YAML doesn't enforce enum, "AUTO" for Auto mode
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
|
|
@ -154,6 +159,7 @@ class GlobalNewLLMConfigRead(BaseModel):
|
|||
citations_enabled: bool = True
|
||||
|
||||
is_global: bool = True # Always true for global configs
|
||||
is_auto_mode: bool = False # True only for Auto mode (ID 0)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
|
|||
632
surfsense_backend/app/services/llm_router_service.py
Normal file
632
surfsense_backend/app/services/llm_router_service.py
Normal file
|
|
@ -0,0 +1,632 @@
|
|||
"""
|
||||
LiteLLM Router Service for Load Balancing
|
||||
|
||||
This module provides a singleton LiteLLM Router for automatic load balancing
|
||||
across multiple LLM deployments. It handles:
|
||||
- Rate limit management with automatic cooldowns
|
||||
- Automatic failover and retries
|
||||
- Usage-based routing to distribute load evenly
|
||||
|
||||
The router is initialized from global LLM configs and provides both
|
||||
synchronous ChatLiteLLM-like interface and async methods.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from litellm import Router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
AUTO_MODE_ID = 0
|
||||
|
||||
# Provider mapping for LiteLLM model string construction
|
||||
PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"COMETAPI": "cometapi",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"AWS_BEDROCK": "bedrock", # Legacy support
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"HUGGINGFACE": "huggingface",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
|
||||
|
||||
class LLMRouterService:
|
||||
"""
|
||||
Singleton service for managing LiteLLM Router.
|
||||
|
||||
The router provides automatic load balancing, failover, and rate limit
|
||||
handling across multiple LLM deployments.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_router: Router | None = None
|
||||
_model_list: list[dict] = []
|
||||
_router_settings: dict = {}
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "LLMRouterService":
|
||||
"""Get the singleton instance of the router service."""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def initialize(
|
||||
cls,
|
||||
global_configs: list[dict],
|
||||
router_settings: dict | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the router with global LLM configurations.
|
||||
|
||||
Args:
|
||||
global_configs: List of global LLM config dictionaries from YAML
|
||||
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
|
||||
"""
|
||||
instance = cls.get_instance()
|
||||
|
||||
if instance._initialized:
|
||||
logger.debug("LLM Router already initialized, skipping")
|
||||
return
|
||||
|
||||
# Build model list from global configs
|
||||
model_list = []
|
||||
for config in global_configs:
|
||||
deployment = cls._config_to_deployment(config)
|
||||
if deployment:
|
||||
model_list.append(deployment)
|
||||
|
||||
if not model_list:
|
||||
logger.warning("No valid LLM configs found for router initialization")
|
||||
return
|
||||
|
||||
instance._model_list = model_list
|
||||
instance._router_settings = router_settings or {}
|
||||
|
||||
# Default router settings optimized for rate limit handling
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing", # Best for rate limit management
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60, # Cooldown for 60 seconds after failures
|
||||
"retry_after": 5, # Wait 5 seconds between retries
|
||||
}
|
||||
|
||||
# Merge with provided settings
|
||||
final_settings = {**default_settings, **instance._router_settings}
|
||||
|
||||
try:
|
||||
instance._router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy=final_settings.get(
|
||||
"routing_strategy", "usage-based-routing"
|
||||
),
|
||||
num_retries=final_settings.get("num_retries", 3),
|
||||
allowed_fails=final_settings.get("allowed_fails", 3),
|
||||
cooldown_time=final_settings.get("cooldown_time", 60),
|
||||
set_verbose=False, # Disable verbose logging in production
|
||||
)
|
||||
instance._initialized = True
|
||||
logger.info(
|
||||
f"LLM Router initialized with {len(model_list)} deployments, "
|
||||
f"strategy: {final_settings.get('routing_strategy')}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def _config_to_deployment(cls, config: dict) -> dict | None:
|
||||
"""
|
||||
Convert a global LLM config to a router deployment entry.
|
||||
|
||||
Args:
|
||||
config: Global LLM config dictionary
|
||||
|
||||
Returns:
|
||||
Router deployment dictionary or None if invalid
|
||||
"""
|
||||
try:
|
||||
# Skip if essential fields are missing
|
||||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
# Build model string
|
||||
if config.get("custom_provider"):
|
||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
# Build litellm params
|
||||
litellm_params = {
|
||||
"model": model_string,
|
||||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
# Add optional api_base
|
||||
if config.get("api_base"):
|
||||
litellm_params["api_base"] = config["api_base"]
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if config.get("litellm_params"):
|
||||
litellm_params.update(config["litellm_params"])
|
||||
|
||||
# Extract rate limits if provided
|
||||
deployment = {
|
||||
"model_name": "auto", # All configs use same alias for unified routing
|
||||
"litellm_params": litellm_params,
|
||||
}
|
||||
|
||||
# Add rate limits from config if available
|
||||
if config.get("rpm"):
|
||||
deployment["rpm"] = config["rpm"]
|
||||
if config.get("tpm"):
|
||||
deployment["tpm"] = config["tpm"]
|
||||
|
||||
return deployment
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert config to deployment: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_router(cls) -> Router | None:
|
||||
"""Get the initialized router instance."""
|
||||
instance = cls.get_instance()
|
||||
return instance._router
|
||||
|
||||
@classmethod
|
||||
def is_initialized(cls) -> bool:
|
||||
"""Check if the router has been initialized."""
|
||||
instance = cls.get_instance()
|
||||
return instance._initialized and instance._router is not None
|
||||
|
||||
@classmethod
|
||||
def get_model_count(cls) -> int:
|
||||
"""Get the number of models in the router."""
|
||||
instance = cls.get_instance()
|
||||
return len(instance._model_list)
|
||||
|
||||
|
||||
class ChatLiteLLMRouter(BaseChatModel):
|
||||
"""
|
||||
A LangChain-compatible chat model that uses LiteLLM Router for load balancing.
|
||||
|
||||
This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM,
|
||||
making it a drop-in replacement for auto-mode routing.
|
||||
"""
|
||||
|
||||
# Use model_config for Pydantic v2 compatibility
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
# Public attributes that Pydantic will manage
|
||||
model: str = "auto"
|
||||
streaming: bool = True
|
||||
|
||||
# Bound tools and tool choice for tool calling
|
||||
_bound_tools: list[dict] | None = None
|
||||
_tool_choice: str | dict | None = None
|
||||
_router: Router | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
router: Router | None = None,
|
||||
bound_tools: list[dict] | None = None,
|
||||
tool_choice: str | dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the ChatLiteLLMRouter.
|
||||
|
||||
Args:
|
||||
router: LiteLLM Router instance. If None, uses the global singleton.
|
||||
bound_tools: Pre-bound tools for tool calling
|
||||
tool_choice: Tool choice configuration
|
||||
"""
|
||||
try:
|
||||
super().__init__(**kwargs)
|
||||
# Store router and tools as private attributes
|
||||
resolved_router = router or LLMRouterService.get_router()
|
||||
object.__setattr__(self, "_router", resolved_router)
|
||||
object.__setattr__(self, "_bound_tools", bound_tools)
|
||||
object.__setattr__(self, "_tool_choice", tool_choice)
|
||||
if not self._router:
|
||||
raise ValueError(
|
||||
"LLM Router not initialized. Call LLMRouterService.initialize() first."
|
||||
)
|
||||
logger.info(
|
||||
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
|
||||
raise
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "litellm-router"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
return {
|
||||
"model": self.model,
|
||||
"model_count": LLMRouterService.get_model_count(),
|
||||
}
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: list[Any],
|
||||
*,
|
||||
tool_choice: str | dict | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "ChatLiteLLMRouter":
|
||||
"""
|
||||
Bind tools to the model for function/tool calling.
|
||||
|
||||
Args:
|
||||
tools: List of tools to bind (can be LangChain tools, Pydantic models, or dicts)
|
||||
tool_choice: Optional tool choice strategy ("auto", "required", "none", or specific tool)
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
New ChatLiteLLMRouter instance with tools bound
|
||||
"""
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
# Convert tools to OpenAI format
|
||||
formatted_tools = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
# Already in dict format
|
||||
formatted_tools.append(tool)
|
||||
else:
|
||||
# Convert using LangChain utility
|
||||
try:
|
||||
formatted_tools.append(convert_to_openai_tool(tool))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert tool {tool}: {e}")
|
||||
continue
|
||||
|
||||
# Create a new instance with tools bound
|
||||
return ChatLiteLLMRouter(
|
||||
router=self._router,
|
||||
bound_tools=formatted_tools if formatted_tools else None,
|
||||
tool_choice=tool_choice,
|
||||
model=self.model,
|
||||
streaming=self.streaming,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""
|
||||
Generate a response using the router (synchronous).
|
||||
"""
|
||||
if not self._router:
|
||||
raise ValueError("Router not initialized")
|
||||
|
||||
# Convert LangChain messages to OpenAI format
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
|
||||
# Add tools if bound
|
||||
call_kwargs = {**kwargs}
|
||||
if self._bound_tools:
|
||||
call_kwargs["tools"] = self._bound_tools
|
||||
if self._tool_choice is not None:
|
||||
call_kwargs["tool_choice"] = self._tool_choice
|
||||
|
||||
# Call router completion
|
||||
response = self._router.completion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
**call_kwargs,
|
||||
)
|
||||
|
||||
# Convert response to ChatResult with potential tool calls
|
||||
message = self._convert_response_to_message(response.choices[0].message)
|
||||
generation = ChatGeneration(message=message)
|
||||
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""
|
||||
Generate a response using the router (asynchronous).
|
||||
"""
|
||||
if not self._router:
|
||||
raise ValueError("Router not initialized")
|
||||
|
||||
# Convert LangChain messages to OpenAI format
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
|
||||
# Add tools if bound
|
||||
call_kwargs = {**kwargs}
|
||||
if self._bound_tools:
|
||||
call_kwargs["tools"] = self._bound_tools
|
||||
if self._tool_choice is not None:
|
||||
call_kwargs["tool_choice"] = self._tool_choice
|
||||
|
||||
# Call router async completion
|
||||
response = await self._router.acompletion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
**call_kwargs,
|
||||
)
|
||||
|
||||
# Convert response to ChatResult with potential tool calls
|
||||
message = self._convert_response_to_message(response.choices[0].message)
|
||||
generation = ChatGeneration(message=message)
|
||||
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Stream a response using the router (synchronous).
|
||||
"""
|
||||
if not self._router:
|
||||
raise ValueError("Router not initialized")
|
||||
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
|
||||
# Add tools if bound
|
||||
call_kwargs = {**kwargs}
|
||||
if self._bound_tools:
|
||||
call_kwargs["tools"] = self._bound_tools
|
||||
if self._tool_choice is not None:
|
||||
call_kwargs["tool_choice"] = self._tool_choice
|
||||
|
||||
# Call router completion with streaming
|
||||
response = self._router.completion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
**call_kwargs,
|
||||
)
|
||||
|
||||
# Yield chunks
|
||||
for chunk in response:
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
delta = chunk.choices[0].delta
|
||||
chunk_msg = self._convert_delta_to_chunk(delta)
|
||||
if chunk_msg:
|
||||
yield ChatGenerationChunk(message=chunk_msg)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Stream a response using the router (asynchronous).
|
||||
"""
|
||||
if not self._router:
|
||||
raise ValueError("Router not initialized")
|
||||
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
|
||||
# Add tools if bound
|
||||
call_kwargs = {**kwargs}
|
||||
if self._bound_tools:
|
||||
call_kwargs["tools"] = self._bound_tools
|
||||
if self._tool_choice is not None:
|
||||
call_kwargs["tool_choice"] = self._tool_choice
|
||||
|
||||
# Call router async completion with streaming
|
||||
response = await self._router.acompletion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
**call_kwargs,
|
||||
)
|
||||
|
||||
# Yield chunks asynchronously
|
||||
async for chunk in response:
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
delta = chunk.choices[0].delta
|
||||
chunk_msg = self._convert_delta_to_chunk(delta)
|
||||
if chunk_msg:
|
||||
yield ChatGenerationChunk(message=chunk_msg)
|
||||
|
||||
def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]:
|
||||
"""Convert LangChain messages to OpenAI format."""
|
||||
from langchain_core.messages import (
|
||||
AIMessage as AIMsg,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
result = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, SystemMessage):
|
||||
result.append({"role": "system", "content": msg.content})
|
||||
elif isinstance(msg, HumanMessage):
|
||||
result.append({"role": "user", "content": msg.content})
|
||||
elif isinstance(msg, AIMsg):
|
||||
ai_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if msg.content:
|
||||
ai_msg["content"] = msg.content
|
||||
# Handle tool calls
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
ai_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.get("name", ""),
|
||||
"arguments": tc.get("args", "{}")
|
||||
if isinstance(tc.get("args"), str)
|
||||
else __import__("json").dumps(tc.get("args", {})),
|
||||
},
|
||||
}
|
||||
for tc in msg.tool_calls
|
||||
]
|
||||
result.append(ai_msg)
|
||||
elif isinstance(msg, ToolMessage):
|
||||
result.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"content": msg.content
|
||||
if isinstance(msg.content, str)
|
||||
else __import__("json").dumps(msg.content),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Fallback for other message types
|
||||
role = getattr(msg, "type", "user")
|
||||
if role == "human":
|
||||
role = "user"
|
||||
elif role == "ai":
|
||||
role = "assistant"
|
||||
result.append({"role": role, "content": msg.content})
|
||||
|
||||
return result
|
||||
|
||||
def _convert_response_to_message(self, response_message: Any) -> AIMessage:
|
||||
"""Convert a LiteLLM response message to a LangChain AIMessage."""
|
||||
import json
|
||||
|
||||
content = getattr(response_message, "content", None) or ""
|
||||
|
||||
# Check for tool calls
|
||||
tool_calls = []
|
||||
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
|
||||
for tc in response_message.tool_calls:
|
||||
tool_call = {
|
||||
"id": tc.id if hasattr(tc, "id") else "",
|
||||
"name": tc.function.name if hasattr(tc, "function") else "",
|
||||
"args": {},
|
||||
}
|
||||
# Parse arguments
|
||||
if hasattr(tc, "function") and hasattr(tc.function, "arguments"):
|
||||
try:
|
||||
tool_call["args"] = json.loads(tc.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
tool_call["args"] = tc.function.arguments
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
if tool_calls:
|
||||
return AIMessage(content=content, tool_calls=tool_calls)
|
||||
return AIMessage(content=content)
|
||||
|
||||
def _convert_delta_to_chunk(self, delta: Any) -> AIMessageChunk | None:
|
||||
"""Convert a streaming delta to an AIMessageChunk."""
|
||||
|
||||
content = getattr(delta, "content", None) or ""
|
||||
|
||||
# Check for tool calls in delta
|
||||
tool_call_chunks = []
|
||||
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
chunk = {
|
||||
"index": tc.index if hasattr(tc, "index") else 0,
|
||||
"id": tc.id if hasattr(tc, "id") else None,
|
||||
"name": tc.function.name
|
||||
if hasattr(tc, "function") and hasattr(tc.function, "name")
|
||||
else None,
|
||||
"args": tc.function.arguments
|
||||
if hasattr(tc, "function") and hasattr(tc.function, "arguments")
|
||||
else "",
|
||||
}
|
||||
tool_call_chunks.append(chunk)
|
||||
|
||||
if content or tool_call_chunks:
|
||||
if tool_call_chunks:
|
||||
return AIMessageChunk(
|
||||
content=content, tool_call_chunks=tool_call_chunks
|
||||
)
|
||||
return AIMessageChunk(content=content)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_auto_mode_llm() -> ChatLiteLLMRouter | None:
|
||||
"""
|
||||
Get a ChatLiteLLMRouter instance for auto mode.
|
||||
|
||||
Returns:
|
||||
ChatLiteLLMRouter instance or None if router not initialized
|
||||
"""
|
||||
if not LLMRouterService.is_initialized():
|
||||
logger.warning("LLM Router not initialized for auto mode")
|
||||
return None
|
||||
|
||||
try:
|
||||
return ChatLiteLLMRouter()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def is_auto_mode(llm_config_id: int | None) -> bool:
|
||||
"""
|
||||
Check if the given LLM config ID represents Auto mode.
|
||||
|
||||
Args:
|
||||
llm_config_id: The LLM config ID to check
|
||||
|
||||
Returns:
|
||||
True if this is Auto mode, False otherwise
|
||||
"""
|
||||
return llm_config_id == AUTO_MODE_ID
|
||||
|
|
@ -8,6 +8,12 @@ from sqlalchemy.future import select
|
|||
|
||||
from app.config import config
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
from app.services.llm_router_service import (
|
||||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
# Configure litellm to automatically drop unsupported parameters
|
||||
litellm.drop_params = True
|
||||
|
|
@ -23,15 +29,26 @@ class LLMRole:
|
|||
def get_global_llm_config(llm_config_id: int) -> dict | None:
|
||||
"""
|
||||
Get a global LLM configuration by ID.
|
||||
Global configs have negative IDs.
|
||||
Global configs have negative IDs. ID 0 is reserved for Auto mode.
|
||||
|
||||
Args:
|
||||
llm_config_id: The ID of the global config (should be negative)
|
||||
llm_config_id: The ID of the global config (should be negative or 0 for Auto)
|
||||
|
||||
Returns:
|
||||
dict: Global config dictionary or None if not found
|
||||
"""
|
||||
if llm_config_id >= 0:
|
||||
# Auto mode (ID 0) is handled separately via the router
|
||||
if llm_config_id == AUTO_MODE_ID:
|
||||
return {
|
||||
"id": AUTO_MODE_ID,
|
||||
"name": "Auto (Load Balanced)",
|
||||
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
|
||||
if llm_config_id > 0:
|
||||
return None
|
||||
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS:
|
||||
|
|
@ -145,19 +162,22 @@ async def validate_llm_config(
|
|||
|
||||
async def get_search_space_llm_instance(
|
||||
session: AsyncSession, search_space_id: int, role: str
|
||||
) -> ChatLiteLLM | None:
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""
|
||||
Get a ChatLiteLLM instance for a specific search space and role.
|
||||
|
||||
LLM preferences are stored at the search space level and shared by all members.
|
||||
|
||||
If Auto mode (ID 0) is configured, returns a ChatLiteLLMRouter that uses
|
||||
LiteLLM Router for automatic load balancing across available providers.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
search_space_id: Search Space ID
|
||||
role: LLM role ('agent' or 'document_summary')
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM instance or None if not found
|
||||
ChatLiteLLM or ChatLiteLLMRouter instance, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Get the search space with its LLM preferences
|
||||
|
|
@ -180,10 +200,28 @@ async def get_search_space_llm_instance(
|
|||
logger.error(f"Invalid LLM role: {role}")
|
||||
return None
|
||||
|
||||
if not llm_config_id:
|
||||
if llm_config_id is None:
|
||||
logger.error(f"No {role} LLM configured for search space {search_space_id}")
|
||||
return None
|
||||
|
||||
# Check for Auto mode (ID 0) - use router for load balancing
|
||||
if is_auto_mode(llm_config_id):
|
||||
if not LLMRouterService.is_initialized():
|
||||
logger.error(
|
||||
"Auto mode requested but LLM Router not initialized. "
|
||||
"Ensure global_llm_config.yaml exists with valid configs."
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
|
||||
)
|
||||
return ChatLiteLLMRouter()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
||||
# Check if this is a global config (negative ID)
|
||||
if llm_config_id < 0:
|
||||
global_config = get_global_llm_config(llm_config_id)
|
||||
|
|
@ -328,14 +366,14 @@ async def get_search_space_llm_instance(
|
|||
|
||||
async def get_agent_llm(
|
||||
session: AsyncSession, search_space_id: int
|
||||
) -> ChatLiteLLM | None:
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Get the search space's agent LLM instance for chat operations."""
|
||||
return await get_search_space_llm_instance(session, search_space_id, LLMRole.AGENT)
|
||||
|
||||
|
||||
async def get_document_summary_llm(
|
||||
session: AsyncSession, search_space_id: int
|
||||
) -> ChatLiteLLM | None:
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Get the search space's document summary LLM instance."""
|
||||
return await get_search_space_llm_instance(
|
||||
session, search_space_id, LLMRole.DOCUMENT_SUMMARY
|
||||
|
|
@ -345,7 +383,7 @@ async def get_document_summary_llm(
|
|||
# Backward-compatible alias (LLM preferences are now per-search-space, not per-user)
|
||||
async def get_user_long_context_llm(
|
||||
session: AsyncSession, user_id: str, search_space_id: int
|
||||
) -> ChatLiteLLM | None:
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""
|
||||
Deprecated: Use get_document_summary_llm instead.
|
||||
The user_id parameter is ignored as LLM preferences are now per-search-space.
|
||||
|
|
|
|||
|
|
@ -1215,8 +1215,12 @@ async def stream_new_chat(
|
|||
|
||||
except Exception as e:
|
||||
# Handle any errors
|
||||
import traceback
|
||||
|
||||
error_message = f"Error during chat: {e!s}"
|
||||
print(f"[stream_new_chat] {error_message}")
|
||||
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
|
||||
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
|
||||
|
||||
# Close any open text block
|
||||
if current_text_id is not None:
|
||||
|
|
|
|||
|
|
@ -46,7 +46,13 @@ export function DashboardClientLayout({
|
|||
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
|
||||
|
||||
const isOnboardingComplete = useCallback(() => {
|
||||
return !!(preferences.agent_llm_id && preferences.document_summary_llm_id);
|
||||
// Check that both LLM IDs are set (including 0 for Auto mode)
|
||||
return (
|
||||
preferences.agent_llm_id !== null &&
|
||||
preferences.agent_llm_id !== undefined &&
|
||||
preferences.document_summary_llm_id !== null &&
|
||||
preferences.document_summary_llm_id !== undefined
|
||||
);
|
||||
}, [preferences]);
|
||||
|
||||
const { data: access = null, isLoading: accessLoading } = useAtomValue(myAccessAtom);
|
||||
|
|
|
|||
|
|
@ -53,8 +53,12 @@ export default function OnboardPage() {
|
|||
}
|
||||
}, []);
|
||||
|
||||
// Check if onboarding is already complete
|
||||
const isOnboardingComplete = preferences.agent_llm_id && preferences.document_summary_llm_id;
|
||||
// Check if onboarding is already complete (including 0 for Auto mode)
|
||||
const isOnboardingComplete =
|
||||
preferences.agent_llm_id !== null &&
|
||||
preferences.agent_llm_id !== undefined &&
|
||||
preferences.document_summary_llm_id !== null &&
|
||||
preferences.document_summary_llm_id !== undefined;
|
||||
|
||||
// If onboarding is already complete, redirect immediately
|
||||
useEffect(() => {
|
||||
|
|
|
|||
|
|
@ -485,7 +485,8 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
if (agentLlmId === null || agentLlmId === undefined) return false;
|
||||
|
||||
// Check if the configured model actually exists
|
||||
if (agentLlmId < 0) {
|
||||
// Auto mode (ID 0) and global configs (negative IDs) are in globalConfigs
|
||||
if (agentLlmId <= 0) {
|
||||
return globalConfigs?.some((c) => c.id === agentLlmId) ?? false;
|
||||
}
|
||||
return userConfigs?.some((c) => c.id === agentLlmId) ?? false;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import { AlertCircle, Bot, ChevronRight, Globe, User, X } from "lucide-react";
|
||||
import { AlertCircle, Bot, ChevronRight, Globe, Shuffle, User, X, Zap } from "lucide-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { createPortal } from "react-dom";
|
||||
|
|
@ -62,9 +62,13 @@ export function ModelConfigSidebar({
|
|||
return () => window.removeEventListener("keydown", handleEscape);
|
||||
}, [open, onOpenChange]);
|
||||
|
||||
// Check if this is Auto mode
|
||||
const isAutoMode = config && "is_auto_mode" in config && config.is_auto_mode;
|
||||
|
||||
// Get title based on mode
|
||||
const getTitle = () => {
|
||||
if (mode === "create") return "Add New Configuration";
|
||||
if (isAutoMode) return "Auto Mode (Load Balanced)";
|
||||
if (isGlobal) return "View Global Configuration";
|
||||
return "Edit Configuration";
|
||||
};
|
||||
|
|
@ -187,15 +191,37 @@ export function ModelConfigSidebar({
|
|||
)}
|
||||
>
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between px-6 py-4 border-b border-border/50 bg-muted/20">
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-between px-6 py-4 border-b border-border/50",
|
||||
isAutoMode ? "bg-gradient-to-r from-violet-500/10 to-purple-500/10" : "bg-muted/20"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex items-center justify-center size-10 rounded-xl bg-primary/10">
|
||||
<Bot className="size-5 text-primary" />
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center size-10 rounded-xl",
|
||||
isAutoMode ? "bg-gradient-to-br from-violet-500 to-purple-600" : "bg-primary/10"
|
||||
)}
|
||||
>
|
||||
{isAutoMode ? (
|
||||
<Shuffle className="size-5 text-white" />
|
||||
) : (
|
||||
<Bot className="size-5 text-primary" />
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
<h2 className="text-base sm:text-lg font-semibold">{getTitle()}</h2>
|
||||
<div className="flex items-center gap-2 mt-0.5">
|
||||
{isGlobal ? (
|
||||
{isAutoMode ? (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="gap-1 text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
|
||||
>
|
||||
<Zap className="size-3" />
|
||||
Recommended
|
||||
</Badge>
|
||||
) : isGlobal ? (
|
||||
<Badge variant="secondary" className="gap-1 text-xs">
|
||||
<Globe className="size-3" />
|
||||
Global
|
||||
|
|
@ -206,7 +232,7 @@ export function ModelConfigSidebar({
|
|||
Custom
|
||||
</Badge>
|
||||
) : null}
|
||||
{config && (
|
||||
{config && !isAutoMode && (
|
||||
<span className="text-xs text-muted-foreground">{config.model_name}</span>
|
||||
)}
|
||||
</div>
|
||||
|
|
@ -226,8 +252,19 @@ export function ModelConfigSidebar({
|
|||
{/* Content - use overflow-y-auto instead of ScrollArea for better compatibility */}
|
||||
<div className="flex-1 overflow-y-auto">
|
||||
<div className="p-6">
|
||||
{/* Auto mode info banner */}
|
||||
{isAutoMode && (
|
||||
<Alert className="mb-6 border-violet-500/30 bg-violet-500/5">
|
||||
<Shuffle className="size-4 text-violet-500" />
|
||||
<AlertDescription className="text-sm text-violet-700 dark:text-violet-400">
|
||||
Auto mode automatically distributes requests across all available LLM
|
||||
providers to optimize performance and avoid rate limits.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{/* Global config notice */}
|
||||
{isGlobal && mode !== "create" && (
|
||||
{isGlobal && !isAutoMode && mode !== "create" && (
|
||||
<Alert className="mb-6 border-amber-500/30 bg-amber-500/5">
|
||||
<AlertCircle className="size-4 text-amber-500" />
|
||||
<AlertDescription className="text-sm text-amber-700 dark:text-amber-400">
|
||||
|
|
@ -247,6 +284,87 @@ export function ModelConfigSidebar({
|
|||
mode="create"
|
||||
submitLabel="Create & Use"
|
||||
/>
|
||||
) : isAutoMode && config ? (
|
||||
// Special view for Auto mode
|
||||
<div className="space-y-6">
|
||||
{/* Auto Mode Features */}
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-1.5">
|
||||
<div className="text-xs font-medium text-muted-foreground uppercase tracking-wider">
|
||||
How It Works
|
||||
</div>
|
||||
<p className="text-sm text-muted-foreground">{config.description}</p>
|
||||
</div>
|
||||
|
||||
<div className="h-px bg-border/50" />
|
||||
|
||||
<div className="space-y-3">
|
||||
<div className="text-xs font-medium text-muted-foreground uppercase tracking-wider">
|
||||
Key Benefits
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-start gap-3 p-3 rounded-lg bg-violet-50 dark:bg-violet-900/20 border border-violet-200 dark:border-violet-800/50">
|
||||
<Zap className="size-4 text-violet-600 dark:text-violet-400 mt-0.5 shrink-0" />
|
||||
<div>
|
||||
<p className="text-sm font-medium text-violet-900 dark:text-violet-100">
|
||||
Automatic Load Balancing
|
||||
</p>
|
||||
<p className="text-xs text-violet-700 dark:text-violet-300">
|
||||
Distributes requests across all configured LLM providers
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-start gap-3 p-3 rounded-lg bg-violet-50 dark:bg-violet-900/20 border border-violet-200 dark:border-violet-800/50">
|
||||
<Zap className="size-4 text-violet-600 dark:text-violet-400 mt-0.5 shrink-0" />
|
||||
<div>
|
||||
<p className="text-sm font-medium text-violet-900 dark:text-violet-100">
|
||||
Rate Limit Protection
|
||||
</p>
|
||||
<p className="text-xs text-violet-700 dark:text-violet-300">
|
||||
Automatically handles rate limits with cooldowns and retries
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-start gap-3 p-3 rounded-lg bg-violet-50 dark:bg-violet-900/20 border border-violet-200 dark:border-violet-800/50">
|
||||
<Zap className="size-4 text-violet-600 dark:text-violet-400 mt-0.5 shrink-0" />
|
||||
<div>
|
||||
<p className="text-sm font-medium text-violet-900 dark:text-violet-100">
|
||||
Automatic Failover
|
||||
</p>
|
||||
<p className="text-xs text-violet-700 dark:text-violet-300">
|
||||
Falls back to other providers if one becomes unavailable
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Action Buttons */}
|
||||
<div className="flex gap-3 pt-4 border-t border-border/50">
|
||||
<Button
|
||||
variant="outline"
|
||||
className="flex-1"
|
||||
onClick={() => onOpenChange(false)}
|
||||
>
|
||||
Close
|
||||
</Button>
|
||||
<Button
|
||||
className="flex-1 gap-2 bg-gradient-to-r from-violet-500 to-purple-600 hover:from-violet-600 hover:to-purple-700"
|
||||
onClick={handleUseGlobalConfig}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? (
|
||||
<>Loading...</>
|
||||
) : (
|
||||
<>
|
||||
<ChevronRight className="size-4" />
|
||||
Use Auto Mode
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
) : isGlobal && config ? (
|
||||
// Read-only view for global configs
|
||||
<div className="space-y-6">
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import {
|
|||
Globe,
|
||||
Plus,
|
||||
Settings2,
|
||||
Shuffle,
|
||||
Sparkles,
|
||||
User,
|
||||
Zap,
|
||||
|
|
@ -43,8 +44,14 @@ import type {
|
|||
import { cn } from "@/lib/utils";
|
||||
|
||||
// Provider icons mapping
|
||||
const getProviderIcon = (provider: string) => {
|
||||
const getProviderIcon = (provider: string, isAutoMode?: boolean) => {
|
||||
const iconClass = "size-4";
|
||||
|
||||
// Special icon for Auto mode
|
||||
if (isAutoMode || provider?.toUpperCase() === "AUTO") {
|
||||
return <Shuffle className={cn(iconClass, "text-violet-500")} />;
|
||||
}
|
||||
|
||||
switch (provider?.toUpperCase()) {
|
||||
case "OPENAI":
|
||||
return <Sparkles className={cn(iconClass, "text-emerald-500")} />;
|
||||
|
|
@ -90,14 +97,19 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp
|
|||
const agentLlmId = preferences.agent_llm_id;
|
||||
if (agentLlmId === null || agentLlmId === undefined) return null;
|
||||
|
||||
// Check if it's a global config (negative ID)
|
||||
if (agentLlmId < 0) {
|
||||
// Check if it's Auto mode (ID 0) or global config (negative ID)
|
||||
if (agentLlmId <= 0) {
|
||||
return globalConfigs?.find((c) => c.id === agentLlmId) ?? null;
|
||||
}
|
||||
// Otherwise, check user configs
|
||||
return userConfigs?.find((c) => c.id === agentLlmId) ?? null;
|
||||
}, [preferences, globalConfigs, userConfigs]);
|
||||
|
||||
// Check if current config is Auto mode
|
||||
const isCurrentAutoMode = useMemo(() => {
|
||||
return currentConfig && "is_auto_mode" in currentConfig && currentConfig.is_auto_mode;
|
||||
}, [currentConfig]);
|
||||
|
||||
// Filter configs based on search
|
||||
const filteredGlobalConfigs = useMemo(() => {
|
||||
if (!globalConfigs) return [];
|
||||
|
|
@ -184,14 +196,23 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp
|
|||
</>
|
||||
) : currentConfig ? (
|
||||
<>
|
||||
{getProviderIcon(currentConfig.provider)}
|
||||
{getProviderIcon(currentConfig.provider, isCurrentAutoMode ?? false)}
|
||||
<span className="max-w-[100px] md:max-w-[150px] truncate hidden md:inline">
|
||||
{currentConfig.name}
|
||||
</span>
|
||||
<Badge variant="secondary" className="ml-1 text-[10px] px-1.5 py-0 h-4 bg-muted/80">
|
||||
{currentConfig.model_name.split("/").pop()?.slice(0, 10) ||
|
||||
currentConfig.model_name.slice(0, 10)}
|
||||
</Badge>
|
||||
{isCurrentAutoMode ? (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="ml-1 text-[10px] px-1.5 py-0 h-4 bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
|
||||
>
|
||||
Balanced
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="secondary" className="ml-1 text-[10px] px-1.5 py-0 h-4 bg-muted/80">
|
||||
{currentConfig.model_name.split("/").pop()?.slice(0, 10) ||
|
||||
currentConfig.model_name.slice(0, 10)}
|
||||
</Badge>
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
|
|
@ -246,6 +267,7 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp
|
|||
</div>
|
||||
{filteredGlobalConfigs.map((config) => {
|
||||
const isSelected = currentConfig?.id === config.id;
|
||||
const isAutoMode = "is_auto_mode" in config && config.is_auto_mode;
|
||||
return (
|
||||
<CommandItem
|
||||
key={`global-${config.id}`}
|
||||
|
|
@ -254,22 +276,33 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp
|
|||
className={cn(
|
||||
"mx-2 rounded-lg mb-1 cursor-pointer group transition-all",
|
||||
"hover:bg-accent/50",
|
||||
isSelected && "bg-accent/80"
|
||||
isSelected && "bg-accent/80",
|
||||
isAutoMode && "border border-violet-200 dark:border-violet-800/50"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center justify-between w-full gap-2">
|
||||
<div className="flex items-center gap-3 min-w-0 flex-1">
|
||||
<div className="shrink-0">{getProviderIcon(config.provider)}</div>
|
||||
<div className="shrink-0">
|
||||
{getProviderIcon(config.provider, isAutoMode)}
|
||||
</div>
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-medium truncate">{config.name}</span>
|
||||
{isAutoMode && (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="text-[9px] px-1 py-0 h-3.5 bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-0"
|
||||
>
|
||||
Recommended
|
||||
</Badge>
|
||||
)}
|
||||
{isSelected && <Check className="size-3.5 text-primary shrink-0" />}
|
||||
</div>
|
||||
<div className="flex items-center gap-1.5 mt-0.5">
|
||||
<span className="text-xs text-muted-foreground truncate">
|
||||
{config.model_name}
|
||||
{isAutoMode ? "Auto load balancing" : config.model_name}
|
||||
</span>
|
||||
{config.citations_enabled && (
|
||||
{!isAutoMode && config.citations_enabled && (
|
||||
<Badge
|
||||
variant="outline"
|
||||
className="text-[9px] px-1 py-0 h-3.5 bg-primary/10 text-primary border-primary/20"
|
||||
|
|
@ -280,14 +313,16 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp
|
|||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="size-7 shrink-0 rounded-md hover:bg-muted opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
onClick={(e) => handleEditConfig(e, config, true)}
|
||||
>
|
||||
<Edit3 className="size-3.5 text-muted-foreground" />
|
||||
</Button>
|
||||
{!isAutoMode && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="size-7 shrink-0 rounded-md hover:bg-muted opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
onClick={(e) => handleEditConfig(e, config, true)}
|
||||
>
|
||||
<Edit3 className="size-3.5 text-muted-foreground" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</CommandItem>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,16 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import { AlertCircle, Bot, CheckCircle, FileText, RefreshCw, RotateCcw, Save } from "lucide-react";
|
||||
import {
|
||||
AlertCircle,
|
||||
Bot,
|
||||
CheckCircle,
|
||||
FileText,
|
||||
RefreshCw,
|
||||
RotateCcw,
|
||||
Save,
|
||||
Shuffle,
|
||||
} from "lucide-react";
|
||||
import { motion } from "motion/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
|
@ -24,6 +33,7 @@ import {
|
|||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const ROLE_DESCRIPTIONS = {
|
||||
agent: {
|
||||
|
|
@ -71,8 +81,8 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
|
||||
|
||||
const [assignments, setAssignments] = useState({
|
||||
agent_llm_id: preferences.agent_llm_id || "",
|
||||
document_summary_llm_id: preferences.document_summary_llm_id || "",
|
||||
agent_llm_id: preferences.agent_llm_id ?? "",
|
||||
document_summary_llm_id: preferences.document_summary_llm_id ?? "",
|
||||
});
|
||||
|
||||
const [hasChanges, setHasChanges] = useState(false);
|
||||
|
|
@ -80,8 +90,8 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
|
||||
useEffect(() => {
|
||||
const newAssignments = {
|
||||
agent_llm_id: preferences.agent_llm_id || "",
|
||||
document_summary_llm_id: preferences.document_summary_llm_id || "",
|
||||
agent_llm_id: preferences.agent_llm_id ?? "",
|
||||
document_summary_llm_id: preferences.document_summary_llm_id ?? "",
|
||||
};
|
||||
setAssignments(newAssignments);
|
||||
setHasChanges(false);
|
||||
|
|
@ -97,8 +107,8 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
|
||||
// Check if there are changes compared to current preferences
|
||||
const currentPrefs = {
|
||||
agent_llm_id: preferences.agent_llm_id || "",
|
||||
document_summary_llm_id: preferences.document_summary_llm_id || "",
|
||||
agent_llm_id: preferences.agent_llm_id ?? "",
|
||||
document_summary_llm_id: preferences.document_summary_llm_id ?? "",
|
||||
};
|
||||
|
||||
const hasChangesNow = Object.keys(newAssignments).some(
|
||||
|
|
@ -141,13 +151,19 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
|
||||
const handleReset = () => {
|
||||
setAssignments({
|
||||
agent_llm_id: preferences.agent_llm_id || "",
|
||||
document_summary_llm_id: preferences.document_summary_llm_id || "",
|
||||
agent_llm_id: preferences.agent_llm_id ?? "",
|
||||
document_summary_llm_id: preferences.document_summary_llm_id ?? "",
|
||||
});
|
||||
setHasChanges(false);
|
||||
};
|
||||
|
||||
const isAssignmentComplete = assignments.agent_llm_id && assignments.document_summary_llm_id;
|
||||
const isAssignmentComplete =
|
||||
assignments.agent_llm_id !== "" &&
|
||||
assignments.agent_llm_id !== null &&
|
||||
assignments.agent_llm_id !== undefined &&
|
||||
assignments.document_summary_llm_id !== "" &&
|
||||
assignments.document_summary_llm_id !== null &&
|
||||
assignments.document_summary_llm_id !== undefined;
|
||||
|
||||
// Combine global and custom configs (new system)
|
||||
const allConfigs = [
|
||||
|
|
@ -300,22 +316,47 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">
|
||||
Global Configurations
|
||||
</div>
|
||||
{globalConfigs.map((config) => (
|
||||
<SelectItem key={config.id} value={config.id.toString()}>
|
||||
<div className="flex items-center gap-2">
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{config.provider}
|
||||
</Badge>
|
||||
<span>{config.name}</span>
|
||||
<span className="text-muted-foreground">
|
||||
({config.model_name})
|
||||
</span>
|
||||
<Badge variant="secondary" className="text-xs">
|
||||
🌐 Global
|
||||
</Badge>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
{globalConfigs.map((config) => {
|
||||
const isAutoMode =
|
||||
"is_auto_mode" in config && config.is_auto_mode;
|
||||
return (
|
||||
<SelectItem key={config.id} value={config.id.toString()}>
|
||||
<div className="flex items-center gap-2">
|
||||
{isAutoMode ? (
|
||||
<Badge
|
||||
variant="outline"
|
||||
className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-violet-200 dark:border-violet-700"
|
||||
>
|
||||
<Shuffle className="size-3 mr-1" />
|
||||
AUTO
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{config.provider}
|
||||
</Badge>
|
||||
)}
|
||||
<span>{config.name}</span>
|
||||
{!isAutoMode && (
|
||||
<span className="text-muted-foreground">
|
||||
({config.model_name})
|
||||
</span>
|
||||
)}
|
||||
{isAutoMode ? (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
|
||||
>
|
||||
Recommended
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="secondary" className="text-xs">
|
||||
🌐 Global
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
</SelectItem>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
|
||||
|
|
@ -349,27 +390,65 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
</div>
|
||||
|
||||
{assignedConfig && (
|
||||
<div className="mt-2 md:mt-3 p-2 md:p-3 bg-muted/50 rounded-lg">
|
||||
<div
|
||||
className={cn(
|
||||
"mt-2 md:mt-3 p-2 md:p-3 rounded-lg",
|
||||
"is_auto_mode" in assignedConfig && assignedConfig.is_auto_mode
|
||||
? "bg-violet-50 dark:bg-violet-900/20 border border-violet-200 dark:border-violet-800/50"
|
||||
: "bg-muted/50"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-1.5 md:gap-2 text-xs md:text-sm flex-wrap">
|
||||
<Bot className="w-3 h-3 md:w-4 md:h-4 shrink-0" />
|
||||
{"is_auto_mode" in assignedConfig && assignedConfig.is_auto_mode ? (
|
||||
<Shuffle className="w-3 h-3 md:w-4 md:h-4 shrink-0 text-violet-600 dark:text-violet-400" />
|
||||
) : (
|
||||
<Bot className="w-3 h-3 md:w-4 md:h-4 shrink-0" />
|
||||
)}
|
||||
<span className="font-medium">Assigned:</span>
|
||||
<Badge variant="secondary" className="text-[10px] md:text-xs">
|
||||
{assignedConfig.provider}
|
||||
</Badge>
|
||||
<span>{assignedConfig.name}</span>
|
||||
{"is_global" in assignedConfig && assignedConfig.is_global && (
|
||||
<Badge variant="outline" className="text-[9px] md:text-xs">
|
||||
🌐 Global
|
||||
{"is_auto_mode" in assignedConfig && assignedConfig.is_auto_mode ? (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="text-[10px] md:text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
|
||||
>
|
||||
AUTO
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="secondary" className="text-[10px] md:text-xs">
|
||||
{assignedConfig.provider}
|
||||
</Badge>
|
||||
)}
|
||||
<span>{assignedConfig.name}</span>
|
||||
{"is_auto_mode" in assignedConfig && assignedConfig.is_auto_mode ? (
|
||||
<Badge
|
||||
variant="outline"
|
||||
className="text-[9px] md:text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-violet-200 dark:border-violet-700"
|
||||
>
|
||||
Recommended
|
||||
</Badge>
|
||||
) : (
|
||||
"is_global" in assignedConfig &&
|
||||
assignedConfig.is_global && (
|
||||
<Badge variant="outline" className="text-[9px] md:text-xs">
|
||||
🌐 Global
|
||||
</Badge>
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
<div className="text-[10px] md:text-xs text-muted-foreground mt-0.5 md:mt-1">
|
||||
Model: {assignedConfig.model_name}
|
||||
</div>
|
||||
{assignedConfig.api_base && (
|
||||
<div className="text-[10px] md:text-xs text-muted-foreground">
|
||||
Base: {assignedConfig.api_base}
|
||||
{"is_auto_mode" in assignedConfig && assignedConfig.is_auto_mode ? (
|
||||
<div className="text-[10px] md:text-xs text-violet-600 dark:text-violet-400 mt-0.5 md:mt-1">
|
||||
Automatically load balances across all available LLM providers
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="text-[10px] md:text-xs text-muted-foreground mt-0.5 md:mt-1">
|
||||
Model: {assignedConfig.model_name}
|
||||
</div>
|
||||
{assignedConfig.api_base && (
|
||||
<div className="text-[10px] md:text-xs text-muted-foreground">
|
||||
Base: {assignedConfig.api_base}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -136,14 +136,15 @@ export const getDefaultSystemInstructionsResponse = z.object({
|
|||
|
||||
/**
|
||||
* Global NewLLMConfig - from YAML, has negative IDs
|
||||
* ID 0 is reserved for "Auto" mode which uses LiteLLM Router for load balancing
|
||||
*/
|
||||
export const globalNewLLMConfig = z.object({
|
||||
id: z.number(), // Negative IDs for global configs
|
||||
id: z.number(), // 0 for Auto mode, negative IDs for global configs
|
||||
name: z.string(),
|
||||
description: z.string().nullable().optional(),
|
||||
|
||||
// LLM Model Configuration (no api_key)
|
||||
provider: z.string(), // String because YAML doesn't enforce enum
|
||||
provider: z.string(), // String because YAML doesn't enforce enum, "AUTO" for Auto mode
|
||||
custom_provider: z.string().nullable().optional(),
|
||||
model_name: z.string(),
|
||||
api_base: z.string().nullable().optional(),
|
||||
|
|
@ -155,6 +156,7 @@ export const globalNewLLMConfig = z.object({
|
|||
citations_enabled: z.boolean().default(true),
|
||||
|
||||
is_global: z.literal(true),
|
||||
is_auto_mode: z.boolean().optional().default(false), // True only for Auto mode (ID 0)
|
||||
});
|
||||
|
||||
export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue