feat: Implement LLM configuration validation in create and update routes

- Added `validate_llm_config` function to `llm_service.py` for validating LLM configurations via test API calls.
- Integrated validation in `create_llm_config` and `update_llm_config` routes in `llm_config_routes.py`, raising HTTP exceptions for invalid configurations.
- Enhanced error handling to provide detailed feedback on configuration issues.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-11-05 12:15:05 -08:00
parent 666dba7150
commit 9466bf595c
9 changed files with 235 additions and 52 deletions

View file

@ -12,6 +12,7 @@ from app.db import (
get_async_session, get_async_session,
) )
from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
from app.services.llm_service import validate_llm_config
from app.users import current_active_user from app.users import current_active_user
router = APIRouter() router = APIRouter()
@ -98,6 +99,22 @@ async def create_llm_config(
# Verify user has access to the search space # Verify user has access to the search space
await check_search_space_access(session, llm_config.search_space_id, user) await check_search_space_access(session, llm_config.search_space_id, user)
# Validate the LLM configuration by making a test API call
is_valid, error_message = await validate_llm_config(
provider=llm_config.provider.value,
model_name=llm_config.model_name,
api_key=llm_config.api_key,
api_base=llm_config.api_base,
custom_provider=llm_config.custom_provider,
litellm_params=llm_config.litellm_params,
)
if not is_valid:
raise HTTPException(
status_code=400,
detail=f"Invalid LLM configuration: {error_message}",
)
db_llm_config = LLMConfig(**llm_config.model_dump()) db_llm_config = LLMConfig(**llm_config.model_dump())
session.add(db_llm_config) session.add(db_llm_config)
await session.commit() await session.commit()
@ -192,6 +209,39 @@ async def update_llm_config(
update_data = llm_config_update.model_dump(exclude_unset=True) update_data = llm_config_update.model_dump(exclude_unset=True)
# Apply updates to a temporary copy for validation
temp_config = {
"provider": update_data.get("provider", db_llm_config.provider).value
if "provider" in update_data
else db_llm_config.provider.value,
"model_name": update_data.get("model_name", db_llm_config.model_name),
"api_key": update_data.get("api_key", db_llm_config.api_key),
"api_base": update_data.get("api_base", db_llm_config.api_base),
"custom_provider": update_data.get(
"custom_provider", db_llm_config.custom_provider
),
"litellm_params": update_data.get(
"litellm_params", db_llm_config.litellm_params
),
}
# Validate the updated configuration
is_valid, error_message = await validate_llm_config(
provider=temp_config["provider"],
model_name=temp_config["model_name"],
api_key=temp_config["api_key"],
api_base=temp_config["api_base"],
custom_provider=temp_config["custom_provider"],
litellm_params=temp_config["litellm_params"],
)
if not is_valid:
raise HTTPException(
status_code=400,
detail=f"Invalid LLM configuration: {error_message}",
)
# Apply updates to the database object
for key, value in update_data.items(): for key, value in update_data.items():
setattr(db_llm_config, key, value) setattr(db_llm_config, key, value)

View file

@ -1,6 +1,7 @@
import logging import logging
import litellm import litellm
from langchain_core.messages import HumanMessage
from langchain_litellm import ChatLiteLLM from langchain_litellm import ChatLiteLLM
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
@ -19,6 +20,93 @@ class LLMRole:
STRATEGIC = "strategic" STRATEGIC = "strategic"
async def validate_llm_config(
provider: str,
model_name: str,
api_key: str,
api_base: str | None = None,
custom_provider: str | None = None,
litellm_params: dict | None = None,
) -> tuple[bool, str]:
"""
Validate an LLM configuration by attempting to make a test API call.
Args:
provider: LLM provider (e.g., 'OPENAI', 'ANTHROPIC')
model_name: Model identifier
api_key: API key for the provider
api_base: Optional custom API base URL
custom_provider: Optional custom provider string
litellm_params: Optional additional litellm parameters
Returns:
Tuple of (is_valid, error_message)
- is_valid: True if config works, False otherwise
- error_message: Empty string if valid, error description if invalid
"""
try:
# Build the model string for litellm
if custom_provider:
model_string = f"{custom_provider}/{model_name}"
else:
# Map provider enum to litellm format
provider_map = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"COMETAPI": "cometapi",
# Chinese LLM providers (OpenAI-compatible)
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
}
provider_prefix = provider_map.get(provider, provider.lower())
model_string = f"{provider_prefix}/{model_name}"
# Create ChatLiteLLM instance
litellm_kwargs = {
"model": model_string,
"api_key": api_key,
"timeout": 30, # Set a timeout for validation
}
# Add optional parameters
if api_base:
litellm_kwargs["api_base"] = api_base
# Add any additional litellm parameters
if litellm_params:
litellm_kwargs.update(litellm_params)
llm = ChatLiteLLM(**litellm_kwargs)
# Make a simple test call
test_message = HumanMessage(content="Hello")
response = await llm.ainvoke([test_message])
# If we got here without exception, the config is valid
if response and response.content:
logger.info(f"Successfully validated LLM config for model: {model_string}")
return True, ""
else:
logger.warning(
f"LLM config validation returned empty response for model: {model_string}"
)
return False, "LLM returned an empty response"
except Exception as e:
error_msg = f"Failed to validate LLM configuration: {e!s}"
logger.error(error_msg)
return False, error_msg
async def get_user_llm_instance( async def get_user_llm_instance(
session: AsyncSession, user_id: str, search_space_id: int, role: str session: AsyncSession, user_id: str, search_space_id: int, role: str
) -> ChatLiteLLM | None: ) -> ChatLiteLLM | None:

View file

@ -898,7 +898,7 @@ async def process_file_in_background(
# Suppress both Python warnings and logging warnings from pdfminer # Suppress both Python warnings and logging warnings from pdfminer
pdfminer_logger = getLogger("pdfminer") pdfminer_logger = getLogger("pdfminer")
original_level = pdfminer_logger.level original_level = pdfminer_logger.level
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings( warnings.filterwarnings(
"ignore", category=UserWarning, module="pdfminer" "ignore", category=UserWarning, module="pdfminer"
@ -907,16 +907,16 @@ async def process_file_in_background(
"ignore", "ignore",
message=".*Cannot set gray non-stroke color.*", message=".*Cannot set gray non-stroke color.*",
) )
warnings.filterwarnings( warnings.filterwarnings("ignore", message=".*invalid float value.*")
"ignore", message=".*invalid float value.*"
)
# Temporarily suppress pdfminer logging warnings # Temporarily suppress pdfminer logging warnings
pdfminer_logger.setLevel(ERROR) pdfminer_logger.setLevel(ERROR)
try: try:
# Process the document # Process the document
result = await docling_service.process_document(file_path, filename) result = await docling_service.process_document(
file_path, filename
)
finally: finally:
# Restore original logging level # Restore original logging level
pdfminer_logger.setLevel(original_level) pdfminer_logger.setLevel(original_level)

View file

@ -73,7 +73,7 @@ async def add_crawled_url_document(
) )
use_firecrawl = bool(config.FIRECRAWL_API_KEY) use_firecrawl = bool(config.FIRECRAWL_API_KEY)
if use_firecrawl: if use_firecrawl:
# Use Firecrawl SDK directly # Use Firecrawl SDK directly
firecrawl_app = AsyncFirecrawlApp(api_key=config.FIRECRAWL_API_KEY) firecrawl_app = AsyncFirecrawlApp(api_key=config.FIRECRAWL_API_KEY)
@ -84,40 +84,50 @@ async def add_crawled_url_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Crawling URL content: {url}", f"Crawling URL content: {url}",
{"stage": "crawling", "crawler_type": "AsyncFirecrawlApp" if use_firecrawl else "AsyncChromiumLoader"}, {
"stage": "crawling",
"crawler_type": "AsyncFirecrawlApp"
if use_firecrawl
else "AsyncChromiumLoader",
},
) )
if use_firecrawl: if use_firecrawl:
# Use async Firecrawl SDK with v1 API - properly awaited # Use async Firecrawl SDK with v1 API - properly awaited
scrape_result = await firecrawl_app.scrape_url( scrape_result = await firecrawl_app.scrape_url(
url=url, url=url, formats=["markdown"]
formats=['markdown']
) )
# scrape_result is a Pydantic ScrapeResponse object # scrape_result is a Pydantic ScrapeResponse object
# Access attributes directly # Access attributes directly
if scrape_result and scrape_result.success: if scrape_result and scrape_result.success:
# Extract markdown content # Extract markdown content
markdown_content = scrape_result.markdown or '' markdown_content = scrape_result.markdown or ""
# Extract metadata - this is a DICT # Extract metadata - this is a DICT
metadata = scrape_result.metadata if scrape_result.metadata else {} metadata = scrape_result.metadata if scrape_result.metadata else {}
# Convert to LangChain Document format # Convert to LangChain Document format
url_crawled = [LangchainDocument( url_crawled = [
page_content=markdown_content, LangchainDocument(
metadata={ page_content=markdown_content,
'source': url, metadata={
'title': metadata.get('title', url), "source": url,
'description': metadata.get('description', ''), "title": metadata.get("title", url),
'language': metadata.get('language', ''), "description": metadata.get("description", ""),
'sourceURL': metadata.get('sourceURL', url), "language": metadata.get("language", ""),
**metadata # Include all other metadata fields "sourceURL": metadata.get("sourceURL", url),
} **metadata, # Include all other metadata fields
)] },
)
]
content_in_markdown = url_crawled[0].page_content content_in_markdown = url_crawled[0].page_content
else: else:
error_msg = scrape_result.error if scrape_result and hasattr(scrape_result, 'error') else "Unknown error" error_msg = (
scrape_result.error
if scrape_result and hasattr(scrape_result, "error")
else "Unknown error"
)
raise ValueError(f"Firecrawl failed to scrape URL: {error_msg}") raise ValueError(f"Firecrawl failed to scrape URL: {error_msg}")
else: else:
# Use AsyncChromiumLoader as fallback # Use AsyncChromiumLoader as fallback
@ -249,7 +259,9 @@ async def add_crawled_url_document(
{"stage": "document_update", "chunks_count": len(chunks)}, {"stage": "document_update", "chunks_count": len(chunks)},
) )
existing_document.title = url_crawled[0].metadata.get('title', url_crawled[0].metadata.get('source', url)) existing_document.title = url_crawled[0].metadata.get(
"title", url_crawled[0].metadata.get("source", url)
)
existing_document.content = summary_content existing_document.content = summary_content
existing_document.content_hash = content_hash existing_document.content_hash = content_hash
existing_document.embedding = summary_embedding existing_document.embedding = summary_embedding
@ -267,7 +279,9 @@ async def add_crawled_url_document(
document = Document( document = Document(
search_space_id=search_space_id, search_space_id=search_space_id,
title=url_crawled[0].metadata.get('title', url_crawled[0].metadata.get('source', url)), title=url_crawled[0].metadata.get(
"title", url_crawled[0].metadata.get("source", url)
),
document_type=DocumentType.CRAWLED_URL, document_type=DocumentType.CRAWLED_URL,
document_metadata=url_crawled[0].metadata, document_metadata=url_crawled[0].metadata,
content=summary_content, content=summary_content,

View file

@ -49,10 +49,10 @@ export default function DashboardLayout({
title: "Upload Documents", title: "Upload Documents",
url: `/dashboard/${search_space_id}/documents/upload`, url: `/dashboard/${search_space_id}/documents/upload`,
}, },
{ // {
title: "Add Webpages", // title: "Add Webpages",
url: `/dashboard/${search_space_id}/documents/webpage`, // url: `/dashboard/${search_space_id}/documents/webpage`,
}, // },
{ {
title: "Add Youtube Videos", title: "Add Youtube Videos",
url: `/dashboard/${search_space_id}/documents/youtube`, url: `/dashboard/${search_space_id}/documents/youtube`,

View file

@ -5,7 +5,9 @@ import { useParams, useRouter } from "next/navigation";
import { useEffect, useMemo } from "react"; import { useEffect, useMemo } from "react";
import ChatInterface from "@/components/chat/ChatInterface"; import ChatInterface from "@/components/chat/ChatInterface";
import { useChatAPI, useChatState } from "@/hooks/use-chat"; import { useChatAPI, useChatState } from "@/hooks/use-chat";
import { useDocumentTypes } from "@/hooks/use-document-types";
import type { Document } from "@/hooks/use-documents"; import type { Document } from "@/hooks/use-documents";
import { useSearchSourceConnectors } from "@/hooks/use-search-source-connectors";
export default function ResearcherPage() { export default function ResearcherPage() {
const { search_space_id, chat_id } = useParams(); const { search_space_id, chat_id } = useParams();
@ -35,6 +37,19 @@ export default function ResearcherPage() {
search_space_id: search_space_id as string, search_space_id: search_space_id as string,
}); });
// Fetch all available sources (document types + live search connectors)
const { documentTypes } = useDocumentTypes(Number(search_space_id));
const { connectors: searchConnectors } = useSearchSourceConnectors(
false,
Number(search_space_id)
);
// Filter for non-indexable connectors (live search)
const liveSearchConnectors = useMemo(
() => searchConnectors.filter((connector) => !connector.is_indexable),
[searchConnectors]
);
// Memoize document IDs to prevent infinite re-renders // Memoize document IDs to prevent infinite re-renders
const documentIds = useMemo(() => { const documentIds = useMemo(() => {
return selectedDocuments.map((doc) => doc.id); return selectedDocuments.map((doc) => doc.id);
@ -135,6 +150,27 @@ export default function ResearcherPage() {
} }
}, [chatIdParam, search_space_id, setSelectedDocuments, setSelectedConnectors, setSearchMode]); }, [chatIdParam, search_space_id, setSelectedDocuments, setSelectedConnectors, setSearchMode]);
// Set all sources as default for new chats
useEffect(() => {
if (isNewChat && selectedConnectors.length === 0 && documentTypes.length > 0) {
// Combine all document types and live search connectors
const allSourceTypes = [
...documentTypes.map((dt) => dt.type),
...liveSearchConnectors.map((c) => c.connector_type),
];
if (allSourceTypes.length > 0) {
setSelectedConnectors(allSourceTypes);
}
}
}, [
isNewChat,
documentTypes,
liveSearchConnectors,
selectedConnectors.length,
setSelectedConnectors,
]);
const loadChatData = async (chatId: string) => { const loadChatData = async (chatId: string) => {
try { try {
const chatData = await fetchChatDetails(chatId); const chatData = await fetchChatDetails(chatId);

View file

@ -115,18 +115,19 @@ const ConnectorSelector = React.memo(
const { search_space_id } = useParams(); const { search_space_id } = useParams();
const [isOpen, setIsOpen] = useState(false); const [isOpen, setIsOpen] = useState(false);
// Fetch immediately (not lazy) so the button can show the correct count
const { documentTypes, isLoading, isLoaded, fetchDocumentTypes } = useDocumentTypes( const { documentTypes, isLoading, isLoaded, fetchDocumentTypes } = useDocumentTypes(
Number(search_space_id), Number(search_space_id),
true false
); );
// Fetch live search connectors (non-indexable) // Fetch live search connectors immediately (non-indexable)
const { const {
connectors: searchConnectors, connectors: searchConnectors,
isLoading: connectorsLoading, isLoading: connectorsLoading,
isLoaded: connectorsLoaded, isLoaded: connectorsLoaded,
fetchConnectors, fetchConnectors,
} = useSearchSourceConnectors(true, Number(search_space_id)); } = useSearchSourceConnectors(false, Number(search_space_id));
// Filter for non-indexable connectors (live search) // Filter for non-indexable connectors (live search)
const liveSearchConnectors = React.useMemo( const liveSearchConnectors = React.useMemo(
@ -134,18 +135,10 @@ const ConnectorSelector = React.memo(
[searchConnectors] [searchConnectors]
); );
const handleOpenChange = useCallback( const handleOpenChange = useCallback((open: boolean) => {
(open: boolean) => { setIsOpen(open);
setIsOpen(open); // Data is already loaded on mount, no need to fetch again
if (open && !isLoaded) { }, []);
fetchDocumentTypes(Number(search_space_id));
}
if (open && !connectorsLoaded) {
fetchConnectors(Number(search_space_id));
}
},
[fetchDocumentTypes, isLoaded, fetchConnectors, connectorsLoaded, search_space_id]
);
const handleConnectorToggle = useCallback( const handleConnectorToggle = useCallback(
(connectorType: string) => { (connectorType: string) => {

View file

@ -77,10 +77,10 @@ const defaultData = {
title: "Upload Documents", title: "Upload Documents",
url: "#", url: "#",
}, },
{ // {
title: "Add Webpages", // title: "Add Webpages",
url: "#", // url: "#",
}, // },
{ {
title: "Manage Documents", title: "Manage Documents",
url: "#", url: "#",

View file

@ -47,6 +47,8 @@ export const useGithubStars = () => {
error, error,
compactFormat: Intl.NumberFormat("en-US", { compactFormat: Intl.NumberFormat("en-US", {
notation: "compact", notation: "compact",
maximumFractionDigits: 1,
minimumFractionDigits: 1,
}).format(stars || 0), }).format(stars || 0),
}; };
}; };