mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
Compare commits
11 commits
7245ab4046
...
09ab174221
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09ab174221 | ||
|
|
739345671b | ||
|
|
45b72de481 | ||
|
|
1712f454f8 | ||
|
|
cf7c14cf44 | ||
|
|
2eb0ff9e5e | ||
|
|
9bb117ffa7 | ||
|
|
80a349ea11 | ||
|
|
0b3551bd06 | ||
|
|
e3172dc282 | ||
|
|
16f47578d7 |
24 changed files with 835 additions and 333 deletions
|
|
@ -314,6 +314,20 @@ async def create_surfsense_deep_agent(
|
|||
_t0 = time.perf_counter()
|
||||
_enabled_tool_names = {t.name for t in tools}
|
||||
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
|
||||
|
||||
# Collect generic MCP connector info so the system prompt can route queries
|
||||
# to their tools instead of falling back to "not in knowledge base".
|
||||
_mcp_connector_tools: dict[str, list[str]] = {}
|
||||
for t in tools:
|
||||
meta = getattr(t, "metadata", None) or {}
|
||||
if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"):
|
||||
_mcp_connector_tools.setdefault(
|
||||
meta["mcp_connector_name"], [],
|
||||
).append(t.name)
|
||||
|
||||
if _mcp_connector_tools:
|
||||
_perf_log.info("MCP connector tool routing: %s", _mcp_connector_tools)
|
||||
|
||||
if agent_config is not None:
|
||||
system_prompt = build_configurable_system_prompt(
|
||||
custom_system_instructions=agent_config.system_instructions,
|
||||
|
|
@ -322,12 +336,14 @@ async def create_surfsense_deep_agent(
|
|||
thread_visibility=thread_visibility,
|
||||
enabled_tool_names=_enabled_tool_names,
|
||||
disabled_tool_names=_user_disabled_tool_names,
|
||||
mcp_connector_tools=_mcp_connector_tools,
|
||||
)
|
||||
else:
|
||||
system_prompt = build_surfsense_system_prompt(
|
||||
thread_visibility=thread_visibility,
|
||||
enabled_tool_names=_enabled_tool_names,
|
||||
disabled_tool_names=_user_disabled_tool_names,
|
||||
mcp_connector_tools=_mcp_connector_tools,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
||||
|
|
|
|||
|
|
@ -815,11 +815,36 @@ Your goal is to provide helpful, informative answers in a clean, readable format
|
|||
"""
|
||||
|
||||
|
||||
def _build_mcp_routing_block(
|
||||
mcp_connector_tools: dict[str, list[str]] | None,
|
||||
) -> str:
|
||||
"""Build an additional tool routing block for generic MCP connectors.
|
||||
|
||||
When users add MCP servers (e.g. GitLab, GitHub), the LLM needs to know
|
||||
those tools exist and should be called directly — not searched in the
|
||||
knowledge base.
|
||||
"""
|
||||
if not mcp_connector_tools:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"\n<mcp_tool_routing>",
|
||||
"You also have direct tools from these user-connected MCP servers.",
|
||||
"Their data is NEVER in the knowledge base — call their tools directly.",
|
||||
"",
|
||||
]
|
||||
for server_name, tool_names in mcp_connector_tools.items():
|
||||
lines.append(f"- {server_name} → {', '.join(tool_names)}")
|
||||
lines.append("</mcp_tool_routing>\n")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def build_surfsense_system_prompt(
|
||||
today: datetime | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
disabled_tool_names: set[str] | None = None,
|
||||
mcp_connector_tools: dict[str, list[str]] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build the SurfSense system prompt with default settings.
|
||||
|
|
@ -834,6 +859,9 @@ def build_surfsense_system_prompt(
|
|||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
||||
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
||||
mcp_connector_tools: Mapping of MCP server display name → list of tool names
|
||||
for generic MCP connectors. Injected into the system prompt so the LLM
|
||||
knows to call these tools directly.
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
|
|
@ -841,6 +869,7 @@ def build_surfsense_system_prompt(
|
|||
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
system_instructions = _get_system_instructions(visibility, today)
|
||||
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
|
||||
tools_instructions = _get_tools_instructions(
|
||||
visibility, enabled_tool_names, disabled_tool_names
|
||||
)
|
||||
|
|
@ -856,6 +885,7 @@ def build_configurable_system_prompt(
|
|||
thread_visibility: ChatVisibility | None = None,
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
disabled_tool_names: set[str] | None = None,
|
||||
mcp_connector_tools: dict[str, list[str]] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
||||
|
|
@ -877,6 +907,9 @@ def build_configurable_system_prompt(
|
|||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
||||
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
||||
mcp_connector_tools: Mapping of MCP server display name → list of tool names
|
||||
for generic MCP connectors. Injected into the system prompt so the LLM
|
||||
knows to call these tools directly.
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
|
|
@ -894,6 +927,8 @@ def build_configurable_system_prompt(
|
|||
else:
|
||||
system_instructions = ""
|
||||
|
||||
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
|
||||
|
||||
# Tools instructions: only include enabled tools, note disabled ones
|
||||
tools_instructions = _get_tools_instructions(
|
||||
thread_visibility, enabled_tool_names, disabled_tool_names
|
||||
|
|
|
|||
|
|
@ -45,6 +45,18 @@ class MCPClient:
|
|||
async def connect(self, max_retries: int = MAX_RETRIES):
|
||||
"""Connect to the MCP server and manage its lifecycle.
|
||||
|
||||
Retries only apply to the **connection** phase (spawning the process,
|
||||
initialising the session). Once the session is yielded to the caller,
|
||||
any exception raised by the caller propagates normally -- the context
|
||||
manager will NOT retry after ``yield``.
|
||||
|
||||
Previous implementation wrapped both connection AND yield inside the
|
||||
retry loop. Because ``@asynccontextmanager`` only allows a single
|
||||
``yield``, a failure after yield caused the generator to attempt a
|
||||
second yield on retry, triggering
|
||||
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
|
||||
the stdio subprocess.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of connection retry attempts
|
||||
|
||||
|
|
@ -57,26 +69,22 @@ class MCPClient:
|
|||
"""
|
||||
last_error = None
|
||||
delay = RETRY_DELAY
|
||||
connected = False
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Merge env vars with current environment
|
||||
server_env = os.environ.copy()
|
||||
server_env.update(self.env)
|
||||
|
||||
# Create server parameters with env
|
||||
server_params = StdioServerParameters(
|
||||
command=self.command, args=self.args, env=server_env
|
||||
)
|
||||
|
||||
# Spawn server process and create session
|
||||
# Note: Cannot combine these context managers because ClientSession
|
||||
# needs the read/write streams from stdio_client
|
||||
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
connected = True
|
||||
|
||||
if attempt > 0:
|
||||
logger.info(
|
||||
|
|
@ -91,10 +99,16 @@ class MCPClient:
|
|||
self.command,
|
||||
" ".join(self.args),
|
||||
)
|
||||
yield session
|
||||
return # Success, exit retry loop
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
self.session = None
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
self.session = None
|
||||
if connected:
|
||||
raise
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
|
|
@ -105,7 +119,7 @@ class MCPClient:
|
|||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
delay *= RETRY_BACKOFF # Exponential backoff
|
||||
delay *= RETRY_BACKOFF
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to connect to MCP server after %d attempts: %s",
|
||||
|
|
@ -113,10 +127,7 @@ class MCPClient:
|
|||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
self.session = None
|
||||
|
||||
# All retries exhausted
|
||||
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
|
||||
if last_error:
|
||||
error_msg += f": {last_error}"
|
||||
|
|
@ -161,12 +172,18 @@ class MCPClient:
|
|||
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
async def call_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
timeout: float = 60.0,
|
||||
) -> Any:
|
||||
"""Call a tool on the MCP server.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Arguments to pass to the tool
|
||||
timeout: Maximum seconds to wait for the tool to respond
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
|
|
@ -185,10 +202,11 @@ class MCPClient:
|
|||
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
|
||||
)
|
||||
|
||||
# Call tools/call RPC method
|
||||
response = await self.session.call_tool(tool_name, arguments=arguments)
|
||||
response = await asyncio.wait_for(
|
||||
self.session.call_tool(tool_name, arguments=arguments),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Extract content from response
|
||||
result = []
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
|
|
@ -202,15 +220,17 @@ class MCPClient:
|
|||
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
|
||||
return result_str
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"MCP tool '%s' timed out after %.0fs", tool_name, timeout
|
||||
)
|
||||
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
|
||||
except RuntimeError as e:
|
||||
# Handle validation errors from MCP server responses
|
||||
# Some MCP servers (like server-memory) return extra fields not in their schema
|
||||
if "Invalid structured content" in str(e):
|
||||
logger.warning(
|
||||
"MCP server returned data not matching its schema, but continuing: %s",
|
||||
e,
|
||||
)
|
||||
# Try to extract result from error message or return a success message
|
||||
return "Operation completed (server returned unexpected format)"
|
||||
raise
|
||||
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ clicking "Always Allow", which adds the tool name to the connector's
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
|
@ -27,7 +28,7 @@ if TYPE_CHECKING:
|
|||
from langchain_core.tools import StructuredTool
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||
from sqlalchemy import cast, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -41,6 +42,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
|
||||
_MCP_CACHE_MAX_SIZE = 50
|
||||
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
|
||||
_TOOL_CALL_MAX_RETRIES = 3
|
||||
_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt
|
||||
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
|
||||
|
||||
|
||||
|
|
@ -62,7 +66,18 @@ def _create_dynamic_input_model_from_schema(
|
|||
tool_name: str,
|
||||
input_schema: dict[str, Any],
|
||||
) -> type[BaseModel]:
|
||||
"""Create a Pydantic model from MCP tool's JSON schema."""
|
||||
"""Create a Pydantic model from MCP tool's JSON schema.
|
||||
|
||||
Models always allow extra fields (``extra="allow"``) so that parameters
|
||||
missing from a broken or incomplete JSON schema (e.g. ``zod-to-json-schema``
|
||||
producing an empty ``$schema``-only object) can still be forwarded to the
|
||||
MCP server.
|
||||
|
||||
When the schema declares **no** properties, a synthetic ``input_data``
|
||||
field of type ``dict`` is injected so the LLM has a visible parameter to
|
||||
populate. The caller should unpack ``input_data`` before forwarding to
|
||||
the MCP server (see ``_unpack_synthetic_input_data``).
|
||||
"""
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = input_schema.get("required", [])
|
||||
|
||||
|
|
@ -82,8 +97,35 @@ def _create_dynamic_input_model_from_schema(
|
|||
Field(None, description=param_description),
|
||||
)
|
||||
|
||||
if not properties:
|
||||
field_definitions["input_data"] = (
|
||||
dict[str, Any] | None,
|
||||
Field(
|
||||
None,
|
||||
description=(
|
||||
"Arguments to pass to this tool as a JSON object. "
|
||||
"Infer sensible key names from the tool name and description "
|
||||
"(e.g. {\"search\": \"my query\"} for a search tool)."
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
|
||||
return create_model(model_name, **field_definitions)
|
||||
model = create_model(model_name, __config__=ConfigDict(extra="allow"), **field_definitions)
|
||||
return model
|
||||
|
||||
|
||||
def _unpack_synthetic_input_data(kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Unpack the synthetic ``input_data`` field into top-level kwargs.
|
||||
|
||||
When the MCP tool schema is empty, ``_create_dynamic_input_model_from_schema``
|
||||
adds a catch-all ``input_data: dict`` field. This helper merges that dict
|
||||
back into the top-level kwargs so the MCP server receives flat arguments.
|
||||
"""
|
||||
input_data = kwargs.pop("input_data", None)
|
||||
if isinstance(input_data, dict):
|
||||
kwargs.update(input_data)
|
||||
return kwargs
|
||||
|
||||
|
||||
async def _create_mcp_tool_from_definition_stdio(
|
||||
|
|
@ -101,7 +143,12 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
``GraphInterrupt`` propagates cleanly to LangGraph.
|
||||
"""
|
||||
tool_name = tool_def.get("name", "unnamed_tool")
|
||||
tool_description = tool_def.get("description", "No description provided")
|
||||
raw_description = tool_def.get("description", "No description provided")
|
||||
tool_description = (
|
||||
f"[MCP server: {connector_name}] {raw_description}"
|
||||
if connector_name
|
||||
else raw_description
|
||||
)
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
|
||||
logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema)
|
||||
|
|
@ -119,7 +166,7 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
params=kwargs,
|
||||
context={
|
||||
"mcp_server": connector_name,
|
||||
"tool_description": tool_description,
|
||||
"tool_description": raw_description,
|
||||
"mcp_transport": "stdio",
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
|
|
@ -127,18 +174,32 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
||||
call_kwargs = _unpack_synthetic_input_data(
|
||||
{k: v for k, v in hitl_result.params.items() if v is not None}
|
||||
)
|
||||
|
||||
try:
|
||||
async with mcp_client.connect():
|
||||
result = await mcp_client.call_tool(tool_name, call_kwargs)
|
||||
return str(result)
|
||||
except RuntimeError as e:
|
||||
logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e)
|
||||
return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
||||
except Exception as e:
|
||||
logger.exception("MCP tool '%s' execution failed: %s", tool_name, e)
|
||||
return f"Error: MCP tool '{tool_name}' execution failed: {e!s}"
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(_TOOL_CALL_MAX_RETRIES):
|
||||
try:
|
||||
async with mcp_client.connect():
|
||||
result = await mcp_client.call_tool(tool_name, call_kwargs)
|
||||
return str(result)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < _TOOL_CALL_MAX_RETRIES - 1:
|
||||
delay = _TOOL_CALL_RETRY_DELAY * (2 ** attempt)
|
||||
logger.warning(
|
||||
"MCP tool '%s' failed (attempt %d/%d): %s. Retrying in %.1fs...",
|
||||
tool_name, attempt + 1, _TOOL_CALL_MAX_RETRIES, e, delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error(
|
||||
"MCP tool '%s' failed after %d attempts: %s",
|
||||
tool_name, _TOOL_CALL_MAX_RETRIES, e, exc_info=True,
|
||||
)
|
||||
|
||||
return f"Error: MCP tool '{tool_name}' failed after {_TOOL_CALL_MAX_RETRIES} attempts: {last_error!s}"
|
||||
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
|
|
@ -148,6 +209,8 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
metadata={
|
||||
"mcp_input_schema": input_schema,
|
||||
"mcp_transport": "stdio",
|
||||
"mcp_connector_name": connector_name or None,
|
||||
"mcp_is_generic": True,
|
||||
"hitl": True,
|
||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||
},
|
||||
|
|
@ -167,6 +230,7 @@ async def _create_mcp_tool_from_definition_http(
|
|||
trusted_tools: list[str] | None = None,
|
||||
readonly_tools: frozenset[str] | None = None,
|
||||
tool_name_prefix: str | None = None,
|
||||
is_generic_mcp: bool = False,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
||||
|
||||
|
|
@ -178,7 +242,7 @@ async def _create_mcp_tool_from_definition_http(
|
|||
but the actual MCP ``call_tool`` still uses the original name.
|
||||
"""
|
||||
original_tool_name = tool_def.get("name", "unnamed_tool")
|
||||
tool_description = tool_def.get("description", "No description provided")
|
||||
raw_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
|
||||
|
||||
|
|
@ -188,18 +252,51 @@ async def _create_mcp_tool_from_definition_http(
|
|||
else original_tool_name
|
||||
)
|
||||
if tool_name_prefix:
|
||||
tool_description = f"[Account: {connector_name}] {tool_description}"
|
||||
tool_description = f"[Account: {connector_name}] {raw_description}"
|
||||
elif is_generic_mcp and connector_name:
|
||||
tool_description = f"[MCP server: {connector_name}] {raw_description}"
|
||||
else:
|
||||
tool_description = raw_description
|
||||
|
||||
logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema)
|
||||
|
||||
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
|
||||
|
||||
async def _do_mcp_call(
|
||||
call_headers: dict[str, str],
|
||||
call_kwargs: dict[str, Any],
|
||||
timeout: float = 60.0,
|
||||
) -> str:
|
||||
"""Execute a single MCP HTTP call with the given headers."""
|
||||
async with (
|
||||
streamablehttp_client(url, headers=call_headers) as (read, write, _),
|
||||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
response = await asyncio.wait_for(
|
||||
session.call_tool(original_tool_name, arguments=call_kwargs),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
result = []
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
result.append(content.text)
|
||||
elif hasattr(content, "data"):
|
||||
result.append(str(content.data))
|
||||
else:
|
||||
result.append(str(content))
|
||||
|
||||
return "\n".join(result) if result else ""
|
||||
|
||||
async def mcp_http_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via HTTP transport."""
|
||||
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
||||
|
||||
if is_readonly:
|
||||
call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
call_kwargs = _unpack_synthetic_input_data(
|
||||
{k: v for k, v in kwargs.items() if v is not None}
|
||||
)
|
||||
else:
|
||||
hitl_result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
|
|
@ -207,7 +304,7 @@ async def _create_mcp_tool_from_definition_http(
|
|||
params=kwargs,
|
||||
context={
|
||||
"mcp_server": connector_name,
|
||||
"tool_description": tool_description,
|
||||
"tool_description": raw_description,
|
||||
"mcp_transport": "http",
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
|
|
@ -215,34 +312,51 @@ async def _create_mcp_tool_from_definition_http(
|
|||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
||||
call_kwargs = _unpack_synthetic_input_data(
|
||||
{k: v for k, v in hitl_result.params.items() if v is not None}
|
||||
)
|
||||
|
||||
try:
|
||||
async with (
|
||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
||||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
response = await session.call_tool(
|
||||
original_tool_name, arguments=call_kwargs,
|
||||
result_str = await _do_mcp_call(headers, call_kwargs)
|
||||
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
|
||||
return result_str
|
||||
|
||||
except Exception as first_err:
|
||||
if not _is_auth_error(first_err) or connector_id is None:
|
||||
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, first_err)
|
||||
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {first_err!s}"
|
||||
|
||||
logger.warning(
|
||||
"MCP HTTP tool '%s' got 401 — attempting token refresh for connector %s",
|
||||
exposed_name, connector_id,
|
||||
)
|
||||
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
||||
if fresh_headers is None:
|
||||
await _mark_connector_auth_expired(connector_id)
|
||||
return (
|
||||
f"Error: MCP tool '{exposed_name}' authentication expired. "
|
||||
"Please re-authenticate the connector in your settings."
|
||||
)
|
||||
|
||||
result = []
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
result.append(content.text)
|
||||
elif hasattr(content, "data"):
|
||||
result.append(str(content.data))
|
||||
else:
|
||||
result.append(str(content))
|
||||
|
||||
result_str = "\n".join(result) if result else ""
|
||||
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
|
||||
try:
|
||||
result_str = await _do_mcp_call(fresh_headers, call_kwargs)
|
||||
logger.info(
|
||||
"MCP HTTP tool '%s' succeeded after 401 recovery",
|
||||
exposed_name,
|
||||
)
|
||||
return result_str
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, e)
|
||||
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
|
||||
except Exception as retry_err:
|
||||
logger.exception(
|
||||
"MCP HTTP tool '%s' still failing after token refresh: %s",
|
||||
exposed_name, retry_err,
|
||||
)
|
||||
if _is_auth_error(retry_err):
|
||||
await _mark_connector_auth_expired(connector_id)
|
||||
return (
|
||||
f"Error: MCP tool '{exposed_name}' authentication expired. "
|
||||
"Please re-authenticate the connector in your settings."
|
||||
)
|
||||
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {retry_err!s}"
|
||||
|
||||
tool = StructuredTool(
|
||||
name=exposed_name,
|
||||
|
|
@ -253,6 +367,8 @@ async def _create_mcp_tool_from_definition_http(
|
|||
"mcp_input_schema": input_schema,
|
||||
"mcp_transport": "http",
|
||||
"mcp_url": url,
|
||||
"mcp_connector_name": connector_name or None,
|
||||
"mcp_is_generic": is_generic_mcp,
|
||||
"hitl": not is_readonly,
|
||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||
"mcp_original_tool_name": original_tool_name,
|
||||
|
|
@ -334,6 +450,7 @@ async def _load_http_mcp_tools(
|
|||
allowed_tools: list[str] | None = None,
|
||||
readonly_tools: frozenset[str] | None = None,
|
||||
tool_name_prefix: str | None = None,
|
||||
is_generic_mcp: bool = False,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from an HTTP-based MCP server.
|
||||
|
||||
|
|
@ -365,66 +482,99 @@ async def _load_http_mcp_tools(
|
|||
|
||||
allowed_set = set(allowed_tools) if allowed_tools else None
|
||||
|
||||
try:
|
||||
async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]:
|
||||
"""Connect, initialize, and list tools from the MCP server."""
|
||||
async with (
|
||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
||||
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
|
||||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
|
||||
response = await session.list_tools()
|
||||
tool_definitions = []
|
||||
for tool in response.tools:
|
||||
tool_definitions.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"input_schema": tool.inputSchema
|
||||
if hasattr(tool, "inputSchema")
|
||||
else {},
|
||||
}
|
||||
)
|
||||
return [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"input_schema": tool.inputSchema
|
||||
if hasattr(tool, "inputSchema")
|
||||
else {},
|
||||
}
|
||||
for tool in response.tools
|
||||
]
|
||||
|
||||
total_discovered = len(tool_definitions)
|
||||
try:
|
||||
tool_definitions = await _discover(headers)
|
||||
except Exception as first_err:
|
||||
if not _is_auth_error(first_err) or connector_id is None:
|
||||
logger.exception(
|
||||
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
||||
url, connector_id, first_err,
|
||||
)
|
||||
return tools
|
||||
|
||||
if allowed_set:
|
||||
tool_definitions = [
|
||||
td for td in tool_definitions if td["name"] in allowed_set
|
||||
]
|
||||
logger.info(
|
||||
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
|
||||
url, connector_id, len(tool_definitions), total_discovered,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
|
||||
total_discovered, url, connector_id,
|
||||
)
|
||||
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition_http(
|
||||
tool_def,
|
||||
url,
|
||||
headers,
|
||||
connector_name=connector_name,
|
||||
connector_id=connector_id,
|
||||
trusted_tools=trusted_tools,
|
||||
readonly_tools=readonly_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to create HTTP tool '%s' from connector %d: %s",
|
||||
tool_def.get("name"), connector_id, e,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
||||
url, connector_id, e,
|
||||
logger.warning(
|
||||
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
|
||||
connector_id,
|
||||
)
|
||||
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
||||
if fresh_headers is None:
|
||||
await _mark_connector_auth_expired(connector_id)
|
||||
logger.error(
|
||||
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
|
||||
connector_id,
|
||||
)
|
||||
return tools
|
||||
|
||||
try:
|
||||
tool_definitions = await _discover(fresh_headers)
|
||||
headers = fresh_headers
|
||||
logger.info(
|
||||
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
|
||||
connector_id,
|
||||
)
|
||||
except Exception as retry_err:
|
||||
logger.exception(
|
||||
"HTTP MCP discovery for connector %d still failing after refresh: %s",
|
||||
connector_id, retry_err,
|
||||
)
|
||||
if _is_auth_error(retry_err):
|
||||
await _mark_connector_auth_expired(connector_id)
|
||||
return tools
|
||||
|
||||
total_discovered = len(tool_definitions)
|
||||
|
||||
if allowed_set:
|
||||
tool_definitions = [
|
||||
td for td in tool_definitions if td["name"] in allowed_set
|
||||
]
|
||||
logger.info(
|
||||
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
|
||||
url, connector_id, len(tool_definitions), total_discovered,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
|
||||
total_discovered, url, connector_id,
|
||||
)
|
||||
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition_http(
|
||||
tool_def,
|
||||
url,
|
||||
headers,
|
||||
connector_name=connector_name,
|
||||
connector_id=connector_id,
|
||||
trusted_tools=trusted_tools,
|
||||
readonly_tools=readonly_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
is_generic_mcp=is_generic_mcp,
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to create HTTP tool '%s' from connector %d: %s",
|
||||
tool_def.get("name"), connector_id, e,
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
|
|
@ -476,6 +626,91 @@ def _inject_oauth_headers(
|
|||
return None
|
||||
|
||||
|
||||
async def _refresh_connector_token(
|
||||
session: AsyncSession,
|
||||
connector: "SearchSourceConnector",
|
||||
) -> str | None:
|
||||
"""Refresh the OAuth token for an MCP connector and persist the result.
|
||||
|
||||
This is the shared core used by both proactive (pre-expiry) and reactive
|
||||
(401 recovery) refresh paths. It handles:
|
||||
- Decrypting the current refresh token / client secret
|
||||
- Calling the token endpoint
|
||||
- Encrypting and persisting the new tokens
|
||||
- Clearing ``auth_expired`` if it was set
|
||||
- Invalidating the MCP tools cache
|
||||
|
||||
Returns the **plaintext** new access token on success, or ``None`` on
|
||||
failure (no refresh token, IdP error, etc.).
|
||||
"""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.services.mcp_oauth.discovery import refresh_access_token
|
||||
|
||||
cfg = connector.config or {}
|
||||
mcp_oauth = cfg.get("mcp_oauth", {})
|
||||
|
||||
refresh_token = mcp_oauth.get("refresh_token")
|
||||
if not refresh_token:
|
||||
logger.warning(
|
||||
"MCP connector %s: no refresh_token available",
|
||||
connector.id,
|
||||
)
|
||||
return None
|
||||
|
||||
enc = _get_token_enc()
|
||||
decrypted_refresh = enc.decrypt_token(refresh_token)
|
||||
decrypted_secret = (
|
||||
enc.decrypt_token(mcp_oauth["client_secret"])
|
||||
if mcp_oauth.get("client_secret")
|
||||
else ""
|
||||
)
|
||||
|
||||
token_json = await refresh_access_token(
|
||||
token_endpoint=mcp_oauth["token_endpoint"],
|
||||
refresh_token=decrypted_refresh,
|
||||
client_id=mcp_oauth["client_id"],
|
||||
client_secret=decrypted_secret,
|
||||
)
|
||||
|
||||
new_access = token_json.get("access_token")
|
||||
if not new_access:
|
||||
logger.warning(
|
||||
"MCP connector %s: token refresh returned no access_token",
|
||||
connector.id,
|
||||
)
|
||||
return None
|
||||
|
||||
new_expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
new_expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
updated_oauth = dict(mcp_oauth)
|
||||
updated_oauth["access_token"] = enc.encrypt_token(new_access)
|
||||
if token_json.get("refresh_token"):
|
||||
updated_oauth["refresh_token"] = enc.encrypt_token(
|
||||
token_json["refresh_token"]
|
||||
)
|
||||
updated_oauth["expires_at"] = (
|
||||
new_expires_at.isoformat() if new_expires_at else None
|
||||
)
|
||||
|
||||
updated_cfg = {**cfg, "mcp_oauth": updated_oauth}
|
||||
updated_cfg.pop("auth_expired", None)
|
||||
connector.config = updated_cfg
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||
|
||||
return new_access
|
||||
|
||||
|
||||
async def _maybe_refresh_mcp_oauth_token(
|
||||
session: AsyncSession,
|
||||
connector: "SearchSourceConnector",
|
||||
|
|
@ -504,73 +739,13 @@ async def _maybe_refresh_mcp_oauth_token(
|
|||
except (ValueError, TypeError):
|
||||
return server_config
|
||||
|
||||
refresh_token = mcp_oauth.get("refresh_token")
|
||||
if not refresh_token:
|
||||
logger.warning(
|
||||
"MCP connector %s token expired but no refresh_token available",
|
||||
connector.id,
|
||||
)
|
||||
return server_config
|
||||
|
||||
try:
|
||||
from app.services.mcp_oauth.discovery import refresh_access_token
|
||||
|
||||
enc = _get_token_enc()
|
||||
decrypted_refresh = enc.decrypt_token(refresh_token)
|
||||
decrypted_secret = (
|
||||
enc.decrypt_token(mcp_oauth["client_secret"])
|
||||
if mcp_oauth.get("client_secret")
|
||||
else ""
|
||||
)
|
||||
|
||||
token_json = await refresh_access_token(
|
||||
token_endpoint=mcp_oauth["token_endpoint"],
|
||||
refresh_token=decrypted_refresh,
|
||||
client_id=mcp_oauth["client_id"],
|
||||
client_secret=decrypted_secret,
|
||||
)
|
||||
|
||||
new_access = token_json.get("access_token")
|
||||
new_access = await _refresh_connector_token(session, connector)
|
||||
if not new_access:
|
||||
logger.warning(
|
||||
"MCP connector %s token refresh returned no access_token",
|
||||
connector.id,
|
||||
)
|
||||
return server_config
|
||||
|
||||
new_expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
new_expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
logger.info("Proactively refreshed MCP OAuth token for connector %s", connector.id)
|
||||
|
||||
updated_oauth = dict(mcp_oauth)
|
||||
updated_oauth["access_token"] = enc.encrypt_token(new_access)
|
||||
if token_json.get("refresh_token"):
|
||||
updated_oauth["refresh_token"] = enc.encrypt_token(
|
||||
token_json["refresh_token"]
|
||||
)
|
||||
updated_oauth["expires_at"] = (
|
||||
new_expires_at.isoformat() if new_expires_at else None
|
||||
)
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
connector.config = {
|
||||
**cfg,
|
||||
"server_config": server_config,
|
||||
"mcp_oauth": updated_oauth,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info("Refreshed MCP OAuth token for connector %s", connector.id)
|
||||
|
||||
# Invalidate cache so next call picks up the new token.
|
||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||
|
||||
# Return server_config with the fresh token injected for immediate use.
|
||||
refreshed_config = dict(server_config)
|
||||
refreshed_config["headers"] = {
|
||||
**server_config.get("headers", {}),
|
||||
|
|
@ -587,6 +762,117 @@ async def _maybe_refresh_mcp_oauth_token(
|
|||
return server_config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reactive 401 handling helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _is_auth_error(exc: Exception) -> bool:
|
||||
"""Check if an exception indicates an HTTP 401 authentication failure."""
|
||||
try:
|
||||
import httpx
|
||||
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
return exc.response.status_code == 401
|
||||
except ImportError:
|
||||
pass
|
||||
err_str = str(exc).lower()
|
||||
return "401" in err_str or "unauthorized" in err_str
|
||||
|
||||
|
||||
async def _force_refresh_and_get_headers(
|
||||
connector_id: int,
|
||||
) -> dict[str, str] | None:
|
||||
"""Force-refresh OAuth token for a connector and return fresh HTTP headers.
|
||||
|
||||
Opens a **new** DB session so this can be called from inside tool closures
|
||||
that don't have access to the original session.
|
||||
|
||||
Returns ``None`` when the connector is not OAuth-backed, has no
|
||||
refresh token, or the refresh itself fails.
|
||||
"""
|
||||
from app.db import async_session_maker
|
||||
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return None
|
||||
|
||||
cfg = connector.config or {}
|
||||
if not cfg.get("mcp_oauth"):
|
||||
return None
|
||||
|
||||
server_config = cfg.get("server_config", {})
|
||||
|
||||
new_access = await _refresh_connector_token(session, connector)
|
||||
if not new_access:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
"Force-refreshed MCP OAuth token for connector %s (401 recovery)",
|
||||
connector_id,
|
||||
)
|
||||
return {
|
||||
**server_config.get("headers", {}),
|
||||
"Authorization": f"Bearer {new_access}",
|
||||
}
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to force-refresh MCP OAuth token for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _mark_connector_auth_expired(connector_id: int) -> None:
|
||||
"""Set ``config.auth_expired = True`` so the frontend shows re-auth UI."""
|
||||
from app.db import async_session_maker
|
||||
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return
|
||||
|
||||
cfg = dict(connector.config or {})
|
||||
if cfg.get("auth_expired"):
|
||||
return
|
||||
|
||||
cfg["auth_expired"] = True
|
||||
connector.config = cfg
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"Marked MCP connector %s as auth_expired after unrecoverable 401",
|
||||
connector_id,
|
||||
)
|
||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to mark connector %s as auth_expired",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||
"""Invalidate cached MCP tools.
|
||||
|
||||
|
|
@ -661,7 +947,7 @@ async def load_mcp_tools(
|
|||
multi_account_types,
|
||||
)
|
||||
|
||||
tools: list[StructuredTool] = []
|
||||
discovery_tasks: list[dict[str, Any]] = []
|
||||
for connector in connectors:
|
||||
try:
|
||||
cfg = connector.config or {}
|
||||
|
|
@ -674,14 +960,10 @@ async def load_mcp_tools(
|
|||
)
|
||||
continue
|
||||
|
||||
# For MCP OAuth connectors: refresh if needed, then decrypt the
|
||||
# access token and inject it into headers at runtime. The DB
|
||||
# intentionally does NOT store plaintext tokens in server_config.
|
||||
if cfg.get("mcp_oauth"):
|
||||
server_config = await _maybe_refresh_mcp_oauth_token(
|
||||
session, connector, cfg, server_config,
|
||||
)
|
||||
# Re-read cfg after potential refresh (connector was reloaded from DB).
|
||||
cfg = connector.config or {}
|
||||
server_config = _inject_oauth_headers(cfg, server_config)
|
||||
if server_config is None:
|
||||
|
|
@ -689,6 +971,7 @@ async def load_mcp_tools(
|
|||
"Skipping MCP connector %d — OAuth token decryption failed",
|
||||
connector.id,
|
||||
)
|
||||
await _mark_connector_auth_expired(connector.id)
|
||||
continue
|
||||
|
||||
trusted_tools = cfg.get("trusted_tools", [])
|
||||
|
|
@ -703,7 +986,6 @@ async def load_mcp_tools(
|
|||
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
||||
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
||||
|
||||
# Build a prefix only when multiple accounts share the same type.
|
||||
tool_name_prefix: str | None = None
|
||||
if ct in multi_account_types and svc_cfg:
|
||||
service_key = next(
|
||||
|
|
@ -713,34 +995,68 @@ async def load_mcp_tools(
|
|||
if service_key:
|
||||
tool_name_prefix = f"{service_key}_{connector.id}"
|
||||
|
||||
transport = server_config.get("transport", "stdio")
|
||||
|
||||
if transport in ("streamable-http", "http", "sse"):
|
||||
connector_tools = await _load_http_mcp_tools(
|
||||
connector.id,
|
||||
connector.name,
|
||||
server_config,
|
||||
trusted_tools=trusted_tools,
|
||||
allowed_tools=allowed_tools,
|
||||
readonly_tools=readonly_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
)
|
||||
else:
|
||||
connector_tools = await _load_stdio_mcp_tools(
|
||||
connector.id,
|
||||
connector.name,
|
||||
server_config,
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
|
||||
tools.extend(connector_tools)
|
||||
discovery_tasks.append({
|
||||
"connector_id": connector.id,
|
||||
"connector_name": connector.name,
|
||||
"server_config": server_config,
|
||||
"trusted_tools": trusted_tools,
|
||||
"allowed_tools": allowed_tools,
|
||||
"readonly_tools": readonly_tools,
|
||||
"tool_name_prefix": tool_name_prefix,
|
||||
"transport": server_config.get("transport", "stdio"),
|
||||
"is_generic_mcp": svc_cfg is None,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to load tools from MCP connector %d: %s",
|
||||
"Failed to prepare MCP connector %d: %s",
|
||||
connector.id, e,
|
||||
)
|
||||
|
||||
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
|
||||
try:
|
||||
if task["transport"] in ("streamable-http", "http", "sse"):
|
||||
return await asyncio.wait_for(
|
||||
_load_http_mcp_tools(
|
||||
task["connector_id"],
|
||||
task["connector_name"],
|
||||
task["server_config"],
|
||||
trusted_tools=task["trusted_tools"],
|
||||
allowed_tools=task["allowed_tools"],
|
||||
readonly_tools=task["readonly_tools"],
|
||||
tool_name_prefix=task["tool_name_prefix"],
|
||||
is_generic_mcp=task.get("is_generic_mcp", False),
|
||||
),
|
||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
else:
|
||||
return await asyncio.wait_for(
|
||||
_load_stdio_mcp_tools(
|
||||
task["connector_id"],
|
||||
task["connector_name"],
|
||||
task["server_config"],
|
||||
trusted_tools=task["trusted_tools"],
|
||||
),
|
||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"MCP connector %d timed out after %ds during discovery",
|
||||
task["connector_id"], _MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to load tools from MCP connector %d: %s",
|
||||
task["connector_id"], e,
|
||||
)
|
||||
return []
|
||||
|
||||
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
|
||||
tools: list[StructuredTool] = [
|
||||
tool for sublist in results for tool in sublist
|
||||
]
|
||||
|
||||
_mcp_tools_cache[search_space_id] = (now, tools)
|
||||
|
||||
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
|
||||
|
|
|
|||
|
|
@ -3105,13 +3105,18 @@ async def trust_mcp_tool(
|
|||
"""Add a tool to the MCP connector's trusted (always-allow) list.
|
||||
|
||||
Once trusted, the tool executes without HITL approval on subsequent calls.
|
||||
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors
|
||||
(LINEAR_CONNECTOR, JIRA_CONNECTOR, etc.) by checking for ``server_config``.
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
@ -3156,13 +3161,17 @@ async def untrust_mcp_tool(
|
|||
"""Remove a tool from the MCP connector's trusted list.
|
||||
|
||||
The tool will require HITL approval again on subsequent calls.
|
||||
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors.
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -66,6 +65,8 @@ class ConfluenceKBSyncService:
|
|||
if dup:
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -184,6 +185,8 @@ class ConfluenceKBSyncService:
|
|||
|
||||
space_id = (document.document_metadata or {}).get("space_id", "")
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.db import Document, DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -73,6 +72,8 @@ class DropboxKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -78,6 +77,8 @@ class GmailKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from app.db import (
|
|||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -91,6 +90,8 @@ class GoogleCalendarKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -249,6 +250,8 @@ class GoogleCalendarKBSyncService:
|
|||
if not indexable_content:
|
||||
return {"status": "error", "message": "Event produced empty content"}
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -75,6 +74,8 @@ class GoogleDriveKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -75,6 +74,8 @@ class JiraKBSyncService:
|
|||
if dup:
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -190,6 +191,8 @@ class JiraKBSyncService:
|
|||
state = formatted.get("status", "Unknown")
|
||||
comment_count = len(formatted.get("comments", []))
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.connectors.linear_connector import LinearConnector
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -85,6 +84,8 @@ class LinearKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -226,6 +227,8 @@ class LinearKBSyncService:
|
|||
comment_count = len(formatted_issue.get("comments", []))
|
||||
formatted_issue.get("description", "")
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from langchain_litellm import ChatLiteLLM
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
from app.config import config
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
from app.services.llm_router_service import (
|
||||
|
|
@ -204,6 +203,8 @@ async def validate_llm_config(
|
|||
if litellm_params:
|
||||
litellm_kwargs.update(litellm_params)
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
# Run the test call in a worker thread with a hard timeout. Some
|
||||
|
|
@ -377,6 +378,8 @@ async def get_search_space_llm_instance(
|
|||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
# Get the LLM configuration from database (NewLLMConfig)
|
||||
|
|
@ -454,6 +457,8 @@ async def get_search_space_llm_instance(
|
|||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -555,6 +560,8 @@ async def get_vision_llm(
|
|||
if global_cfg.get("litellm_params"):
|
||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
result = await session.execute(
|
||||
|
|
@ -588,6 +595,8 @@ async def get_vision_llm(
|
|||
if vision_cfg.litellm_params:
|
||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -74,6 +73,8 @@ class NotionKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -244,6 +245,8 @@ class NotionKBSyncService:
|
|||
f"Final content length: {len(full_content)} chars, verified={content_verified}"
|
||||
)
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
logger.debug("Generating summary and embeddings")
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.db import Document, DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -73,6 +72,8 @@ class OneDriveKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
|
|||
|
|
@ -123,8 +123,9 @@ export const ConnectorIndicator = forwardRef<ConnectorIndicatorHandle, Connector
|
|||
handleSkipIndexing,
|
||||
handleStartEdit,
|
||||
handleSaveConnector,
|
||||
handleDisconnectConnector,
|
||||
handleBackFromEdit,
|
||||
handleDisconnectConnector,
|
||||
handleDisconnectFromList,
|
||||
handleBackFromEdit,
|
||||
handleBackFromConnect,
|
||||
handleBackFromYouTube,
|
||||
handleViewAccountsList,
|
||||
|
|
@ -225,25 +226,27 @@ export const ConnectorIndicator = forwardRef<ConnectorIndicatorHandle, Connector
|
|||
{isYouTubeView && searchSpaceId ? (
|
||||
<YouTubeCrawlerView searchSpaceId={searchSpaceId} onBack={handleBackFromYouTube} />
|
||||
) : viewingMCPList ? (
|
||||
<ConnectorAccountsListView
|
||||
connectorType="MCP_CONNECTOR"
|
||||
connectorTitle="MCP Connectors"
|
||||
connectors={(allConnectors || []) as SearchSourceConnector[]}
|
||||
indexingConnectorIds={indexingConnectorIds}
|
||||
onBack={handleBackFromMCPList}
|
||||
onManage={handleStartEdit}
|
||||
onAddAccount={handleAddNewMCPFromList}
|
||||
addButtonText="Add New MCP Server"
|
||||
/>
|
||||
<ConnectorAccountsListView
|
||||
connectorType="MCP_CONNECTOR"
|
||||
connectorTitle="MCP Connectors"
|
||||
connectors={(allConnectors || []) as SearchSourceConnector[]}
|
||||
indexingConnectorIds={indexingConnectorIds}
|
||||
onBack={handleBackFromMCPList}
|
||||
onManage={handleStartEdit}
|
||||
onDisconnect={(connector) => handleDisconnectFromList(connector, () => refreshConnectors())}
|
||||
onAddAccount={handleAddNewMCPFromList}
|
||||
addButtonText="Add New MCP Server"
|
||||
/>
|
||||
) : viewingAccountsType ? (
|
||||
<ConnectorAccountsListView
|
||||
connectorType={viewingAccountsType.connectorType}
|
||||
connectorTitle={viewingAccountsType.connectorTitle}
|
||||
connectors={(connectors || []) as SearchSourceConnector[]}
|
||||
indexingConnectorIds={indexingConnectorIds}
|
||||
onBack={handleBackFromAccountsList}
|
||||
onManage={handleStartEdit}
|
||||
onAddAccount={() => {
|
||||
<ConnectorAccountsListView
|
||||
connectorType={viewingAccountsType.connectorType}
|
||||
connectorTitle={viewingAccountsType.connectorTitle}
|
||||
connectors={(connectors || []) as SearchSourceConnector[]}
|
||||
indexingConnectorIds={indexingConnectorIds}
|
||||
onBack={handleBackFromAccountsList}
|
||||
onManage={handleStartEdit}
|
||||
onDisconnect={(connector) => handleDisconnectFromList(connector, () => refreshConnectors())}
|
||||
onAddAccount={() => {
|
||||
// Check both OAUTH_CONNECTORS and COMPOSIO_CONNECTORS
|
||||
const oauthConnector =
|
||||
OAUTH_CONNECTORS.find(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { CheckCircle2, ChevronDown, ChevronUp, Server, XCircle } from "lucide-react";
|
||||
import { CheckCircle2, ChevronDown, ChevronUp, Loader2, Server, XCircle } from "lucide-react";
|
||||
import { type FC, useRef, useState } from "react";
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
|
@ -212,7 +212,14 @@ export const MCPConnectForm: FC<ConnectFormProps> = ({ onSubmit, isSubmitting })
|
|||
variant="secondary"
|
||||
className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80"
|
||||
>
|
||||
{isTesting ? "Testing Connection" : "Test Connection"}
|
||||
{isTesting ? (
|
||||
<>
|
||||
<Loader2 className="h-3.5 w-3.5 animate-spin" />
|
||||
Testing Connection...
|
||||
</>
|
||||
) : (
|
||||
"Test Connection"
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { CheckCircle2, ChevronDown, ChevronUp, Server, XCircle } from "lucide-react";
|
||||
import { CheckCircle2, ChevronDown, ChevronUp, Loader2, Server, XCircle } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
|
|
@ -217,7 +217,14 @@ export const MCPConfig: FC<MCPConfigProps> = ({ connector, onConfigChange, onNam
|
|||
variant="secondary"
|
||||
className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80"
|
||||
>
|
||||
{isTesting ? "Testing Connection" : "Test Connection"}
|
||||
{isTesting ? (
|
||||
<>
|
||||
<Loader2 className="h-3.5 w-3.5 animate-spin" />
|
||||
Testing Connection...
|
||||
</>
|
||||
) : (
|
||||
"Test Connection"
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import { toast } from "sonner";
|
|||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||
|
|
@ -16,23 +15,11 @@ import { DateRangeSelector } from "../../components/date-range-selector";
|
|||
import { PeriodicSyncConfig } from "../../components/periodic-sync-config";
|
||||
import { SummaryConfig } from "../../components/summary-config";
|
||||
import { VisionLLMConfig } from "../../components/vision-llm-config";
|
||||
import { LIVE_CONNECTOR_TYPES } from "../../constants/connector-constants";
|
||||
import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../../constants/connector-constants";
|
||||
import { getConnectorDisplayName } from "../../tabs/all-connectors-tab";
|
||||
import { MCPServiceConfig } from "../components/mcp-service-config";
|
||||
import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index";
|
||||
|
||||
const REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
|
||||
[EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth",
|
||||
[EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth",
|
||||
[EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth",
|
||||
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth",
|
||||
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth",
|
||||
[EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||
[EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth",
|
||||
[EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth",
|
||||
};
|
||||
|
||||
interface ConnectorEditViewProps {
|
||||
connector: SearchSourceConnector;
|
||||
startDate: Date | undefined;
|
||||
|
|
@ -86,7 +73,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
}) => {
|
||||
const searchSpaceIdAtom = useAtomValue(activeSearchSpaceIdAtom);
|
||||
const isAuthExpired = connector.config?.auth_expired === true;
|
||||
const reauthEndpoint = REAUTH_ENDPOINTS[connector.connector_type];
|
||||
const reauthEndpoint = getReauthEndpoint(connector);
|
||||
const [reauthing, setReauthing] = useState(false);
|
||||
|
||||
const handleReauth = useCallback(async () => {
|
||||
|
|
@ -124,10 +111,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
|
||||
// Get connector-specific config component (MCP-backed connectors use a generic view)
|
||||
const ConnectorConfigComponent = useMemo(() => {
|
||||
if (isMCPBacked) {
|
||||
const { MCPServiceConfig } = require("../components/mcp-service-config");
|
||||
return MCPServiceConfig as FC<ConnectorConfigProps>;
|
||||
}
|
||||
if (isMCPBacked) return MCPServiceConfig;
|
||||
return getConnectorConfigComponent(connector.connector_type);
|
||||
}, [connector.connector_type, isMCPBacked]);
|
||||
const [isScrolled, setIsScrolled] = useState(false);
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
||||
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||
|
||||
/**
|
||||
* Connectors that operate in real time (no background indexing).
|
||||
|
|
@ -367,5 +368,45 @@ export function getConnectorTelemetryMeta(connectorType: string): ConnectorTelem
|
|||
};
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// REAUTH ENDPOINTS
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Legacy (non-MCP) OAuth reauth endpoints, keyed by connector type.
|
||||
* These are used for connectors that were NOT created via MCP OAuth.
|
||||
*/
|
||||
export const LEGACY_REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
|
||||
[EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth",
|
||||
[EnumConnectorName.JIRA_CONNECTOR]: "/api/v1/auth/jira/connector/reauth",
|
||||
[EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth",
|
||||
[EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth",
|
||||
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth",
|
||||
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth",
|
||||
[EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||
[EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth",
|
||||
[EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth",
|
||||
[EnumConnectorName.CONFLUENCE_CONNECTOR]: "/api/v1/auth/confluence/connector/reauth",
|
||||
[EnumConnectorName.TEAMS_CONNECTOR]: "/api/v1/auth/teams/connector/reauth",
|
||||
[EnumConnectorName.DISCORD_CONNECTOR]: "/api/v1/auth/discord/connector/reauth",
|
||||
};
|
||||
|
||||
/**
|
||||
* Resolve the reauth endpoint for a connector.
|
||||
*
|
||||
* MCP OAuth connectors (those with ``config.mcp_service``) dynamically build
|
||||
* the URL from the service key. Legacy OAuth connectors fall back to the
|
||||
* static ``LEGACY_REAUTH_ENDPOINTS`` map.
|
||||
*/
|
||||
export function getReauthEndpoint(connector: SearchSourceConnector): string | undefined {
|
||||
const mcpService = connector.config?.mcp_service as string | undefined;
|
||||
if (mcpService) {
|
||||
return `/api/v1/auth/mcp/${mcpService}/connector/reauth`;
|
||||
}
|
||||
return LEGACY_REAUTH_ENDPOINTS[connector.connector_type];
|
||||
}
|
||||
|
||||
// Re-export IndexingConfigState from schemas for backward compatibility
|
||||
export type { IndexingConfigState } from "./connector-popup.schemas";
|
||||
|
|
|
|||
|
|
@ -1311,6 +1311,25 @@ export const useConnectorDialog = () => {
|
|||
[editingConnector, searchSpaceId, deleteConnector, cameFromMCPList, setIsOpen]
|
||||
);
|
||||
|
||||
const handleDisconnectFromList = useCallback(
|
||||
async (connector: SearchSourceConnector, refreshConnectors: () => void) => {
|
||||
if (!searchSpaceId) return;
|
||||
try {
|
||||
await deleteConnector({ id: connector.id });
|
||||
trackConnectorDeleted(Number(searchSpaceId), connector.connector_type, connector.id);
|
||||
toast.success(`${connector.name} disconnected successfully`);
|
||||
refreshConnectors();
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.logs.summary(Number(searchSpaceId)),
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error disconnecting connector:", error);
|
||||
toast.error("Failed to disconnect connector");
|
||||
}
|
||||
},
|
||||
[searchSpaceId, deleteConnector]
|
||||
);
|
||||
|
||||
// Handle quick index (index with selected date range, or backend defaults if none selected)
|
||||
const handleQuickIndexConnector = useCallback(
|
||||
async (
|
||||
|
|
@ -1484,6 +1503,7 @@ export const useConnectorDialog = () => {
|
|||
handleStartEdit,
|
||||
handleSaveConnector,
|
||||
handleDisconnectConnector,
|
||||
handleDisconnectFromList,
|
||||
handleBackFromEdit,
|
||||
handleBackFromConnect,
|
||||
handleBackFromYouTube,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import { ArrowLeft, Plus, RefreshCw, Server } from "lucide-react";
|
||||
import { ArrowLeft, Plus, RefreshCw, Server, Trash2 } from "lucide-react";
|
||||
import { type FC, useCallback, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||
|
|
@ -13,25 +13,10 @@ import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
|||
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||
import { formatRelativeDate } from "@/lib/format-date";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants";
|
||||
import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../constants/connector-constants";
|
||||
import { useConnectorStatus } from "../hooks/use-connector-status";
|
||||
import { getConnectorDisplayName } from "../tabs/all-connectors-tab";
|
||||
|
||||
const REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
|
||||
[EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth",
|
||||
[EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth",
|
||||
[EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth",
|
||||
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth",
|
||||
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth",
|
||||
[EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||
[EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth",
|
||||
[EnumConnectorName.JIRA_CONNECTOR]: "/api/v1/auth/jira/connector/reauth",
|
||||
[EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth",
|
||||
[EnumConnectorName.CONFLUENCE_CONNECTOR]: "/api/v1/auth/confluence/connector/reauth",
|
||||
};
|
||||
|
||||
interface ConnectorAccountsListViewProps {
|
||||
connectorType: string;
|
||||
connectorTitle: string;
|
||||
|
|
@ -39,15 +24,12 @@ interface ConnectorAccountsListViewProps {
|
|||
indexingConnectorIds: Set<number>;
|
||||
onBack: () => void;
|
||||
onManage: (connector: SearchSourceConnector) => void;
|
||||
onDisconnect?: (connector: SearchSourceConnector) => Promise<void> | void;
|
||||
onAddAccount: () => void;
|
||||
isConnecting?: boolean;
|
||||
addButtonText?: string;
|
||||
}
|
||||
|
||||
function isLiveConnector(connectorType: string): boolean {
|
||||
return LIVE_CONNECTOR_TYPES.has(connectorType) || connectorType === "MCP_CONNECTOR";
|
||||
}
|
||||
|
||||
export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
||||
connectorType,
|
||||
connectorTitle,
|
||||
|
|
@ -55,12 +37,15 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
|||
indexingConnectorIds,
|
||||
onBack,
|
||||
onManage,
|
||||
onDisconnect,
|
||||
onAddAccount,
|
||||
isConnecting = false,
|
||||
addButtonText,
|
||||
}) => {
|
||||
const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
|
||||
const [reauthingId, setReauthingId] = useState<number | null>(null);
|
||||
const [confirmDisconnectId, setConfirmDisconnectId] = useState<number | null>(null);
|
||||
const [disconnectingId, setDisconnectingId] = useState<number | null>(null);
|
||||
|
||||
// Get connector status
|
||||
const { isConnectorEnabled, getConnectorStatusMessage } = useConnectorStatus();
|
||||
|
|
@ -68,16 +53,15 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
|||
const isEnabled = isConnectorEnabled(connectorType);
|
||||
const statusMessage = getConnectorStatusMessage(connectorType);
|
||||
|
||||
const reauthEndpoint = REAUTH_ENDPOINTS[connectorType];
|
||||
|
||||
const handleReauth = useCallback(
|
||||
async (connectorId: number) => {
|
||||
if (!searchSpaceId || !reauthEndpoint) return;
|
||||
setReauthingId(connectorId);
|
||||
async (connector: SearchSourceConnector) => {
|
||||
const endpoint = getReauthEndpoint(connector);
|
||||
if (!searchSpaceId || !endpoint) return;
|
||||
setReauthingId(connector.id);
|
||||
try {
|
||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||
const url = new URL(`${backendUrl}${reauthEndpoint}`);
|
||||
url.searchParams.set("connector_id", String(connectorId));
|
||||
const url = new URL(`${backendUrl}${endpoint}`);
|
||||
url.searchParams.set("connector_id", String(connector.id));
|
||||
url.searchParams.set("space_id", String(searchSpaceId));
|
||||
url.searchParams.set("return_url", window.location.pathname);
|
||||
const response = await authenticatedFetch(url.toString());
|
||||
|
|
@ -99,7 +83,7 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
|||
setReauthingId(null);
|
||||
}
|
||||
},
|
||||
[searchSpaceId, reauthEndpoint]
|
||||
[searchSpaceId]
|
||||
);
|
||||
|
||||
// Filter connectors to only show those of this type
|
||||
|
|
@ -198,9 +182,11 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
|||
</div>
|
||||
) : (
|
||||
<div className="grid grid-cols-1 sm:grid-cols-2 gap-3">
|
||||
{typeConnectors.map((connector) => {
|
||||
const isIndexing = indexingConnectorIds.has(connector.id);
|
||||
const isAuthExpired = !!reauthEndpoint && connector.config?.auth_expired === true;
|
||||
{typeConnectors.map((connector) => {
|
||||
const isIndexing = indexingConnectorIds.has(connector.id);
|
||||
const connectorReauthEndpoint = getReauthEndpoint(connector);
|
||||
const isAuthExpired = !!connectorReauthEndpoint && connector.config?.auth_expired === true;
|
||||
const isLive = LIVE_CONNECTOR_TYPES.has(connector.connector_type) || Boolean(connector.config?.server_config);
|
||||
|
||||
return (
|
||||
<div
|
||||
|
|
@ -231,7 +217,7 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
|||
<Spinner size="xs" />
|
||||
Syncing
|
||||
</p>
|
||||
) : !isLiveConnector(connector.connector_type) ? (
|
||||
) : !isLive ? (
|
||||
<p className="text-[10px] mt-1 whitespace-nowrap truncate text-muted-foreground">
|
||||
{connector.last_indexed_at
|
||||
? `Last indexed: ${formatRelativeDate(connector.last_indexed_at)}`
|
||||
|
|
@ -239,28 +225,73 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
|||
</p>
|
||||
) : null}
|
||||
</div>
|
||||
{isAuthExpired ? (
|
||||
<Button
|
||||
size="sm"
|
||||
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-amber-600 hover:bg-amber-700 text-white border-0 shadow-xs shrink-0"
|
||||
onClick={() => handleReauth(connector.id)}
|
||||
disabled={reauthingId === connector.id}
|
||||
>
|
||||
<RefreshCw
|
||||
className={cn("size-3.5", reauthingId === connector.id && "animate-spin")}
|
||||
/>
|
||||
Re-authenticate
|
||||
</Button>
|
||||
{isAuthExpired ? (
|
||||
<Button
|
||||
size="sm"
|
||||
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-amber-600 hover:bg-amber-700 text-white border-0 shadow-xs shrink-0"
|
||||
onClick={() => handleReauth(connector)}
|
||||
disabled={reauthingId === connector.id}
|
||||
>
|
||||
<RefreshCw
|
||||
className={cn("size-3.5", reauthingId === connector.id && "animate-spin")}
|
||||
/>
|
||||
Re-authenticate
|
||||
</Button>
|
||||
) : isLive && onDisconnect ? (
|
||||
confirmDisconnectId === connector.id ? (
|
||||
<div className="flex items-center gap-1.5 shrink-0">
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
className="h-8 text-[11px] px-3 rounded-lg font-medium shadow-xs"
|
||||
onClick={async () => {
|
||||
setDisconnectingId(connector.id);
|
||||
setConfirmDisconnectId(null);
|
||||
try {
|
||||
await onDisconnect(connector);
|
||||
} finally {
|
||||
setDisconnectingId(null);
|
||||
}
|
||||
}}
|
||||
disabled={disconnectingId === connector.id}
|
||||
>
|
||||
{disconnectingId === connector.id ? (
|
||||
<RefreshCw className="size-3.5 animate-spin" />
|
||||
) : (
|
||||
"Confirm"
|
||||
)}
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-8 text-[11px] px-2 rounded-lg"
|
||||
onClick={() => setConfirmDisconnectId(null)}
|
||||
disabled={disconnectingId === connector.id}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80 shrink-0"
|
||||
onClick={() => onManage(connector)}
|
||||
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-red-50 hover:text-red-700 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-red-950 dark:hover:text-red-400 shrink-0"
|
||||
onClick={() => setConfirmDisconnectId(connector.id)}
|
||||
>
|
||||
Manage
|
||||
<Trash2 className="size-3.5" />
|
||||
Disconnect
|
||||
</Button>
|
||||
)}
|
||||
)
|
||||
) : (
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80 shrink-0"
|
||||
onClick={() => onManage(connector)}
|
||||
>
|
||||
Manage
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
|
||||
import { CornerDownLeftIcon, Pen } from "lucide-react";
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
|
|
@ -116,8 +117,8 @@ function GenericApprovalCard({
|
|||
if (phase !== "pending" || !isMCPTool) return;
|
||||
setProcessing();
|
||||
onDecision({ type: "approve" });
|
||||
connectorsApiService.trustMCPTool(mcpConnectorId, toolName).catch((err) => {
|
||||
console.error("Failed to trust MCP tool:", err);
|
||||
connectorsApiService.trustMCPTool(mcpConnectorId, toolName).catch(() => {
|
||||
toast.error("Failed to save 'Always Allow' preference. The tool will still require approval next time.");
|
||||
});
|
||||
}, [phase, setProcessing, onDecision, isMCPTool, mcpConnectorId, toolName]);
|
||||
|
||||
|
|
|
|||
|
|
@ -414,16 +414,8 @@ class ConnectorsApiService {
|
|||
* Subsequent calls to this tool will skip HITL approval.
|
||||
*/
|
||||
trustMCPTool = async (connectorId: number, toolName: string): Promise<void> => {
|
||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||
const token =
|
||||
typeof window !== "undefined" ? document.cookie.match(/fapiToken=([^;]+)/)?.[1] : undefined;
|
||||
await fetch(`${backendUrl}/api/v1/connectors/mcp/${connectorId}/trust-tool`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...(token ? { Authorization: `Bearer ${token}` } : {}),
|
||||
},
|
||||
body: JSON.stringify({ tool_name: toolName }),
|
||||
await baseApiService.post(`/api/v1/connectors/mcp/${connectorId}/trust-tool`, undefined, {
|
||||
body: { tool_name: toolName },
|
||||
});
|
||||
};
|
||||
|
||||
|
|
@ -431,16 +423,8 @@ class ConnectorsApiService {
|
|||
* Remove a tool from the MCP connector's "Always Allow" list.
|
||||
*/
|
||||
untrustMCPTool = async (connectorId: number, toolName: string): Promise<void> => {
|
||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||
const token =
|
||||
typeof window !== "undefined" ? document.cookie.match(/fapiToken=([^;]+)/)?.[1] : undefined;
|
||||
await fetch(`${backendUrl}/api/v1/connectors/mcp/${connectorId}/untrust-tool`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...(token ? { Authorization: `Bearer ${token}` } : {}),
|
||||
},
|
||||
body: JSON.stringify({ tool_name: toolName }),
|
||||
await baseApiService.post(`/api/v1/connectors/mcp/${connectorId}/untrust-tool`, undefined, {
|
||||
body: { tool_name: toolName },
|
||||
});
|
||||
};
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue