mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
feat: added global llm configurations
This commit is contained in:
parent
48fca3329b
commit
d4345f75e5
24 changed files with 878 additions and 158 deletions
1
surfsense_backend/.gitignore
vendored
1
surfsense_backend/.gitignore
vendored
|
|
@ -11,3 +11,4 @@ celerybeat-schedule*
|
|||
celerybeat-schedule.*
|
||||
celerybeat-schedule.dir
|
||||
celerybeat-schedule.bak
|
||||
global_llm_config.yaml
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
"""remove_fk_constraints_for_global_llm_configs
|
||||
|
||||
Revision ID: 36
|
||||
Revises: 35
|
||||
Create Date: 2025-11-13 23:20:12.912741
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "36"
|
||||
down_revision: str | None = "35"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
Remove foreign key constraints on LLM preference columns to allow global configs (negative IDs).
|
||||
|
||||
Global LLM configs use negative IDs and don't exist in the llm_configs table,
|
||||
so we need to remove the foreign key constraints that were preventing their use.
|
||||
"""
|
||||
# Drop the foreign key constraints
|
||||
op.drop_constraint(
|
||||
"user_search_space_preferences_long_context_llm_id_fkey",
|
||||
"user_search_space_preferences",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"user_search_space_preferences_fast_llm_id_fkey",
|
||||
"user_search_space_preferences",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"user_search_space_preferences_strategic_llm_id_fkey",
|
||||
"user_search_space_preferences",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""
|
||||
Re-add foreign key constraints (will fail if any negative IDs exist in the table).
|
||||
"""
|
||||
# Re-add the foreign key constraints
|
||||
op.create_foreign_key(
|
||||
"user_search_space_preferences_long_context_llm_id_fkey",
|
||||
"user_search_space_preferences",
|
||||
"llm_configs",
|
||||
["long_context_llm_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"user_search_space_preferences_fast_llm_id_fkey",
|
||||
"user_search_space_preferences",
|
||||
"llm_configs",
|
||||
["fast_llm_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"user_search_space_preferences_strategic_llm_id_fkey",
|
||||
"user_search_space_preferences",
|
||||
"llm_configs",
|
||||
["strategic_llm_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
|
@ -3,6 +3,7 @@ import shutil
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker
|
||||
from chonkie.embeddings.azure_openai import AzureOpenAIEmbeddings
|
||||
from chonkie.embeddings.registry import EmbeddingsRegistry
|
||||
|
|
@ -80,6 +81,36 @@ def is_ffmpeg_installed():
|
|||
return shutil.which("ffmpeg") is not None
|
||||
|
||||
|
||||
def load_global_llm_configs():
|
||||
"""
|
||||
Load global LLM configurations from YAML file.
|
||||
Falls back to example file if main file doesn't exist.
|
||||
|
||||
Returns:
|
||||
list: List of global LLM config dictionaries, or empty list if file doesn't exist
|
||||
"""
|
||||
# Try main config file first
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
# Fall back to example file for testing
|
||||
# if not global_config_file.exists():
|
||||
# global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.example.yaml"
|
||||
# if global_config_file.exists():
|
||||
# print("Info: Using global_llm_config.example.yaml (copy to global_llm_config.yaml for production)")
|
||||
|
||||
if not global_config_file.exists():
|
||||
# No global configs available
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("global_llm_configs", [])
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global LLM configs: {e}")
|
||||
return []
|
||||
|
||||
|
||||
class Config:
|
||||
# Check if ffmpeg is installed
|
||||
if not is_ffmpeg_installed():
|
||||
|
|
@ -122,6 +153,11 @@ class Config:
|
|||
# LLM instances are now managed per-user through the LLMConfig system
|
||||
# Legacy environment variables removed in favor of user-specific configurations
|
||||
|
||||
# Global LLM Configurations (optional)
|
||||
# Load from global_llm_config.yaml if available
|
||||
# These can be used as default options for users
|
||||
GLOBAL_LLM_CONFIGS = load_global_llm_configs()
|
||||
|
||||
# Chonkie Configuration | Edit this to your needs
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||
# Azure OpenAI credentials from environment variables
|
||||
|
|
|
|||
80
surfsense_backend/app/config/global_llm_config.example.yaml
Normal file
80
surfsense_backend/app/config/global_llm_config.example.yaml
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
# Global LLM Configuration
|
||||
#
|
||||
# SETUP INSTRUCTIONS:
|
||||
# 1. For production: Copy this file to global_llm_config.yaml and add your real API keys
|
||||
# 2. For testing: The system will use this example file automatically if global_llm_config.yaml doesn't exist
|
||||
#
|
||||
# NOTE: The example API keys below are placeholders and won't work.
|
||||
# Replace them with your actual API keys to enable global configurations.
|
||||
#
|
||||
# These configurations will be available to all users as a convenient option
|
||||
# Users can choose to use these global configs or add their own
|
||||
|
||||
global_llm_configs:
|
||||
# Example: OpenAI GPT-4 Turbo
|
||||
- id: -1
|
||||
name: "Global GPT-4 Turbo"
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-4-turbo-preview"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
language: "English"
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
|
||||
# Example: Anthropic Claude 3 Opus
|
||||
- id: -2
|
||||
name: "Global Claude 3 Opus"
|
||||
provider: "ANTHROPIC"
|
||||
model_name: "claude-3-opus-20240229"
|
||||
api_key: "sk-ant-your-anthropic-api-key-here"
|
||||
api_base: ""
|
||||
language: "English"
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
|
||||
# Example: Fast model - GPT-3.5 Turbo
|
||||
- id: -3
|
||||
name: "Global GPT-3.5 Turbo"
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-3.5-turbo"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
language: "English"
|
||||
litellm_params:
|
||||
temperature: 0.5
|
||||
max_tokens: 2000
|
||||
|
||||
# Example: Chinese LLM - DeepSeek
|
||||
- id: -4
|
||||
name: "Global DeepSeek Chat"
|
||||
provider: "DEEPSEEK"
|
||||
model_name: "deepseek-chat"
|
||||
api_key: "your-deepseek-api-key-here"
|
||||
api_base: "https://api.deepseek.com/v1"
|
||||
language: "Chinese"
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
|
||||
# Example: Groq - Fast inference
|
||||
- id: -5
|
||||
name: "Global Groq Llama 3"
|
||||
provider: "GROQ"
|
||||
model_name: "llama3-70b-8192"
|
||||
api_key: "your-groq-api-key-here"
|
||||
api_base: ""
|
||||
language: "English"
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 8000
|
||||
|
||||
# Notes:
|
||||
# - Use negative IDs to distinguish global configs from user configs
|
||||
# - IDs should be unique and sequential (e.g., -1, -2, -3, etc.)
|
||||
# - The 'api_key' field will not be exposed to users via API
|
||||
# - Users can select these configs for their long_context, fast, or strategic LLM roles
|
||||
# - All standard LiteLLM providers are supported
|
||||
|
||||
|
|
@ -348,15 +348,11 @@ class UserSearchSpacePreference(BaseModel, TimestampMixin):
|
|||
)
|
||||
|
||||
# User-specific LLM preferences for this search space
|
||||
long_context_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
fast_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
strategic_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
# Note: These can be negative IDs for global configs (from YAML) or positive IDs for custom configs (from DB)
|
||||
# Foreign keys removed to support global configs with negative IDs
|
||||
long_context_llm_id = Column(Integer, nullable=True)
|
||||
fast_llm_id = Column(Integer, nullable=True)
|
||||
strategic_llm_id = Column(Integer, nullable=True)
|
||||
|
||||
# Future RBAC fields can be added here
|
||||
# role = Column(String(50), nullable=True) # e.g., 'owner', 'editor', 'viewer'
|
||||
|
|
@ -365,13 +361,12 @@ class UserSearchSpacePreference(BaseModel, TimestampMixin):
|
|||
user = relationship("User", back_populates="search_space_preferences")
|
||||
search_space = relationship("SearchSpace", back_populates="user_preferences")
|
||||
|
||||
long_context_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
|
||||
)
|
||||
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
|
||||
strategic_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
|
||||
)
|
||||
# Note: Relationships removed because foreign keys no longer exist
|
||||
# Global configs (negative IDs) don't exist in llm_configs table
|
||||
# Application code manually fetches configs when needed
|
||||
# long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True)
|
||||
# fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
|
||||
# strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True)
|
||||
|
||||
|
||||
class Log(BaseModel, TimestampMixin):
|
||||
|
|
|
|||
|
|
@ -68,9 +68,9 @@ async def handle_chat_data(
|
|||
selectinload(UserSearchSpacePreference.search_space).selectinload(
|
||||
SearchSpace.llm_configs
|
||||
),
|
||||
selectinload(UserSearchSpacePreference.long_context_llm),
|
||||
selectinload(UserSearchSpacePreference.fast_llm),
|
||||
selectinload(UserSearchSpacePreference.strategic_llm),
|
||||
# Note: Removed selectinload for LLM relationships as they no longer exist
|
||||
# Global configs (negative IDs) don't have foreign keys
|
||||
# LLM configs are now fetched manually when needed
|
||||
)
|
||||
.filter(
|
||||
UserSearchSpacePreference.search_space_id == search_space_id,
|
||||
|
|
@ -81,6 +81,8 @@ async def handle_chat_data(
|
|||
# print("UserSearchSpacePreference:", user_preference)
|
||||
|
||||
language = None
|
||||
llm_configs = [] # Initialize to empty list
|
||||
|
||||
if (
|
||||
user_preference
|
||||
and user_preference.search_space
|
||||
|
|
@ -88,16 +90,36 @@ async def handle_chat_data(
|
|||
):
|
||||
llm_configs = user_preference.search_space.llm_configs
|
||||
|
||||
for preferred_llm in [
|
||||
user_preference.fast_llm,
|
||||
user_preference.long_context_llm,
|
||||
user_preference.strategic_llm,
|
||||
]:
|
||||
if preferred_llm and getattr(preferred_llm, "language", None):
|
||||
language = preferred_llm.language
|
||||
break
|
||||
# Manually fetch LLM configs since relationships no longer exist
|
||||
# Check fast_llm, long_context_llm, and strategic_llm IDs
|
||||
from app.config import config as app_config
|
||||
|
||||
if not language:
|
||||
for llm_id in [
|
||||
user_preference.fast_llm_id,
|
||||
user_preference.long_context_llm_id,
|
||||
user_preference.strategic_llm_id,
|
||||
]:
|
||||
if llm_id is not None:
|
||||
# Check if it's a global config (negative ID)
|
||||
if llm_id < 0:
|
||||
# Look in global configs
|
||||
for global_cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if global_cfg.get("id") == llm_id:
|
||||
language = global_cfg.get("language")
|
||||
if language:
|
||||
break
|
||||
else:
|
||||
# Look in custom configs
|
||||
for llm_config in llm_configs:
|
||||
if llm_config.id == llm_id and getattr(
|
||||
llm_config, "language", None
|
||||
):
|
||||
language = llm_config.language
|
||||
break
|
||||
if language:
|
||||
break
|
||||
|
||||
if not language and llm_configs:
|
||||
first_llm_config = llm_configs[0]
|
||||
language = getattr(first_llm_config, "language", None)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
LLMConfig,
|
||||
SearchSpace,
|
||||
|
|
@ -16,6 +18,7 @@ from app.services.llm_service import validate_llm_config
|
|||
from app.users import current_active_user
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Helper function to check search space access
|
||||
|
|
@ -43,16 +46,11 @@ async def get_or_create_user_preference(
|
|||
) -> UserSearchSpacePreference:
|
||||
"""Get or create user preference for a search space"""
|
||||
result = await session.execute(
|
||||
select(UserSearchSpacePreference)
|
||||
.filter(
|
||||
select(UserSearchSpacePreference).filter(
|
||||
UserSearchSpacePreference.user_id == user_id,
|
||||
UserSearchSpacePreference.search_space_id == search_space_id,
|
||||
)
|
||||
.options(
|
||||
selectinload(UserSearchSpacePreference.long_context_llm),
|
||||
selectinload(UserSearchSpacePreference.fast_llm),
|
||||
selectinload(UserSearchSpacePreference.strategic_llm),
|
||||
)
|
||||
# Removed selectinload options since relationships no longer exist
|
||||
)
|
||||
preference = result.scalars().first()
|
||||
|
||||
|
|
@ -88,6 +86,58 @@ class LLMPreferencesRead(BaseModel):
|
|||
strategic_llm: LLMConfigRead | None = None
|
||||
|
||||
|
||||
class GlobalLLMConfigRead(BaseModel):
|
||||
"""Schema for reading global LLM configs (without API key)"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
language: str | None = None
|
||||
litellm_params: dict | None = None
|
||||
is_global: bool = True
|
||||
|
||||
|
||||
# Global LLM Config endpoints
|
||||
|
||||
|
||||
@router.get("/global-llm-configs", response_model=list[GlobalLLMConfigRead])
|
||||
async def get_global_llm_configs(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get all available global LLM configurations.
|
||||
These are pre-configured by the system administrator and available to all users.
|
||||
API keys are not exposed through this endpoint.
|
||||
"""
|
||||
try:
|
||||
global_configs = config.GLOBAL_LLM_CONFIGS
|
||||
|
||||
# Remove API keys from response
|
||||
safe_configs = []
|
||||
for cfg in global_configs:
|
||||
safe_config = {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base"),
|
||||
"language": cfg.get("language"),
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
}
|
||||
safe_configs.append(safe_config)
|
||||
|
||||
return safe_configs
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch global LLM configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/llm-configs", response_model=LLMConfigRead)
|
||||
async def create_llm_config(
|
||||
llm_config: LLMConfigCreate,
|
||||
|
|
@ -309,13 +359,49 @@ async def get_user_llm_preferences(
|
|||
session, user.id, search_space_id
|
||||
)
|
||||
|
||||
# Helper function to get config (global or custom)
|
||||
async def get_config_for_id(config_id):
|
||||
if config_id is None:
|
||||
return None
|
||||
|
||||
# Check if it's a global config (negative ID)
|
||||
if config_id < 0:
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
# Return as LLMConfigRead-compatible dict
|
||||
return {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_key": "***GLOBAL***", # Don't expose the actual key
|
||||
"api_base": cfg.get("api_base"),
|
||||
"language": cfg.get("language"),
|
||||
"litellm_params": cfg.get("litellm_params"),
|
||||
"created_at": None,
|
||||
"search_space_id": search_space_id,
|
||||
}
|
||||
return None
|
||||
|
||||
# It's a custom config, fetch from database
|
||||
result = await session.execute(
|
||||
select(LLMConfig).filter(LLMConfig.id == config_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
# Get the configs (from DB for custom, or constructed for global)
|
||||
long_context_llm = await get_config_for_id(preference.long_context_llm_id)
|
||||
fast_llm = await get_config_for_id(preference.fast_llm_id)
|
||||
strategic_llm = await get_config_for_id(preference.strategic_llm_id)
|
||||
|
||||
return {
|
||||
"long_context_llm_id": preference.long_context_llm_id,
|
||||
"fast_llm_id": preference.fast_llm_id,
|
||||
"strategic_llm_id": preference.strategic_llm_id,
|
||||
"long_context_llm": preference.long_context_llm,
|
||||
"fast_llm": preference.fast_llm,
|
||||
"strategic_llm": preference.strategic_llm,
|
||||
"long_context_llm": long_context_llm,
|
||||
"fast_llm": fast_llm,
|
||||
"strategic_llm": strategic_llm,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -353,29 +439,57 @@ async def update_user_llm_preferences(
|
|||
|
||||
for _key, llm_config_id in update_data.items():
|
||||
if llm_config_id is not None:
|
||||
# Verify the LLM config belongs to the search space
|
||||
result = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == llm_config_id,
|
||||
LLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
if not llm_config:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"LLM configuration {llm_config_id} not found in this search space",
|
||||
)
|
||||
# Check if this is a global config (negative ID)
|
||||
if llm_config_id < 0:
|
||||
# Validate global config exists
|
||||
global_config = None
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == llm_config_id:
|
||||
global_config = cfg
|
||||
break
|
||||
|
||||
# Collect language for consistency check
|
||||
languages.add(llm_config.language)
|
||||
if not global_config:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Global LLM configuration {llm_config_id} not found",
|
||||
)
|
||||
|
||||
# Check if all selected LLM configs have the same language
|
||||
# Collect language for consistency check (if explicitly set)
|
||||
lang = global_config.get("language")
|
||||
if lang and lang.strip(): # Only add non-empty languages
|
||||
languages.add(lang.strip())
|
||||
else:
|
||||
# Verify the LLM config belongs to the search space (custom config)
|
||||
result = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == llm_config_id,
|
||||
LLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
if not llm_config:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"LLM configuration {llm_config_id} not found in this search space",
|
||||
)
|
||||
|
||||
# Collect language for consistency check (if explicitly set)
|
||||
if llm_config.language and llm_config.language.strip():
|
||||
languages.add(llm_config.language.strip())
|
||||
|
||||
# Language consistency check - only warn if there are multiple explicit languages
|
||||
# Allow mixing configs with and without language settings
|
||||
if len(languages) > 1:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="All selected LLM configurations must have the same language setting",
|
||||
# Log warning but allow the operation
|
||||
logger.warning(
|
||||
f"Multiple languages detected in LLM selection for search_space {search_space_id}: {languages}. "
|
||||
"This may affect response quality."
|
||||
)
|
||||
# Don't raise an exception - allow users to proceed
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail="All selected LLM configurations must have the same language setting",
|
||||
# )
|
||||
|
||||
# Update user preferences
|
||||
for key, value in update_data.items():
|
||||
|
|
@ -384,19 +498,50 @@ async def update_user_llm_preferences(
|
|||
await session.commit()
|
||||
await session.refresh(preference)
|
||||
|
||||
# Reload relationships
|
||||
await session.refresh(
|
||||
preference, ["long_context_llm", "fast_llm", "strategic_llm"]
|
||||
)
|
||||
# Helper function to get config (global or custom)
|
||||
async def get_config_for_id(config_id):
|
||||
if config_id is None:
|
||||
return None
|
||||
|
||||
# Check if it's a global config (negative ID)
|
||||
if config_id < 0:
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
# Return as LLMConfigRead-compatible dict
|
||||
return {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_key": "***GLOBAL***", # Don't expose the actual key
|
||||
"api_base": cfg.get("api_base"),
|
||||
"language": cfg.get("language"),
|
||||
"litellm_params": cfg.get("litellm_params"),
|
||||
"created_at": None,
|
||||
"search_space_id": search_space_id,
|
||||
}
|
||||
return None
|
||||
|
||||
# It's a custom config, fetch from database
|
||||
result = await session.execute(
|
||||
select(LLMConfig).filter(LLMConfig.id == config_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
# Get the configs (from DB for custom, or constructed for global)
|
||||
long_context_llm = await get_config_for_id(preference.long_context_llm_id)
|
||||
fast_llm = await get_config_for_id(preference.fast_llm_id)
|
||||
strategic_llm = await get_config_for_id(preference.strategic_llm_id)
|
||||
|
||||
# Return updated preferences
|
||||
return {
|
||||
"long_context_llm_id": preference.long_context_llm_id,
|
||||
"fast_llm_id": preference.fast_llm_id,
|
||||
"strategic_llm_id": preference.strategic_llm_id,
|
||||
"long_context_llm": preference.long_context_llm,
|
||||
"fast_llm": preference.fast_llm,
|
||||
"strategic_llm": preference.strategic_llm,
|
||||
"long_context_llm": long_context_llm,
|
||||
"fast_llm": fast_llm,
|
||||
"strategic_llm": strategic_llm,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -62,7 +62,11 @@ class LLMConfigUpdate(BaseModel):
|
|||
|
||||
class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel):
|
||||
id: int
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
created_at: datetime | None = Field(
|
||||
None, description="Creation timestamp (None for global configs)"
|
||||
)
|
||||
search_space_id: int | None = Field(
|
||||
None, description="Search space ID (None for global configs)"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from langchain_litellm import ChatLiteLLM
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.db import LLMConfig, UserSearchSpacePreference
|
||||
|
||||
# Configure litellm to automatically drop unsupported parameters
|
||||
|
|
@ -20,6 +21,27 @@ class LLMRole:
|
|||
STRATEGIC = "strategic"
|
||||
|
||||
|
||||
def get_global_llm_config(llm_config_id: int) -> dict | None:
|
||||
"""
|
||||
Get a global LLM configuration by ID.
|
||||
Global configs have negative IDs.
|
||||
|
||||
Args:
|
||||
llm_config_id: The ID of the global config (should be negative)
|
||||
|
||||
Returns:
|
||||
dict: Global config dictionary or None if not found
|
||||
"""
|
||||
if llm_config_id >= 0:
|
||||
return None
|
||||
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == llm_config_id:
|
||||
return cfg
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def validate_llm_config(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
|
|
@ -171,7 +193,70 @@ async def get_user_llm_instance(
|
|||
)
|
||||
return None
|
||||
|
||||
# Get the LLM configuration
|
||||
# Check if this is a global config (negative ID)
|
||||
if llm_config_id < 0:
|
||||
global_config = get_global_llm_config(llm_config_id)
|
||||
if not global_config:
|
||||
logger.error(f"Global LLM config {llm_config_id} not found")
|
||||
return None
|
||||
|
||||
# Build model string for global config
|
||||
if global_config.get("custom_provider"):
|
||||
model_string = (
|
||||
f"{global_config['custom_provider']}/{global_config['model_name']}"
|
||||
)
|
||||
else:
|
||||
provider_map = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"COMETAPI": "cometapi",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"AWS_BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
}
|
||||
provider_prefix = provider_map.get(
|
||||
global_config["provider"], global_config["provider"].lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{global_config['model_name']}"
|
||||
|
||||
# Create ChatLiteLLM instance from global config
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": global_config["api_key"],
|
||||
}
|
||||
|
||||
if global_config.get("api_base"):
|
||||
litellm_kwargs["api_base"] = global_config["api_base"]
|
||||
|
||||
if global_config.get("litellm_params"):
|
||||
litellm_kwargs.update(global_config["litellm_params"])
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
# Get the LLM configuration from database (user-specific config)
|
||||
result = await session.execute(
|
||||
select(LLMConfig).where(
|
||||
LLMConfig.id == llm_config_id,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue