mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-21 18:55:16 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/obsidian-plugin
This commit is contained in:
commit
9b1b9a90c0
175 changed files with 10592 additions and 2302 deletions
|
|
@ -26,7 +26,7 @@ COMPOSIO_TOOLKIT_NAMES = {
|
|||
}
|
||||
|
||||
# Toolkits that support indexing (Phase 1: Google services only)
|
||||
INDEXABLE_TOOLKITS = {"googledrive", "gmail", "googlecalendar"}
|
||||
INDEXABLE_TOOLKITS = {"googledrive"}
|
||||
|
||||
# Mapping of toolkit IDs to connector types
|
||||
TOOLKIT_TO_CONNECTOR_TYPE = {
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -66,6 +65,8 @@ class ConfluenceKBSyncService:
|
|||
if dup:
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -184,6 +185,8 @@ class ConfluenceKBSyncService:
|
|||
|
||||
space_id = (document.document_metadata or {}).get("space_id", "")
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.db import Document, DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -73,6 +72,8 @@ class DropboxKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -78,6 +77,8 @@ class GmailKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from app.db import (
|
|||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -91,6 +90,8 @@ class GoogleCalendarKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -249,6 +250,8 @@ class GoogleCalendarKBSyncService:
|
|||
if not indexable_content:
|
||||
return {"status": "error", "message": "Event produced empty content"}
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -75,6 +74,8 @@ class GoogleDriveKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -75,6 +74,8 @@ class JiraKBSyncService:
|
|||
if dup:
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -190,6 +191,8 @@ class JiraKBSyncService:
|
|||
state = formatted.get("status", "Unknown")
|
||||
comment_count = len(formatted.get("comments", []))
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.connectors.linear_connector import LinearConnector
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -85,6 +84,8 @@ class LinearKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -226,6 +227,8 @@ class LinearKBSyncService:
|
|||
comment_count = len(formatted_issue.get("comments", []))
|
||||
formatted_issue.get("description", "")
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -133,6 +133,44 @@ PROVIDER_MAP = {
|
|||
}
|
||||
|
||||
|
||||
# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
|
||||
# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
|
||||
# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
|
||||
# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
|
||||
# request to an Azure endpoint, which then 404s with ``Resource not found``.
|
||||
# Only providers with a well-known, stable public base URL are listed here —
|
||||
# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
||||
# huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
||||
# so their existing config-driven behaviour is preserved.
|
||||
PROVIDER_DEFAULT_API_BASE = {
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"perplexity": "https://api.perplexity.ai",
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"cerebras": "https://api.cerebras.ai/v1",
|
||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
||||
"together_ai": "https://api.together.xyz/v1",
|
||||
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
||||
"cometapi": "https://api.cometapi.com/v1",
|
||||
"sambanova": "https://api.sambanova.ai/v1",
|
||||
}
|
||||
|
||||
|
||||
# Canonical provider → base URL when a config uses a generic ``openai``-style
|
||||
# prefix but the ``provider`` field tells us which API it really is
|
||||
# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
|
||||
# each has its own base URL).
|
||||
PROVIDER_KEY_DEFAULT_API_BASE = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"MOONSHOT": "https://api.moonshot.ai/v1",
|
||||
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"MINIMAX": "https://api.minimax.io/v1",
|
||||
}
|
||||
|
||||
|
||||
class LLMRouterService:
|
||||
"""
|
||||
Singleton service for managing LiteLLM Router.
|
||||
|
|
@ -224,6 +262,16 @@ class LLMRouterService:
|
|||
# hits ContextWindowExceededError.
|
||||
full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
|
||||
|
||||
# Build a general-purpose fallback list so NotFound/timeout/rate-limit
|
||||
# style failures on one deployment don't bubble up as hard errors —
|
||||
# the router retries with a sibling deployment in ``auto-large``.
|
||||
# ``auto-large`` is the large-context subset of ``auto``; if it is
|
||||
# empty we fall back to ``auto`` itself so the router at least picks a
|
||||
# different deployment in the same group.
|
||||
fallbacks: list[dict[str, list[str]]] | None = None
|
||||
if ctx_fallbacks:
|
||||
fallbacks = [{"auto": ["auto-large"]}]
|
||||
|
||||
try:
|
||||
router_kwargs: dict[str, Any] = {
|
||||
"model_list": full_model_list,
|
||||
|
|
@ -237,15 +285,24 @@ class LLMRouterService:
|
|||
}
|
||||
if ctx_fallbacks:
|
||||
router_kwargs["context_window_fallbacks"] = ctx_fallbacks
|
||||
if fallbacks:
|
||||
router_kwargs["fallbacks"] = fallbacks
|
||||
|
||||
instance._router = Router(**router_kwargs)
|
||||
instance._initialized = True
|
||||
|
||||
global _cached_context_profile, _cached_context_profile_computed
|
||||
_cached_context_profile = None
|
||||
_cached_context_profile_computed = False
|
||||
_router_instance_cache.clear()
|
||||
|
||||
logger.info(
|
||||
"LLM Router initialized with %d deployments, "
|
||||
"strategy: %s, context_window_fallbacks: %s",
|
||||
"strategy: %s, context_window_fallbacks: %s, fallbacks: %s",
|
||||
len(model_list),
|
||||
final_settings.get("routing_strategy"),
|
||||
ctx_fallbacks or "none",
|
||||
fallbacks or "none",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||
|
|
@ -348,10 +405,11 @@ class LLMRouterService:
|
|||
return None
|
||||
|
||||
# Build model string
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
provider_prefix = config["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
|
|
@ -361,9 +419,19 @@ class LLMRouterService:
|
|||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
# Add optional api_base
|
||||
if config.get("api_base"):
|
||||
litellm_params["api_base"] = config["api_base"]
|
||||
# Resolve ``api_base``. Config value wins; otherwise apply a
|
||||
# provider-aware default so the deployment does not silently
|
||||
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
|
||||
# requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
|
||||
# docstring for the motivating bug (OpenRouter models 404-ing
|
||||
# against an Azure endpoint).
|
||||
api_base = config.get("api_base")
|
||||
if not api_base:
|
||||
api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
|
||||
if not api_base:
|
||||
api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if config.get("litellm_params"):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
import litellm
|
||||
|
|
@ -6,7 +7,6 @@ from langchain_litellm import ChatLiteLLM
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
from app.config import config
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
from app.services.llm_router_service import (
|
||||
|
|
@ -32,6 +32,39 @@ litellm.callbacks = [token_tracker]
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Providers that require an interactive OAuth / device-flow login before
|
||||
# issuing any completion. LiteLLM implements these with blocking sync polling
|
||||
# (requests + time.sleep), which would freeze the FastAPI event loop if
|
||||
# invoked from validation. They are never usable from a headless backend,
|
||||
# so we reject them at the edge.
|
||||
_INTERACTIVE_AUTH_PROVIDERS: frozenset[str] = frozenset(
|
||||
{
|
||||
"github_copilot",
|
||||
"github-copilot",
|
||||
"githubcopilot",
|
||||
"copilot",
|
||||
}
|
||||
)
|
||||
|
||||
# Hard upper bound for a single validation call. Must exceed the ChatLiteLLM
|
||||
# request timeout (30s) by a small margin so a well-behaved provider never
|
||||
# trips the watchdog, while any pathological/blocking provider is killed.
|
||||
_VALIDATION_TIMEOUT_SECONDS: float = 35.0
|
||||
|
||||
|
||||
def _is_interactive_auth_provider(
|
||||
provider: str | None, custom_provider: str | None
|
||||
) -> bool:
|
||||
"""Return True if the given provider triggers interactive OAuth in LiteLLM."""
|
||||
for raw in (custom_provider, provider):
|
||||
if not raw:
|
||||
continue
|
||||
normalized = raw.strip().lower().replace(" ", "_")
|
||||
if normalized in _INTERACTIVE_AUTH_PROVIDERS:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class LLMRole:
|
||||
AGENT = "agent" # For agent/chat operations
|
||||
DOCUMENT_SUMMARY = "document_summary" # For document summarization
|
||||
|
|
@ -93,6 +126,25 @@ async def validate_llm_config(
|
|||
- is_valid: True if config works, False otherwise
|
||||
- error_message: Empty string if valid, error description if invalid
|
||||
"""
|
||||
# Reject providers that require interactive OAuth/device-flow auth.
|
||||
# LiteLLM's github_copilot provider (and similar) uses a blocking sync
|
||||
# Authenticator that polls GitHub for up to several minutes and prints a
|
||||
# device code to stdout. Running it on the FastAPI event loop will freeze
|
||||
# the entire backend, so we refuse them up front.
|
||||
if _is_interactive_auth_provider(provider, custom_provider):
|
||||
msg = (
|
||||
"Provider requires interactive OAuth/device-flow authentication "
|
||||
"(e.g. github_copilot) and cannot be used in a hosted backend. "
|
||||
"Please choose a provider that authenticates via API key."
|
||||
)
|
||||
logger.warning(
|
||||
"Rejected LLM config validation for interactive-auth provider "
|
||||
"(provider=%r, custom_provider=%r)",
|
||||
provider,
|
||||
custom_provider,
|
||||
)
|
||||
return False, msg
|
||||
|
||||
try:
|
||||
# Build the model string for litellm
|
||||
if custom_provider:
|
||||
|
|
@ -151,11 +203,34 @@ async def validate_llm_config(
|
|||
if litellm_params:
|
||||
litellm_kwargs.update(litellm_params)
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
# Make a simple test call
|
||||
# Run the test call in a worker thread with a hard timeout. Some
|
||||
# LiteLLM providers have synchronous blocking code paths (e.g. OAuth
|
||||
# authenticators that call time.sleep and requests.post) that would
|
||||
# otherwise freeze the asyncio event loop. Offloading to a thread and
|
||||
# bounding the wait keeps the server responsive even if a provider
|
||||
# misbehaves.
|
||||
test_message = HumanMessage(content="Hello")
|
||||
response = await llm.ainvoke([test_message])
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
asyncio.to_thread(llm.invoke, [test_message]),
|
||||
timeout=_VALIDATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"LLM config validation timed out after %ss for model: %s",
|
||||
_VALIDATION_TIMEOUT_SECONDS,
|
||||
model_string,
|
||||
)
|
||||
return (
|
||||
False,
|
||||
f"Validation timed out after {int(_VALIDATION_TIMEOUT_SECONDS)}s. "
|
||||
"The provider is unreachable or requires interactive "
|
||||
"authentication that is not supported by the backend.",
|
||||
)
|
||||
|
||||
# If we got here without exception, the config is valid
|
||||
if response and response.content:
|
||||
|
|
@ -303,6 +378,8 @@ async def get_search_space_llm_instance(
|
|||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
# Get the LLM configuration from database (NewLLMConfig)
|
||||
|
|
@ -380,6 +457,8 @@ async def get_search_space_llm_instance(
|
|||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -481,6 +560,8 @@ async def get_vision_llm(
|
|||
if global_cfg.get("litellm_params"):
|
||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
result = await session.execute(
|
||||
|
|
@ -514,6 +595,8 @@ async def get_vision_llm(
|
|||
if vision_cfg.litellm_params:
|
||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
0
surfsense_backend/app/services/mcp_oauth/__init__.py
Normal file
0
surfsense_backend/app/services/mcp_oauth/__init__.py
Normal file
121
surfsense_backend/app/services/mcp_oauth/discovery.py
Normal file
121
surfsense_backend/app/services/mcp_oauth/discovery.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"""MCP OAuth 2.1 metadata discovery, Dynamic Client Registration, and token exchange."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def discover_oauth_metadata(
|
||||
mcp_url: str,
|
||||
*,
|
||||
origin_override: str | None = None,
|
||||
timeout: float = 15.0,
|
||||
) -> dict:
|
||||
"""Fetch OAuth 2.1 metadata from the MCP server's well-known endpoint.
|
||||
|
||||
Per the MCP spec the discovery document lives at the *origin* of the
|
||||
MCP server URL. ``origin_override`` can be used when the OAuth server
|
||||
lives on a different domain (e.g. Airtable: MCP at ``mcp.airtable.com``,
|
||||
OAuth at ``airtable.com``).
|
||||
"""
|
||||
if origin_override:
|
||||
origin = origin_override.rstrip("/")
|
||||
else:
|
||||
parsed = urlparse(mcp_url)
|
||||
origin = f"{parsed.scheme}://{parsed.netloc}"
|
||||
discovery_url = f"{origin}/.well-known/oauth-authorization-server"
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.get(discovery_url, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def register_client(
|
||||
registration_endpoint: str,
|
||||
redirect_uri: str,
|
||||
*,
|
||||
client_name: str = "SurfSense",
|
||||
timeout: float = 15.0,
|
||||
) -> dict:
|
||||
"""Perform Dynamic Client Registration (RFC 7591)."""
|
||||
payload = {
|
||||
"client_name": client_name,
|
||||
"redirect_uris": [redirect_uri],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.post(
|
||||
registration_endpoint, json=payload, timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
token_endpoint: str,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
code_verifier: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
) -> dict:
|
||||
"""Exchange an authorization code for access + refresh tokens."""
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.post(
|
||||
token_endpoint,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Authorization": f"Basic {creds}",
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def refresh_access_token(
|
||||
token_endpoint: str,
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
) -> dict:
|
||||
"""Refresh an expired access token."""
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.post(
|
||||
token_endpoint,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Authorization": f"Basic {creds}",
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
161
surfsense_backend/app/services/mcp_oauth/registry.py
Normal file
161
surfsense_backend/app/services/mcp_oauth/registry.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
"""Registry of MCP services with OAuth support.
|
||||
|
||||
Each entry maps a URL-safe service key to its MCP server endpoint and
|
||||
authentication configuration. Services with ``supports_dcr=True`` use
|
||||
RFC 7591 Dynamic Client Registration (the MCP server issues its own
|
||||
credentials); the rest use pre-configured credentials via env vars.
|
||||
|
||||
``allowed_tools`` whitelists which MCP tools to expose to the agent.
|
||||
An empty list means "load every tool the server advertises" (used for
|
||||
user-managed generic MCP servers). Service-specific entries should
|
||||
curate this list to keep the agent's tool count low and selection
|
||||
accuracy high.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from app.db import SearchSourceConnectorType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MCPServiceConfig:
|
||||
name: str
|
||||
mcp_url: str
|
||||
connector_type: str
|
||||
supports_dcr: bool = True
|
||||
oauth_discovery_origin: str | None = None
|
||||
client_id_env: str | None = None
|
||||
client_secret_env: str | None = None
|
||||
scopes: list[str] = field(default_factory=list)
|
||||
scope_param: str = "scope"
|
||||
auth_endpoint_override: str | None = None
|
||||
token_endpoint_override: str | None = None
|
||||
allowed_tools: list[str] = field(default_factory=list)
|
||||
readonly_tools: frozenset[str] = field(default_factory=frozenset)
|
||||
account_metadata_keys: list[str] = field(default_factory=list)
|
||||
"""``connector.config`` keys exposed by ``get_connected_accounts``.
|
||||
|
||||
Only listed keys are returned to the LLM — tokens and secrets are
|
||||
never included. Every service should at least have its
|
||||
``display_name`` populated during OAuth; additional service-specific
|
||||
fields (e.g. Jira ``cloud_id``) are listed here so the LLM can pass
|
||||
them to action tools.
|
||||
"""
|
||||
|
||||
|
||||
MCP_SERVICES: dict[str, MCPServiceConfig] = {
|
||||
"linear": MCPServiceConfig(
|
||||
name="Linear",
|
||||
mcp_url="https://mcp.linear.app/mcp",
|
||||
connector_type="LINEAR_CONNECTOR",
|
||||
allowed_tools=[
|
||||
"list_issues",
|
||||
"get_issue",
|
||||
"save_issue",
|
||||
],
|
||||
readonly_tools=frozenset({"list_issues", "get_issue"}),
|
||||
account_metadata_keys=["organization_name", "organization_url_key"],
|
||||
),
|
||||
"jira": MCPServiceConfig(
|
||||
name="Jira",
|
||||
mcp_url="https://mcp.atlassian.com/v1/mcp",
|
||||
connector_type="JIRA_CONNECTOR",
|
||||
allowed_tools=[
|
||||
"getAccessibleAtlassianResources",
|
||||
"searchJiraIssuesUsingJql",
|
||||
"getVisibleJiraProjects",
|
||||
"getJiraProjectIssueTypesMetadata",
|
||||
"createJiraIssue",
|
||||
"editJiraIssue",
|
||||
],
|
||||
readonly_tools=frozenset({
|
||||
"getAccessibleAtlassianResources",
|
||||
"searchJiraIssuesUsingJql",
|
||||
"getVisibleJiraProjects",
|
||||
"getJiraProjectIssueTypesMetadata",
|
||||
}),
|
||||
account_metadata_keys=["cloud_id", "site_name", "base_url"],
|
||||
),
|
||||
"clickup": MCPServiceConfig(
|
||||
name="ClickUp",
|
||||
mcp_url="https://mcp.clickup.com/mcp",
|
||||
connector_type="CLICKUP_CONNECTOR",
|
||||
allowed_tools=[
|
||||
"clickup_search",
|
||||
"clickup_get_task",
|
||||
],
|
||||
readonly_tools=frozenset({"clickup_search", "clickup_get_task"}),
|
||||
account_metadata_keys=["workspace_id", "workspace_name"],
|
||||
),
|
||||
"slack": MCPServiceConfig(
|
||||
name="Slack",
|
||||
mcp_url="https://mcp.slack.com/mcp",
|
||||
connector_type="SLACK_CONNECTOR",
|
||||
supports_dcr=False,
|
||||
client_id_env="SLACK_CLIENT_ID",
|
||||
client_secret_env="SLACK_CLIENT_SECRET",
|
||||
auth_endpoint_override="https://slack.com/oauth/v2_user/authorize",
|
||||
token_endpoint_override="https://slack.com/api/oauth.v2.user.access",
|
||||
scopes=[
|
||||
"search:read.public", "search:read.private", "search:read.mpim", "search:read.im",
|
||||
"channels:history", "groups:history", "mpim:history", "im:history",
|
||||
],
|
||||
allowed_tools=[
|
||||
"slack_search_channels",
|
||||
"slack_read_channel",
|
||||
"slack_read_thread",
|
||||
],
|
||||
readonly_tools=frozenset({"slack_search_channels", "slack_read_channel", "slack_read_thread"}),
|
||||
# TODO: oauth.v2.user.access only returns team.id, not team.name.
|
||||
# To populate team_name, either add "team:read" scope and call
|
||||
# GET /api/team.info during OAuth callback, or switch to oauth.v2.access.
|
||||
account_metadata_keys=["team_id", "team_name"],
|
||||
),
|
||||
"airtable": MCPServiceConfig(
|
||||
name="Airtable",
|
||||
mcp_url="https://mcp.airtable.com/mcp",
|
||||
connector_type="AIRTABLE_CONNECTOR",
|
||||
supports_dcr=False,
|
||||
oauth_discovery_origin="https://airtable.com",
|
||||
client_id_env="AIRTABLE_CLIENT_ID",
|
||||
client_secret_env="AIRTABLE_CLIENT_SECRET",
|
||||
scopes=["data.records:read", "schema.bases:read"],
|
||||
allowed_tools=[
|
||||
"list_bases",
|
||||
"list_tables_for_base",
|
||||
"list_records_for_table",
|
||||
],
|
||||
readonly_tools=frozenset({"list_bases", "list_tables_for_base", "list_records_for_table"}),
|
||||
account_metadata_keys=["user_id", "user_email"],
|
||||
),
|
||||
}
|
||||
|
||||
_CONNECTOR_TYPE_TO_SERVICE: dict[str, MCPServiceConfig] = {
|
||||
svc.connector_type: svc for svc in MCP_SERVICES.values()
|
||||
}
|
||||
|
||||
LIVE_CONNECTOR_TYPES: frozenset[SearchSourceConnectorType] = frozenset({
|
||||
SearchSourceConnectorType.SLACK_CONNECTOR,
|
||||
SearchSourceConnectorType.TEAMS_CONNECTOR,
|
||||
SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR,
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||
SearchSourceConnectorType.LUMA_CONNECTOR,
|
||||
})
|
||||
|
||||
|
||||
def get_service(key: str) -> MCPServiceConfig | None:
|
||||
return MCP_SERVICES.get(key)
|
||||
|
||||
|
||||
def get_service_by_connector_type(connector_type: str) -> MCPServiceConfig | None:
|
||||
"""Look up an MCP service config by its ``connector_type`` enum value."""
|
||||
return _CONNECTOR_TYPE_TO_SERVICE.get(connector_type)
|
||||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -74,6 +73,8 @@ class NotionKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -244,6 +245,8 @@ class NotionKBSyncService:
|
|||
f"Final content length: {len(full_content)} chars, verified={content_verified}"
|
||||
)
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
logger.debug("Generating summary and embeddings")
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
|
|
|
|||
|
|
@ -227,8 +227,6 @@ class NotionToolMetadataService:
|
|||
async def _check_account_health(self, connector_id: int) -> bool:
|
||||
"""Check if a Notion connector's token is still valid.
|
||||
|
||||
Uses a lightweight ``users.me()`` call to verify the token.
|
||||
|
||||
Returns True if the token is expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.db import Document, DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -73,6 +72,8 @@ class OneDriveKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue