mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-21 18:55:16 +02:00
Merge pull request #1297 from CREDO23/feature/mcp-migration
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
[Improvement] MCP OAuth trust, 401 recovery, parallel discovery & connector UX
This commit is contained in:
commit
09ab174221
24 changed files with 835 additions and 333 deletions
|
|
@ -314,6 +314,20 @@ async def create_surfsense_deep_agent(
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
_enabled_tool_names = {t.name for t in tools}
|
_enabled_tool_names = {t.name for t in tools}
|
||||||
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
|
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
|
||||||
|
|
||||||
|
# Collect generic MCP connector info so the system prompt can route queries
|
||||||
|
# to their tools instead of falling back to "not in knowledge base".
|
||||||
|
_mcp_connector_tools: dict[str, list[str]] = {}
|
||||||
|
for t in tools:
|
||||||
|
meta = getattr(t, "metadata", None) or {}
|
||||||
|
if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"):
|
||||||
|
_mcp_connector_tools.setdefault(
|
||||||
|
meta["mcp_connector_name"], [],
|
||||||
|
).append(t.name)
|
||||||
|
|
||||||
|
if _mcp_connector_tools:
|
||||||
|
_perf_log.info("MCP connector tool routing: %s", _mcp_connector_tools)
|
||||||
|
|
||||||
if agent_config is not None:
|
if agent_config is not None:
|
||||||
system_prompt = build_configurable_system_prompt(
|
system_prompt = build_configurable_system_prompt(
|
||||||
custom_system_instructions=agent_config.system_instructions,
|
custom_system_instructions=agent_config.system_instructions,
|
||||||
|
|
@ -322,12 +336,14 @@ async def create_surfsense_deep_agent(
|
||||||
thread_visibility=thread_visibility,
|
thread_visibility=thread_visibility,
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
|
mcp_connector_tools=_mcp_connector_tools,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
system_prompt = build_surfsense_system_prompt(
|
system_prompt = build_surfsense_system_prompt(
|
||||||
thread_visibility=thread_visibility,
|
thread_visibility=thread_visibility,
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
|
mcp_connector_tools=_mcp_connector_tools,
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
||||||
|
|
|
||||||
|
|
@ -815,11 +815,36 @@ Your goal is to provide helpful, informative answers in a clean, readable format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_mcp_routing_block(
|
||||||
|
mcp_connector_tools: dict[str, list[str]] | None,
|
||||||
|
) -> str:
|
||||||
|
"""Build an additional tool routing block for generic MCP connectors.
|
||||||
|
|
||||||
|
When users add MCP servers (e.g. GitLab, GitHub), the LLM needs to know
|
||||||
|
those tools exist and should be called directly — not searched in the
|
||||||
|
knowledge base.
|
||||||
|
"""
|
||||||
|
if not mcp_connector_tools:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"\n<mcp_tool_routing>",
|
||||||
|
"You also have direct tools from these user-connected MCP servers.",
|
||||||
|
"Their data is NEVER in the knowledge base — call their tools directly.",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
for server_name, tool_names in mcp_connector_tools.items():
|
||||||
|
lines.append(f"- {server_name} → {', '.join(tool_names)}")
|
||||||
|
lines.append("</mcp_tool_routing>\n")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def build_surfsense_system_prompt(
|
def build_surfsense_system_prompt(
|
||||||
today: datetime | None = None,
|
today: datetime | None = None,
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
enabled_tool_names: set[str] | None = None,
|
enabled_tool_names: set[str] | None = None,
|
||||||
disabled_tool_names: set[str] | None = None,
|
disabled_tool_names: set[str] | None = None,
|
||||||
|
mcp_connector_tools: dict[str, list[str]] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build the SurfSense system prompt with default settings.
|
Build the SurfSense system prompt with default settings.
|
||||||
|
|
@ -834,6 +859,9 @@ def build_surfsense_system_prompt(
|
||||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||||
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
||||||
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
||||||
|
mcp_connector_tools: Mapping of MCP server display name → list of tool names
|
||||||
|
for generic MCP connectors. Injected into the system prompt so the LLM
|
||||||
|
knows to call these tools directly.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Complete system prompt string
|
Complete system prompt string
|
||||||
|
|
@ -841,6 +869,7 @@ def build_surfsense_system_prompt(
|
||||||
|
|
||||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
system_instructions = _get_system_instructions(visibility, today)
|
system_instructions = _get_system_instructions(visibility, today)
|
||||||
|
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
|
||||||
tools_instructions = _get_tools_instructions(
|
tools_instructions = _get_tools_instructions(
|
||||||
visibility, enabled_tool_names, disabled_tool_names
|
visibility, enabled_tool_names, disabled_tool_names
|
||||||
)
|
)
|
||||||
|
|
@ -856,6 +885,7 @@ def build_configurable_system_prompt(
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
enabled_tool_names: set[str] | None = None,
|
enabled_tool_names: set[str] | None = None,
|
||||||
disabled_tool_names: set[str] | None = None,
|
disabled_tool_names: set[str] | None = None,
|
||||||
|
mcp_connector_tools: dict[str, list[str]] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
||||||
|
|
@ -877,6 +907,9 @@ def build_configurable_system_prompt(
|
||||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||||
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
||||||
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
||||||
|
mcp_connector_tools: Mapping of MCP server display name → list of tool names
|
||||||
|
for generic MCP connectors. Injected into the system prompt so the LLM
|
||||||
|
knows to call these tools directly.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Complete system prompt string
|
Complete system prompt string
|
||||||
|
|
@ -894,6 +927,8 @@ def build_configurable_system_prompt(
|
||||||
else:
|
else:
|
||||||
system_instructions = ""
|
system_instructions = ""
|
||||||
|
|
||||||
|
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
|
||||||
|
|
||||||
# Tools instructions: only include enabled tools, note disabled ones
|
# Tools instructions: only include enabled tools, note disabled ones
|
||||||
tools_instructions = _get_tools_instructions(
|
tools_instructions = _get_tools_instructions(
|
||||||
thread_visibility, enabled_tool_names, disabled_tool_names
|
thread_visibility, enabled_tool_names, disabled_tool_names
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,18 @@ class MCPClient:
|
||||||
async def connect(self, max_retries: int = MAX_RETRIES):
|
async def connect(self, max_retries: int = MAX_RETRIES):
|
||||||
"""Connect to the MCP server and manage its lifecycle.
|
"""Connect to the MCP server and manage its lifecycle.
|
||||||
|
|
||||||
|
Retries only apply to the **connection** phase (spawning the process,
|
||||||
|
initialising the session). Once the session is yielded to the caller,
|
||||||
|
any exception raised by the caller propagates normally -- the context
|
||||||
|
manager will NOT retry after ``yield``.
|
||||||
|
|
||||||
|
Previous implementation wrapped both connection AND yield inside the
|
||||||
|
retry loop. Because ``@asynccontextmanager`` only allows a single
|
||||||
|
``yield``, a failure after yield caused the generator to attempt a
|
||||||
|
second yield on retry, triggering
|
||||||
|
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
|
||||||
|
the stdio subprocess.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
max_retries: Maximum number of connection retry attempts
|
max_retries: Maximum number of connection retry attempts
|
||||||
|
|
||||||
|
|
@ -57,26 +69,22 @@ class MCPClient:
|
||||||
"""
|
"""
|
||||||
last_error = None
|
last_error = None
|
||||||
delay = RETRY_DELAY
|
delay = RETRY_DELAY
|
||||||
|
connected = False
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
# Merge env vars with current environment
|
|
||||||
server_env = os.environ.copy()
|
server_env = os.environ.copy()
|
||||||
server_env.update(self.env)
|
server_env.update(self.env)
|
||||||
|
|
||||||
# Create server parameters with env
|
|
||||||
server_params = StdioServerParameters(
|
server_params = StdioServerParameters(
|
||||||
command=self.command, args=self.args, env=server_env
|
command=self.command, args=self.args, env=server_env
|
||||||
)
|
)
|
||||||
|
|
||||||
# Spawn server process and create session
|
|
||||||
# Note: Cannot combine these context managers because ClientSession
|
|
||||||
# needs the read/write streams from stdio_client
|
|
||||||
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
||||||
async with ClientSession(read, write) as session:
|
async with ClientSession(read, write) as session:
|
||||||
# Initialize the connection
|
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
self.session = session
|
self.session = session
|
||||||
|
connected = True
|
||||||
|
|
||||||
if attempt > 0:
|
if attempt > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -91,10 +99,16 @@ class MCPClient:
|
||||||
self.command,
|
self.command,
|
||||||
" ".join(self.args),
|
" ".join(self.args),
|
||||||
)
|
)
|
||||||
yield session
|
try:
|
||||||
return # Success, exit retry loop
|
yield session
|
||||||
|
finally:
|
||||||
|
self.session = None
|
||||||
|
return
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self.session = None
|
||||||
|
if connected:
|
||||||
|
raise
|
||||||
last_error = e
|
last_error = e
|
||||||
if attempt < max_retries - 1:
|
if attempt < max_retries - 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -105,7 +119,7 @@ class MCPClient:
|
||||||
delay,
|
delay,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
delay *= RETRY_BACKOFF # Exponential backoff
|
delay *= RETRY_BACKOFF
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to connect to MCP server after %d attempts: %s",
|
"Failed to connect to MCP server after %d attempts: %s",
|
||||||
|
|
@ -113,10 +127,7 @@ class MCPClient:
|
||||||
e,
|
e,
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
self.session = None
|
|
||||||
|
|
||||||
# All retries exhausted
|
|
||||||
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
|
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
|
||||||
if last_error:
|
if last_error:
|
||||||
error_msg += f": {last_error}"
|
error_msg += f": {last_error}"
|
||||||
|
|
@ -161,12 +172,18 @@ class MCPClient:
|
||||||
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
|
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
async def call_tool(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any],
|
||||||
|
timeout: float = 60.0,
|
||||||
|
) -> Any:
|
||||||
"""Call a tool on the MCP server.
|
"""Call a tool on the MCP server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_name: Name of the tool to call
|
tool_name: Name of the tool to call
|
||||||
arguments: Arguments to pass to the tool
|
arguments: Arguments to pass to the tool
|
||||||
|
timeout: Maximum seconds to wait for the tool to respond
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tool execution result
|
Tool execution result
|
||||||
|
|
@ -185,10 +202,11 @@ class MCPClient:
|
||||||
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
|
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call tools/call RPC method
|
response = await asyncio.wait_for(
|
||||||
response = await self.session.call_tool(tool_name, arguments=arguments)
|
self.session.call_tool(tool_name, arguments=arguments),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
# Extract content from response
|
|
||||||
result = []
|
result = []
|
||||||
for content in response.content:
|
for content in response.content:
|
||||||
if hasattr(content, "text"):
|
if hasattr(content, "text"):
|
||||||
|
|
@ -202,15 +220,17 @@ class MCPClient:
|
||||||
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
|
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
|
||||||
return result_str
|
return result_str
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
"MCP tool '%s' timed out after %.0fs", tool_name, timeout
|
||||||
|
)
|
||||||
|
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Handle validation errors from MCP server responses
|
|
||||||
# Some MCP servers (like server-memory) return extra fields not in their schema
|
|
||||||
if "Invalid structured content" in str(e):
|
if "Invalid structured content" in str(e):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"MCP server returned data not matching its schema, but continuing: %s",
|
"MCP server returned data not matching its schema, but continuing: %s",
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
# Try to extract result from error message or return a success message
|
|
||||||
return "Operation completed (server returned unexpected format)"
|
return "Operation completed (server returned unexpected format)"
|
||||||
raise
|
raise
|
||||||
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ clicking "Always Allow", which adds the tool name to the connector's
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
@ -27,7 +28,7 @@ if TYPE_CHECKING:
|
||||||
from langchain_core.tools import StructuredTool
|
from langchain_core.tools import StructuredTool
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
from pydantic import BaseModel, Field, create_model
|
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||||
from sqlalchemy import cast, select
|
from sqlalchemy import cast, select
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
@ -41,6 +42,9 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
|
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
|
||||||
_MCP_CACHE_MAX_SIZE = 50
|
_MCP_CACHE_MAX_SIZE = 50
|
||||||
|
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
|
||||||
|
_TOOL_CALL_MAX_RETRIES = 3
|
||||||
|
_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt
|
||||||
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
|
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -62,7 +66,18 @@ def _create_dynamic_input_model_from_schema(
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
input_schema: dict[str, Any],
|
input_schema: dict[str, Any],
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Create a Pydantic model from MCP tool's JSON schema."""
|
"""Create a Pydantic model from MCP tool's JSON schema.
|
||||||
|
|
||||||
|
Models always allow extra fields (``extra="allow"``) so that parameters
|
||||||
|
missing from a broken or incomplete JSON schema (e.g. ``zod-to-json-schema``
|
||||||
|
producing an empty ``$schema``-only object) can still be forwarded to the
|
||||||
|
MCP server.
|
||||||
|
|
||||||
|
When the schema declares **no** properties, a synthetic ``input_data``
|
||||||
|
field of type ``dict`` is injected so the LLM has a visible parameter to
|
||||||
|
populate. The caller should unpack ``input_data`` before forwarding to
|
||||||
|
the MCP server (see ``_unpack_synthetic_input_data``).
|
||||||
|
"""
|
||||||
properties = input_schema.get("properties", {})
|
properties = input_schema.get("properties", {})
|
||||||
required_fields = input_schema.get("required", [])
|
required_fields = input_schema.get("required", [])
|
||||||
|
|
||||||
|
|
@ -82,8 +97,35 @@ def _create_dynamic_input_model_from_schema(
|
||||||
Field(None, description=param_description),
|
Field(None, description=param_description),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not properties:
|
||||||
|
field_definitions["input_data"] = (
|
||||||
|
dict[str, Any] | None,
|
||||||
|
Field(
|
||||||
|
None,
|
||||||
|
description=(
|
||||||
|
"Arguments to pass to this tool as a JSON object. "
|
||||||
|
"Infer sensible key names from the tool name and description "
|
||||||
|
"(e.g. {\"search\": \"my query\"} for a search tool)."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
|
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
|
||||||
return create_model(model_name, **field_definitions)
|
model = create_model(model_name, __config__=ConfigDict(extra="allow"), **field_definitions)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _unpack_synthetic_input_data(kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Unpack the synthetic ``input_data`` field into top-level kwargs.
|
||||||
|
|
||||||
|
When the MCP tool schema is empty, ``_create_dynamic_input_model_from_schema``
|
||||||
|
adds a catch-all ``input_data: dict`` field. This helper merges that dict
|
||||||
|
back into the top-level kwargs so the MCP server receives flat arguments.
|
||||||
|
"""
|
||||||
|
input_data = kwargs.pop("input_data", None)
|
||||||
|
if isinstance(input_data, dict):
|
||||||
|
kwargs.update(input_data)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
async def _create_mcp_tool_from_definition_stdio(
|
async def _create_mcp_tool_from_definition_stdio(
|
||||||
|
|
@ -101,7 +143,12 @@ async def _create_mcp_tool_from_definition_stdio(
|
||||||
``GraphInterrupt`` propagates cleanly to LangGraph.
|
``GraphInterrupt`` propagates cleanly to LangGraph.
|
||||||
"""
|
"""
|
||||||
tool_name = tool_def.get("name", "unnamed_tool")
|
tool_name = tool_def.get("name", "unnamed_tool")
|
||||||
tool_description = tool_def.get("description", "No description provided")
|
raw_description = tool_def.get("description", "No description provided")
|
||||||
|
tool_description = (
|
||||||
|
f"[MCP server: {connector_name}] {raw_description}"
|
||||||
|
if connector_name
|
||||||
|
else raw_description
|
||||||
|
)
|
||||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||||
|
|
||||||
logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema)
|
logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema)
|
||||||
|
|
@ -119,7 +166,7 @@ async def _create_mcp_tool_from_definition_stdio(
|
||||||
params=kwargs,
|
params=kwargs,
|
||||||
context={
|
context={
|
||||||
"mcp_server": connector_name,
|
"mcp_server": connector_name,
|
||||||
"tool_description": tool_description,
|
"tool_description": raw_description,
|
||||||
"mcp_transport": "stdio",
|
"mcp_transport": "stdio",
|
||||||
"mcp_connector_id": connector_id,
|
"mcp_connector_id": connector_id,
|
||||||
},
|
},
|
||||||
|
|
@ -127,18 +174,32 @@ async def _create_mcp_tool_from_definition_stdio(
|
||||||
)
|
)
|
||||||
if hitl_result.rejected:
|
if hitl_result.rejected:
|
||||||
return "Tool call rejected by user."
|
return "Tool call rejected by user."
|
||||||
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
call_kwargs = _unpack_synthetic_input_data(
|
||||||
|
{k: v for k, v in hitl_result.params.items() if v is not None}
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
last_error: Exception | None = None
|
||||||
async with mcp_client.connect():
|
for attempt in range(_TOOL_CALL_MAX_RETRIES):
|
||||||
result = await mcp_client.call_tool(tool_name, call_kwargs)
|
try:
|
||||||
return str(result)
|
async with mcp_client.connect():
|
||||||
except RuntimeError as e:
|
result = await mcp_client.call_tool(tool_name, call_kwargs)
|
||||||
logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e)
|
return str(result)
|
||||||
return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
except Exception as e:
|
||||||
except Exception as e:
|
last_error = e
|
||||||
logger.exception("MCP tool '%s' execution failed: %s", tool_name, e)
|
if attempt < _TOOL_CALL_MAX_RETRIES - 1:
|
||||||
return f"Error: MCP tool '{tool_name}' execution failed: {e!s}"
|
delay = _TOOL_CALL_RETRY_DELAY * (2 ** attempt)
|
||||||
|
logger.warning(
|
||||||
|
"MCP tool '%s' failed (attempt %d/%d): %s. Retrying in %.1fs...",
|
||||||
|
tool_name, attempt + 1, _TOOL_CALL_MAX_RETRIES, e, delay,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"MCP tool '%s' failed after %d attempts: %s",
|
||||||
|
tool_name, _TOOL_CALL_MAX_RETRIES, e, exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"Error: MCP tool '{tool_name}' failed after {_TOOL_CALL_MAX_RETRIES} attempts: {last_error!s}"
|
||||||
|
|
||||||
tool = StructuredTool(
|
tool = StructuredTool(
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
|
|
@ -148,6 +209,8 @@ async def _create_mcp_tool_from_definition_stdio(
|
||||||
metadata={
|
metadata={
|
||||||
"mcp_input_schema": input_schema,
|
"mcp_input_schema": input_schema,
|
||||||
"mcp_transport": "stdio",
|
"mcp_transport": "stdio",
|
||||||
|
"mcp_connector_name": connector_name or None,
|
||||||
|
"mcp_is_generic": True,
|
||||||
"hitl": True,
|
"hitl": True,
|
||||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||||
},
|
},
|
||||||
|
|
@ -167,6 +230,7 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
trusted_tools: list[str] | None = None,
|
trusted_tools: list[str] | None = None,
|
||||||
readonly_tools: frozenset[str] | None = None,
|
readonly_tools: frozenset[str] | None = None,
|
||||||
tool_name_prefix: str | None = None,
|
tool_name_prefix: str | None = None,
|
||||||
|
is_generic_mcp: bool = False,
|
||||||
) -> StructuredTool:
|
) -> StructuredTool:
|
||||||
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
||||||
|
|
||||||
|
|
@ -178,7 +242,7 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
but the actual MCP ``call_tool`` still uses the original name.
|
but the actual MCP ``call_tool`` still uses the original name.
|
||||||
"""
|
"""
|
||||||
original_tool_name = tool_def.get("name", "unnamed_tool")
|
original_tool_name = tool_def.get("name", "unnamed_tool")
|
||||||
tool_description = tool_def.get("description", "No description provided")
|
raw_description = tool_def.get("description", "No description provided")
|
||||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||||
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
|
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
|
||||||
|
|
||||||
|
|
@ -188,18 +252,51 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
else original_tool_name
|
else original_tool_name
|
||||||
)
|
)
|
||||||
if tool_name_prefix:
|
if tool_name_prefix:
|
||||||
tool_description = f"[Account: {connector_name}] {tool_description}"
|
tool_description = f"[Account: {connector_name}] {raw_description}"
|
||||||
|
elif is_generic_mcp and connector_name:
|
||||||
|
tool_description = f"[MCP server: {connector_name}] {raw_description}"
|
||||||
|
else:
|
||||||
|
tool_description = raw_description
|
||||||
|
|
||||||
logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema)
|
logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema)
|
||||||
|
|
||||||
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
|
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
|
||||||
|
|
||||||
|
async def _do_mcp_call(
|
||||||
|
call_headers: dict[str, str],
|
||||||
|
call_kwargs: dict[str, Any],
|
||||||
|
timeout: float = 60.0,
|
||||||
|
) -> str:
|
||||||
|
"""Execute a single MCP HTTP call with the given headers."""
|
||||||
|
async with (
|
||||||
|
streamablehttp_client(url, headers=call_headers) as (read, write, _),
|
||||||
|
ClientSession(read, write) as session,
|
||||||
|
):
|
||||||
|
await session.initialize()
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
session.call_tool(original_tool_name, arguments=call_kwargs),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for content in response.content:
|
||||||
|
if hasattr(content, "text"):
|
||||||
|
result.append(content.text)
|
||||||
|
elif hasattr(content, "data"):
|
||||||
|
result.append(str(content.data))
|
||||||
|
else:
|
||||||
|
result.append(str(content))
|
||||||
|
|
||||||
|
return "\n".join(result) if result else ""
|
||||||
|
|
||||||
async def mcp_http_tool_call(**kwargs) -> str:
|
async def mcp_http_tool_call(**kwargs) -> str:
|
||||||
"""Execute the MCP tool call via HTTP transport."""
|
"""Execute the MCP tool call via HTTP transport."""
|
||||||
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
||||||
|
|
||||||
if is_readonly:
|
if is_readonly:
|
||||||
call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
call_kwargs = _unpack_synthetic_input_data(
|
||||||
|
{k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hitl_result = request_approval(
|
hitl_result = request_approval(
|
||||||
action_type="mcp_tool_call",
|
action_type="mcp_tool_call",
|
||||||
|
|
@ -207,7 +304,7 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
params=kwargs,
|
params=kwargs,
|
||||||
context={
|
context={
|
||||||
"mcp_server": connector_name,
|
"mcp_server": connector_name,
|
||||||
"tool_description": tool_description,
|
"tool_description": raw_description,
|
||||||
"mcp_transport": "http",
|
"mcp_transport": "http",
|
||||||
"mcp_connector_id": connector_id,
|
"mcp_connector_id": connector_id,
|
||||||
},
|
},
|
||||||
|
|
@ -215,34 +312,51 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
)
|
)
|
||||||
if hitl_result.rejected:
|
if hitl_result.rejected:
|
||||||
return "Tool call rejected by user."
|
return "Tool call rejected by user."
|
||||||
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
call_kwargs = _unpack_synthetic_input_data(
|
||||||
|
{k: v for k, v in hitl_result.params.items() if v is not None}
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with (
|
result_str = await _do_mcp_call(headers, call_kwargs)
|
||||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
|
||||||
ClientSession(read, write) as session,
|
return result_str
|
||||||
):
|
|
||||||
await session.initialize()
|
except Exception as first_err:
|
||||||
response = await session.call_tool(
|
if not _is_auth_error(first_err) or connector_id is None:
|
||||||
original_tool_name, arguments=call_kwargs,
|
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, first_err)
|
||||||
|
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {first_err!s}"
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"MCP HTTP tool '%s' got 401 — attempting token refresh for connector %s",
|
||||||
|
exposed_name, connector_id,
|
||||||
|
)
|
||||||
|
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
||||||
|
if fresh_headers is None:
|
||||||
|
await _mark_connector_auth_expired(connector_id)
|
||||||
|
return (
|
||||||
|
f"Error: MCP tool '{exposed_name}' authentication expired. "
|
||||||
|
"Please re-authenticate the connector in your settings."
|
||||||
)
|
)
|
||||||
|
|
||||||
result = []
|
try:
|
||||||
for content in response.content:
|
result_str = await _do_mcp_call(fresh_headers, call_kwargs)
|
||||||
if hasattr(content, "text"):
|
logger.info(
|
||||||
result.append(content.text)
|
"MCP HTTP tool '%s' succeeded after 401 recovery",
|
||||||
elif hasattr(content, "data"):
|
exposed_name,
|
||||||
result.append(str(content.data))
|
)
|
||||||
else:
|
|
||||||
result.append(str(content))
|
|
||||||
|
|
||||||
result_str = "\n".join(result) if result else ""
|
|
||||||
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
|
|
||||||
return result_str
|
return result_str
|
||||||
|
except Exception as retry_err:
|
||||||
except Exception as e:
|
logger.exception(
|
||||||
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, e)
|
"MCP HTTP tool '%s' still failing after token refresh: %s",
|
||||||
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
|
exposed_name, retry_err,
|
||||||
|
)
|
||||||
|
if _is_auth_error(retry_err):
|
||||||
|
await _mark_connector_auth_expired(connector_id)
|
||||||
|
return (
|
||||||
|
f"Error: MCP tool '{exposed_name}' authentication expired. "
|
||||||
|
"Please re-authenticate the connector in your settings."
|
||||||
|
)
|
||||||
|
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {retry_err!s}"
|
||||||
|
|
||||||
tool = StructuredTool(
|
tool = StructuredTool(
|
||||||
name=exposed_name,
|
name=exposed_name,
|
||||||
|
|
@ -253,6 +367,8 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
"mcp_input_schema": input_schema,
|
"mcp_input_schema": input_schema,
|
||||||
"mcp_transport": "http",
|
"mcp_transport": "http",
|
||||||
"mcp_url": url,
|
"mcp_url": url,
|
||||||
|
"mcp_connector_name": connector_name or None,
|
||||||
|
"mcp_is_generic": is_generic_mcp,
|
||||||
"hitl": not is_readonly,
|
"hitl": not is_readonly,
|
||||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||||
"mcp_original_tool_name": original_tool_name,
|
"mcp_original_tool_name": original_tool_name,
|
||||||
|
|
@ -334,6 +450,7 @@ async def _load_http_mcp_tools(
|
||||||
allowed_tools: list[str] | None = None,
|
allowed_tools: list[str] | None = None,
|
||||||
readonly_tools: frozenset[str] | None = None,
|
readonly_tools: frozenset[str] | None = None,
|
||||||
tool_name_prefix: str | None = None,
|
tool_name_prefix: str | None = None,
|
||||||
|
is_generic_mcp: bool = False,
|
||||||
) -> list[StructuredTool]:
|
) -> list[StructuredTool]:
|
||||||
"""Load tools from an HTTP-based MCP server.
|
"""Load tools from an HTTP-based MCP server.
|
||||||
|
|
||||||
|
|
@ -365,66 +482,99 @@ async def _load_http_mcp_tools(
|
||||||
|
|
||||||
allowed_set = set(allowed_tools) if allowed_tools else None
|
allowed_set = set(allowed_tools) if allowed_tools else None
|
||||||
|
|
||||||
try:
|
async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]:
|
||||||
|
"""Connect, initialize, and list tools from the MCP server."""
|
||||||
async with (
|
async with (
|
||||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
|
||||||
ClientSession(read, write) as session,
|
ClientSession(read, write) as session,
|
||||||
):
|
):
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
response = await session.list_tools()
|
response = await session.list_tools()
|
||||||
tool_definitions = []
|
return [
|
||||||
for tool in response.tools:
|
{
|
||||||
tool_definitions.append(
|
"name": tool.name,
|
||||||
{
|
"description": tool.description or "",
|
||||||
"name": tool.name,
|
"input_schema": tool.inputSchema
|
||||||
"description": tool.description or "",
|
if hasattr(tool, "inputSchema")
|
||||||
"input_schema": tool.inputSchema
|
else {},
|
||||||
if hasattr(tool, "inputSchema")
|
}
|
||||||
else {},
|
for tool in response.tools
|
||||||
}
|
]
|
||||||
)
|
|
||||||
|
|
||||||
total_discovered = len(tool_definitions)
|
try:
|
||||||
|
tool_definitions = await _discover(headers)
|
||||||
|
except Exception as first_err:
|
||||||
|
if not _is_auth_error(first_err) or connector_id is None:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
||||||
|
url, connector_id, first_err,
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
if allowed_set:
|
logger.warning(
|
||||||
tool_definitions = [
|
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
|
||||||
td for td in tool_definitions if td["name"] in allowed_set
|
connector_id,
|
||||||
]
|
|
||||||
logger.info(
|
|
||||||
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
|
|
||||||
url, connector_id, len(tool_definitions), total_discovered,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
|
|
||||||
total_discovered, url, connector_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_def in tool_definitions:
|
|
||||||
try:
|
|
||||||
tool = await _create_mcp_tool_from_definition_http(
|
|
||||||
tool_def,
|
|
||||||
url,
|
|
||||||
headers,
|
|
||||||
connector_name=connector_name,
|
|
||||||
connector_id=connector_id,
|
|
||||||
trusted_tools=trusted_tools,
|
|
||||||
readonly_tools=readonly_tools,
|
|
||||||
tool_name_prefix=tool_name_prefix,
|
|
||||||
)
|
|
||||||
tools.append(tool)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to create HTTP tool '%s' from connector %d: %s",
|
|
||||||
tool_def.get("name"), connector_id, e,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
|
||||||
url, connector_id, e,
|
|
||||||
)
|
)
|
||||||
|
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
||||||
|
if fresh_headers is None:
|
||||||
|
await _mark_connector_auth_expired(connector_id)
|
||||||
|
logger.error(
|
||||||
|
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_definitions = await _discover(fresh_headers)
|
||||||
|
headers = fresh_headers
|
||||||
|
logger.info(
|
||||||
|
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
except Exception as retry_err:
|
||||||
|
logger.exception(
|
||||||
|
"HTTP MCP discovery for connector %d still failing after refresh: %s",
|
||||||
|
connector_id, retry_err,
|
||||||
|
)
|
||||||
|
if _is_auth_error(retry_err):
|
||||||
|
await _mark_connector_auth_expired(connector_id)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
total_discovered = len(tool_definitions)
|
||||||
|
|
||||||
|
if allowed_set:
|
||||||
|
tool_definitions = [
|
||||||
|
td for td in tool_definitions if td["name"] in allowed_set
|
||||||
|
]
|
||||||
|
logger.info(
|
||||||
|
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
|
||||||
|
url, connector_id, len(tool_definitions), total_discovered,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
|
||||||
|
total_discovered, url, connector_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool_def in tool_definitions:
|
||||||
|
try:
|
||||||
|
tool = await _create_mcp_tool_from_definition_http(
|
||||||
|
tool_def,
|
||||||
|
url,
|
||||||
|
headers,
|
||||||
|
connector_name=connector_name,
|
||||||
|
connector_id=connector_id,
|
||||||
|
trusted_tools=trusted_tools,
|
||||||
|
readonly_tools=readonly_tools,
|
||||||
|
tool_name_prefix=tool_name_prefix,
|
||||||
|
is_generic_mcp=is_generic_mcp,
|
||||||
|
)
|
||||||
|
tools.append(tool)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to create HTTP tool '%s' from connector %d: %s",
|
||||||
|
tool_def.get("name"), connector_id, e,
|
||||||
|
)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
@ -476,6 +626,91 @@ def _inject_oauth_headers(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _refresh_connector_token(
|
||||||
|
session: AsyncSession,
|
||||||
|
connector: "SearchSourceConnector",
|
||||||
|
) -> str | None:
|
||||||
|
"""Refresh the OAuth token for an MCP connector and persist the result.
|
||||||
|
|
||||||
|
This is the shared core used by both proactive (pre-expiry) and reactive
|
||||||
|
(401 recovery) refresh paths. It handles:
|
||||||
|
- Decrypting the current refresh token / client secret
|
||||||
|
- Calling the token endpoint
|
||||||
|
- Encrypting and persisting the new tokens
|
||||||
|
- Clearing ``auth_expired`` if it was set
|
||||||
|
- Invalidating the MCP tools cache
|
||||||
|
|
||||||
|
Returns the **plaintext** new access token on success, or ``None`` on
|
||||||
|
failure (no refresh token, IdP error, etc.).
|
||||||
|
"""
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
from app.services.mcp_oauth.discovery import refresh_access_token
|
||||||
|
|
||||||
|
cfg = connector.config or {}
|
||||||
|
mcp_oauth = cfg.get("mcp_oauth", {})
|
||||||
|
|
||||||
|
refresh_token = mcp_oauth.get("refresh_token")
|
||||||
|
if not refresh_token:
|
||||||
|
logger.warning(
|
||||||
|
"MCP connector %s: no refresh_token available",
|
||||||
|
connector.id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
enc = _get_token_enc()
|
||||||
|
decrypted_refresh = enc.decrypt_token(refresh_token)
|
||||||
|
decrypted_secret = (
|
||||||
|
enc.decrypt_token(mcp_oauth["client_secret"])
|
||||||
|
if mcp_oauth.get("client_secret")
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
token_json = await refresh_access_token(
|
||||||
|
token_endpoint=mcp_oauth["token_endpoint"],
|
||||||
|
refresh_token=decrypted_refresh,
|
||||||
|
client_id=mcp_oauth["client_id"],
|
||||||
|
client_secret=decrypted_secret,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_access = token_json.get("access_token")
|
||||||
|
if not new_access:
|
||||||
|
logger.warning(
|
||||||
|
"MCP connector %s: token refresh returned no access_token",
|
||||||
|
connector.id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
new_expires_at = None
|
||||||
|
if token_json.get("expires_in"):
|
||||||
|
new_expires_at = datetime.now(UTC) + timedelta(
|
||||||
|
seconds=int(token_json["expires_in"])
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_oauth = dict(mcp_oauth)
|
||||||
|
updated_oauth["access_token"] = enc.encrypt_token(new_access)
|
||||||
|
if token_json.get("refresh_token"):
|
||||||
|
updated_oauth["refresh_token"] = enc.encrypt_token(
|
||||||
|
token_json["refresh_token"]
|
||||||
|
)
|
||||||
|
updated_oauth["expires_at"] = (
|
||||||
|
new_expires_at.isoformat() if new_expires_at else None
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_cfg = {**cfg, "mcp_oauth": updated_oauth}
|
||||||
|
updated_cfg.pop("auth_expired", None)
|
||||||
|
connector.config = updated_cfg
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(connector)
|
||||||
|
|
||||||
|
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||||
|
|
||||||
|
return new_access
|
||||||
|
|
||||||
|
|
||||||
async def _maybe_refresh_mcp_oauth_token(
|
async def _maybe_refresh_mcp_oauth_token(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
connector: "SearchSourceConnector",
|
connector: "SearchSourceConnector",
|
||||||
|
|
@ -504,73 +739,13 @@ async def _maybe_refresh_mcp_oauth_token(
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return server_config
|
return server_config
|
||||||
|
|
||||||
refresh_token = mcp_oauth.get("refresh_token")
|
|
||||||
if not refresh_token:
|
|
||||||
logger.warning(
|
|
||||||
"MCP connector %s token expired but no refresh_token available",
|
|
||||||
connector.id,
|
|
||||||
)
|
|
||||||
return server_config
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from app.services.mcp_oauth.discovery import refresh_access_token
|
new_access = await _refresh_connector_token(session, connector)
|
||||||
|
|
||||||
enc = _get_token_enc()
|
|
||||||
decrypted_refresh = enc.decrypt_token(refresh_token)
|
|
||||||
decrypted_secret = (
|
|
||||||
enc.decrypt_token(mcp_oauth["client_secret"])
|
|
||||||
if mcp_oauth.get("client_secret")
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
token_json = await refresh_access_token(
|
|
||||||
token_endpoint=mcp_oauth["token_endpoint"],
|
|
||||||
refresh_token=decrypted_refresh,
|
|
||||||
client_id=mcp_oauth["client_id"],
|
|
||||||
client_secret=decrypted_secret,
|
|
||||||
)
|
|
||||||
|
|
||||||
new_access = token_json.get("access_token")
|
|
||||||
if not new_access:
|
if not new_access:
|
||||||
logger.warning(
|
|
||||||
"MCP connector %s token refresh returned no access_token",
|
|
||||||
connector.id,
|
|
||||||
)
|
|
||||||
return server_config
|
return server_config
|
||||||
|
|
||||||
new_expires_at = None
|
logger.info("Proactively refreshed MCP OAuth token for connector %s", connector.id)
|
||||||
if token_json.get("expires_in"):
|
|
||||||
new_expires_at = datetime.now(UTC) + timedelta(
|
|
||||||
seconds=int(token_json["expires_in"])
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_oauth = dict(mcp_oauth)
|
|
||||||
updated_oauth["access_token"] = enc.encrypt_token(new_access)
|
|
||||||
if token_json.get("refresh_token"):
|
|
||||||
updated_oauth["refresh_token"] = enc.encrypt_token(
|
|
||||||
token_json["refresh_token"]
|
|
||||||
)
|
|
||||||
updated_oauth["expires_at"] = (
|
|
||||||
new_expires_at.isoformat() if new_expires_at else None
|
|
||||||
)
|
|
||||||
|
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
|
||||||
|
|
||||||
connector.config = {
|
|
||||||
**cfg,
|
|
||||||
"server_config": server_config,
|
|
||||||
"mcp_oauth": updated_oauth,
|
|
||||||
}
|
|
||||||
flag_modified(connector, "config")
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(connector)
|
|
||||||
|
|
||||||
logger.info("Refreshed MCP OAuth token for connector %s", connector.id)
|
|
||||||
|
|
||||||
# Invalidate cache so next call picks up the new token.
|
|
||||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
|
||||||
|
|
||||||
# Return server_config with the fresh token injected for immediate use.
|
|
||||||
refreshed_config = dict(server_config)
|
refreshed_config = dict(server_config)
|
||||||
refreshed_config["headers"] = {
|
refreshed_config["headers"] = {
|
||||||
**server_config.get("headers", {}),
|
**server_config.get("headers", {}),
|
||||||
|
|
@ -587,6 +762,117 @@ async def _maybe_refresh_mcp_oauth_token(
|
||||||
return server_config
|
return server_config
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Reactive 401 handling helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _is_auth_error(exc: Exception) -> bool:
|
||||||
|
"""Check if an exception indicates an HTTP 401 authentication failure."""
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
if isinstance(exc, httpx.HTTPStatusError):
|
||||||
|
return exc.response.status_code == 401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
err_str = str(exc).lower()
|
||||||
|
return "401" in err_str or "unauthorized" in err_str
|
||||||
|
|
||||||
|
|
||||||
|
async def _force_refresh_and_get_headers(
|
||||||
|
connector_id: int,
|
||||||
|
) -> dict[str, str] | None:
|
||||||
|
"""Force-refresh OAuth token for a connector and return fresh HTTP headers.
|
||||||
|
|
||||||
|
Opens a **new** DB session so this can be called from inside tool closures
|
||||||
|
that don't have access to the original session.
|
||||||
|
|
||||||
|
Returns ``None`` when the connector is not OAuth-backed, has no
|
||||||
|
refresh token, or the refresh itself fails.
|
||||||
|
"""
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == connector_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cfg = connector.config or {}
|
||||||
|
if not cfg.get("mcp_oauth"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
server_config = cfg.get("server_config", {})
|
||||||
|
|
||||||
|
new_access = await _refresh_connector_token(session, connector)
|
||||||
|
if not new_access:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Force-refreshed MCP OAuth token for connector %s (401 recovery)",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
**server_config.get("headers", {}),
|
||||||
|
"Authorization": f"Bearer {new_access}",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to force-refresh MCP OAuth token for connector %s",
|
||||||
|
connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _mark_connector_auth_expired(connector_id: int) -> None:
|
||||||
|
"""Set ``config.auth_expired = True`` so the frontend shows re-auth UI."""
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == connector_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return
|
||||||
|
|
||||||
|
cfg = dict(connector.config or {})
|
||||||
|
if cfg.get("auth_expired"):
|
||||||
|
return
|
||||||
|
|
||||||
|
cfg["auth_expired"] = True
|
||||||
|
connector.config = cfg
|
||||||
|
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Marked MCP connector %s as auth_expired after unrecoverable 401",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to mark connector %s as auth_expired",
|
||||||
|
connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||||
"""Invalidate cached MCP tools.
|
"""Invalidate cached MCP tools.
|
||||||
|
|
||||||
|
|
@ -661,7 +947,7 @@ async def load_mcp_tools(
|
||||||
multi_account_types,
|
multi_account_types,
|
||||||
)
|
)
|
||||||
|
|
||||||
tools: list[StructuredTool] = []
|
discovery_tasks: list[dict[str, Any]] = []
|
||||||
for connector in connectors:
|
for connector in connectors:
|
||||||
try:
|
try:
|
||||||
cfg = connector.config or {}
|
cfg = connector.config or {}
|
||||||
|
|
@ -674,14 +960,10 @@ async def load_mcp_tools(
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# For MCP OAuth connectors: refresh if needed, then decrypt the
|
|
||||||
# access token and inject it into headers at runtime. The DB
|
|
||||||
# intentionally does NOT store plaintext tokens in server_config.
|
|
||||||
if cfg.get("mcp_oauth"):
|
if cfg.get("mcp_oauth"):
|
||||||
server_config = await _maybe_refresh_mcp_oauth_token(
|
server_config = await _maybe_refresh_mcp_oauth_token(
|
||||||
session, connector, cfg, server_config,
|
session, connector, cfg, server_config,
|
||||||
)
|
)
|
||||||
# Re-read cfg after potential refresh (connector was reloaded from DB).
|
|
||||||
cfg = connector.config or {}
|
cfg = connector.config or {}
|
||||||
server_config = _inject_oauth_headers(cfg, server_config)
|
server_config = _inject_oauth_headers(cfg, server_config)
|
||||||
if server_config is None:
|
if server_config is None:
|
||||||
|
|
@ -689,6 +971,7 @@ async def load_mcp_tools(
|
||||||
"Skipping MCP connector %d — OAuth token decryption failed",
|
"Skipping MCP connector %d — OAuth token decryption failed",
|
||||||
connector.id,
|
connector.id,
|
||||||
)
|
)
|
||||||
|
await _mark_connector_auth_expired(connector.id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
trusted_tools = cfg.get("trusted_tools", [])
|
trusted_tools = cfg.get("trusted_tools", [])
|
||||||
|
|
@ -703,7 +986,6 @@ async def load_mcp_tools(
|
||||||
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
||||||
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
||||||
|
|
||||||
# Build a prefix only when multiple accounts share the same type.
|
|
||||||
tool_name_prefix: str | None = None
|
tool_name_prefix: str | None = None
|
||||||
if ct in multi_account_types and svc_cfg:
|
if ct in multi_account_types and svc_cfg:
|
||||||
service_key = next(
|
service_key = next(
|
||||||
|
|
@ -713,34 +995,68 @@ async def load_mcp_tools(
|
||||||
if service_key:
|
if service_key:
|
||||||
tool_name_prefix = f"{service_key}_{connector.id}"
|
tool_name_prefix = f"{service_key}_{connector.id}"
|
||||||
|
|
||||||
transport = server_config.get("transport", "stdio")
|
discovery_tasks.append({
|
||||||
|
"connector_id": connector.id,
|
||||||
if transport in ("streamable-http", "http", "sse"):
|
"connector_name": connector.name,
|
||||||
connector_tools = await _load_http_mcp_tools(
|
"server_config": server_config,
|
||||||
connector.id,
|
"trusted_tools": trusted_tools,
|
||||||
connector.name,
|
"allowed_tools": allowed_tools,
|
||||||
server_config,
|
"readonly_tools": readonly_tools,
|
||||||
trusted_tools=trusted_tools,
|
"tool_name_prefix": tool_name_prefix,
|
||||||
allowed_tools=allowed_tools,
|
"transport": server_config.get("transport", "stdio"),
|
||||||
readonly_tools=readonly_tools,
|
"is_generic_mcp": svc_cfg is None,
|
||||||
tool_name_prefix=tool_name_prefix,
|
})
|
||||||
)
|
|
||||||
else:
|
|
||||||
connector_tools = await _load_stdio_mcp_tools(
|
|
||||||
connector.id,
|
|
||||||
connector.name,
|
|
||||||
server_config,
|
|
||||||
trusted_tools=trusted_tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
tools.extend(connector_tools)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to load tools from MCP connector %d: %s",
|
"Failed to prepare MCP connector %d: %s",
|
||||||
connector.id, e,
|
connector.id, e,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
|
||||||
|
try:
|
||||||
|
if task["transport"] in ("streamable-http", "http", "sse"):
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
_load_http_mcp_tools(
|
||||||
|
task["connector_id"],
|
||||||
|
task["connector_name"],
|
||||||
|
task["server_config"],
|
||||||
|
trusted_tools=task["trusted_tools"],
|
||||||
|
allowed_tools=task["allowed_tools"],
|
||||||
|
readonly_tools=task["readonly_tools"],
|
||||||
|
tool_name_prefix=task["tool_name_prefix"],
|
||||||
|
is_generic_mcp=task.get("is_generic_mcp", False),
|
||||||
|
),
|
||||||
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
_load_stdio_mcp_tools(
|
||||||
|
task["connector_id"],
|
||||||
|
task["connector_name"],
|
||||||
|
task["server_config"],
|
||||||
|
trusted_tools=task["trusted_tools"],
|
||||||
|
),
|
||||||
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
"MCP connector %d timed out after %ds during discovery",
|
||||||
|
task["connector_id"], _MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to load tools from MCP connector %d: %s",
|
||||||
|
task["connector_id"], e,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
|
||||||
|
tools: list[StructuredTool] = [
|
||||||
|
tool for sublist in results for tool in sublist
|
||||||
|
]
|
||||||
|
|
||||||
_mcp_tools_cache[search_space_id] = (now, tools)
|
_mcp_tools_cache[search_space_id] = (now, tools)
|
||||||
|
|
||||||
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
|
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
|
||||||
|
|
|
||||||
|
|
@ -3105,13 +3105,18 @@ async def trust_mcp_tool(
|
||||||
"""Add a tool to the MCP connector's trusted (always-allow) list.
|
"""Add a tool to the MCP connector's trusted (always-allow) list.
|
||||||
|
|
||||||
Once trusted, the tool executes without HITL approval on subsequent calls.
|
Once trusted, the tool executes without HITL approval on subsequent calls.
|
||||||
|
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors
|
||||||
|
(LINEAR_CONNECTOR, JIRA_CONNECTOR, etc.) by checking for ``server_config``.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
from sqlalchemy import cast
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
|
||||||
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.id == connector_id,
|
SearchSourceConnector.id == connector_id,
|
||||||
SearchSourceConnector.connector_type
|
SearchSourceConnector.user_id == user.id,
|
||||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
connector = result.scalars().first()
|
||||||
|
|
@ -3156,13 +3161,17 @@ async def untrust_mcp_tool(
|
||||||
"""Remove a tool from the MCP connector's trusted list.
|
"""Remove a tool from the MCP connector's trusted list.
|
||||||
|
|
||||||
The tool will require HITL approval again on subsequent calls.
|
The tool will require HITL approval again on subsequent calls.
|
||||||
|
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
from sqlalchemy import cast
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
|
||||||
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.id == connector_id,
|
SearchSourceConnector.id == connector_id,
|
||||||
SearchSourceConnector.connector_type
|
SearchSourceConnector.user_id == user.id,
|
||||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
connector = result.scalars().first()
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -66,6 +65,8 @@ class ConfluenceKBSyncService:
|
||||||
if dup:
|
if dup:
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -184,6 +185,8 @@ class ConfluenceKBSyncService:
|
||||||
|
|
||||||
space_id = (document.document_metadata or {}).get("space_id", "")
|
space_id = (document.document_metadata or {}).get("space_id", "")
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -73,6 +72,8 @@ class DropboxKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -78,6 +77,8 @@ class GmailKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ from app.db import (
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
)
|
)
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -91,6 +90,8 @@ class GoogleCalendarKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -249,6 +250,8 @@ class GoogleCalendarKBSyncService:
|
||||||
if not indexable_content:
|
if not indexable_content:
|
||||||
return {"status": "error", "message": "Event produced empty content"}
|
return {"status": "error", "message": "Event produced empty content"}
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -75,6 +74,8 @@ class GoogleDriveKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.connectors.jira_history import JiraHistoryConnector
|
from app.connectors.jira_history import JiraHistoryConnector
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -75,6 +74,8 @@ class JiraKBSyncService:
|
||||||
if dup:
|
if dup:
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -190,6 +191,8 @@ class JiraKBSyncService:
|
||||||
state = formatted.get("status", "Unknown")
|
state = formatted.get("status", "Unknown")
|
||||||
comment_count = len(formatted.get("comments", []))
|
comment_count = len(formatted.get("comments", []))
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.connectors.linear_connector import LinearConnector
|
from app.connectors.linear_connector import LinearConnector
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -85,6 +84,8 @@ class LinearKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -226,6 +227,8 @@ class LinearKBSyncService:
|
||||||
comment_count = len(formatted_issue.get("comments", []))
|
comment_count = len(formatted_issue.get("comments", []))
|
||||||
formatted_issue.get("description", "")
|
formatted_issue.get("description", "")
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from langchain_litellm import ChatLiteLLM
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import NewLLMConfig, SearchSpace
|
from app.db import NewLLMConfig, SearchSpace
|
||||||
from app.services.llm_router_service import (
|
from app.services.llm_router_service import (
|
||||||
|
|
@ -204,6 +203,8 @@ async def validate_llm_config(
|
||||||
if litellm_params:
|
if litellm_params:
|
||||||
litellm_kwargs.update(litellm_params)
|
litellm_kwargs.update(litellm_params)
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
# Run the test call in a worker thread with a hard timeout. Some
|
# Run the test call in a worker thread with a hard timeout. Some
|
||||||
|
|
@ -377,6 +378,8 @@ async def get_search_space_llm_instance(
|
||||||
if disable_streaming:
|
if disable_streaming:
|
||||||
litellm_kwargs["disable_streaming"] = True
|
litellm_kwargs["disable_streaming"] = True
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
# Get the LLM configuration from database (NewLLMConfig)
|
# Get the LLM configuration from database (NewLLMConfig)
|
||||||
|
|
@ -454,6 +457,8 @@ async def get_search_space_llm_instance(
|
||||||
if disable_streaming:
|
if disable_streaming:
|
||||||
litellm_kwargs["disable_streaming"] = True
|
litellm_kwargs["disable_streaming"] = True
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -555,6 +560,8 @@ async def get_vision_llm(
|
||||||
if global_cfg.get("litellm_params"):
|
if global_cfg.get("litellm_params"):
|
||||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
|
|
@ -588,6 +595,8 @@ async def get_vision_llm(
|
||||||
if vision_cfg.litellm_params:
|
if vision_cfg.litellm_params:
|
||||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -74,6 +73,8 @@ class NotionKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -244,6 +245,8 @@ class NotionKBSyncService:
|
||||||
f"Final content length: {len(full_content)} chars, verified={content_verified}"
|
f"Final content length: {len(full_content)} chars, verified={content_verified}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
logger.debug("Generating summary and embeddings")
|
logger.debug("Generating summary and embeddings")
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -73,6 +72,8 @@ class OneDriveKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
|
||||||
|
|
@ -123,8 +123,9 @@ export const ConnectorIndicator = forwardRef<ConnectorIndicatorHandle, Connector
|
||||||
handleSkipIndexing,
|
handleSkipIndexing,
|
||||||
handleStartEdit,
|
handleStartEdit,
|
||||||
handleSaveConnector,
|
handleSaveConnector,
|
||||||
handleDisconnectConnector,
|
handleDisconnectConnector,
|
||||||
handleBackFromEdit,
|
handleDisconnectFromList,
|
||||||
|
handleBackFromEdit,
|
||||||
handleBackFromConnect,
|
handleBackFromConnect,
|
||||||
handleBackFromYouTube,
|
handleBackFromYouTube,
|
||||||
handleViewAccountsList,
|
handleViewAccountsList,
|
||||||
|
|
@ -225,25 +226,27 @@ export const ConnectorIndicator = forwardRef<ConnectorIndicatorHandle, Connector
|
||||||
{isYouTubeView && searchSpaceId ? (
|
{isYouTubeView && searchSpaceId ? (
|
||||||
<YouTubeCrawlerView searchSpaceId={searchSpaceId} onBack={handleBackFromYouTube} />
|
<YouTubeCrawlerView searchSpaceId={searchSpaceId} onBack={handleBackFromYouTube} />
|
||||||
) : viewingMCPList ? (
|
) : viewingMCPList ? (
|
||||||
<ConnectorAccountsListView
|
<ConnectorAccountsListView
|
||||||
connectorType="MCP_CONNECTOR"
|
connectorType="MCP_CONNECTOR"
|
||||||
connectorTitle="MCP Connectors"
|
connectorTitle="MCP Connectors"
|
||||||
connectors={(allConnectors || []) as SearchSourceConnector[]}
|
connectors={(allConnectors || []) as SearchSourceConnector[]}
|
||||||
indexingConnectorIds={indexingConnectorIds}
|
indexingConnectorIds={indexingConnectorIds}
|
||||||
onBack={handleBackFromMCPList}
|
onBack={handleBackFromMCPList}
|
||||||
onManage={handleStartEdit}
|
onManage={handleStartEdit}
|
||||||
onAddAccount={handleAddNewMCPFromList}
|
onDisconnect={(connector) => handleDisconnectFromList(connector, () => refreshConnectors())}
|
||||||
addButtonText="Add New MCP Server"
|
onAddAccount={handleAddNewMCPFromList}
|
||||||
/>
|
addButtonText="Add New MCP Server"
|
||||||
|
/>
|
||||||
) : viewingAccountsType ? (
|
) : viewingAccountsType ? (
|
||||||
<ConnectorAccountsListView
|
<ConnectorAccountsListView
|
||||||
connectorType={viewingAccountsType.connectorType}
|
connectorType={viewingAccountsType.connectorType}
|
||||||
connectorTitle={viewingAccountsType.connectorTitle}
|
connectorTitle={viewingAccountsType.connectorTitle}
|
||||||
connectors={(connectors || []) as SearchSourceConnector[]}
|
connectors={(connectors || []) as SearchSourceConnector[]}
|
||||||
indexingConnectorIds={indexingConnectorIds}
|
indexingConnectorIds={indexingConnectorIds}
|
||||||
onBack={handleBackFromAccountsList}
|
onBack={handleBackFromAccountsList}
|
||||||
onManage={handleStartEdit}
|
onManage={handleStartEdit}
|
||||||
onAddAccount={() => {
|
onDisconnect={(connector) => handleDisconnectFromList(connector, () => refreshConnectors())}
|
||||||
|
onAddAccount={() => {
|
||||||
// Check both OAUTH_CONNECTORS and COMPOSIO_CONNECTORS
|
// Check both OAUTH_CONNECTORS and COMPOSIO_CONNECTORS
|
||||||
const oauthConnector =
|
const oauthConnector =
|
||||||
OAUTH_CONNECTORS.find(
|
OAUTH_CONNECTORS.find(
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { CheckCircle2, ChevronDown, ChevronUp, Server, XCircle } from "lucide-react";
|
import { CheckCircle2, ChevronDown, ChevronUp, Loader2, Server, XCircle } from "lucide-react";
|
||||||
import { type FC, useRef, useState } from "react";
|
import { type FC, useRef, useState } from "react";
|
||||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
|
|
@ -212,7 +212,14 @@ export const MCPConnectForm: FC<ConnectFormProps> = ({ onSubmit, isSubmitting })
|
||||||
variant="secondary"
|
variant="secondary"
|
||||||
className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80"
|
className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80"
|
||||||
>
|
>
|
||||||
{isTesting ? "Testing Connection" : "Test Connection"}
|
{isTesting ? (
|
||||||
|
<>
|
||||||
|
<Loader2 className="h-3.5 w-3.5 animate-spin" />
|
||||||
|
Testing Connection...
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
"Test Connection"
|
||||||
|
)}
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { CheckCircle2, ChevronDown, ChevronUp, Server, XCircle } from "lucide-react";
|
import { CheckCircle2, ChevronDown, ChevronUp, Loader2, Server, XCircle } from "lucide-react";
|
||||||
import type { FC } from "react";
|
import type { FC } from "react";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||||
|
|
@ -217,7 +217,14 @@ export const MCPConfig: FC<MCPConfigProps> = ({ connector, onConfigChange, onNam
|
||||||
variant="secondary"
|
variant="secondary"
|
||||||
className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80"
|
className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80"
|
||||||
>
|
>
|
||||||
{isTesting ? "Testing Connection" : "Test Connection"}
|
{isTesting ? (
|
||||||
|
<>
|
||||||
|
<Loader2 className="h-3.5 w-3.5 animate-spin" />
|
||||||
|
Testing Connection...
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
"Test Connection"
|
||||||
|
)}
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ import { toast } from "sonner";
|
||||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { Spinner } from "@/components/ui/spinner";
|
import { Spinner } from "@/components/ui/spinner";
|
||||||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
|
||||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||||
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||||
|
|
@ -16,23 +15,11 @@ import { DateRangeSelector } from "../../components/date-range-selector";
|
||||||
import { PeriodicSyncConfig } from "../../components/periodic-sync-config";
|
import { PeriodicSyncConfig } from "../../components/periodic-sync-config";
|
||||||
import { SummaryConfig } from "../../components/summary-config";
|
import { SummaryConfig } from "../../components/summary-config";
|
||||||
import { VisionLLMConfig } from "../../components/vision-llm-config";
|
import { VisionLLMConfig } from "../../components/vision-llm-config";
|
||||||
import { LIVE_CONNECTOR_TYPES } from "../../constants/connector-constants";
|
import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../../constants/connector-constants";
|
||||||
import { getConnectorDisplayName } from "../../tabs/all-connectors-tab";
|
import { getConnectorDisplayName } from "../../tabs/all-connectors-tab";
|
||||||
|
import { MCPServiceConfig } from "../components/mcp-service-config";
|
||||||
import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index";
|
import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index";
|
||||||
|
|
||||||
const REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
|
|
||||||
[EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth",
|
|
||||||
[EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth",
|
|
||||||
[EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth",
|
|
||||||
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth",
|
|
||||||
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth",
|
|
||||||
[EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
|
||||||
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
|
||||||
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
|
||||||
[EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth",
|
|
||||||
[EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth",
|
|
||||||
};
|
|
||||||
|
|
||||||
interface ConnectorEditViewProps {
|
interface ConnectorEditViewProps {
|
||||||
connector: SearchSourceConnector;
|
connector: SearchSourceConnector;
|
||||||
startDate: Date | undefined;
|
startDate: Date | undefined;
|
||||||
|
|
@ -86,7 +73,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
||||||
}) => {
|
}) => {
|
||||||
const searchSpaceIdAtom = useAtomValue(activeSearchSpaceIdAtom);
|
const searchSpaceIdAtom = useAtomValue(activeSearchSpaceIdAtom);
|
||||||
const isAuthExpired = connector.config?.auth_expired === true;
|
const isAuthExpired = connector.config?.auth_expired === true;
|
||||||
const reauthEndpoint = REAUTH_ENDPOINTS[connector.connector_type];
|
const reauthEndpoint = getReauthEndpoint(connector);
|
||||||
const [reauthing, setReauthing] = useState(false);
|
const [reauthing, setReauthing] = useState(false);
|
||||||
|
|
||||||
const handleReauth = useCallback(async () => {
|
const handleReauth = useCallback(async () => {
|
||||||
|
|
@ -124,10 +111,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
||||||
|
|
||||||
// Get connector-specific config component (MCP-backed connectors use a generic view)
|
// Get connector-specific config component (MCP-backed connectors use a generic view)
|
||||||
const ConnectorConfigComponent = useMemo(() => {
|
const ConnectorConfigComponent = useMemo(() => {
|
||||||
if (isMCPBacked) {
|
if (isMCPBacked) return MCPServiceConfig;
|
||||||
const { MCPServiceConfig } = require("../components/mcp-service-config");
|
|
||||||
return MCPServiceConfig as FC<ConnectorConfigProps>;
|
|
||||||
}
|
|
||||||
return getConnectorConfigComponent(connector.connector_type);
|
return getConnectorConfigComponent(connector.connector_type);
|
||||||
}, [connector.connector_type, isMCPBacked]);
|
}, [connector.connector_type, isMCPBacked]);
|
||||||
const [isScrolled, setIsScrolled] = useState(false);
|
const [isScrolled, setIsScrolled] = useState(false);
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
import { EnumConnectorName } from "@/contracts/enums/connector";
|
||||||
|
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Connectors that operate in real time (no background indexing).
|
* Connectors that operate in real time (no background indexing).
|
||||||
|
|
@ -367,5 +368,45 @@ export function getConnectorTelemetryMeta(connectorType: string): ConnectorTelem
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// REAUTH ENDPOINTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Legacy (non-MCP) OAuth reauth endpoints, keyed by connector type.
|
||||||
|
* These are used for connectors that were NOT created via MCP OAuth.
|
||||||
|
*/
|
||||||
|
export const LEGACY_REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
|
||||||
|
[EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth",
|
||||||
|
[EnumConnectorName.JIRA_CONNECTOR]: "/api/v1/auth/jira/connector/reauth",
|
||||||
|
[EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth",
|
||||||
|
[EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth",
|
||||||
|
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth",
|
||||||
|
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth",
|
||||||
|
[EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||||
|
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||||
|
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||||
|
[EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth",
|
||||||
|
[EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth",
|
||||||
|
[EnumConnectorName.CONFLUENCE_CONNECTOR]: "/api/v1/auth/confluence/connector/reauth",
|
||||||
|
[EnumConnectorName.TEAMS_CONNECTOR]: "/api/v1/auth/teams/connector/reauth",
|
||||||
|
[EnumConnectorName.DISCORD_CONNECTOR]: "/api/v1/auth/discord/connector/reauth",
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Resolve the reauth endpoint for a connector.
|
||||||
|
*
|
||||||
|
* MCP OAuth connectors (those with ``config.mcp_service``) dynamically build
|
||||||
|
* the URL from the service key. Legacy OAuth connectors fall back to the
|
||||||
|
* static ``LEGACY_REAUTH_ENDPOINTS`` map.
|
||||||
|
*/
|
||||||
|
export function getReauthEndpoint(connector: SearchSourceConnector): string | undefined {
|
||||||
|
const mcpService = connector.config?.mcp_service as string | undefined;
|
||||||
|
if (mcpService) {
|
||||||
|
return `/api/v1/auth/mcp/${mcpService}/connector/reauth`;
|
||||||
|
}
|
||||||
|
return LEGACY_REAUTH_ENDPOINTS[connector.connector_type];
|
||||||
|
}
|
||||||
|
|
||||||
// Re-export IndexingConfigState from schemas for backward compatibility
|
// Re-export IndexingConfigState from schemas for backward compatibility
|
||||||
export type { IndexingConfigState } from "./connector-popup.schemas";
|
export type { IndexingConfigState } from "./connector-popup.schemas";
|
||||||
|
|
|
||||||
|
|
@ -1311,6 +1311,25 @@ export const useConnectorDialog = () => {
|
||||||
[editingConnector, searchSpaceId, deleteConnector, cameFromMCPList, setIsOpen]
|
[editingConnector, searchSpaceId, deleteConnector, cameFromMCPList, setIsOpen]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleDisconnectFromList = useCallback(
|
||||||
|
async (connector: SearchSourceConnector, refreshConnectors: () => void) => {
|
||||||
|
if (!searchSpaceId) return;
|
||||||
|
try {
|
||||||
|
await deleteConnector({ id: connector.id });
|
||||||
|
trackConnectorDeleted(Number(searchSpaceId), connector.connector_type, connector.id);
|
||||||
|
toast.success(`${connector.name} disconnected successfully`);
|
||||||
|
refreshConnectors();
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: cacheKeys.logs.summary(Number(searchSpaceId)),
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error disconnecting connector:", error);
|
||||||
|
toast.error("Failed to disconnect connector");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[searchSpaceId, deleteConnector]
|
||||||
|
);
|
||||||
|
|
||||||
// Handle quick index (index with selected date range, or backend defaults if none selected)
|
// Handle quick index (index with selected date range, or backend defaults if none selected)
|
||||||
const handleQuickIndexConnector = useCallback(
|
const handleQuickIndexConnector = useCallback(
|
||||||
async (
|
async (
|
||||||
|
|
@ -1484,6 +1503,7 @@ export const useConnectorDialog = () => {
|
||||||
handleStartEdit,
|
handleStartEdit,
|
||||||
handleSaveConnector,
|
handleSaveConnector,
|
||||||
handleDisconnectConnector,
|
handleDisconnectConnector,
|
||||||
|
handleDisconnectFromList,
|
||||||
handleBackFromEdit,
|
handleBackFromEdit,
|
||||||
handleBackFromConnect,
|
handleBackFromConnect,
|
||||||
handleBackFromYouTube,
|
handleBackFromYouTube,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useAtomValue } from "jotai";
|
import { useAtomValue } from "jotai";
|
||||||
import { ArrowLeft, Plus, RefreshCw, Server } from "lucide-react";
|
import { ArrowLeft, Plus, RefreshCw, Server, Trash2 } from "lucide-react";
|
||||||
import { type FC, useCallback, useState } from "react";
|
import { type FC, useCallback, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||||
|
|
@ -13,25 +13,10 @@ import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||||
import { formatRelativeDate } from "@/lib/format-date";
|
import { formatRelativeDate } from "@/lib/format-date";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants";
|
import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../constants/connector-constants";
|
||||||
import { useConnectorStatus } from "../hooks/use-connector-status";
|
import { useConnectorStatus } from "../hooks/use-connector-status";
|
||||||
import { getConnectorDisplayName } from "../tabs/all-connectors-tab";
|
import { getConnectorDisplayName } from "../tabs/all-connectors-tab";
|
||||||
|
|
||||||
const REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
|
|
||||||
[EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth",
|
|
||||||
[EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth",
|
|
||||||
[EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth",
|
|
||||||
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth",
|
|
||||||
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth",
|
|
||||||
[EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
|
||||||
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
|
||||||
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
|
||||||
[EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth",
|
|
||||||
[EnumConnectorName.JIRA_CONNECTOR]: "/api/v1/auth/jira/connector/reauth",
|
|
||||||
[EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth",
|
|
||||||
[EnumConnectorName.CONFLUENCE_CONNECTOR]: "/api/v1/auth/confluence/connector/reauth",
|
|
||||||
};
|
|
||||||
|
|
||||||
interface ConnectorAccountsListViewProps {
|
interface ConnectorAccountsListViewProps {
|
||||||
connectorType: string;
|
connectorType: string;
|
||||||
connectorTitle: string;
|
connectorTitle: string;
|
||||||
|
|
@ -39,15 +24,12 @@ interface ConnectorAccountsListViewProps {
|
||||||
indexingConnectorIds: Set<number>;
|
indexingConnectorIds: Set<number>;
|
||||||
onBack: () => void;
|
onBack: () => void;
|
||||||
onManage: (connector: SearchSourceConnector) => void;
|
onManage: (connector: SearchSourceConnector) => void;
|
||||||
|
onDisconnect?: (connector: SearchSourceConnector) => Promise<void> | void;
|
||||||
onAddAccount: () => void;
|
onAddAccount: () => void;
|
||||||
isConnecting?: boolean;
|
isConnecting?: boolean;
|
||||||
addButtonText?: string;
|
addButtonText?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
function isLiveConnector(connectorType: string): boolean {
|
|
||||||
return LIVE_CONNECTOR_TYPES.has(connectorType) || connectorType === "MCP_CONNECTOR";
|
|
||||||
}
|
|
||||||
|
|
||||||
export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
||||||
connectorType,
|
connectorType,
|
||||||
connectorTitle,
|
connectorTitle,
|
||||||
|
|
@ -55,12 +37,15 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
||||||
indexingConnectorIds,
|
indexingConnectorIds,
|
||||||
onBack,
|
onBack,
|
||||||
onManage,
|
onManage,
|
||||||
|
onDisconnect,
|
||||||
onAddAccount,
|
onAddAccount,
|
||||||
isConnecting = false,
|
isConnecting = false,
|
||||||
addButtonText,
|
addButtonText,
|
||||||
}) => {
|
}) => {
|
||||||
const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
|
const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
|
||||||
const [reauthingId, setReauthingId] = useState<number | null>(null);
|
const [reauthingId, setReauthingId] = useState<number | null>(null);
|
||||||
|
const [confirmDisconnectId, setConfirmDisconnectId] = useState<number | null>(null);
|
||||||
|
const [disconnectingId, setDisconnectingId] = useState<number | null>(null);
|
||||||
|
|
||||||
// Get connector status
|
// Get connector status
|
||||||
const { isConnectorEnabled, getConnectorStatusMessage } = useConnectorStatus();
|
const { isConnectorEnabled, getConnectorStatusMessage } = useConnectorStatus();
|
||||||
|
|
@ -68,16 +53,15 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
||||||
const isEnabled = isConnectorEnabled(connectorType);
|
const isEnabled = isConnectorEnabled(connectorType);
|
||||||
const statusMessage = getConnectorStatusMessage(connectorType);
|
const statusMessage = getConnectorStatusMessage(connectorType);
|
||||||
|
|
||||||
const reauthEndpoint = REAUTH_ENDPOINTS[connectorType];
|
|
||||||
|
|
||||||
const handleReauth = useCallback(
|
const handleReauth = useCallback(
|
||||||
async (connectorId: number) => {
|
async (connector: SearchSourceConnector) => {
|
||||||
if (!searchSpaceId || !reauthEndpoint) return;
|
const endpoint = getReauthEndpoint(connector);
|
||||||
setReauthingId(connectorId);
|
if (!searchSpaceId || !endpoint) return;
|
||||||
|
setReauthingId(connector.id);
|
||||||
try {
|
try {
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||||
const url = new URL(`${backendUrl}${reauthEndpoint}`);
|
const url = new URL(`${backendUrl}${endpoint}`);
|
||||||
url.searchParams.set("connector_id", String(connectorId));
|
url.searchParams.set("connector_id", String(connector.id));
|
||||||
url.searchParams.set("space_id", String(searchSpaceId));
|
url.searchParams.set("space_id", String(searchSpaceId));
|
||||||
url.searchParams.set("return_url", window.location.pathname);
|
url.searchParams.set("return_url", window.location.pathname);
|
||||||
const response = await authenticatedFetch(url.toString());
|
const response = await authenticatedFetch(url.toString());
|
||||||
|
|
@ -99,7 +83,7 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
||||||
setReauthingId(null);
|
setReauthingId(null);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[searchSpaceId, reauthEndpoint]
|
[searchSpaceId]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Filter connectors to only show those of this type
|
// Filter connectors to only show those of this type
|
||||||
|
|
@ -198,9 +182,11 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<div className="grid grid-cols-1 sm:grid-cols-2 gap-3">
|
<div className="grid grid-cols-1 sm:grid-cols-2 gap-3">
|
||||||
{typeConnectors.map((connector) => {
|
{typeConnectors.map((connector) => {
|
||||||
const isIndexing = indexingConnectorIds.has(connector.id);
|
const isIndexing = indexingConnectorIds.has(connector.id);
|
||||||
const isAuthExpired = !!reauthEndpoint && connector.config?.auth_expired === true;
|
const connectorReauthEndpoint = getReauthEndpoint(connector);
|
||||||
|
const isAuthExpired = !!connectorReauthEndpoint && connector.config?.auth_expired === true;
|
||||||
|
const isLive = LIVE_CONNECTOR_TYPES.has(connector.connector_type) || Boolean(connector.config?.server_config);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
|
|
@ -231,7 +217,7 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
||||||
<Spinner size="xs" />
|
<Spinner size="xs" />
|
||||||
Syncing
|
Syncing
|
||||||
</p>
|
</p>
|
||||||
) : !isLiveConnector(connector.connector_type) ? (
|
) : !isLive ? (
|
||||||
<p className="text-[10px] mt-1 whitespace-nowrap truncate text-muted-foreground">
|
<p className="text-[10px] mt-1 whitespace-nowrap truncate text-muted-foreground">
|
||||||
{connector.last_indexed_at
|
{connector.last_indexed_at
|
||||||
? `Last indexed: ${formatRelativeDate(connector.last_indexed_at)}`
|
? `Last indexed: ${formatRelativeDate(connector.last_indexed_at)}`
|
||||||
|
|
@ -239,28 +225,73 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
||||||
</p>
|
</p>
|
||||||
) : null}
|
) : null}
|
||||||
</div>
|
</div>
|
||||||
{isAuthExpired ? (
|
{isAuthExpired ? (
|
||||||
<Button
|
<Button
|
||||||
size="sm"
|
size="sm"
|
||||||
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-amber-600 hover:bg-amber-700 text-white border-0 shadow-xs shrink-0"
|
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-amber-600 hover:bg-amber-700 text-white border-0 shadow-xs shrink-0"
|
||||||
onClick={() => handleReauth(connector.id)}
|
onClick={() => handleReauth(connector)}
|
||||||
disabled={reauthingId === connector.id}
|
disabled={reauthingId === connector.id}
|
||||||
>
|
>
|
||||||
<RefreshCw
|
<RefreshCw
|
||||||
className={cn("size-3.5", reauthingId === connector.id && "animate-spin")}
|
className={cn("size-3.5", reauthingId === connector.id && "animate-spin")}
|
||||||
/>
|
/>
|
||||||
Re-authenticate
|
Re-authenticate
|
||||||
</Button>
|
</Button>
|
||||||
|
) : isLive && onDisconnect ? (
|
||||||
|
confirmDisconnectId === connector.id ? (
|
||||||
|
<div className="flex items-center gap-1.5 shrink-0">
|
||||||
|
<Button
|
||||||
|
variant="destructive"
|
||||||
|
size="sm"
|
||||||
|
className="h-8 text-[11px] px-3 rounded-lg font-medium shadow-xs"
|
||||||
|
onClick={async () => {
|
||||||
|
setDisconnectingId(connector.id);
|
||||||
|
setConfirmDisconnectId(null);
|
||||||
|
try {
|
||||||
|
await onDisconnect(connector);
|
||||||
|
} finally {
|
||||||
|
setDisconnectingId(null);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
disabled={disconnectingId === connector.id}
|
||||||
|
>
|
||||||
|
{disconnectingId === connector.id ? (
|
||||||
|
<RefreshCw className="size-3.5 animate-spin" />
|
||||||
|
) : (
|
||||||
|
"Confirm"
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
className="h-8 text-[11px] px-2 rounded-lg"
|
||||||
|
onClick={() => setConfirmDisconnectId(null)}
|
||||||
|
disabled={disconnectingId === connector.id}
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<Button
|
<Button
|
||||||
variant="secondary"
|
variant="secondary"
|
||||||
size="sm"
|
size="sm"
|
||||||
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80 shrink-0"
|
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-red-50 hover:text-red-700 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-red-950 dark:hover:text-red-400 shrink-0"
|
||||||
onClick={() => onManage(connector)}
|
onClick={() => setConfirmDisconnectId(connector.id)}
|
||||||
>
|
>
|
||||||
Manage
|
<Trash2 className="size-3.5" />
|
||||||
|
Disconnect
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)
|
||||||
|
) : (
|
||||||
|
<Button
|
||||||
|
variant="secondary"
|
||||||
|
size="sm"
|
||||||
|
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80 shrink-0"
|
||||||
|
onClick={() => onManage(connector)}
|
||||||
|
>
|
||||||
|
Manage
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
|
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
|
||||||
import { CornerDownLeftIcon, Pen } from "lucide-react";
|
import { CornerDownLeftIcon, Pen } from "lucide-react";
|
||||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||||
|
import { toast } from "sonner";
|
||||||
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { Input } from "@/components/ui/input";
|
import { Input } from "@/components/ui/input";
|
||||||
|
|
@ -116,8 +117,8 @@ function GenericApprovalCard({
|
||||||
if (phase !== "pending" || !isMCPTool) return;
|
if (phase !== "pending" || !isMCPTool) return;
|
||||||
setProcessing();
|
setProcessing();
|
||||||
onDecision({ type: "approve" });
|
onDecision({ type: "approve" });
|
||||||
connectorsApiService.trustMCPTool(mcpConnectorId, toolName).catch((err) => {
|
connectorsApiService.trustMCPTool(mcpConnectorId, toolName).catch(() => {
|
||||||
console.error("Failed to trust MCP tool:", err);
|
toast.error("Failed to save 'Always Allow' preference. The tool will still require approval next time.");
|
||||||
});
|
});
|
||||||
}, [phase, setProcessing, onDecision, isMCPTool, mcpConnectorId, toolName]);
|
}, [phase, setProcessing, onDecision, isMCPTool, mcpConnectorId, toolName]);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -414,16 +414,8 @@ class ConnectorsApiService {
|
||||||
* Subsequent calls to this tool will skip HITL approval.
|
* Subsequent calls to this tool will skip HITL approval.
|
||||||
*/
|
*/
|
||||||
trustMCPTool = async (connectorId: number, toolName: string): Promise<void> => {
|
trustMCPTool = async (connectorId: number, toolName: string): Promise<void> => {
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
await baseApiService.post(`/api/v1/connectors/mcp/${connectorId}/trust-tool`, undefined, {
|
||||||
const token =
|
body: { tool_name: toolName },
|
||||||
typeof window !== "undefined" ? document.cookie.match(/fapiToken=([^;]+)/)?.[1] : undefined;
|
|
||||||
await fetch(`${backendUrl}/api/v1/connectors/mcp/${connectorId}/trust-tool`, {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
...(token ? { Authorization: `Bearer ${token}` } : {}),
|
|
||||||
},
|
|
||||||
body: JSON.stringify({ tool_name: toolName }),
|
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -431,16 +423,8 @@ class ConnectorsApiService {
|
||||||
* Remove a tool from the MCP connector's "Always Allow" list.
|
* Remove a tool from the MCP connector's "Always Allow" list.
|
||||||
*/
|
*/
|
||||||
untrustMCPTool = async (connectorId: number, toolName: string): Promise<void> => {
|
untrustMCPTool = async (connectorId: number, toolName: string): Promise<void> => {
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
await baseApiService.post(`/api/v1/connectors/mcp/${connectorId}/untrust-tool`, undefined, {
|
||||||
const token =
|
body: { tool_name: toolName },
|
||||||
typeof window !== "undefined" ? document.cookie.match(/fapiToken=([^;]+)/)?.[1] : undefined;
|
|
||||||
await fetch(`${backendUrl}/api/v1/connectors/mcp/${connectorId}/untrust-tool`, {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
...(token ? { Authorization: `Bearer ${token}` } : {}),
|
|
||||||
},
|
|
||||||
body: JSON.stringify({ tool_name: toolName }),
|
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue