SurfSense/surfsense_backend/app/routes/llm_config_routes.py
2025-11-14 21:53:46 -08:00

552 lines
20 KiB
Python

import logging
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config
from app.db import (
LLMConfig,
SearchSpace,
User,
UserSearchSpacePreference,
get_async_session,
)
from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
from app.services.llm_service import validate_llm_config
from app.users import current_active_user
router = APIRouter()
logger = logging.getLogger(__name__)
# Helper function to check search space access
async def check_search_space_access(
session: AsyncSession, search_space_id: int, user: User
) -> SearchSpace:
"""Verify that the user has access to the search space"""
result = await session.execute(
select(SearchSpace).filter(
SearchSpace.id == search_space_id, SearchSpace.user_id == user.id
)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(
status_code=404,
detail="Search space not found or you don't have permission to access it",
)
return search_space
# Helper function to get or create user search space preference
async def get_or_create_user_preference(
session: AsyncSession, user_id, search_space_id: int
) -> UserSearchSpacePreference:
"""Get or create user preference for a search space"""
result = await session.execute(
select(UserSearchSpacePreference).filter(
UserSearchSpacePreference.user_id == user_id,
UserSearchSpacePreference.search_space_id == search_space_id,
)
# Removed selectinload options since relationships no longer exist
)
preference = result.scalars().first()
if not preference:
# Create new preference entry
preference = UserSearchSpacePreference(
user_id=user_id,
search_space_id=search_space_id,
)
session.add(preference)
await session.commit()
await session.refresh(preference)
return preference
class LLMPreferencesUpdate(BaseModel):
"""Schema for updating user LLM preferences"""
long_context_llm_id: int | None = None
fast_llm_id: int | None = None
strategic_llm_id: int | None = None
class LLMPreferencesRead(BaseModel):
"""Schema for reading user LLM preferences"""
long_context_llm_id: int | None = None
fast_llm_id: int | None = None
strategic_llm_id: int | None = None
long_context_llm: LLMConfigRead | None = None
fast_llm: LLMConfigRead | None = None
strategic_llm: LLMConfigRead | None = None
class GlobalLLMConfigRead(BaseModel):
"""Schema for reading global LLM configs (without API key)"""
id: int
name: str
provider: str
custom_provider: str | None = None
model_name: str
api_base: str | None = None
language: str | None = None
litellm_params: dict | None = None
is_global: bool = True
# Global LLM Config endpoints
@router.get("/global-llm-configs", response_model=list[GlobalLLMConfigRead])
async def get_global_llm_configs(
user: User = Depends(current_active_user),
):
"""
Get all available global LLM configurations.
These are pre-configured by the system administrator and available to all users.
API keys are not exposed through this endpoint.
"""
try:
global_configs = config.GLOBAL_LLM_CONFIGS
# Remove API keys from response
safe_configs = []
for cfg in global_configs:
safe_config = {
"id": cfg.get("id"),
"name": cfg.get("name"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base"),
"language": cfg.get("language"),
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
}
safe_configs.append(safe_config)
return safe_configs
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to fetch global LLM configs: {e!s}"
) from e
@router.post("/llm-configs", response_model=LLMConfigRead)
async def create_llm_config(
llm_config: LLMConfigCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Create a new LLM configuration for a search space"""
try:
# Verify user has access to the search space
await check_search_space_access(session, llm_config.search_space_id, user)
# Validate the LLM configuration by making a test API call
is_valid, error_message = await validate_llm_config(
provider=llm_config.provider.value,
model_name=llm_config.model_name,
api_key=llm_config.api_key,
api_base=llm_config.api_base,
custom_provider=llm_config.custom_provider,
litellm_params=llm_config.litellm_params,
)
if not is_valid:
raise HTTPException(
status_code=400,
detail=f"Invalid LLM configuration: {error_message}",
)
db_llm_config = LLMConfig(**llm_config.model_dump())
session.add(db_llm_config)
await session.commit()
await session.refresh(db_llm_config)
return db_llm_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to create LLM configuration: {e!s}"
) from e
@router.get("/llm-configs", response_model=list[LLMConfigRead])
async def read_llm_configs(
search_space_id: int,
skip: int = 0,
limit: int = 200,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Get all LLM configurations for a search space"""
try:
# Verify user has access to the search space
await check_search_space_access(session, search_space_id, user)
result = await session.execute(
select(LLMConfig)
.filter(LLMConfig.search_space_id == search_space_id)
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to fetch LLM configurations: {e!s}"
) from e
@router.get("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
async def read_llm_config(
llm_config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Get a specific LLM configuration by ID"""
try:
# Get the LLM config
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
)
llm_config = result.scalars().first()
if not llm_config:
raise HTTPException(status_code=404, detail="LLM configuration not found")
# Verify user has access to the search space
await check_search_space_access(session, llm_config.search_space_id, user)
return llm_config
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to fetch LLM configuration: {e!s}"
) from e
@router.put("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
async def update_llm_config(
llm_config_id: int,
llm_config_update: LLMConfigUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Update an existing LLM configuration"""
try:
# Get the LLM config
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
)
db_llm_config = result.scalars().first()
if not db_llm_config:
raise HTTPException(status_code=404, detail="LLM configuration not found")
# Verify user has access to the search space
await check_search_space_access(session, db_llm_config.search_space_id, user)
update_data = llm_config_update.model_dump(exclude_unset=True)
# Apply updates to a temporary copy for validation
temp_config = {
"provider": update_data.get("provider", db_llm_config.provider.value),
"model_name": update_data.get("model_name", db_llm_config.model_name),
"api_key": update_data.get("api_key", db_llm_config.api_key),
"api_base": update_data.get("api_base", db_llm_config.api_base),
"custom_provider": update_data.get(
"custom_provider", db_llm_config.custom_provider
),
"litellm_params": update_data.get(
"litellm_params", db_llm_config.litellm_params
),
}
# Validate the updated configuration
is_valid, error_message = await validate_llm_config(
provider=temp_config["provider"],
model_name=temp_config["model_name"],
api_key=temp_config["api_key"],
api_base=temp_config["api_base"],
custom_provider=temp_config["custom_provider"],
litellm_params=temp_config["litellm_params"],
)
if not is_valid:
raise HTTPException(
status_code=400,
detail=f"Invalid LLM configuration: {error_message}",
)
# Apply updates to the database object
for key, value in update_data.items():
setattr(db_llm_config, key, value)
await session.commit()
await session.refresh(db_llm_config)
return db_llm_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to update LLM configuration: {e!s}"
) from e
@router.delete("/llm-configs/{llm_config_id}", response_model=dict)
async def delete_llm_config(
llm_config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Delete an LLM configuration"""
try:
# Get the LLM config
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
)
db_llm_config = result.scalars().first()
if not db_llm_config:
raise HTTPException(status_code=404, detail="LLM configuration not found")
# Verify user has access to the search space
await check_search_space_access(session, db_llm_config.search_space_id, user)
await session.delete(db_llm_config)
await session.commit()
return {"message": "LLM configuration deleted successfully"}
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to delete LLM configuration: {e!s}"
) from e
# User LLM Preferences endpoints
@router.get(
"/search-spaces/{search_space_id}/llm-preferences",
response_model=LLMPreferencesRead,
)
async def get_user_llm_preferences(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Get the current user's LLM preferences for a specific search space"""
try:
# Verify user has access to the search space
await check_search_space_access(session, search_space_id, user)
# Get or create user preference for this search space
preference = await get_or_create_user_preference(
session, user.id, search_space_id
)
# Helper function to get config (global or custom)
async def get_config_for_id(config_id):
if config_id is None:
return None
# Check if it's a global config (negative ID)
if config_id < 0:
for cfg in config.GLOBAL_LLM_CONFIGS:
if cfg.get("id") == config_id:
# Return as LLMConfigRead-compatible dict
return {
"id": cfg.get("id"),
"name": cfg.get("name"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_key": "***GLOBAL***", # Don't expose the actual key
"api_base": cfg.get("api_base"),
"language": cfg.get("language"),
"litellm_params": cfg.get("litellm_params"),
"created_at": None,
"search_space_id": search_space_id,
}
return None
# It's a custom config, fetch from database
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == config_id)
)
return result.scalars().first()
# Get the configs (from DB for custom, or constructed for global)
long_context_llm = await get_config_for_id(preference.long_context_llm_id)
fast_llm = await get_config_for_id(preference.fast_llm_id)
strategic_llm = await get_config_for_id(preference.strategic_llm_id)
return {
"long_context_llm_id": preference.long_context_llm_id,
"fast_llm_id": preference.fast_llm_id,
"strategic_llm_id": preference.strategic_llm_id,
"long_context_llm": long_context_llm,
"fast_llm": fast_llm,
"strategic_llm": strategic_llm,
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to fetch LLM preferences: {e!s}"
) from e
@router.put(
"/search-spaces/{search_space_id}/llm-preferences",
response_model=LLMPreferencesRead,
)
async def update_user_llm_preferences(
search_space_id: int,
preferences: LLMPreferencesUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Update the current user's LLM preferences for a specific search space"""
try:
# Verify user has access to the search space
await check_search_space_access(session, search_space_id, user)
# Get or create user preference for this search space
preference = await get_or_create_user_preference(
session, user.id, search_space_id
)
# Validate that all provided LLM config IDs belong to the search space
update_data = preferences.model_dump(exclude_unset=True)
# Store language from configs to validate consistency
languages = set()
for _key, llm_config_id in update_data.items():
if llm_config_id is not None:
# Check if this is a global config (negative ID)
if llm_config_id < 0:
# Validate global config exists
global_config = None
for cfg in config.GLOBAL_LLM_CONFIGS:
if cfg.get("id") == llm_config_id:
global_config = cfg
break
if not global_config:
raise HTTPException(
status_code=404,
detail=f"Global LLM configuration {llm_config_id} not found",
)
# Collect language for consistency check (if explicitly set)
lang = global_config.get("language")
if lang and lang.strip(): # Only add non-empty languages
languages.add(lang.strip())
else:
# Verify the LLM config belongs to the search space (custom config)
result = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == llm_config_id,
LLMConfig.search_space_id == search_space_id,
)
)
llm_config = result.scalars().first()
if not llm_config:
raise HTTPException(
status_code=404,
detail=f"LLM configuration {llm_config_id} not found in this search space",
)
# Collect language for consistency check (if explicitly set)
if llm_config.language and llm_config.language.strip():
languages.add(llm_config.language.strip())
# Language consistency check - only warn if there are multiple explicit languages
# Allow mixing configs with and without language settings
if len(languages) > 1:
# Log warning but allow the operation
logger.warning(
f"Multiple languages detected in LLM selection for search_space {search_space_id}: {languages}. "
"This may affect response quality."
)
# Don't raise an exception - allow users to proceed
# raise HTTPException(
# status_code=400,
# detail="All selected LLM configurations must have the same language setting",
# )
# Update user preferences
for key, value in update_data.items():
setattr(preference, key, value)
await session.commit()
await session.refresh(preference)
# Helper function to get config (global or custom)
async def get_config_for_id(config_id):
if config_id is None:
return None
# Check if it's a global config (negative ID)
if config_id < 0:
for cfg in config.GLOBAL_LLM_CONFIGS:
if cfg.get("id") == config_id:
# Return as LLMConfigRead-compatible dict
return {
"id": cfg.get("id"),
"name": cfg.get("name"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_key": "***GLOBAL***", # Don't expose the actual key
"api_base": cfg.get("api_base"),
"language": cfg.get("language"),
"litellm_params": cfg.get("litellm_params"),
"created_at": None,
"search_space_id": search_space_id,
}
return None
# It's a custom config, fetch from database
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == config_id)
)
return result.scalars().first()
# Get the configs (from DB for custom, or constructed for global)
long_context_llm = await get_config_for_id(preference.long_context_llm_id)
fast_llm = await get_config_for_id(preference.fast_llm_id)
strategic_llm = await get_config_for_id(preference.strategic_llm_id)
# Return updated preferences
return {
"long_context_llm_id": preference.long_context_llm_id,
"fast_llm_id": preference.fast_llm_id,
"strategic_llm_id": preference.strategic_llm_id,
"long_context_llm": long_context_llm,
"fast_llm": fast_llm,
"strategic_llm": strategic_llm,
}
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to update LLM preferences: {e!s}"
) from e