mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 12:52:39 +02:00
hotpatch(cloud): add llm load balancing
This commit is contained in:
parent
5d5f9d3bfb
commit
6fb656fd8f
21 changed files with 1324 additions and 103 deletions
|
|
@ -0,0 +1,75 @@
|
||||||
|
"""Migrate global LLM configs to Auto mode
|
||||||
|
|
||||||
|
Revision ID: 84
|
||||||
|
Revises: 83
|
||||||
|
|
||||||
|
This migration updates existing search spaces that use global LLM configs
|
||||||
|
(negative IDs) to use the new Auto mode (ID 0) instead.
|
||||||
|
|
||||||
|
Auto mode uses LiteLLM Router to automatically load balance requests across
|
||||||
|
all configured global LLM providers, which helps avoid rate limits.
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
1. Update agent_llm_id from negative values to 0 (Auto mode)
|
||||||
|
2. Update document_summary_llm_id from negative values to 0 (Auto mode)
|
||||||
|
3. Update NULL values to 0 (Auto mode) as the new default
|
||||||
|
|
||||||
|
Note: This migration preserves any custom user-created LLM configs (positive IDs).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "84"
|
||||||
|
down_revision: str | None = "83"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Migrate global LLM config IDs (negative) and NULL to Auto mode (0)."""
|
||||||
|
# Update agent_llm_id: convert negative values and NULL to 0 (Auto mode)
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
UPDATE searchspaces
|
||||||
|
SET agent_llm_id = 0
|
||||||
|
WHERE agent_llm_id < 0 OR agent_llm_id IS NULL
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update document_summary_llm_id: convert negative values and NULL to 0 (Auto mode)
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
UPDATE searchspaces
|
||||||
|
SET document_summary_llm_id = 0
|
||||||
|
WHERE document_summary_llm_id < 0 OR document_summary_llm_id IS NULL
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Revert Auto mode back to the first global config (ID -1).
|
||||||
|
|
||||||
|
Note: This is a best-effort revert. We cannot know which specific
|
||||||
|
global config each search space was using before, so we default
|
||||||
|
to -1 (typically the first/primary global config).
|
||||||
|
"""
|
||||||
|
# Revert agent_llm_id from Auto mode (0) back to first global config (-1)
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
UPDATE searchspaces
|
||||||
|
SET agent_llm_id = -1
|
||||||
|
WHERE agent_llm_id = 0
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Revert document_summary_llm_id from Auto mode (0) back to first global config (-1)
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
UPDATE searchspaces
|
||||||
|
SET document_summary_llm_id = -1
|
||||||
|
WHERE document_summary_llm_id = 0
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
@ -10,8 +10,8 @@ from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from deepagents import create_deep_agent
|
from deepagents import create_deep_agent
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_litellm import ChatLiteLLM
|
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
@ -114,7 +114,7 @@ def _map_connectors_to_searchable_types(
|
||||||
|
|
||||||
|
|
||||||
async def create_surfsense_deep_agent(
|
async def create_surfsense_deep_agent(
|
||||||
llm: ChatLiteLLM,
|
llm: BaseChatModel,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
connector_service: ConnectorService,
|
connector_service: ConnectorService,
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,9 @@
|
||||||
LLM configuration utilities for SurfSense agents.
|
LLM configuration utilities for SurfSense agents.
|
||||||
|
|
||||||
This module provides functions for loading LLM configurations from:
|
This module provides functions for loading LLM configurations from:
|
||||||
1. YAML files (global configs with negative IDs)
|
1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing
|
||||||
2. Database NewLLMConfig table (user-created configs with positive IDs)
|
2. YAML files (global configs with negative IDs)
|
||||||
|
3. Database NewLLMConfig table (user-created configs with positive IDs)
|
||||||
|
|
||||||
It also provides utilities for creating ChatLiteLLM instances and
|
It also provides utilities for creating ChatLiteLLM instances and
|
||||||
managing prompt configurations.
|
managing prompt configurations.
|
||||||
|
|
@ -17,6 +18,13 @@ from langchain_litellm import ChatLiteLLM
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.services.llm_router_service import (
|
||||||
|
AUTO_MODE_ID,
|
||||||
|
ChatLiteLLMRouter,
|
||||||
|
LLMRouterService,
|
||||||
|
is_auto_mode,
|
||||||
|
)
|
||||||
|
|
||||||
# Provider mapping for LiteLLM model string construction
|
# Provider mapping for LiteLLM model string construction
|
||||||
PROVIDER_MAP = {
|
PROVIDER_MAP = {
|
||||||
"OPENAI": "openai",
|
"OPENAI": "openai",
|
||||||
|
|
@ -58,6 +66,7 @@ class AgentConfig:
|
||||||
Complete configuration for the SurfSense agent.
|
Complete configuration for the SurfSense agent.
|
||||||
|
|
||||||
This combines LLM settings with prompt configuration from NewLLMConfig.
|
This combines LLM settings with prompt configuration from NewLLMConfig.
|
||||||
|
Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# LLM Model Settings
|
# LLM Model Settings
|
||||||
|
|
@ -77,6 +86,32 @@ class AgentConfig:
|
||||||
config_id: int | None = None
|
config_id: int | None = None
|
||||||
config_name: str | None = None
|
config_name: str | None = None
|
||||||
|
|
||||||
|
# Auto mode flag
|
||||||
|
is_auto_mode: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_auto_mode(cls) -> "AgentConfig":
|
||||||
|
"""
|
||||||
|
Create an AgentConfig for Auto mode (LiteLLM Router load balancing).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentConfig instance configured for Auto mode
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
provider="AUTO",
|
||||||
|
model_name="auto",
|
||||||
|
api_key="", # Not needed for router
|
||||||
|
api_base=None,
|
||||||
|
custom_provider=None,
|
||||||
|
litellm_params=None,
|
||||||
|
system_instructions=None,
|
||||||
|
use_default_system_instructions=True,
|
||||||
|
citations_enabled=True,
|
||||||
|
config_id=AUTO_MODE_ID,
|
||||||
|
config_name="Auto (Load Balanced)",
|
||||||
|
is_auto_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_new_llm_config(cls, config) -> "AgentConfig":
|
def from_new_llm_config(cls, config) -> "AgentConfig":
|
||||||
"""
|
"""
|
||||||
|
|
@ -102,6 +137,7 @@ class AgentConfig:
|
||||||
citations_enabled=config.citations_enabled,
|
citations_enabled=config.citations_enabled,
|
||||||
config_id=config.id,
|
config_id=config.id,
|
||||||
config_name=config.name,
|
config_name=config.name,
|
||||||
|
is_auto_mode=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -138,6 +174,7 @@ class AgentConfig:
|
||||||
citations_enabled=yaml_config.get("citations_enabled", True),
|
citations_enabled=yaml_config.get("citations_enabled", True),
|
||||||
config_id=yaml_config.get("id"),
|
config_id=yaml_config.get("id"),
|
||||||
config_name=yaml_config.get("name"),
|
config_name=yaml_config.get("name"),
|
||||||
|
is_auto_mode=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -261,20 +298,28 @@ async def load_agent_config(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
) -> "AgentConfig | None":
|
) -> "AgentConfig | None":
|
||||||
"""
|
"""
|
||||||
Load an agent configuration, supporting both YAML (negative IDs) and database (positive IDs) configs.
|
Load an agent configuration, supporting Auto mode, YAML, and database configs.
|
||||||
|
|
||||||
This is the main entry point for loading configurations:
|
This is the main entry point for loading configurations:
|
||||||
|
- ID 0: Auto mode (uses LiteLLM Router for load balancing)
|
||||||
- Negative IDs: Load from YAML file (global configs)
|
- Negative IDs: Load from YAML file (global configs)
|
||||||
- Positive IDs: Load from NewLLMConfig database table
|
- Positive IDs: Load from NewLLMConfig database table
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: AsyncSession for database access
|
session: AsyncSession for database access
|
||||||
config_id: The config ID (negative for YAML, positive for database)
|
config_id: The config ID (0 for Auto, negative for YAML, positive for database)
|
||||||
search_space_id: Optional search space ID for context
|
search_space_id: Optional search space ID for context
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AgentConfig instance or None if not found
|
AgentConfig instance or None if not found
|
||||||
"""
|
"""
|
||||||
|
# Auto mode (ID 0) - use LiteLLM Router
|
||||||
|
if is_auto_mode(config_id):
|
||||||
|
if not LLMRouterService.is_initialized():
|
||||||
|
print("Error: Auto mode requested but LLM Router not initialized")
|
||||||
|
return None
|
||||||
|
return AgentConfig.from_auto_mode()
|
||||||
|
|
||||||
if config_id < 0:
|
if config_id < 0:
|
||||||
# Load from YAML (global configs have negative IDs)
|
# Load from YAML (global configs have negative IDs)
|
||||||
yaml_config = load_llm_config_from_yaml(config_id)
|
yaml_config = load_llm_config_from_yaml(config_id)
|
||||||
|
|
@ -324,16 +369,30 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||||
|
|
||||||
def create_chat_litellm_from_agent_config(
|
def create_chat_litellm_from_agent_config(
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
) -> ChatLiteLLM | None:
|
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||||
"""
|
"""
|
||||||
Create a ChatLiteLLM instance from an AgentConfig.
|
Create a ChatLiteLLM or ChatLiteLLMRouter instance from an AgentConfig.
|
||||||
|
|
||||||
|
For Auto mode configs, returns a ChatLiteLLMRouter that uses LiteLLM Router
|
||||||
|
for automatic load balancing across available providers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent_config: AgentConfig instance
|
agent_config: AgentConfig instance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ChatLiteLLM instance or None on error
|
ChatLiteLLM or ChatLiteLLMRouter instance, or None on error
|
||||||
"""
|
"""
|
||||||
|
# Handle Auto mode - return ChatLiteLLMRouter
|
||||||
|
if agent_config.is_auto_mode:
|
||||||
|
if not LLMRouterService.is_initialized():
|
||||||
|
print("Error: Auto mode requested but LLM Router not initialized")
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return ChatLiteLLMRouter()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
# Build the model string
|
# Build the model string
|
||||||
if agent_config.custom_provider:
|
if agent_config.custom_provider:
|
||||||
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
|
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from app.agents.new_chat.checkpointer import (
|
||||||
close_checkpointer,
|
close_checkpointer,
|
||||||
setup_checkpointer_tables,
|
setup_checkpointer_tables,
|
||||||
)
|
)
|
||||||
from app.config import config
|
from app.config import config, initialize_llm_router
|
||||||
from app.db import User, create_db_and_tables, get_async_session
|
from app.db import User, create_db_and_tables, get_async_session
|
||||||
from app.routes import router as crud_router
|
from app.routes import router as crud_router
|
||||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||||
|
|
@ -23,6 +23,8 @@ async def lifespan(app: FastAPI):
|
||||||
await create_db_and_tables()
|
await create_db_and_tables()
|
||||||
# Setup LangGraph checkpointer tables for conversation persistence
|
# Setup LangGraph checkpointer tables for conversation persistence
|
||||||
await setup_checkpointer_tables()
|
await setup_checkpointer_tables()
|
||||||
|
# Initialize LLM Router for Auto mode load balancing
|
||||||
|
initialize_llm_router()
|
||||||
# Seed Surfsense documentation
|
# Seed Surfsense documentation
|
||||||
await seed_surfsense_docs()
|
await seed_surfsense_docs()
|
||||||
yield
|
yield
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,25 @@ import os
|
||||||
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
from celery.schedules import crontab
|
from celery.schedules import crontab
|
||||||
|
from celery.signals import worker_process_init
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@worker_process_init.connect
|
||||||
|
def init_worker(**kwargs):
|
||||||
|
"""Initialize the LLM Router when a Celery worker process starts.
|
||||||
|
|
||||||
|
This ensures the Auto mode (LiteLLM Router) is available for background tasks
|
||||||
|
like document summarization.
|
||||||
|
"""
|
||||||
|
from app.config import initialize_llm_router
|
||||||
|
|
||||||
|
initialize_llm_router()
|
||||||
|
|
||||||
|
|
||||||
# Get Celery configuration from environment
|
# Get Celery configuration from environment
|
||||||
CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||||
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/0")
|
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/0")
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,63 @@ def load_global_llm_configs():
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def load_router_settings():
|
||||||
|
"""
|
||||||
|
Load router settings for Auto mode from YAML file.
|
||||||
|
Falls back to default settings if not found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Router settings dictionary
|
||||||
|
"""
|
||||||
|
# Default router settings
|
||||||
|
default_settings = {
|
||||||
|
"routing_strategy": "usage-based-routing",
|
||||||
|
"num_retries": 3,
|
||||||
|
"allowed_fails": 3,
|
||||||
|
"cooldown_time": 60,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Try main config file first
|
||||||
|
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||||
|
|
||||||
|
if not global_config_file.exists():
|
||||||
|
return default_settings
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(global_config_file, encoding="utf-8") as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
settings = data.get("router_settings", {})
|
||||||
|
# Merge with defaults
|
||||||
|
return {**default_settings, **settings}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to load router settings: {e}")
|
||||||
|
return default_settings
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_llm_router():
|
||||||
|
"""
|
||||||
|
Initialize the LLM Router service for Auto mode.
|
||||||
|
This should be called during application startup.
|
||||||
|
"""
|
||||||
|
global_configs = load_global_llm_configs()
|
||||||
|
router_settings = load_router_settings()
|
||||||
|
|
||||||
|
if not global_configs:
|
||||||
|
print("Info: No global LLM configs found, Auto mode will not be available")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.services.llm_router_service import LLMRouterService
|
||||||
|
|
||||||
|
LLMRouterService.initialize(global_configs, router_settings)
|
||||||
|
print(
|
||||||
|
f"Info: LLM Router initialized with {len(global_configs)} models "
|
||||||
|
f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to initialize LLM Router: {e}")
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
# Check if ffmpeg is installed
|
# Check if ffmpeg is installed
|
||||||
if not is_ffmpeg_installed():
|
if not is_ffmpeg_installed():
|
||||||
|
|
@ -156,6 +213,9 @@ class Config:
|
||||||
# These can be used as default options for users
|
# These can be used as default options for users
|
||||||
GLOBAL_LLM_CONFIGS = load_global_llm_configs()
|
GLOBAL_LLM_CONFIGS = load_global_llm_configs()
|
||||||
|
|
||||||
|
# Router settings for Auto mode (LiteLLM Router load balancing)
|
||||||
|
ROUTER_SETTINGS = load_router_settings()
|
||||||
|
|
||||||
# Chonkie Configuration | Edit this to your needs
|
# Chonkie Configuration | Edit this to your needs
|
||||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||||
# Azure OpenAI credentials from environment variables
|
# Azure OpenAI credentials from environment variables
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,39 @@
|
||||||
# These configurations will be available to all users as a convenient option
|
# These configurations will be available to all users as a convenient option
|
||||||
# Users can choose to use these global configs or add their own
|
# Users can choose to use these global configs or add their own
|
||||||
#
|
#
|
||||||
|
# AUTO MODE (Recommended):
|
||||||
|
# - Auto mode (ID: 0) uses LiteLLM Router to automatically load balance across all global configs
|
||||||
|
# - This helps avoid rate limits by distributing requests across multiple providers
|
||||||
|
# - New users are automatically assigned Auto mode by default
|
||||||
|
# - Configure router_settings below to customize the load balancing behavior
|
||||||
|
#
|
||||||
# Structure matches NewLLMConfig:
|
# Structure matches NewLLMConfig:
|
||||||
# - LLM model configuration (provider, model_name, api_key, etc.)
|
# - LLM model configuration (provider, model_name, api_key, etc.)
|
||||||
# - Prompt configuration (system_instructions, citations_enabled)
|
# - Prompt configuration (system_instructions, citations_enabled)
|
||||||
|
|
||||||
|
# Router Settings for Auto Mode
|
||||||
|
# These settings control how the LiteLLM Router distributes requests across models
|
||||||
|
router_settings:
|
||||||
|
# Routing strategy options:
|
||||||
|
# - "usage-based-routing": Routes to deployment with lowest current usage (recommended for rate limits)
|
||||||
|
# - "simple-shuffle": Random distribution with optional RPM/TPM weighting
|
||||||
|
# - "least-busy": Routes to least busy deployment
|
||||||
|
# - "latency-based-routing": Routes based on response latency
|
||||||
|
routing_strategy: "usage-based-routing"
|
||||||
|
|
||||||
|
# Number of retries before failing
|
||||||
|
num_retries: 3
|
||||||
|
|
||||||
|
# Number of failures allowed before cooling down a deployment
|
||||||
|
allowed_fails: 3
|
||||||
|
|
||||||
|
# Cooldown time in seconds after allowed_fails is exceeded
|
||||||
|
cooldown_time: 60
|
||||||
|
|
||||||
|
# Fallback models (optional) - when primary fails, try these
|
||||||
|
# Format: [{"primary_model": ["fallback1", "fallback2"]}]
|
||||||
|
# fallbacks: []
|
||||||
|
|
||||||
global_llm_configs:
|
global_llm_configs:
|
||||||
# Example: OpenAI GPT-4 Turbo with citations enabled
|
# Example: OpenAI GPT-4 Turbo with citations enabled
|
||||||
- id: -1
|
- id: -1
|
||||||
|
|
@ -23,6 +52,9 @@ global_llm_configs:
|
||||||
model_name: "gpt-4-turbo-preview"
|
model_name: "gpt-4-turbo-preview"
|
||||||
api_key: "sk-your-openai-api-key-here"
|
api_key: "sk-your-openai-api-key-here"
|
||||||
api_base: ""
|
api_base: ""
|
||||||
|
# Rate limits for load balancing (requests/tokens per minute)
|
||||||
|
rpm: 500 # Requests per minute
|
||||||
|
tpm: 100000 # Tokens per minute
|
||||||
litellm_params:
|
litellm_params:
|
||||||
temperature: 0.7
|
temperature: 0.7
|
||||||
max_tokens: 4000
|
max_tokens: 4000
|
||||||
|
|
@ -39,6 +71,8 @@ global_llm_configs:
|
||||||
model_name: "claude-3-opus-20240229"
|
model_name: "claude-3-opus-20240229"
|
||||||
api_key: "sk-ant-your-anthropic-api-key-here"
|
api_key: "sk-ant-your-anthropic-api-key-here"
|
||||||
api_base: ""
|
api_base: ""
|
||||||
|
rpm: 1000
|
||||||
|
tpm: 100000
|
||||||
litellm_params:
|
litellm_params:
|
||||||
temperature: 0.7
|
temperature: 0.7
|
||||||
max_tokens: 4000
|
max_tokens: 4000
|
||||||
|
|
@ -54,6 +88,8 @@ global_llm_configs:
|
||||||
model_name: "gpt-3.5-turbo"
|
model_name: "gpt-3.5-turbo"
|
||||||
api_key: "sk-your-openai-api-key-here"
|
api_key: "sk-your-openai-api-key-here"
|
||||||
api_base: ""
|
api_base: ""
|
||||||
|
rpm: 3500 # GPT-3.5 has higher rate limits
|
||||||
|
tpm: 200000
|
||||||
litellm_params:
|
litellm_params:
|
||||||
temperature: 0.5
|
temperature: 0.5
|
||||||
max_tokens: 2000
|
max_tokens: 2000
|
||||||
|
|
@ -69,6 +105,8 @@ global_llm_configs:
|
||||||
model_name: "deepseek-chat"
|
model_name: "deepseek-chat"
|
||||||
api_key: "your-deepseek-api-key-here"
|
api_key: "your-deepseek-api-key-here"
|
||||||
api_base: "https://api.deepseek.com/v1"
|
api_base: "https://api.deepseek.com/v1"
|
||||||
|
rpm: 60
|
||||||
|
tpm: 100000
|
||||||
litellm_params:
|
litellm_params:
|
||||||
temperature: 0.7
|
temperature: 0.7
|
||||||
max_tokens: 4000
|
max_tokens: 4000
|
||||||
|
|
@ -92,6 +130,8 @@ global_llm_configs:
|
||||||
model_name: "llama3-70b-8192"
|
model_name: "llama3-70b-8192"
|
||||||
api_key: "your-groq-api-key-here"
|
api_key: "your-groq-api-key-here"
|
||||||
api_base: ""
|
api_base: ""
|
||||||
|
rpm: 30 # Groq has lower rate limits on free tier
|
||||||
|
tpm: 14400
|
||||||
litellm_params:
|
litellm_params:
|
||||||
temperature: 0.7
|
temperature: 0.7
|
||||||
max_tokens: 8000
|
max_tokens: 8000
|
||||||
|
|
@ -100,6 +140,7 @@ global_llm_configs:
|
||||||
citations_enabled: true
|
citations_enabled: true
|
||||||
|
|
||||||
# Notes:
|
# Notes:
|
||||||
|
# - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing
|
||||||
# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB)
|
# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB)
|
||||||
# - IDs should be unique and sequential (e.g., -1, -2, -3, etc.)
|
# - IDs should be unique and sequential (e.g., -1, -2, -3, etc.)
|
||||||
# - The 'api_key' field will not be exposed to users via API
|
# - The 'api_key' field will not be exposed to users via API
|
||||||
|
|
@ -107,3 +148,5 @@ global_llm_configs:
|
||||||
# - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty
|
# - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty
|
||||||
# - citations_enabled: true = include citation instructions, false = include anti-citation instructions
|
# - citations_enabled: true = include citation instructions, false = include anti-citation instructions
|
||||||
# - All standard LiteLLM providers are supported
|
# - All standard LiteLLM providers are supported
|
||||||
|
# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute)
|
||||||
|
# These help the router distribute load evenly and avoid rate limit errors
|
||||||
|
|
|
||||||
|
|
@ -807,11 +807,16 @@ class SearchSpace(BaseModel, TimestampMixin):
|
||||||
) # User's custom instructions
|
) # User's custom instructions
|
||||||
|
|
||||||
# Search space-level LLM preferences (shared by all members)
|
# Search space-level LLM preferences (shared by all members)
|
||||||
# Note: These can be negative IDs for global configs (from YAML) or positive IDs for custom configs (from DB)
|
# Note: ID values:
|
||||||
agent_llm_id = Column(Integer, nullable=True) # For agent/chat operations
|
# - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces
|
||||||
|
# - Negative IDs: Global configs from YAML
|
||||||
|
# - Positive IDs: Custom configs from DB (NewLLMConfig table)
|
||||||
|
agent_llm_id = Column(
|
||||||
|
Integer, nullable=True, default=0
|
||||||
|
) # For agent/chat operations, defaults to Auto mode
|
||||||
document_summary_llm_id = Column(
|
document_summary_llm_id = Column(
|
||||||
Integer, nullable=True
|
Integer, nullable=True, default=0
|
||||||
) # For document summarization
|
) # For document summarization, defaults to Auto mode
|
||||||
|
|
||||||
user_id = Column(
|
user_id = Column(
|
||||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||||
|
|
|
||||||
|
|
@ -50,13 +50,33 @@ async def get_global_new_llm_configs(
|
||||||
These are pre-configured by the system administrator and available to all users.
|
These are pre-configured by the system administrator and available to all users.
|
||||||
API keys are not exposed through this endpoint.
|
API keys are not exposed through this endpoint.
|
||||||
|
|
||||||
Global configs have negative IDs to distinguish from user-created configs.
|
Includes:
|
||||||
|
- Auto mode (ID 0): Uses LiteLLM Router for automatic load balancing
|
||||||
|
- Global configs (negative IDs): Individual pre-configured LLM providers
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
global_configs = config.GLOBAL_LLM_CONFIGS
|
global_configs = config.GLOBAL_LLM_CONFIGS
|
||||||
|
|
||||||
# Transform to new structure, hiding API keys
|
# Start with Auto mode as the first option (recommended default)
|
||||||
safe_configs = []
|
safe_configs = [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"name": "Auto (Load Balanced)",
|
||||||
|
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling. Recommended for most users.",
|
||||||
|
"provider": "AUTO",
|
||||||
|
"custom_provider": None,
|
||||||
|
"model_name": "auto",
|
||||||
|
"api_base": None,
|
||||||
|
"litellm_params": {},
|
||||||
|
"system_instructions": "",
|
||||||
|
"use_default_system_instructions": True,
|
||||||
|
"citations_enabled": True,
|
||||||
|
"is_global": True,
|
||||||
|
"is_auto_mode": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add individual global configs
|
||||||
for cfg in global_configs:
|
for cfg in global_configs:
|
||||||
safe_config = {
|
safe_config = {
|
||||||
"id": cfg.get("id"),
|
"id": cfg.get("id"),
|
||||||
|
|
|
||||||
|
|
@ -314,11 +314,29 @@ async def _get_llm_config_by_id(
|
||||||
) -> dict | None:
|
) -> dict | None:
|
||||||
"""
|
"""
|
||||||
Get an LLM config by ID as a dictionary. Returns database config for positive IDs,
|
Get an LLM config by ID as a dictionary. Returns database config for positive IDs,
|
||||||
global config for negative IDs, or None if ID is None.
|
global config for negative IDs, Auto mode config for ID 0, or None if ID is None.
|
||||||
"""
|
"""
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Auto mode (ID 0) - uses LiteLLM Router for load balancing
|
||||||
|
if config_id == 0:
|
||||||
|
return {
|
||||||
|
"id": 0,
|
||||||
|
"name": "Auto (Load Balanced)",
|
||||||
|
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling",
|
||||||
|
"provider": "AUTO",
|
||||||
|
"custom_provider": None,
|
||||||
|
"model_name": "auto",
|
||||||
|
"api_base": None,
|
||||||
|
"litellm_params": {},
|
||||||
|
"system_instructions": "",
|
||||||
|
"use_default_system_instructions": True,
|
||||||
|
"citations_enabled": True,
|
||||||
|
"is_global": True,
|
||||||
|
"is_auto_mode": True,
|
||||||
|
}
|
||||||
|
|
||||||
if config_id < 0:
|
if config_id < 0:
|
||||||
# Global config - find from YAML
|
# Global config - find from YAML
|
||||||
global_configs = config.GLOBAL_LLM_CONFIGS
|
global_configs = config.GLOBAL_LLM_CONFIGS
|
||||||
|
|
|
||||||
|
|
@ -135,14 +135,19 @@ class GlobalNewLLMConfigRead(BaseModel):
|
||||||
Schema for reading global LLM configs from YAML.
|
Schema for reading global LLM configs from YAML.
|
||||||
Global configs have negative IDs and no search_space_id.
|
Global configs have negative IDs and no search_space_id.
|
||||||
API key is hidden for security.
|
API key is hidden for security.
|
||||||
|
|
||||||
|
ID 0 is reserved for Auto mode which uses LiteLLM Router for load balancing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: int = Field(..., description="Negative ID for global configs")
|
id: int = Field(
|
||||||
|
...,
|
||||||
|
description="Config ID: 0 for Auto mode, negative for global configs",
|
||||||
|
)
|
||||||
name: str
|
name: str
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
|
|
||||||
# LLM Model Configuration (no api_key)
|
# LLM Model Configuration (no api_key)
|
||||||
provider: str # String because YAML doesn't enforce enum
|
provider: str # String because YAML doesn't enforce enum, "AUTO" for Auto mode
|
||||||
custom_provider: str | None = None
|
custom_provider: str | None = None
|
||||||
model_name: str
|
model_name: str
|
||||||
api_base: str | None = None
|
api_base: str | None = None
|
||||||
|
|
@ -154,6 +159,7 @@ class GlobalNewLLMConfigRead(BaseModel):
|
||||||
citations_enabled: bool = True
|
citations_enabled: bool = True
|
||||||
|
|
||||||
is_global: bool = True # Always true for global configs
|
is_global: bool = True # Always true for global configs
|
||||||
|
is_auto_mode: bool = False # True only for Auto mode (ID 0)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
|
||||||
632
surfsense_backend/app/services/llm_router_service.py
Normal file
632
surfsense_backend/app/services/llm_router_service.py
Normal file
|
|
@ -0,0 +1,632 @@
|
||||||
|
"""
|
||||||
|
LiteLLM Router Service for Load Balancing
|
||||||
|
|
||||||
|
This module provides a singleton LiteLLM Router for automatic load balancing
|
||||||
|
across multiple LLM deployments. It handles:
|
||||||
|
- Rate limit management with automatic cooldowns
|
||||||
|
- Automatic failover and retries
|
||||||
|
- Usage-based routing to distribute load evenly
|
||||||
|
|
||||||
|
The router is initialized from global LLM configs and provides both
|
||||||
|
synchronous ChatLiteLLM-like interface and async methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Special ID for Auto mode - uses router for load balancing
|
||||||
|
AUTO_MODE_ID = 0
|
||||||
|
|
||||||
|
# Provider mapping for LiteLLM model string construction
|
||||||
|
PROVIDER_MAP = {
|
||||||
|
"OPENAI": "openai",
|
||||||
|
"ANTHROPIC": "anthropic",
|
||||||
|
"GROQ": "groq",
|
||||||
|
"COHERE": "cohere",
|
||||||
|
"GOOGLE": "gemini",
|
||||||
|
"OLLAMA": "ollama",
|
||||||
|
"MISTRAL": "mistral",
|
||||||
|
"AZURE_OPENAI": "azure",
|
||||||
|
"OPENROUTER": "openrouter",
|
||||||
|
"COMETAPI": "cometapi",
|
||||||
|
"XAI": "xai",
|
||||||
|
"BEDROCK": "bedrock",
|
||||||
|
"AWS_BEDROCK": "bedrock", # Legacy support
|
||||||
|
"VERTEX_AI": "vertex_ai",
|
||||||
|
"TOGETHER_AI": "together_ai",
|
||||||
|
"FIREWORKS_AI": "fireworks_ai",
|
||||||
|
"REPLICATE": "replicate",
|
||||||
|
"PERPLEXITY": "perplexity",
|
||||||
|
"ANYSCALE": "anyscale",
|
||||||
|
"DEEPINFRA": "deepinfra",
|
||||||
|
"CEREBRAS": "cerebras",
|
||||||
|
"SAMBANOVA": "sambanova",
|
||||||
|
"AI21": "ai21",
|
||||||
|
"CLOUDFLARE": "cloudflare",
|
||||||
|
"DATABRICKS": "databricks",
|
||||||
|
"DEEPSEEK": "openai",
|
||||||
|
"ALIBABA_QWEN": "openai",
|
||||||
|
"MOONSHOT": "openai",
|
||||||
|
"ZHIPU": "openai",
|
||||||
|
"HUGGINGFACE": "huggingface",
|
||||||
|
"CUSTOM": "custom",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LLMRouterService:
|
||||||
|
"""
|
||||||
|
Singleton service for managing LiteLLM Router.
|
||||||
|
|
||||||
|
The router provides automatic load balancing, failover, and rate limit
|
||||||
|
handling across multiple LLM deployments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_router: Router | None = None
|
||||||
|
_model_list: list[dict] = []
|
||||||
|
_router_settings: dict = {}
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls) -> "LLMRouterService":
|
||||||
|
"""Get the singleton instance of the router service."""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def initialize(
|
||||||
|
cls,
|
||||||
|
global_configs: list[dict],
|
||||||
|
router_settings: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the router with global LLM configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_configs: List of global LLM config dictionaries from YAML
|
||||||
|
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
|
||||||
|
"""
|
||||||
|
instance = cls.get_instance()
|
||||||
|
|
||||||
|
if instance._initialized:
|
||||||
|
logger.debug("LLM Router already initialized, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build model list from global configs
|
||||||
|
model_list = []
|
||||||
|
for config in global_configs:
|
||||||
|
deployment = cls._config_to_deployment(config)
|
||||||
|
if deployment:
|
||||||
|
model_list.append(deployment)
|
||||||
|
|
||||||
|
if not model_list:
|
||||||
|
logger.warning("No valid LLM configs found for router initialization")
|
||||||
|
return
|
||||||
|
|
||||||
|
instance._model_list = model_list
|
||||||
|
instance._router_settings = router_settings or {}
|
||||||
|
|
||||||
|
# Default router settings optimized for rate limit handling
|
||||||
|
default_settings = {
|
||||||
|
"routing_strategy": "usage-based-routing", # Best for rate limit management
|
||||||
|
"num_retries": 3,
|
||||||
|
"allowed_fails": 3,
|
||||||
|
"cooldown_time": 60, # Cooldown for 60 seconds after failures
|
||||||
|
"retry_after": 5, # Wait 5 seconds between retries
|
||||||
|
}
|
||||||
|
|
||||||
|
# Merge with provided settings
|
||||||
|
final_settings = {**default_settings, **instance._router_settings}
|
||||||
|
|
||||||
|
try:
|
||||||
|
instance._router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
routing_strategy=final_settings.get(
|
||||||
|
"routing_strategy", "usage-based-routing"
|
||||||
|
),
|
||||||
|
num_retries=final_settings.get("num_retries", 3),
|
||||||
|
allowed_fails=final_settings.get("allowed_fails", 3),
|
||||||
|
cooldown_time=final_settings.get("cooldown_time", 60),
|
||||||
|
set_verbose=False, # Disable verbose logging in production
|
||||||
|
)
|
||||||
|
instance._initialized = True
|
||||||
|
logger.info(
|
||||||
|
f"LLM Router initialized with {len(model_list)} deployments, "
|
||||||
|
f"strategy: {final_settings.get('routing_strategy')}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||||
|
instance._router = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _config_to_deployment(cls, config: dict) -> dict | None:
|
||||||
|
"""
|
||||||
|
Convert a global LLM config to a router deployment entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Global LLM config dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Router deployment dictionary or None if invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Skip if essential fields are missing
|
||||||
|
if not config.get("model_name") or not config.get("api_key"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build model string
|
||||||
|
if config.get("custom_provider"):
|
||||||
|
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||||
|
else:
|
||||||
|
provider = config.get("provider", "").upper()
|
||||||
|
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||||
|
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||||
|
|
||||||
|
# Build litellm params
|
||||||
|
litellm_params = {
|
||||||
|
"model": model_string,
|
||||||
|
"api_key": config.get("api_key"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional api_base
|
||||||
|
if config.get("api_base"):
|
||||||
|
litellm_params["api_base"] = config["api_base"]
|
||||||
|
|
||||||
|
# Add any additional litellm parameters
|
||||||
|
if config.get("litellm_params"):
|
||||||
|
litellm_params.update(config["litellm_params"])
|
||||||
|
|
||||||
|
# Extract rate limits if provided
|
||||||
|
deployment = {
|
||||||
|
"model_name": "auto", # All configs use same alias for unified routing
|
||||||
|
"litellm_params": litellm_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add rate limits from config if available
|
||||||
|
if config.get("rpm"):
|
||||||
|
deployment["rpm"] = config["rpm"]
|
||||||
|
if config.get("tpm"):
|
||||||
|
deployment["tpm"] = config["tpm"]
|
||||||
|
|
||||||
|
return deployment
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to convert config to deployment: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_router(cls) -> Router | None:
|
||||||
|
"""Get the initialized router instance."""
|
||||||
|
instance = cls.get_instance()
|
||||||
|
return instance._router
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_initialized(cls) -> bool:
|
||||||
|
"""Check if the router has been initialized."""
|
||||||
|
instance = cls.get_instance()
|
||||||
|
return instance._initialized and instance._router is not None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_count(cls) -> int:
|
||||||
|
"""Get the number of models in the router."""
|
||||||
|
instance = cls.get_instance()
|
||||||
|
return len(instance._model_list)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatLiteLLMRouter(BaseChatModel):
|
||||||
|
"""
|
||||||
|
A LangChain-compatible chat model that uses LiteLLM Router for load balancing.
|
||||||
|
|
||||||
|
This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM,
|
||||||
|
making it a drop-in replacement for auto-mode routing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Use model_config for Pydantic v2 compatibility
|
||||||
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
|
# Public attributes that Pydantic will manage
|
||||||
|
model: str = "auto"
|
||||||
|
streaming: bool = True
|
||||||
|
|
||||||
|
# Bound tools and tool choice for tool calling
|
||||||
|
_bound_tools: list[dict] | None = None
|
||||||
|
_tool_choice: str | dict | None = None
|
||||||
|
_router: Router | None = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
router: Router | None = None,
|
||||||
|
bound_tools: list[dict] | None = None,
|
||||||
|
tool_choice: str | dict | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the ChatLiteLLMRouter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
router: LiteLLM Router instance. If None, uses the global singleton.
|
||||||
|
bound_tools: Pre-bound tools for tool calling
|
||||||
|
tool_choice: Tool choice configuration
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
# Store router and tools as private attributes
|
||||||
|
resolved_router = router or LLMRouterService.get_router()
|
||||||
|
object.__setattr__(self, "_router", resolved_router)
|
||||||
|
object.__setattr__(self, "_bound_tools", bound_tools)
|
||||||
|
object.__setattr__(self, "_tool_choice", tool_choice)
|
||||||
|
if not self._router:
|
||||||
|
raise ValueError(
|
||||||
|
"LLM Router not initialized. Call LLMRouterService.initialize() first."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "litellm-router"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"model": self.model,
|
||||||
|
"model_count": LLMRouterService.get_model_count(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def bind_tools(
|
||||||
|
self,
|
||||||
|
tools: list[Any],
|
||||||
|
*,
|
||||||
|
tool_choice: str | dict | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "ChatLiteLLMRouter":
|
||||||
|
"""
|
||||||
|
Bind tools to the model for function/tool calling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: List of tools to bind (can be LangChain tools, Pydantic models, or dicts)
|
||||||
|
tool_choice: Optional tool choice strategy ("auto", "required", "none", or specific tool)
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New ChatLiteLLMRouter instance with tools bound
|
||||||
|
"""
|
||||||
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
|
||||||
|
# Convert tools to OpenAI format
|
||||||
|
formatted_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
if isinstance(tool, dict):
|
||||||
|
# Already in dict format
|
||||||
|
formatted_tools.append(tool)
|
||||||
|
else:
|
||||||
|
# Convert using LangChain utility
|
||||||
|
try:
|
||||||
|
formatted_tools.append(convert_to_openai_tool(tool))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to convert tool {tool}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create a new instance with tools bound
|
||||||
|
return ChatLiteLLMRouter(
|
||||||
|
router=self._router,
|
||||||
|
bound_tools=formatted_tools if formatted_tools else None,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
model=self.model,
|
||||||
|
streaming=self.streaming,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""
|
||||||
|
Generate a response using the router (synchronous).
|
||||||
|
"""
|
||||||
|
if not self._router:
|
||||||
|
raise ValueError("Router not initialized")
|
||||||
|
|
||||||
|
# Convert LangChain messages to OpenAI format
|
||||||
|
formatted_messages = self._convert_messages(messages)
|
||||||
|
|
||||||
|
# Add tools if bound
|
||||||
|
call_kwargs = {**kwargs}
|
||||||
|
if self._bound_tools:
|
||||||
|
call_kwargs["tools"] = self._bound_tools
|
||||||
|
if self._tool_choice is not None:
|
||||||
|
call_kwargs["tool_choice"] = self._tool_choice
|
||||||
|
|
||||||
|
# Call router completion
|
||||||
|
response = self._router.completion(
|
||||||
|
model=self.model,
|
||||||
|
messages=formatted_messages,
|
||||||
|
stop=stop,
|
||||||
|
**call_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert response to ChatResult with potential tool calls
|
||||||
|
message = self._convert_response_to_message(response.choices[0].message)
|
||||||
|
generation = ChatGeneration(message=message)
|
||||||
|
|
||||||
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""
|
||||||
|
Generate a response using the router (asynchronous).
|
||||||
|
"""
|
||||||
|
if not self._router:
|
||||||
|
raise ValueError("Router not initialized")
|
||||||
|
|
||||||
|
# Convert LangChain messages to OpenAI format
|
||||||
|
formatted_messages = self._convert_messages(messages)
|
||||||
|
|
||||||
|
# Add tools if bound
|
||||||
|
call_kwargs = {**kwargs}
|
||||||
|
if self._bound_tools:
|
||||||
|
call_kwargs["tools"] = self._bound_tools
|
||||||
|
if self._tool_choice is not None:
|
||||||
|
call_kwargs["tool_choice"] = self._tool_choice
|
||||||
|
|
||||||
|
# Call router async completion
|
||||||
|
response = await self._router.acompletion(
|
||||||
|
model=self.model,
|
||||||
|
messages=formatted_messages,
|
||||||
|
stop=stop,
|
||||||
|
**call_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert response to ChatResult with potential tool calls
|
||||||
|
message = self._convert_response_to_message(response.choices[0].message)
|
||||||
|
generation = ChatGeneration(message=message)
|
||||||
|
|
||||||
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stream a response using the router (synchronous).
|
||||||
|
"""
|
||||||
|
if not self._router:
|
||||||
|
raise ValueError("Router not initialized")
|
||||||
|
|
||||||
|
formatted_messages = self._convert_messages(messages)
|
||||||
|
|
||||||
|
# Add tools if bound
|
||||||
|
call_kwargs = {**kwargs}
|
||||||
|
if self._bound_tools:
|
||||||
|
call_kwargs["tools"] = self._bound_tools
|
||||||
|
if self._tool_choice is not None:
|
||||||
|
call_kwargs["tool_choice"] = self._tool_choice
|
||||||
|
|
||||||
|
# Call router completion with streaming
|
||||||
|
response = self._router.completion(
|
||||||
|
model=self.model,
|
||||||
|
messages=formatted_messages,
|
||||||
|
stop=stop,
|
||||||
|
stream=True,
|
||||||
|
**call_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield chunks
|
||||||
|
for chunk in response:
|
||||||
|
if hasattr(chunk, "choices") and chunk.choices:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
chunk_msg = self._convert_delta_to_chunk(delta)
|
||||||
|
if chunk_msg:
|
||||||
|
yield ChatGenerationChunk(message=chunk_msg)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stream a response using the router (asynchronous).
|
||||||
|
"""
|
||||||
|
if not self._router:
|
||||||
|
raise ValueError("Router not initialized")
|
||||||
|
|
||||||
|
formatted_messages = self._convert_messages(messages)
|
||||||
|
|
||||||
|
# Add tools if bound
|
||||||
|
call_kwargs = {**kwargs}
|
||||||
|
if self._bound_tools:
|
||||||
|
call_kwargs["tools"] = self._bound_tools
|
||||||
|
if self._tool_choice is not None:
|
||||||
|
call_kwargs["tool_choice"] = self._tool_choice
|
||||||
|
|
||||||
|
# Call router async completion with streaming
|
||||||
|
response = await self._router.acompletion(
|
||||||
|
model=self.model,
|
||||||
|
messages=formatted_messages,
|
||||||
|
stop=stop,
|
||||||
|
stream=True,
|
||||||
|
**call_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield chunks asynchronously
|
||||||
|
async for chunk in response:
|
||||||
|
if hasattr(chunk, "choices") and chunk.choices:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
chunk_msg = self._convert_delta_to_chunk(delta)
|
||||||
|
if chunk_msg:
|
||||||
|
yield ChatGenerationChunk(message=chunk_msg)
|
||||||
|
|
||||||
|
def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]:
|
||||||
|
"""Convert LangChain messages to OpenAI format."""
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage as AIMsg,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, SystemMessage):
|
||||||
|
result.append({"role": "system", "content": msg.content})
|
||||||
|
elif isinstance(msg, HumanMessage):
|
||||||
|
result.append({"role": "user", "content": msg.content})
|
||||||
|
elif isinstance(msg, AIMsg):
|
||||||
|
ai_msg: dict[str, Any] = {"role": "assistant"}
|
||||||
|
if msg.content:
|
||||||
|
ai_msg["content"] = msg.content
|
||||||
|
# Handle tool calls
|
||||||
|
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||||
|
ai_msg["tool_calls"] = [
|
||||||
|
{
|
||||||
|
"id": tc.get("id", ""),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.get("name", ""),
|
||||||
|
"arguments": tc.get("args", "{}")
|
||||||
|
if isinstance(tc.get("args"), str)
|
||||||
|
else __import__("json").dumps(tc.get("args", {})),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in msg.tool_calls
|
||||||
|
]
|
||||||
|
result.append(ai_msg)
|
||||||
|
elif isinstance(msg, ToolMessage):
|
||||||
|
result.append(
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": msg.tool_call_id,
|
||||||
|
"content": msg.content
|
||||||
|
if isinstance(msg.content, str)
|
||||||
|
else __import__("json").dumps(msg.content),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback for other message types
|
||||||
|
role = getattr(msg, "type", "user")
|
||||||
|
if role == "human":
|
||||||
|
role = "user"
|
||||||
|
elif role == "ai":
|
||||||
|
role = "assistant"
|
||||||
|
result.append({"role": role, "content": msg.content})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _convert_response_to_message(self, response_message: Any) -> AIMessage:
|
||||||
|
"""Convert a LiteLLM response message to a LangChain AIMessage."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
content = getattr(response_message, "content", None) or ""
|
||||||
|
|
||||||
|
# Check for tool calls
|
||||||
|
tool_calls = []
|
||||||
|
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
|
||||||
|
for tc in response_message.tool_calls:
|
||||||
|
tool_call = {
|
||||||
|
"id": tc.id if hasattr(tc, "id") else "",
|
||||||
|
"name": tc.function.name if hasattr(tc, "function") else "",
|
||||||
|
"args": {},
|
||||||
|
}
|
||||||
|
# Parse arguments
|
||||||
|
if hasattr(tc, "function") and hasattr(tc.function, "arguments"):
|
||||||
|
try:
|
||||||
|
tool_call["args"] = json.loads(tc.function.arguments)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
tool_call["args"] = tc.function.arguments
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
return AIMessage(content=content, tool_calls=tool_calls)
|
||||||
|
return AIMessage(content=content)
|
||||||
|
|
||||||
|
def _convert_delta_to_chunk(self, delta: Any) -> AIMessageChunk | None:
|
||||||
|
"""Convert a streaming delta to an AIMessageChunk."""
|
||||||
|
|
||||||
|
content = getattr(delta, "content", None) or ""
|
||||||
|
|
||||||
|
# Check for tool calls in delta
|
||||||
|
tool_call_chunks = []
|
||||||
|
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
||||||
|
for tc in delta.tool_calls:
|
||||||
|
chunk = {
|
||||||
|
"index": tc.index if hasattr(tc, "index") else 0,
|
||||||
|
"id": tc.id if hasattr(tc, "id") else None,
|
||||||
|
"name": tc.function.name
|
||||||
|
if hasattr(tc, "function") and hasattr(tc.function, "name")
|
||||||
|
else None,
|
||||||
|
"args": tc.function.arguments
|
||||||
|
if hasattr(tc, "function") and hasattr(tc.function, "arguments")
|
||||||
|
else "",
|
||||||
|
}
|
||||||
|
tool_call_chunks.append(chunk)
|
||||||
|
|
||||||
|
if content or tool_call_chunks:
|
||||||
|
if tool_call_chunks:
|
||||||
|
return AIMessageChunk(
|
||||||
|
content=content, tool_call_chunks=tool_call_chunks
|
||||||
|
)
|
||||||
|
return AIMessageChunk(content=content)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_auto_mode_llm() -> ChatLiteLLMRouter | None:
|
||||||
|
"""
|
||||||
|
Get a ChatLiteLLMRouter instance for auto mode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatLiteLLMRouter instance or None if router not initialized
|
||||||
|
"""
|
||||||
|
if not LLMRouterService.is_initialized():
|
||||||
|
logger.warning("LLM Router not initialized for auto mode")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ChatLiteLLMRouter()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_auto_mode(llm_config_id: int | None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given LLM config ID represents Auto mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_config_id: The LLM config ID to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if this is Auto mode, False otherwise
|
||||||
|
"""
|
||||||
|
return llm_config_id == AUTO_MODE_ID
|
||||||
|
|
@ -8,6 +8,12 @@ from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import NewLLMConfig, SearchSpace
|
from app.db import NewLLMConfig, SearchSpace
|
||||||
|
from app.services.llm_router_service import (
|
||||||
|
AUTO_MODE_ID,
|
||||||
|
ChatLiteLLMRouter,
|
||||||
|
LLMRouterService,
|
||||||
|
is_auto_mode,
|
||||||
|
)
|
||||||
|
|
||||||
# Configure litellm to automatically drop unsupported parameters
|
# Configure litellm to automatically drop unsupported parameters
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
@ -23,15 +29,26 @@ class LLMRole:
|
||||||
def get_global_llm_config(llm_config_id: int) -> dict | None:
|
def get_global_llm_config(llm_config_id: int) -> dict | None:
|
||||||
"""
|
"""
|
||||||
Get a global LLM configuration by ID.
|
Get a global LLM configuration by ID.
|
||||||
Global configs have negative IDs.
|
Global configs have negative IDs. ID 0 is reserved for Auto mode.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
llm_config_id: The ID of the global config (should be negative)
|
llm_config_id: The ID of the global config (should be negative or 0 for Auto)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Global config dictionary or None if not found
|
dict: Global config dictionary or None if not found
|
||||||
"""
|
"""
|
||||||
if llm_config_id >= 0:
|
# Auto mode (ID 0) is handled separately via the router
|
||||||
|
if llm_config_id == AUTO_MODE_ID:
|
||||||
|
return {
|
||||||
|
"id": AUTO_MODE_ID,
|
||||||
|
"name": "Auto (Load Balanced)",
|
||||||
|
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling",
|
||||||
|
"provider": "AUTO",
|
||||||
|
"model_name": "auto",
|
||||||
|
"is_auto_mode": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if llm_config_id > 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
for cfg in config.GLOBAL_LLM_CONFIGS:
|
for cfg in config.GLOBAL_LLM_CONFIGS:
|
||||||
|
|
@ -145,19 +162,22 @@ async def validate_llm_config(
|
||||||
|
|
||||||
async def get_search_space_llm_instance(
|
async def get_search_space_llm_instance(
|
||||||
session: AsyncSession, search_space_id: int, role: str
|
session: AsyncSession, search_space_id: int, role: str
|
||||||
) -> ChatLiteLLM | None:
|
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||||
"""
|
"""
|
||||||
Get a ChatLiteLLM instance for a specific search space and role.
|
Get a ChatLiteLLM instance for a specific search space and role.
|
||||||
|
|
||||||
LLM preferences are stored at the search space level and shared by all members.
|
LLM preferences are stored at the search space level and shared by all members.
|
||||||
|
|
||||||
|
If Auto mode (ID 0) is configured, returns a ChatLiteLLMRouter that uses
|
||||||
|
LiteLLM Router for automatic load balancing across available providers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: Database session
|
session: Database session
|
||||||
search_space_id: Search Space ID
|
search_space_id: Search Space ID
|
||||||
role: LLM role ('agent' or 'document_summary')
|
role: LLM role ('agent' or 'document_summary')
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ChatLiteLLM instance or None if not found
|
ChatLiteLLM or ChatLiteLLMRouter instance, or None if not found
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get the search space with its LLM preferences
|
# Get the search space with its LLM preferences
|
||||||
|
|
@ -180,10 +200,28 @@ async def get_search_space_llm_instance(
|
||||||
logger.error(f"Invalid LLM role: {role}")
|
logger.error(f"Invalid LLM role: {role}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not llm_config_id:
|
if llm_config_id is None:
|
||||||
logger.error(f"No {role} LLM configured for search space {search_space_id}")
|
logger.error(f"No {role} LLM configured for search space {search_space_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Check for Auto mode (ID 0) - use router for load balancing
|
||||||
|
if is_auto_mode(llm_config_id):
|
||||||
|
if not LLMRouterService.is_initialized():
|
||||||
|
logger.error(
|
||||||
|
"Auto mode requested but LLM Router not initialized. "
|
||||||
|
"Ensure global_llm_config.yaml exists with valid configs."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
|
||||||
|
)
|
||||||
|
return ChatLiteLLMRouter()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
# Check if this is a global config (negative ID)
|
# Check if this is a global config (negative ID)
|
||||||
if llm_config_id < 0:
|
if llm_config_id < 0:
|
||||||
global_config = get_global_llm_config(llm_config_id)
|
global_config = get_global_llm_config(llm_config_id)
|
||||||
|
|
@ -328,14 +366,14 @@ async def get_search_space_llm_instance(
|
||||||
|
|
||||||
async def get_agent_llm(
|
async def get_agent_llm(
|
||||||
session: AsyncSession, search_space_id: int
|
session: AsyncSession, search_space_id: int
|
||||||
) -> ChatLiteLLM | None:
|
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||||
"""Get the search space's agent LLM instance for chat operations."""
|
"""Get the search space's agent LLM instance for chat operations."""
|
||||||
return await get_search_space_llm_instance(session, search_space_id, LLMRole.AGENT)
|
return await get_search_space_llm_instance(session, search_space_id, LLMRole.AGENT)
|
||||||
|
|
||||||
|
|
||||||
async def get_document_summary_llm(
|
async def get_document_summary_llm(
|
||||||
session: AsyncSession, search_space_id: int
|
session: AsyncSession, search_space_id: int
|
||||||
) -> ChatLiteLLM | None:
|
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||||
"""Get the search space's document summary LLM instance."""
|
"""Get the search space's document summary LLM instance."""
|
||||||
return await get_search_space_llm_instance(
|
return await get_search_space_llm_instance(
|
||||||
session, search_space_id, LLMRole.DOCUMENT_SUMMARY
|
session, search_space_id, LLMRole.DOCUMENT_SUMMARY
|
||||||
|
|
@ -345,7 +383,7 @@ async def get_document_summary_llm(
|
||||||
# Backward-compatible alias (LLM preferences are now per-search-space, not per-user)
|
# Backward-compatible alias (LLM preferences are now per-search-space, not per-user)
|
||||||
async def get_user_long_context_llm(
|
async def get_user_long_context_llm(
|
||||||
session: AsyncSession, user_id: str, search_space_id: int
|
session: AsyncSession, user_id: str, search_space_id: int
|
||||||
) -> ChatLiteLLM | None:
|
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||||
"""
|
"""
|
||||||
Deprecated: Use get_document_summary_llm instead.
|
Deprecated: Use get_document_summary_llm instead.
|
||||||
The user_id parameter is ignored as LLM preferences are now per-search-space.
|
The user_id parameter is ignored as LLM preferences are now per-search-space.
|
||||||
|
|
|
||||||
|
|
@ -1215,8 +1215,12 @@ async def stream_new_chat(
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle any errors
|
# Handle any errors
|
||||||
|
import traceback
|
||||||
|
|
||||||
error_message = f"Error during chat: {e!s}"
|
error_message = f"Error during chat: {e!s}"
|
||||||
print(f"[stream_new_chat] {error_message}")
|
print(f"[stream_new_chat] {error_message}")
|
||||||
|
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
|
||||||
|
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
|
||||||
|
|
||||||
# Close any open text block
|
# Close any open text block
|
||||||
if current_text_id is not None:
|
if current_text_id is not None:
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,13 @@ export function DashboardClientLayout({
|
||||||
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
|
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
|
||||||
|
|
||||||
const isOnboardingComplete = useCallback(() => {
|
const isOnboardingComplete = useCallback(() => {
|
||||||
return !!(preferences.agent_llm_id && preferences.document_summary_llm_id);
|
// Check that both LLM IDs are set (including 0 for Auto mode)
|
||||||
|
return (
|
||||||
|
preferences.agent_llm_id !== null &&
|
||||||
|
preferences.agent_llm_id !== undefined &&
|
||||||
|
preferences.document_summary_llm_id !== null &&
|
||||||
|
preferences.document_summary_llm_id !== undefined
|
||||||
|
);
|
||||||
}, [preferences]);
|
}, [preferences]);
|
||||||
|
|
||||||
const { data: access = null, isLoading: accessLoading } = useAtomValue(myAccessAtom);
|
const { data: access = null, isLoading: accessLoading } = useAtomValue(myAccessAtom);
|
||||||
|
|
|
||||||
|
|
@ -53,8 +53,12 @@ export default function OnboardPage() {
|
||||||
}
|
}
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
// Check if onboarding is already complete
|
// Check if onboarding is already complete (including 0 for Auto mode)
|
||||||
const isOnboardingComplete = preferences.agent_llm_id && preferences.document_summary_llm_id;
|
const isOnboardingComplete =
|
||||||
|
preferences.agent_llm_id !== null &&
|
||||||
|
preferences.agent_llm_id !== undefined &&
|
||||||
|
preferences.document_summary_llm_id !== null &&
|
||||||
|
preferences.document_summary_llm_id !== undefined;
|
||||||
|
|
||||||
// If onboarding is already complete, redirect immediately
|
// If onboarding is already complete, redirect immediately
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
|
|
||||||
|
|
@ -485,7 +485,8 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
||||||
if (agentLlmId === null || agentLlmId === undefined) return false;
|
if (agentLlmId === null || agentLlmId === undefined) return false;
|
||||||
|
|
||||||
// Check if the configured model actually exists
|
// Check if the configured model actually exists
|
||||||
if (agentLlmId < 0) {
|
// Auto mode (ID 0) and global configs (negative IDs) are in globalConfigs
|
||||||
|
if (agentLlmId <= 0) {
|
||||||
return globalConfigs?.some((c) => c.id === agentLlmId) ?? false;
|
return globalConfigs?.some((c) => c.id === agentLlmId) ?? false;
|
||||||
}
|
}
|
||||||
return userConfigs?.some((c) => c.id === agentLlmId) ?? false;
|
return userConfigs?.some((c) => c.id === agentLlmId) ?? false;
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useAtomValue } from "jotai";
|
import { useAtomValue } from "jotai";
|
||||||
import { AlertCircle, Bot, ChevronRight, Globe, User, X } from "lucide-react";
|
import { AlertCircle, Bot, ChevronRight, Globe, Shuffle, User, X, Zap } from "lucide-react";
|
||||||
import { AnimatePresence, motion } from "motion/react";
|
import { AnimatePresence, motion } from "motion/react";
|
||||||
import { useCallback, useEffect, useState } from "react";
|
import { useCallback, useEffect, useState } from "react";
|
||||||
import { createPortal } from "react-dom";
|
import { createPortal } from "react-dom";
|
||||||
|
|
@ -62,9 +62,13 @@ export function ModelConfigSidebar({
|
||||||
return () => window.removeEventListener("keydown", handleEscape);
|
return () => window.removeEventListener("keydown", handleEscape);
|
||||||
}, [open, onOpenChange]);
|
}, [open, onOpenChange]);
|
||||||
|
|
||||||
|
// Check if this is Auto mode
|
||||||
|
const isAutoMode = config && "is_auto_mode" in config && config.is_auto_mode;
|
||||||
|
|
||||||
// Get title based on mode
|
// Get title based on mode
|
||||||
const getTitle = () => {
|
const getTitle = () => {
|
||||||
if (mode === "create") return "Add New Configuration";
|
if (mode === "create") return "Add New Configuration";
|
||||||
|
if (isAutoMode) return "Auto Mode (Load Balanced)";
|
||||||
if (isGlobal) return "View Global Configuration";
|
if (isGlobal) return "View Global Configuration";
|
||||||
return "Edit Configuration";
|
return "Edit Configuration";
|
||||||
};
|
};
|
||||||
|
|
@ -187,15 +191,37 @@ export function ModelConfigSidebar({
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{/* Header */}
|
{/* Header */}
|
||||||
<div className="flex items-center justify-between px-6 py-4 border-b border-border/50 bg-muted/20">
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex items-center justify-between px-6 py-4 border-b border-border/50",
|
||||||
|
isAutoMode ? "bg-gradient-to-r from-violet-500/10 to-purple-500/10" : "bg-muted/20"
|
||||||
|
)}
|
||||||
|
>
|
||||||
<div className="flex items-center gap-3">
|
<div className="flex items-center gap-3">
|
||||||
<div className="flex items-center justify-center size-10 rounded-xl bg-primary/10">
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex items-center justify-center size-10 rounded-xl",
|
||||||
|
isAutoMode ? "bg-gradient-to-br from-violet-500 to-purple-600" : "bg-primary/10"
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{isAutoMode ? (
|
||||||
|
<Shuffle className="size-5 text-white" />
|
||||||
|
) : (
|
||||||
<Bot className="size-5 text-primary" />
|
<Bot className="size-5 text-primary" />
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<h2 className="text-base sm:text-lg font-semibold">{getTitle()}</h2>
|
<h2 className="text-base sm:text-lg font-semibold">{getTitle()}</h2>
|
||||||
<div className="flex items-center gap-2 mt-0.5">
|
<div className="flex items-center gap-2 mt-0.5">
|
||||||
{isGlobal ? (
|
{isAutoMode ? (
|
||||||
|
<Badge
|
||||||
|
variant="secondary"
|
||||||
|
className="gap-1 text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
|
||||||
|
>
|
||||||
|
<Zap className="size-3" />
|
||||||
|
Recommended
|
||||||
|
</Badge>
|
||||||
|
) : isGlobal ? (
|
||||||
<Badge variant="secondary" className="gap-1 text-xs">
|
<Badge variant="secondary" className="gap-1 text-xs">
|
||||||
<Globe className="size-3" />
|
<Globe className="size-3" />
|
||||||
Global
|
Global
|
||||||
|
|
@ -206,7 +232,7 @@ export function ModelConfigSidebar({
|
||||||
Custom
|
Custom
|
||||||
</Badge>
|
</Badge>
|
||||||
) : null}
|
) : null}
|
||||||
{config && (
|
{config && !isAutoMode && (
|
||||||
<span className="text-xs text-muted-foreground">{config.model_name}</span>
|
<span className="text-xs text-muted-foreground">{config.model_name}</span>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -226,8 +252,19 @@ export function ModelConfigSidebar({
|
||||||
{/* Content - use overflow-y-auto instead of ScrollArea for better compatibility */}
|
{/* Content - use overflow-y-auto instead of ScrollArea for better compatibility */}
|
||||||
<div className="flex-1 overflow-y-auto">
|
<div className="flex-1 overflow-y-auto">
|
||||||
<div className="p-6">
|
<div className="p-6">
|
||||||
|
{/* Auto mode info banner */}
|
||||||
|
{isAutoMode && (
|
||||||
|
<Alert className="mb-6 border-violet-500/30 bg-violet-500/5">
|
||||||
|
<Shuffle className="size-4 text-violet-500" />
|
||||||
|
<AlertDescription className="text-sm text-violet-700 dark:text-violet-400">
|
||||||
|
Auto mode automatically distributes requests across all available LLM
|
||||||
|
providers to optimize performance and avoid rate limits.
|
||||||
|
</AlertDescription>
|
||||||
|
</Alert>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Global config notice */}
|
{/* Global config notice */}
|
||||||
{isGlobal && mode !== "create" && (
|
{isGlobal && !isAutoMode && mode !== "create" && (
|
||||||
<Alert className="mb-6 border-amber-500/30 bg-amber-500/5">
|
<Alert className="mb-6 border-amber-500/30 bg-amber-500/5">
|
||||||
<AlertCircle className="size-4 text-amber-500" />
|
<AlertCircle className="size-4 text-amber-500" />
|
||||||
<AlertDescription className="text-sm text-amber-700 dark:text-amber-400">
|
<AlertDescription className="text-sm text-amber-700 dark:text-amber-400">
|
||||||
|
|
@ -247,6 +284,87 @@ export function ModelConfigSidebar({
|
||||||
mode="create"
|
mode="create"
|
||||||
submitLabel="Create & Use"
|
submitLabel="Create & Use"
|
||||||
/>
|
/>
|
||||||
|
) : isAutoMode && config ? (
|
||||||
|
// Special view for Auto mode
|
||||||
|
<div className="space-y-6">
|
||||||
|
{/* Auto Mode Features */}
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div className="space-y-1.5">
|
||||||
|
<div className="text-xs font-medium text-muted-foreground uppercase tracking-wider">
|
||||||
|
How It Works
|
||||||
|
</div>
|
||||||
|
<p className="text-sm text-muted-foreground">{config.description}</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="h-px bg-border/50" />
|
||||||
|
|
||||||
|
<div className="space-y-3">
|
||||||
|
<div className="text-xs font-medium text-muted-foreground uppercase tracking-wider">
|
||||||
|
Key Benefits
|
||||||
|
</div>
|
||||||
|
<div className="space-y-2">
|
||||||
|
<div className="flex items-start gap-3 p-3 rounded-lg bg-violet-50 dark:bg-violet-900/20 border border-violet-200 dark:border-violet-800/50">
|
||||||
|
<Zap className="size-4 text-violet-600 dark:text-violet-400 mt-0.5 shrink-0" />
|
||||||
|
<div>
|
||||||
|
<p className="text-sm font-medium text-violet-900 dark:text-violet-100">
|
||||||
|
Automatic Load Balancing
|
||||||
|
</p>
|
||||||
|
<p className="text-xs text-violet-700 dark:text-violet-300">
|
||||||
|
Distributes requests across all configured LLM providers
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-start gap-3 p-3 rounded-lg bg-violet-50 dark:bg-violet-900/20 border border-violet-200 dark:border-violet-800/50">
|
||||||
|
<Zap className="size-4 text-violet-600 dark:text-violet-400 mt-0.5 shrink-0" />
|
||||||
|
<div>
|
||||||
|
<p className="text-sm font-medium text-violet-900 dark:text-violet-100">
|
||||||
|
Rate Limit Protection
|
||||||
|
</p>
|
||||||
|
<p className="text-xs text-violet-700 dark:text-violet-300">
|
||||||
|
Automatically handles rate limits with cooldowns and retries
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-start gap-3 p-3 rounded-lg bg-violet-50 dark:bg-violet-900/20 border border-violet-200 dark:border-violet-800/50">
|
||||||
|
<Zap className="size-4 text-violet-600 dark:text-violet-400 mt-0.5 shrink-0" />
|
||||||
|
<div>
|
||||||
|
<p className="text-sm font-medium text-violet-900 dark:text-violet-100">
|
||||||
|
Automatic Failover
|
||||||
|
</p>
|
||||||
|
<p className="text-xs text-violet-700 dark:text-violet-300">
|
||||||
|
Falls back to other providers if one becomes unavailable
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Action Buttons */}
|
||||||
|
<div className="flex gap-3 pt-4 border-t border-border/50">
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
className="flex-1"
|
||||||
|
onClick={() => onOpenChange(false)}
|
||||||
|
>
|
||||||
|
Close
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
className="flex-1 gap-2 bg-gradient-to-r from-violet-500 to-purple-600 hover:from-violet-600 hover:to-purple-700"
|
||||||
|
onClick={handleUseGlobalConfig}
|
||||||
|
disabled={isSubmitting}
|
||||||
|
>
|
||||||
|
{isSubmitting ? (
|
||||||
|
<>Loading...</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<ChevronRight className="size-4" />
|
||||||
|
Use Auto Mode
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
) : isGlobal && config ? (
|
) : isGlobal && config ? (
|
||||||
// Read-only view for global configs
|
// Read-only view for global configs
|
||||||
<div className="space-y-6">
|
<div className="space-y-6">
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import {
|
||||||
Globe,
|
Globe,
|
||||||
Plus,
|
Plus,
|
||||||
Settings2,
|
Settings2,
|
||||||
|
Shuffle,
|
||||||
Sparkles,
|
Sparkles,
|
||||||
User,
|
User,
|
||||||
Zap,
|
Zap,
|
||||||
|
|
@ -43,8 +44,14 @@ import type {
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
// Provider icons mapping
|
// Provider icons mapping
|
||||||
const getProviderIcon = (provider: string) => {
|
const getProviderIcon = (provider: string, isAutoMode?: boolean) => {
|
||||||
const iconClass = "size-4";
|
const iconClass = "size-4";
|
||||||
|
|
||||||
|
// Special icon for Auto mode
|
||||||
|
if (isAutoMode || provider?.toUpperCase() === "AUTO") {
|
||||||
|
return <Shuffle className={cn(iconClass, "text-violet-500")} />;
|
||||||
|
}
|
||||||
|
|
||||||
switch (provider?.toUpperCase()) {
|
switch (provider?.toUpperCase()) {
|
||||||
case "OPENAI":
|
case "OPENAI":
|
||||||
return <Sparkles className={cn(iconClass, "text-emerald-500")} />;
|
return <Sparkles className={cn(iconClass, "text-emerald-500")} />;
|
||||||
|
|
@ -90,14 +97,19 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp
|
||||||
const agentLlmId = preferences.agent_llm_id;
|
const agentLlmId = preferences.agent_llm_id;
|
||||||
if (agentLlmId === null || agentLlmId === undefined) return null;
|
if (agentLlmId === null || agentLlmId === undefined) return null;
|
||||||
|
|
||||||
// Check if it's a global config (negative ID)
|
// Check if it's Auto mode (ID 0) or global config (negative ID)
|
||||||
if (agentLlmId < 0) {
|
if (agentLlmId <= 0) {
|
||||||
return globalConfigs?.find((c) => c.id === agentLlmId) ?? null;
|
return globalConfigs?.find((c) => c.id === agentLlmId) ?? null;
|
||||||
}
|
}
|
||||||
// Otherwise, check user configs
|
// Otherwise, check user configs
|
||||||
return userConfigs?.find((c) => c.id === agentLlmId) ?? null;
|
return userConfigs?.find((c) => c.id === agentLlmId) ?? null;
|
||||||
}, [preferences, globalConfigs, userConfigs]);
|
}, [preferences, globalConfigs, userConfigs]);
|
||||||
|
|
||||||
|
// Check if current config is Auto mode
|
||||||
|
const isCurrentAutoMode = useMemo(() => {
|
||||||
|
return currentConfig && "is_auto_mode" in currentConfig && currentConfig.is_auto_mode;
|
||||||
|
}, [currentConfig]);
|
||||||
|
|
||||||
// Filter configs based on search
|
// Filter configs based on search
|
||||||
const filteredGlobalConfigs = useMemo(() => {
|
const filteredGlobalConfigs = useMemo(() => {
|
||||||
if (!globalConfigs) return [];
|
if (!globalConfigs) return [];
|
||||||
|
|
@ -184,14 +196,23 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp
|
||||||
</>
|
</>
|
||||||
) : currentConfig ? (
|
) : currentConfig ? (
|
||||||
<>
|
<>
|
||||||
{getProviderIcon(currentConfig.provider)}
|
{getProviderIcon(currentConfig.provider, isCurrentAutoMode ?? false)}
|
||||||
<span className="max-w-[100px] md:max-w-[150px] truncate hidden md:inline">
|
<span className="max-w-[100px] md:max-w-[150px] truncate hidden md:inline">
|
||||||
{currentConfig.name}
|
{currentConfig.name}
|
||||||
</span>
|
</span>
|
||||||
|
{isCurrentAutoMode ? (
|
||||||
|
<Badge
|
||||||
|
variant="secondary"
|
||||||
|
className="ml-1 text-[10px] px-1.5 py-0 h-4 bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
|
||||||
|
>
|
||||||
|
Balanced
|
||||||
|
</Badge>
|
||||||
|
) : (
|
||||||
<Badge variant="secondary" className="ml-1 text-[10px] px-1.5 py-0 h-4 bg-muted/80">
|
<Badge variant="secondary" className="ml-1 text-[10px] px-1.5 py-0 h-4 bg-muted/80">
|
||||||
{currentConfig.model_name.split("/").pop()?.slice(0, 10) ||
|
{currentConfig.model_name.split("/").pop()?.slice(0, 10) ||
|
||||||
currentConfig.model_name.slice(0, 10)}
|
currentConfig.model_name.slice(0, 10)}
|
||||||
</Badge>
|
</Badge>
|
||||||
|
)}
|
||||||
</>
|
</>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
|
|
@ -246,6 +267,7 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp
|
||||||
</div>
|
</div>
|
||||||
{filteredGlobalConfigs.map((config) => {
|
{filteredGlobalConfigs.map((config) => {
|
||||||
const isSelected = currentConfig?.id === config.id;
|
const isSelected = currentConfig?.id === config.id;
|
||||||
|
const isAutoMode = "is_auto_mode" in config && config.is_auto_mode;
|
||||||
return (
|
return (
|
||||||
<CommandItem
|
<CommandItem
|
||||||
key={`global-${config.id}`}
|
key={`global-${config.id}`}
|
||||||
|
|
@ -254,22 +276,33 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp
|
||||||
className={cn(
|
className={cn(
|
||||||
"mx-2 rounded-lg mb-1 cursor-pointer group transition-all",
|
"mx-2 rounded-lg mb-1 cursor-pointer group transition-all",
|
||||||
"hover:bg-accent/50",
|
"hover:bg-accent/50",
|
||||||
isSelected && "bg-accent/80"
|
isSelected && "bg-accent/80",
|
||||||
|
isAutoMode && "border border-violet-200 dark:border-violet-800/50"
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<div className="flex items-center justify-between w-full gap-2">
|
<div className="flex items-center justify-between w-full gap-2">
|
||||||
<div className="flex items-center gap-3 min-w-0 flex-1">
|
<div className="flex items-center gap-3 min-w-0 flex-1">
|
||||||
<div className="shrink-0">{getProviderIcon(config.provider)}</div>
|
<div className="shrink-0">
|
||||||
|
{getProviderIcon(config.provider, isAutoMode)}
|
||||||
|
</div>
|
||||||
<div className="min-w-0 flex-1">
|
<div className="min-w-0 flex-1">
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
<span className="font-medium truncate">{config.name}</span>
|
<span className="font-medium truncate">{config.name}</span>
|
||||||
|
{isAutoMode && (
|
||||||
|
<Badge
|
||||||
|
variant="secondary"
|
||||||
|
className="text-[9px] px-1 py-0 h-3.5 bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-0"
|
||||||
|
>
|
||||||
|
Recommended
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
{isSelected && <Check className="size-3.5 text-primary shrink-0" />}
|
{isSelected && <Check className="size-3.5 text-primary shrink-0" />}
|
||||||
</div>
|
</div>
|
||||||
<div className="flex items-center gap-1.5 mt-0.5">
|
<div className="flex items-center gap-1.5 mt-0.5">
|
||||||
<span className="text-xs text-muted-foreground truncate">
|
<span className="text-xs text-muted-foreground truncate">
|
||||||
{config.model_name}
|
{isAutoMode ? "Auto load balancing" : config.model_name}
|
||||||
</span>
|
</span>
|
||||||
{config.citations_enabled && (
|
{!isAutoMode && config.citations_enabled && (
|
||||||
<Badge
|
<Badge
|
||||||
variant="outline"
|
variant="outline"
|
||||||
className="text-[9px] px-1 py-0 h-3.5 bg-primary/10 text-primary border-primary/20"
|
className="text-[9px] px-1 py-0 h-3.5 bg-primary/10 text-primary border-primary/20"
|
||||||
|
|
@ -280,6 +313,7 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
{!isAutoMode && (
|
||||||
<Button
|
<Button
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
size="icon"
|
size="icon"
|
||||||
|
|
@ -288,6 +322,7 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp
|
||||||
>
|
>
|
||||||
<Edit3 className="size-3.5 text-muted-foreground" />
|
<Edit3 className="size-3.5 text-muted-foreground" />
|
||||||
</Button>
|
</Button>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</CommandItem>
|
</CommandItem>
|
||||||
);
|
);
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,16 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useAtomValue } from "jotai";
|
import { useAtomValue } from "jotai";
|
||||||
import { AlertCircle, Bot, CheckCircle, FileText, RefreshCw, RotateCcw, Save } from "lucide-react";
|
import {
|
||||||
|
AlertCircle,
|
||||||
|
Bot,
|
||||||
|
CheckCircle,
|
||||||
|
FileText,
|
||||||
|
RefreshCw,
|
||||||
|
RotateCcw,
|
||||||
|
Save,
|
||||||
|
Shuffle,
|
||||||
|
} from "lucide-react";
|
||||||
import { motion } from "motion/react";
|
import { motion } from "motion/react";
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
|
|
@ -24,6 +33,7 @@ import {
|
||||||
SelectValue,
|
SelectValue,
|
||||||
} from "@/components/ui/select";
|
} from "@/components/ui/select";
|
||||||
import { Spinner } from "@/components/ui/spinner";
|
import { Spinner } from "@/components/ui/spinner";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
const ROLE_DESCRIPTIONS = {
|
const ROLE_DESCRIPTIONS = {
|
||||||
agent: {
|
agent: {
|
||||||
|
|
@ -71,8 +81,8 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
|
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
|
||||||
|
|
||||||
const [assignments, setAssignments] = useState({
|
const [assignments, setAssignments] = useState({
|
||||||
agent_llm_id: preferences.agent_llm_id || "",
|
agent_llm_id: preferences.agent_llm_id ?? "",
|
||||||
document_summary_llm_id: preferences.document_summary_llm_id || "",
|
document_summary_llm_id: preferences.document_summary_llm_id ?? "",
|
||||||
});
|
});
|
||||||
|
|
||||||
const [hasChanges, setHasChanges] = useState(false);
|
const [hasChanges, setHasChanges] = useState(false);
|
||||||
|
|
@ -80,8 +90,8 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const newAssignments = {
|
const newAssignments = {
|
||||||
agent_llm_id: preferences.agent_llm_id || "",
|
agent_llm_id: preferences.agent_llm_id ?? "",
|
||||||
document_summary_llm_id: preferences.document_summary_llm_id || "",
|
document_summary_llm_id: preferences.document_summary_llm_id ?? "",
|
||||||
};
|
};
|
||||||
setAssignments(newAssignments);
|
setAssignments(newAssignments);
|
||||||
setHasChanges(false);
|
setHasChanges(false);
|
||||||
|
|
@ -97,8 +107,8 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
|
|
||||||
// Check if there are changes compared to current preferences
|
// Check if there are changes compared to current preferences
|
||||||
const currentPrefs = {
|
const currentPrefs = {
|
||||||
agent_llm_id: preferences.agent_llm_id || "",
|
agent_llm_id: preferences.agent_llm_id ?? "",
|
||||||
document_summary_llm_id: preferences.document_summary_llm_id || "",
|
document_summary_llm_id: preferences.document_summary_llm_id ?? "",
|
||||||
};
|
};
|
||||||
|
|
||||||
const hasChangesNow = Object.keys(newAssignments).some(
|
const hasChangesNow = Object.keys(newAssignments).some(
|
||||||
|
|
@ -141,13 +151,19 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
|
|
||||||
const handleReset = () => {
|
const handleReset = () => {
|
||||||
setAssignments({
|
setAssignments({
|
||||||
agent_llm_id: preferences.agent_llm_id || "",
|
agent_llm_id: preferences.agent_llm_id ?? "",
|
||||||
document_summary_llm_id: preferences.document_summary_llm_id || "",
|
document_summary_llm_id: preferences.document_summary_llm_id ?? "",
|
||||||
});
|
});
|
||||||
setHasChanges(false);
|
setHasChanges(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
const isAssignmentComplete = assignments.agent_llm_id && assignments.document_summary_llm_id;
|
const isAssignmentComplete =
|
||||||
|
assignments.agent_llm_id !== "" &&
|
||||||
|
assignments.agent_llm_id !== null &&
|
||||||
|
assignments.agent_llm_id !== undefined &&
|
||||||
|
assignments.document_summary_llm_id !== "" &&
|
||||||
|
assignments.document_summary_llm_id !== null &&
|
||||||
|
assignments.document_summary_llm_id !== undefined;
|
||||||
|
|
||||||
// Combine global and custom configs (new system)
|
// Combine global and custom configs (new system)
|
||||||
const allConfigs = [
|
const allConfigs = [
|
||||||
|
|
@ -300,22 +316,47 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">
|
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">
|
||||||
Global Configurations
|
Global Configurations
|
||||||
</div>
|
</div>
|
||||||
{globalConfigs.map((config) => (
|
{globalConfigs.map((config) => {
|
||||||
|
const isAutoMode =
|
||||||
|
"is_auto_mode" in config && config.is_auto_mode;
|
||||||
|
return (
|
||||||
<SelectItem key={config.id} value={config.id.toString()}>
|
<SelectItem key={config.id} value={config.id.toString()}>
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
|
{isAutoMode ? (
|
||||||
|
<Badge
|
||||||
|
variant="outline"
|
||||||
|
className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-violet-200 dark:border-violet-700"
|
||||||
|
>
|
||||||
|
<Shuffle className="size-3 mr-1" />
|
||||||
|
AUTO
|
||||||
|
</Badge>
|
||||||
|
) : (
|
||||||
<Badge variant="outline" className="text-xs">
|
<Badge variant="outline" className="text-xs">
|
||||||
{config.provider}
|
{config.provider}
|
||||||
</Badge>
|
</Badge>
|
||||||
|
)}
|
||||||
<span>{config.name}</span>
|
<span>{config.name}</span>
|
||||||
|
{!isAutoMode && (
|
||||||
<span className="text-muted-foreground">
|
<span className="text-muted-foreground">
|
||||||
({config.model_name})
|
({config.model_name})
|
||||||
</span>
|
</span>
|
||||||
|
)}
|
||||||
|
{isAutoMode ? (
|
||||||
|
<Badge
|
||||||
|
variant="secondary"
|
||||||
|
className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
|
||||||
|
>
|
||||||
|
Recommended
|
||||||
|
</Badge>
|
||||||
|
) : (
|
||||||
<Badge variant="secondary" className="text-xs">
|
<Badge variant="secondary" className="text-xs">
|
||||||
🌐 Global
|
🌐 Global
|
||||||
</Badge>
|
</Badge>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
))}
|
);
|
||||||
|
})}
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
|
@ -349,20 +390,56 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{assignedConfig && (
|
{assignedConfig && (
|
||||||
<div className="mt-2 md:mt-3 p-2 md:p-3 bg-muted/50 rounded-lg">
|
<div
|
||||||
|
className={cn(
|
||||||
|
"mt-2 md:mt-3 p-2 md:p-3 rounded-lg",
|
||||||
|
"is_auto_mode" in assignedConfig && assignedConfig.is_auto_mode
|
||||||
|
? "bg-violet-50 dark:bg-violet-900/20 border border-violet-200 dark:border-violet-800/50"
|
||||||
|
: "bg-muted/50"
|
||||||
|
)}
|
||||||
|
>
|
||||||
<div className="flex items-center gap-1.5 md:gap-2 text-xs md:text-sm flex-wrap">
|
<div className="flex items-center gap-1.5 md:gap-2 text-xs md:text-sm flex-wrap">
|
||||||
|
{"is_auto_mode" in assignedConfig && assignedConfig.is_auto_mode ? (
|
||||||
|
<Shuffle className="w-3 h-3 md:w-4 md:h-4 shrink-0 text-violet-600 dark:text-violet-400" />
|
||||||
|
) : (
|
||||||
<Bot className="w-3 h-3 md:w-4 md:h-4 shrink-0" />
|
<Bot className="w-3 h-3 md:w-4 md:h-4 shrink-0" />
|
||||||
|
)}
|
||||||
<span className="font-medium">Assigned:</span>
|
<span className="font-medium">Assigned:</span>
|
||||||
|
{"is_auto_mode" in assignedConfig && assignedConfig.is_auto_mode ? (
|
||||||
|
<Badge
|
||||||
|
variant="secondary"
|
||||||
|
className="text-[10px] md:text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
|
||||||
|
>
|
||||||
|
AUTO
|
||||||
|
</Badge>
|
||||||
|
) : (
|
||||||
<Badge variant="secondary" className="text-[10px] md:text-xs">
|
<Badge variant="secondary" className="text-[10px] md:text-xs">
|
||||||
{assignedConfig.provider}
|
{assignedConfig.provider}
|
||||||
</Badge>
|
</Badge>
|
||||||
|
)}
|
||||||
<span>{assignedConfig.name}</span>
|
<span>{assignedConfig.name}</span>
|
||||||
{"is_global" in assignedConfig && assignedConfig.is_global && (
|
{"is_auto_mode" in assignedConfig && assignedConfig.is_auto_mode ? (
|
||||||
|
<Badge
|
||||||
|
variant="outline"
|
||||||
|
className="text-[9px] md:text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-violet-200 dark:border-violet-700"
|
||||||
|
>
|
||||||
|
Recommended
|
||||||
|
</Badge>
|
||||||
|
) : (
|
||||||
|
"is_global" in assignedConfig &&
|
||||||
|
assignedConfig.is_global && (
|
||||||
<Badge variant="outline" className="text-[9px] md:text-xs">
|
<Badge variant="outline" className="text-[9px] md:text-xs">
|
||||||
🌐 Global
|
🌐 Global
|
||||||
</Badge>
|
</Badge>
|
||||||
|
)
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
{"is_auto_mode" in assignedConfig && assignedConfig.is_auto_mode ? (
|
||||||
|
<div className="text-[10px] md:text-xs text-violet-600 dark:text-violet-400 mt-0.5 md:mt-1">
|
||||||
|
Automatically load balances across all available LLM providers
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
<div className="text-[10px] md:text-xs text-muted-foreground mt-0.5 md:mt-1">
|
<div className="text-[10px] md:text-xs text-muted-foreground mt-0.5 md:mt-1">
|
||||||
Model: {assignedConfig.model_name}
|
Model: {assignedConfig.model_name}
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -371,6 +448,8 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
Base: {assignedConfig.api_base}
|
Base: {assignedConfig.api_base}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</CardContent>
|
</CardContent>
|
||||||
|
|
|
||||||
|
|
@ -136,14 +136,15 @@ export const getDefaultSystemInstructionsResponse = z.object({
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Global NewLLMConfig - from YAML, has negative IDs
|
* Global NewLLMConfig - from YAML, has negative IDs
|
||||||
|
* ID 0 is reserved for "Auto" mode which uses LiteLLM Router for load balancing
|
||||||
*/
|
*/
|
||||||
export const globalNewLLMConfig = z.object({
|
export const globalNewLLMConfig = z.object({
|
||||||
id: z.number(), // Negative IDs for global configs
|
id: z.number(), // 0 for Auto mode, negative IDs for global configs
|
||||||
name: z.string(),
|
name: z.string(),
|
||||||
description: z.string().nullable().optional(),
|
description: z.string().nullable().optional(),
|
||||||
|
|
||||||
// LLM Model Configuration (no api_key)
|
// LLM Model Configuration (no api_key)
|
||||||
provider: z.string(), // String because YAML doesn't enforce enum
|
provider: z.string(), // String because YAML doesn't enforce enum, "AUTO" for Auto mode
|
||||||
custom_provider: z.string().nullable().optional(),
|
custom_provider: z.string().nullable().optional(),
|
||||||
model_name: z.string(),
|
model_name: z.string(),
|
||||||
api_base: z.string().nullable().optional(),
|
api_base: z.string().nullable().optional(),
|
||||||
|
|
@ -155,6 +156,7 @@ export const globalNewLLMConfig = z.object({
|
||||||
citations_enabled: z.boolean().default(true),
|
citations_enabled: z.boolean().default(true),
|
||||||
|
|
||||||
is_global: z.literal(true),
|
is_global: z.literal(true),
|
||||||
|
is_auto_mode: z.boolean().optional().default(false), // True only for Auto mode (ID 0)
|
||||||
});
|
});
|
||||||
|
|
||||||
export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig);
|
export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig);
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue