mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
nit
This commit is contained in:
commit
aa90da602b
26 changed files with 406 additions and 234 deletions
|
|
@ -18,7 +18,9 @@ logger = logging.getLogger(__name__)
|
|||
class MCPClient:
|
||||
"""Client for communicating with an MCP server."""
|
||||
|
||||
def __init__(self, command: str, args: list[str], env: dict[str, str] | None = None):
|
||||
def __init__(
|
||||
self, command: str, args: list[str], env: dict[str, str] | None = None
|
||||
):
|
||||
"""Initialize MCP client.
|
||||
|
||||
Args:
|
||||
|
|
@ -44,18 +46,16 @@ class MCPClient:
|
|||
# Merge env vars with current environment
|
||||
server_env = os.environ.copy()
|
||||
server_env.update(self.env)
|
||||
|
||||
|
||||
# Create server parameters with env
|
||||
server_params = StdioServerParameters(
|
||||
command=self.command,
|
||||
args=self.args,
|
||||
env=server_env
|
||||
command=self.command, args=self.args, env=server_env
|
||||
)
|
||||
|
||||
|
||||
# Spawn server process and create session
|
||||
# Note: Cannot combine these context managers because ClientSession
|
||||
# needs the read/write streams from stdio_client
|
||||
async with stdio_client(server=server_params) as (read, write):
|
||||
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
|
|
@ -85,7 +85,9 @@ class MCPClient:
|
|||
|
||||
"""
|
||||
if not self.session:
|
||||
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
|
||||
raise RuntimeError(
|
||||
"Not connected to MCP server. Use 'async with client.connect():'"
|
||||
)
|
||||
|
||||
try:
|
||||
# Call tools/list RPC method
|
||||
|
|
@ -93,11 +95,15 @@ class MCPClient:
|
|||
|
||||
tools = []
|
||||
for tool in response.tools:
|
||||
tools.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
|
||||
})
|
||||
tools.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"input_schema": tool.inputSchema
|
||||
if hasattr(tool, "inputSchema")
|
||||
else {},
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("Listed %d tools from MCP server", len(tools))
|
||||
return tools
|
||||
|
|
@ -121,10 +127,14 @@ class MCPClient:
|
|||
|
||||
"""
|
||||
if not self.session:
|
||||
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
|
||||
raise RuntimeError(
|
||||
"Not connected to MCP server. Use 'async with client.connect():'"
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info("Calling MCP tool '%s' with arguments: %s", tool_name, arguments)
|
||||
logger.info(
|
||||
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
|
||||
)
|
||||
|
||||
# Call tools/call RPC method
|
||||
response = await self.session.call_tool(tool_name, arguments=arguments)
|
||||
|
|
@ -147,12 +157,17 @@ class MCPClient:
|
|||
# Handle validation errors from MCP server responses
|
||||
# Some MCP servers (like server-memory) return extra fields not in their schema
|
||||
if "Invalid structured content" in str(e):
|
||||
logger.warning("MCP server returned data not matching its schema, but continuing: %s", e)
|
||||
logger.warning(
|
||||
"MCP server returned data not matching its schema, but continuing: %s",
|
||||
e,
|
||||
)
|
||||
# Try to extract result from error message or return a success message
|
||||
return "Operation completed (server returned unexpected format)"
|
||||
raise
|
||||
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
||||
logger.error("Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True)
|
||||
logger.error(
|
||||
"Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True
|
||||
)
|
||||
return f"Error calling tool: {e!s}"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def _create_dynamic_input_model_from_schema(
|
||||
tool_name: str, input_schema: dict[str, Any],
|
||||
tool_name: str,
|
||||
input_schema: dict[str, Any],
|
||||
) -> type[BaseModel]:
|
||||
"""Create a Pydantic model from MCP tool's JSON schema.
|
||||
|
||||
|
|
@ -41,15 +42,18 @@ def _create_dynamic_input_model_from_schema(
|
|||
for param_name, param_schema in properties.items():
|
||||
param_description = param_schema.get("description", "")
|
||||
is_required = param_name in required_fields
|
||||
|
||||
|
||||
# Use Any type for complex schemas to preserve structure
|
||||
# This allows the MCP server to do its own validation
|
||||
from typing import Any as AnyType
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
if is_required:
|
||||
field_definitions[param_name] = (AnyType, Field(..., description=param_description))
|
||||
field_definitions[param_name] = (
|
||||
AnyType,
|
||||
Field(..., description=param_description),
|
||||
)
|
||||
else:
|
||||
field_definitions[param_name] = (
|
||||
AnyType | None,
|
||||
|
|
@ -88,7 +92,7 @@ async def _create_mcp_tool_from_definition(
|
|||
async def mcp_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via the client."""
|
||||
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
|
||||
|
||||
|
||||
try:
|
||||
# Connect to server and call tool
|
||||
async with mcp_client.connect():
|
||||
|
|
@ -114,7 +118,8 @@ async def _create_mcp_tool_from_definition(
|
|||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: AsyncSession, search_space_id: int,
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load all MCP tools from user's active MCP server connectors.
|
||||
|
||||
|
|
@ -155,9 +160,11 @@ async def load_mcp_tools(
|
|||
args = server_config.get("args", [])
|
||||
env = server_config.get("env", {})
|
||||
|
||||
if not command:
|
||||
logger.warning(f"MCP connector {connector.id} server config missing command, skipping")
|
||||
continue
|
||||
if not command:
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} missing command, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Create MCP client
|
||||
mcp_client = MCPClient(command, args, env)
|
||||
|
|
@ -171,16 +178,18 @@ async def load_mcp_tools(
|
|||
f"'{command}' (connector {connector.id})"
|
||||
)
|
||||
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition(tool_def, mcp_client)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector.id}: {e!s}",
|
||||
)
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition(
|
||||
tool_def, mcp_client
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector.id}: {e!s}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
|
|
|
|||
|
|
@ -283,7 +283,8 @@ async def build_tools_async(
|
|||
):
|
||||
try:
|
||||
mcp_tools = await load_mcp_tools(
|
||||
dependencies["db_session"], dependencies["search_space_id"],
|
||||
dependencies["db_session"],
|
||||
dependencies["search_space_id"],
|
||||
)
|
||||
tools.extend(mcp_tools)
|
||||
logging.info(
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ class SearchSourceConnectorBase(BaseModel):
|
|||
@field_validator("config")
|
||||
@classmethod
|
||||
def validate_config_for_connector_type(
|
||||
cls, config: dict[str, Any], values: dict[str, Any],
|
||||
cls,
|
||||
config: dict[str, Any],
|
||||
values: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
connector_type = values.data.get("connector_type")
|
||||
return validate_connector_config(connector_type, config)
|
||||
|
|
|
|||
|
|
@ -2,11 +2,14 @@
|
|||
File document processors for different ETL services (Unstructured, LlamaCloud, Docling).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import ssl
|
||||
import warnings
|
||||
from logging import ERROR, getLogger
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
from litellm import atranscription
|
||||
|
|
@ -31,6 +34,122 @@ from .base import (
|
|||
)
|
||||
from .markdown_processor import add_received_markdown_file_document
|
||||
|
||||
# Constants for LlamaCloud retry configuration
|
||||
LLAMACLOUD_MAX_RETRIES = 3
|
||||
LLAMACLOUD_BASE_DELAY = 5 # Base delay in seconds for exponential backoff
|
||||
LLAMACLOUD_RETRYABLE_EXCEPTIONS = (
|
||||
ssl.SSLError,
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadTimeout,
|
||||
httpx.WriteTimeout,
|
||||
ConnectionError,
|
||||
TimeoutError,
|
||||
)
|
||||
|
||||
|
||||
async def parse_with_llamacloud_retry(
|
||||
file_path: str,
|
||||
estimated_pages: int,
|
||||
task_logger: TaskLoggingService | None = None,
|
||||
log_entry: Log | None = None,
|
||||
):
|
||||
"""
|
||||
Parse a file with LlamaCloud with retry logic for transient SSL/connection errors.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to parse
|
||||
estimated_pages: Estimated number of pages for timeout calculation
|
||||
task_logger: Optional task logger for progress updates
|
||||
log_entry: Optional log entry for progress updates
|
||||
|
||||
Returns:
|
||||
LlamaParse result object
|
||||
|
||||
Raises:
|
||||
Exception: If all retries fail
|
||||
"""
|
||||
from llama_cloud_services import LlamaParse
|
||||
from llama_cloud_services.parse.utils import ResultType
|
||||
|
||||
# Calculate timeouts based on estimated pages
|
||||
# Base timeout of 300 seconds + 30 seconds per page for large documents
|
||||
base_timeout = 300
|
||||
per_page_timeout = 30
|
||||
job_timeout = base_timeout + (estimated_pages * per_page_timeout)
|
||||
|
||||
# Create custom httpx client with larger timeouts for file uploads
|
||||
# The SSL error often occurs during large file uploads, so we need generous timeouts
|
||||
custom_timeout = httpx.Timeout(
|
||||
connect=60.0, # 60 seconds to establish connection
|
||||
read=300.0, # 5 minutes to read response
|
||||
write=300.0, # 5 minutes to write/upload (important for large files)
|
||||
pool=60.0, # 60 seconds to acquire connection from pool
|
||||
)
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(1, LLAMACLOUD_MAX_RETRIES + 1):
|
||||
try:
|
||||
# Create a fresh httpx client for each attempt
|
||||
async with httpx.AsyncClient(timeout=custom_timeout) as custom_client:
|
||||
# Create LlamaParse parser instance with optimized settings
|
||||
parser = LlamaParse(
|
||||
api_key=app_config.LLAMA_CLOUD_API_KEY,
|
||||
num_workers=1, # Use single worker for file processing
|
||||
verbose=True,
|
||||
language="en",
|
||||
result_type=ResultType.MD,
|
||||
# Timeout settings for large files
|
||||
max_timeout=max(2000, job_timeout), # Overall max timeout
|
||||
job_timeout_in_seconds=job_timeout,
|
||||
job_timeout_extra_time_per_page_in_seconds=per_page_timeout,
|
||||
# Use our custom client with larger timeouts
|
||||
custom_client=custom_client,
|
||||
)
|
||||
|
||||
# Parse the file asynchronously
|
||||
result = await parser.aparse(file_path)
|
||||
return result
|
||||
|
||||
except LLAMACLOUD_RETRYABLE_EXCEPTIONS as e:
|
||||
last_exception = e
|
||||
error_type = type(e).__name__
|
||||
|
||||
if attempt < LLAMACLOUD_MAX_RETRIES:
|
||||
# Calculate exponential backoff delay
|
||||
delay = LLAMACLOUD_BASE_DELAY * (2 ** (attempt - 1))
|
||||
|
||||
if task_logger and log_entry:
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"LlamaCloud upload failed (attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}), retrying in {delay}s",
|
||||
{
|
||||
"error_type": error_type,
|
||||
"error_message": str(e)[:200],
|
||||
"attempt": attempt,
|
||||
"retry_delay": delay,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
f"LlamaCloud upload failed (attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}): {error_type}. "
|
||||
f"Retrying in {delay}s..."
|
||||
)
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logging.error(
|
||||
f"LlamaCloud upload failed after {LLAMACLOUD_MAX_RETRIES} attempts: {error_type} - {e}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Non-retryable exception, raise immediately
|
||||
raise
|
||||
|
||||
# All retries exhausted
|
||||
raise last_exception or RuntimeError("LlamaCloud parsing failed after all retries")
|
||||
|
||||
|
||||
async def add_received_file_document_using_unstructured(
|
||||
session: AsyncSession,
|
||||
|
|
@ -819,24 +938,18 @@ async def process_file_in_background(
|
|||
"file_type": "document",
|
||||
"etl_service": "LLAMACLOUD",
|
||||
"processing_stage": "parsing",
|
||||
"estimated_pages": estimated_pages_before,
|
||||
},
|
||||
)
|
||||
|
||||
from llama_cloud_services import LlamaParse
|
||||
from llama_cloud_services.parse.utils import ResultType
|
||||
|
||||
# Create LlamaParse parser instance
|
||||
parser = LlamaParse(
|
||||
api_key=app_config.LLAMA_CLOUD_API_KEY,
|
||||
num_workers=1, # Use single worker for file processing
|
||||
verbose=True,
|
||||
language="en",
|
||||
result_type=ResultType.MD,
|
||||
# Parse file with retry logic for SSL/connection errors (common with large files)
|
||||
result = await parse_with_llamacloud_retry(
|
||||
file_path=file_path,
|
||||
estimated_pages=estimated_pages_before,
|
||||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
)
|
||||
|
||||
# Parse the file asynchronously
|
||||
result = await parser.aparse(file_path)
|
||||
|
||||
# Clean up the temp file
|
||||
import os
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue