mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-02 12:22:40 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/obsidian-plugin
This commit is contained in:
commit
9b1b9a90c0
175 changed files with 10592 additions and 2302 deletions
|
|
@ -77,6 +77,8 @@ services:
|
|||
- shared_temp:/shared_tmp
|
||||
env_file:
|
||||
- ../surfsense_backend/.env
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
environment:
|
||||
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
|
||||
|
|
@ -118,6 +120,8 @@ services:
|
|||
- shared_temp:/shared_tmp
|
||||
env_file:
|
||||
- ../surfsense_backend/.env
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
environment:
|
||||
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
|
||||
|
|
|
|||
|
|
@ -60,6 +60,8 @@ services:
|
|||
- shared_temp:/shared_tmp
|
||||
env_file:
|
||||
- .env
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
environment:
|
||||
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
|
||||
|
|
@ -100,6 +102,8 @@ services:
|
|||
- shared_temp:/shared_tmp
|
||||
env_file:
|
||||
- .env
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
environment:
|
||||
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
|
||||
|
|
|
|||
|
|
@ -239,6 +239,9 @@ LLAMA_CLOUD_API_KEY=llx-nnn
|
|||
# DAYTONA_TARGET=us
|
||||
# DAYTONA_SNAPSHOT_ID=
|
||||
|
||||
# Desktop local filesystem mode (chat file tools run against a local folder root)
|
||||
# ENABLE_DESKTOP_LOCAL_FILESYSTEM=FALSE
|
||||
|
||||
# OPTIONAL: Add these for LangSmith Observability
|
||||
LANGSMITH_TRACING=true
|
||||
LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ from deepagents.backends import StateBackend
|
|||
from deepagents.graph import BASE_AGENT_PROMPT
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||
from deepagents.middleware.summarization import create_summarization_middleware
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
|
|
@ -34,18 +33,24 @@ from langgraph.types import Checkpointer
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.agents.new_chat.middleware import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
FileIntentMiddleware,
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
MemoryInjectionMiddleware,
|
||||
SurfSenseFilesystemMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.safe_summarization import (
|
||||
create_safe_summarization_middleware,
|
||||
)
|
||||
from app.agents.new_chat.system_prompt import (
|
||||
build_configurable_system_prompt,
|
||||
build_surfsense_system_prompt,
|
||||
)
|
||||
from app.agents.new_chat.tools.registry import build_tools_async
|
||||
from app.agents.new_chat.tools.registry import build_tools_async, get_connector_gated_tools
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
|
@ -162,6 +167,7 @@ async def create_surfsense_deep_agent(
|
|||
thread_visibility: ChatVisibility | None = None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
anon_session_id: str | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
):
|
||||
"""
|
||||
Create a SurfSense deep agent with configurable tools and prompts.
|
||||
|
|
@ -236,6 +242,8 @@ async def create_surfsense_deep_agent(
|
|||
)
|
||||
"""
|
||||
_t_agent_total = time.perf_counter()
|
||||
filesystem_selection = filesystem_selection or FilesystemSelection()
|
||||
backend_resolver = build_backend_resolver(filesystem_selection)
|
||||
|
||||
# Discover available connectors and document types for this search space
|
||||
available_connectors: list[str] | None = None
|
||||
|
|
@ -285,105 +293,10 @@ async def create_surfsense_deep_agent(
|
|||
"llm": llm,
|
||||
}
|
||||
|
||||
# Disable Notion action tools if no Notion connector is configured
|
||||
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
|
||||
has_notion_connector = (
|
||||
available_connectors is not None and "NOTION_CONNECTOR" in available_connectors
|
||||
modified_disabled_tools.extend(
|
||||
get_connector_gated_tools(available_connectors)
|
||||
)
|
||||
if not has_notion_connector:
|
||||
notion_tools = [
|
||||
"create_notion_page",
|
||||
"update_notion_page",
|
||||
"delete_notion_page",
|
||||
]
|
||||
modified_disabled_tools.extend(notion_tools)
|
||||
|
||||
# Disable Linear action tools if no Linear connector is configured
|
||||
has_linear_connector = (
|
||||
available_connectors is not None and "LINEAR_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_linear_connector:
|
||||
linear_tools = [
|
||||
"create_linear_issue",
|
||||
"update_linear_issue",
|
||||
"delete_linear_issue",
|
||||
]
|
||||
modified_disabled_tools.extend(linear_tools)
|
||||
|
||||
# Disable Google Drive action tools if no Google Drive connector is configured
|
||||
has_google_drive_connector = (
|
||||
available_connectors is not None and "GOOGLE_DRIVE_FILE" in available_connectors
|
||||
)
|
||||
if not has_google_drive_connector:
|
||||
google_drive_tools = [
|
||||
"create_google_drive_file",
|
||||
"delete_google_drive_file",
|
||||
]
|
||||
modified_disabled_tools.extend(google_drive_tools)
|
||||
|
||||
has_dropbox_connector = (
|
||||
available_connectors is not None and "DROPBOX_FILE" in available_connectors
|
||||
)
|
||||
if not has_dropbox_connector:
|
||||
modified_disabled_tools.extend(["create_dropbox_file", "delete_dropbox_file"])
|
||||
|
||||
has_onedrive_connector = (
|
||||
available_connectors is not None and "ONEDRIVE_FILE" in available_connectors
|
||||
)
|
||||
if not has_onedrive_connector:
|
||||
modified_disabled_tools.extend(["create_onedrive_file", "delete_onedrive_file"])
|
||||
|
||||
# Disable Google Calendar action tools if no Google Calendar connector is configured
|
||||
has_google_calendar_connector = (
|
||||
available_connectors is not None
|
||||
and "GOOGLE_CALENDAR_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_google_calendar_connector:
|
||||
calendar_tools = [
|
||||
"create_calendar_event",
|
||||
"update_calendar_event",
|
||||
"delete_calendar_event",
|
||||
]
|
||||
modified_disabled_tools.extend(calendar_tools)
|
||||
|
||||
# Disable Gmail action tools if no Gmail connector is configured
|
||||
has_gmail_connector = (
|
||||
available_connectors is not None
|
||||
and "GOOGLE_GMAIL_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_gmail_connector:
|
||||
gmail_tools = [
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"send_gmail_email",
|
||||
"trash_gmail_email",
|
||||
]
|
||||
modified_disabled_tools.extend(gmail_tools)
|
||||
|
||||
# Disable Jira action tools if no Jira connector is configured
|
||||
has_jira_connector = (
|
||||
available_connectors is not None and "JIRA_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_jira_connector:
|
||||
jira_tools = [
|
||||
"create_jira_issue",
|
||||
"update_jira_issue",
|
||||
"delete_jira_issue",
|
||||
]
|
||||
modified_disabled_tools.extend(jira_tools)
|
||||
|
||||
# Disable Confluence action tools if no Confluence connector is configured
|
||||
has_confluence_connector = (
|
||||
available_connectors is not None
|
||||
and "CONFLUENCE_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_confluence_connector:
|
||||
confluence_tools = [
|
||||
"create_confluence_page",
|
||||
"update_confluence_page",
|
||||
"delete_confluence_page",
|
||||
]
|
||||
modified_disabled_tools.extend(confluence_tools)
|
||||
|
||||
# Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware.
|
||||
if "search_knowledge_base" not in modified_disabled_tools:
|
||||
|
|
@ -407,6 +320,20 @@ async def create_surfsense_deep_agent(
|
|||
_t0 = time.perf_counter()
|
||||
_enabled_tool_names = {t.name for t in tools}
|
||||
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
|
||||
|
||||
# Collect generic MCP connector info so the system prompt can route queries
|
||||
# to their tools instead of falling back to "not in knowledge base".
|
||||
_mcp_connector_tools: dict[str, list[str]] = {}
|
||||
for t in tools:
|
||||
meta = getattr(t, "metadata", None) or {}
|
||||
if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"):
|
||||
_mcp_connector_tools.setdefault(
|
||||
meta["mcp_connector_name"], [],
|
||||
).append(t.name)
|
||||
|
||||
if _mcp_connector_tools:
|
||||
_perf_log.info("MCP connector tool routing: %s", _mcp_connector_tools)
|
||||
|
||||
if agent_config is not None:
|
||||
system_prompt = build_configurable_system_prompt(
|
||||
custom_system_instructions=agent_config.system_instructions,
|
||||
|
|
@ -415,12 +342,14 @@ async def create_surfsense_deep_agent(
|
|||
thread_visibility=thread_visibility,
|
||||
enabled_tool_names=_enabled_tool_names,
|
||||
disabled_tool_names=_user_disabled_tool_names,
|
||||
mcp_connector_tools=_mcp_connector_tools,
|
||||
)
|
||||
else:
|
||||
system_prompt = build_surfsense_system_prompt(
|
||||
thread_visibility=thread_visibility,
|
||||
enabled_tool_names=_enabled_tool_names,
|
||||
disabled_tool_names=_user_disabled_tool_names,
|
||||
mcp_connector_tools=_mcp_connector_tools,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
||||
|
|
@ -437,12 +366,15 @@ async def create_surfsense_deep_agent(
|
|||
gp_middleware = [
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
FileIntentMiddleware(llm=llm),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_selection.mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
create_summarization_middleware(llm, StateBackend),
|
||||
create_safe_summarization_middleware(llm, StateBackend),
|
||||
PatchToolCallsMiddleware(),
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
|
|
@ -458,21 +390,25 @@ async def create_surfsense_deep_agent(
|
|||
deepagent_middleware = [
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
FileIntentMiddleware(llm=llm),
|
||||
KnowledgeBaseSearchMiddleware(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_selection.mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
anon_session_id=anon_session_id,
|
||||
),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_selection.mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
|
||||
create_summarization_middleware(llm, StateBackend),
|
||||
create_safe_summarization_middleware(llm, StateBackend),
|
||||
PatchToolCallsMiddleware(),
|
||||
DedupHITLToolCallsMiddleware(agent_tools=tools),
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
|
|
|
|||
|
|
@ -4,7 +4,15 @@ Context schema definitions for SurfSense agents.
|
|||
This module defines the custom state schema used by the SurfSense deep agent.
|
||||
"""
|
||||
|
||||
from typing import TypedDict
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
|
||||
class FileOperationContractState(TypedDict):
|
||||
intent: str
|
||||
confidence: float
|
||||
suggested_path: str
|
||||
timestamp: str
|
||||
turn_id: str
|
||||
|
||||
|
||||
class SurfSenseContextSchema(TypedDict):
|
||||
|
|
@ -24,5 +32,8 @@ class SurfSenseContextSchema(TypedDict):
|
|||
"""
|
||||
|
||||
search_space_id: int
|
||||
file_operation_contract: NotRequired[FileOperationContractState]
|
||||
turn_id: NotRequired[str]
|
||||
request_id: NotRequired[str]
|
||||
# These are runtime-injected and won't be serialized
|
||||
# db_session and connector_service are passed when invoking the agent
|
||||
|
|
|
|||
42
surfsense_backend/app/agents/new_chat/filesystem_backends.py
Normal file
42
surfsense_backend/app/agents/new_chat/filesystem_backends.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Filesystem backend resolver for cloud and desktop-local modes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
|
||||
from deepagents.backends.state import StateBackend
|
||||
from langgraph.prebuilt.tool_node import ToolRuntime
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||
MultiRootLocalFolderBackend,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=64)
|
||||
def _cached_multi_root_backend(
|
||||
mounts: tuple[tuple[str, str], ...],
|
||||
) -> MultiRootLocalFolderBackend:
|
||||
return MultiRootLocalFolderBackend(mounts)
|
||||
|
||||
|
||||
def build_backend_resolver(
|
||||
selection: FilesystemSelection,
|
||||
) -> Callable[[ToolRuntime], StateBackend | MultiRootLocalFolderBackend]:
|
||||
"""Create deepagents backend resolver for the selected filesystem mode."""
|
||||
|
||||
if selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER and selection.local_mounts:
|
||||
|
||||
def _resolve_local(_runtime: ToolRuntime) -> MultiRootLocalFolderBackend:
|
||||
mounts = tuple(
|
||||
(entry.mount_id, entry.root_path) for entry in selection.local_mounts
|
||||
)
|
||||
return _cached_multi_root_backend(mounts)
|
||||
|
||||
return _resolve_local
|
||||
|
||||
def _resolve_cloud(runtime: ToolRuntime) -> StateBackend:
|
||||
return StateBackend(runtime)
|
||||
|
||||
return _resolve_cloud
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
"""Filesystem mode contracts and selection helpers for chat sessions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class FilesystemMode(StrEnum):
|
||||
"""Supported filesystem backends for agent tool execution."""
|
||||
|
||||
CLOUD = "cloud"
|
||||
DESKTOP_LOCAL_FOLDER = "desktop_local_folder"
|
||||
|
||||
|
||||
class ClientPlatform(StrEnum):
|
||||
"""Client runtime reported by the caller."""
|
||||
|
||||
WEB = "web"
|
||||
DESKTOP = "desktop"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LocalFilesystemMount:
|
||||
"""Canonical mount mapping provided by desktop runtime."""
|
||||
|
||||
mount_id: str
|
||||
root_path: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class FilesystemSelection:
|
||||
"""Resolved filesystem selection for a single chat request."""
|
||||
|
||||
mode: FilesystemMode = FilesystemMode.CLOUD
|
||||
client_platform: ClientPlatform = ClientPlatform.WEB
|
||||
local_mounts: tuple[LocalFilesystemMount, ...] = ()
|
||||
|
||||
@property
|
||||
def is_local_mode(self) -> bool:
|
||||
return self.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||
|
|
@ -6,6 +6,9 @@ from app.agents.new_chat.middleware.dedup_tool_calls import (
|
|||
from app.agents.new_chat.middleware.filesystem import (
|
||||
SurfSenseFilesystemMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.file_intent import (
|
||||
FileIntentMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
)
|
||||
|
|
@ -15,6 +18,7 @@ from app.agents.new_chat.middleware.memory_injection import (
|
|||
|
||||
__all__ = [
|
||||
"DedupHITLToolCallsMiddleware",
|
||||
"FileIntentMiddleware",
|
||||
"KnowledgeBaseSearchMiddleware",
|
||||
"MemoryInjectionMiddleware",
|
||||
"SurfSenseFilesystemMiddleware",
|
||||
|
|
|
|||
352
surfsense_backend/app/agents/new_chat/middleware/file_intent.py
Normal file
352
surfsense_backend/app/agents/new_chat/middleware/file_intent.py
Normal file
|
|
@ -0,0 +1,352 @@
|
|||
"""Semantic file-intent routing middleware for new chat turns.
|
||||
|
||||
This middleware classifies the latest human turn into a small intent set:
|
||||
- chat_only
|
||||
- file_write
|
||||
- file_read
|
||||
|
||||
For ``file_write`` turns it injects a strict system contract so the model
|
||||
uses filesystem tools before claiming success, and provides a deterministic
|
||||
fallback path when no filename is specified by the user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileOperationIntent(StrEnum):
|
||||
CHAT_ONLY = "chat_only"
|
||||
FILE_WRITE = "file_write"
|
||||
FILE_READ = "file_read"
|
||||
|
||||
|
||||
class FileIntentPlan(BaseModel):
|
||||
intent: FileOperationIntent = Field(
|
||||
description="Primary user intent for this turn."
|
||||
)
|
||||
confidence: float = Field(
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
default=0.5,
|
||||
description="Model confidence in the selected intent.",
|
||||
)
|
||||
suggested_filename: str | None = Field(
|
||||
default=None,
|
||||
description="Optional filename (e.g. notes.md) inferred from user request.",
|
||||
)
|
||||
suggested_directory: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional directory path (e.g. /reports/q2 or reports/q2) inferred from "
|
||||
"user request."
|
||||
),
|
||||
)
|
||||
suggested_path: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional full file path (e.g. /reports/q2/summary.md). If present, this "
|
||||
"takes precedence over suggested_directory + suggested_filename."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
parts.append(str(item.get("text", "")))
|
||||
return "\n".join(part for part in parts if part)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _extract_json_payload(text: str) -> str:
|
||||
stripped = text.strip()
|
||||
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
||||
if fenced:
|
||||
return fenced.group(1)
|
||||
start = stripped.find("{")
|
||||
end = stripped.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return stripped[start : end + 1]
|
||||
return stripped
|
||||
|
||||
|
||||
def _sanitize_filename(value: str) -> str:
|
||||
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
||||
name = re.sub(r"\s+", "-", name)
|
||||
name = name.strip("._-")
|
||||
if not name:
|
||||
name = "note"
|
||||
if len(name) > 80:
|
||||
name = name[:80].rstrip("-_.")
|
||||
return name
|
||||
|
||||
|
||||
def _sanitize_path_segment(value: str) -> str:
|
||||
segment = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
||||
segment = re.sub(r"\s+", "_", segment)
|
||||
segment = segment.strip("._-")
|
||||
return segment
|
||||
|
||||
|
||||
def _infer_text_file_extension(user_text: str) -> str:
|
||||
lowered = user_text.lower()
|
||||
if any(token in lowered for token in ("json", ".json")):
|
||||
return ".json"
|
||||
if any(token in lowered for token in ("yaml", "yml", ".yaml", ".yml")):
|
||||
return ".yaml"
|
||||
if any(token in lowered for token in ("csv", ".csv")):
|
||||
return ".csv"
|
||||
if any(token in lowered for token in ("python", ".py")):
|
||||
return ".py"
|
||||
if any(token in lowered for token in ("typescript", ".ts", ".tsx")):
|
||||
return ".ts"
|
||||
if any(token in lowered for token in ("javascript", ".js", ".mjs", ".cjs")):
|
||||
return ".js"
|
||||
if any(token in lowered for token in ("html", ".html")):
|
||||
return ".html"
|
||||
if any(token in lowered for token in ("css", ".css")):
|
||||
return ".css"
|
||||
if any(token in lowered for token in ("sql", ".sql")):
|
||||
return ".sql"
|
||||
if any(token in lowered for token in ("toml", ".toml")):
|
||||
return ".toml"
|
||||
if any(token in lowered for token in ("ini", ".ini")):
|
||||
return ".ini"
|
||||
if any(token in lowered for token in ("xml", ".xml")):
|
||||
return ".xml"
|
||||
if any(token in lowered for token in ("markdown", ".md", "readme")):
|
||||
return ".md"
|
||||
return ".md"
|
||||
|
||||
|
||||
def _normalize_directory(value: str) -> str:
|
||||
raw = value.strip().replace("\\", "/")
|
||||
raw = raw.strip("/")
|
||||
if not raw:
|
||||
return ""
|
||||
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
|
||||
parts = [part for part in parts if part]
|
||||
return "/".join(parts)
|
||||
|
||||
|
||||
def _normalize_file_path(value: str) -> str:
|
||||
raw = value.strip().replace("\\", "/").strip()
|
||||
if not raw:
|
||||
return ""
|
||||
had_trailing_slash = raw.endswith("/")
|
||||
raw = raw.strip("/")
|
||||
if not raw:
|
||||
return ""
|
||||
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
|
||||
parts = [part for part in parts if part]
|
||||
if not parts:
|
||||
return ""
|
||||
if had_trailing_slash:
|
||||
return f"/{'/'.join(parts)}/"
|
||||
return f"/{'/'.join(parts)}"
|
||||
|
||||
|
||||
def _infer_directory_from_user_text(user_text: str) -> str | None:
|
||||
patterns = (
|
||||
r"\b(?:in|inside|under)\s+(?:the\s+)?([a-zA-Z0-9 _\-/]+?)\s+folder\b",
|
||||
r"\b(?:in|inside|under)\s+([a-zA-Z0-9 _\-/]+?)\b",
|
||||
)
|
||||
lowered = user_text.lower()
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, lowered, flags=re.IGNORECASE)
|
||||
if not match:
|
||||
continue
|
||||
candidate = match.group(1).strip()
|
||||
if candidate in {"the", "a", "an"}:
|
||||
continue
|
||||
normalized = _normalize_directory(candidate)
|
||||
if normalized:
|
||||
return normalized
|
||||
return None
|
||||
|
||||
|
||||
def _fallback_path(
|
||||
suggested_filename: str | None,
|
||||
*,
|
||||
suggested_directory: str | None = None,
|
||||
suggested_path: str | None = None,
|
||||
user_text: str,
|
||||
) -> str:
|
||||
default_extension = _infer_text_file_extension(user_text)
|
||||
inferred_dir = _infer_directory_from_user_text(user_text)
|
||||
|
||||
sanitized_filename = ""
|
||||
if suggested_filename:
|
||||
sanitized_filename = _sanitize_filename(suggested_filename)
|
||||
if sanitized_filename.lower().endswith(".txt"):
|
||||
sanitized_filename = f"{sanitized_filename[:-4]}.md"
|
||||
if not sanitized_filename:
|
||||
sanitized_filename = f"notes{default_extension}"
|
||||
elif "." not in sanitized_filename:
|
||||
sanitized_filename = f"{sanitized_filename}{default_extension}"
|
||||
|
||||
normalized_suggested_path = (
|
||||
_normalize_file_path(suggested_path) if suggested_path else ""
|
||||
)
|
||||
if normalized_suggested_path:
|
||||
if normalized_suggested_path.endswith("/"):
|
||||
return f"{normalized_suggested_path.rstrip('/')}/{sanitized_filename}"
|
||||
return normalized_suggested_path
|
||||
|
||||
directory = _normalize_directory(suggested_directory or "")
|
||||
if not directory and inferred_dir:
|
||||
directory = inferred_dir
|
||||
if directory:
|
||||
return f"/{directory}/{sanitized_filename}"
|
||||
|
||||
return f"/{sanitized_filename}"
|
||||
|
||||
|
||||
def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str:
|
||||
return (
|
||||
"Classify the latest user request into a filesystem intent for an AI agent.\n"
|
||||
"Return JSON only with this exact schema:\n"
|
||||
'{"intent":"chat_only|file_write|file_read","confidence":0.0,"suggested_filename":"string or null","suggested_directory":"string or null","suggested_path":"string or null"}\n\n'
|
||||
"Rules:\n"
|
||||
"- Use semantic intent, not literal keywords.\n"
|
||||
"- file_write: user asks to create/save/write/update/edit content as a file.\n"
|
||||
"- file_read: user asks to open/read/list/search existing files.\n"
|
||||
"- chat_only: conversational/analysis responses without required file operations.\n"
|
||||
"- For file_write, choose a concise semantic suggested_filename and match the requested format.\n"
|
||||
"- If the user mentions a folder/directory, populate suggested_directory.\n"
|
||||
"- If user specifies an explicit full path, populate suggested_path.\n"
|
||||
"- Use extensions that match user intent (e.g. .md, .json, .yaml, .csv, .py, .ts, .js, .html, .css, .sql).\n"
|
||||
"- Do not use .txt; prefer .md for generic text notes.\n"
|
||||
"- Do not include dates or timestamps in suggested_filename unless explicitly requested.\n"
|
||||
"- Never include markdown or explanation.\n\n"
|
||||
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
|
||||
f"Latest user message:\n{user_text}"
|
||||
)
|
||||
|
||||
|
||||
def _build_recent_conversation(messages: list[BaseMessage], *, max_messages: int = 6) -> str:
|
||||
rows: list[str] = []
|
||||
for msg in messages[-max_messages:]:
|
||||
role = "user" if isinstance(msg, HumanMessage) else "assistant"
|
||||
text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip()
|
||||
if text:
|
||||
rows.append(f"{role}: {text[:280]}")
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Classify file intent and inject a strict file-write contract."""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(self, *, llm: BaseChatModel | None = None) -> None:
|
||||
self.llm = llm
|
||||
|
||||
async def _classify_intent(
|
||||
self, *, messages: list[BaseMessage], user_text: str
|
||||
) -> FileIntentPlan:
|
||||
if self.llm is None:
|
||||
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
|
||||
|
||||
prompt = _build_classifier_prompt(
|
||||
recent_conversation=_build_recent_conversation(messages),
|
||||
user_text=user_text,
|
||||
)
|
||||
try:
|
||||
response = await self.llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
)
|
||||
payload = json.loads(_extract_json_payload(_extract_text_from_message(response)))
|
||||
plan = FileIntentPlan.model_validate(payload)
|
||||
return plan
|
||||
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
|
||||
logger.warning("File intent classifier returned invalid output: %s", exc)
|
||||
except Exception as exc: # pragma: no cover - defensive fallback
|
||||
logger.warning("File intent classifier failed: %s", exc)
|
||||
|
||||
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_human: HumanMessage | None = None
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, HumanMessage):
|
||||
last_human = msg
|
||||
break
|
||||
if last_human is None:
|
||||
return None
|
||||
|
||||
user_text = _extract_text_from_message(last_human).strip()
|
||||
if not user_text:
|
||||
return None
|
||||
|
||||
plan = await self._classify_intent(messages=messages, user_text=user_text)
|
||||
suggested_path = _fallback_path(
|
||||
plan.suggested_filename,
|
||||
suggested_directory=plan.suggested_directory,
|
||||
suggested_path=plan.suggested_path,
|
||||
user_text=user_text,
|
||||
)
|
||||
contract = {
|
||||
"intent": plan.intent.value,
|
||||
"confidence": plan.confidence,
|
||||
"suggested_path": suggested_path,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"turn_id": state.get("turn_id", ""),
|
||||
}
|
||||
|
||||
if plan.intent != FileOperationIntent.FILE_WRITE:
|
||||
return {"file_operation_contract": contract}
|
||||
|
||||
contract_msg = SystemMessage(
|
||||
content=(
|
||||
"<file_operation_contract>\n"
|
||||
"This turn intent is file_write.\n"
|
||||
f"Suggested default path: {suggested_path}\n"
|
||||
"Rules:\n"
|
||||
"- You MUST call write_file or edit_file before claiming success.\n"
|
||||
"- If no path is provided by the user, use the suggested default path.\n"
|
||||
"- Do not claim a file was created/updated unless tool output confirms it.\n"
|
||||
"- If the write/edit fails, clearly report failure instead of success.\n"
|
||||
"- Do not include timestamps or dates in generated file content unless the user explicitly asks for them.\n"
|
||||
"- For open-ended requests (e.g., random note), generate useful concrete content, not placeholders.\n"
|
||||
"</file_operation_contract>"
|
||||
)
|
||||
)
|
||||
|
||||
# Insert just before the latest human turn so it applies to this request.
|
||||
new_messages = list(messages)
|
||||
insert_at = max(len(new_messages) - 1, 0)
|
||||
new_messages.insert(insert_at, contract_msg)
|
||||
return {"messages": new_messages, "file_operation_contract": contract}
|
||||
|
||||
|
|
@ -26,6 +26,10 @@ from langchain_core.tools import BaseTool, StructuredTool
|
|||
from langgraph.types import Command
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||
MultiRootLocalFolderBackend,
|
||||
)
|
||||
from app.agents.new_chat.sandbox import (
|
||||
_evict_sandbox_cache,
|
||||
delete_sandbox,
|
||||
|
|
@ -50,6 +54,8 @@ SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = """## Following Conventions
|
|||
|
||||
- Read files before editing — understand existing content before making changes.
|
||||
- Mimic existing style, naming conventions, and patterns.
|
||||
- Never claim a file was created/updated unless filesystem tool output confirms success.
|
||||
- If a file write/edit fails, explicitly report the failure.
|
||||
|
||||
## Filesystem Tools
|
||||
|
||||
|
|
@ -109,13 +115,20 @@ Usage:
|
|||
- Use chunk IDs (`<chunk id='...'>`) as citations in answers.
|
||||
"""
|
||||
|
||||
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new file to the in-memory filesystem (session-only).
|
||||
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new text file to the in-memory filesystem (session-only).
|
||||
|
||||
Use this to create scratch/working files during the conversation. Files created
|
||||
here are ephemeral and will not be saved to the user's knowledge base.
|
||||
|
||||
To permanently save a document to the user's knowledge base, use the
|
||||
`save_document` tool instead.
|
||||
|
||||
Supported outputs include common LLM-friendly text formats like markdown, json,
|
||||
yaml, csv, xml, html, css, sql, and code files.
|
||||
|
||||
When creating content from open-ended prompts, produce concrete and useful text,
|
||||
not placeholders. Avoid adding dates/timestamps unless the user explicitly asks
|
||||
for them.
|
||||
"""
|
||||
|
||||
SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files.
|
||||
|
|
@ -182,11 +195,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
backend: Any = None,
|
||||
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||
search_space_id: int | None = None,
|
||||
created_by_id: str | None = None,
|
||||
thread_id: int | str | None = None,
|
||||
tool_token_limit_before_evict: int | None = 20000,
|
||||
) -> None:
|
||||
self._filesystem_mode = filesystem_mode
|
||||
self._search_space_id = search_space_id
|
||||
self._created_by_id = created_by_id
|
||||
self._thread_id = thread_id
|
||||
|
|
@ -204,8 +220,17 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
" extract the data, write it as a clean file (CSV, JSON, etc.),"
|
||||
" and then run your code against it."
|
||||
)
|
||||
if filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
|
||||
system_prompt += (
|
||||
"\n\n## Local Folder Mode"
|
||||
"\n\nThis chat is running in desktop local-folder mode."
|
||||
" Keep all file operations local. Do not use save_document."
|
||||
" Always use mount-prefixed absolute paths like /<folder>/file.ext."
|
||||
" If you are unsure which mounts are available, call ls('/') first."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
backend=backend,
|
||||
system_prompt=system_prompt,
|
||||
custom_tool_descriptions={
|
||||
"ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION,
|
||||
|
|
@ -219,7 +244,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
max_execute_timeout=self._MAX_EXECUTE_TIMEOUT,
|
||||
)
|
||||
self.tools = [t for t in self.tools if t.name != "execute"]
|
||||
self.tools.append(self._create_save_document_tool())
|
||||
if self._should_persist_documents():
|
||||
self.tools.append(self._create_save_document_tool())
|
||||
if self._sandbox_available:
|
||||
self.tools.append(self._create_execute_code_tool())
|
||||
|
||||
|
|
@ -637,15 +663,25 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
runtime: ToolRuntime[None, FilesystemState],
|
||||
) -> Command | str:
|
||||
resolved_backend = self._get_backend(runtime)
|
||||
target_path = self._resolve_write_target_path(file_path, runtime)
|
||||
try:
|
||||
validated_path = validate_path(file_path)
|
||||
validated_path = validate_path(target_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
res: WriteResult = resolved_backend.write(validated_path, content)
|
||||
if res.error:
|
||||
return res.error
|
||||
verify_error = self._verify_written_content_sync(
|
||||
backend=resolved_backend,
|
||||
path=validated_path,
|
||||
expected_content=content,
|
||||
)
|
||||
if verify_error:
|
||||
return verify_error
|
||||
|
||||
if not self._is_kb_document(validated_path):
|
||||
if self._should_persist_documents() and not self._is_kb_document(
|
||||
validated_path
|
||||
):
|
||||
persist_result = self._run_async_blocking(
|
||||
self._persist_new_document(
|
||||
file_path=validated_path, content=content
|
||||
|
|
@ -682,15 +718,25 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
runtime: ToolRuntime[None, FilesystemState],
|
||||
) -> Command | str:
|
||||
resolved_backend = self._get_backend(runtime)
|
||||
target_path = self._resolve_write_target_path(file_path, runtime)
|
||||
try:
|
||||
validated_path = validate_path(file_path)
|
||||
validated_path = validate_path(target_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
res: WriteResult = await resolved_backend.awrite(validated_path, content)
|
||||
if res.error:
|
||||
return res.error
|
||||
verify_error = await self._verify_written_content_async(
|
||||
backend=resolved_backend,
|
||||
path=validated_path,
|
||||
expected_content=content,
|
||||
)
|
||||
if verify_error:
|
||||
return verify_error
|
||||
|
||||
if not self._is_kb_document(validated_path):
|
||||
if self._should_persist_documents() and not self._is_kb_document(
|
||||
validated_path
|
||||
):
|
||||
persist_result = await self._persist_new_document(
|
||||
file_path=validated_path,
|
||||
content=content,
|
||||
|
|
@ -726,6 +772,164 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
"""Return True for paths under /documents/ (KB-sourced, XML-wrapped)."""
|
||||
return path.startswith("/documents/")
|
||||
|
||||
def _should_persist_documents(self) -> bool:
|
||||
"""Only cloud mode persists file content to Document/Chunk tables."""
|
||||
return self._filesystem_mode == FilesystemMode.CLOUD
|
||||
|
||||
def _default_mount_prefix(self, runtime: ToolRuntime[None, FilesystemState]) -> str:
|
||||
backend = self._get_backend(runtime)
|
||||
if isinstance(backend, MultiRootLocalFolderBackend):
|
||||
return f"/{backend.default_mount()}"
|
||||
return ""
|
||||
|
||||
def _normalize_local_mount_path(
|
||||
self, candidate: str, runtime: ToolRuntime[None, FilesystemState]
|
||||
) -> str:
|
||||
backend = self._get_backend(runtime)
|
||||
mount_prefix = self._default_mount_prefix(runtime)
|
||||
normalized_candidate = re.sub(r"/+", "/", candidate.strip().replace("\\", "/"))
|
||||
if not mount_prefix or not isinstance(backend, MultiRootLocalFolderBackend):
|
||||
if normalized_candidate.startswith("/"):
|
||||
return normalized_candidate
|
||||
return f"/{normalized_candidate.lstrip('/')}"
|
||||
|
||||
mount_names = set(backend.list_mounts())
|
||||
if normalized_candidate.startswith("/"):
|
||||
first_segment = normalized_candidate.lstrip("/").split("/", 1)[0]
|
||||
if first_segment in mount_names:
|
||||
return normalized_candidate
|
||||
return f"{mount_prefix}{normalized_candidate}"
|
||||
|
||||
relative = normalized_candidate.lstrip("/")
|
||||
first_segment = relative.split("/", 1)[0]
|
||||
if first_segment in mount_names:
|
||||
return f"/{relative}"
|
||||
return f"{mount_prefix}/{relative}"
|
||||
|
||||
def _get_contract_suggested_path(
|
||||
self, runtime: ToolRuntime[None, FilesystemState]
|
||||
) -> str:
|
||||
contract = runtime.state.get("file_operation_contract") or {}
|
||||
suggested = contract.get("suggested_path")
|
||||
if isinstance(suggested, str) and suggested.strip():
|
||||
cleaned = suggested.strip()
|
||||
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
|
||||
return self._normalize_local_mount_path(cleaned, runtime)
|
||||
return cleaned
|
||||
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
|
||||
mount_prefix = self._default_mount_prefix(runtime)
|
||||
if mount_prefix:
|
||||
return f"{mount_prefix}/notes.md"
|
||||
return "/notes.md"
|
||||
|
||||
def _resolve_write_target_path(
|
||||
self,
|
||||
file_path: str,
|
||||
runtime: ToolRuntime[None, FilesystemState],
|
||||
) -> str:
|
||||
candidate = file_path.strip()
|
||||
if not candidate:
|
||||
return self._get_contract_suggested_path(runtime)
|
||||
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
|
||||
return self._normalize_local_mount_path(candidate, runtime)
|
||||
if not candidate.startswith("/"):
|
||||
return f"/{candidate.lstrip('/')}"
|
||||
return candidate
|
||||
|
||||
@staticmethod
|
||||
def _is_error_text(value: str) -> bool:
|
||||
return value.startswith("Error:")
|
||||
|
||||
@staticmethod
|
||||
def _read_for_verification_sync(backend: Any, path: str) -> str:
|
||||
read_raw = getattr(backend, "read_raw", None)
|
||||
if callable(read_raw):
|
||||
return read_raw(path)
|
||||
return backend.read(path, offset=0, limit=200000)
|
||||
|
||||
@staticmethod
|
||||
async def _read_for_verification_async(backend: Any, path: str) -> str:
|
||||
aread_raw = getattr(backend, "aread_raw", None)
|
||||
if callable(aread_raw):
|
||||
return await aread_raw(path)
|
||||
return await backend.aread(path, offset=0, limit=200000)
|
||||
|
||||
def _verify_written_content_sync(
|
||||
self,
|
||||
*,
|
||||
backend: Any,
|
||||
path: str,
|
||||
expected_content: str,
|
||||
) -> str | None:
|
||||
actual = self._read_for_verification_sync(backend, path)
|
||||
if self._is_error_text(actual):
|
||||
return f"Error: could not verify written file '{path}'."
|
||||
if actual.rstrip() != expected_content.rstrip():
|
||||
return (
|
||||
"Error: file write verification failed; expected content was not fully written "
|
||||
f"to '{path}'."
|
||||
)
|
||||
return None
|
||||
|
||||
async def _verify_written_content_async(
|
||||
self,
|
||||
*,
|
||||
backend: Any,
|
||||
path: str,
|
||||
expected_content: str,
|
||||
) -> str | None:
|
||||
actual = await self._read_for_verification_async(backend, path)
|
||||
if self._is_error_text(actual):
|
||||
return f"Error: could not verify written file '{path}'."
|
||||
if actual.rstrip() != expected_content.rstrip():
|
||||
return (
|
||||
"Error: file write verification failed; expected content was not fully written "
|
||||
f"to '{path}'."
|
||||
)
|
||||
return None
|
||||
|
||||
def _verify_edited_content_sync(
|
||||
self,
|
||||
*,
|
||||
backend: Any,
|
||||
path: str,
|
||||
new_string: str,
|
||||
) -> tuple[str | None, str | None]:
|
||||
updated_content = self._read_for_verification_sync(backend, path)
|
||||
if self._is_error_text(updated_content):
|
||||
return (
|
||||
f"Error: could not verify edited file '{path}'.",
|
||||
None,
|
||||
)
|
||||
if new_string and new_string not in updated_content:
|
||||
return (
|
||||
"Error: edit verification failed; updated content was not found in "
|
||||
f"'{path}'.",
|
||||
None,
|
||||
)
|
||||
return None, updated_content
|
||||
|
||||
async def _verify_edited_content_async(
|
||||
self,
|
||||
*,
|
||||
backend: Any,
|
||||
path: str,
|
||||
new_string: str,
|
||||
) -> tuple[str | None, str | None]:
|
||||
updated_content = await self._read_for_verification_async(backend, path)
|
||||
if self._is_error_text(updated_content):
|
||||
return (
|
||||
f"Error: could not verify edited file '{path}'.",
|
||||
None,
|
||||
)
|
||||
if new_string and new_string not in updated_content:
|
||||
return (
|
||||
"Error: edit verification failed; updated content was not found in "
|
||||
f"'{path}'.",
|
||||
None,
|
||||
)
|
||||
return None, updated_content
|
||||
|
||||
def _create_edit_file_tool(self) -> BaseTool:
|
||||
"""Create edit_file with DB persistence (skipped for KB documents)."""
|
||||
tool_description = (
|
||||
|
|
@ -754,8 +958,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
] = False,
|
||||
) -> Command | str:
|
||||
resolved_backend = self._get_backend(runtime)
|
||||
target_path = self._resolve_write_target_path(file_path, runtime)
|
||||
try:
|
||||
validated_path = validate_path(file_path)
|
||||
validated_path = validate_path(target_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
res: EditResult = resolved_backend.edit(
|
||||
|
|
@ -767,13 +972,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
if res.error:
|
||||
return res.error
|
||||
|
||||
if not self._is_kb_document(validated_path):
|
||||
read_result = resolved_backend.read(
|
||||
validated_path, offset=0, limit=200000
|
||||
)
|
||||
if read_result.error or read_result.file_data is None:
|
||||
return f"Error: could not reload edited file '{validated_path}' for persistence."
|
||||
updated_content = read_result.file_data["content"]
|
||||
verify_error, updated_content = self._verify_edited_content_sync(
|
||||
backend=resolved_backend,
|
||||
path=validated_path,
|
||||
new_string=new_string,
|
||||
)
|
||||
if verify_error:
|
||||
return verify_error
|
||||
|
||||
if self._should_persist_documents() and not self._is_kb_document(
|
||||
validated_path
|
||||
):
|
||||
if updated_content is None:
|
||||
return (
|
||||
f"Error: could not reload edited file '{validated_path}' for "
|
||||
"persistence."
|
||||
)
|
||||
persist_result = self._run_async_blocking(
|
||||
self._persist_edited_document(
|
||||
file_path=validated_path,
|
||||
|
|
@ -818,8 +1032,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
] = False,
|
||||
) -> Command | str:
|
||||
resolved_backend = self._get_backend(runtime)
|
||||
target_path = self._resolve_write_target_path(file_path, runtime)
|
||||
try:
|
||||
validated_path = validate_path(file_path)
|
||||
validated_path = validate_path(target_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
res: EditResult = await resolved_backend.aedit(
|
||||
|
|
@ -831,13 +1046,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
if res.error:
|
||||
return res.error
|
||||
|
||||
if not self._is_kb_document(validated_path):
|
||||
read_result = await resolved_backend.aread(
|
||||
validated_path, offset=0, limit=200000
|
||||
)
|
||||
if read_result.error or read_result.file_data is None:
|
||||
return f"Error: could not reload edited file '{validated_path}' for persistence."
|
||||
updated_content = read_result.file_data["content"]
|
||||
verify_error, updated_content = await self._verify_edited_content_async(
|
||||
backend=resolved_backend,
|
||||
path=validated_path,
|
||||
new_string=new_string,
|
||||
)
|
||||
if verify_error:
|
||||
return verify_error
|
||||
|
||||
if self._should_persist_documents() and not self._is_kb_document(
|
||||
validated_path
|
||||
):
|
||||
if updated_content is None:
|
||||
return (
|
||||
f"Error: could not reload edited file '{validated_path}' for "
|
||||
"persistence."
|
||||
)
|
||||
persist_error = await self._persist_edited_document(
|
||||
file_path=validated_path,
|
||||
updated_content=updated_content,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import (
|
||||
NATIVE_TO_LEGACY_DOCTYPE,
|
||||
Chunk,
|
||||
|
|
@ -857,6 +858,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
*,
|
||||
llm: BaseChatModel | None = None,
|
||||
search_space_id: int,
|
||||
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
|
|
@ -865,6 +867,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
) -> None:
|
||||
self.llm = llm
|
||||
self.search_space_id = search_space_id
|
||||
self.filesystem_mode = filesystem_mode
|
||||
self.available_connectors = available_connectors
|
||||
self.available_document_types = available_document_types
|
||||
self.top_k = top_k
|
||||
|
|
@ -996,6 +999,9 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||
# Local-folder mode should not seed cloud KB documents into filesystem.
|
||||
return None
|
||||
|
||||
last_human = None
|
||||
for msg in reversed(messages):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,316 @@
|
|||
"""Desktop local-folder filesystem backend for deepagents tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import fnmatch
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from deepagents.backends.protocol import (
|
||||
EditResult,
|
||||
FileDownloadResponse,
|
||||
FileInfo,
|
||||
FileUploadResponse,
|
||||
GrepMatch,
|
||||
WriteResult,
|
||||
)
|
||||
from deepagents.backends.utils import (
|
||||
create_file_data,
|
||||
format_read_response,
|
||||
perform_string_replacement,
|
||||
)
|
||||
|
||||
_INVALID_PATH = "invalid_path"
|
||||
_FILE_NOT_FOUND = "file_not_found"
|
||||
_IS_DIRECTORY = "is_directory"
|
||||
|
||||
|
||||
class LocalFolderBackend:
|
||||
"""Filesystem backend rooted to a single local folder."""
|
||||
|
||||
def __init__(self, root_path: str) -> None:
|
||||
root = Path(root_path).expanduser().resolve()
|
||||
if not root.exists() or not root.is_dir():
|
||||
msg = f"Local filesystem root does not exist or is not a directory: {root_path}"
|
||||
raise ValueError(msg)
|
||||
self._root = root
|
||||
self._locks: dict[str, threading.Lock] = {}
|
||||
self._locks_mu = threading.Lock()
|
||||
|
||||
def _lock_for(self, path: str) -> threading.Lock:
|
||||
with self._locks_mu:
|
||||
if path not in self._locks:
|
||||
self._locks[path] = threading.Lock()
|
||||
return self._locks[path]
|
||||
|
||||
def _resolve_virtual(self, virtual_path: str, *, allow_root: bool = False) -> Path:
|
||||
if not virtual_path.startswith("/"):
|
||||
msg = f"Invalid path (must be absolute): {virtual_path}"
|
||||
raise ValueError(msg)
|
||||
rel = virtual_path.lstrip("/")
|
||||
candidate = self._root if rel == "" else (self._root / rel)
|
||||
resolved = candidate.resolve()
|
||||
if not allow_root and resolved == self._root:
|
||||
msg = "Path must refer to a file or child directory under root"
|
||||
raise ValueError(msg)
|
||||
if not resolved.is_relative_to(self._root):
|
||||
msg = f"Path escapes local filesystem root: {virtual_path}"
|
||||
raise ValueError(msg)
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _to_virtual(path: Path, root: Path) -> str:
|
||||
rel = path.relative_to(root).as_posix()
|
||||
return "/" if rel == "." else f"/{rel}"
|
||||
|
||||
def _write_text_atomic(self, path: Path, content: str) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
temp_path = path.with_suffix(f"{path.suffix}.tmp")
|
||||
temp_path.write_text(content, encoding="utf-8")
|
||||
os.replace(temp_path, path)
|
||||
|
||||
def ls_info(self, path: str) -> list[FileInfo]:
|
||||
try:
|
||||
target = self._resolve_virtual(path, allow_root=True)
|
||||
except ValueError:
|
||||
return []
|
||||
if not target.exists() or not target.is_dir():
|
||||
return []
|
||||
infos: list[FileInfo] = []
|
||||
for child in sorted(target.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())):
|
||||
infos.append(
|
||||
FileInfo(
|
||||
path=self._to_virtual(child, self._root),
|
||||
is_dir=child.is_dir(),
|
||||
size=child.stat().st_size if child.is_file() else 0,
|
||||
modified_at=str(child.stat().st_mtime),
|
||||
)
|
||||
)
|
||||
return infos
|
||||
|
||||
async def als_info(self, path: str) -> list[FileInfo]:
|
||||
return await asyncio.to_thread(self.ls_info, path)
|
||||
|
||||
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||
try:
|
||||
path = self._resolve_virtual(file_path)
|
||||
except ValueError:
|
||||
return f"Error: Invalid path '{file_path}'"
|
||||
if not path.exists():
|
||||
return f"Error: File '{file_path}' not found"
|
||||
if not path.is_file():
|
||||
return f"Error: Path '{file_path}' is not a file"
|
||||
content = path.read_text(encoding="utf-8", errors="replace")
|
||||
file_data = create_file_data(content)
|
||||
return format_read_response(file_data, offset, limit)
|
||||
|
||||
async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||
return await asyncio.to_thread(self.read, file_path, offset, limit)
|
||||
|
||||
def read_raw(self, file_path: str) -> str:
|
||||
"""Read raw file text without line-number formatting."""
|
||||
try:
|
||||
path = self._resolve_virtual(file_path)
|
||||
except ValueError:
|
||||
return f"Error: Invalid path '{file_path}'"
|
||||
if not path.exists():
|
||||
return f"Error: File '{file_path}' not found"
|
||||
if not path.is_file():
|
||||
return f"Error: Path '{file_path}' is not a file"
|
||||
return path.read_text(encoding="utf-8", errors="replace")
|
||||
|
||||
async def aread_raw(self, file_path: str) -> str:
|
||||
"""Async variant of read_raw."""
|
||||
return await asyncio.to_thread(self.read_raw, file_path)
|
||||
|
||||
def write(self, file_path: str, content: str) -> WriteResult:
|
||||
try:
|
||||
path = self._resolve_virtual(file_path)
|
||||
except ValueError:
|
||||
return WriteResult(error=f"Error: Invalid path '{file_path}'")
|
||||
lock = self._lock_for(file_path)
|
||||
with lock:
|
||||
if path.exists():
|
||||
return WriteResult(
|
||||
error=(
|
||||
f"Cannot write to {file_path} because it already exists. "
|
||||
"Read and then make an edit, or write to a new path."
|
||||
)
|
||||
)
|
||||
self._write_text_atomic(path, content)
|
||||
return WriteResult(path=file_path, files_update=None)
|
||||
|
||||
async def awrite(self, file_path: str, content: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.write, file_path, content)
|
||||
|
||||
def edit(
|
||||
self,
|
||||
file_path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
) -> EditResult:
|
||||
try:
|
||||
path = self._resolve_virtual(file_path)
|
||||
except ValueError:
|
||||
return EditResult(error=f"Error: Invalid path '{file_path}'")
|
||||
lock = self._lock_for(file_path)
|
||||
with lock:
|
||||
if not path.exists() or not path.is_file():
|
||||
return EditResult(error=f"Error: File '{file_path}' not found")
|
||||
content = path.read_text(encoding="utf-8", errors="replace")
|
||||
result = perform_string_replacement(content, old_string, new_string, replace_all)
|
||||
if isinstance(result, str):
|
||||
return EditResult(error=result)
|
||||
updated_content, occurrences = result
|
||||
self._write_text_atomic(path, updated_content)
|
||||
return EditResult(path=file_path, files_update=None, occurrences=int(occurrences))
|
||||
|
||||
async def aedit(
|
||||
self,
|
||||
file_path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
) -> EditResult:
|
||||
return await asyncio.to_thread(
|
||||
self.edit, file_path, old_string, new_string, replace_all
|
||||
)
|
||||
|
||||
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||
try:
|
||||
base = self._resolve_virtual(path, allow_root=True)
|
||||
except ValueError:
|
||||
return []
|
||||
|
||||
if pattern.startswith("/"):
|
||||
search_base = self._root
|
||||
normalized_pattern = pattern.lstrip("/")
|
||||
else:
|
||||
search_base = base
|
||||
normalized_pattern = pattern
|
||||
|
||||
matches: list[FileInfo] = []
|
||||
for hit in search_base.glob(normalized_pattern):
|
||||
try:
|
||||
resolved = hit.resolve()
|
||||
if not resolved.is_relative_to(self._root):
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
matches.append(
|
||||
FileInfo(
|
||||
path=self._to_virtual(resolved, self._root),
|
||||
is_dir=resolved.is_dir(),
|
||||
size=resolved.stat().st_size if resolved.is_file() else 0,
|
||||
modified_at=str(resolved.stat().st_mtime),
|
||||
)
|
||||
)
|
||||
return matches
|
||||
|
||||
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||
return await asyncio.to_thread(self.glob_info, pattern, path)
|
||||
|
||||
def _iter_candidate_files(self, path: str | None, glob: str | None) -> list[Path]:
|
||||
base_virtual = path or "/"
|
||||
try:
|
||||
base = self._resolve_virtual(base_virtual, allow_root=True)
|
||||
except ValueError:
|
||||
return []
|
||||
if not base.exists():
|
||||
return []
|
||||
|
||||
candidates = [p for p in base.rglob("*") if p.is_file()]
|
||||
if glob:
|
||||
candidates = [
|
||||
p
|
||||
for p in candidates
|
||||
if fnmatch.fnmatch(self._to_virtual(p, self._root), glob)
|
||||
or fnmatch.fnmatch(p.name, glob)
|
||||
]
|
||||
return candidates
|
||||
|
||||
def grep_raw(
|
||||
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||
) -> list[GrepMatch] | str:
|
||||
if not pattern:
|
||||
return "Error: pattern cannot be empty"
|
||||
matches: list[GrepMatch] = []
|
||||
for file_path in self._iter_candidate_files(path, glob):
|
||||
try:
|
||||
lines = file_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
||||
except Exception:
|
||||
continue
|
||||
for idx, line in enumerate(lines, start=1):
|
||||
if pattern in line:
|
||||
matches.append(
|
||||
GrepMatch(
|
||||
path=self._to_virtual(file_path, self._root),
|
||||
line=idx,
|
||||
text=line,
|
||||
)
|
||||
)
|
||||
return matches
|
||||
|
||||
async def agrep_raw(
|
||||
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||
) -> list[GrepMatch] | str:
|
||||
return await asyncio.to_thread(self.grep_raw, pattern, path, glob)
|
||||
|
||||
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||
responses: list[FileUploadResponse] = []
|
||||
for virtual_path, content in files:
|
||||
try:
|
||||
target = self._resolve_virtual(virtual_path)
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
temp_path = target.with_suffix(f"{target.suffix}.tmp")
|
||||
temp_path.write_bytes(content)
|
||||
os.replace(temp_path, target)
|
||||
responses.append(FileUploadResponse(path=virtual_path, error=None))
|
||||
except FileNotFoundError:
|
||||
responses.append(
|
||||
FileUploadResponse(path=virtual_path, error=_FILE_NOT_FOUND)
|
||||
)
|
||||
except IsADirectoryError:
|
||||
responses.append(FileUploadResponse(path=virtual_path, error=_IS_DIRECTORY))
|
||||
except Exception:
|
||||
responses.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH))
|
||||
return responses
|
||||
|
||||
async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||
return await asyncio.to_thread(self.upload_files, files)
|
||||
|
||||
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
responses: list[FileDownloadResponse] = []
|
||||
for virtual_path in paths:
|
||||
try:
|
||||
target = self._resolve_virtual(virtual_path)
|
||||
if not target.exists():
|
||||
responses.append(
|
||||
FileDownloadResponse(
|
||||
path=virtual_path, content=None, error=_FILE_NOT_FOUND
|
||||
)
|
||||
)
|
||||
continue
|
||||
if target.is_dir():
|
||||
responses.append(
|
||||
FileDownloadResponse(
|
||||
path=virtual_path, content=None, error=_IS_DIRECTORY
|
||||
)
|
||||
)
|
||||
continue
|
||||
responses.append(
|
||||
FileDownloadResponse(
|
||||
path=virtual_path, content=target.read_bytes(), error=None
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
responses.append(
|
||||
FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH)
|
||||
)
|
||||
return responses
|
||||
|
||||
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
return await asyncio.to_thread(self.download_files, paths)
|
||||
|
|
@ -0,0 +1,329 @@
|
|||
"""Aggregate multiple LocalFolderBackend roots behind mount-prefixed virtual paths."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from deepagents.backends.protocol import (
|
||||
EditResult,
|
||||
FileDownloadResponse,
|
||||
FileInfo,
|
||||
FileUploadResponse,
|
||||
GrepMatch,
|
||||
WriteResult,
|
||||
)
|
||||
|
||||
from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend
|
||||
|
||||
_INVALID_PATH = "invalid_path"
|
||||
_FILE_NOT_FOUND = "file_not_found"
|
||||
_IS_DIRECTORY = "is_directory"
|
||||
|
||||
|
||||
class MultiRootLocalFolderBackend:
|
||||
"""Route filesystem operations to one of several mounted local roots.
|
||||
|
||||
Virtual paths are namespaced as:
|
||||
- `/<mount>/...`
|
||||
where `<mount>` is derived from each selected root folder name.
|
||||
"""
|
||||
|
||||
def __init__(self, mounts: tuple[tuple[str, str], ...]) -> None:
|
||||
if not mounts:
|
||||
msg = "At least one local mount is required"
|
||||
raise ValueError(msg)
|
||||
self._mount_to_backend: dict[str, LocalFolderBackend] = {}
|
||||
for raw_mount, raw_root in mounts:
|
||||
mount = raw_mount.strip()
|
||||
if not mount:
|
||||
msg = "Mount id cannot be empty"
|
||||
raise ValueError(msg)
|
||||
if mount in self._mount_to_backend:
|
||||
msg = f"Duplicate mount id: {mount}"
|
||||
raise ValueError(msg)
|
||||
normalized_root = str(Path(raw_root).expanduser().resolve())
|
||||
self._mount_to_backend[mount] = LocalFolderBackend(normalized_root)
|
||||
self._mount_order = tuple(self._mount_to_backend.keys())
|
||||
|
||||
def list_mounts(self) -> tuple[str, ...]:
|
||||
return self._mount_order
|
||||
|
||||
def default_mount(self) -> str:
|
||||
return self._mount_order[0]
|
||||
|
||||
def _mount_error(self) -> str:
|
||||
mounts = ", ".join(f"/{mount}" for mount in self._mount_order)
|
||||
return (
|
||||
"Path must start with one of the selected folders: "
|
||||
f"{mounts}. Example: /{self._mount_order[0]}/file.txt"
|
||||
)
|
||||
|
||||
def _split_mount_path(self, virtual_path: str) -> tuple[str, str]:
|
||||
if not virtual_path.startswith("/"):
|
||||
msg = f"Invalid path (must be absolute): {virtual_path}"
|
||||
raise ValueError(msg)
|
||||
rel = virtual_path.lstrip("/")
|
||||
if not rel:
|
||||
raise ValueError(self._mount_error())
|
||||
mount, _, remainder = rel.partition("/")
|
||||
backend = self._mount_to_backend.get(mount)
|
||||
if backend is None:
|
||||
raise ValueError(self._mount_error())
|
||||
local_path = f"/{remainder}" if remainder else "/"
|
||||
return mount, local_path
|
||||
|
||||
@staticmethod
|
||||
def _prefix_mount_path(mount: str, local_path: str) -> str:
|
||||
if local_path == "/":
|
||||
return f"/{mount}"
|
||||
return f"/{mount}{local_path}"
|
||||
|
||||
@staticmethod
|
||||
def _get_value(item: Any, key: str) -> Any:
|
||||
if isinstance(item, dict):
|
||||
return item.get(key)
|
||||
return getattr(item, key, None)
|
||||
|
||||
@classmethod
|
||||
def _get_str(cls, item: Any, key: str) -> str:
|
||||
value = cls._get_value(item, key)
|
||||
return value if isinstance(value, str) else ""
|
||||
|
||||
@classmethod
|
||||
def _get_int(cls, item: Any, key: str) -> int:
|
||||
value = cls._get_value(item, key)
|
||||
return int(value) if isinstance(value, int | float) else 0
|
||||
|
||||
@classmethod
|
||||
def _get_bool(cls, item: Any, key: str) -> bool:
|
||||
value = cls._get_value(item, key)
|
||||
return bool(value)
|
||||
|
||||
def _list_mount_roots(self) -> list[FileInfo]:
|
||||
return [
|
||||
FileInfo(path=f"/{mount}", is_dir=True, size=0, modified_at="0")
|
||||
for mount in self._mount_order
|
||||
]
|
||||
|
||||
def _transform_infos(self, mount: str, infos: list[FileInfo]) -> list[FileInfo]:
|
||||
transformed: list[FileInfo] = []
|
||||
for info in infos:
|
||||
transformed.append(
|
||||
FileInfo(
|
||||
path=self._prefix_mount_path(mount, self._get_str(info, "path")),
|
||||
is_dir=self._get_bool(info, "is_dir"),
|
||||
size=self._get_int(info, "size"),
|
||||
modified_at=self._get_str(info, "modified_at"),
|
||||
)
|
||||
)
|
||||
return transformed
|
||||
|
||||
def ls_info(self, path: str) -> list[FileInfo]:
|
||||
if path == "/":
|
||||
return self._list_mount_roots()
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(path)
|
||||
except ValueError:
|
||||
return []
|
||||
return self._transform_infos(mount, self._mount_to_backend[mount].ls_info(local_path))
|
||||
|
||||
async def als_info(self, path: str) -> list[FileInfo]:
|
||||
return await asyncio.to_thread(self.ls_info, path)
|
||||
|
||||
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(file_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
return self._mount_to_backend[mount].read(local_path, offset, limit)
|
||||
|
||||
async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||
return await asyncio.to_thread(self.read, file_path, offset, limit)
|
||||
|
||||
def read_raw(self, file_path: str) -> str:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(file_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
return self._mount_to_backend[mount].read_raw(local_path)
|
||||
|
||||
async def aread_raw(self, file_path: str) -> str:
|
||||
return await asyncio.to_thread(self.read_raw, file_path)
|
||||
|
||||
def write(self, file_path: str, content: str) -> WriteResult:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(file_path)
|
||||
except ValueError as exc:
|
||||
return WriteResult(error=f"Error: {exc}")
|
||||
result = self._mount_to_backend[mount].write(local_path, content)
|
||||
if result.path:
|
||||
result.path = self._prefix_mount_path(mount, result.path)
|
||||
return result
|
||||
|
||||
async def awrite(self, file_path: str, content: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.write, file_path, content)
|
||||
|
||||
def edit(
|
||||
self,
|
||||
file_path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
) -> EditResult:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(file_path)
|
||||
except ValueError as exc:
|
||||
return EditResult(error=f"Error: {exc}")
|
||||
result = self._mount_to_backend[mount].edit(
|
||||
local_path, old_string, new_string, replace_all
|
||||
)
|
||||
if result.path:
|
||||
result.path = self._prefix_mount_path(mount, result.path)
|
||||
return result
|
||||
|
||||
async def aedit(
|
||||
self,
|
||||
file_path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
) -> EditResult:
|
||||
return await asyncio.to_thread(
|
||||
self.edit, file_path, old_string, new_string, replace_all
|
||||
)
|
||||
|
||||
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||
if path == "/":
|
||||
prefixed_results: list[FileInfo] = []
|
||||
if pattern.startswith("/"):
|
||||
mount, _, remainder = pattern.lstrip("/").partition("/")
|
||||
backend = self._mount_to_backend.get(mount)
|
||||
if not backend:
|
||||
return []
|
||||
local_pattern = f"/{remainder}" if remainder else "/"
|
||||
return self._transform_infos(
|
||||
mount, backend.glob_info(local_pattern, path="/")
|
||||
)
|
||||
for mount, backend in self._mount_to_backend.items():
|
||||
prefixed_results.extend(
|
||||
self._transform_infos(mount, backend.glob_info(pattern, path="/"))
|
||||
)
|
||||
return prefixed_results
|
||||
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(path)
|
||||
except ValueError:
|
||||
return []
|
||||
return self._transform_infos(
|
||||
mount, self._mount_to_backend[mount].glob_info(pattern, path=local_path)
|
||||
)
|
||||
|
||||
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||
return await asyncio.to_thread(self.glob_info, pattern, path)
|
||||
|
||||
def grep_raw(
|
||||
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||
) -> list[GrepMatch] | str:
|
||||
if not pattern:
|
||||
return "Error: pattern cannot be empty"
|
||||
if path is None or path == "/":
|
||||
all_matches: list[GrepMatch] = []
|
||||
for mount, backend in self._mount_to_backend.items():
|
||||
result = backend.grep_raw(pattern, path="/", glob=glob)
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
all_matches.extend(
|
||||
[
|
||||
GrepMatch(
|
||||
path=self._prefix_mount_path(mount, self._get_str(match, "path")),
|
||||
line=self._get_int(match, "line"),
|
||||
text=self._get_str(match, "text"),
|
||||
)
|
||||
for match in result
|
||||
]
|
||||
)
|
||||
return all_matches
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
|
||||
result = self._mount_to_backend[mount].grep_raw(
|
||||
pattern, path=local_path, glob=glob
|
||||
)
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
return [
|
||||
GrepMatch(
|
||||
path=self._prefix_mount_path(mount, self._get_str(match, "path")),
|
||||
line=self._get_int(match, "line"),
|
||||
text=self._get_str(match, "text"),
|
||||
)
|
||||
for match in result
|
||||
]
|
||||
|
||||
async def agrep_raw(
|
||||
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||
) -> list[GrepMatch] | str:
|
||||
return await asyncio.to_thread(self.grep_raw, pattern, path, glob)
|
||||
|
||||
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||
grouped: dict[str, list[tuple[str, bytes]]] = {}
|
||||
invalid: list[FileUploadResponse] = []
|
||||
for virtual_path, content in files:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(virtual_path)
|
||||
except ValueError:
|
||||
invalid.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH))
|
||||
continue
|
||||
grouped.setdefault(mount, []).append((local_path, content))
|
||||
|
||||
responses = list(invalid)
|
||||
for mount, mount_files in grouped.items():
|
||||
result = self._mount_to_backend[mount].upload_files(mount_files)
|
||||
responses.extend(
|
||||
[
|
||||
FileUploadResponse(
|
||||
path=self._prefix_mount_path(mount, self._get_str(item, "path")),
|
||||
error=self._get_str(item, "error") or None,
|
||||
)
|
||||
for item in result
|
||||
]
|
||||
)
|
||||
return responses
|
||||
|
||||
async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||
return await asyncio.to_thread(self.upload_files, files)
|
||||
|
||||
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
grouped: dict[str, list[str]] = {}
|
||||
invalid: list[FileDownloadResponse] = []
|
||||
for virtual_path in paths:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(virtual_path)
|
||||
except ValueError:
|
||||
invalid.append(
|
||||
FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH)
|
||||
)
|
||||
continue
|
||||
grouped.setdefault(mount, []).append(local_path)
|
||||
|
||||
responses = list(invalid)
|
||||
for mount, mount_paths in grouped.items():
|
||||
result = self._mount_to_backend[mount].download_files(mount_paths)
|
||||
responses.extend(
|
||||
[
|
||||
FileDownloadResponse(
|
||||
path=self._prefix_mount_path(mount, self._get_str(item, "path")),
|
||||
content=self._get_value(item, "content"),
|
||||
error=self._get_str(item, "error") or None,
|
||||
)
|
||||
for item in result
|
||||
]
|
||||
)
|
||||
return responses
|
||||
|
||||
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
return await asyncio.to_thread(self.download_files, paths)
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
"""Safe wrapper around deepagents' SummarizationMiddleware.
|
||||
|
||||
Upstream issue
|
||||
--------------
|
||||
`deepagents.middleware.summarization.SummarizationMiddleware._aoffload_to_backend`
|
||||
(and its sync counterpart) call
|
||||
``get_buffer_string(filtered_messages)`` before writing the evicted history
|
||||
to the backend file. In recent ``langchain-core`` versions, ``get_buffer_string``
|
||||
accesses ``m.text`` which iterates ``self.content`` — this raises
|
||||
``TypeError: 'NoneType' object is not iterable`` whenever an ``AIMessage``
|
||||
has ``content=None`` (common when a model returns *only* tool_calls, seen
|
||||
frequently with Azure OpenAI ``gpt-5.x`` responses streamed through
|
||||
LiteLLM).
|
||||
|
||||
The exception aborts the whole agent turn, so the user just sees "Error during
|
||||
chat" with no assistant response.
|
||||
|
||||
Fix
|
||||
---
|
||||
We subclass ``SummarizationMiddleware`` and override
|
||||
``_filter_summary_messages`` — the only call site that feeds messages into
|
||||
``get_buffer_string`` — to return *copies* of messages whose ``content`` is
|
||||
``None`` with ``content=""``. The originals flowing through the rest of the
|
||||
agent state are untouched.
|
||||
|
||||
We also expose a drop-in ``create_safe_summarization_middleware`` factory
|
||||
that mirrors ``deepagents.middleware.summarization.create_summarization_middleware``
|
||||
but instantiates our safe subclass.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deepagents.middleware.summarization import (
|
||||
SummarizationMiddleware,
|
||||
compute_summarization_defaults,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deepagents.backends.protocol import BACKEND_TYPES
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AnyMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
|
||||
"""Return ``msg`` with ``content`` coerced to a non-``None`` value.
|
||||
|
||||
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``;
|
||||
when a provider streams back an ``AIMessage`` with only tool_calls and
|
||||
no text, ``content`` can be ``None`` and the iteration explodes. We
|
||||
replace ``None`` with an empty string so downstream consumers that only
|
||||
care about text see an empty body.
|
||||
|
||||
The original message is left untouched — we return a copy via
|
||||
pydantic's ``model_copy`` when available, otherwise we fall back to
|
||||
re-setting the attribute on a shallow copy.
|
||||
"""
|
||||
|
||||
if getattr(msg, "content", "not-missing") is not None:
|
||||
return msg
|
||||
|
||||
try:
|
||||
return msg.model_copy(update={"content": ""})
|
||||
except AttributeError:
|
||||
import copy
|
||||
|
||||
new_msg = copy.copy(msg)
|
||||
try:
|
||||
new_msg.content = ""
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.debug(
|
||||
"Could not sanitize content=None on message of type %s",
|
||||
type(msg).__name__,
|
||||
)
|
||||
return msg
|
||||
return new_msg
|
||||
|
||||
|
||||
class SafeSummarizationMiddleware(SummarizationMiddleware):
|
||||
"""`SummarizationMiddleware` that tolerates messages with ``content=None``.
|
||||
|
||||
Only ``_filter_summary_messages`` is overridden — this is the single
|
||||
helper invoked by both the sync and async offload paths immediately
|
||||
before ``get_buffer_string``. Normalising here means we get coverage
|
||||
for both without having to copy the (long, rapidly-changing) offload
|
||||
implementations from upstream.
|
||||
"""
|
||||
|
||||
def _filter_summary_messages(self, messages: list[AnyMessage]) -> list[AnyMessage]:
|
||||
filtered = super()._filter_summary_messages(messages)
|
||||
return [_sanitize_message_content(m) for m in filtered]
|
||||
|
||||
|
||||
def create_safe_summarization_middleware(
|
||||
model: BaseChatModel,
|
||||
backend: BACKEND_TYPES,
|
||||
) -> SafeSummarizationMiddleware:
|
||||
"""Drop-in replacement for ``create_summarization_middleware``.
|
||||
|
||||
Mirrors the defaults computed by ``deepagents`` but returns our
|
||||
``SafeSummarizationMiddleware`` subclass so the
|
||||
``content=None`` crash in ``get_buffer_string`` is avoided.
|
||||
"""
|
||||
|
||||
defaults = compute_summarization_defaults(model)
|
||||
return SafeSummarizationMiddleware(
|
||||
model=model,
|
||||
backend=backend,
|
||||
trigger=defaults["trigger"],
|
||||
keep=defaults["keep"],
|
||||
trim_tokens_to_summarize=None,
|
||||
truncate_args_settings=defaults["truncate_args_settings"],
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SafeSummarizationMiddleware",
|
||||
"create_safe_summarization_middleware",
|
||||
]
|
||||
|
|
@ -38,8 +38,66 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
|||
* Formatting, summarization, or analysis of content already present in the conversation
|
||||
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
||||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see <tool_routing> below
|
||||
</knowledge_base_only_policy>
|
||||
|
||||
<tool_routing>
|
||||
CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
|
||||
Their data is NEVER in the knowledge base. You MUST call their tools immediately — never
|
||||
say "I don't see it in the knowledge base" or ask the user if they want you to check.
|
||||
Ignore any knowledge base results for these services.
|
||||
|
||||
When to use which tool:
|
||||
- Linear (issues) → list_issues, get_issue, save_issue (create/update)
|
||||
- ClickUp (tasks) → clickup_search, clickup_get_task
|
||||
- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
|
||||
- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread
|
||||
- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table
|
||||
- Knowledge base content (Notion, GitHub, files, notes) → automatically searched
|
||||
- Real-time public web data → call web_search
|
||||
- Reading a specific webpage → call scrape_webpage
|
||||
</tool_routing>
|
||||
|
||||
<parameter_resolution>
|
||||
Some service tools require identifiers or context you do not have (account IDs,
|
||||
workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw
|
||||
IDs or technical identifiers — they cannot memorise them.
|
||||
|
||||
Instead, follow this discovery pattern:
|
||||
1. Call a listing/discovery tool to find available options.
|
||||
2. ONE result → use it silently, no question to the user.
|
||||
3. MULTIPLE results → present the options by their display names and let the
|
||||
user choose. Never show raw UUIDs — always use friendly names.
|
||||
|
||||
Discovery tools by level:
|
||||
- Which account/workspace? → get_connected_accounts("<service>")
|
||||
- Which Jira site (cloudId)? → getAccessibleAtlassianResources
|
||||
- Which Jira project? → getVisibleJiraProjects (after resolving cloudId)
|
||||
- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project)
|
||||
- Which channel? → slack_search_channels
|
||||
- Which base? → list_bases
|
||||
- Which table? → list_tables_for_base (after resolving baseId)
|
||||
- Which task? → clickup_search
|
||||
- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira)
|
||||
|
||||
For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to
|
||||
obtain the cloudId, then pass it to other Jira tools. When creating an issue,
|
||||
chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue.
|
||||
If there is only one option at each step, use it silently. If multiple, present
|
||||
friendly names.
|
||||
|
||||
Chain discovery when needed — e.g. for Airtable records: list_bases → pick
|
||||
base → list_tables_for_base → pick table → list_records_for_table.
|
||||
|
||||
MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for
|
||||
the same service, tool names are prefixed to avoid collisions — e.g.
|
||||
linear_25_list_issues and linear_30_list_issues instead of two list_issues.
|
||||
Each prefixed tool's description starts with [Account: <display_name>] so you
|
||||
know which account it targets. Use get_connected_accounts("<service>") to see
|
||||
the full list of accounts with their connector IDs and display names.
|
||||
When only one account is connected, tools have their normal unprefixed names.
|
||||
</parameter_resolution>
|
||||
|
||||
<memory_protocol>
|
||||
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||
reveal durable facts about the user (role, interests, preferences, projects,
|
||||
|
|
@ -76,8 +134,66 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
|||
* Formatting, summarization, or analysis of content already present in the conversation
|
||||
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
||||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see <tool_routing> below
|
||||
</knowledge_base_only_policy>
|
||||
|
||||
<tool_routing>
|
||||
CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
|
||||
Their data is NEVER in the knowledge base. You MUST call their tools immediately — never
|
||||
say "I don't see it in the knowledge base" or ask if they want you to check.
|
||||
Ignore any knowledge base results for these services.
|
||||
|
||||
When to use which tool:
|
||||
- Linear (issues) → list_issues, get_issue, save_issue (create/update)
|
||||
- ClickUp (tasks) → clickup_search, clickup_get_task
|
||||
- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
|
||||
- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread
|
||||
- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table
|
||||
- Knowledge base content (Notion, GitHub, files, notes) → automatically searched
|
||||
- Real-time public web data → call web_search
|
||||
- Reading a specific webpage → call scrape_webpage
|
||||
</tool_routing>
|
||||
|
||||
<parameter_resolution>
|
||||
Some service tools require identifiers or context you do not have (account IDs,
|
||||
workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw
|
||||
IDs or technical identifiers — they cannot memorise them.
|
||||
|
||||
Instead, follow this discovery pattern:
|
||||
1. Call a listing/discovery tool to find available options.
|
||||
2. ONE result → use it silently, no question to the user.
|
||||
3. MULTIPLE results → present the options by their display names and let the
|
||||
user choose. Never show raw UUIDs — always use friendly names.
|
||||
|
||||
Discovery tools by level:
|
||||
- Which account/workspace? → get_connected_accounts("<service>")
|
||||
- Which Jira site (cloudId)? → getAccessibleAtlassianResources
|
||||
- Which Jira project? → getVisibleJiraProjects (after resolving cloudId)
|
||||
- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project)
|
||||
- Which channel? → slack_search_channels
|
||||
- Which base? → list_bases
|
||||
- Which table? → list_tables_for_base (after resolving baseId)
|
||||
- Which task? → clickup_search
|
||||
- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira)
|
||||
|
||||
For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to
|
||||
obtain the cloudId, then pass it to other Jira tools. When creating an issue,
|
||||
chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue.
|
||||
If there is only one option at each step, use it silently. If multiple, present
|
||||
friendly names.
|
||||
|
||||
Chain discovery when needed — e.g. for Airtable records: list_bases → pick
|
||||
base → list_tables_for_base → pick table → list_records_for_table.
|
||||
|
||||
MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for
|
||||
the same service, tool names are prefixed to avoid collisions — e.g.
|
||||
linear_25_list_issues and linear_30_list_issues instead of two list_issues.
|
||||
Each prefixed tool's description starts with [Account: <display_name>] so you
|
||||
know which account it targets. Use get_connected_accounts("<service>") to see
|
||||
the full list of accounts with their connector IDs and display names.
|
||||
When only one account is connected, tools have their normal unprefixed names.
|
||||
</parameter_resolution>
|
||||
|
||||
<memory_protocol>
|
||||
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||
reveal durable facts about the team (decisions, conventions, architecture, processes,
|
||||
|
|
@ -450,6 +566,9 @@ _TOOL_INSTRUCTIONS["generate_resume"] = """
|
|||
- WHEN NOT TO CALL: General career advice, resume tips, cover letters, or reviewing
|
||||
a resume without making changes. For cover letters, use generate_report instead.
|
||||
- The tool produces Typst source code that is compiled to a PDF preview automatically.
|
||||
- PAGE POLICY:
|
||||
- Default behavior is ONE PAGE. For new resume creation, set max_pages=1 unless the user explicitly asks for more.
|
||||
- If the user requests a longer resume (e.g., "make it 2 pages"), set max_pages to that value.
|
||||
- Args:
|
||||
- user_info: The user's resume content — work experience, education, skills, contact
|
||||
info, etc. Can be structured or unstructured text.
|
||||
|
|
@ -465,6 +584,7 @@ _TOOL_INSTRUCTIONS["generate_resume"] = """
|
|||
"keep it to one page"). For revisions, describe what to change.
|
||||
- parent_report_id: Set this when the user wants to MODIFY an existing resume from
|
||||
this conversation. Use the report_id from a previous generate_resume result.
|
||||
- max_pages: Maximum resume length in pages (integer 1-5). Default is 1.
|
||||
- Returns: Dict with status, report_id, title, and content_type.
|
||||
- After calling: Give a brief confirmation. Do NOT paste resume content in chat. Do NOT mention report_id or any internal IDs — the resume card is shown automatically.
|
||||
- VERSIONING: Same rules as generate_report — set parent_report_id for modifications
|
||||
|
|
@ -473,17 +593,20 @@ _TOOL_INSTRUCTIONS["generate_resume"] = """
|
|||
|
||||
_TOOL_EXAMPLES["generate_resume"] = """
|
||||
- User: "Build me a resume. I'm John Doe, engineer at Acme Corp..."
|
||||
- Call: `generate_resume(user_info="John Doe, engineer at Acme Corp...")`
|
||||
- Call: `generate_resume(user_info="John Doe, engineer at Acme Corp...", max_pages=1)`
|
||||
- WHY: Has creation verb "build" + resume → call the tool.
|
||||
- User: "Create my CV with this info: [experience, education, skills]"
|
||||
- Call: `generate_resume(user_info="[experience, education, skills]")`
|
||||
- Call: `generate_resume(user_info="[experience, education, skills]", max_pages=1)`
|
||||
- User: "Build me a resume" (and there is a resume/CV document in the conversation context)
|
||||
- Extract the FULL content from the document in context, then call:
|
||||
`generate_resume(user_info="Name: John Doe\\nEmail: john@example.com\\n\\nExperience:\\n- Senior Engineer at Acme Corp (2020-2024)\\n Led team of 5...\\n\\nEducation:\\n- BS Computer Science, MIT (2016-2020)\\n\\nSkills: Python, TypeScript, AWS...")`
|
||||
`generate_resume(user_info="Name: John Doe\\nEmail: john@example.com\\n\\nExperience:\\n- Senior Engineer at Acme Corp (2020-2024)\\n Led team of 5...\\n\\nEducation:\\n- BS Computer Science, MIT (2016-2020)\\n\\nSkills: Python, TypeScript, AWS...", max_pages=1)`
|
||||
- WHY: Document content is available in context — extract ALL of it into user_info. Do NOT ignore referenced documents.
|
||||
- User: (after resume generated) "Change my title to Senior Engineer"
|
||||
- Call: `generate_resume(user_info="", user_instructions="Change the job title to Senior Engineer", parent_report_id=<previous_report_id>)`
|
||||
- Call: `generate_resume(user_info="", user_instructions="Change the job title to Senior Engineer", parent_report_id=<previous_report_id>, max_pages=1)`
|
||||
- WHY: Modification verb "change" + refers to existing resume → set parent_report_id.
|
||||
- User: (after resume generated) "Make this 2 pages and expand projects"
|
||||
- Call: `generate_resume(user_info="", user_instructions="Expand projects and keep this to at most 2 pages", parent_report_id=<previous_report_id>, max_pages=2)`
|
||||
- WHY: Explicit page increase request → set max_pages to 2.
|
||||
- User: "How should I structure my resume?"
|
||||
- Do NOT call generate_resume. Answer in chat with advice.
|
||||
- WHY: No creation/modification verb.
|
||||
|
|
@ -692,11 +815,36 @@ Your goal is to provide helpful, informative answers in a clean, readable format
|
|||
"""
|
||||
|
||||
|
||||
def _build_mcp_routing_block(
|
||||
mcp_connector_tools: dict[str, list[str]] | None,
|
||||
) -> str:
|
||||
"""Build an additional tool routing block for generic MCP connectors.
|
||||
|
||||
When users add MCP servers (e.g. GitLab, GitHub), the LLM needs to know
|
||||
those tools exist and should be called directly — not searched in the
|
||||
knowledge base.
|
||||
"""
|
||||
if not mcp_connector_tools:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"\n<mcp_tool_routing>",
|
||||
"You also have direct tools from these user-connected MCP servers.",
|
||||
"Their data is NEVER in the knowledge base — call their tools directly.",
|
||||
"",
|
||||
]
|
||||
for server_name, tool_names in mcp_connector_tools.items():
|
||||
lines.append(f"- {server_name} → {', '.join(tool_names)}")
|
||||
lines.append("</mcp_tool_routing>\n")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def build_surfsense_system_prompt(
|
||||
today: datetime | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
disabled_tool_names: set[str] | None = None,
|
||||
mcp_connector_tools: dict[str, list[str]] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build the SurfSense system prompt with default settings.
|
||||
|
|
@ -711,6 +859,9 @@ def build_surfsense_system_prompt(
|
|||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
||||
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
||||
mcp_connector_tools: Mapping of MCP server display name → list of tool names
|
||||
for generic MCP connectors. Injected into the system prompt so the LLM
|
||||
knows to call these tools directly.
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
|
|
@ -718,6 +869,7 @@ def build_surfsense_system_prompt(
|
|||
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
system_instructions = _get_system_instructions(visibility, today)
|
||||
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
|
||||
tools_instructions = _get_tools_instructions(
|
||||
visibility, enabled_tool_names, disabled_tool_names
|
||||
)
|
||||
|
|
@ -733,6 +885,7 @@ def build_configurable_system_prompt(
|
|||
thread_visibility: ChatVisibility | None = None,
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
disabled_tool_names: set[str] | None = None,
|
||||
mcp_connector_tools: dict[str, list[str]] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
||||
|
|
@ -754,6 +907,9 @@ def build_configurable_system_prompt(
|
|||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
||||
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
||||
mcp_connector_tools: Mapping of MCP server display name → list of tool names
|
||||
for generic MCP connectors. Injected into the system prompt so the LLM
|
||||
knows to call these tools directly.
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
|
|
@ -771,6 +927,8 @@ def build_configurable_system_prompt(
|
|||
else:
|
||||
system_instructions = ""
|
||||
|
||||
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
|
||||
|
||||
# Tools instructions: only include enabled tools, note disabled ones
|
||||
tools_instructions = _get_tools_instructions(
|
||||
thread_visibility, enabled_tool_names, disabled_tool_names
|
||||
|
|
|
|||
|
|
@ -0,0 +1,109 @@
|
|||
"""Connected-accounts discovery tool.
|
||||
|
||||
Lets the LLM discover which accounts are connected for a given service
|
||||
(e.g. "jira", "linear", "slack") and retrieve the metadata it needs to
|
||||
call action tools — such as Jira's ``cloudId``.
|
||||
|
||||
The tool returns **only** non-sensitive fields explicitly listed in the
|
||||
service's ``account_metadata_keys`` (see ``registry.py``), plus the
|
||||
always-present ``display_name`` and ``connector_id``.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.services.mcp_oauth.registry import MCP_SERVICES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SERVICE_KEY_BY_CONNECTOR_TYPE: dict[str, str] = {
|
||||
cfg.connector_type: key for key, cfg in MCP_SERVICES.items()
|
||||
}
|
||||
|
||||
|
||||
class GetConnectedAccountsInput(BaseModel):
|
||||
service: str = Field(
|
||||
description=(
|
||||
"Service key to look up connected accounts for. "
|
||||
"Valid values: " + ", ".join(sorted(MCP_SERVICES.keys()))
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _extract_display_name(connector: SearchSourceConnector) -> str:
|
||||
"""Best-effort human-readable label for a connector."""
|
||||
cfg = connector.config or {}
|
||||
if cfg.get("display_name"):
|
||||
return cfg["display_name"]
|
||||
if cfg.get("base_url"):
|
||||
return f"{connector.name} ({cfg['base_url']})"
|
||||
if cfg.get("organization_name"):
|
||||
return f"{connector.name} ({cfg['organization_name']})"
|
||||
return connector.name
|
||||
|
||||
|
||||
def create_get_connected_accounts_tool(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> StructuredTool:
|
||||
|
||||
async def _run(service: str) -> list[dict[str, Any]]:
|
||||
svc_cfg = MCP_SERVICES.get(service)
|
||||
if not svc_cfg:
|
||||
return [{"error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}"}]
|
||||
|
||||
try:
|
||||
connector_type = SearchSourceConnectorType(svc_cfg.connector_type)
|
||||
except ValueError:
|
||||
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == connector_type,
|
||||
)
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
|
||||
if not connectors:
|
||||
return [{"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."}]
|
||||
|
||||
is_multi = len(connectors) > 1
|
||||
|
||||
accounts: list[dict[str, Any]] = []
|
||||
for conn in connectors:
|
||||
cfg = conn.config or {}
|
||||
entry: dict[str, Any] = {
|
||||
"connector_id": conn.id,
|
||||
"display_name": _extract_display_name(conn),
|
||||
"service": service,
|
||||
}
|
||||
if is_multi:
|
||||
entry["tool_prefix"] = f"{service}_{conn.id}"
|
||||
for key in svc_cfg.account_metadata_keys:
|
||||
if key in cfg:
|
||||
entry[key] = cfg[key]
|
||||
accounts.append(entry)
|
||||
|
||||
return accounts
|
||||
|
||||
return StructuredTool(
|
||||
name="get_connected_accounts",
|
||||
description=(
|
||||
"Discover which accounts are connected for a service (e.g. jira, linear, slack, clickup, airtable). "
|
||||
"Returns display names and service-specific metadata the action tools need "
|
||||
"(e.g. Jira's cloudId). Call this BEFORE using a service's action tools when "
|
||||
"you need an account identifier or are unsure which account to use."
|
||||
),
|
||||
coroutine=_run,
|
||||
args_schema=GetConnectedAccountsInput,
|
||||
metadata={"hitl": False},
|
||||
)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.discord.list_channels import (
|
||||
create_list_discord_channels_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.discord.read_messages import (
|
||||
create_read_discord_messages_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.discord.send_message import (
|
||||
create_send_discord_message_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_list_discord_channels_tool",
|
||||
"create_read_discord_messages_tool",
|
||||
"create_send_discord_message_tool",
|
||||
]
|
||||
42
surfsense_backend/app/agents/new_chat/tools/discord/_auth.py
Normal file
42
surfsense_backend/app/agents/new_chat/tools/discord/_auth.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Shared auth helper for Discord agent tools (REST API, not gateway bot)."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
DISCORD_API = "https://discord.com/api/v10"
|
||||
|
||||
|
||||
async def get_discord_connector(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> SearchSourceConnector | None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
def get_bot_token(connector: SearchSourceConnector) -> str:
|
||||
"""Extract and decrypt the bot token from connector config."""
|
||||
cfg = dict(connector.config)
|
||||
if cfg.get("_token_encrypted") and config.SECRET_KEY:
|
||||
enc = TokenEncryption(config.SECRET_KEY)
|
||||
if cfg.get("bot_token"):
|
||||
cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"])
|
||||
token = cfg.get("bot_token")
|
||||
if not token:
|
||||
raise ValueError("Discord bot token not found in connector config.")
|
||||
return token
|
||||
|
||||
|
||||
def get_guild_id(connector: SearchSourceConnector) -> str | None:
|
||||
return connector.config.get("guild_id")
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_list_discord_channels_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def list_discord_channels() -> dict[str, Any]:
|
||||
"""List text channels in the connected Discord server.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of channels (id, name).
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
guild_id = get_guild_id(connector)
|
||||
if not guild_id:
|
||||
return {"status": "error", "message": "No guild ID in Discord connector config."}
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/guilds/{guild_id}/channels",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||
|
||||
# Type 0 = text channel
|
||||
channels = [
|
||||
{"id": ch["id"], "name": ch["name"]}
|
||||
for ch in resp.json()
|
||||
if ch.get("type") == 0
|
||||
]
|
||||
return {"status": "success", "guild_id": guild_id, "channels": channels, "total": len(channels)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error listing Discord channels: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to list Discord channels."}
|
||||
|
||||
return list_discord_channels
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_read_discord_messages_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_discord_messages(
|
||||
channel_id: str,
|
||||
limit: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""Read recent messages from a Discord text channel.
|
||||
|
||||
Args:
|
||||
channel_id: The Discord channel ID (from list_discord_channels).
|
||||
limit: Number of messages to fetch (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of messages including
|
||||
id, author, content, timestamp.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||
|
||||
limit = min(limit, 50)
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/channels/{channel_id}/messages",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
params={"limit": limit},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Bot lacks permission to read this channel."}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"id": m["id"],
|
||||
"author": m.get("author", {}).get("username", "Unknown"),
|
||||
"content": m.get("content", ""),
|
||||
"timestamp": m.get("timestamp", ""),
|
||||
}
|
||||
for m in resp.json()
|
||||
]
|
||||
|
||||
return {"status": "success", "channel_id": channel_id, "messages": messages, "total": len(messages)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Discord messages: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read Discord messages."}
|
||||
|
||||
return read_discord_messages
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_send_discord_message_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def send_discord_message(
|
||||
channel_id: str,
|
||||
content: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a message to a Discord text channel.
|
||||
|
||||
Args:
|
||||
channel_id: The Discord channel ID (from list_discord_channels).
|
||||
content: The message text (max 2000 characters).
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message_id on success.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||
|
||||
if len(content) > 2000:
|
||||
return {"status": "error", "message": "Message exceeds Discord's 2000-character limit."}
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="discord_send_message",
|
||||
tool_name="send_discord_message",
|
||||
params={"channel_id": channel_id, "content": content},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {"status": "rejected", "message": "User declined. Message was not sent."}
|
||||
|
||||
final_content = result.params.get("content", content)
|
||||
final_channel = result.params.get("channel_id", channel_id)
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{DISCORD_API}/channels/{final_channel}/messages",
|
||||
headers={
|
||||
"Authorization": f"Bot {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"content": final_content},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Bot lacks permission to send messages in this channel."}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||
|
||||
msg_data = resp.json()
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": msg_data.get("id"),
|
||||
"message": f"Message sent to channel {final_channel}.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error sending Discord message: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to send Discord message."}
|
||||
|
||||
return send_discord_message
|
||||
|
|
@ -1,6 +1,12 @@
|
|||
from app.agents.new_chat.tools.gmail.create_draft import (
|
||||
create_create_gmail_draft_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.read_email import (
|
||||
create_read_gmail_email_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
create_search_gmail_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.send_email import (
|
||||
create_send_gmail_email_tool,
|
||||
)
|
||||
|
|
@ -13,6 +19,8 @@ from app.agents.new_chat.tools.gmail.update_draft import (
|
|||
|
||||
__all__ = [
|
||||
"create_create_gmail_draft_tool",
|
||||
"create_read_gmail_email_tool",
|
||||
"create_search_gmail_tool",
|
||||
"create_send_gmail_email_tool",
|
||||
"create_trash_gmail_email_tool",
|
||||
"create_update_gmail_draft_tool",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,87 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GMAIL_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
|
||||
def create_read_gmail_email_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_gmail_email(message_id: str) -> dict[str, Any]:
|
||||
"""Read the full content of a specific Gmail email by its message ID.
|
||||
|
||||
Use after search_gmail to get the complete body of an email.
|
||||
|
||||
Args:
|
||||
message_id: The Gmail message ID (from search_gmail results).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and the full email content formatted as markdown.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
gmail = GoogleGmailConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
detail, error = await gmail.get_message_details(message_id)
|
||||
if error:
|
||||
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
|
||||
return {"status": "auth_error", "message": error, "connector_type": "gmail"}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
if not detail:
|
||||
return {"status": "not_found", "message": f"Email with ID '{message_id}' not found."}
|
||||
|
||||
content = gmail.format_message_to_markdown(detail)
|
||||
|
||||
return {"status": "success", "message_id": message_id, "content": content}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Gmail email: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read email. Please try again."}
|
||||
|
||||
return read_gmail_email
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GMAIL_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
_token_encryption_cache: object | None = None
|
||||
|
||||
|
||||
def _get_token_encryption():
|
||||
global _token_encryption_cache
|
||||
if _token_encryption_cache is None:
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise RuntimeError("SECRET_KEY not configured for token decryption.")
|
||||
_token_encryption_cache = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption_cache
|
||||
|
||||
|
||||
def _build_credentials(connector: SearchSourceConnector):
|
||||
"""Build Google OAuth Credentials from a connector's stored config.
|
||||
|
||||
Handles both native OAuth connectors (with encrypted tokens) and
|
||||
Composio-backed connectors. Shared by Gmail and Calendar tools.
|
||||
"""
|
||||
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected account ID not found.")
|
||||
return build_composio_credentials(cca_id)
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
cfg = dict(connector.config)
|
||||
if cfg.get("_token_encrypted"):
|
||||
enc = _get_token_encryption()
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if cfg.get(key):
|
||||
cfg[key] = enc.decrypt_token(cfg[key])
|
||||
|
||||
exp = (cfg.get("expiry") or "").replace("Z", "")
|
||||
return Credentials(
|
||||
token=cfg.get("token"),
|
||||
refresh_token=cfg.get("refresh_token"),
|
||||
token_uri=cfg.get("token_uri"),
|
||||
client_id=cfg.get("client_id"),
|
||||
client_secret=cfg.get("client_secret"),
|
||||
scopes=cfg.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
|
||||
def create_search_gmail_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def search_gmail(
|
||||
query: str,
|
||||
max_results: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Search emails in the user's Gmail inbox using Gmail search syntax.
|
||||
|
||||
Args:
|
||||
query: Gmail search query, same syntax as the Gmail search bar.
|
||||
Examples: "from:alice@example.com", "subject:meeting",
|
||||
"is:unread", "after:2024/01/01 before:2024/02/01",
|
||||
"has:attachment", "in:sent".
|
||||
max_results: Number of emails to return (default 10, max 20).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of email summaries including
|
||||
message_id, subject, from, date, snippet.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||
|
||||
max_results = min(max_results, 20)
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
gmail = GoogleGmailConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
messages_list, error = await gmail.get_messages_list(
|
||||
max_results=max_results, query=query
|
||||
)
|
||||
if error:
|
||||
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
|
||||
return {"status": "auth_error", "message": error, "connector_type": "gmail"}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
if not messages_list:
|
||||
return {"status": "success", "emails": [], "total": 0, "message": "No emails found."}
|
||||
|
||||
emails = []
|
||||
for msg in messages_list:
|
||||
detail, err = await gmail.get_message_details(msg["id"])
|
||||
if err:
|
||||
continue
|
||||
headers = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in detail.get("payload", {}).get("headers", [])
|
||||
}
|
||||
emails.append({
|
||||
"message_id": detail.get("id"),
|
||||
"thread_id": detail.get("threadId"),
|
||||
"subject": headers.get("subject", "No Subject"),
|
||||
"from": headers.get("from", "Unknown"),
|
||||
"to": headers.get("to", ""),
|
||||
"date": headers.get("date", ""),
|
||||
"snippet": detail.get("snippet", ""),
|
||||
"labels": detail.get("labelIds", []),
|
||||
})
|
||||
|
||||
return {"status": "success", "emails": emails, "total": len(emails)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error searching Gmail: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to search Gmail. Please try again."}
|
||||
|
||||
return search_gmail
|
||||
|
|
@ -4,6 +4,9 @@ from app.agents.new_chat.tools.google_calendar.create_event import (
|
|||
from app.agents.new_chat.tools.google_calendar.delete_event import (
|
||||
create_delete_calendar_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.search_events import (
|
||||
create_search_calendar_events_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.update_event import (
|
||||
create_update_calendar_event_tool,
|
||||
)
|
||||
|
|
@ -11,5 +14,6 @@ from app.agents.new_chat.tools.google_calendar.update_event import (
|
|||
__all__ = [
|
||||
"create_create_calendar_event_tool",
|
||||
"create_delete_calendar_event_tool",
|
||||
"create_search_calendar_events_tool",
|
||||
"create_update_calendar_event_tool",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,114 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CALENDAR_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
|
||||
def create_search_calendar_events_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def search_calendar_events(
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
max_results: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""Search Google Calendar events within a date range.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format (e.g. "2026-04-01").
|
||||
end_date: End date in YYYY-MM-DD format (e.g. "2026-04-30").
|
||||
max_results: Maximum number of events to return (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of events including
|
||||
event_id, summary, start, end, location, attendees.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Calendar tool not properly configured."}
|
||||
|
||||
max_results = min(max_results, 50)
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
||||
|
||||
cal = GoogleCalendarConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
events_raw, error = await cal.get_all_primary_calendar_events(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
if error:
|
||||
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
|
||||
return {"status": "auth_error", "message": error, "connector_type": "google_calendar"}
|
||||
if "no events found" in error.lower():
|
||||
return {"status": "success", "events": [], "total": 0, "message": error}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
events = []
|
||||
for ev in events_raw:
|
||||
start = ev.get("start", {})
|
||||
end = ev.get("end", {})
|
||||
attendees_raw = ev.get("attendees", [])
|
||||
events.append({
|
||||
"event_id": ev.get("id"),
|
||||
"summary": ev.get("summary", "No Title"),
|
||||
"start": start.get("dateTime") or start.get("date", ""),
|
||||
"end": end.get("dateTime") or end.get("date", ""),
|
||||
"location": ev.get("location", ""),
|
||||
"description": ev.get("description", ""),
|
||||
"html_link": ev.get("htmlLink", ""),
|
||||
"attendees": [
|
||||
a.get("email", "") for a in attendees_raw[:10]
|
||||
],
|
||||
"status": ev.get("status", ""),
|
||||
})
|
||||
|
||||
return {"status": "success", "events": events, "total": len(events)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error searching calendar events: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to search calendar events. Please try again."}
|
||||
|
||||
return search_calendar_events
|
||||
|
|
@ -130,8 +130,8 @@ def request_approval(
|
|||
try:
|
||||
decision_type, edited_params = _parse_decision(approval)
|
||||
except ValueError:
|
||||
logger.warning("No approval decision received for %s", tool_name)
|
||||
return HITLResult(rejected=False, decision_type="error", params=params)
|
||||
logger.warning("No approval decision received for %s — rejecting for safety", tool_name)
|
||||
return HITLResult(rejected=True, decision_type="error", params=params)
|
||||
|
||||
logger.info("User decision for %s: %s", tool_name, decision_type)
|
||||
|
||||
|
|
|
|||
15
surfsense_backend/app/agents/new_chat/tools/luma/__init__.py
Normal file
15
surfsense_backend/app/agents/new_chat/tools/luma/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.luma.create_event import (
|
||||
create_create_luma_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.luma.list_events import (
|
||||
create_list_luma_events_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.luma.read_event import (
|
||||
create_read_luma_event_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_luma_event_tool",
|
||||
"create_list_luma_events_tool",
|
||||
"create_read_luma_event_tool",
|
||||
]
|
||||
38
surfsense_backend/app/agents/new_chat/tools/luma/_auth.py
Normal file
38
surfsense_backend/app/agents/new_chat/tools/luma/_auth.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
"""Shared auth helper for Luma agent tools."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
LUMA_API = "https://public-api.luma.com/v1"
|
||||
|
||||
|
||||
async def get_luma_connector(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> SearchSourceConnector | None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
def get_api_key(connector: SearchSourceConnector) -> str:
|
||||
"""Extract the API key from connector config (handles both key names)."""
|
||||
key = connector.config.get("api_key") or connector.config.get("LUMA_API_KEY")
|
||||
if not key:
|
||||
raise ValueError("Luma API key not found in connector config.")
|
||||
return key
|
||||
|
||||
|
||||
def luma_headers(api_key: str) -> dict[str, str]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"x-luma-api-key": api_key,
|
||||
}
|
||||
116
surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
Normal file
116
surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
|
||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_luma_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_luma_event(
|
||||
name: str,
|
||||
start_at: str,
|
||||
end_at: str,
|
||||
description: str | None = None,
|
||||
timezone: str = "UTC",
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new event on Luma.
|
||||
|
||||
Args:
|
||||
name: The event title.
|
||||
start_at: Start time in ISO 8601 format (e.g. "2026-05-01T18:00:00").
|
||||
end_at: End time in ISO 8601 format (e.g. "2026-05-01T20:00:00").
|
||||
description: Optional event description (markdown supported).
|
||||
timezone: Timezone string (default "UTC", e.g. "America/New_York").
|
||||
|
||||
Returns:
|
||||
Dictionary with status, event_id on success.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="luma_create_event",
|
||||
tool_name="create_luma_event",
|
||||
params={
|
||||
"name": name,
|
||||
"start_at": start_at,
|
||||
"end_at": end_at,
|
||||
"description": description,
|
||||
"timezone": timezone,
|
||||
},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {"status": "rejected", "message": "User declined. Event was not created."}
|
||||
|
||||
final_name = result.params.get("name", name)
|
||||
final_start = result.params.get("start_at", start_at)
|
||||
final_end = result.params.get("end_at", end_at)
|
||||
final_desc = result.params.get("description", description)
|
||||
final_tz = result.params.get("timezone", timezone)
|
||||
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"name": final_name,
|
||||
"start_at": final_start,
|
||||
"end_at": final_end,
|
||||
"timezone": final_tz,
|
||||
}
|
||||
if final_desc:
|
||||
body["description_md"] = final_desc
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(
|
||||
f"{LUMA_API}/event/create",
|
||||
headers=headers,
|
||||
json=body,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Luma Plus subscription required to create events via API."}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {"status": "error", "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}"}
|
||||
|
||||
data = resp.json()
|
||||
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": event_id,
|
||||
"message": f"Event '{final_name}' created on Luma.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error creating Luma event: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to create Luma event."}
|
||||
|
||||
return create_luma_event
|
||||
100
surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
Normal file
100
surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_list_luma_events_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def list_luma_events(
|
||||
max_results: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""List upcoming and recent Luma events.
|
||||
|
||||
Args:
|
||||
max_results: Maximum events to return (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of events including
|
||||
event_id, name, start_at, end_at, location, url.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||
|
||||
max_results = min(max_results, 50)
|
||||
|
||||
try:
|
||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
all_entries: list[dict] = []
|
||||
cursor = None
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
while len(all_entries) < max_results:
|
||||
params: dict[str, Any] = {"limit": min(100, max_results - len(all_entries))}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
resp = await client.get(
|
||||
f"{LUMA_API}/calendar/list-events",
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Luma API error: {resp.status_code}"}
|
||||
|
||||
data = resp.json()
|
||||
entries = data.get("entries", [])
|
||||
if not entries:
|
||||
break
|
||||
all_entries.extend(entries)
|
||||
|
||||
next_cursor = data.get("next_cursor")
|
||||
if not next_cursor:
|
||||
break
|
||||
cursor = next_cursor
|
||||
|
||||
events = []
|
||||
for entry in all_entries[:max_results]:
|
||||
ev = entry.get("event", {})
|
||||
geo = ev.get("geo_info", {})
|
||||
events.append({
|
||||
"event_id": entry.get("api_id"),
|
||||
"name": ev.get("name", "Untitled"),
|
||||
"start_at": ev.get("start_at", ""),
|
||||
"end_at": ev.get("end_at", ""),
|
||||
"timezone": ev.get("timezone", ""),
|
||||
"location": geo.get("name", ""),
|
||||
"url": ev.get("url", ""),
|
||||
"visibility": ev.get("visibility", ""),
|
||||
})
|
||||
|
||||
return {"status": "success", "events": events, "total": len(events)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error listing Luma events: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to list Luma events."}
|
||||
|
||||
return list_luma_events
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_read_luma_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_luma_event(event_id: str) -> dict[str, Any]:
|
||||
"""Read detailed information about a specific Luma event.
|
||||
|
||||
Args:
|
||||
event_id: The Luma event API ID (from list_luma_events).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and full event details including
|
||||
description, attendees count, meeting URL.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
f"{LUMA_API}/events/{event_id}",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
|
||||
if resp.status_code == 404:
|
||||
return {"status": "not_found", "message": f"Event '{event_id}' not found."}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Luma API error: {resp.status_code}"}
|
||||
|
||||
data = resp.json()
|
||||
ev = data.get("event", data)
|
||||
geo = ev.get("geo_info", {})
|
||||
|
||||
event_detail = {
|
||||
"event_id": event_id,
|
||||
"name": ev.get("name", ""),
|
||||
"description": ev.get("description", ""),
|
||||
"start_at": ev.get("start_at", ""),
|
||||
"end_at": ev.get("end_at", ""),
|
||||
"timezone": ev.get("timezone", ""),
|
||||
"location_name": geo.get("name", ""),
|
||||
"address": geo.get("address", ""),
|
||||
"url": ev.get("url", ""),
|
||||
"meeting_url": ev.get("meeting_url", ""),
|
||||
"visibility": ev.get("visibility", ""),
|
||||
"cover_url": ev.get("cover_url", ""),
|
||||
}
|
||||
|
||||
return {"status": "success", "event": event_detail}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Luma event: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read Luma event."}
|
||||
|
||||
return read_luma_event
|
||||
|
|
@ -45,6 +45,18 @@ class MCPClient:
|
|||
async def connect(self, max_retries: int = MAX_RETRIES):
|
||||
"""Connect to the MCP server and manage its lifecycle.
|
||||
|
||||
Retries only apply to the **connection** phase (spawning the process,
|
||||
initialising the session). Once the session is yielded to the caller,
|
||||
any exception raised by the caller propagates normally -- the context
|
||||
manager will NOT retry after ``yield``.
|
||||
|
||||
Previous implementation wrapped both connection AND yield inside the
|
||||
retry loop. Because ``@asynccontextmanager`` only allows a single
|
||||
``yield``, a failure after yield caused the generator to attempt a
|
||||
second yield on retry, triggering
|
||||
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
|
||||
the stdio subprocess.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of connection retry attempts
|
||||
|
||||
|
|
@ -57,26 +69,22 @@ class MCPClient:
|
|||
"""
|
||||
last_error = None
|
||||
delay = RETRY_DELAY
|
||||
connected = False
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Merge env vars with current environment
|
||||
server_env = os.environ.copy()
|
||||
server_env.update(self.env)
|
||||
|
||||
# Create server parameters with env
|
||||
server_params = StdioServerParameters(
|
||||
command=self.command, args=self.args, env=server_env
|
||||
)
|
||||
|
||||
# Spawn server process and create session
|
||||
# Note: Cannot combine these context managers because ClientSession
|
||||
# needs the read/write streams from stdio_client
|
||||
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
connected = True
|
||||
|
||||
if attempt > 0:
|
||||
logger.info(
|
||||
|
|
@ -91,10 +99,16 @@ class MCPClient:
|
|||
self.command,
|
||||
" ".join(self.args),
|
||||
)
|
||||
yield session
|
||||
return # Success, exit retry loop
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
self.session = None
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
self.session = None
|
||||
if connected:
|
||||
raise
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
|
|
@ -105,7 +119,7 @@ class MCPClient:
|
|||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
delay *= RETRY_BACKOFF # Exponential backoff
|
||||
delay *= RETRY_BACKOFF
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to connect to MCP server after %d attempts: %s",
|
||||
|
|
@ -113,10 +127,7 @@ class MCPClient:
|
|||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
self.session = None
|
||||
|
||||
# All retries exhausted
|
||||
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
|
||||
if last_error:
|
||||
error_msg += f": {last_error}"
|
||||
|
|
@ -161,12 +172,18 @@ class MCPClient:
|
|||
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
async def call_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
timeout: float = 60.0,
|
||||
) -> Any:
|
||||
"""Call a tool on the MCP server.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Arguments to pass to the tool
|
||||
timeout: Maximum seconds to wait for the tool to respond
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
|
|
@ -185,10 +202,11 @@ class MCPClient:
|
|||
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
|
||||
)
|
||||
|
||||
# Call tools/call RPC method
|
||||
response = await self.session.call_tool(tool_name, arguments=arguments)
|
||||
response = await asyncio.wait_for(
|
||||
self.session.call_tool(tool_name, arguments=arguments),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Extract content from response
|
||||
result = []
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
|
|
@ -202,15 +220,17 @@ class MCPClient:
|
|||
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
|
||||
return result_str
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"MCP tool '%s' timed out after %.0fs", tool_name, timeout
|
||||
)
|
||||
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
|
||||
except RuntimeError as e:
|
||||
# Handle validation errors from MCP server responses
|
||||
# Some MCP servers (like server-memory) return extra fields not in their schema
|
||||
if "Invalid structured content" in str(e):
|
||||
logger.warning(
|
||||
"MCP server returned data not matching its schema, but continuing: %s",
|
||||
e,
|
||||
)
|
||||
# Try to extract result from error message or return a success message
|
||||
return "Operation completed (server returned unexpected format)"
|
||||
raise
|
||||
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -50,6 +50,11 @@ from .confluence import (
|
|||
create_delete_confluence_page_tool,
|
||||
create_update_confluence_page_tool,
|
||||
)
|
||||
from .discord import (
|
||||
create_list_discord_channels_tool,
|
||||
create_read_discord_messages_tool,
|
||||
create_send_discord_message_tool,
|
||||
)
|
||||
from .dropbox import (
|
||||
create_create_dropbox_file_tool,
|
||||
create_delete_dropbox_file_tool,
|
||||
|
|
@ -57,6 +62,8 @@ from .dropbox import (
|
|||
from .generate_image import create_generate_image_tool
|
||||
from .gmail import (
|
||||
create_create_gmail_draft_tool,
|
||||
create_read_gmail_email_tool,
|
||||
create_search_gmail_tool,
|
||||
create_send_gmail_email_tool,
|
||||
create_trash_gmail_email_tool,
|
||||
create_update_gmail_draft_tool,
|
||||
|
|
@ -64,21 +71,18 @@ from .gmail import (
|
|||
from .google_calendar import (
|
||||
create_create_calendar_event_tool,
|
||||
create_delete_calendar_event_tool,
|
||||
create_search_calendar_events_tool,
|
||||
create_update_calendar_event_tool,
|
||||
)
|
||||
from .google_drive import (
|
||||
create_create_google_drive_file_tool,
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
from .jira import (
|
||||
create_create_jira_issue_tool,
|
||||
create_delete_jira_issue_tool,
|
||||
create_update_jira_issue_tool,
|
||||
)
|
||||
from .linear import (
|
||||
create_create_linear_issue_tool,
|
||||
create_delete_linear_issue_tool,
|
||||
create_update_linear_issue_tool,
|
||||
from .connected_accounts import create_get_connected_accounts_tool
|
||||
from .luma import (
|
||||
create_create_luma_event_tool,
|
||||
create_list_luma_events_tool,
|
||||
create_read_luma_event_tool,
|
||||
)
|
||||
from .mcp_tool import load_mcp_tools
|
||||
from .notion import (
|
||||
|
|
@ -95,6 +99,11 @@ from .report import create_generate_report_tool
|
|||
from .resume import create_generate_resume_tool
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
from .teams import (
|
||||
create_list_teams_channels_tool,
|
||||
create_read_teams_messages_tool,
|
||||
create_send_teams_message_tool,
|
||||
)
|
||||
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
||||
from .video_presentation import create_generate_video_presentation_tool
|
||||
from .web_search import create_web_search_tool
|
||||
|
|
@ -114,6 +123,8 @@ class ToolDefinition:
|
|||
factory: Callable that creates the tool. Receives a dict of dependencies.
|
||||
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
|
||||
enabled_by_default: Whether the tool is enabled when no explicit config is provided
|
||||
required_connector: Searchable type string (e.g. ``"LINEAR_CONNECTOR"``)
|
||||
that must be in ``available_connectors`` for the tool to be enabled.
|
||||
|
||||
"""
|
||||
|
||||
|
|
@ -123,6 +134,7 @@ class ToolDefinition:
|
|||
requires: list[str] = field(default_factory=list)
|
||||
enabled_by_default: bool = True
|
||||
hidden: bool = False
|
||||
required_connector: str | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -221,6 +233,21 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
requires=["db_session"],
|
||||
),
|
||||
# =========================================================================
|
||||
# SERVICE ACCOUNT DISCOVERY
|
||||
# Generic tool for the LLM to discover connected accounts and resolve
|
||||
# service-specific identifiers (e.g. Jira cloudId, Slack team, etc.)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="get_connected_accounts",
|
||||
description="Discover connected accounts for a service and their metadata",
|
||||
factory=lambda deps: create_get_connected_accounts_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# MEMORY TOOL - single update_memory, private or team by thread_visibility
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
|
|
@ -248,40 +275,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
],
|
||||
),
|
||||
# =========================================================================
|
||||
# LINEAR TOOLS - create, update, delete issues
|
||||
# Auto-disabled when no Linear connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_linear_issue",
|
||||
description="Create a new issue in the user's Linear workspace",
|
||||
factory=lambda deps: create_create_linear_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_linear_issue",
|
||||
description="Update an existing indexed Linear issue",
|
||||
factory=lambda deps: create_update_linear_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_linear_issue",
|
||||
description="Archive (delete) an existing indexed Linear issue",
|
||||
factory=lambda deps: create_delete_linear_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# NOTION TOOLS - create, update, delete pages
|
||||
# Auto-disabled when no Notion connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
|
|
@ -294,6 +287,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="NOTION_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_notion_page",
|
||||
|
|
@ -304,6 +298,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="NOTION_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_notion_page",
|
||||
|
|
@ -314,6 +309,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="NOTION_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE DRIVE TOOLS - create files, delete files
|
||||
|
|
@ -328,6 +324,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_DRIVE_FILE",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_google_drive_file",
|
||||
|
|
@ -338,6 +335,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_DRIVE_FILE",
|
||||
),
|
||||
# =========================================================================
|
||||
# DROPBOX TOOLS - create and trash files
|
||||
|
|
@ -352,6 +350,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DROPBOX_FILE",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_dropbox_file",
|
||||
|
|
@ -362,6 +361,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DROPBOX_FILE",
|
||||
),
|
||||
# =========================================================================
|
||||
# ONEDRIVE TOOLS - create and trash files
|
||||
|
|
@ -376,6 +376,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="ONEDRIVE_FILE",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_onedrive_file",
|
||||
|
|
@ -386,11 +387,23 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="ONEDRIVE_FILE",
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE CALENDAR TOOLS - create, update, delete events
|
||||
# GOOGLE CALENDAR TOOLS - search, create, update, delete events
|
||||
# Auto-disabled when no Google Calendar connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="search_calendar_events",
|
||||
description="Search Google Calendar events within a date range",
|
||||
factory=lambda deps: create_search_calendar_events_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="create_calendar_event",
|
||||
description="Create a new event on Google Calendar",
|
||||
|
|
@ -400,6 +413,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_calendar_event",
|
||||
|
|
@ -410,6 +424,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_calendar_event",
|
||||
|
|
@ -420,11 +435,34 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# GMAIL TOOLS - create drafts, update drafts, send emails, trash emails
|
||||
# GMAIL TOOLS - search, read, create drafts, update drafts, send, trash
|
||||
# Auto-disabled when no Gmail connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="search_gmail",
|
||||
description="Search emails in Gmail using Gmail search syntax",
|
||||
factory=lambda deps: create_search_gmail_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="read_gmail_email",
|
||||
description="Read the full content of a specific Gmail email",
|
||||
factory=lambda deps: create_read_gmail_email_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="create_gmail_draft",
|
||||
description="Create a draft email in Gmail",
|
||||
|
|
@ -434,6 +472,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="send_gmail_email",
|
||||
|
|
@ -444,6 +483,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="trash_gmail_email",
|
||||
|
|
@ -454,6 +494,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_gmail_draft",
|
||||
|
|
@ -464,40 +505,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# JIRA TOOLS - create, update, delete issues
|
||||
# Auto-disabled when no Jira connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_jira_issue",
|
||||
description="Create a new issue in the user's Jira project",
|
||||
factory=lambda deps: create_create_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_jira_issue",
|
||||
description="Update an existing indexed Jira issue",
|
||||
factory=lambda deps: create_update_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_jira_issue",
|
||||
description="Delete an existing indexed Jira issue",
|
||||
factory=lambda deps: create_delete_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# CONFLUENCE TOOLS - create, update, delete pages
|
||||
|
|
@ -512,6 +520,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="CONFLUENCE_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_confluence_page",
|
||||
|
|
@ -522,6 +531,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="CONFLUENCE_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_confluence_page",
|
||||
|
|
@ -532,6 +542,118 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="CONFLUENCE_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# DISCORD TOOLS - list channels, read messages, send messages
|
||||
# Auto-disabled when no Discord connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="list_discord_channels",
|
||||
description="List text channels in the connected Discord server",
|
||||
factory=lambda deps: create_list_discord_channels_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DISCORD_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="read_discord_messages",
|
||||
description="Read recent messages from a Discord text channel",
|
||||
factory=lambda deps: create_read_discord_messages_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DISCORD_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="send_discord_message",
|
||||
description="Send a message to a Discord text channel",
|
||||
factory=lambda deps: create_send_discord_message_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DISCORD_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# TEAMS TOOLS - list channels, read messages, send messages
|
||||
# Auto-disabled when no Teams connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="list_teams_channels",
|
||||
description="List Microsoft Teams and their channels",
|
||||
factory=lambda deps: create_list_teams_channels_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="TEAMS_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="read_teams_messages",
|
||||
description="Read recent messages from a Microsoft Teams channel",
|
||||
factory=lambda deps: create_read_teams_messages_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="TEAMS_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="send_teams_message",
|
||||
description="Send a message to a Microsoft Teams channel",
|
||||
factory=lambda deps: create_send_teams_message_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="TEAMS_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# LUMA TOOLS - list events, read event details, create events
|
||||
# Auto-disabled when no Luma connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="list_luma_events",
|
||||
description="List upcoming and recent Luma events",
|
||||
factory=lambda deps: create_list_luma_events_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="LUMA_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="read_luma_event",
|
||||
description="Read detailed information about a specific Luma event",
|
||||
factory=lambda deps: create_read_luma_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="LUMA_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="create_luma_event",
|
||||
description="Create a new event on Luma",
|
||||
factory=lambda deps: create_create_luma_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="LUMA_CONNECTOR",
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -549,6 +671,22 @@ def get_tool_by_name(name: str) -> ToolDefinition | None:
|
|||
return None
|
||||
|
||||
|
||||
def get_connector_gated_tools(
|
||||
available_connectors: list[str] | None,
|
||||
) -> list[str]:
|
||||
"""Return tool names to disable"""
|
||||
if available_connectors is None:
|
||||
available = set()
|
||||
else:
|
||||
available = set(available_connectors)
|
||||
|
||||
disabled: list[str] = []
|
||||
for tool_def in BUILTIN_TOOLS:
|
||||
if tool_def.required_connector and tool_def.required_connector not in available:
|
||||
disabled.append(tool_def.name)
|
||||
return disabled
|
||||
|
||||
|
||||
def get_all_tool_names() -> list[str]:
|
||||
"""Get names of all registered tools."""
|
||||
return [tool_def.name for tool_def in BUILTIN_TOOLS]
|
||||
|
|
@ -690,15 +828,15 @@ async def build_tools_async(
|
|||
)
|
||||
tools.extend(mcp_tools)
|
||||
logging.info(
|
||||
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
|
||||
"Registered %d MCP tools: %s",
|
||||
len(mcp_tools), [t.name for t in mcp_tools],
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't fail - just continue without MCP tools
|
||||
logging.exception(f"Failed to load MCP tools: {e!s}")
|
||||
logging.exception("Failed to load MCP tools: %s", e)
|
||||
|
||||
# Log all tools being returned to agent
|
||||
logging.info(
|
||||
f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}",
|
||||
"Total tools for agent: %d — %s",
|
||||
len(tools), [t.name for t in tools],
|
||||
)
|
||||
|
||||
return tools
|
||||
|
|
|
|||
|
|
@ -13,11 +13,13 @@ Uses the same short-lived session pattern as generate_report so no DB
|
|||
connection is held during the long LLM call.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import pypdf
|
||||
import typst
|
||||
from langchain_core.callbacks import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
|
@ -114,7 +116,7 @@ _TEMPLATES: dict[str, dict[str, str]] = {
|
|||
entries-highlights-nested-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt),
|
||||
entries-highlights-space-left: 0cm,
|
||||
entries-highlights-space-above: 0.08cm,
|
||||
entries-highlights-space-between-items: 0.08cm,
|
||||
entries-highlights-space-between-items: 0.02cm,
|
||||
entries-highlights-space-between-bullet-and-text: 0.3em,
|
||||
date: datetime(
|
||||
year: {year},
|
||||
|
|
@ -166,8 +168,8 @@ Available components (use ONLY these):
|
|||
#summary([Short paragraph summary]) // Optional summary inside an entry
|
||||
#content-area([Free-form content]) // Freeform text block
|
||||
|
||||
For skills sections, use bold labels directly:
|
||||
#strong[Category:] item1, item2, item3
|
||||
For skills sections, use one bullet per category label:
|
||||
- #strong[Category:] item1, item2, item3
|
||||
|
||||
For simple list sections (e.g. Honors), use plain bullet points:
|
||||
- Item one
|
||||
|
|
@ -184,15 +186,19 @@ RULES:
|
|||
- Every section MUST use == heading.
|
||||
- Use #regular-entry() for experience, projects, publications, certifications, and similar entries.
|
||||
- Use #education-entry() for education.
|
||||
- Use #strong[Label:] for skills categories.
|
||||
- For skills sections, use one bullet line per category with a bold label.
|
||||
- Keep content professional, concise, and achievement-oriented.
|
||||
- Use action verbs for bullet points (Led, Built, Designed, Reduced, etc.).
|
||||
- This template works for ALL professions — adapt sections to the user's field.
|
||||
- Default behavior should prioritize concise one-page content.
|
||||
""",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_TEMPLATE = "classic"
|
||||
MIN_RESUME_PAGES = 1
|
||||
MAX_RESUME_PAGES = 5
|
||||
MAX_COMPRESSION_ATTEMPTS = 2
|
||||
|
||||
|
||||
# ─── Template Helpers ─────────────────────────────────────────────────────────
|
||||
|
|
@ -315,6 +321,8 @@ You are an expert resume writer. Generate professional resume content as Typst m
|
|||
**User Information:**
|
||||
{user_info}
|
||||
|
||||
**Target Maximum Pages:** {max_pages}
|
||||
|
||||
{user_instructions_section}
|
||||
|
||||
Generate the resume content now (starting with = Full Name):
|
||||
|
|
@ -326,6 +334,8 @@ Apply ONLY the requested changes — do NOT rewrite sections that are not affect
|
|||
|
||||
{llm_reference}
|
||||
|
||||
**Target Maximum Pages:** {max_pages}
|
||||
|
||||
**Modification Instructions:** {user_instructions}
|
||||
|
||||
**EXISTING RESUME CONTENT:**
|
||||
|
|
@ -352,6 +362,28 @@ The resume content you generated failed to compile. Fix the error while preservi
|
|||
(starting with = Full Name), NOT the #import or #show rule:**
|
||||
"""
|
||||
|
||||
_COMPRESS_TO_PAGE_LIMIT_PROMPT = """\
|
||||
The resume compiles, but it exceeds the maximum allowed page count.
|
||||
Compress the resume while preserving high-impact accomplishments and role relevance.
|
||||
|
||||
{llm_reference}
|
||||
|
||||
**Target Maximum Pages:** {max_pages}
|
||||
**Current Page Count:** {actual_pages}
|
||||
**Compression Attempt:** {attempt_number}
|
||||
|
||||
Compression priorities (in this order):
|
||||
1) Keep recent, high-impact, role-relevant bullets.
|
||||
2) Remove low-impact or redundant bullets.
|
||||
3) Shorten verbose wording while preserving meaning.
|
||||
4) Trim older or less relevant details before recent ones.
|
||||
|
||||
Return the complete updated Typst content (starting with = Full Name), and keep it at or below the target pages.
|
||||
|
||||
**EXISTING RESUME CONTENT:**
|
||||
{previous_content}
|
||||
"""
|
||||
|
||||
|
||||
# ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
|
@ -373,6 +405,24 @@ def _compile_typst(source: str) -> bytes:
|
|||
return typst.compile(source.encode("utf-8"))
|
||||
|
||||
|
||||
def _count_pdf_pages(pdf_bytes: bytes) -> int:
|
||||
"""Count the number of pages in compiled PDF bytes."""
|
||||
with io.BytesIO(pdf_bytes) as pdf_stream:
|
||||
reader = pypdf.PdfReader(pdf_stream)
|
||||
return len(reader.pages)
|
||||
|
||||
|
||||
def _validate_max_pages(max_pages: int) -> int:
|
||||
"""Validate and normalize max_pages input."""
|
||||
if MIN_RESUME_PAGES <= max_pages <= MAX_RESUME_PAGES:
|
||||
return max_pages
|
||||
msg = (
|
||||
f"max_pages must be between {MIN_RESUME_PAGES} and "
|
||||
f"{MAX_RESUME_PAGES}. Received: {max_pages}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
# ─── Tool Factory ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
|
@ -394,6 +444,7 @@ def create_generate_resume_tool(
|
|||
user_info: str,
|
||||
user_instructions: str | None = None,
|
||||
parent_report_id: int | None = None,
|
||||
max_pages: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a professional resume as a Typst document.
|
||||
|
|
@ -426,6 +477,8 @@ def create_generate_resume_tool(
|
|||
"use a modern style"). For revisions, describe what to change.
|
||||
parent_report_id: ID of a previous resume to revise (creates
|
||||
new version in the same version group).
|
||||
max_pages: Maximum number of pages for the generated resume.
|
||||
Defaults to 1. Allowed range: 1-5.
|
||||
|
||||
Returns:
|
||||
Dict with status, report_id, title, and content_type.
|
||||
|
|
@ -469,6 +522,19 @@ def create_generate_resume_tool(
|
|||
return None
|
||||
|
||||
try:
|
||||
try:
|
||||
validated_max_pages = _validate_max_pages(max_pages)
|
||||
except ValueError as e:
|
||||
error_msg = str(e)
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
# ── Phase 1: READ ─────────────────────────────────────────────
|
||||
async with shielded_async_session() as read_session:
|
||||
if parent_report_id:
|
||||
|
|
@ -512,6 +578,7 @@ def create_generate_resume_tool(
|
|||
parent_body = _strip_header(parent_content)
|
||||
prompt = _REVISION_PROMPT.format(
|
||||
llm_reference=llm_reference,
|
||||
max_pages=validated_max_pages,
|
||||
user_instructions=user_instructions
|
||||
or "Improve and refine the resume.",
|
||||
previous_content=parent_body,
|
||||
|
|
@ -524,6 +591,7 @@ def create_generate_resume_tool(
|
|||
prompt = _RESUME_PROMPT.format(
|
||||
llm_reference=llm_reference,
|
||||
user_info=user_info,
|
||||
max_pages=validated_max_pages,
|
||||
user_instructions_section=user_instructions_section,
|
||||
)
|
||||
|
||||
|
|
@ -551,49 +619,116 @@ def create_generate_resume_tool(
|
|||
)
|
||||
|
||||
name = _extract_name(body) or "Resume"
|
||||
header = _build_header(template, name)
|
||||
typst_source = header + body
|
||||
typst_source = ""
|
||||
actual_pages = 0
|
||||
compression_attempts = 0
|
||||
target_page_met = False
|
||||
|
||||
compile_error: str | None = None
|
||||
for attempt in range(2):
|
||||
try:
|
||||
_compile_typst(typst_source)
|
||||
compile_error = None
|
||||
break
|
||||
except Exception as e:
|
||||
compile_error = str(e)
|
||||
logger.warning(
|
||||
f"[generate_resume] Compile attempt {attempt + 1} failed: {compile_error}"
|
||||
for compression_round in range(MAX_COMPRESSION_ATTEMPTS + 1):
|
||||
header = _build_header(template, name)
|
||||
typst_source = header + body
|
||||
compile_error: str | None = None
|
||||
pdf_bytes: bytes | None = None
|
||||
|
||||
for compile_attempt in range(2):
|
||||
try:
|
||||
pdf_bytes = _compile_typst(typst_source)
|
||||
compile_error = None
|
||||
break
|
||||
except Exception as e:
|
||||
compile_error = str(e)
|
||||
logger.warning(
|
||||
"[generate_resume] Compile attempt %s failed: %s",
|
||||
compile_attempt + 1,
|
||||
compile_error,
|
||||
)
|
||||
|
||||
if compile_attempt == 0:
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{
|
||||
"phase": "fixing",
|
||||
"message": "Fixing compilation issue...",
|
||||
},
|
||||
)
|
||||
fix_prompt = _FIX_COMPILE_PROMPT.format(
|
||||
llm_reference=llm_reference,
|
||||
error=compile_error,
|
||||
full_source=typst_source,
|
||||
)
|
||||
fix_response = await llm.ainvoke(
|
||||
[HumanMessage(content=fix_prompt)]
|
||||
)
|
||||
if fix_response.content and isinstance(
|
||||
fix_response.content, str
|
||||
):
|
||||
body = _strip_typst_fences(fix_response.content)
|
||||
body = _strip_imports(body)
|
||||
name = _extract_name(body) or name
|
||||
header = _build_header(template, name)
|
||||
typst_source = header + body
|
||||
|
||||
if compile_error or not pdf_bytes:
|
||||
error_msg = (
|
||||
"Typst compilation failed after 2 attempts: "
|
||||
f"{compile_error or 'Unknown compile error'}"
|
||||
)
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
if attempt == 0:
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{
|
||||
"phase": "fixing",
|
||||
"message": "Fixing compilation issue...",
|
||||
},
|
||||
)
|
||||
fix_prompt = _FIX_COMPILE_PROMPT.format(
|
||||
llm_reference=llm_reference,
|
||||
error=compile_error,
|
||||
full_source=typst_source,
|
||||
)
|
||||
fix_response = await llm.ainvoke(
|
||||
[HumanMessage(content=fix_prompt)]
|
||||
)
|
||||
if fix_response.content and isinstance(
|
||||
fix_response.content, str
|
||||
):
|
||||
body = _strip_typst_fences(fix_response.content)
|
||||
body = _strip_imports(body)
|
||||
name = _extract_name(body) or name
|
||||
header = _build_header(template, name)
|
||||
typst_source = header + body
|
||||
actual_pages = _count_pdf_pages(pdf_bytes)
|
||||
if actual_pages <= validated_max_pages:
|
||||
target_page_met = True
|
||||
break
|
||||
|
||||
if compile_error:
|
||||
if compression_round >= MAX_COMPRESSION_ATTEMPTS:
|
||||
break
|
||||
|
||||
compression_attempts += 1
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{
|
||||
"phase": "compressing",
|
||||
"message": f"Condensing resume to {validated_max_pages} page(s)...",
|
||||
},
|
||||
)
|
||||
compress_prompt = _COMPRESS_TO_PAGE_LIMIT_PROMPT.format(
|
||||
llm_reference=llm_reference,
|
||||
max_pages=validated_max_pages,
|
||||
actual_pages=actual_pages,
|
||||
attempt_number=compression_attempts,
|
||||
previous_content=body,
|
||||
)
|
||||
compress_response = await llm.ainvoke(
|
||||
[HumanMessage(content=compress_prompt)]
|
||||
)
|
||||
if not compress_response.content or not isinstance(
|
||||
compress_response.content, str
|
||||
):
|
||||
error_msg = "LLM returned empty content while compressing resume"
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
body = _strip_typst_fences(compress_response.content)
|
||||
body = _strip_imports(body)
|
||||
name = _extract_name(body) or name
|
||||
|
||||
if actual_pages > MAX_RESUME_PAGES:
|
||||
error_msg = (
|
||||
f"Typst compilation failed after 2 attempts: {compile_error}"
|
||||
"Resume exceeds hard page limit after compression retries. "
|
||||
f"Hard limit: <= {MAX_RESUME_PAGES} page(s), actual: {actual_pages}."
|
||||
)
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
|
|
@ -616,6 +751,11 @@ def create_generate_resume_tool(
|
|||
"status": "ready",
|
||||
"word_count": len(typst_source.split()),
|
||||
"char_count": len(typst_source),
|
||||
"target_max_pages": validated_max_pages,
|
||||
"actual_page_count": actual_pages,
|
||||
"page_limit_enforced": True,
|
||||
"compression_attempts": compression_attempts,
|
||||
"target_page_met": target_page_met,
|
||||
}
|
||||
|
||||
async with shielded_async_session() as write_session:
|
||||
|
|
@ -647,7 +787,14 @@ def create_generate_resume_tool(
|
|||
"title": resume_title,
|
||||
"content_type": "typst",
|
||||
"is_revision": bool(parent_content),
|
||||
"message": f"Resume generated successfully: {resume_title}",
|
||||
"message": (
|
||||
f"Resume generated successfully: {resume_title}"
|
||||
if target_page_met
|
||||
else (
|
||||
f"Resume generated, but could not fit the target of <= {validated_max_pages} "
|
||||
f"page(s). Final length: {actual_pages} page(s)."
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.teams.list_channels import (
|
||||
create_list_teams_channels_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.teams.read_messages import (
|
||||
create_read_teams_messages_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.teams.send_message import (
|
||||
create_send_teams_message_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_list_teams_channels_tool",
|
||||
"create_read_teams_messages_tool",
|
||||
"create_send_teams_message_tool",
|
||||
]
|
||||
37
surfsense_backend/app/agents/new_chat/tools/teams/_auth.py
Normal file
37
surfsense_backend/app/agents/new_chat/tools/teams/_auth.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""Shared auth helper for Teams agent tools (Microsoft Graph REST API)."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
GRAPH_API = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
|
||||
async def get_teams_connector(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> SearchSourceConnector | None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def get_access_token(
|
||||
db_session: AsyncSession,
|
||||
connector: SearchSourceConnector,
|
||||
) -> str:
|
||||
"""Get a valid Microsoft Graph access token, refreshing if expired."""
|
||||
from app.connectors.teams_connector import TeamsConnector
|
||||
|
||||
tc = TeamsConnector(
|
||||
session=db_session,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
return await tc._get_valid_token()
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_list_teams_channels_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def list_teams_channels() -> dict[str, Any]:
|
||||
"""List all Microsoft Teams and their channels the user has access to.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of teams, each containing
|
||||
team_id, team_name, and a list of channels (id, name).
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
token = await get_access_token(db_session, connector)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
teams_resp = await client.get(f"{GRAPH_API}/me/joinedTeams", headers=headers)
|
||||
|
||||
if teams_resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
|
||||
if teams_resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Graph API error: {teams_resp.status_code}"}
|
||||
|
||||
teams_data = teams_resp.json().get("value", [])
|
||||
result_teams = []
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
for team in teams_data:
|
||||
team_id = team["id"]
|
||||
ch_resp = await client.get(
|
||||
f"{GRAPH_API}/teams/{team_id}/channels",
|
||||
headers=headers,
|
||||
)
|
||||
channels = []
|
||||
if ch_resp.status_code == 200:
|
||||
channels = [
|
||||
{"id": ch["id"], "name": ch.get("displayName", "")}
|
||||
for ch in ch_resp.json().get("value", [])
|
||||
]
|
||||
result_teams.append({
|
||||
"team_id": team_id,
|
||||
"team_name": team.get("displayName", ""),
|
||||
"channels": channels,
|
||||
})
|
||||
|
||||
return {"status": "success", "teams": result_teams, "total_teams": len(result_teams)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error listing Teams channels: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to list Teams channels."}
|
||||
|
||||
return list_teams_channels
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_read_teams_messages_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_teams_messages(
|
||||
team_id: str,
|
||||
channel_id: str,
|
||||
limit: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""Read recent messages from a Microsoft Teams channel.
|
||||
|
||||
Args:
|
||||
team_id: The team ID (from list_teams_channels).
|
||||
channel_id: The channel ID (from list_teams_channels).
|
||||
limit: Number of messages to fetch (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of messages including
|
||||
id, sender, content, timestamp.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||
|
||||
limit = min(limit, 50)
|
||||
|
||||
try:
|
||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
token = await get_access_token(db_session, connector)
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.get(
|
||||
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
params={"$top": limit},
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Insufficient permissions to read this channel."}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Graph API error: {resp.status_code}"}
|
||||
|
||||
raw_msgs = resp.json().get("value", [])
|
||||
messages = []
|
||||
for m in raw_msgs:
|
||||
sender = m.get("from", {})
|
||||
user_info = sender.get("user", {}) if sender else {}
|
||||
body = m.get("body", {})
|
||||
messages.append({
|
||||
"id": m.get("id"),
|
||||
"sender": user_info.get("displayName", "Unknown"),
|
||||
"content": body.get("content", ""),
|
||||
"content_type": body.get("contentType", "text"),
|
||||
"timestamp": m.get("createdDateTime", ""),
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"team_id": team_id,
|
||||
"channel_id": channel_id,
|
||||
"messages": messages,
|
||||
"total": len(messages),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Teams messages: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read Teams messages."}
|
||||
|
||||
return read_teams_messages
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
|
||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_send_teams_message_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def send_teams_message(
|
||||
team_id: str,
|
||||
channel_id: str,
|
||||
content: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a message to a Microsoft Teams channel.
|
||||
|
||||
Requires the ChannelMessage.Send OAuth scope. If the user gets a
|
||||
permission error, they may need to re-authenticate with updated scopes.
|
||||
|
||||
Args:
|
||||
team_id: The team ID (from list_teams_channels).
|
||||
channel_id: The channel ID (from list_teams_channels).
|
||||
content: The message text (HTML supported).
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message_id on success.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="teams_send_message",
|
||||
tool_name="send_teams_message",
|
||||
params={"team_id": team_id, "channel_id": channel_id, "content": content},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {"status": "rejected", "message": "User declined. Message was not sent."}
|
||||
|
||||
final_content = result.params.get("content", content)
|
||||
final_team = result.params.get("team_id", team_id)
|
||||
final_channel = result.params.get("channel_id", channel_id)
|
||||
|
||||
token = await get_access_token(db_session, connector)
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(
|
||||
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"body": {"content": final_content}},
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
|
||||
}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {"status": "error", "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}"}
|
||||
|
||||
msg_data = resp.json()
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": msg_data.get("id"),
|
||||
"message": f"Message sent to Teams channel.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error sending Teams message: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to send Teams message."}
|
||||
|
||||
return send_teams_message
|
||||
41
surfsense_backend/app/agents/new_chat/tools/tool_response.py
Normal file
41
surfsense_backend/app/agents/new_chat/tools/tool_response.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Standardised response dict factories for LangChain agent tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ToolResponse:
|
||||
|
||||
@staticmethod
|
||||
def success(message: str, **data: Any) -> dict[str, Any]:
|
||||
return {"status": "success", "message": message, **data}
|
||||
|
||||
@staticmethod
|
||||
def error(error: str, **data: Any) -> dict[str, Any]:
|
||||
return {"status": "error", "error": error, **data}
|
||||
|
||||
@staticmethod
|
||||
def auth_error(service: str, **data: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"error": (
|
||||
f"{service} authentication has expired or been revoked. "
|
||||
"Please re-connect the integration in Settings → Connectors."
|
||||
),
|
||||
**data,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def rejected(message: str = "Action was declined by the user.") -> dict[str, Any]:
|
||||
return {"status": "rejected", "message": message}
|
||||
|
||||
@staticmethod
|
||||
def not_found(
|
||||
resource: str, identifier: str, **data: Any
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"error": f"{resource} '{identifier}' was not found.",
|
||||
**data,
|
||||
}
|
||||
|
|
@ -141,6 +141,15 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons
|
|||
exc.status_code,
|
||||
message,
|
||||
)
|
||||
elif exc.status_code >= 400:
|
||||
_error_logger.warning(
|
||||
"[%s] %s %s - HTTPException %d: %s",
|
||||
rid,
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc.status_code,
|
||||
message,
|
||||
)
|
||||
if should_sanitize:
|
||||
message = GENERIC_5XX_MESSAGE
|
||||
err_code = "INTERNAL_ERROR"
|
||||
|
|
@ -170,6 +179,15 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons
|
|||
exc.status_code,
|
||||
detail,
|
||||
)
|
||||
elif exc.status_code >= 400:
|
||||
_error_logger.warning(
|
||||
"[%s] %s %s - HTTPException %d: %s",
|
||||
rid,
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc.status_code,
|
||||
detail,
|
||||
)
|
||||
if should_sanitize:
|
||||
detail = GENERIC_5XX_MESSAGE
|
||||
code = _status_to_code(exc.status_code, detail)
|
||||
|
|
|
|||
|
|
@ -136,20 +136,12 @@ celery_app.conf.update(
|
|||
# never block fast user-facing tasks (file uploads, podcasts, etc.)
|
||||
task_routes={
|
||||
# Connector indexing tasks → connectors queue
|
||||
"index_slack_messages": {"queue": CONNECTORS_QUEUE},
|
||||
"index_notion_pages": {"queue": CONNECTORS_QUEUE},
|
||||
"index_github_repos": {"queue": CONNECTORS_QUEUE},
|
||||
"index_linear_issues": {"queue": CONNECTORS_QUEUE},
|
||||
"index_jira_issues": {"queue": CONNECTORS_QUEUE},
|
||||
"index_confluence_pages": {"queue": CONNECTORS_QUEUE},
|
||||
"index_clickup_tasks": {"queue": CONNECTORS_QUEUE},
|
||||
"index_google_calendar_events": {"queue": CONNECTORS_QUEUE},
|
||||
"index_airtable_records": {"queue": CONNECTORS_QUEUE},
|
||||
"index_google_gmail_messages": {"queue": CONNECTORS_QUEUE},
|
||||
"index_google_drive_files": {"queue": CONNECTORS_QUEUE},
|
||||
"index_discord_messages": {"queue": CONNECTORS_QUEUE},
|
||||
"index_teams_messages": {"queue": CONNECTORS_QUEUE},
|
||||
"index_luma_events": {"queue": CONNECTORS_QUEUE},
|
||||
"index_elasticsearch_documents": {"queue": CONNECTORS_QUEUE},
|
||||
"index_crawled_urls": {"queue": CONNECTORS_QUEUE},
|
||||
"index_bookstack_pages": {"queue": CONNECTORS_QUEUE},
|
||||
|
|
|
|||
|
|
@ -339,6 +339,9 @@ class Config:
|
|||
# self-hosted: Full access to local file system connectors (Obsidian, etc.)
|
||||
# cloud: Only cloud-based connectors available
|
||||
DEPLOYMENT_MODE = os.getenv("SURFSENSE_DEPLOYMENT_MODE", "self-hosted")
|
||||
ENABLE_DESKTOP_LOCAL_FILESYSTEM = (
|
||||
os.getenv("ENABLE_DESKTOP_LOCAL_FILESYSTEM", "FALSE").upper() == "TRUE"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_self_hosted(cls) -> bool:
|
||||
|
|
|
|||
98
surfsense_backend/app/connectors/exceptions.py
Normal file
98
surfsense_backend/app/connectors/exceptions.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""Standard exception hierarchy for all connectors.
|
||||
|
||||
ConnectorError
|
||||
├── ConnectorAuthError (401/403 — non-retryable)
|
||||
├── ConnectorRateLimitError (429 — retryable, carries ``retry_after``)
|
||||
├── ConnectorTimeoutError (timeout/504 — retryable)
|
||||
└── ConnectorAPIError (5xx or unexpected — retryable when >= 500)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ConnectorError(Exception):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
service: str = "",
|
||||
status_code: int | None = None,
|
||||
response_body: Any = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.service = service
|
||||
self.status_code = status_code
|
||||
self.response_body = response_body
|
||||
|
||||
@property
|
||||
def retryable(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class ConnectorAuthError(ConnectorError):
|
||||
"""Token expired, revoked, insufficient scopes, or needs re-auth (401/403)."""
|
||||
|
||||
@property
|
||||
def retryable(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class ConnectorRateLimitError(ConnectorError):
|
||||
"""429 Too Many Requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Rate limited",
|
||||
*,
|
||||
service: str = "",
|
||||
retry_after: float | None = None,
|
||||
status_code: int = 429,
|
||||
response_body: Any = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
message,
|
||||
service=service,
|
||||
status_code=status_code,
|
||||
response_body=response_body,
|
||||
)
|
||||
self.retry_after = retry_after
|
||||
|
||||
@property
|
||||
def retryable(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class ConnectorTimeoutError(ConnectorError):
|
||||
"""Request timeout or gateway timeout (504)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Request timed out",
|
||||
*,
|
||||
service: str = "",
|
||||
status_code: int | None = None,
|
||||
response_body: Any = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
message,
|
||||
service=service,
|
||||
status_code=status_code,
|
||||
response_body=response_body,
|
||||
)
|
||||
|
||||
@property
|
||||
def retryable(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class ConnectorAPIError(ConnectorError):
|
||||
"""Generic API error (5xx or unexpected status codes)."""
|
||||
|
||||
@property
|
||||
def retryable(self) -> bool:
|
||||
if self.status_code is not None:
|
||||
return self.status_code >= 500
|
||||
return False
|
||||
|
|
@ -30,6 +30,7 @@ from .jira_add_connector_route import router as jira_add_connector_router
|
|||
from .linear_add_connector_route import router as linear_add_connector_router
|
||||
from .logs_routes import router as logs_router
|
||||
from .luma_add_connector_route import router as luma_add_connector_router
|
||||
from .mcp_oauth_route import router as mcp_oauth_router
|
||||
from .memory_routes import router as memory_router
|
||||
from .model_list_routes import router as model_list_router
|
||||
from .new_chat_routes import router as new_chat_router
|
||||
|
|
@ -97,6 +98,7 @@ router.include_router(logs_router)
|
|||
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
|
||||
router.include_router(surfsense_docs_router) # Surfsense documentation for citations
|
||||
router.include_router(notifications_router) # Notifications with Zero sync
|
||||
router.include_router(mcp_oauth_router) # MCP OAuth 2.1 for Linear, Jira, ClickUp, Slack, Airtable
|
||||
router.include_router(composio_router) # Composio OAuth and toolkit management
|
||||
router.include_router(public_chat_router) # Public chat sharing and cloning
|
||||
router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages
|
||||
|
|
|
|||
|
|
@ -311,7 +311,7 @@ async def airtable_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.AIRTABLE_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=credentials_dict,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -301,7 +301,7 @@ async def clickup_callback(
|
|||
# Update existing connector
|
||||
existing_connector.config = connector_config
|
||||
existing_connector.name = "ClickUp Connector"
|
||||
existing_connector.is_indexable = True
|
||||
existing_connector.is_indexable = False
|
||||
logger.info(
|
||||
f"Updated existing ClickUp connector for user {user_id} in space {space_id}"
|
||||
)
|
||||
|
|
@ -310,7 +310,7 @@ async def clickup_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name="ClickUp Connector",
|
||||
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -326,7 +326,7 @@ async def discord_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -340,7 +340,7 @@ async def calendar_callback(
|
|||
config=creds_dict,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
)
|
||||
session.add(db_connector)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -371,7 +371,7 @@ async def gmail_callback(
|
|||
config=creds_dict,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
)
|
||||
session.add(db_connector)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -386,7 +386,7 @@ async def jira_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -399,7 +399,7 @@ async def linear_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ async def add_luma_connector(
|
|||
if existing_connector:
|
||||
# Update existing connector with new API key
|
||||
existing_connector.config = {"api_key": request.api_key}
|
||||
existing_connector.is_indexable = True
|
||||
existing_connector.is_indexable = False
|
||||
await session.commit()
|
||||
await session.refresh(existing_connector)
|
||||
|
||||
|
|
@ -82,7 +82,7 @@ async def add_luma_connector(
|
|||
config={"api_key": request.api_key},
|
||||
search_space_id=request.space_id,
|
||||
user_id=user.id,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
)
|
||||
|
||||
session.add(db_connector)
|
||||
|
|
|
|||
601
surfsense_backend/app/routes/mcp_oauth_route.py
Normal file
601
surfsense_backend/app/routes/mcp_oauth_route.py
Normal file
|
|
@ -0,0 +1,601 @@
|
|||
"""Generic MCP OAuth 2.1 route for services with official MCP servers.
|
||||
|
||||
Handles the full flow: discovery → DCR → PKCE authorization → token exchange
|
||||
→ MCP_CONNECTOR creation. Currently supports Linear, Jira, ClickUp, Slack,
|
||||
and Airtable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.connector_naming import generate_unique_connector_name
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_pkce_pair
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _fetch_account_metadata(
|
||||
service_key: str, access_token: str, token_json: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch display-friendly account metadata after a successful token exchange.
|
||||
|
||||
DCR services (Linear, Jira, ClickUp) issue MCP-scoped tokens that cannot
|
||||
call their standard REST/GraphQL APIs — metadata discovery for those
|
||||
happens at runtime through MCP tools instead.
|
||||
|
||||
Pre-configured services (Slack, Airtable) use standard OAuth tokens that
|
||||
*can* call their APIs, so we extract metadata here.
|
||||
|
||||
Failures are logged but never block connector creation.
|
||||
"""
|
||||
from app.services.mcp_oauth.registry import MCP_SERVICES
|
||||
|
||||
svc = MCP_SERVICES.get(service_key)
|
||||
if not svc or svc.supports_dcr:
|
||||
return {}
|
||||
|
||||
import httpx
|
||||
|
||||
meta: dict[str, Any] = {}
|
||||
|
||||
try:
|
||||
if service_key == "slack":
|
||||
team_info = token_json.get("team", {})
|
||||
meta["team_id"] = team_info.get("id", "")
|
||||
# TODO: oauth.v2.user.access only returns team.id, not
|
||||
# team.name. To populate team_name, add "team:read" scope
|
||||
# and call GET /api/team.info here.
|
||||
meta["team_name"] = team_info.get("name", "")
|
||||
if meta["team_name"]:
|
||||
meta["display_name"] = meta["team_name"]
|
||||
elif meta["team_id"]:
|
||||
meta["display_name"] = f"Slack ({meta['team_id']})"
|
||||
|
||||
elif service_key == "airtable":
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
"https://api.airtable.com/v0/meta/whoami",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
whoami = resp.json()
|
||||
meta["user_id"] = whoami.get("id", "")
|
||||
meta["user_email"] = whoami.get("email", "")
|
||||
meta["display_name"] = whoami.get("email", "Airtable")
|
||||
else:
|
||||
logger.warning(
|
||||
"Airtable whoami API returned %d (non-blocking)", resp.status_code,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch account metadata for %s (non-blocking)",
|
||||
service_key,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return meta
|
||||
|
||||
_state_manager: OAuthStateManager | None = None
|
||||
_token_encryption: TokenEncryption | None = None
|
||||
|
||||
|
||||
def _get_state_manager() -> OAuthStateManager:
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
|
||||
_state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return _state_manager
|
||||
|
||||
|
||||
def _get_token_encryption() -> TokenEncryption:
|
||||
global _token_encryption
|
||||
if _token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
|
||||
_token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption
|
||||
|
||||
|
||||
def _build_redirect_uri(service: str) -> str:
|
||||
base = config.BACKEND_URL or "http://localhost:8000"
|
||||
return f"{base.rstrip('/')}/api/v1/auth/mcp/{service}/connector/callback"
|
||||
|
||||
|
||||
def _frontend_redirect(
|
||||
space_id: int | None,
|
||||
*,
|
||||
success: bool = False,
|
||||
connector_id: int | None = None,
|
||||
error: str | None = None,
|
||||
service: str = "mcp",
|
||||
) -> RedirectResponse:
|
||||
if success and space_id:
|
||||
qs = f"success=true&connector={service}-mcp-connector"
|
||||
if connector_id:
|
||||
qs += f"&connectorId={connector_id}"
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
|
||||
)
|
||||
if error and space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
|
||||
)
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /add — start MCP OAuth flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/auth/mcp/{service}/connector/add")
|
||||
async def connect_mcp_service(
|
||||
service: str,
|
||||
space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
from app.services.mcp_oauth.registry import get_service
|
||||
|
||||
svc = get_service(service)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
|
||||
|
||||
try:
|
||||
from app.services.mcp_oauth.discovery import (
|
||||
discover_oauth_metadata,
|
||||
register_client,
|
||||
)
|
||||
|
||||
metadata = await discover_oauth_metadata(
|
||||
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
|
||||
)
|
||||
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
|
||||
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
|
||||
registration_endpoint = metadata.get("registration_endpoint")
|
||||
|
||||
if not auth_endpoint or not token_endpoint:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
|
||||
)
|
||||
|
||||
redirect_uri = _build_redirect_uri(service)
|
||||
|
||||
if svc.supports_dcr and registration_endpoint:
|
||||
dcr = await register_client(registration_endpoint, redirect_uri)
|
||||
client_id = dcr.get("client_id")
|
||||
client_secret = dcr.get("client_secret", "")
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"DCR for {svc.name} did not return a client_id.",
|
||||
)
|
||||
elif svc.client_id_env:
|
||||
client_id = getattr(config, svc.client_id_env, None)
|
||||
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
|
||||
)
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
enc = _get_token_encryption()
|
||||
|
||||
state = _get_state_manager().generate_secure_state(
|
||||
space_id,
|
||||
user.id,
|
||||
service=service,
|
||||
code_verifier=verifier,
|
||||
mcp_client_id=client_id,
|
||||
mcp_client_secret=enc.encrypt_token(client_secret) if client_secret else "",
|
||||
mcp_token_endpoint=token_endpoint,
|
||||
mcp_url=svc.mcp_url,
|
||||
)
|
||||
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": state,
|
||||
}
|
||||
if svc.scopes:
|
||||
auth_params[svc.scope_param] = " ".join(svc.scopes)
|
||||
|
||||
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Generated %s MCP OAuth URL for user %s, space %s",
|
||||
svc.name, user.id, space_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to initiate %s MCP OAuth: %s", service, e, exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate {service} MCP OAuth.",
|
||||
) from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /callback — handle OAuth redirect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/auth/mcp/{service}/connector/callback")
|
||||
async def mcp_oauth_callback(
|
||||
service: str,
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
if error:
|
||||
logger.warning("%s MCP OAuth error: %s", service, error)
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
data = _get_state_manager().validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
pass
|
||||
return _frontend_redirect(
|
||||
space_id, error=f"{service}_mcp_oauth_denied", service=service,
|
||||
)
|
||||
|
||||
if not code:
|
||||
raise HTTPException(status_code=400, detail="Missing authorization code")
|
||||
if not state:
|
||||
raise HTTPException(status_code=400, detail="Missing state parameter")
|
||||
|
||||
data = _get_state_manager().validate_state(state)
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
svc_key = data.get("service", service)
|
||||
|
||||
if svc_key != service:
|
||||
raise HTTPException(status_code=400, detail="State/path service mismatch")
|
||||
|
||||
from app.services.mcp_oauth.registry import get_service
|
||||
|
||||
svc = get_service(svc_key)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {svc_key}")
|
||||
|
||||
try:
|
||||
from app.services.mcp_oauth.discovery import exchange_code_for_tokens
|
||||
|
||||
enc = _get_token_encryption()
|
||||
client_id = data["mcp_client_id"]
|
||||
client_secret = (
|
||||
enc.decrypt_token(data["mcp_client_secret"])
|
||||
if data.get("mcp_client_secret")
|
||||
else ""
|
||||
)
|
||||
token_endpoint = data["mcp_token_endpoint"]
|
||||
code_verifier = data["code_verifier"]
|
||||
mcp_url = data["mcp_url"]
|
||||
redirect_uri = _build_redirect_uri(service)
|
||||
|
||||
token_json = await exchange_code_for_tokens(
|
||||
token_endpoint=token_endpoint,
|
||||
code=code,
|
||||
redirect_uri=redirect_uri,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
|
||||
access_token = token_json.get("access_token")
|
||||
refresh_token = token_json.get("refresh_token")
|
||||
expires_in = token_json.get("expires_in")
|
||||
scope = token_json.get("scope")
|
||||
|
||||
if not access_token and "authed_user" in token_json:
|
||||
authed = token_json["authed_user"]
|
||||
access_token = authed.get("access_token")
|
||||
refresh_token = refresh_token or authed.get("refresh_token")
|
||||
scope = scope or authed.get("scope")
|
||||
expires_in = expires_in or authed.get("expires_in")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No access token received from {svc.name}.",
|
||||
)
|
||||
|
||||
expires_at = None
|
||||
if expires_in:
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(expires_in)
|
||||
)
|
||||
|
||||
connector_config = {
|
||||
"server_config": {
|
||||
"transport": "streamable-http",
|
||||
"url": mcp_url,
|
||||
},
|
||||
"mcp_service": svc_key,
|
||||
"mcp_oauth": {
|
||||
"client_id": client_id,
|
||||
"client_secret": enc.encrypt_token(client_secret) if client_secret else "",
|
||||
"token_endpoint": token_endpoint,
|
||||
"access_token": enc.encrypt_token(access_token),
|
||||
"refresh_token": enc.encrypt_token(refresh_token) if refresh_token else None,
|
||||
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||
"scope": scope,
|
||||
},
|
||||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
account_meta = await _fetch_account_metadata(svc_key, access_token, token_json)
|
||||
if account_meta:
|
||||
_SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email",
|
||||
"workspace_id", "workspace_name", "organization_name",
|
||||
"organization_url_key", "cloud_id", "site_name", "base_url"}
|
||||
for k, v in account_meta.items():
|
||||
if k in _SAFE_META_KEYS:
|
||||
connector_config[k] = v
|
||||
logger.info(
|
||||
"Stored account metadata for %s: display_name=%s",
|
||||
svc_key, account_meta.get("display_name", ""),
|
||||
)
|
||||
|
||||
# ---- Re-auth path ----
|
||||
db_connector_type = SearchSourceConnectorType(svc.connector_type)
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == db_connector_type,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = connector_config
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
_invalidate_cache(space_id)
|
||||
|
||||
logger.info(
|
||||
"Re-authenticated %s MCP connector %s for user %s",
|
||||
svc.name, db_connector.id, user_id,
|
||||
)
|
||||
reauth_return_url = data.get("return_url")
|
||||
if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return _frontend_redirect(
|
||||
space_id, success=True, connector_id=db_connector.id, service=service,
|
||||
)
|
||||
|
||||
# ---- New connector path ----
|
||||
naming_identifier = account_meta.get("display_name")
|
||||
connector_name = await generate_unique_connector_name(
|
||||
session,
|
||||
db_connector_type,
|
||||
space_id,
|
||||
user_id,
|
||||
naming_identifier,
|
||||
)
|
||||
|
||||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=db_connector_type,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(new_connector)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409, detail="A connector for this service already exists.",
|
||||
) from e
|
||||
|
||||
_invalidate_cache(space_id)
|
||||
|
||||
logger.info(
|
||||
"Created %s MCP connector %s for user %s in space %s",
|
||||
svc.name, new_connector.id, user_id, space_id,
|
||||
)
|
||||
return _frontend_redirect(
|
||||
space_id, success=True, connector_id=new_connector.id, service=service,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to complete %s MCP OAuth: %s", service, e, exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to complete {service} MCP OAuth.",
|
||||
) from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /reauth — re-authenticate an existing MCP connector
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/auth/mcp/{service}/connector/reauth")
|
||||
async def reauth_mcp_service(
|
||||
service: str,
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
from app.services.mcp_oauth.registry import get_service
|
||||
|
||||
svc = get_service(service)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
|
||||
|
||||
db_connector_type = SearchSourceConnectorType(svc.connector_type)
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == db_connector_type,
|
||||
)
|
||||
)
|
||||
if not result.scalars().first():
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Connector not found or access denied",
|
||||
)
|
||||
|
||||
try:
|
||||
from app.services.mcp_oauth.discovery import (
|
||||
discover_oauth_metadata,
|
||||
register_client,
|
||||
)
|
||||
|
||||
metadata = await discover_oauth_metadata(
|
||||
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
|
||||
)
|
||||
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
|
||||
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
|
||||
registration_endpoint = metadata.get("registration_endpoint")
|
||||
|
||||
if not auth_endpoint or not token_endpoint:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
|
||||
)
|
||||
|
||||
redirect_uri = _build_redirect_uri(service)
|
||||
|
||||
if svc.supports_dcr and registration_endpoint:
|
||||
dcr = await register_client(registration_endpoint, redirect_uri)
|
||||
client_id = dcr.get("client_id")
|
||||
client_secret = dcr.get("client_secret", "")
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"DCR for {svc.name} did not return a client_id.",
|
||||
)
|
||||
elif svc.client_id_env:
|
||||
client_id = getattr(config, svc.client_id_env, None)
|
||||
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
|
||||
)
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
enc = _get_token_encryption()
|
||||
|
||||
extra: dict = {
|
||||
"service": service,
|
||||
"code_verifier": verifier,
|
||||
"mcp_client_id": client_id,
|
||||
"mcp_client_secret": enc.encrypt_token(client_secret) if client_secret else "",
|
||||
"mcp_token_endpoint": token_endpoint,
|
||||
"mcp_url": svc.mcp_url,
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
|
||||
state = _get_state_manager().generate_secure_state(
|
||||
space_id, user.id, **extra,
|
||||
)
|
||||
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": state,
|
||||
}
|
||||
if svc.scopes:
|
||||
auth_params[svc.scope_param] = " ".join(svc.scopes)
|
||||
|
||||
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Initiating %s MCP re-auth for user %s, connector %s",
|
||||
svc.name, user.id, connector_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to initiate %s MCP re-auth: %s", service, e, exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to initiate {service} MCP re-auth.",
|
||||
) from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _invalidate_cache(space_id: int) -> None:
|
||||
try:
|
||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||
|
||||
invalidate_mcp_tools_cache(space_id)
|
||||
except Exception:
|
||||
logger.debug("MCP cache invalidation skipped", exc_info=True)
|
||||
|
|
@ -22,6 +22,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import (
|
||||
ClientPlatform,
|
||||
LocalFilesystemMount,
|
||||
FilesystemMode,
|
||||
FilesystemSelection,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ChatComment,
|
||||
ChatVisibility,
|
||||
|
|
@ -36,6 +43,7 @@ from app.db import (
|
|||
)
|
||||
from app.schemas.new_chat import (
|
||||
AgentToolInfo,
|
||||
LocalFilesystemMountPayload,
|
||||
NewChatMessageRead,
|
||||
NewChatRequest,
|
||||
NewChatThreadCreate,
|
||||
|
|
@ -63,6 +71,67 @@ _background_tasks: set[asyncio.Task] = set()
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _resolve_filesystem_selection(
|
||||
*,
|
||||
mode: str,
|
||||
client_platform: str,
|
||||
local_mounts: list[LocalFilesystemMountPayload] | None,
|
||||
) -> FilesystemSelection:
|
||||
"""Validate and normalize filesystem mode settings from request payload."""
|
||||
try:
|
||||
resolved_mode = FilesystemMode(mode)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid filesystem_mode") from exc
|
||||
try:
|
||||
resolved_platform = ClientPlatform(client_platform)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid client_platform") from exc
|
||||
|
||||
if resolved_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
|
||||
if not config.ENABLE_DESKTOP_LOCAL_FILESYSTEM:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Desktop local filesystem mode is disabled on this deployment.",
|
||||
)
|
||||
if resolved_platform != ClientPlatform.DESKTOP:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="desktop_local_folder mode is only available on desktop runtime.",
|
||||
)
|
||||
normalized_mounts: list[tuple[str, str]] = []
|
||||
seen_mounts: set[str] = set()
|
||||
for mount in local_mounts or []:
|
||||
mount_id = mount.mount_id.strip()
|
||||
root_path = mount.root_path.strip()
|
||||
if not mount_id or not root_path:
|
||||
continue
|
||||
if mount_id in seen_mounts:
|
||||
continue
|
||||
seen_mounts.add(mount_id)
|
||||
normalized_mounts.append((mount_id, root_path))
|
||||
if not normalized_mounts:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"local_filesystem_mounts must include at least one mount for "
|
||||
"desktop_local_folder mode."
|
||||
),
|
||||
)
|
||||
return FilesystemSelection(
|
||||
mode=resolved_mode,
|
||||
client_platform=resolved_platform,
|
||||
local_mounts=tuple(
|
||||
LocalFilesystemMount(mount_id=mount_id, root_path=root_path)
|
||||
for mount_id, root_path in normalized_mounts
|
||||
),
|
||||
)
|
||||
|
||||
return FilesystemSelection(
|
||||
mode=FilesystemMode.CLOUD,
|
||||
client_platform=resolved_platform,
|
||||
)
|
||||
|
||||
|
||||
def _try_delete_sandbox(thread_id: int) -> None:
|
||||
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
||||
from app.agents.new_chat.sandbox import (
|
||||
|
|
@ -1098,6 +1167,7 @@ async def list_agent_tools(
|
|||
@router.post("/new_chat")
|
||||
async def handle_new_chat(
|
||||
request: NewChatRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
@ -1133,6 +1203,11 @@ async def handle_new_chat(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
local_mounts=request.local_filesystem_mounts,
|
||||
)
|
||||
|
||||
# Get search space to check LLM config preferences
|
||||
search_space_result = await session.execute(
|
||||
|
|
@ -1175,6 +1250,8 @@ async def handle_new_chat(
|
|||
thread_visibility=thread.visibility,
|
||||
current_user_display_name=user.display_name or "A team member",
|
||||
disabled_tools=request.disabled_tools,
|
||||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
@ -1202,6 +1279,7 @@ async def handle_new_chat(
|
|||
async def regenerate_response(
|
||||
thread_id: int,
|
||||
request: RegenerateRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
@ -1247,6 +1325,11 @@ async def regenerate_response(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
local_mounts=request.local_filesystem_mounts,
|
||||
)
|
||||
|
||||
# Get the checkpointer and state history
|
||||
checkpointer = await get_checkpointer()
|
||||
|
|
@ -1412,6 +1495,8 @@ async def regenerate_response(
|
|||
thread_visibility=thread.visibility,
|
||||
current_user_display_name=user.display_name or "A team member",
|
||||
disabled_tools=request.disabled_tools,
|
||||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
):
|
||||
yield chunk
|
||||
streaming_completed = True
|
||||
|
|
@ -1477,6 +1562,7 @@ async def regenerate_response(
|
|||
async def resume_chat(
|
||||
thread_id: int,
|
||||
request: ResumeRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
@ -1498,6 +1584,11 @@ async def resume_chat(
|
|||
)
|
||||
|
||||
await check_thread_access(session, thread, user)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
local_mounts=request.local_filesystem_mounts,
|
||||
)
|
||||
|
||||
search_space_result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||
|
|
@ -1526,6 +1617,8 @@ async def resume_chat(
|
|||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
thread_visibility=thread.visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
|
|||
620
surfsense_backend/app/routes/oauth_connector_base.py
Normal file
620
surfsense_backend/app/routes/oauth_connector_base.py
Normal file
|
|
@ -0,0 +1,620 @@
|
|||
"""Reusable base for OAuth 2.0 connector routes.
|
||||
|
||||
Subclasses override ``fetch_account_info``, ``build_connector_config``,
|
||||
and ``get_connector_display_name`` to customise provider-specific behaviour.
|
||||
Call ``build_router()`` to get a FastAPI ``APIRouter`` with ``/connector/add``,
|
||||
``/connector/callback``, and ``/connector/reauth`` endpoints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthConnectorRoute:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider_name: str,
|
||||
connector_type: SearchSourceConnectorType,
|
||||
authorize_url: str,
|
||||
token_url: str,
|
||||
client_id_env: str,
|
||||
client_secret_env: str,
|
||||
redirect_uri_env: str,
|
||||
scopes: list[str],
|
||||
auth_prefix: str,
|
||||
use_pkce: bool = False,
|
||||
token_auth_method: str = "body",
|
||||
is_indexable: bool = True,
|
||||
extra_auth_params: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.provider_name = provider_name
|
||||
self.connector_type = connector_type
|
||||
self.authorize_url = authorize_url
|
||||
self.token_url = token_url
|
||||
self.client_id_env = client_id_env
|
||||
self.client_secret_env = client_secret_env
|
||||
self.redirect_uri_env = redirect_uri_env
|
||||
self.scopes = scopes
|
||||
self.auth_prefix = auth_prefix.rstrip("/")
|
||||
self.use_pkce = use_pkce
|
||||
self.token_auth_method = token_auth_method
|
||||
self.is_indexable = is_indexable
|
||||
self.extra_auth_params = extra_auth_params or {}
|
||||
|
||||
self._state_manager: OAuthStateManager | None = None
|
||||
self._token_encryption: TokenEncryption | None = None
|
||||
|
||||
def _get_client_id(self) -> str:
|
||||
value = getattr(config, self.client_id_env, None)
|
||||
if not value:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{self.provider_name.title()} OAuth not configured "
|
||||
f"({self.client_id_env} missing).",
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_client_secret(self) -> str:
|
||||
value = getattr(config, self.client_secret_env, None)
|
||||
if not value:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{self.provider_name.title()} OAuth not configured "
|
||||
f"({self.client_secret_env} missing).",
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_redirect_uri(self) -> str:
|
||||
value = getattr(config, self.redirect_uri_env, None)
|
||||
if not value:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{self.redirect_uri_env} not configured.",
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_state_manager(self) -> OAuthStateManager:
|
||||
if self._state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="SECRET_KEY not configured for OAuth security.",
|
||||
)
|
||||
self._state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return self._state_manager
|
||||
|
||||
def _get_token_encryption(self) -> TokenEncryption:
|
||||
if self._token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="SECRET_KEY not configured for token encryption.",
|
||||
)
|
||||
self._token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return self._token_encryption
|
||||
|
||||
def _frontend_redirect(
|
||||
self,
|
||||
space_id: int | None,
|
||||
*,
|
||||
success: bool = False,
|
||||
connector_id: int | None = None,
|
||||
error: str | None = None,
|
||||
) -> RedirectResponse:
|
||||
if success and space_id:
|
||||
connector_slug = f"{self.provider_name}-connector"
|
||||
qs = f"success=true&connector={connector_slug}"
|
||||
if connector_id:
|
||||
qs += f"&connectorId={connector_id}"
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
|
||||
)
|
||||
if error and space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
|
||||
)
|
||||
if error:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error={error}"
|
||||
)
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
|
||||
|
||||
async def fetch_account_info(self, access_token: str) -> dict[str, Any]:
|
||||
"""Override to fetch account/workspace info after token exchange.
|
||||
|
||||
Return dict is merged into connector config; key ``"name"`` is used
|
||||
for the display name and dedup.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def build_connector_config(
|
||||
self,
|
||||
token_json: dict[str, Any],
|
||||
account_info: dict[str, Any],
|
||||
encryption: TokenEncryption,
|
||||
) -> dict[str, Any]:
|
||||
"""Override for custom config shapes. Default: standard encrypted OAuth fields."""
|
||||
access_token = token_json.get("access_token", "")
|
||||
refresh_token = token_json.get("refresh_token")
|
||||
|
||||
expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"access_token": encryption.encrypt_token(access_token),
|
||||
"refresh_token": (
|
||||
encryption.encrypt_token(refresh_token) if refresh_token else None
|
||||
),
|
||||
"token_type": token_json.get("token_type", "Bearer"),
|
||||
"expires_in": token_json.get("expires_in"),
|
||||
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||
"scope": token_json.get("scope"),
|
||||
"_token_encrypted": True,
|
||||
}
|
||||
cfg.update(account_info)
|
||||
return cfg
|
||||
|
||||
def get_connector_display_name(self, account_info: dict[str, Any]) -> str:
|
||||
return str(account_info.get("name", self.provider_name.title()))
|
||||
|
||||
async def on_token_refresh_failure(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
connector: SearchSourceConnector,
|
||||
) -> None:
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired flag for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _exchange_code(
|
||||
self, code: str, extra_state: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
client_id = self._get_client_id()
|
||||
client_secret = self._get_client_secret()
|
||||
redirect_uri = self._get_redirect_uri()
|
||||
|
||||
headers: dict[str, str] = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
body: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
if self.token_auth_method == "basic":
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
else:
|
||||
body["client_id"] = client_id
|
||||
body["client_secret"] = client_secret
|
||||
|
||||
if self.use_pkce:
|
||||
verifier = extra_state.get("code_verifier")
|
||||
if verifier:
|
||||
body["code_verifier"] = verifier
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
self.token_url, data=body, headers=headers, timeout=30.0
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
detail = resp.text
|
||||
try:
|
||||
detail = resp.json().get("error_description", detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token exchange failed: {detail}"
|
||||
)
|
||||
|
||||
return resp.json()
|
||||
|
||||
async def refresh_token(
|
||||
self, session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
encryption = self._get_token_encryption()
|
||||
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
refresh_tok = connector.config.get("refresh_token")
|
||||
if is_encrypted and refresh_tok:
|
||||
try:
|
||||
refresh_tok = encryption.decrypt_token(refresh_tok)
|
||||
except Exception as e:
|
||||
logger.error("Failed to decrypt refresh token: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to decrypt stored refresh token"
|
||||
) from e
|
||||
|
||||
if not refresh_tok:
|
||||
await self.on_token_refresh_failure(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No refresh token available. Please re-authenticate.",
|
||||
)
|
||||
|
||||
client_id = self._get_client_id()
|
||||
client_secret = self._get_client_secret()
|
||||
|
||||
headers: dict[str, str] = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
body: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_tok,
|
||||
}
|
||||
|
||||
if self.token_auth_method == "basic":
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
else:
|
||||
body["client_id"] = client_id
|
||||
body["client_secret"] = client_secret
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
self.token_url, data=body, headers=headers, timeout=30.0
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
error_detail = resp.text
|
||||
try:
|
||||
ej = resp.json()
|
||||
error_detail = ej.get("error_description", error_detail)
|
||||
error_code = ej.get("error", "")
|
||||
except Exception:
|
||||
error_code = ""
|
||||
combined = (error_detail + error_code).lower()
|
||||
if any(kw in combined for kw in ("invalid_grant", "expired", "revoked")):
|
||||
await self.on_token_refresh_failure(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"{self.provider_name.title()} authentication failed. "
|
||||
"Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
||||
token_json = resp.json()
|
||||
new_access = token_json.get("access_token")
|
||||
if not new_access:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from refresh"
|
||||
)
|
||||
|
||||
expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
updated_config = dict(connector.config)
|
||||
updated_config["access_token"] = encryption.encrypt_token(new_access)
|
||||
new_refresh = token_json.get("refresh_token")
|
||||
if new_refresh:
|
||||
updated_config["refresh_token"] = encryption.encrypt_token(new_refresh)
|
||||
updated_config["expires_in"] = token_json.get("expires_in")
|
||||
updated_config["expires_at"] = expires_at.isoformat() if expires_at else None
|
||||
updated_config["scope"] = token_json.get("scope", updated_config.get("scope"))
|
||||
updated_config["_token_encrypted"] = True
|
||||
updated_config.pop("auth_expired", None)
|
||||
|
||||
connector.config = updated_config
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info(
|
||||
"Refreshed %s token for connector %s",
|
||||
self.provider_name,
|
||||
connector.id,
|
||||
)
|
||||
return connector
|
||||
|
||||
def build_router(self) -> APIRouter:
|
||||
router = APIRouter()
|
||||
oauth = self
|
||||
|
||||
@router.get(f"{oauth.auth_prefix}/connector/add")
|
||||
async def connect(
|
||||
space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
||||
client_id = oauth._get_client_id()
|
||||
state_mgr = oauth._get_state_manager()
|
||||
|
||||
extra_state: dict[str, Any] = {}
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": oauth._get_redirect_uri(),
|
||||
"scope": " ".join(oauth.scopes),
|
||||
}
|
||||
|
||||
if oauth.use_pkce:
|
||||
from app.utils.oauth_security import generate_pkce_pair
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
extra_state["code_verifier"] = verifier
|
||||
auth_params["code_challenge"] = challenge
|
||||
auth_params["code_challenge_method"] = "S256"
|
||||
|
||||
auth_params.update(oauth.extra_auth_params)
|
||||
|
||||
state_encoded = state_mgr.generate_secure_state(
|
||||
space_id, user.id, **extra_state
|
||||
)
|
||||
auth_params["state"] = state_encoded
|
||||
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Generated %s OAuth URL for user %s, space %s",
|
||||
oauth.provider_name,
|
||||
user.id,
|
||||
space_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
@router.get(f"{oauth.auth_prefix}/connector/reauth")
|
||||
async def reauth(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == oauth.connector_type,
|
||||
)
|
||||
)
|
||||
if not result.scalars().first():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"{oauth.provider_name.title()} connector not found "
|
||||
"or access denied",
|
||||
)
|
||||
|
||||
client_id = oauth._get_client_id()
|
||||
state_mgr = oauth._get_state_manager()
|
||||
|
||||
extra: dict[str, Any] = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/") and not return_url.startswith("//"):
|
||||
extra["return_url"] = return_url
|
||||
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": oauth._get_redirect_uri(),
|
||||
"scope": " ".join(oauth.scopes),
|
||||
}
|
||||
|
||||
if oauth.use_pkce:
|
||||
from app.utils.oauth_security import generate_pkce_pair
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
extra["code_verifier"] = verifier
|
||||
auth_params["code_challenge"] = challenge
|
||||
auth_params["code_challenge_method"] = "S256"
|
||||
|
||||
auth_params.update(oauth.extra_auth_params)
|
||||
|
||||
state_encoded = state_mgr.generate_secure_state(
|
||||
space_id, user.id, **extra
|
||||
)
|
||||
auth_params["state"] = state_encoded
|
||||
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Initiating %s re-auth for user %s, connector %s",
|
||||
oauth.provider_name,
|
||||
user.id,
|
||||
connector_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
@router.get(f"{oauth.auth_prefix}/connector/callback")
|
||||
async def callback(
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
error_label = f"{oauth.provider_name}_oauth_denied"
|
||||
|
||||
if error:
|
||||
logger.warning("%s OAuth error: %s", oauth.provider_name, error)
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
data = oauth._get_state_manager().validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
pass
|
||||
return oauth._frontend_redirect(space_id, error=error_label)
|
||||
|
||||
if not code:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing authorization code"
|
||||
)
|
||||
if not state:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing state parameter"
|
||||
)
|
||||
|
||||
state_mgr = oauth._get_state_manager()
|
||||
try:
|
||||
data = state_mgr.validate_state(state)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid or expired state parameter."
|
||||
) from e
|
||||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
|
||||
token_json = await oauth._exchange_code(code, data)
|
||||
|
||||
access_token = token_json.get("access_token", "")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No access token received from {oauth.provider_name.title()}",
|
||||
)
|
||||
|
||||
account_info = await oauth.fetch_account_info(access_token)
|
||||
encryption = oauth._get_token_encryption()
|
||||
connector_config = oauth.build_connector_config(
|
||||
token_json, account_info, encryption
|
||||
)
|
||||
|
||||
display_name = oauth.get_connector_display_name(account_info)
|
||||
|
||||
# --- Re-auth path ---
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == oauth.connector_type,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = connector_config
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
"Re-authenticated %s connector %s for user %s",
|
||||
oauth.provider_name,
|
||||
db_connector.id,
|
||||
user_id,
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return oauth._frontend_redirect(
|
||||
space_id, success=True, connector_id=db_connector.id
|
||||
)
|
||||
|
||||
# --- New connector path ---
|
||||
is_dup = await check_duplicate_connector(
|
||||
session,
|
||||
oauth.connector_type,
|
||||
space_id,
|
||||
user_id,
|
||||
display_name,
|
||||
)
|
||||
if is_dup:
|
||||
logger.warning(
|
||||
"Duplicate %s connector for user %s (%s)",
|
||||
oauth.provider_name,
|
||||
user_id,
|
||||
display_name,
|
||||
)
|
||||
return oauth._frontend_redirect(
|
||||
space_id,
|
||||
error=f"duplicate_account&connector={oauth.provider_name}-connector",
|
||||
)
|
||||
|
||||
connector_name = await generate_unique_connector_name(
|
||||
session,
|
||||
oauth.connector_type,
|
||||
space_id,
|
||||
user_id,
|
||||
display_name,
|
||||
)
|
||||
|
||||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=oauth.connector_type,
|
||||
is_indexable=oauth.is_indexable,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(new_connector)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409, detail="A connector for this service already exists."
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
"Created %s connector %s for user %s in space %s",
|
||||
oauth.provider_name,
|
||||
new_connector.id,
|
||||
user_id,
|
||||
space_id,
|
||||
)
|
||||
return oauth._frontend_redirect(
|
||||
space_id, success=True, connector_id=new_connector.id
|
||||
)
|
||||
|
||||
return router
|
||||
|
|
@ -693,27 +693,10 @@ async def index_connector_content(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Index content from a connector to a search space.
|
||||
Requires CONNECTORS_UPDATE permission (to trigger indexing).
|
||||
Index content from a KB connector to a search space.
|
||||
|
||||
Currently supports:
|
||||
- SLACK_CONNECTOR: Indexes messages from all accessible Slack channels
|
||||
- TEAMS_CONNECTOR: Indexes messages from all accessible Microsoft Teams channels
|
||||
- NOTION_CONNECTOR: Indexes pages from all accessible Notion pages
|
||||
- GITHUB_CONNECTOR: Indexes code and documentation from GitHub repositories
|
||||
- LINEAR_CONNECTOR: Indexes issues and comments from Linear
|
||||
- JIRA_CONNECTOR: Indexes issues and comments from Jira
|
||||
- DISCORD_CONNECTOR: Indexes messages from all accessible Discord channels
|
||||
- LUMA_CONNECTOR: Indexes events from Luma
|
||||
- ELASTICSEARCH_CONNECTOR: Indexes documents from Elasticsearch
|
||||
- WEBCRAWLER_CONNECTOR: Indexes web pages from crawled websites
|
||||
|
||||
Args:
|
||||
connector_id: ID of the connector to use
|
||||
search_space_id: ID of the search space to store indexed content
|
||||
|
||||
Returns:
|
||||
Dictionary with indexing status
|
||||
Live connectors (Slack, Teams, Linear, Jira, ClickUp, Calendar, Airtable,
|
||||
Gmail, Discord, Luma) use real-time agent tools instead.
|
||||
"""
|
||||
try:
|
||||
# Get the connector first
|
||||
|
|
@ -770,9 +753,7 @@ async def index_connector_content(
|
|||
|
||||
# For calendar connectors, default to today but allow future dates if explicitly provided
|
||||
if connector.connector_type in [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.LUMA_CONNECTOR,
|
||||
]:
|
||||
# Default to today if no end_date provided (users can manually select future dates)
|
||||
indexing_to = today_str if end_date is None else end_date
|
||||
|
|
@ -796,33 +777,22 @@ async def index_connector_content(
|
|||
# For non-calendar connectors, cap at today
|
||||
indexing_to = end_date if end_date else today_str
|
||||
|
||||
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_slack_messages_task,
|
||||
)
|
||||
from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES
|
||||
|
||||
logger.info(
|
||||
f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_slack_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Slack indexing started in the background."
|
||||
if connector.connector_type in LIVE_CONNECTOR_TYPES:
|
||||
return {
|
||||
"message": (
|
||||
f"{connector.connector_type.value} uses real-time agent tools; "
|
||||
"background indexing is disabled."
|
||||
),
|
||||
"indexing_started": False,
|
||||
"connector_id": connector_id,
|
||||
"search_space_id": search_space_id,
|
||||
"indexing_from": indexing_from,
|
||||
"indexing_to": indexing_to,
|
||||
}
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_teams_messages_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Teams indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_teams_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Teams indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
if connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_notion_pages_task
|
||||
|
||||
logger.info(
|
||||
|
|
@ -844,28 +814,6 @@ async def index_connector_content(
|
|||
)
|
||||
response_message = "GitHub indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_linear_issues_task
|
||||
|
||||
logger.info(
|
||||
f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_linear_issues_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Linear indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.JIRA_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_jira_issues_task
|
||||
|
||||
logger.info(
|
||||
f"Triggering Jira indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_jira_issues_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Jira indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.CONFLUENCE_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_confluence_pages_task,
|
||||
|
|
@ -892,59 +840,6 @@ async def index_connector_content(
|
|||
)
|
||||
response_message = "BookStack indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.CLICKUP_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_clickup_tasks_task
|
||||
|
||||
logger.info(
|
||||
f"Triggering ClickUp indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_clickup_tasks_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "ClickUp indexing started in the background."
|
||||
|
||||
elif (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_google_calendar_events_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Google Calendar indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_google_calendar_events_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Google Calendar indexing started in the background."
|
||||
elif connector.connector_type == SearchSourceConnectorType.AIRTABLE_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_airtable_records_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Airtable indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_airtable_records_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Airtable indexing started in the background."
|
||||
elif (
|
||||
connector.connector_type == SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_google_gmail_messages_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Google Gmail indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_google_gmail_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Google Gmail indexing started in the background."
|
||||
|
||||
elif (
|
||||
connector.connector_type == SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
|
|
@ -1089,30 +984,6 @@ async def index_connector_content(
|
|||
)
|
||||
response_message = "Dropbox indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_discord_messages_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_discord_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Discord indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_luma_events_task
|
||||
|
||||
logger.info(
|
||||
f"Triggering Luma indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_luma_events_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Luma indexing started in the background."
|
||||
|
||||
elif (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR
|
||||
|
|
@ -1319,57 +1190,6 @@ async def _update_connector_timestamp_by_id(session: AsyncSession, connector_id:
|
|||
await session.rollback()
|
||||
|
||||
|
||||
async def run_slack_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Slack indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_slack_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_slack_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Slack indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Slack connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_slack_messages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_slack_messages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
_AUTH_ERROR_PATTERNS = (
|
||||
"failed to refresh linear oauth",
|
||||
"failed to refresh your notion connection",
|
||||
|
|
@ -1908,215 +1728,6 @@ async def run_github_indexing(
|
|||
)
|
||||
|
||||
|
||||
# Add new helper functions for Linear indexing
|
||||
async def run_linear_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run Linear indexing with its own database session."""
|
||||
logger.info(
|
||||
f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_linear_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing Linear connector {connector_id}")
|
||||
|
||||
|
||||
async def run_linear_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Linear indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Linear connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_linear_issues
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_linear_issues,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for discord indexing
|
||||
async def run_discord_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Discord indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_discord_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_discord_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Discord indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Discord connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_discord_messages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_discord_messages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
async def run_teams_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Microsoft Teams indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_teams_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_teams_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Microsoft Teams indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Teams connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers.teams_indexer import index_teams_messages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_teams_messages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for Jira indexing
|
||||
async def run_jira_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run Jira indexing with its own database session."""
|
||||
logger.info(
|
||||
f"Background task started: Indexing Jira connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_jira_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing Jira connector {connector_id}")
|
||||
|
||||
|
||||
async def run_jira_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Jira indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Jira connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_jira_issues
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_jira_issues,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for Confluence indexing
|
||||
async def run_confluence_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
|
|
@ -2172,112 +1783,6 @@ async def run_confluence_indexing(
|
|||
)
|
||||
|
||||
|
||||
# Add new helper functions for ClickUp indexing
|
||||
async def run_clickup_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run ClickUp indexing with its own database session."""
|
||||
logger.info(
|
||||
f"Background task started: Indexing ClickUp connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_clickup_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing ClickUp connector {connector_id}")
|
||||
|
||||
|
||||
async def run_clickup_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run ClickUp indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the ClickUp connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_clickup_tasks
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_clickup_tasks,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for Airtable indexing
|
||||
async def run_airtable_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run Airtable indexing with its own database session."""
|
||||
logger.info(
|
||||
f"Background task started: Indexing Airtable connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_airtable_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing Airtable connector {connector_id}")
|
||||
|
||||
|
||||
async def run_airtable_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Airtable indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Airtable connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_airtable_records
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_airtable_records,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for Google Calendar indexing
|
||||
async def run_google_calendar_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
|
|
@ -2816,58 +2321,6 @@ async def run_dropbox_indexing(
|
|||
logger.error(f"Failed to update notification: {notif_error!s}")
|
||||
|
||||
|
||||
# Add new helper functions for luma indexing
|
||||
async def run_luma_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Luma indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_luma_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_luma_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Luma indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Luma connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_luma_events
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_luma_events,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
async def run_elasticsearch_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
|
|
@ -3580,13 +3033,18 @@ async def trust_mcp_tool(
|
|||
"""Add a tool to the MCP connector's trusted (always-allow) list.
|
||||
|
||||
Once trusted, the tool executes without HITL approval on subsequent calls.
|
||||
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors
|
||||
(LINEAR_CONNECTOR, JIRA_CONNECTOR, etc.) by checking for ``server_config``.
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
@ -3631,13 +3089,17 @@ async def untrust_mcp_tool(
|
|||
"""Remove a tool from the MCP connector's trusted list.
|
||||
|
||||
The tool will require HITL approval again on subsequent calls.
|
||||
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors.
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
|
|||
|
|
@ -312,7 +312,7 @@ async def slack_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.SLACK_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ SCOPES = [
|
|||
"Team.ReadBasic.All", # Read basic team information
|
||||
"Channel.ReadBasic.All", # Read basic channel information
|
||||
"ChannelMessage.Read.All", # Read messages in channels
|
||||
"ChannelMessage.Send", # Send messages in channels
|
||||
]
|
||||
|
||||
# Initialize security utilities
|
||||
|
|
@ -320,7 +321,7 @@ async def teams_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.TEAMS_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -168,6 +168,11 @@ class ChatMessage(BaseModel):
|
|||
content: str
|
||||
|
||||
|
||||
class LocalFilesystemMountPayload(BaseModel):
|
||||
mount_id: str
|
||||
root_path: str
|
||||
|
||||
|
||||
class NewChatRequest(BaseModel):
|
||||
"""Request schema for the deep agent chat endpoint."""
|
||||
|
||||
|
|
@ -184,6 +189,9 @@ class NewChatRequest(BaseModel):
|
|||
disabled_tools: list[str] | None = (
|
||||
None # Optional list of tool names the user has disabled from the UI
|
||||
)
|
||||
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||
client_platform: Literal["web", "desktop"] = "web"
|
||||
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||
|
||||
|
||||
class RegenerateRequest(BaseModel):
|
||||
|
|
@ -204,6 +212,9 @@ class RegenerateRequest(BaseModel):
|
|||
mentioned_document_ids: list[int] | None = None
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None
|
||||
disabled_tools: list[str] | None = None
|
||||
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||
client_platform: Literal["web", "desktop"] = "web"
|
||||
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -227,6 +238,9 @@ class ResumeDecision(BaseModel):
|
|||
class ResumeRequest(BaseModel):
|
||||
search_space_id: int
|
||||
decisions: list[ResumeDecision]
|
||||
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||
client_platform: Literal["web", "desktop"] = "web"
|
||||
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ COMPOSIO_TOOLKIT_NAMES = {
|
|||
}
|
||||
|
||||
# Toolkits that support indexing (Phase 1: Google services only)
|
||||
INDEXABLE_TOOLKITS = {"googledrive", "gmail", "googlecalendar"}
|
||||
INDEXABLE_TOOLKITS = {"googledrive"}
|
||||
|
||||
# Mapping of toolkit IDs to connector types
|
||||
TOOLKIT_TO_CONNECTOR_TYPE = {
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -66,6 +65,8 @@ class ConfluenceKBSyncService:
|
|||
if dup:
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -184,6 +185,8 @@ class ConfluenceKBSyncService:
|
|||
|
||||
space_id = (document.document_metadata or {}).get("space_id", "")
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.db import Document, DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -73,6 +72,8 @@ class DropboxKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -78,6 +77,8 @@ class GmailKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from app.db import (
|
|||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -91,6 +90,8 @@ class GoogleCalendarKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -249,6 +250,8 @@ class GoogleCalendarKBSyncService:
|
|||
if not indexable_content:
|
||||
return {"status": "error", "message": "Event produced empty content"}
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -75,6 +74,8 @@ class GoogleDriveKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -75,6 +74,8 @@ class JiraKBSyncService:
|
|||
if dup:
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -190,6 +191,8 @@ class JiraKBSyncService:
|
|||
state = formatted.get("status", "Unknown")
|
||||
comment_count = len(formatted.get("comments", []))
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.connectors.linear_connector import LinearConnector
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -85,6 +84,8 @@ class LinearKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -226,6 +227,8 @@ class LinearKBSyncService:
|
|||
comment_count = len(formatted_issue.get("comments", []))
|
||||
formatted_issue.get("description", "")
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -133,6 +133,44 @@ PROVIDER_MAP = {
|
|||
}
|
||||
|
||||
|
||||
# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
|
||||
# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
|
||||
# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
|
||||
# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
|
||||
# request to an Azure endpoint, which then 404s with ``Resource not found``.
|
||||
# Only providers with a well-known, stable public base URL are listed here —
|
||||
# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
||||
# huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
||||
# so their existing config-driven behaviour is preserved.
|
||||
PROVIDER_DEFAULT_API_BASE = {
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"perplexity": "https://api.perplexity.ai",
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"cerebras": "https://api.cerebras.ai/v1",
|
||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
||||
"together_ai": "https://api.together.xyz/v1",
|
||||
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
||||
"cometapi": "https://api.cometapi.com/v1",
|
||||
"sambanova": "https://api.sambanova.ai/v1",
|
||||
}
|
||||
|
||||
|
||||
# Canonical provider → base URL when a config uses a generic ``openai``-style
|
||||
# prefix but the ``provider`` field tells us which API it really is
|
||||
# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
|
||||
# each has its own base URL).
|
||||
PROVIDER_KEY_DEFAULT_API_BASE = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"MOONSHOT": "https://api.moonshot.ai/v1",
|
||||
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"MINIMAX": "https://api.minimax.io/v1",
|
||||
}
|
||||
|
||||
|
||||
class LLMRouterService:
|
||||
"""
|
||||
Singleton service for managing LiteLLM Router.
|
||||
|
|
@ -224,6 +262,16 @@ class LLMRouterService:
|
|||
# hits ContextWindowExceededError.
|
||||
full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
|
||||
|
||||
# Build a general-purpose fallback list so NotFound/timeout/rate-limit
|
||||
# style failures on one deployment don't bubble up as hard errors —
|
||||
# the router retries with a sibling deployment in ``auto-large``.
|
||||
# ``auto-large`` is the large-context subset of ``auto``; if it is
|
||||
# empty we fall back to ``auto`` itself so the router at least picks a
|
||||
# different deployment in the same group.
|
||||
fallbacks: list[dict[str, list[str]]] | None = None
|
||||
if ctx_fallbacks:
|
||||
fallbacks = [{"auto": ["auto-large"]}]
|
||||
|
||||
try:
|
||||
router_kwargs: dict[str, Any] = {
|
||||
"model_list": full_model_list,
|
||||
|
|
@ -237,15 +285,24 @@ class LLMRouterService:
|
|||
}
|
||||
if ctx_fallbacks:
|
||||
router_kwargs["context_window_fallbacks"] = ctx_fallbacks
|
||||
if fallbacks:
|
||||
router_kwargs["fallbacks"] = fallbacks
|
||||
|
||||
instance._router = Router(**router_kwargs)
|
||||
instance._initialized = True
|
||||
|
||||
global _cached_context_profile, _cached_context_profile_computed
|
||||
_cached_context_profile = None
|
||||
_cached_context_profile_computed = False
|
||||
_router_instance_cache.clear()
|
||||
|
||||
logger.info(
|
||||
"LLM Router initialized with %d deployments, "
|
||||
"strategy: %s, context_window_fallbacks: %s",
|
||||
"strategy: %s, context_window_fallbacks: %s, fallbacks: %s",
|
||||
len(model_list),
|
||||
final_settings.get("routing_strategy"),
|
||||
ctx_fallbacks or "none",
|
||||
fallbacks or "none",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||
|
|
@ -348,10 +405,11 @@ class LLMRouterService:
|
|||
return None
|
||||
|
||||
# Build model string
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
provider_prefix = config["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
|
|
@ -361,9 +419,19 @@ class LLMRouterService:
|
|||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
# Add optional api_base
|
||||
if config.get("api_base"):
|
||||
litellm_params["api_base"] = config["api_base"]
|
||||
# Resolve ``api_base``. Config value wins; otherwise apply a
|
||||
# provider-aware default so the deployment does not silently
|
||||
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
|
||||
# requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
|
||||
# docstring for the motivating bug (OpenRouter models 404-ing
|
||||
# against an Azure endpoint).
|
||||
api_base = config.get("api_base")
|
||||
if not api_base:
|
||||
api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
|
||||
if not api_base:
|
||||
api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if config.get("litellm_params"):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
import litellm
|
||||
|
|
@ -6,7 +7,6 @@ from langchain_litellm import ChatLiteLLM
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
from app.config import config
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
from app.services.llm_router_service import (
|
||||
|
|
@ -32,6 +32,39 @@ litellm.callbacks = [token_tracker]
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Providers that require an interactive OAuth / device-flow login before
|
||||
# issuing any completion. LiteLLM implements these with blocking sync polling
|
||||
# (requests + time.sleep), which would freeze the FastAPI event loop if
|
||||
# invoked from validation. They are never usable from a headless backend,
|
||||
# so we reject them at the edge.
|
||||
_INTERACTIVE_AUTH_PROVIDERS: frozenset[str] = frozenset(
|
||||
{
|
||||
"github_copilot",
|
||||
"github-copilot",
|
||||
"githubcopilot",
|
||||
"copilot",
|
||||
}
|
||||
)
|
||||
|
||||
# Hard upper bound for a single validation call. Must exceed the ChatLiteLLM
|
||||
# request timeout (30s) by a small margin so a well-behaved provider never
|
||||
# trips the watchdog, while any pathological/blocking provider is killed.
|
||||
_VALIDATION_TIMEOUT_SECONDS: float = 35.0
|
||||
|
||||
|
||||
def _is_interactive_auth_provider(
|
||||
provider: str | None, custom_provider: str | None
|
||||
) -> bool:
|
||||
"""Return True if the given provider triggers interactive OAuth in LiteLLM."""
|
||||
for raw in (custom_provider, provider):
|
||||
if not raw:
|
||||
continue
|
||||
normalized = raw.strip().lower().replace(" ", "_")
|
||||
if normalized in _INTERACTIVE_AUTH_PROVIDERS:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class LLMRole:
|
||||
AGENT = "agent" # For agent/chat operations
|
||||
DOCUMENT_SUMMARY = "document_summary" # For document summarization
|
||||
|
|
@ -93,6 +126,25 @@ async def validate_llm_config(
|
|||
- is_valid: True if config works, False otherwise
|
||||
- error_message: Empty string if valid, error description if invalid
|
||||
"""
|
||||
# Reject providers that require interactive OAuth/device-flow auth.
|
||||
# LiteLLM's github_copilot provider (and similar) uses a blocking sync
|
||||
# Authenticator that polls GitHub for up to several minutes and prints a
|
||||
# device code to stdout. Running it on the FastAPI event loop will freeze
|
||||
# the entire backend, so we refuse them up front.
|
||||
if _is_interactive_auth_provider(provider, custom_provider):
|
||||
msg = (
|
||||
"Provider requires interactive OAuth/device-flow authentication "
|
||||
"(e.g. github_copilot) and cannot be used in a hosted backend. "
|
||||
"Please choose a provider that authenticates via API key."
|
||||
)
|
||||
logger.warning(
|
||||
"Rejected LLM config validation for interactive-auth provider "
|
||||
"(provider=%r, custom_provider=%r)",
|
||||
provider,
|
||||
custom_provider,
|
||||
)
|
||||
return False, msg
|
||||
|
||||
try:
|
||||
# Build the model string for litellm
|
||||
if custom_provider:
|
||||
|
|
@ -151,11 +203,34 @@ async def validate_llm_config(
|
|||
if litellm_params:
|
||||
litellm_kwargs.update(litellm_params)
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
# Make a simple test call
|
||||
# Run the test call in a worker thread with a hard timeout. Some
|
||||
# LiteLLM providers have synchronous blocking code paths (e.g. OAuth
|
||||
# authenticators that call time.sleep and requests.post) that would
|
||||
# otherwise freeze the asyncio event loop. Offloading to a thread and
|
||||
# bounding the wait keeps the server responsive even if a provider
|
||||
# misbehaves.
|
||||
test_message = HumanMessage(content="Hello")
|
||||
response = await llm.ainvoke([test_message])
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
asyncio.to_thread(llm.invoke, [test_message]),
|
||||
timeout=_VALIDATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"LLM config validation timed out after %ss for model: %s",
|
||||
_VALIDATION_TIMEOUT_SECONDS,
|
||||
model_string,
|
||||
)
|
||||
return (
|
||||
False,
|
||||
f"Validation timed out after {int(_VALIDATION_TIMEOUT_SECONDS)}s. "
|
||||
"The provider is unreachable or requires interactive "
|
||||
"authentication that is not supported by the backend.",
|
||||
)
|
||||
|
||||
# If we got here without exception, the config is valid
|
||||
if response and response.content:
|
||||
|
|
@ -303,6 +378,8 @@ async def get_search_space_llm_instance(
|
|||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
# Get the LLM configuration from database (NewLLMConfig)
|
||||
|
|
@ -380,6 +457,8 @@ async def get_search_space_llm_instance(
|
|||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -481,6 +560,8 @@ async def get_vision_llm(
|
|||
if global_cfg.get("litellm_params"):
|
||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
result = await session.execute(
|
||||
|
|
@ -514,6 +595,8 @@ async def get_vision_llm(
|
|||
if vision_cfg.litellm_params:
|
||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
0
surfsense_backend/app/services/mcp_oauth/__init__.py
Normal file
0
surfsense_backend/app/services/mcp_oauth/__init__.py
Normal file
121
surfsense_backend/app/services/mcp_oauth/discovery.py
Normal file
121
surfsense_backend/app/services/mcp_oauth/discovery.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"""MCP OAuth 2.1 metadata discovery, Dynamic Client Registration, and token exchange."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def discover_oauth_metadata(
|
||||
mcp_url: str,
|
||||
*,
|
||||
origin_override: str | None = None,
|
||||
timeout: float = 15.0,
|
||||
) -> dict:
|
||||
"""Fetch OAuth 2.1 metadata from the MCP server's well-known endpoint.
|
||||
|
||||
Per the MCP spec the discovery document lives at the *origin* of the
|
||||
MCP server URL. ``origin_override`` can be used when the OAuth server
|
||||
lives on a different domain (e.g. Airtable: MCP at ``mcp.airtable.com``,
|
||||
OAuth at ``airtable.com``).
|
||||
"""
|
||||
if origin_override:
|
||||
origin = origin_override.rstrip("/")
|
||||
else:
|
||||
parsed = urlparse(mcp_url)
|
||||
origin = f"{parsed.scheme}://{parsed.netloc}"
|
||||
discovery_url = f"{origin}/.well-known/oauth-authorization-server"
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.get(discovery_url, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def register_client(
|
||||
registration_endpoint: str,
|
||||
redirect_uri: str,
|
||||
*,
|
||||
client_name: str = "SurfSense",
|
||||
timeout: float = 15.0,
|
||||
) -> dict:
|
||||
"""Perform Dynamic Client Registration (RFC 7591)."""
|
||||
payload = {
|
||||
"client_name": client_name,
|
||||
"redirect_uris": [redirect_uri],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.post(
|
||||
registration_endpoint, json=payload, timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
token_endpoint: str,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
code_verifier: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
) -> dict:
|
||||
"""Exchange an authorization code for access + refresh tokens."""
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.post(
|
||||
token_endpoint,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Authorization": f"Basic {creds}",
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def refresh_access_token(
|
||||
token_endpoint: str,
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
) -> dict:
|
||||
"""Refresh an expired access token."""
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.post(
|
||||
token_endpoint,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Authorization": f"Basic {creds}",
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
161
surfsense_backend/app/services/mcp_oauth/registry.py
Normal file
161
surfsense_backend/app/services/mcp_oauth/registry.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
"""Registry of MCP services with OAuth support.
|
||||
|
||||
Each entry maps a URL-safe service key to its MCP server endpoint and
|
||||
authentication configuration. Services with ``supports_dcr=True`` use
|
||||
RFC 7591 Dynamic Client Registration (the MCP server issues its own
|
||||
credentials); the rest use pre-configured credentials via env vars.
|
||||
|
||||
``allowed_tools`` whitelists which MCP tools to expose to the agent.
|
||||
An empty list means "load every tool the server advertises" (used for
|
||||
user-managed generic MCP servers). Service-specific entries should
|
||||
curate this list to keep the agent's tool count low and selection
|
||||
accuracy high.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from app.db import SearchSourceConnectorType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MCPServiceConfig:
|
||||
name: str
|
||||
mcp_url: str
|
||||
connector_type: str
|
||||
supports_dcr: bool = True
|
||||
oauth_discovery_origin: str | None = None
|
||||
client_id_env: str | None = None
|
||||
client_secret_env: str | None = None
|
||||
scopes: list[str] = field(default_factory=list)
|
||||
scope_param: str = "scope"
|
||||
auth_endpoint_override: str | None = None
|
||||
token_endpoint_override: str | None = None
|
||||
allowed_tools: list[str] = field(default_factory=list)
|
||||
readonly_tools: frozenset[str] = field(default_factory=frozenset)
|
||||
account_metadata_keys: list[str] = field(default_factory=list)
|
||||
"""``connector.config`` keys exposed by ``get_connected_accounts``.
|
||||
|
||||
Only listed keys are returned to the LLM — tokens and secrets are
|
||||
never included. Every service should at least have its
|
||||
``display_name`` populated during OAuth; additional service-specific
|
||||
fields (e.g. Jira ``cloud_id``) are listed here so the LLM can pass
|
||||
them to action tools.
|
||||
"""
|
||||
|
||||
|
||||
MCP_SERVICES: dict[str, MCPServiceConfig] = {
|
||||
"linear": MCPServiceConfig(
|
||||
name="Linear",
|
||||
mcp_url="https://mcp.linear.app/mcp",
|
||||
connector_type="LINEAR_CONNECTOR",
|
||||
allowed_tools=[
|
||||
"list_issues",
|
||||
"get_issue",
|
||||
"save_issue",
|
||||
],
|
||||
readonly_tools=frozenset({"list_issues", "get_issue"}),
|
||||
account_metadata_keys=["organization_name", "organization_url_key"],
|
||||
),
|
||||
"jira": MCPServiceConfig(
|
||||
name="Jira",
|
||||
mcp_url="https://mcp.atlassian.com/v1/mcp",
|
||||
connector_type="JIRA_CONNECTOR",
|
||||
allowed_tools=[
|
||||
"getAccessibleAtlassianResources",
|
||||
"searchJiraIssuesUsingJql",
|
||||
"getVisibleJiraProjects",
|
||||
"getJiraProjectIssueTypesMetadata",
|
||||
"createJiraIssue",
|
||||
"editJiraIssue",
|
||||
],
|
||||
readonly_tools=frozenset({
|
||||
"getAccessibleAtlassianResources",
|
||||
"searchJiraIssuesUsingJql",
|
||||
"getVisibleJiraProjects",
|
||||
"getJiraProjectIssueTypesMetadata",
|
||||
}),
|
||||
account_metadata_keys=["cloud_id", "site_name", "base_url"],
|
||||
),
|
||||
"clickup": MCPServiceConfig(
|
||||
name="ClickUp",
|
||||
mcp_url="https://mcp.clickup.com/mcp",
|
||||
connector_type="CLICKUP_CONNECTOR",
|
||||
allowed_tools=[
|
||||
"clickup_search",
|
||||
"clickup_get_task",
|
||||
],
|
||||
readonly_tools=frozenset({"clickup_search", "clickup_get_task"}),
|
||||
account_metadata_keys=["workspace_id", "workspace_name"],
|
||||
),
|
||||
"slack": MCPServiceConfig(
|
||||
name="Slack",
|
||||
mcp_url="https://mcp.slack.com/mcp",
|
||||
connector_type="SLACK_CONNECTOR",
|
||||
supports_dcr=False,
|
||||
client_id_env="SLACK_CLIENT_ID",
|
||||
client_secret_env="SLACK_CLIENT_SECRET",
|
||||
auth_endpoint_override="https://slack.com/oauth/v2_user/authorize",
|
||||
token_endpoint_override="https://slack.com/api/oauth.v2.user.access",
|
||||
scopes=[
|
||||
"search:read.public", "search:read.private", "search:read.mpim", "search:read.im",
|
||||
"channels:history", "groups:history", "mpim:history", "im:history",
|
||||
],
|
||||
allowed_tools=[
|
||||
"slack_search_channels",
|
||||
"slack_read_channel",
|
||||
"slack_read_thread",
|
||||
],
|
||||
readonly_tools=frozenset({"slack_search_channels", "slack_read_channel", "slack_read_thread"}),
|
||||
# TODO: oauth.v2.user.access only returns team.id, not team.name.
|
||||
# To populate team_name, either add "team:read" scope and call
|
||||
# GET /api/team.info during OAuth callback, or switch to oauth.v2.access.
|
||||
account_metadata_keys=["team_id", "team_name"],
|
||||
),
|
||||
"airtable": MCPServiceConfig(
|
||||
name="Airtable",
|
||||
mcp_url="https://mcp.airtable.com/mcp",
|
||||
connector_type="AIRTABLE_CONNECTOR",
|
||||
supports_dcr=False,
|
||||
oauth_discovery_origin="https://airtable.com",
|
||||
client_id_env="AIRTABLE_CLIENT_ID",
|
||||
client_secret_env="AIRTABLE_CLIENT_SECRET",
|
||||
scopes=["data.records:read", "schema.bases:read"],
|
||||
allowed_tools=[
|
||||
"list_bases",
|
||||
"list_tables_for_base",
|
||||
"list_records_for_table",
|
||||
],
|
||||
readonly_tools=frozenset({"list_bases", "list_tables_for_base", "list_records_for_table"}),
|
||||
account_metadata_keys=["user_id", "user_email"],
|
||||
),
|
||||
}
|
||||
|
||||
_CONNECTOR_TYPE_TO_SERVICE: dict[str, MCPServiceConfig] = {
|
||||
svc.connector_type: svc for svc in MCP_SERVICES.values()
|
||||
}
|
||||
|
||||
LIVE_CONNECTOR_TYPES: frozenset[SearchSourceConnectorType] = frozenset({
|
||||
SearchSourceConnectorType.SLACK_CONNECTOR,
|
||||
SearchSourceConnectorType.TEAMS_CONNECTOR,
|
||||
SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR,
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||
SearchSourceConnectorType.LUMA_CONNECTOR,
|
||||
})
|
||||
|
||||
|
||||
def get_service(key: str) -> MCPServiceConfig | None:
|
||||
return MCP_SERVICES.get(key)
|
||||
|
||||
|
||||
def get_service_by_connector_type(connector_type: str) -> MCPServiceConfig | None:
|
||||
"""Look up an MCP service config by its ``connector_type`` enum value."""
|
||||
return _CONNECTOR_TYPE_TO_SERVICE.get(connector_type)
|
||||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -74,6 +73,8 @@ class NotionKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
@ -244,6 +245,8 @@ class NotionKBSyncService:
|
|||
f"Final content length: {len(full_content)} chars, verified={content_verified}"
|
||||
)
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
logger.debug("Generating summary and embeddings")
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
|
|
|
|||
|
|
@ -227,8 +227,6 @@ class NotionToolMetadataService:
|
|||
async def _check_account_health(self, connector_id: int) -> bool:
|
||||
"""Check if a Notion connector's token is still valid.
|
||||
|
||||
Uses a lightweight ``users.me()`` call to verify the token.
|
||||
|
||||
Returns True if the token is expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.db import Document, DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -73,6 +72,8 @@ class OneDriveKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
|
|
|
|||
|
|
@ -39,52 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N
|
|||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_slack_messages", bind=True)
|
||||
def index_slack_messages_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Slack messages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_slack_messages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_slack_messages", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_slack_messages(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Slack messages with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_slack_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_slack_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_notion_pages", bind=True)
|
||||
def index_notion_pages_task(
|
||||
self,
|
||||
|
|
@ -174,92 +128,6 @@ async def _index_github_repos(
|
|||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_linear_issues", bind=True)
|
||||
def index_linear_issues_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Linear issues."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_linear_issues(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_linear_issues(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Linear issues with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_linear_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_linear_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_jira_issues", bind=True)
|
||||
def index_jira_issues_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Jira issues."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_jira_issues(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_jira_issues(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Jira issues with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_jira_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_jira_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_confluence_pages", bind=True)
|
||||
def index_confluence_pages_task(
|
||||
self,
|
||||
|
|
@ -303,49 +171,6 @@ async def _index_confluence_pages(
|
|||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_clickup_tasks", bind=True)
|
||||
def index_clickup_tasks_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index ClickUp tasks."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_clickup_tasks(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_clickup_tasks(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index ClickUp tasks with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_clickup_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_clickup_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_google_calendar_events", bind=True)
|
||||
def index_google_calendar_events_task(
|
||||
self,
|
||||
|
|
@ -392,49 +217,6 @@ async def _index_google_calendar_events(
|
|||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_airtable_records", bind=True)
|
||||
def index_airtable_records_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Airtable records."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_airtable_records(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_airtable_records(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Airtable records with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_airtable_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_airtable_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_google_gmail_messages", bind=True)
|
||||
def index_google_gmail_messages_task(
|
||||
self,
|
||||
|
|
@ -622,135 +404,6 @@ async def _index_dropbox_files(
|
|||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_discord_messages", bind=True)
|
||||
def index_discord_messages_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Discord messages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_discord_messages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_discord_messages(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Discord messages with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_discord_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_discord_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_teams_messages", bind=True)
|
||||
def index_teams_messages_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Microsoft Teams messages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_teams_messages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_teams_messages(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Microsoft Teams messages with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_teams_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_teams_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_luma_events", bind=True)
|
||||
def index_luma_events_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Luma events."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_luma_events(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_luma_events(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Luma events with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_luma_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_luma_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_elasticsearch_documents", bind=True)
|
||||
def index_elasticsearch_documents_task(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -51,50 +51,51 @@ async def _check_and_trigger_schedules():
|
|||
|
||||
logger.info(f"Found {len(due_connectors)} connectors due for indexing")
|
||||
|
||||
# Import all indexing tasks
|
||||
# Import indexing tasks for KB connectors only.
|
||||
# Live connectors (Linear, Slack, Jira, ClickUp, Airtable, Discord,
|
||||
# Teams, Gmail, Calendar, Luma) use real-time tools instead.
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_airtable_records_task,
|
||||
index_clickup_tasks_task,
|
||||
index_confluence_pages_task,
|
||||
index_crawled_urls_task,
|
||||
index_discord_messages_task,
|
||||
index_elasticsearch_documents_task,
|
||||
index_github_repos_task,
|
||||
index_google_calendar_events_task,
|
||||
index_google_drive_files_task,
|
||||
index_google_gmail_messages_task,
|
||||
index_jira_issues_task,
|
||||
index_linear_issues_task,
|
||||
index_luma_events_task,
|
||||
index_notion_pages_task,
|
||||
index_slack_messages_task,
|
||||
)
|
||||
|
||||
# Map connector types to their tasks
|
||||
task_map = {
|
||||
SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task,
|
||||
SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task,
|
||||
SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task,
|
||||
SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task,
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task,
|
||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task,
|
||||
SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task,
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
|
||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task,
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task,
|
||||
SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task,
|
||||
SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task,
|
||||
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
|
||||
# Composio connector types (unified with native Google tasks)
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: index_google_gmail_messages_task,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
|
||||
}
|
||||
|
||||
from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES
|
||||
|
||||
# Disable obsolete periodic indexing for live connectors in one batch.
|
||||
live_disabled = []
|
||||
for connector in due_connectors:
|
||||
if connector.connector_type in LIVE_CONNECTOR_TYPES:
|
||||
connector.periodic_indexing_enabled = False
|
||||
connector.next_scheduled_at = None
|
||||
live_disabled.append(connector)
|
||||
if live_disabled:
|
||||
await session.commit()
|
||||
for c in live_disabled:
|
||||
logger.info(
|
||||
"Disabled obsolete periodic indexing for live connector %s (%s)",
|
||||
c.id,
|
||||
c.connector_type.value,
|
||||
)
|
||||
|
||||
# Trigger indexing for each due connector
|
||||
for connector in due_connectors:
|
||||
if connector in live_disabled:
|
||||
continue
|
||||
|
||||
# Primary guard: Redis lock indicates a task is currently running.
|
||||
if is_connector_indexing_locked(connector.id):
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ from sqlalchemy.orm import selectinload
|
|||
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemSelection
|
||||
from app.config import config
|
||||
from app.agents.new_chat.llm_config import (
|
||||
AgentConfig,
|
||||
create_chat_litellm_from_agent_config,
|
||||
|
|
@ -145,6 +147,102 @@ class StreamResult:
|
|||
interrupt_value: dict[str, Any] | None = None
|
||||
sandbox_files: list[str] = field(default_factory=list)
|
||||
agent_called_update_memory: bool = False
|
||||
request_id: str | None = None
|
||||
turn_id: str = ""
|
||||
filesystem_mode: str = "cloud"
|
||||
client_platform: str = "web"
|
||||
intent_detected: str = "chat_only"
|
||||
intent_confidence: float = 0.0
|
||||
write_attempted: bool = False
|
||||
write_succeeded: bool = False
|
||||
verification_succeeded: bool = False
|
||||
commit_gate_passed: bool = True
|
||||
commit_gate_reason: str = ""
|
||||
|
||||
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _tool_output_to_text(tool_output: Any) -> str:
|
||||
if isinstance(tool_output, dict):
|
||||
if isinstance(tool_output.get("result"), str):
|
||||
return tool_output["result"]
|
||||
if isinstance(tool_output.get("error"), str):
|
||||
return tool_output["error"]
|
||||
return json.dumps(tool_output, ensure_ascii=False)
|
||||
return str(tool_output)
|
||||
|
||||
|
||||
def _tool_output_has_error(tool_output: Any) -> bool:
|
||||
if isinstance(tool_output, dict):
|
||||
if tool_output.get("error"):
|
||||
return True
|
||||
result = tool_output.get("result")
|
||||
if isinstance(result, str) and result.strip().lower().startswith("error:"):
|
||||
return True
|
||||
return False
|
||||
if isinstance(tool_output, str):
|
||||
return tool_output.strip().lower().startswith("error:")
|
||||
return False
|
||||
|
||||
|
||||
def _extract_resolved_file_path(*, tool_name: str, tool_output: Any) -> str | None:
|
||||
if isinstance(tool_output, dict):
|
||||
path_value = tool_output.get("path")
|
||||
if isinstance(path_value, str) and path_value.strip():
|
||||
return path_value.strip()
|
||||
text = _tool_output_to_text(tool_output)
|
||||
if tool_name == "write_file":
|
||||
match = re.search(r"Updated file\s+(.+)$", text.strip())
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
if tool_name == "edit_file":
|
||||
match = re.search(r"in '([^']+)'", text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return None
|
||||
|
||||
|
||||
def _contract_enforcement_active(result: StreamResult) -> bool:
|
||||
# Keep policy deterministic with no env-driven progression modes:
|
||||
# enforce the file-operation contract only in desktop local-folder mode.
|
||||
return result.filesystem_mode == "desktop_local_folder"
|
||||
|
||||
|
||||
def _evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]:
|
||||
if result.intent_detected != "file_write":
|
||||
return True, ""
|
||||
if not result.write_attempted:
|
||||
return False, "no_write_attempt"
|
||||
if not result.write_succeeded:
|
||||
return False, "write_failed"
|
||||
if not result.verification_succeeded:
|
||||
return False, "verification_failed"
|
||||
return True, ""
|
||||
|
||||
|
||||
def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
|
||||
payload: dict[str, Any] = {
|
||||
"stage": stage,
|
||||
"request_id": result.request_id or "unknown",
|
||||
"turn_id": result.turn_id or "unknown",
|
||||
"chat_id": result.turn_id.split(":", 1)[0] if ":" in result.turn_id else "unknown",
|
||||
"filesystem_mode": result.filesystem_mode,
|
||||
"client_platform": result.client_platform,
|
||||
"intent_detected": result.intent_detected,
|
||||
"intent_confidence": result.intent_confidence,
|
||||
"write_attempted": result.write_attempted,
|
||||
"write_succeeded": result.write_succeeded,
|
||||
"verification_succeeded": result.verification_succeeded,
|
||||
"commit_gate_passed": result.commit_gate_passed,
|
||||
"commit_gate_reason": result.commit_gate_reason or None,
|
||||
}
|
||||
payload.update(extra)
|
||||
_perf_log.info("[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False))
|
||||
|
||||
|
||||
async def _stream_agent_events(
|
||||
|
|
@ -239,6 +337,8 @@ async def _stream_agent_events(
|
|||
tool_name = event.get("name", "unknown_tool")
|
||||
run_id = event.get("run_id", "")
|
||||
tool_input = event.get("data", {}).get("input", {})
|
||||
if tool_name in ("write_file", "edit_file"):
|
||||
result.write_attempted = True
|
||||
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
|
|
@ -514,6 +614,14 @@ async def _stream_agent_events(
|
|||
else:
|
||||
tool_output = {"result": str(raw_output) if raw_output else "completed"}
|
||||
|
||||
if tool_name in ("write_file", "edit_file"):
|
||||
if _tool_output_has_error(tool_output):
|
||||
# Keep successful evidence if a previous write/edit in this turn succeeded.
|
||||
pass
|
||||
else:
|
||||
result.write_succeeded = True
|
||||
result.verification_succeeded = True
|
||||
|
||||
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
|
||||
original_step_id = tool_step_ids.get(
|
||||
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
|
||||
|
|
@ -925,6 +1033,30 @@ async def _stream_agent_events(
|
|||
f"Scrape failed: {error_msg}",
|
||||
"error",
|
||||
)
|
||||
elif tool_name in ("write_file", "edit_file"):
|
||||
resolved_path = _extract_resolved_file_path(
|
||||
tool_name=tool_name,
|
||||
tool_output=tool_output,
|
||||
)
|
||||
result_text = _tool_output_to_text(tool_output)
|
||||
if _tool_output_has_error(tool_output):
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
{
|
||||
"status": "error",
|
||||
"error": result_text,
|
||||
"path": resolved_path,
|
||||
},
|
||||
)
|
||||
else:
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
{
|
||||
"status": "completed",
|
||||
"path": resolved_path,
|
||||
"result": result_text,
|
||||
},
|
||||
)
|
||||
elif tool_name == "generate_report":
|
||||
# Stream the full report result so frontend can render the ReportCard
|
||||
yield streaming_service.format_tool_output_available(
|
||||
|
|
@ -1143,10 +1275,59 @@ async def _stream_agent_events(
|
|||
if completion_event:
|
||||
yield completion_event
|
||||
|
||||
state = await agent.aget_state(config)
|
||||
state_values = getattr(state, "values", {}) or {}
|
||||
contract_state = state_values.get("file_operation_contract") or {}
|
||||
contract_turn_id = contract_state.get("turn_id")
|
||||
current_turn_id = config.get("configurable", {}).get("turn_id", "")
|
||||
intent_value = contract_state.get("intent")
|
||||
if (
|
||||
isinstance(intent_value, str)
|
||||
and intent_value in ("chat_only", "file_write", "file_read")
|
||||
and contract_turn_id == current_turn_id
|
||||
):
|
||||
result.intent_detected = intent_value
|
||||
if (
|
||||
isinstance(intent_value, str)
|
||||
and intent_value in (
|
||||
"chat_only",
|
||||
"file_write",
|
||||
"file_read",
|
||||
)
|
||||
and contract_turn_id != current_turn_id
|
||||
):
|
||||
# Ignore stale intent contracts from previous turns/checkpoints.
|
||||
result.intent_detected = "chat_only"
|
||||
result.intent_confidence = (
|
||||
_safe_float(contract_state.get("confidence"), default=0.0)
|
||||
if contract_turn_id == current_turn_id
|
||||
else 0.0
|
||||
)
|
||||
|
||||
if result.intent_detected == "file_write":
|
||||
result.commit_gate_passed, result.commit_gate_reason = (
|
||||
_evaluate_file_contract_outcome(result)
|
||||
)
|
||||
if not result.commit_gate_passed:
|
||||
if _contract_enforcement_active(result):
|
||||
gate_notice = (
|
||||
"I could not complete the requested file write because no successful "
|
||||
"write_file/edit_file operation was confirmed."
|
||||
)
|
||||
gate_text_id = streaming_service.generate_text_id()
|
||||
yield streaming_service.format_text_start(gate_text_id)
|
||||
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
|
||||
yield streaming_service.format_text_end(gate_text_id)
|
||||
yield streaming_service.format_terminal_info(gate_notice, "error")
|
||||
accumulated_text = gate_notice
|
||||
else:
|
||||
result.commit_gate_passed = True
|
||||
result.commit_gate_reason = ""
|
||||
|
||||
result.accumulated_text = accumulated_text
|
||||
result.agent_called_update_memory = called_update_memory
|
||||
_log_file_contract("turn_outcome", result)
|
||||
|
||||
state = await agent.aget_state(config)
|
||||
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
|
||||
if is_interrupted:
|
||||
result.is_interrupted = True
|
||||
|
|
@ -1167,6 +1348,8 @@ async def stream_new_chat(
|
|||
thread_visibility: ChatVisibility | None = None,
|
||||
current_user_display_name: str | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream chat responses from the new SurfSense deep agent.
|
||||
|
|
@ -1194,6 +1377,20 @@ async def stream_new_chat(
|
|||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
|
||||
fs_platform = (
|
||||
filesystem_selection.client_platform.value if filesystem_selection else "web"
|
||||
)
|
||||
stream_result.request_id = request_id
|
||||
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
|
||||
stream_result.filesystem_mode = fs_mode
|
||||
stream_result.client_platform = fs_platform
|
||||
_log_file_contract("turn_start", stream_result)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] filesystem_mode=%s client_platform=%s",
|
||||
fs_mode,
|
||||
fs_platform,
|
||||
)
|
||||
log_system_snapshot("stream_new_chat_START")
|
||||
|
||||
from app.services.token_tracking_service import start_turn
|
||||
|
|
@ -1329,6 +1526,7 @@ async def stream_new_chat(
|
|||
thread_visibility=visibility,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
filesystem_selection=filesystem_selection,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
|
|
@ -1435,6 +1633,8 @@ async def stream_new_chat(
|
|||
# We will use this to simulate group chat functionality in the future
|
||||
"messages": langchain_messages,
|
||||
"search_space_id": search_space_id,
|
||||
"request_id": request_id or "unknown",
|
||||
"turn_id": stream_result.turn_id,
|
||||
}
|
||||
|
||||
_perf_log.info(
|
||||
|
|
@ -1464,6 +1664,8 @@ async def stream_new_chat(
|
|||
# Configure LangGraph with thread_id for memory
|
||||
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
|
||||
configurable = {"thread_id": str(chat_id)}
|
||||
configurable["request_id"] = request_id or "unknown"
|
||||
configurable["turn_id"] = stream_result.turn_id
|
||||
if checkpoint_id:
|
||||
configurable["checkpoint_id"] = checkpoint_id
|
||||
|
||||
|
|
@ -1871,10 +2073,26 @@ async def stream_resume_chat(
|
|||
user_id: str | None = None,
|
||||
llm_config_id: int = -1,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
|
||||
fs_platform = (
|
||||
filesystem_selection.client_platform.value if filesystem_selection else "web"
|
||||
)
|
||||
stream_result.request_id = request_id
|
||||
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
|
||||
stream_result.filesystem_mode = fs_mode
|
||||
stream_result.client_platform = fs_platform
|
||||
_log_file_contract("turn_start", stream_result)
|
||||
_perf_log.info(
|
||||
"[stream_resume] filesystem_mode=%s client_platform=%s",
|
||||
fs_mode,
|
||||
fs_platform,
|
||||
)
|
||||
|
||||
from app.services.token_tracking_service import start_turn
|
||||
|
||||
|
|
@ -1991,6 +2209,7 @@ async def stream_resume_chat(
|
|||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
|
|
@ -2009,7 +2228,11 @@ async def stream_resume_chat(
|
|||
from langgraph.types import Command
|
||||
|
||||
config = {
|
||||
"configurable": {"thread_id": str(chat_id)},
|
||||
"configurable": {
|
||||
"thread_id": str(chat_id),
|
||||
"request_id": request_id or "unknown",
|
||||
"turn_id": stream_result.turn_id,
|
||||
},
|
||||
"recursion_limit": 80,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,75 +1,29 @@
|
|||
"""
|
||||
Connector indexers module for background tasks.
|
||||
|
||||
This module provides a collection of connector indexers for different platforms
|
||||
and services. Each indexer is responsible for handling the indexing of content
|
||||
from a specific connector type.
|
||||
|
||||
Available indexers:
|
||||
- Slack: Index messages from Slack channels
|
||||
- Notion: Index pages from Notion workspaces
|
||||
- GitHub: Index repositories and files from GitHub
|
||||
- Linear: Index issues from Linear workspaces
|
||||
- Jira: Index issues from Jira projects
|
||||
- Confluence: Index pages from Confluence spaces
|
||||
- BookStack: Index pages from BookStack wiki instances
|
||||
- Discord: Index messages from Discord servers
|
||||
- ClickUp: Index tasks from ClickUp workspaces
|
||||
- Google Gmail: Index messages from Google Gmail
|
||||
- Google Calendar: Index events from Google Calendar
|
||||
- Luma: Index events from Luma
|
||||
- Webcrawler: Index crawled URLs
|
||||
- Elasticsearch: Index documents from Elasticsearch instances
|
||||
Each indexer handles content indexing from a specific connector type.
|
||||
Live connectors (Slack, Linear, Jira, ClickUp, Airtable, Discord, Teams,
|
||||
Luma) now use real-time agent tools instead of background indexing.
|
||||
"""
|
||||
|
||||
# Communication platforms
|
||||
# Calendar and scheduling
|
||||
from .airtable_indexer import index_airtable_records
|
||||
from .bookstack_indexer import index_bookstack_pages
|
||||
|
||||
# Note: composio_indexer is imported directly in connector_tasks.py to avoid circular imports
|
||||
from .clickup_indexer import index_clickup_tasks
|
||||
from .confluence_indexer import index_confluence_pages
|
||||
from .discord_indexer import index_discord_messages
|
||||
|
||||
# Development platforms
|
||||
from .elasticsearch_indexer import index_elasticsearch_documents
|
||||
from .github_indexer import index_github_repos
|
||||
from .google_calendar_indexer import index_google_calendar_events
|
||||
from .google_drive_indexer import index_google_drive_files
|
||||
from .google_gmail_indexer import index_google_gmail_messages
|
||||
from .jira_indexer import index_jira_issues
|
||||
|
||||
# Issue tracking and project management
|
||||
from .linear_indexer import index_linear_issues
|
||||
|
||||
# Documentation and knowledge management
|
||||
from .luma_indexer import index_luma_events
|
||||
from .notion_indexer import index_notion_pages
|
||||
from .slack_indexer import index_slack_messages
|
||||
from .webcrawler_indexer import index_crawled_urls
|
||||
|
||||
__all__ = [ # noqa: RUF022
|
||||
"index_airtable_records",
|
||||
__all__ = [
|
||||
"index_bookstack_pages",
|
||||
# "index_composio_connector", # Imported directly in connector_tasks.py to avoid circular imports
|
||||
"index_clickup_tasks",
|
||||
"index_confluence_pages",
|
||||
"index_discord_messages",
|
||||
# Development platforms
|
||||
"index_crawled_urls",
|
||||
"index_elasticsearch_documents",
|
||||
"index_github_repos",
|
||||
# Calendar and scheduling
|
||||
"index_google_calendar_events",
|
||||
"index_google_drive_files",
|
||||
"index_luma_events",
|
||||
"index_jira_issues",
|
||||
# Issue tracking and project management
|
||||
"index_linear_issues",
|
||||
# Documentation and knowledge management
|
||||
"index_notion_pages",
|
||||
"index_crawled_urls",
|
||||
# Communication platforms
|
||||
"index_slack_messages",
|
||||
"index_google_gmail_messages",
|
||||
"index_notion_pages",
|
||||
]
|
||||
|
|
|
|||
129
surfsense_backend/app/utils/async_retry.py
Normal file
129
surfsense_backend/app/utils/async_retry.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""Async retry decorators for connector API calls, built on tenacity."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception,
|
||||
stop_after_attempt,
|
||||
stop_after_delay,
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from app.connectors.exceptions import (
|
||||
ConnectorAPIError,
|
||||
ConnectorAuthError,
|
||||
ConnectorError,
|
||||
ConnectorRateLimitError,
|
||||
ConnectorTimeoutError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
||||
|
||||
def _is_retryable(exc: BaseException) -> bool:
|
||||
if isinstance(exc, ConnectorError):
|
||||
return exc.retryable
|
||||
if isinstance(exc, (httpx.TimeoutException, httpx.ConnectError)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def build_retry(
|
||||
*,
|
||||
max_attempts: int = 4,
|
||||
max_delay: float = 60.0,
|
||||
initial_delay: float = 1.0,
|
||||
total_timeout: float = 180.0,
|
||||
service: str = "",
|
||||
) -> Callable:
|
||||
"""Configurable tenacity ``@retry`` decorator with exponential backoff + jitter."""
|
||||
_logger = logging.getLogger(f"connector.retry.{service}") if service else logger
|
||||
|
||||
return retry(
|
||||
retry=retry_if_exception(_is_retryable),
|
||||
stop=(stop_after_attempt(max_attempts) | stop_after_delay(total_timeout)),
|
||||
wait=wait_exponential_jitter(initial=initial_delay, max=max_delay),
|
||||
reraise=True,
|
||||
before_sleep=before_sleep_log(_logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def retry_on_transient(
|
||||
*,
|
||||
service: str = "",
|
||||
max_attempts: int = 4,
|
||||
) -> Callable:
|
||||
"""Shorthand: retry up to *max_attempts* on rate-limits, timeouts, and 5xx."""
|
||||
return build_retry(max_attempts=max_attempts, service=service)
|
||||
|
||||
|
||||
def raise_for_status(
|
||||
response: httpx.Response,
|
||||
*,
|
||||
service: str = "",
|
||||
) -> None:
|
||||
"""Map non-2xx httpx responses to the appropriate ``ConnectorError``."""
|
||||
if response.is_success:
|
||||
return
|
||||
|
||||
status = response.status_code
|
||||
|
||||
try:
|
||||
body = response.json()
|
||||
except Exception:
|
||||
body = response.text[:500] if response.text else None
|
||||
|
||||
if status == 429:
|
||||
retry_after_raw = response.headers.get("Retry-After")
|
||||
retry_after: float | None = None
|
||||
if retry_after_raw:
|
||||
try:
|
||||
retry_after = float(retry_after_raw)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
raise ConnectorRateLimitError(
|
||||
f"{service} rate limited (429)",
|
||||
service=service,
|
||||
retry_after=retry_after,
|
||||
response_body=body,
|
||||
)
|
||||
|
||||
if status in (401, 403):
|
||||
raise ConnectorAuthError(
|
||||
f"{service} authentication failed ({status})",
|
||||
service=service,
|
||||
status_code=status,
|
||||
response_body=body,
|
||||
)
|
||||
|
||||
if status == 504:
|
||||
raise ConnectorTimeoutError(
|
||||
f"{service} gateway timeout (504)",
|
||||
service=service,
|
||||
status_code=status,
|
||||
response_body=body,
|
||||
)
|
||||
|
||||
if status >= 500:
|
||||
raise ConnectorAPIError(
|
||||
f"{service} server error ({status})",
|
||||
service=service,
|
||||
status_code=status,
|
||||
response_body=body,
|
||||
)
|
||||
|
||||
raise ConnectorAPIError(
|
||||
f"{service} request failed ({status})",
|
||||
service=service,
|
||||
status_code=status,
|
||||
response_body=body,
|
||||
)
|
||||
|
|
@ -39,7 +39,7 @@ BASE_NAME_FOR_TYPE = {
|
|||
def get_base_name_for_type(connector_type: SearchSourceConnectorType) -> str:
|
||||
"""Get a friendly display name for a connector type."""
|
||||
return BASE_NAME_FOR_TYPE.get(
|
||||
connector_type, connector_type.replace("_", " ").title()
|
||||
connector_type, connector_type.value.replace("_", " ").title()
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -231,9 +231,11 @@ async def generate_unique_connector_name(
|
|||
base = get_base_name_for_type(connector_type)
|
||||
|
||||
if identifier:
|
||||
return f"{base} - {identifier}"
|
||||
name = f"{base} - {identifier}"
|
||||
return await ensure_unique_connector_name(
|
||||
session, name, search_space_id, user_id,
|
||||
)
|
||||
|
||||
# Fallback: use counter for uniqueness
|
||||
count = await count_connectors_of_type(
|
||||
session, connector_type, search_space_id, user_id
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,19 +18,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# Mapping of connector types to their corresponding Celery task names
|
||||
CONNECTOR_TASK_MAP = {
|
||||
SearchSourceConnectorType.SLACK_CONNECTOR: "index_slack_messages",
|
||||
SearchSourceConnectorType.TEAMS_CONNECTOR: "index_teams_messages",
|
||||
SearchSourceConnectorType.NOTION_CONNECTOR: "index_notion_pages",
|
||||
SearchSourceConnectorType.GITHUB_CONNECTOR: "index_github_repos",
|
||||
SearchSourceConnectorType.LINEAR_CONNECTOR: "index_linear_issues",
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR: "index_jira_issues",
|
||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "index_confluence_pages",
|
||||
SearchSourceConnectorType.CLICKUP_CONNECTOR: "index_clickup_tasks",
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "index_google_calendar_events",
|
||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR: "index_airtable_records",
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: "index_google_gmail_messages",
|
||||
SearchSourceConnectorType.DISCORD_CONNECTOR: "index_discord_messages",
|
||||
SearchSourceConnectorType.LUMA_CONNECTOR: "index_luma_events",
|
||||
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: "index_elasticsearch_documents",
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: "index_crawled_urls",
|
||||
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: "index_bookstack_pages",
|
||||
|
|
@ -83,39 +73,19 @@ def create_periodic_schedule(
|
|||
f"(frequency: {frequency_minutes} minutes). Triggering first run..."
|
||||
)
|
||||
|
||||
# Import all indexing tasks
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_airtable_records_task,
|
||||
index_bookstack_pages_task,
|
||||
index_clickup_tasks_task,
|
||||
index_confluence_pages_task,
|
||||
index_crawled_urls_task,
|
||||
index_discord_messages_task,
|
||||
index_elasticsearch_documents_task,
|
||||
index_github_repos_task,
|
||||
index_google_calendar_events_task,
|
||||
index_google_gmail_messages_task,
|
||||
index_jira_issues_task,
|
||||
index_linear_issues_task,
|
||||
index_luma_events_task,
|
||||
index_notion_pages_task,
|
||||
index_slack_messages_task,
|
||||
)
|
||||
|
||||
# Map connector type to task
|
||||
task_map = {
|
||||
SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task,
|
||||
SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task,
|
||||
SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task,
|
||||
SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task,
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task,
|
||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task,
|
||||
SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task,
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
|
||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task,
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task,
|
||||
SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task,
|
||||
SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task,
|
||||
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
|
||||
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: index_bookstack_pages_task,
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ dependencies = [
|
|||
"deepagents>=0.4.12",
|
||||
"stripe>=15.0.0",
|
||||
"azure-ai-documentintelligence>=1.0.2",
|
||||
"litellm>=1.83.0",
|
||||
"litellm>=1.83.4",
|
||||
"langchain-litellm>=0.6.4",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,213 @@
|
|||
"""Unit tests for resume page-limit helpers and enforcement flow."""
|
||||
|
||||
import io
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pypdf
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.tools import resume as resume_tool
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeReport:
|
||||
_next_id = 1000
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
self.id = None
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, parent_report=None):
|
||||
self.parent_report = parent_report
|
||||
self.added: list[_FakeReport] = []
|
||||
|
||||
async def get(self, _model, _id):
|
||||
return self.parent_report
|
||||
|
||||
def add(self, report):
|
||||
self.added.append(report)
|
||||
|
||||
async def commit(self):
|
||||
for report in self.added:
|
||||
if getattr(report, "id", None) is None:
|
||||
report.id = _FakeReport._next_id
|
||||
_FakeReport._next_id += 1
|
||||
|
||||
async def refresh(self, _report):
|
||||
return None
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.session
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
class _SessionFactory:
|
||||
def __init__(self, sessions):
|
||||
self._sessions = list(sessions)
|
||||
|
||||
def __call__(self):
|
||||
if not self._sessions:
|
||||
raise RuntimeError("No fake sessions left")
|
||||
return _SessionContext(self._sessions.pop(0))
|
||||
|
||||
|
||||
def _make_pdf_with_pages(page_count: int) -> bytes:
|
||||
writer = pypdf.PdfWriter()
|
||||
for _ in range(page_count):
|
||||
writer.add_blank_page(width=612, height=792)
|
||||
output = io.BytesIO()
|
||||
writer.write(output)
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def test_count_pdf_pages_reads_compiled_bytes() -> None:
|
||||
pdf_bytes = _make_pdf_with_pages(2)
|
||||
assert resume_tool._count_pdf_pages(pdf_bytes) == 2
|
||||
|
||||
|
||||
def test_validate_max_pages_rejects_out_of_range() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
resume_tool._validate_max_pages(0)
|
||||
with pytest.raises(ValueError):
|
||||
resume_tool._validate_max_pages(6)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_resume_defaults_to_one_page_target(monkeypatch) -> None:
|
||||
read_session = _FakeSession()
|
||||
write_session = _FakeSession()
|
||||
session_factory = _SessionFactory([read_session, write_session])
|
||||
monkeypatch.setattr(resume_tool, "shielded_async_session", session_factory)
|
||||
monkeypatch.setattr(resume_tool, "Report", _FakeReport)
|
||||
|
||||
prompts: list[str] = []
|
||||
|
||||
async def _llm_invoke(messages):
|
||||
prompts.append(messages[0].content)
|
||||
return SimpleNamespace(content="= Jane Doe\n== Experience\n- Built systems")
|
||||
|
||||
llm = SimpleNamespace(ainvoke=AsyncMock(side_effect=_llm_invoke))
|
||||
monkeypatch.setattr(
|
||||
resume_tool,
|
||||
"get_document_summary_llm",
|
||||
AsyncMock(return_value=llm),
|
||||
)
|
||||
monkeypatch.setattr(resume_tool, "_compile_typst", lambda _source: b"pdf")
|
||||
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: 1)
|
||||
|
||||
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||
result = await tool.ainvoke({"user_info": "Jane Doe experience"})
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert prompts
|
||||
assert "**Target Maximum Pages:** 1" in prompts[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_resume_compresses_when_over_limit(monkeypatch) -> None:
|
||||
read_session = _FakeSession()
|
||||
write_session = _FakeSession()
|
||||
session_factory = _SessionFactory([read_session, write_session])
|
||||
monkeypatch.setattr(resume_tool, "shielded_async_session", session_factory)
|
||||
monkeypatch.setattr(resume_tool, "Report", _FakeReport)
|
||||
|
||||
responses = [
|
||||
SimpleNamespace(content="= Jane Doe\n== Experience\n- Detailed bullet 1"),
|
||||
SimpleNamespace(content="= Jane Doe\n== Experience\n- Condensed bullet"),
|
||||
]
|
||||
llm = SimpleNamespace(ainvoke=AsyncMock(side_effect=responses))
|
||||
monkeypatch.setattr(
|
||||
resume_tool,
|
||||
"get_document_summary_llm",
|
||||
AsyncMock(return_value=llm),
|
||||
)
|
||||
monkeypatch.setattr(resume_tool, "_compile_typst", lambda _source: b"pdf")
|
||||
page_counts = iter([2, 1])
|
||||
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
||||
|
||||
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert write_session.added, "Expected successful report write"
|
||||
metadata = write_session.added[0].report_metadata
|
||||
assert metadata["target_max_pages"] == 1
|
||||
assert metadata["actual_page_count"] == 1
|
||||
assert metadata["compression_attempts"] == 1
|
||||
assert metadata["page_limit_enforced"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_resume_returns_ready_when_target_not_met(monkeypatch) -> None:
|
||||
read_session = _FakeSession()
|
||||
write_session = _FakeSession()
|
||||
session_factory = _SessionFactory([read_session, write_session])
|
||||
monkeypatch.setattr(resume_tool, "shielded_async_session", session_factory)
|
||||
monkeypatch.setattr(resume_tool, "Report", _FakeReport)
|
||||
|
||||
responses = [
|
||||
SimpleNamespace(content="= Jane Doe\n== Experience\n- Long detail"),
|
||||
SimpleNamespace(content="= Jane Doe\n== Experience\n- Still long"),
|
||||
SimpleNamespace(content="= Jane Doe\n== Experience\n- Still too long"),
|
||||
]
|
||||
llm = SimpleNamespace(ainvoke=AsyncMock(side_effect=responses))
|
||||
monkeypatch.setattr(
|
||||
resume_tool,
|
||||
"get_document_summary_llm",
|
||||
AsyncMock(return_value=llm),
|
||||
)
|
||||
monkeypatch.setattr(resume_tool, "_compile_typst", lambda _source: b"pdf")
|
||||
page_counts = iter([3, 3, 2])
|
||||
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
||||
|
||||
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert "could not fit the target" in (result["message"] or "").lower()
|
||||
metadata = write_session.added[0].report_metadata
|
||||
assert metadata["target_page_met"] is False
|
||||
assert metadata["actual_page_count"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_resume_fails_when_hard_limit_exceeded(monkeypatch) -> None:
|
||||
read_session = _FakeSession()
|
||||
failed_session = _FakeSession()
|
||||
session_factory = _SessionFactory([read_session, failed_session])
|
||||
monkeypatch.setattr(resume_tool, "shielded_async_session", session_factory)
|
||||
monkeypatch.setattr(resume_tool, "Report", _FakeReport)
|
||||
|
||||
responses = [
|
||||
SimpleNamespace(content="= Jane Doe\n== Experience\n- Long detail"),
|
||||
SimpleNamespace(content="= Jane Doe\n== Experience\n- Still long"),
|
||||
SimpleNamespace(content="= Jane Doe\n== Experience\n- Still too long"),
|
||||
]
|
||||
llm = SimpleNamespace(ainvoke=AsyncMock(side_effect=responses))
|
||||
monkeypatch.setattr(
|
||||
resume_tool,
|
||||
"get_document_summary_llm",
|
||||
AsyncMock(return_value=llm),
|
||||
)
|
||||
monkeypatch.setattr(resume_tool, "_compile_typst", lambda _source: b"pdf")
|
||||
page_counts = iter([7, 6, 6])
|
||||
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
||||
|
||||
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert "hard page limit" in (result["error"] or "").lower()
|
||||
assert failed_session.added, "Expected failed report persistence"
|
||||
|
|
@ -0,0 +1,214 @@
|
|||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.agents.new_chat.middleware.file_intent import (
|
||||
FileIntentMiddleware,
|
||||
FileOperationIntent,
|
||||
_fallback_path,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeLLM:
|
||||
def __init__(self, response_text: str):
|
||||
self._response_text = response_text
|
||||
|
||||
async def ainvoke(self, *_args, **_kwargs):
|
||||
return AIMessage(content=self._response_text)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_write_intent_injects_contract_message():
|
||||
llm = _FakeLLM(
|
||||
'{"intent":"file_write","confidence":0.93,"suggested_filename":"ideas.md"}'
|
||||
)
|
||||
middleware = FileIntentMiddleware(llm=llm)
|
||||
state = {
|
||||
"messages": [HumanMessage(content="Create another random note for me")],
|
||||
"turn_id": "123:456",
|
||||
}
|
||||
|
||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||
|
||||
assert result is not None
|
||||
contract = result["file_operation_contract"]
|
||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||
assert contract["suggested_path"] == "/ideas.md"
|
||||
assert contract["turn_id"] == "123:456"
|
||||
assert any(
|
||||
"file_operation_contract" in str(msg.content)
|
||||
for msg in result["messages"]
|
||||
if hasattr(msg, "content")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_write_intent_does_not_inject_contract_message():
|
||||
llm = _FakeLLM(
|
||||
'{"intent":"file_read","confidence":0.88,"suggested_filename":null}'
|
||||
)
|
||||
middleware = FileIntentMiddleware(llm=llm)
|
||||
original_messages = [HumanMessage(content="Read /notes.md")]
|
||||
state = {"messages": original_messages, "turn_id": "abc:def"}
|
||||
|
||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||
|
||||
assert result is not None
|
||||
assert result["file_operation_contract"]["intent"] == FileOperationIntent.FILE_READ.value
|
||||
assert "messages" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_write_null_filename_uses_semantic_default_path():
|
||||
llm = _FakeLLM(
|
||||
'{"intent":"file_write","confidence":0.74,"suggested_filename":null}'
|
||||
)
|
||||
middleware = FileIntentMiddleware(llm=llm)
|
||||
state = {
|
||||
"messages": [HumanMessage(content="create a random markdown file")],
|
||||
"turn_id": "turn:1",
|
||||
}
|
||||
|
||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||
|
||||
assert result is not None
|
||||
contract = result["file_operation_contract"]
|
||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||
assert contract["suggested_path"] == "/notes.md"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_write_null_filename_infers_json_extension():
|
||||
llm = _FakeLLM(
|
||||
'{"intent":"file_write","confidence":0.71,"suggested_filename":null}'
|
||||
)
|
||||
middleware = FileIntentMiddleware(llm=llm)
|
||||
state = {
|
||||
"messages": [HumanMessage(content="create a sample json config file")],
|
||||
"turn_id": "turn:2",
|
||||
}
|
||||
|
||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||
|
||||
assert result is not None
|
||||
contract = result["file_operation_contract"]
|
||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||
assert contract["suggested_path"] == "/notes.json"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_write_txt_suggestion_is_normalized_to_markdown():
|
||||
llm = _FakeLLM(
|
||||
'{"intent":"file_write","confidence":0.82,"suggested_filename":"random.txt"}'
|
||||
)
|
||||
middleware = FileIntentMiddleware(llm=llm)
|
||||
state = {
|
||||
"messages": [HumanMessage(content="create a random file")],
|
||||
"turn_id": "turn:3",
|
||||
}
|
||||
|
||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||
|
||||
assert result is not None
|
||||
contract = result["file_operation_contract"]
|
||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||
assert contract["suggested_path"] == "/random.md"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_write_with_suggested_directory_preserves_folder():
|
||||
llm = _FakeLLM(
|
||||
'{"intent":"file_write","confidence":0.86,"suggested_filename":"random.md","suggested_directory":"pc backups","suggested_path":null}'
|
||||
)
|
||||
middleware = FileIntentMiddleware(llm=llm)
|
||||
state = {
|
||||
"messages": [HumanMessage(content="create a random file in pc backups folder")],
|
||||
"turn_id": "turn:4",
|
||||
}
|
||||
|
||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||
|
||||
assert result is not None
|
||||
contract = result["file_operation_contract"]
|
||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||
assert contract["suggested_path"] == "/pc_backups/random.md"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_write_with_suggested_path_takes_precedence():
|
||||
llm = _FakeLLM(
|
||||
'{"intent":"file_write","confidence":0.9,"suggested_filename":"ignored.md","suggested_directory":"docs","suggested_path":"/reports/q2/summary.md"}'
|
||||
)
|
||||
middleware = FileIntentMiddleware(llm=llm)
|
||||
state = {
|
||||
"messages": [HumanMessage(content="create report")],
|
||||
"turn_id": "turn:5",
|
||||
}
|
||||
|
||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||
|
||||
assert result is not None
|
||||
contract = result["file_operation_contract"]
|
||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||
assert contract["suggested_path"] == "/reports/q2/summary.md"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_write_infers_directory_from_user_text_when_missing():
|
||||
llm = _FakeLLM(
|
||||
'{"intent":"file_write","confidence":0.83,"suggested_filename":"random.md","suggested_directory":null,"suggested_path":null}'
|
||||
)
|
||||
middleware = FileIntentMiddleware(llm=llm)
|
||||
state = {
|
||||
"messages": [HumanMessage(content="create a random file in pc backups folder")],
|
||||
"turn_id": "turn:6",
|
||||
}
|
||||
|
||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||
|
||||
assert result is not None
|
||||
contract = result["file_operation_contract"]
|
||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||
assert contract["suggested_path"] == "/pc_backups/random.md"
|
||||
|
||||
|
||||
def test_fallback_path_normalizes_windows_slashes() -> None:
|
||||
resolved = _fallback_path(
|
||||
suggested_filename="summary.md",
|
||||
suggested_path=r"\reports\q2\summary.md",
|
||||
user_text="create report",
|
||||
)
|
||||
|
||||
assert resolved == "/reports/q2/summary.md"
|
||||
|
||||
|
||||
def test_fallback_path_normalizes_windows_drive_path() -> None:
|
||||
resolved = _fallback_path(
|
||||
suggested_filename=None,
|
||||
suggested_path=r"C:\Users\anish\notes\todo.md",
|
||||
user_text="create note",
|
||||
)
|
||||
|
||||
assert resolved == "/C/Users/anish/notes/todo.md"
|
||||
|
||||
|
||||
def test_fallback_path_normalizes_mixed_separators_and_duplicate_slashes() -> None:
|
||||
resolved = _fallback_path(
|
||||
suggested_filename="summary.md",
|
||||
suggested_path=r"\\reports\\q2//summary.md",
|
||||
user_text="create report",
|
||||
)
|
||||
|
||||
assert resolved == "/reports/q2/summary.md"
|
||||
|
||||
|
||||
def test_fallback_path_keeps_posix_style_absolute_path_for_linux_and_macos() -> None:
|
||||
resolved = _fallback_path(
|
||||
suggested_filename=None,
|
||||
suggested_path="/var/log/surfsense/notes.md",
|
||||
user_text="create note",
|
||||
)
|
||||
|
||||
assert resolved == "/var/log/surfsense/notes.md"
|
||||
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||
from app.agents.new_chat.filesystem_selection import (
|
||||
ClientPlatform,
|
||||
FilesystemMode,
|
||||
FilesystemSelection,
|
||||
LocalFilesystemMount,
|
||||
)
|
||||
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||
MultiRootLocalFolderBackend,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _RuntimeStub:
|
||||
state = {"files": {}}
|
||||
|
||||
|
||||
def test_backend_resolver_returns_multi_root_backend_for_single_root(tmp_path: Path):
|
||||
selection = FilesystemSelection(
|
||||
mode=FilesystemMode.DESKTOP_LOCAL_FOLDER,
|
||||
client_platform=ClientPlatform.DESKTOP,
|
||||
local_mounts=(LocalFilesystemMount(mount_id="tmp", root_path=str(tmp_path)),),
|
||||
)
|
||||
resolver = build_backend_resolver(selection)
|
||||
|
||||
backend = resolver(_RuntimeStub())
|
||||
assert isinstance(backend, MultiRootLocalFolderBackend)
|
||||
|
||||
|
||||
def test_backend_resolver_uses_cloud_mode_by_default():
|
||||
resolver = build_backend_resolver(FilesystemSelection())
|
||||
backend = resolver(_RuntimeStub())
|
||||
# StateBackend class name check keeps this test decoupled
|
||||
# from internal deepagents runtime class identity.
|
||||
assert backend.__class__.__name__ == "StateBackend"
|
||||
|
||||
|
||||
def test_backend_resolver_returns_multi_root_backend_for_multiple_roots(tmp_path: Path):
|
||||
root_one = tmp_path / "resume"
|
||||
root_two = tmp_path / "notes"
|
||||
root_one.mkdir()
|
||||
root_two.mkdir()
|
||||
selection = FilesystemSelection(
|
||||
mode=FilesystemMode.DESKTOP_LOCAL_FOLDER,
|
||||
client_platform=ClientPlatform.DESKTOP,
|
||||
local_mounts=(
|
||||
LocalFilesystemMount(mount_id="resume", root_path=str(root_one)),
|
||||
LocalFilesystemMount(mount_id="notes", root_path=str(root_two)),
|
||||
),
|
||||
)
|
||||
resolver = build_backend_resolver(selection)
|
||||
|
||||
backend = resolver(_RuntimeStub())
|
||||
assert isinstance(backend, MultiRootLocalFolderBackend)
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||
MultiRootLocalFolderBackend,
|
||||
)
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _BackendWithRawRead:
|
||||
def __init__(self, content: str) -> None:
|
||||
self._content = content
|
||||
|
||||
def read(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
|
||||
del file_path, offset, limit
|
||||
return " 1\tline1\n 2\tline2"
|
||||
|
||||
async def aread(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
|
||||
return self.read(file_path, offset, limit)
|
||||
|
||||
def read_raw(self, file_path: str) -> str:
|
||||
del file_path
|
||||
return self._content
|
||||
|
||||
async def aread_raw(self, file_path: str) -> str:
|
||||
return self.read_raw(file_path)
|
||||
|
||||
|
||||
class _RuntimeNoSuggestedPath:
|
||||
state = {"file_operation_contract": {}}
|
||||
|
||||
|
||||
def test_verify_written_content_prefers_raw_sync() -> None:
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
expected = "line1\nline2"
|
||||
backend = _BackendWithRawRead(expected)
|
||||
|
||||
verify_error = middleware._verify_written_content_sync(
|
||||
backend=backend,
|
||||
path="/note.md",
|
||||
expected_content=expected,
|
||||
)
|
||||
|
||||
assert verify_error is None
|
||||
|
||||
|
||||
def test_contract_suggested_path_falls_back_to_notes_md() -> None:
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
middleware._filesystem_mode = FilesystemMode.CLOUD
|
||||
suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type]
|
||||
assert suggested == "/notes.md"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_written_content_prefers_raw_async() -> None:
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
expected = "line1\nline2"
|
||||
backend = _BackendWithRawRead(expected)
|
||||
|
||||
verify_error = await middleware._verify_written_content_async(
|
||||
backend=backend,
|
||||
path="/note.md",
|
||||
expected_content=expected,
|
||||
)
|
||||
|
||||
assert verify_error is None
|
||||
|
||||
|
||||
def test_normalize_local_mount_path_prefixes_default_mount(tmp_path: Path) -> None:
|
||||
root = tmp_path / "PC Backups"
|
||||
root.mkdir()
|
||||
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
|
||||
runtime = _RuntimeNoSuggestedPath()
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
|
||||
|
||||
resolved = middleware._normalize_local_mount_path("/random-note.md", runtime) # type: ignore[arg-type]
|
||||
|
||||
assert resolved == "/pc_backups/random-note.md"
|
||||
|
||||
|
||||
def test_normalize_local_mount_path_keeps_explicit_mount(tmp_path: Path) -> None:
|
||||
root = tmp_path / "PC Backups"
|
||||
root.mkdir()
|
||||
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
|
||||
runtime = _RuntimeNoSuggestedPath()
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
|
||||
|
||||
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
|
||||
"/pc_backups/notes/random-note.md",
|
||||
runtime,
|
||||
)
|
||||
|
||||
assert resolved == "/pc_backups/notes/random-note.md"
|
||||
|
||||
|
||||
def test_normalize_local_mount_path_windows_backslashes(tmp_path: Path) -> None:
|
||||
root = tmp_path / "PC Backups"
|
||||
root.mkdir()
|
||||
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
|
||||
runtime = _RuntimeNoSuggestedPath()
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
|
||||
|
||||
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
|
||||
r"\notes\random-note.md",
|
||||
runtime,
|
||||
)
|
||||
|
||||
assert resolved == "/pc_backups/notes/random-note.md"
|
||||
|
||||
|
||||
def test_normalize_local_mount_path_normalizes_mixed_separators(tmp_path: Path) -> None:
|
||||
root = tmp_path / "PC Backups"
|
||||
root.mkdir()
|
||||
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
|
||||
runtime = _RuntimeNoSuggestedPath()
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
|
||||
|
||||
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
|
||||
r"\\notes//nested\\random-note.md",
|
||||
runtime,
|
||||
)
|
||||
|
||||
assert resolved == "/pc_backups/notes/nested/random-note.md"
|
||||
|
||||
|
||||
def test_normalize_local_mount_path_keeps_explicit_mount_with_backslashes(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
root = tmp_path / "PC Backups"
|
||||
root.mkdir()
|
||||
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
|
||||
runtime = _RuntimeNoSuggestedPath()
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
|
||||
|
||||
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
|
||||
r"\pc_backups\notes\random-note.md",
|
||||
runtime,
|
||||
)
|
||||
|
||||
assert resolved == "/pc_backups/notes/random-note.md"
|
||||
|
||||
|
||||
def test_normalize_local_mount_path_prefixes_posix_absolute_path_for_linux_and_macos(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
root = tmp_path / "PC Backups"
|
||||
root.mkdir()
|
||||
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
|
||||
runtime = _RuntimeNoSuggestedPath()
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
|
||||
|
||||
resolved = middleware._normalize_local_mount_path("/var/log/app.log", runtime) # type: ignore[arg-type]
|
||||
|
||||
assert resolved == "/pc_backups/var/log/app.log"
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_local_backend_write_read_edit_roundtrip(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
|
||||
write = backend.write("/notes/test.md", "line1\nline2")
|
||||
assert write.error is None
|
||||
assert write.path == "/notes/test.md"
|
||||
|
||||
read = backend.read("/notes/test.md", offset=0, limit=20)
|
||||
assert "line1" in read
|
||||
assert "line2" in read
|
||||
|
||||
edit = backend.edit("/notes/test.md", "line2", "updated")
|
||||
assert edit.error is None
|
||||
assert edit.occurrences == 1
|
||||
|
||||
read_after = backend.read("/notes/test.md", offset=0, limit=20)
|
||||
assert "updated" in read_after
|
||||
|
||||
|
||||
def test_local_backend_blocks_path_escape(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
|
||||
result = backend.write("/../../etc/passwd", "bad")
|
||||
assert result.error is not None
|
||||
assert "Invalid path" in result.error
|
||||
|
||||
|
||||
def test_local_backend_glob_and_grep(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
(tmp_path / "docs").mkdir()
|
||||
(tmp_path / "docs" / "a.txt").write_text("hello world\n")
|
||||
(tmp_path / "docs" / "b.md").write_text("hello markdown\n")
|
||||
|
||||
infos = backend.glob_info("**/*.txt", "/docs")
|
||||
paths = {info["path"] for info in infos}
|
||||
assert "/docs/a.txt" in paths
|
||||
|
||||
grep = backend.grep_raw("hello", "/docs", "*.md")
|
||||
assert isinstance(grep, list)
|
||||
assert any(match["path"] == "/docs/b.md" for match in grep)
|
||||
|
||||
|
||||
def test_local_backend_read_raw_returns_exact_content(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
expected = "# Title\n\nline 1\nline 2\n"
|
||||
write = backend.write("/notes/raw.md", expected)
|
||||
assert write.error is None
|
||||
|
||||
raw = backend.read_raw("/notes/raw.md")
|
||||
assert raw == expected
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||
MultiRootLocalFolderBackend,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_mount_ids_preserve_client_mapping_order(tmp_path: Path) -> None:
|
||||
root_one = tmp_path / "PC Backups"
|
||||
root_two = tmp_path / "pc_backups"
|
||||
root_three = tmp_path / "notes@2026"
|
||||
root_one.mkdir()
|
||||
root_two.mkdir()
|
||||
root_three.mkdir()
|
||||
|
||||
backend = MultiRootLocalFolderBackend(
|
||||
(
|
||||
("pc_backups", str(root_one)),
|
||||
("pc_backups_2", str(root_two)),
|
||||
("notes_2026", str(root_three)),
|
||||
)
|
||||
)
|
||||
|
||||
assert backend.list_mounts() == ("pc_backups", "pc_backups_2", "notes_2026")
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import pytest
|
||||
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
StreamResult,
|
||||
_contract_enforcement_active,
|
||||
_evaluate_file_contract_outcome,
|
||||
_tool_output_has_error,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_tool_output_error_detection():
|
||||
assert _tool_output_has_error("Error: failed to write file")
|
||||
assert _tool_output_has_error({"error": "boom"})
|
||||
assert _tool_output_has_error({"result": "Error: disk is full"})
|
||||
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
|
||||
|
||||
|
||||
def test_file_write_contract_outcome_reasons():
|
||||
result = StreamResult(intent_detected="file_write")
|
||||
passed, reason = _evaluate_file_contract_outcome(result)
|
||||
assert not passed
|
||||
assert reason == "no_write_attempt"
|
||||
|
||||
result.write_attempted = True
|
||||
passed, reason = _evaluate_file_contract_outcome(result)
|
||||
assert not passed
|
||||
assert reason == "write_failed"
|
||||
|
||||
result.write_succeeded = True
|
||||
passed, reason = _evaluate_file_contract_outcome(result)
|
||||
assert not passed
|
||||
assert reason == "verification_failed"
|
||||
|
||||
result.verification_succeeded = True
|
||||
passed, reason = _evaluate_file_contract_outcome(result)
|
||||
assert passed
|
||||
assert reason == ""
|
||||
|
||||
|
||||
def test_contract_enforcement_local_only():
|
||||
result = StreamResult(filesystem_mode="desktop_local_folder")
|
||||
assert _contract_enforcement_active(result)
|
||||
|
||||
result.filesystem_mode = "cloud"
|
||||
assert not _contract_enforcement_active(result)
|
||||
|
||||
2
surfsense_backend/uv.lock
generated
2
surfsense_backend/uv.lock
generated
|
|
@ -8070,7 +8070,7 @@ requires-dist = [
|
|||
{ name = "langgraph", specifier = ">=1.1.3" },
|
||||
{ name = "langgraph-checkpoint-postgres", specifier = ">=3.0.2" },
|
||||
{ name = "linkup-sdk", specifier = ">=0.2.4" },
|
||||
{ name = "litellm", specifier = ">=1.83.0" },
|
||||
{ name = "litellm", specifier = ">=1.83.4" },
|
||||
{ name = "llama-cloud-services", specifier = ">=0.6.25" },
|
||||
{ name = "markdown", specifier = ">=3.7" },
|
||||
{ name = "markdownify", specifier = ">=0.14.1" },
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ export const IPC_CHANNELS = {
|
|||
FOLDER_SYNC_SEED_MTIMES: 'folder-sync:seed-mtimes',
|
||||
BROWSE_FILES: 'browse:files',
|
||||
READ_LOCAL_FILES: 'browse:read-local-files',
|
||||
READ_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:read-local-file-text',
|
||||
WRITE_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:write-local-file-text',
|
||||
// Auth token sync across windows
|
||||
GET_AUTH_TOKENS: 'auth:get-tokens',
|
||||
SET_AUTH_TOKENS: 'auth:set-tokens',
|
||||
|
|
@ -51,4 +53,9 @@ export const IPC_CHANNELS = {
|
|||
ANALYTICS_RESET: 'analytics:reset',
|
||||
ANALYTICS_CAPTURE: 'analytics:capture',
|
||||
ANALYTICS_GET_CONTEXT: 'analytics:get-context',
|
||||
// Agent filesystem mode
|
||||
AGENT_FILESYSTEM_GET_SETTINGS: 'agent-filesystem:get-settings',
|
||||
AGENT_FILESYSTEM_GET_MOUNTS: 'agent-filesystem:get-mounts',
|
||||
AGENT_FILESYSTEM_SET_SETTINGS: 'agent-filesystem:set-settings',
|
||||
AGENT_FILESYSTEM_PICK_ROOT: 'agent-filesystem:pick-root',
|
||||
} as const;
|
||||
|
|
|
|||
|
|
@ -36,6 +36,14 @@ import {
|
|||
resetUser as analyticsReset,
|
||||
trackEvent,
|
||||
} from '../modules/analytics';
|
||||
import {
|
||||
readAgentLocalFileText,
|
||||
writeAgentLocalFileText,
|
||||
getAgentFilesystemMounts,
|
||||
getAgentFilesystemSettings,
|
||||
pickAgentFilesystemRoot,
|
||||
setAgentFilesystemSettings,
|
||||
} from '../modules/agent-filesystem';
|
||||
|
||||
let authTokens: { bearer: string; refresh: string } | null = null;
|
||||
|
||||
|
|
@ -118,6 +126,29 @@ export function registerIpcHandlers(): void {
|
|||
readLocalFiles(paths)
|
||||
);
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, async (_event, virtualPath: string) => {
|
||||
try {
|
||||
const result = await readAgentLocalFileText(virtualPath);
|
||||
return { ok: true, path: result.path, content: result.content };
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : 'Failed to read local file';
|
||||
return { ok: false, path: virtualPath, error: message };
|
||||
}
|
||||
});
|
||||
|
||||
ipcMain.handle(
|
||||
IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT,
|
||||
async (_event, virtualPath: string, content: string) => {
|
||||
try {
|
||||
const result = await writeAgentLocalFileText(virtualPath, content);
|
||||
return { ok: true, path: result.path };
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : 'Failed to write local file';
|
||||
return { ok: false, path: virtualPath, error: message };
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.SET_AUTH_TOKENS, (_event, tokens: { bearer: string; refresh: string }) => {
|
||||
authTokens = tokens;
|
||||
});
|
||||
|
|
@ -191,4 +222,22 @@ export function registerIpcHandlers(): void {
|
|||
platform: process.platform,
|
||||
};
|
||||
});
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS, () =>
|
||||
getAgentFilesystemSettings()
|
||||
);
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS, () =>
|
||||
getAgentFilesystemMounts()
|
||||
);
|
||||
|
||||
ipcMain.handle(
|
||||
IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS,
|
||||
(_event, settings: { mode?: 'cloud' | 'desktop_local_folder'; localRootPaths?: string[] | null }) =>
|
||||
setAgentFilesystemSettings(settings)
|
||||
);
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT, () =>
|
||||
pickAgentFilesystemRoot()
|
||||
);
|
||||
}
|
||||
|
|
|
|||
254
surfsense_desktop/src/modules/agent-filesystem.ts
Normal file
254
surfsense_desktop/src/modules/agent-filesystem.ts
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
import { app, dialog } from "electron";
|
||||
import { access, mkdir, readFile, writeFile } from "node:fs/promises";
|
||||
import { dirname, isAbsolute, join, relative, resolve } from "node:path";
|
||||
|
||||
export type AgentFilesystemMode = "cloud" | "desktop_local_folder";
|
||||
|
||||
export interface AgentFilesystemSettings {
|
||||
mode: AgentFilesystemMode;
|
||||
localRootPaths: string[];
|
||||
updatedAt: string;
|
||||
}
|
||||
|
||||
const SETTINGS_FILENAME = "agent-filesystem-settings.json";
|
||||
const MAX_LOCAL_ROOTS = 5;
|
||||
|
||||
function getSettingsPath(): string {
|
||||
return join(app.getPath("userData"), SETTINGS_FILENAME);
|
||||
}
|
||||
|
||||
function getDefaultSettings(): AgentFilesystemSettings {
|
||||
return {
|
||||
mode: "cloud",
|
||||
localRootPaths: [],
|
||||
updatedAt: new Date().toISOString(),
|
||||
};
|
||||
}
|
||||
|
||||
function normalizeLocalRootPaths(paths: unknown): string[] {
|
||||
if (!Array.isArray(paths)) {
|
||||
return [];
|
||||
}
|
||||
const uniquePaths = new Set<string>();
|
||||
for (const path of paths) {
|
||||
if (typeof path !== "string") continue;
|
||||
const trimmed = path.trim();
|
||||
if (!trimmed) continue;
|
||||
uniquePaths.add(trimmed);
|
||||
if (uniquePaths.size >= MAX_LOCAL_ROOTS) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return [...uniquePaths];
|
||||
}
|
||||
|
||||
export async function getAgentFilesystemSettings(): Promise<AgentFilesystemSettings> {
|
||||
try {
|
||||
const raw = await readFile(getSettingsPath(), "utf8");
|
||||
const parsed = JSON.parse(raw) as Partial<AgentFilesystemSettings>;
|
||||
if (parsed.mode !== "cloud" && parsed.mode !== "desktop_local_folder") {
|
||||
return getDefaultSettings();
|
||||
}
|
||||
return {
|
||||
mode: parsed.mode,
|
||||
localRootPaths: normalizeLocalRootPaths(parsed.localRootPaths),
|
||||
updatedAt: parsed.updatedAt ?? new Date().toISOString(),
|
||||
};
|
||||
} catch {
|
||||
return getDefaultSettings();
|
||||
}
|
||||
}
|
||||
|
||||
export async function setAgentFilesystemSettings(
|
||||
settings: {
|
||||
mode?: AgentFilesystemMode;
|
||||
localRootPaths?: string[] | null;
|
||||
}
|
||||
): Promise<AgentFilesystemSettings> {
|
||||
const current = await getAgentFilesystemSettings();
|
||||
const nextMode =
|
||||
settings.mode === "cloud" || settings.mode === "desktop_local_folder"
|
||||
? settings.mode
|
||||
: current.mode;
|
||||
const next: AgentFilesystemSettings = {
|
||||
mode: nextMode,
|
||||
localRootPaths:
|
||||
settings.localRootPaths === undefined
|
||||
? current.localRootPaths
|
||||
: normalizeLocalRootPaths(settings.localRootPaths ?? []),
|
||||
updatedAt: new Date().toISOString(),
|
||||
};
|
||||
|
||||
const settingsPath = getSettingsPath();
|
||||
await mkdir(dirname(settingsPath), { recursive: true });
|
||||
await writeFile(settingsPath, JSON.stringify(next, null, 2), "utf8");
|
||||
return next;
|
||||
}
|
||||
|
||||
export async function pickAgentFilesystemRoot(): Promise<string | null> {
|
||||
const result = await dialog.showOpenDialog({
|
||||
title: "Select local folder for Agent Filesystem",
|
||||
properties: ["openDirectory"],
|
||||
});
|
||||
if (result.canceled || result.filePaths.length === 0) {
|
||||
return null;
|
||||
}
|
||||
return result.filePaths[0] ?? null;
|
||||
}
|
||||
|
||||
function resolveVirtualPath(rootPath: string, virtualPath: string): string {
|
||||
if (!virtualPath.startsWith("/")) {
|
||||
throw new Error("Path must start with '/'");
|
||||
}
|
||||
const normalizedRoot = resolve(rootPath);
|
||||
const relativePath = virtualPath.replace(/^\/+/, "");
|
||||
if (!relativePath) {
|
||||
throw new Error("Path must refer to a file under the selected root");
|
||||
}
|
||||
const absolutePath = resolve(normalizedRoot, relativePath);
|
||||
const rel = relative(normalizedRoot, absolutePath);
|
||||
if (!rel || rel.startsWith("..") || isAbsolute(rel)) {
|
||||
throw new Error("Path escapes selected local root");
|
||||
}
|
||||
return absolutePath;
|
||||
}
|
||||
|
||||
function toVirtualPath(rootPath: string, absolutePath: string): string {
|
||||
const normalizedRoot = resolve(rootPath);
|
||||
const rel = relative(normalizedRoot, absolutePath);
|
||||
if (!rel || rel.startsWith("..") || isAbsolute(rel)) {
|
||||
return "/";
|
||||
}
|
||||
return `/${rel.replace(/\\/g, "/")}`;
|
||||
}
|
||||
|
||||
export type LocalRootMount = {
|
||||
mount: string;
|
||||
rootPath: string;
|
||||
};
|
||||
|
||||
function sanitizeMountName(rawMount: string): string {
|
||||
const normalized = rawMount
|
||||
.trim()
|
||||
.toLowerCase()
|
||||
.replace(/[^a-z0-9_-]+/g, "_")
|
||||
.replace(/_+/g, "_")
|
||||
.replace(/^[_-]+|[_-]+$/g, "");
|
||||
return normalized || "root";
|
||||
}
|
||||
|
||||
function buildRootMounts(rootPaths: string[]): LocalRootMount[] {
|
||||
const mounts: LocalRootMount[] = [];
|
||||
const usedMounts = new Set<string>();
|
||||
for (const rawRootPath of rootPaths) {
|
||||
const normalizedRoot = resolve(rawRootPath);
|
||||
const baseMount = sanitizeMountName(normalizedRoot.split(/[\\/]/).at(-1) || "root");
|
||||
let mount = baseMount;
|
||||
let suffix = 2;
|
||||
while (usedMounts.has(mount)) {
|
||||
mount = `${baseMount}-${suffix}`;
|
||||
suffix += 1;
|
||||
}
|
||||
usedMounts.add(mount);
|
||||
mounts.push({ mount, rootPath: normalizedRoot });
|
||||
}
|
||||
return mounts;
|
||||
}
|
||||
|
||||
export async function getAgentFilesystemMounts(): Promise<LocalRootMount[]> {
|
||||
const rootPaths = await resolveCurrentRootPaths();
|
||||
return buildRootMounts(rootPaths);
|
||||
}
|
||||
|
||||
function parseMountedVirtualPath(
|
||||
virtualPath: string,
|
||||
mounts: LocalRootMount[]
|
||||
): {
|
||||
mount: string;
|
||||
subPath: string;
|
||||
} {
|
||||
if (!virtualPath.startsWith("/")) {
|
||||
throw new Error("Path must start with '/'");
|
||||
}
|
||||
const trimmed = virtualPath.replace(/^\/+/, "");
|
||||
if (!trimmed) {
|
||||
throw new Error("Path must include a mounted root segment");
|
||||
}
|
||||
|
||||
const [mount, ...rest] = trimmed.split("/");
|
||||
const remainder = rest.join("/");
|
||||
const directMount = mounts.find((entry) => entry.mount === mount);
|
||||
if (!directMount) {
|
||||
throw new Error(
|
||||
`Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}`
|
||||
);
|
||||
}
|
||||
if (!remainder) {
|
||||
throw new Error("Path must include a file path under the mounted root");
|
||||
}
|
||||
return { mount, subPath: `/${remainder}` };
|
||||
}
|
||||
|
||||
function findMountByName(mounts: LocalRootMount[], mountName: string): LocalRootMount | undefined {
|
||||
return mounts.find((entry) => entry.mount === mountName);
|
||||
}
|
||||
|
||||
function toMountedVirtualPath(mount: string, rootPath: string, absolutePath: string): string {
|
||||
const relativePath = toVirtualPath(rootPath, absolutePath);
|
||||
return `/${mount}${relativePath}`;
|
||||
}
|
||||
|
||||
async function resolveCurrentRootPaths(): Promise<string[]> {
|
||||
const settings = await getAgentFilesystemSettings();
|
||||
if (settings.localRootPaths.length === 0) {
|
||||
throw new Error("No local filesystem roots selected");
|
||||
}
|
||||
return settings.localRootPaths;
|
||||
}
|
||||
|
||||
export async function readAgentLocalFileText(
|
||||
virtualPath: string
|
||||
): Promise<{ path: string; content: string }> {
|
||||
const rootPaths = await resolveCurrentRootPaths();
|
||||
const mounts = buildRootMounts(rootPaths);
|
||||
const { mount, subPath } = parseMountedVirtualPath(virtualPath, mounts);
|
||||
const rootMount = findMountByName(mounts, mount);
|
||||
if (!rootMount) {
|
||||
throw new Error(
|
||||
`Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}`
|
||||
);
|
||||
}
|
||||
const absolutePath = resolveVirtualPath(rootMount.rootPath, subPath);
|
||||
const content = await readFile(absolutePath, "utf8");
|
||||
return {
|
||||
path: toMountedVirtualPath(rootMount.mount, rootMount.rootPath, absolutePath),
|
||||
content,
|
||||
};
|
||||
}
|
||||
|
||||
export async function writeAgentLocalFileText(
|
||||
virtualPath: string,
|
||||
content: string
|
||||
): Promise<{ path: string }> {
|
||||
const rootPaths = await resolveCurrentRootPaths();
|
||||
const mounts = buildRootMounts(rootPaths);
|
||||
const { mount, subPath } = parseMountedVirtualPath(virtualPath, mounts);
|
||||
const rootMount = findMountByName(mounts, mount);
|
||||
if (!rootMount) {
|
||||
throw new Error(
|
||||
`Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}`
|
||||
);
|
||||
}
|
||||
let selectedAbsolutePath = resolveVirtualPath(rootMount.rootPath, subPath);
|
||||
|
||||
try {
|
||||
await access(selectedAbsolutePath);
|
||||
} catch {
|
||||
// New files are created under the selected mounted root.
|
||||
}
|
||||
await mkdir(dirname(selectedAbsolutePath), { recursive: true });
|
||||
await writeFile(selectedAbsolutePath, content, "utf8");
|
||||
return {
|
||||
path: toMountedVirtualPath(rootMount.mount, rootMount.rootPath, selectedAbsolutePath),
|
||||
};
|
||||
}
|
||||
|
|
@ -71,6 +71,10 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
|||
// Browse files via native dialog
|
||||
browseFiles: () => ipcRenderer.invoke(IPC_CHANNELS.BROWSE_FILES),
|
||||
readLocalFiles: (paths: string[]) => ipcRenderer.invoke(IPC_CHANNELS.READ_LOCAL_FILES, paths),
|
||||
readAgentLocalFileText: (virtualPath: string) =>
|
||||
ipcRenderer.invoke(IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, virtualPath),
|
||||
writeAgentLocalFileText: (virtualPath: string, content: string) =>
|
||||
ipcRenderer.invoke(IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT, virtualPath, content),
|
||||
|
||||
// Auth token sync across windows
|
||||
getAuthTokens: () => ipcRenderer.invoke(IPC_CHANNELS.GET_AUTH_TOKENS),
|
||||
|
|
@ -101,4 +105,14 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
|||
analyticsCapture: (event: string, properties?: Record<string, unknown>) =>
|
||||
ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_CAPTURE, { event, properties }),
|
||||
getAnalyticsContext: () => ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_GET_CONTEXT),
|
||||
// Agent filesystem mode
|
||||
getAgentFilesystemSettings: () =>
|
||||
ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS),
|
||||
getAgentFilesystemMounts: () =>
|
||||
ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS),
|
||||
setAgentFilesystemSettings: (settings: {
|
||||
mode?: "cloud" | "desktop_local_folder";
|
||||
localRootPaths?: string[] | null;
|
||||
}) => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS, settings),
|
||||
pickAgentFilesystemRoot: () => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT),
|
||||
});
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ import {
|
|||
import { useChatSessionStateSync } from "@/hooks/use-chat-session-state";
|
||||
import { useMessagesSync } from "@/hooks/use-messages-sync";
|
||||
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||
import { getAgentFilesystemSelection } from "@/lib/agent-filesystem";
|
||||
import { getBearerToken } from "@/lib/auth-utils";
|
||||
import { convertToThreadMessage } from "@/lib/chat/message-utils";
|
||||
import {
|
||||
|
|
@ -158,7 +159,7 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] {
|
|||
/**
|
||||
* Tools that should render custom UI in the chat.
|
||||
*/
|
||||
const TOOLS_WITH_UI = new Set([
|
||||
const BASE_TOOLS_WITH_UI = new Set([
|
||||
"web_search",
|
||||
"generate_podcast",
|
||||
"generate_report",
|
||||
|
|
@ -210,6 +211,7 @@ export default function NewChatPage() {
|
|||
assistantMsgId: string;
|
||||
interruptData: Record<string, unknown>;
|
||||
} | null>(null);
|
||||
const toolsWithUI = useMemo(() => new Set([...BASE_TOOLS_WITH_UI]), []);
|
||||
|
||||
// Get disabled tools from the tool toggle UI
|
||||
const disabledTools = useAtomValue(disabledToolsAtom);
|
||||
|
|
@ -656,6 +658,15 @@ export default function NewChatPage() {
|
|||
|
||||
try {
|
||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||
const selection = await getAgentFilesystemSelection();
|
||||
if (
|
||||
selection.filesystem_mode === "desktop_local_folder" &&
|
||||
(!selection.local_filesystem_mounts ||
|
||||
selection.local_filesystem_mounts.length === 0)
|
||||
) {
|
||||
toast.error("Select a local folder before using Local Folder mode.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Build message history for context
|
||||
const messageHistory = messages
|
||||
|
|
@ -691,6 +702,9 @@ export default function NewChatPage() {
|
|||
chat_id: currentThreadId,
|
||||
user_query: userQuery.trim(),
|
||||
search_space_id: searchSpaceId,
|
||||
filesystem_mode: selection.filesystem_mode,
|
||||
client_platform: selection.client_platform,
|
||||
local_filesystem_mounts: selection.local_filesystem_mounts,
|
||||
messages: messageHistory,
|
||||
mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined,
|
||||
mentioned_surfsense_doc_ids: hasSurfsenseDocIds
|
||||
|
|
@ -709,7 +723,7 @@ export default function NewChatPage() {
|
|||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
|
|
@ -724,7 +738,7 @@ export default function NewChatPage() {
|
|||
break;
|
||||
|
||||
case "tool-input-start":
|
||||
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
|
||||
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
|
||||
batcher.flush();
|
||||
break;
|
||||
|
||||
|
|
@ -734,7 +748,7 @@ export default function NewChatPage() {
|
|||
} else {
|
||||
addToolCall(
|
||||
contentPartsState,
|
||||
TOOLS_WITH_UI,
|
||||
toolsWithUI,
|
||||
parsed.toolCallId,
|
||||
parsed.toolName,
|
||||
parsed.input || {}
|
||||
|
|
@ -830,7 +844,7 @@ export default function NewChatPage() {
|
|||
const tcId = `interrupt-${action.name}`;
|
||||
addToolCall(
|
||||
contentPartsState,
|
||||
TOOLS_WITH_UI,
|
||||
toolsWithUI,
|
||||
tcId,
|
||||
action.name,
|
||||
action.args,
|
||||
|
|
@ -844,7 +858,7 @@ export default function NewChatPage() {
|
|||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
|
|
@ -871,7 +885,7 @@ export default function NewChatPage() {
|
|||
batcher.flush();
|
||||
|
||||
// Skip persistence for interrupted messages -- handleResume will persist the final version
|
||||
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
|
||||
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
||||
if (contentParts.length > 0 && !wasInterrupted) {
|
||||
try {
|
||||
const savedMessage = await appendMessage(currentThreadId, {
|
||||
|
|
@ -907,10 +921,10 @@ export default function NewChatPage() {
|
|||
const hasContent = contentParts.some(
|
||||
(part) =>
|
||||
(part.type === "text" && part.text.length > 0) ||
|
||||
(part.type === "tool-call" && TOOLS_WITH_UI.has(part.toolName))
|
||||
(part.type === "tool-call" && toolsWithUI.has(part.toolName))
|
||||
);
|
||||
if (hasContent && currentThreadId) {
|
||||
const partialContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
|
||||
const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
||||
try {
|
||||
const savedMessage = await appendMessage(currentThreadId, {
|
||||
role: "assistant",
|
||||
|
|
@ -1074,6 +1088,7 @@ export default function NewChatPage() {
|
|||
|
||||
try {
|
||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||
const selection = await getAgentFilesystemSelection();
|
||||
const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
|
|
@ -1083,6 +1098,9 @@ export default function NewChatPage() {
|
|||
body: JSON.stringify({
|
||||
search_space_id: searchSpaceId,
|
||||
decisions,
|
||||
filesystem_mode: selection.filesystem_mode,
|
||||
client_platform: selection.client_platform,
|
||||
local_filesystem_mounts: selection.local_filesystem_mounts,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
|
@ -1095,7 +1113,7 @@ export default function NewChatPage() {
|
|||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
|
|
@ -1110,7 +1128,7 @@ export default function NewChatPage() {
|
|||
break;
|
||||
|
||||
case "tool-input-start":
|
||||
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
|
||||
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
|
||||
batcher.flush();
|
||||
break;
|
||||
|
||||
|
|
@ -1122,7 +1140,7 @@ export default function NewChatPage() {
|
|||
} else {
|
||||
addToolCall(
|
||||
contentPartsState,
|
||||
TOOLS_WITH_UI,
|
||||
toolsWithUI,
|
||||
parsed.toolCallId,
|
||||
parsed.toolName,
|
||||
parsed.input || {}
|
||||
|
|
@ -1173,7 +1191,7 @@ export default function NewChatPage() {
|
|||
const tcId = `interrupt-${action.name}`;
|
||||
addToolCall(
|
||||
contentPartsState,
|
||||
TOOLS_WITH_UI,
|
||||
toolsWithUI,
|
||||
tcId,
|
||||
action.name,
|
||||
action.args,
|
||||
|
|
@ -1190,7 +1208,7 @@ export default function NewChatPage() {
|
|||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
|
|
@ -1214,7 +1232,7 @@ export default function NewChatPage() {
|
|||
|
||||
batcher.flush();
|
||||
|
||||
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
|
||||
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
||||
if (contentParts.length > 0) {
|
||||
try {
|
||||
const savedMessage = await appendMessage(resumeThreadId, {
|
||||
|
|
@ -1406,6 +1424,7 @@ export default function NewChatPage() {
|
|||
]);
|
||||
|
||||
try {
|
||||
const selection = await getAgentFilesystemSelection();
|
||||
const response = await fetch(getRegenerateUrl(threadId), {
|
||||
method: "POST",
|
||||
headers: {
|
||||
|
|
@ -1416,6 +1435,9 @@ export default function NewChatPage() {
|
|||
search_space_id: searchSpaceId,
|
||||
user_query: newUserQuery || null,
|
||||
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
|
||||
filesystem_mode: selection.filesystem_mode,
|
||||
client_platform: selection.client_platform,
|
||||
local_filesystem_mounts: selection.local_filesystem_mounts,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
|
@ -1428,7 +1450,7 @@ export default function NewChatPage() {
|
|||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
|
|
@ -1443,7 +1465,7 @@ export default function NewChatPage() {
|
|||
break;
|
||||
|
||||
case "tool-input-start":
|
||||
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
|
||||
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
|
||||
batcher.flush();
|
||||
break;
|
||||
|
||||
|
|
@ -1453,7 +1475,7 @@ export default function NewChatPage() {
|
|||
} else {
|
||||
addToolCall(
|
||||
contentPartsState,
|
||||
TOOLS_WITH_UI,
|
||||
toolsWithUI,
|
||||
parsed.toolCallId,
|
||||
parsed.toolName,
|
||||
parsed.input || {}
|
||||
|
|
@ -1502,7 +1524,7 @@ export default function NewChatPage() {
|
|||
batcher.flush();
|
||||
|
||||
// Persist messages after streaming completes
|
||||
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
|
||||
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
||||
if (contentParts.length > 0) {
|
||||
try {
|
||||
// Persist user message (for both edit and reload modes, since backend deleted it)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { BrainCog, Power, Rocket, Zap } from "lucide-react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { DEFAULT_SHORTCUTS, ShortcutRecorder } from "@/components/desktop/shortcut-recorder";
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import {
|
||||
|
|
@ -24,9 +22,6 @@ export function DesktopContent() {
|
|||
const [loading, setLoading] = useState(true);
|
||||
const [enabled, setEnabled] = useState(true);
|
||||
|
||||
const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS);
|
||||
const [shortcutsLoaded, setShortcutsLoaded] = useState(false);
|
||||
|
||||
const [searchSpaces, setSearchSpaces] = useState<SearchSpace[]>([]);
|
||||
const [activeSpaceId, setActiveSpaceId] = useState<string | null>(null);
|
||||
|
||||
|
|
@ -37,7 +32,6 @@ export function DesktopContent() {
|
|||
useEffect(() => {
|
||||
if (!api) {
|
||||
setLoading(false);
|
||||
setShortcutsLoaded(true);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -48,15 +42,13 @@ export function DesktopContent() {
|
|||
|
||||
Promise.all([
|
||||
api.getAutocompleteEnabled(),
|
||||
api.getShortcuts?.() ?? Promise.resolve(null),
|
||||
api.getActiveSearchSpace?.() ?? Promise.resolve(null),
|
||||
searchSpacesApiService.getSearchSpaces(),
|
||||
hasAutoLaunchApi ? api.getAutoLaunch() : Promise.resolve(null),
|
||||
])
|
||||
.then(([autoEnabled, config, spaceId, spaces, autoLaunch]) => {
|
||||
.then(([autoEnabled, spaceId, spaces, autoLaunch]) => {
|
||||
if (!mounted) return;
|
||||
setEnabled(autoEnabled);
|
||||
if (config) setShortcuts(config);
|
||||
setActiveSpaceId(spaceId);
|
||||
if (spaces) setSearchSpaces(spaces);
|
||||
if (autoLaunch) {
|
||||
|
|
@ -65,12 +57,10 @@ export function DesktopContent() {
|
|||
setAutoLaunchSupported(autoLaunch.supported);
|
||||
}
|
||||
setLoading(false);
|
||||
setShortcutsLoaded(true);
|
||||
})
|
||||
.catch(() => {
|
||||
if (!mounted) return;
|
||||
setLoading(false);
|
||||
setShortcutsLoaded(true);
|
||||
});
|
||||
|
||||
return () => {
|
||||
|
|
@ -82,7 +72,7 @@ export function DesktopContent() {
|
|||
return (
|
||||
<div className="flex flex-col items-center justify-center py-12 text-center">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Desktop settings are only available in the SurfSense desktop app.
|
||||
App preferences are only available in the SurfSense desktop app.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
|
|
@ -101,24 +91,6 @@ export function DesktopContent() {
|
|||
await api.setAutocompleteEnabled(checked);
|
||||
};
|
||||
|
||||
const updateShortcut = (
|
||||
key: "generalAssist" | "quickAsk" | "autocomplete",
|
||||
accelerator: string
|
||||
) => {
|
||||
setShortcuts((prev) => {
|
||||
const updated = { ...prev, [key]: accelerator };
|
||||
api.setShortcuts?.({ [key]: accelerator }).catch(() => {
|
||||
toast.error("Failed to update shortcut");
|
||||
});
|
||||
return updated;
|
||||
});
|
||||
toast.success("Shortcut updated");
|
||||
};
|
||||
|
||||
const resetShortcut = (key: "generalAssist" | "quickAsk" | "autocomplete") => {
|
||||
updateShortcut(key, DEFAULT_SHORTCUTS[key]);
|
||||
};
|
||||
|
||||
const handleAutoLaunchToggle = async (checked: boolean) => {
|
||||
if (!autoLaunchSupported || !api.setAutoLaunch) {
|
||||
toast.error("Please update the desktop app to configure launch on startup");
|
||||
|
|
@ -196,7 +168,6 @@ export function DesktopContent() {
|
|||
<Card>
|
||||
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
||||
<CardTitle className="text-base md:text-lg flex items-center gap-2">
|
||||
<Power className="h-4 w-4" />
|
||||
Launch on Startup
|
||||
</CardTitle>
|
||||
<CardDescription className="text-xs md:text-sm">
|
||||
|
|
@ -245,56 +216,6 @@ export function DesktopContent() {
|
|||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Keyboard Shortcuts */}
|
||||
<Card>
|
||||
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
||||
<CardTitle className="text-base md:text-lg">Keyboard Shortcuts</CardTitle>
|
||||
<CardDescription className="text-xs md:text-sm">
|
||||
Customize the global keyboard shortcuts for desktop features.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="px-3 md:px-6 pb-3 md:pb-6">
|
||||
{shortcutsLoaded ? (
|
||||
<div className="flex flex-col gap-3">
|
||||
<ShortcutRecorder
|
||||
value={shortcuts.generalAssist}
|
||||
onChange={(accel) => updateShortcut("generalAssist", accel)}
|
||||
onReset={() => resetShortcut("generalAssist")}
|
||||
defaultValue={DEFAULT_SHORTCUTS.generalAssist}
|
||||
label="General Assist"
|
||||
description="Launch SurfSense instantly from any application"
|
||||
icon={Rocket}
|
||||
/>
|
||||
<ShortcutRecorder
|
||||
value={shortcuts.quickAsk}
|
||||
onChange={(accel) => updateShortcut("quickAsk", accel)}
|
||||
onReset={() => resetShortcut("quickAsk")}
|
||||
defaultValue={DEFAULT_SHORTCUTS.quickAsk}
|
||||
label="Quick Assist"
|
||||
description="Select text anywhere, then ask AI to explain, rewrite, or act on it"
|
||||
icon={Zap}
|
||||
/>
|
||||
<ShortcutRecorder
|
||||
value={shortcuts.autocomplete}
|
||||
onChange={(accel) => updateShortcut("autocomplete", accel)}
|
||||
onReset={() => resetShortcut("autocomplete")}
|
||||
defaultValue={DEFAULT_SHORTCUTS.autocomplete}
|
||||
label="Extreme Assist"
|
||||
description="AI drafts text using your screen context and knowledge base"
|
||||
icon={BrainCog}
|
||||
/>
|
||||
<p className="text-[11px] text-muted-foreground">
|
||||
Click a shortcut and press a new key combination to change it.
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex justify-center py-4">
|
||||
<Spinner size="sm" />
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Extreme Assist Toggle */}
|
||||
<Card>
|
||||
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue