This commit is contained in:
Manoj Aggarwal 2026-01-16 11:21:01 -08:00
commit aa90da602b
26 changed files with 406 additions and 234 deletions

View file

@ -18,7 +18,9 @@ logger = logging.getLogger(__name__)
class MCPClient:
"""Client for communicating with an MCP server."""
def __init__(self, command: str, args: list[str], env: dict[str, str] | None = None):
def __init__(
self, command: str, args: list[str], env: dict[str, str] | None = None
):
"""Initialize MCP client.
Args:
@ -44,18 +46,16 @@ class MCPClient:
# Merge env vars with current environment
server_env = os.environ.copy()
server_env.update(self.env)
# Create server parameters with env
server_params = StdioServerParameters(
command=self.command,
args=self.args,
env=server_env
command=self.command, args=self.args, env=server_env
)
# Spawn server process and create session
# Note: Cannot combine these context managers because ClientSession
# needs the read/write streams from stdio_client
async with stdio_client(server=server_params) as (read, write):
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
@ -85,7 +85,9 @@ class MCPClient:
"""
if not self.session:
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
raise RuntimeError(
"Not connected to MCP server. Use 'async with client.connect():'"
)
try:
# Call tools/list RPC method
@ -93,11 +95,15 @@ class MCPClient:
tools = []
for tool in response.tools:
tools.append({
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
})
tools.append(
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
}
)
logger.info("Listed %d tools from MCP server", len(tools))
return tools
@ -121,10 +127,14 @@ class MCPClient:
"""
if not self.session:
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
raise RuntimeError(
"Not connected to MCP server. Use 'async with client.connect():'"
)
try:
logger.info("Calling MCP tool '%s' with arguments: %s", tool_name, arguments)
logger.info(
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
)
# Call tools/call RPC method
response = await self.session.call_tool(tool_name, arguments=arguments)
@ -147,12 +157,17 @@ class MCPClient:
# Handle validation errors from MCP server responses
# Some MCP servers (like server-memory) return extra fields not in their schema
if "Invalid structured content" in str(e):
logger.warning("MCP server returned data not matching its schema, but continuing: %s", e)
logger.warning(
"MCP server returned data not matching its schema, but continuing: %s",
e,
)
# Try to extract result from error message or return a success message
return "Operation completed (server returned unexpected format)"
raise
except (ValueError, TypeError, AttributeError, KeyError) as e:
logger.error("Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True)
logger.error(
"Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True
)
return f"Error calling tool: {e!s}"

View file

@ -21,7 +21,8 @@ logger = logging.getLogger(__name__)
def _create_dynamic_input_model_from_schema(
tool_name: str, input_schema: dict[str, Any],
tool_name: str,
input_schema: dict[str, Any],
) -> type[BaseModel]:
"""Create a Pydantic model from MCP tool's JSON schema.
@ -41,15 +42,18 @@ def _create_dynamic_input_model_from_schema(
for param_name, param_schema in properties.items():
param_description = param_schema.get("description", "")
is_required = param_name in required_fields
# Use Any type for complex schemas to preserve structure
# This allows the MCP server to do its own validation
from typing import Any as AnyType
from pydantic import Field
if is_required:
field_definitions[param_name] = (AnyType, Field(..., description=param_description))
field_definitions[param_name] = (
AnyType,
Field(..., description=param_description),
)
else:
field_definitions[param_name] = (
AnyType | None,
@ -88,7 +92,7 @@ async def _create_mcp_tool_from_definition(
async def mcp_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via the client."""
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
try:
# Connect to server and call tool
async with mcp_client.connect():
@ -114,7 +118,8 @@ async def _create_mcp_tool_from_definition(
async def load_mcp_tools(
session: AsyncSession, search_space_id: int,
session: AsyncSession,
search_space_id: int,
) -> list[StructuredTool]:
"""Load all MCP tools from user's active MCP server connectors.
@ -155,9 +160,11 @@ async def load_mcp_tools(
args = server_config.get("args", [])
env = server_config.get("env", {})
if not command:
logger.warning(f"MCP connector {connector.id} server config missing command, skipping")
continue
if not command:
logger.warning(
f"MCP connector {connector.id} missing command, skipping"
)
continue
# Create MCP client
mcp_client = MCPClient(command, args, env)
@ -171,16 +178,18 @@ async def load_mcp_tools(
f"'{command}' (connector {connector.id})"
)
# Create LangChain tools from definitions
for tool_def in tool_definitions:
try:
tool = await _create_mcp_tool_from_definition(tool_def, mcp_client)
tools.append(tool)
except Exception as e:
logger.exception(
f"Failed to create tool '{tool_def.get('name')}' "
f"from connector {connector.id}: {e!s}",
)
# Create LangChain tools from definitions
for tool_def in tool_definitions:
try:
tool = await _create_mcp_tool_from_definition(
tool_def, mcp_client
)
tools.append(tool)
except Exception as e:
logger.exception(
f"Failed to create tool '{tool_def.get('name')}' "
f"from connector {connector.id}: {e!s}",
)
except Exception as e:
logger.exception(

View file

@ -283,7 +283,8 @@ async def build_tools_async(
):
try:
mcp_tools = await load_mcp_tools(
dependencies["db_session"], dependencies["search_space_id"],
dependencies["db_session"],
dependencies["search_space_id"],
)
tools.extend(mcp_tools)
logging.info(

View file

@ -23,7 +23,9 @@ class SearchSourceConnectorBase(BaseModel):
@field_validator("config")
@classmethod
def validate_config_for_connector_type(
cls, config: dict[str, Any], values: dict[str, Any],
cls,
config: dict[str, Any],
values: dict[str, Any],
) -> dict[str, Any]:
connector_type = values.data.get("connector_type")
return validate_connector_config(connector_type, config)

View file

@ -2,11 +2,14 @@
File document processors for different ETL services (Unstructured, LlamaCloud, Docling).
"""
import asyncio
import contextlib
import logging
import ssl
import warnings
from logging import ERROR, getLogger
import httpx
from fastapi import HTTPException
from langchain_core.documents import Document as LangChainDocument
from litellm import atranscription
@ -31,6 +34,122 @@ from .base import (
)
from .markdown_processor import add_received_markdown_file_document
# Constants for LlamaCloud retry configuration
LLAMACLOUD_MAX_RETRIES = 3
LLAMACLOUD_BASE_DELAY = 5 # Base delay in seconds for exponential backoff
LLAMACLOUD_RETRYABLE_EXCEPTIONS = (
ssl.SSLError,
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadTimeout,
httpx.WriteTimeout,
ConnectionError,
TimeoutError,
)
async def parse_with_llamacloud_retry(
file_path: str,
estimated_pages: int,
task_logger: TaskLoggingService | None = None,
log_entry: Log | None = None,
):
"""
Parse a file with LlamaCloud with retry logic for transient SSL/connection errors.
Args:
file_path: Path to the file to parse
estimated_pages: Estimated number of pages for timeout calculation
task_logger: Optional task logger for progress updates
log_entry: Optional log entry for progress updates
Returns:
LlamaParse result object
Raises:
Exception: If all retries fail
"""
from llama_cloud_services import LlamaParse
from llama_cloud_services.parse.utils import ResultType
# Calculate timeouts based on estimated pages
# Base timeout of 300 seconds + 30 seconds per page for large documents
base_timeout = 300
per_page_timeout = 30
job_timeout = base_timeout + (estimated_pages * per_page_timeout)
# Create custom httpx client with larger timeouts for file uploads
# The SSL error often occurs during large file uploads, so we need generous timeouts
custom_timeout = httpx.Timeout(
connect=60.0, # 60 seconds to establish connection
read=300.0, # 5 minutes to read response
write=300.0, # 5 minutes to write/upload (important for large files)
pool=60.0, # 60 seconds to acquire connection from pool
)
last_exception = None
for attempt in range(1, LLAMACLOUD_MAX_RETRIES + 1):
try:
# Create a fresh httpx client for each attempt
async with httpx.AsyncClient(timeout=custom_timeout) as custom_client:
# Create LlamaParse parser instance with optimized settings
parser = LlamaParse(
api_key=app_config.LLAMA_CLOUD_API_KEY,
num_workers=1, # Use single worker for file processing
verbose=True,
language="en",
result_type=ResultType.MD,
# Timeout settings for large files
max_timeout=max(2000, job_timeout), # Overall max timeout
job_timeout_in_seconds=job_timeout,
job_timeout_extra_time_per_page_in_seconds=per_page_timeout,
# Use our custom client with larger timeouts
custom_client=custom_client,
)
# Parse the file asynchronously
result = await parser.aparse(file_path)
return result
except LLAMACLOUD_RETRYABLE_EXCEPTIONS as e:
last_exception = e
error_type = type(e).__name__
if attempt < LLAMACLOUD_MAX_RETRIES:
# Calculate exponential backoff delay
delay = LLAMACLOUD_BASE_DELAY * (2 ** (attempt - 1))
if task_logger and log_entry:
await task_logger.log_task_progress(
log_entry,
f"LlamaCloud upload failed (attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}), retrying in {delay}s",
{
"error_type": error_type,
"error_message": str(e)[:200],
"attempt": attempt,
"retry_delay": delay,
},
)
else:
logging.warning(
f"LlamaCloud upload failed (attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}): {error_type}. "
f"Retrying in {delay}s..."
)
await asyncio.sleep(delay)
else:
logging.error(
f"LlamaCloud upload failed after {LLAMACLOUD_MAX_RETRIES} attempts: {error_type} - {e}"
)
except Exception:
# Non-retryable exception, raise immediately
raise
# All retries exhausted
raise last_exception or RuntimeError("LlamaCloud parsing failed after all retries")
async def add_received_file_document_using_unstructured(
session: AsyncSession,
@ -819,24 +938,18 @@ async def process_file_in_background(
"file_type": "document",
"etl_service": "LLAMACLOUD",
"processing_stage": "parsing",
"estimated_pages": estimated_pages_before,
},
)
from llama_cloud_services import LlamaParse
from llama_cloud_services.parse.utils import ResultType
# Create LlamaParse parser instance
parser = LlamaParse(
api_key=app_config.LLAMA_CLOUD_API_KEY,
num_workers=1, # Use single worker for file processing
verbose=True,
language="en",
result_type=ResultType.MD,
# Parse file with retry logic for SSL/connection errors (common with large files)
result = await parse_with_llamacloud_retry(
file_path=file_path,
estimated_pages=estimated_pages_before,
task_logger=task_logger,
log_entry=log_entry,
)
# Parse the file asynchronously
result = await parser.aparse(file_path)
# Clean up the temp file
import os