mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-15 18:25:18 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/obsidian-plugin
This commit is contained in:
commit
9b1b9a90c0
175 changed files with 10592 additions and 2302 deletions
|
|
@ -77,6 +77,8 @@ services:
|
||||||
- shared_temp:/shared_tmp
|
- shared_temp:/shared_tmp
|
||||||
env_file:
|
env_file:
|
||||||
- ../surfsense_backend/.env
|
- ../surfsense_backend/.env
|
||||||
|
extra_hosts:
|
||||||
|
- "host.docker.internal:host-gateway"
|
||||||
environment:
|
environment:
|
||||||
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||||
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
|
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
|
||||||
|
|
@ -118,6 +120,8 @@ services:
|
||||||
- shared_temp:/shared_tmp
|
- shared_temp:/shared_tmp
|
||||||
env_file:
|
env_file:
|
||||||
- ../surfsense_backend/.env
|
- ../surfsense_backend/.env
|
||||||
|
extra_hosts:
|
||||||
|
- "host.docker.internal:host-gateway"
|
||||||
environment:
|
environment:
|
||||||
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||||
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
|
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,8 @@ services:
|
||||||
- shared_temp:/shared_tmp
|
- shared_temp:/shared_tmp
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
|
extra_hosts:
|
||||||
|
- "host.docker.internal:host-gateway"
|
||||||
environment:
|
environment:
|
||||||
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||||
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
|
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
|
||||||
|
|
@ -100,6 +102,8 @@ services:
|
||||||
- shared_temp:/shared_tmp
|
- shared_temp:/shared_tmp
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
|
extra_hosts:
|
||||||
|
- "host.docker.internal:host-gateway"
|
||||||
environment:
|
environment:
|
||||||
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||||
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
|
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
|
||||||
|
|
|
||||||
|
|
@ -239,6 +239,9 @@ LLAMA_CLOUD_API_KEY=llx-nnn
|
||||||
# DAYTONA_TARGET=us
|
# DAYTONA_TARGET=us
|
||||||
# DAYTONA_SNAPSHOT_ID=
|
# DAYTONA_SNAPSHOT_ID=
|
||||||
|
|
||||||
|
# Desktop local filesystem mode (chat file tools run against a local folder root)
|
||||||
|
# ENABLE_DESKTOP_LOCAL_FILESYSTEM=FALSE
|
||||||
|
|
||||||
# OPTIONAL: Add these for LangSmith Observability
|
# OPTIONAL: Add these for LangSmith Observability
|
||||||
LANGSMITH_TRACING=true
|
LANGSMITH_TRACING=true
|
||||||
LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@ from deepagents.backends import StateBackend
|
||||||
from deepagents.graph import BASE_AGENT_PROMPT
|
from deepagents.graph import BASE_AGENT_PROMPT
|
||||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||||
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||||
from deepagents.middleware.summarization import create_summarization_middleware
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.agents.middleware import TodoListMiddleware
|
from langchain.agents.middleware import TodoListMiddleware
|
||||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||||
|
|
@ -34,18 +33,24 @@ from langgraph.types import Checkpointer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||||
|
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemSelection
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.new_chat.llm_config import AgentConfig
|
||||||
from app.agents.new_chat.middleware import (
|
from app.agents.new_chat.middleware import (
|
||||||
DedupHITLToolCallsMiddleware,
|
DedupHITLToolCallsMiddleware,
|
||||||
|
FileIntentMiddleware,
|
||||||
KnowledgeBaseSearchMiddleware,
|
KnowledgeBaseSearchMiddleware,
|
||||||
MemoryInjectionMiddleware,
|
MemoryInjectionMiddleware,
|
||||||
SurfSenseFilesystemMiddleware,
|
SurfSenseFilesystemMiddleware,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.safe_summarization import (
|
||||||
|
create_safe_summarization_middleware,
|
||||||
|
)
|
||||||
from app.agents.new_chat.system_prompt import (
|
from app.agents.new_chat.system_prompt import (
|
||||||
build_configurable_system_prompt,
|
build_configurable_system_prompt,
|
||||||
build_surfsense_system_prompt,
|
build_surfsense_system_prompt,
|
||||||
)
|
)
|
||||||
from app.agents.new_chat.tools.registry import build_tools_async
|
from app.agents.new_chat.tools.registry import build_tools_async, get_connector_gated_tools
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
@ -162,6 +167,7 @@ async def create_surfsense_deep_agent(
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
mentioned_document_ids: list[int] | None = None,
|
mentioned_document_ids: list[int] | None = None,
|
||||||
anon_session_id: str | None = None,
|
anon_session_id: str | None = None,
|
||||||
|
filesystem_selection: FilesystemSelection | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a SurfSense deep agent with configurable tools and prompts.
|
Create a SurfSense deep agent with configurable tools and prompts.
|
||||||
|
|
@ -236,6 +242,8 @@ async def create_surfsense_deep_agent(
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
_t_agent_total = time.perf_counter()
|
_t_agent_total = time.perf_counter()
|
||||||
|
filesystem_selection = filesystem_selection or FilesystemSelection()
|
||||||
|
backend_resolver = build_backend_resolver(filesystem_selection)
|
||||||
|
|
||||||
# Discover available connectors and document types for this search space
|
# Discover available connectors and document types for this search space
|
||||||
available_connectors: list[str] | None = None
|
available_connectors: list[str] | None = None
|
||||||
|
|
@ -285,105 +293,10 @@ async def create_surfsense_deep_agent(
|
||||||
"llm": llm,
|
"llm": llm,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Disable Notion action tools if no Notion connector is configured
|
|
||||||
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
|
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
|
||||||
has_notion_connector = (
|
modified_disabled_tools.extend(
|
||||||
available_connectors is not None and "NOTION_CONNECTOR" in available_connectors
|
get_connector_gated_tools(available_connectors)
|
||||||
)
|
)
|
||||||
if not has_notion_connector:
|
|
||||||
notion_tools = [
|
|
||||||
"create_notion_page",
|
|
||||||
"update_notion_page",
|
|
||||||
"delete_notion_page",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(notion_tools)
|
|
||||||
|
|
||||||
# Disable Linear action tools if no Linear connector is configured
|
|
||||||
has_linear_connector = (
|
|
||||||
available_connectors is not None and "LINEAR_CONNECTOR" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_linear_connector:
|
|
||||||
linear_tools = [
|
|
||||||
"create_linear_issue",
|
|
||||||
"update_linear_issue",
|
|
||||||
"delete_linear_issue",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(linear_tools)
|
|
||||||
|
|
||||||
# Disable Google Drive action tools if no Google Drive connector is configured
|
|
||||||
has_google_drive_connector = (
|
|
||||||
available_connectors is not None and "GOOGLE_DRIVE_FILE" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_google_drive_connector:
|
|
||||||
google_drive_tools = [
|
|
||||||
"create_google_drive_file",
|
|
||||||
"delete_google_drive_file",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(google_drive_tools)
|
|
||||||
|
|
||||||
has_dropbox_connector = (
|
|
||||||
available_connectors is not None and "DROPBOX_FILE" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_dropbox_connector:
|
|
||||||
modified_disabled_tools.extend(["create_dropbox_file", "delete_dropbox_file"])
|
|
||||||
|
|
||||||
has_onedrive_connector = (
|
|
||||||
available_connectors is not None and "ONEDRIVE_FILE" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_onedrive_connector:
|
|
||||||
modified_disabled_tools.extend(["create_onedrive_file", "delete_onedrive_file"])
|
|
||||||
|
|
||||||
# Disable Google Calendar action tools if no Google Calendar connector is configured
|
|
||||||
has_google_calendar_connector = (
|
|
||||||
available_connectors is not None
|
|
||||||
and "GOOGLE_CALENDAR_CONNECTOR" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_google_calendar_connector:
|
|
||||||
calendar_tools = [
|
|
||||||
"create_calendar_event",
|
|
||||||
"update_calendar_event",
|
|
||||||
"delete_calendar_event",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(calendar_tools)
|
|
||||||
|
|
||||||
# Disable Gmail action tools if no Gmail connector is configured
|
|
||||||
has_gmail_connector = (
|
|
||||||
available_connectors is not None
|
|
||||||
and "GOOGLE_GMAIL_CONNECTOR" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_gmail_connector:
|
|
||||||
gmail_tools = [
|
|
||||||
"create_gmail_draft",
|
|
||||||
"update_gmail_draft",
|
|
||||||
"send_gmail_email",
|
|
||||||
"trash_gmail_email",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(gmail_tools)
|
|
||||||
|
|
||||||
# Disable Jira action tools if no Jira connector is configured
|
|
||||||
has_jira_connector = (
|
|
||||||
available_connectors is not None and "JIRA_CONNECTOR" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_jira_connector:
|
|
||||||
jira_tools = [
|
|
||||||
"create_jira_issue",
|
|
||||||
"update_jira_issue",
|
|
||||||
"delete_jira_issue",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(jira_tools)
|
|
||||||
|
|
||||||
# Disable Confluence action tools if no Confluence connector is configured
|
|
||||||
has_confluence_connector = (
|
|
||||||
available_connectors is not None
|
|
||||||
and "CONFLUENCE_CONNECTOR" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_confluence_connector:
|
|
||||||
confluence_tools = [
|
|
||||||
"create_confluence_page",
|
|
||||||
"update_confluence_page",
|
|
||||||
"delete_confluence_page",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(confluence_tools)
|
|
||||||
|
|
||||||
# Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware.
|
# Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware.
|
||||||
if "search_knowledge_base" not in modified_disabled_tools:
|
if "search_knowledge_base" not in modified_disabled_tools:
|
||||||
|
|
@ -407,6 +320,20 @@ async def create_surfsense_deep_agent(
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
_enabled_tool_names = {t.name for t in tools}
|
_enabled_tool_names = {t.name for t in tools}
|
||||||
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
|
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
|
||||||
|
|
||||||
|
# Collect generic MCP connector info so the system prompt can route queries
|
||||||
|
# to their tools instead of falling back to "not in knowledge base".
|
||||||
|
_mcp_connector_tools: dict[str, list[str]] = {}
|
||||||
|
for t in tools:
|
||||||
|
meta = getattr(t, "metadata", None) or {}
|
||||||
|
if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"):
|
||||||
|
_mcp_connector_tools.setdefault(
|
||||||
|
meta["mcp_connector_name"], [],
|
||||||
|
).append(t.name)
|
||||||
|
|
||||||
|
if _mcp_connector_tools:
|
||||||
|
_perf_log.info("MCP connector tool routing: %s", _mcp_connector_tools)
|
||||||
|
|
||||||
if agent_config is not None:
|
if agent_config is not None:
|
||||||
system_prompt = build_configurable_system_prompt(
|
system_prompt = build_configurable_system_prompt(
|
||||||
custom_system_instructions=agent_config.system_instructions,
|
custom_system_instructions=agent_config.system_instructions,
|
||||||
|
|
@ -415,12 +342,14 @@ async def create_surfsense_deep_agent(
|
||||||
thread_visibility=thread_visibility,
|
thread_visibility=thread_visibility,
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
|
mcp_connector_tools=_mcp_connector_tools,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
system_prompt = build_surfsense_system_prompt(
|
system_prompt = build_surfsense_system_prompt(
|
||||||
thread_visibility=thread_visibility,
|
thread_visibility=thread_visibility,
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
|
mcp_connector_tools=_mcp_connector_tools,
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
||||||
|
|
@ -437,12 +366,15 @@ async def create_surfsense_deep_agent(
|
||||||
gp_middleware = [
|
gp_middleware = [
|
||||||
TodoListMiddleware(),
|
TodoListMiddleware(),
|
||||||
_memory_middleware,
|
_memory_middleware,
|
||||||
|
FileIntentMiddleware(llm=llm),
|
||||||
SurfSenseFilesystemMiddleware(
|
SurfSenseFilesystemMiddleware(
|
||||||
|
backend=backend_resolver,
|
||||||
|
filesystem_mode=filesystem_selection.mode,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
created_by_id=user_id,
|
created_by_id=user_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
),
|
),
|
||||||
create_summarization_middleware(llm, StateBackend),
|
create_safe_summarization_middleware(llm, StateBackend),
|
||||||
PatchToolCallsMiddleware(),
|
PatchToolCallsMiddleware(),
|
||||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||||
]
|
]
|
||||||
|
|
@ -458,21 +390,25 @@ async def create_surfsense_deep_agent(
|
||||||
deepagent_middleware = [
|
deepagent_middleware = [
|
||||||
TodoListMiddleware(),
|
TodoListMiddleware(),
|
||||||
_memory_middleware,
|
_memory_middleware,
|
||||||
|
FileIntentMiddleware(llm=llm),
|
||||||
KnowledgeBaseSearchMiddleware(
|
KnowledgeBaseSearchMiddleware(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
|
filesystem_mode=filesystem_selection.mode,
|
||||||
available_connectors=available_connectors,
|
available_connectors=available_connectors,
|
||||||
available_document_types=available_document_types,
|
available_document_types=available_document_types,
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
anon_session_id=anon_session_id,
|
anon_session_id=anon_session_id,
|
||||||
),
|
),
|
||||||
SurfSenseFilesystemMiddleware(
|
SurfSenseFilesystemMiddleware(
|
||||||
|
backend=backend_resolver,
|
||||||
|
filesystem_mode=filesystem_selection.mode,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
created_by_id=user_id,
|
created_by_id=user_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
),
|
),
|
||||||
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
|
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
|
||||||
create_summarization_middleware(llm, StateBackend),
|
create_safe_summarization_middleware(llm, StateBackend),
|
||||||
PatchToolCallsMiddleware(),
|
PatchToolCallsMiddleware(),
|
||||||
DedupHITLToolCallsMiddleware(agent_tools=tools),
|
DedupHITLToolCallsMiddleware(agent_tools=tools),
|
||||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,15 @@ Context schema definitions for SurfSense agents.
|
||||||
This module defines the custom state schema used by the SurfSense deep agent.
|
This module defines the custom state schema used by the SurfSense deep agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TypedDict
|
from typing import NotRequired, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class FileOperationContractState(TypedDict):
|
||||||
|
intent: str
|
||||||
|
confidence: float
|
||||||
|
suggested_path: str
|
||||||
|
timestamp: str
|
||||||
|
turn_id: str
|
||||||
|
|
||||||
|
|
||||||
class SurfSenseContextSchema(TypedDict):
|
class SurfSenseContextSchema(TypedDict):
|
||||||
|
|
@ -24,5 +32,8 @@ class SurfSenseContextSchema(TypedDict):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
|
file_operation_contract: NotRequired[FileOperationContractState]
|
||||||
|
turn_id: NotRequired[str]
|
||||||
|
request_id: NotRequired[str]
|
||||||
# These are runtime-injected and won't be serialized
|
# These are runtime-injected and won't be serialized
|
||||||
# db_session and connector_service are passed when invoking the agent
|
# db_session and connector_service are passed when invoking the agent
|
||||||
|
|
|
||||||
42
surfsense_backend/app/agents/new_chat/filesystem_backends.py
Normal file
42
surfsense_backend/app/agents/new_chat/filesystem_backends.py
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
"""Filesystem backend resolver for cloud and desktop-local modes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from deepagents.backends.state import StateBackend
|
||||||
|
from langgraph.prebuilt.tool_node import ToolRuntime
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||||
|
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||||
|
MultiRootLocalFolderBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=64)
|
||||||
|
def _cached_multi_root_backend(
|
||||||
|
mounts: tuple[tuple[str, str], ...],
|
||||||
|
) -> MultiRootLocalFolderBackend:
|
||||||
|
return MultiRootLocalFolderBackend(mounts)
|
||||||
|
|
||||||
|
|
||||||
|
def build_backend_resolver(
|
||||||
|
selection: FilesystemSelection,
|
||||||
|
) -> Callable[[ToolRuntime], StateBackend | MultiRootLocalFolderBackend]:
|
||||||
|
"""Create deepagents backend resolver for the selected filesystem mode."""
|
||||||
|
|
||||||
|
if selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER and selection.local_mounts:
|
||||||
|
|
||||||
|
def _resolve_local(_runtime: ToolRuntime) -> MultiRootLocalFolderBackend:
|
||||||
|
mounts = tuple(
|
||||||
|
(entry.mount_id, entry.root_path) for entry in selection.local_mounts
|
||||||
|
)
|
||||||
|
return _cached_multi_root_backend(mounts)
|
||||||
|
|
||||||
|
return _resolve_local
|
||||||
|
|
||||||
|
def _resolve_cloud(runtime: ToolRuntime) -> StateBackend:
|
||||||
|
return StateBackend(runtime)
|
||||||
|
|
||||||
|
return _resolve_cloud
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
"""Filesystem mode contracts and selection helpers for chat sessions."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class FilesystemMode(StrEnum):
|
||||||
|
"""Supported filesystem backends for agent tool execution."""
|
||||||
|
|
||||||
|
CLOUD = "cloud"
|
||||||
|
DESKTOP_LOCAL_FOLDER = "desktop_local_folder"
|
||||||
|
|
||||||
|
|
||||||
|
class ClientPlatform(StrEnum):
|
||||||
|
"""Client runtime reported by the caller."""
|
||||||
|
|
||||||
|
WEB = "web"
|
||||||
|
DESKTOP = "desktop"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class LocalFilesystemMount:
|
||||||
|
"""Canonical mount mapping provided by desktop runtime."""
|
||||||
|
|
||||||
|
mount_id: str
|
||||||
|
root_path: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class FilesystemSelection:
|
||||||
|
"""Resolved filesystem selection for a single chat request."""
|
||||||
|
|
||||||
|
mode: FilesystemMode = FilesystemMode.CLOUD
|
||||||
|
client_platform: ClientPlatform = ClientPlatform.WEB
|
||||||
|
local_mounts: tuple[LocalFilesystemMount, ...] = ()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_local_mode(self) -> bool:
|
||||||
|
return self.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||||
|
|
@ -6,6 +6,9 @@ from app.agents.new_chat.middleware.dedup_tool_calls import (
|
||||||
from app.agents.new_chat.middleware.filesystem import (
|
from app.agents.new_chat.middleware.filesystem import (
|
||||||
SurfSenseFilesystemMiddleware,
|
SurfSenseFilesystemMiddleware,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.file_intent import (
|
||||||
|
FileIntentMiddleware,
|
||||||
|
)
|
||||||
from app.agents.new_chat.middleware.knowledge_search import (
|
from app.agents.new_chat.middleware.knowledge_search import (
|
||||||
KnowledgeBaseSearchMiddleware,
|
KnowledgeBaseSearchMiddleware,
|
||||||
)
|
)
|
||||||
|
|
@ -15,6 +18,7 @@ from app.agents.new_chat.middleware.memory_injection import (
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DedupHITLToolCallsMiddleware",
|
"DedupHITLToolCallsMiddleware",
|
||||||
|
"FileIntentMiddleware",
|
||||||
"KnowledgeBaseSearchMiddleware",
|
"KnowledgeBaseSearchMiddleware",
|
||||||
"MemoryInjectionMiddleware",
|
"MemoryInjectionMiddleware",
|
||||||
"SurfSenseFilesystemMiddleware",
|
"SurfSenseFilesystemMiddleware",
|
||||||
|
|
|
||||||
352
surfsense_backend/app/agents/new_chat/middleware/file_intent.py
Normal file
352
surfsense_backend/app/agents/new_chat/middleware/file_intent.py
Normal file
|
|
@ -0,0 +1,352 @@
|
||||||
|
"""Semantic file-intent routing middleware for new chat turns.
|
||||||
|
|
||||||
|
This middleware classifies the latest human turn into a small intent set:
|
||||||
|
- chat_only
|
||||||
|
- file_write
|
||||||
|
- file_read
|
||||||
|
|
||||||
|
For ``file_write`` turns it injects a strict system contract so the model
|
||||||
|
uses filesystem tools before claiming success, and provides a deterministic
|
||||||
|
fallback path when no filename is specified by the user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FileOperationIntent(StrEnum):
|
||||||
|
CHAT_ONLY = "chat_only"
|
||||||
|
FILE_WRITE = "file_write"
|
||||||
|
FILE_READ = "file_read"
|
||||||
|
|
||||||
|
|
||||||
|
class FileIntentPlan(BaseModel):
|
||||||
|
intent: FileOperationIntent = Field(
|
||||||
|
description="Primary user intent for this turn."
|
||||||
|
)
|
||||||
|
confidence: float = Field(
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
default=0.5,
|
||||||
|
description="Model confidence in the selected intent.",
|
||||||
|
)
|
||||||
|
suggested_filename: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional filename (e.g. notes.md) inferred from user request.",
|
||||||
|
)
|
||||||
|
suggested_directory: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Optional directory path (e.g. /reports/q2 or reports/q2) inferred from "
|
||||||
|
"user request."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
suggested_path: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Optional full file path (e.g. /reports/q2/summary.md). If present, this "
|
||||||
|
"takes precedence over suggested_directory + suggested_filename."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||||
|
content = getattr(message, "content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict) and item.get("type") == "text":
|
||||||
|
parts.append(str(item.get("text", "")))
|
||||||
|
return "\n".join(part for part in parts if part)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_payload(text: str) -> str:
|
||||||
|
stripped = text.strip()
|
||||||
|
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
||||||
|
if fenced:
|
||||||
|
return fenced.group(1)
|
||||||
|
start = stripped.find("{")
|
||||||
|
end = stripped.rfind("}")
|
||||||
|
if start != -1 and end != -1 and end > start:
|
||||||
|
return stripped[start : end + 1]
|
||||||
|
return stripped
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_filename(value: str) -> str:
|
||||||
|
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
||||||
|
name = re.sub(r"\s+", "-", name)
|
||||||
|
name = name.strip("._-")
|
||||||
|
if not name:
|
||||||
|
name = "note"
|
||||||
|
if len(name) > 80:
|
||||||
|
name = name[:80].rstrip("-_.")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_path_segment(value: str) -> str:
|
||||||
|
segment = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
||||||
|
segment = re.sub(r"\s+", "_", segment)
|
||||||
|
segment = segment.strip("._-")
|
||||||
|
return segment
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_text_file_extension(user_text: str) -> str:
|
||||||
|
lowered = user_text.lower()
|
||||||
|
if any(token in lowered for token in ("json", ".json")):
|
||||||
|
return ".json"
|
||||||
|
if any(token in lowered for token in ("yaml", "yml", ".yaml", ".yml")):
|
||||||
|
return ".yaml"
|
||||||
|
if any(token in lowered for token in ("csv", ".csv")):
|
||||||
|
return ".csv"
|
||||||
|
if any(token in lowered for token in ("python", ".py")):
|
||||||
|
return ".py"
|
||||||
|
if any(token in lowered for token in ("typescript", ".ts", ".tsx")):
|
||||||
|
return ".ts"
|
||||||
|
if any(token in lowered for token in ("javascript", ".js", ".mjs", ".cjs")):
|
||||||
|
return ".js"
|
||||||
|
if any(token in lowered for token in ("html", ".html")):
|
||||||
|
return ".html"
|
||||||
|
if any(token in lowered for token in ("css", ".css")):
|
||||||
|
return ".css"
|
||||||
|
if any(token in lowered for token in ("sql", ".sql")):
|
||||||
|
return ".sql"
|
||||||
|
if any(token in lowered for token in ("toml", ".toml")):
|
||||||
|
return ".toml"
|
||||||
|
if any(token in lowered for token in ("ini", ".ini")):
|
||||||
|
return ".ini"
|
||||||
|
if any(token in lowered for token in ("xml", ".xml")):
|
||||||
|
return ".xml"
|
||||||
|
if any(token in lowered for token in ("markdown", ".md", "readme")):
|
||||||
|
return ".md"
|
||||||
|
return ".md"
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_directory(value: str) -> str:
|
||||||
|
raw = value.strip().replace("\\", "/")
|
||||||
|
raw = raw.strip("/")
|
||||||
|
if not raw:
|
||||||
|
return ""
|
||||||
|
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
|
||||||
|
parts = [part for part in parts if part]
|
||||||
|
return "/".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_file_path(value: str) -> str:
|
||||||
|
raw = value.strip().replace("\\", "/").strip()
|
||||||
|
if not raw:
|
||||||
|
return ""
|
||||||
|
had_trailing_slash = raw.endswith("/")
|
||||||
|
raw = raw.strip("/")
|
||||||
|
if not raw:
|
||||||
|
return ""
|
||||||
|
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
|
||||||
|
parts = [part for part in parts if part]
|
||||||
|
if not parts:
|
||||||
|
return ""
|
||||||
|
if had_trailing_slash:
|
||||||
|
return f"/{'/'.join(parts)}/"
|
||||||
|
return f"/{'/'.join(parts)}"
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_directory_from_user_text(user_text: str) -> str | None:
|
||||||
|
patterns = (
|
||||||
|
r"\b(?:in|inside|under)\s+(?:the\s+)?([a-zA-Z0-9 _\-/]+?)\s+folder\b",
|
||||||
|
r"\b(?:in|inside|under)\s+([a-zA-Z0-9 _\-/]+?)\b",
|
||||||
|
)
|
||||||
|
lowered = user_text.lower()
|
||||||
|
for pattern in patterns:
|
||||||
|
match = re.search(pattern, lowered, flags=re.IGNORECASE)
|
||||||
|
if not match:
|
||||||
|
continue
|
||||||
|
candidate = match.group(1).strip()
|
||||||
|
if candidate in {"the", "a", "an"}:
|
||||||
|
continue
|
||||||
|
normalized = _normalize_directory(candidate)
|
||||||
|
if normalized:
|
||||||
|
return normalized
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_path(
|
||||||
|
suggested_filename: str | None,
|
||||||
|
*,
|
||||||
|
suggested_directory: str | None = None,
|
||||||
|
suggested_path: str | None = None,
|
||||||
|
user_text: str,
|
||||||
|
) -> str:
|
||||||
|
default_extension = _infer_text_file_extension(user_text)
|
||||||
|
inferred_dir = _infer_directory_from_user_text(user_text)
|
||||||
|
|
||||||
|
sanitized_filename = ""
|
||||||
|
if suggested_filename:
|
||||||
|
sanitized_filename = _sanitize_filename(suggested_filename)
|
||||||
|
if sanitized_filename.lower().endswith(".txt"):
|
||||||
|
sanitized_filename = f"{sanitized_filename[:-4]}.md"
|
||||||
|
if not sanitized_filename:
|
||||||
|
sanitized_filename = f"notes{default_extension}"
|
||||||
|
elif "." not in sanitized_filename:
|
||||||
|
sanitized_filename = f"{sanitized_filename}{default_extension}"
|
||||||
|
|
||||||
|
normalized_suggested_path = (
|
||||||
|
_normalize_file_path(suggested_path) if suggested_path else ""
|
||||||
|
)
|
||||||
|
if normalized_suggested_path:
|
||||||
|
if normalized_suggested_path.endswith("/"):
|
||||||
|
return f"{normalized_suggested_path.rstrip('/')}/{sanitized_filename}"
|
||||||
|
return normalized_suggested_path
|
||||||
|
|
||||||
|
directory = _normalize_directory(suggested_directory or "")
|
||||||
|
if not directory and inferred_dir:
|
||||||
|
directory = inferred_dir
|
||||||
|
if directory:
|
||||||
|
return f"/{directory}/{sanitized_filename}"
|
||||||
|
|
||||||
|
return f"/{sanitized_filename}"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str:
|
||||||
|
return (
|
||||||
|
"Classify the latest user request into a filesystem intent for an AI agent.\n"
|
||||||
|
"Return JSON only with this exact schema:\n"
|
||||||
|
'{"intent":"chat_only|file_write|file_read","confidence":0.0,"suggested_filename":"string or null","suggested_directory":"string or null","suggested_path":"string or null"}\n\n'
|
||||||
|
"Rules:\n"
|
||||||
|
"- Use semantic intent, not literal keywords.\n"
|
||||||
|
"- file_write: user asks to create/save/write/update/edit content as a file.\n"
|
||||||
|
"- file_read: user asks to open/read/list/search existing files.\n"
|
||||||
|
"- chat_only: conversational/analysis responses without required file operations.\n"
|
||||||
|
"- For file_write, choose a concise semantic suggested_filename and match the requested format.\n"
|
||||||
|
"- If the user mentions a folder/directory, populate suggested_directory.\n"
|
||||||
|
"- If user specifies an explicit full path, populate suggested_path.\n"
|
||||||
|
"- Use extensions that match user intent (e.g. .md, .json, .yaml, .csv, .py, .ts, .js, .html, .css, .sql).\n"
|
||||||
|
"- Do not use .txt; prefer .md for generic text notes.\n"
|
||||||
|
"- Do not include dates or timestamps in suggested_filename unless explicitly requested.\n"
|
||||||
|
"- Never include markdown or explanation.\n\n"
|
||||||
|
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
|
||||||
|
f"Latest user message:\n{user_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_recent_conversation(messages: list[BaseMessage], *, max_messages: int = 6) -> str:
|
||||||
|
rows: list[str] = []
|
||||||
|
for msg in messages[-max_messages:]:
|
||||||
|
role = "user" if isinstance(msg, HumanMessage) else "assistant"
|
||||||
|
text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip()
|
||||||
|
if text:
|
||||||
|
rows.append(f"{role}: {text[:280]}")
|
||||||
|
return "\n".join(rows)
|
||||||
|
|
||||||
|
|
||||||
|
class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
"""Classify file intent and inject a strict file-write contract."""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
|
||||||
|
def __init__(self, *, llm: BaseChatModel | None = None) -> None:
|
||||||
|
self.llm = llm
|
||||||
|
|
||||||
|
async def _classify_intent(
|
||||||
|
self, *, messages: list[BaseMessage], user_text: str
|
||||||
|
) -> FileIntentPlan:
|
||||||
|
if self.llm is None:
|
||||||
|
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
|
||||||
|
|
||||||
|
prompt = _build_classifier_prompt(
|
||||||
|
recent_conversation=_build_recent_conversation(messages),
|
||||||
|
user_text=user_text,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = await self.llm.ainvoke(
|
||||||
|
[HumanMessage(content=prompt)],
|
||||||
|
config={"tags": ["surfsense:internal"]},
|
||||||
|
)
|
||||||
|
payload = json.loads(_extract_json_payload(_extract_text_from_message(response)))
|
||||||
|
plan = FileIntentPlan.model_validate(payload)
|
||||||
|
return plan
|
||||||
|
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
|
||||||
|
logger.warning("File intent classifier returned invalid output: %s", exc)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive fallback
|
||||||
|
logger.warning("File intent classifier failed: %s", exc)
|
||||||
|
|
||||||
|
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
|
||||||
|
|
||||||
|
async def abefore_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
del runtime
|
||||||
|
messages = state.get("messages") or []
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
last_human: HumanMessage | None = None
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
last_human = msg
|
||||||
|
break
|
||||||
|
if last_human is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_text = _extract_text_from_message(last_human).strip()
|
||||||
|
if not user_text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
plan = await self._classify_intent(messages=messages, user_text=user_text)
|
||||||
|
suggested_path = _fallback_path(
|
||||||
|
plan.suggested_filename,
|
||||||
|
suggested_directory=plan.suggested_directory,
|
||||||
|
suggested_path=plan.suggested_path,
|
||||||
|
user_text=user_text,
|
||||||
|
)
|
||||||
|
contract = {
|
||||||
|
"intent": plan.intent.value,
|
||||||
|
"confidence": plan.confidence,
|
||||||
|
"suggested_path": suggested_path,
|
||||||
|
"timestamp": datetime.now(UTC).isoformat(),
|
||||||
|
"turn_id": state.get("turn_id", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
if plan.intent != FileOperationIntent.FILE_WRITE:
|
||||||
|
return {"file_operation_contract": contract}
|
||||||
|
|
||||||
|
contract_msg = SystemMessage(
|
||||||
|
content=(
|
||||||
|
"<file_operation_contract>\n"
|
||||||
|
"This turn intent is file_write.\n"
|
||||||
|
f"Suggested default path: {suggested_path}\n"
|
||||||
|
"Rules:\n"
|
||||||
|
"- You MUST call write_file or edit_file before claiming success.\n"
|
||||||
|
"- If no path is provided by the user, use the suggested default path.\n"
|
||||||
|
"- Do not claim a file was created/updated unless tool output confirms it.\n"
|
||||||
|
"- If the write/edit fails, clearly report failure instead of success.\n"
|
||||||
|
"- Do not include timestamps or dates in generated file content unless the user explicitly asks for them.\n"
|
||||||
|
"- For open-ended requests (e.g., random note), generate useful concrete content, not placeholders.\n"
|
||||||
|
"</file_operation_contract>"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert just before the latest human turn so it applies to this request.
|
||||||
|
new_messages = list(messages)
|
||||||
|
insert_at = max(len(new_messages) - 1, 0)
|
||||||
|
new_messages.insert(insert_at, contract_msg)
|
||||||
|
return {"messages": new_messages, "file_operation_contract": contract}
|
||||||
|
|
||||||
|
|
@ -26,6 +26,10 @@ from langchain_core.tools import BaseTool, StructuredTool
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
from sqlalchemy import delete, select
|
from sqlalchemy import delete, select
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||||
|
MultiRootLocalFolderBackend,
|
||||||
|
)
|
||||||
from app.agents.new_chat.sandbox import (
|
from app.agents.new_chat.sandbox import (
|
||||||
_evict_sandbox_cache,
|
_evict_sandbox_cache,
|
||||||
delete_sandbox,
|
delete_sandbox,
|
||||||
|
|
@ -50,6 +54,8 @@ SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = """## Following Conventions
|
||||||
|
|
||||||
- Read files before editing — understand existing content before making changes.
|
- Read files before editing — understand existing content before making changes.
|
||||||
- Mimic existing style, naming conventions, and patterns.
|
- Mimic existing style, naming conventions, and patterns.
|
||||||
|
- Never claim a file was created/updated unless filesystem tool output confirms success.
|
||||||
|
- If a file write/edit fails, explicitly report the failure.
|
||||||
|
|
||||||
## Filesystem Tools
|
## Filesystem Tools
|
||||||
|
|
||||||
|
|
@ -109,13 +115,20 @@ Usage:
|
||||||
- Use chunk IDs (`<chunk id='...'>`) as citations in answers.
|
- Use chunk IDs (`<chunk id='...'>`) as citations in answers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new file to the in-memory filesystem (session-only).
|
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new text file to the in-memory filesystem (session-only).
|
||||||
|
|
||||||
Use this to create scratch/working files during the conversation. Files created
|
Use this to create scratch/working files during the conversation. Files created
|
||||||
here are ephemeral and will not be saved to the user's knowledge base.
|
here are ephemeral and will not be saved to the user's knowledge base.
|
||||||
|
|
||||||
To permanently save a document to the user's knowledge base, use the
|
To permanently save a document to the user's knowledge base, use the
|
||||||
`save_document` tool instead.
|
`save_document` tool instead.
|
||||||
|
|
||||||
|
Supported outputs include common LLM-friendly text formats like markdown, json,
|
||||||
|
yaml, csv, xml, html, css, sql, and code files.
|
||||||
|
|
||||||
|
When creating content from open-ended prompts, produce concrete and useful text,
|
||||||
|
not placeholders. Avoid adding dates/timestamps unless the user explicitly asks
|
||||||
|
for them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files.
|
SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files.
|
||||||
|
|
@ -182,11 +195,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
backend: Any = None,
|
||||||
|
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
created_by_id: str | None = None,
|
created_by_id: str | None = None,
|
||||||
thread_id: int | str | None = None,
|
thread_id: int | str | None = None,
|
||||||
tool_token_limit_before_evict: int | None = 20000,
|
tool_token_limit_before_evict: int | None = 20000,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self._filesystem_mode = filesystem_mode
|
||||||
self._search_space_id = search_space_id
|
self._search_space_id = search_space_id
|
||||||
self._created_by_id = created_by_id
|
self._created_by_id = created_by_id
|
||||||
self._thread_id = thread_id
|
self._thread_id = thread_id
|
||||||
|
|
@ -204,8 +220,17 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
" extract the data, write it as a clean file (CSV, JSON, etc.),"
|
" extract the data, write it as a clean file (CSV, JSON, etc.),"
|
||||||
" and then run your code against it."
|
" and then run your code against it."
|
||||||
)
|
)
|
||||||
|
if filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
|
||||||
|
system_prompt += (
|
||||||
|
"\n\n## Local Folder Mode"
|
||||||
|
"\n\nThis chat is running in desktop local-folder mode."
|
||||||
|
" Keep all file operations local. Do not use save_document."
|
||||||
|
" Always use mount-prefixed absolute paths like /<folder>/file.ext."
|
||||||
|
" If you are unsure which mounts are available, call ls('/') first."
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
backend=backend,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
custom_tool_descriptions={
|
custom_tool_descriptions={
|
||||||
"ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION,
|
"ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION,
|
||||||
|
|
@ -219,7 +244,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
max_execute_timeout=self._MAX_EXECUTE_TIMEOUT,
|
max_execute_timeout=self._MAX_EXECUTE_TIMEOUT,
|
||||||
)
|
)
|
||||||
self.tools = [t for t in self.tools if t.name != "execute"]
|
self.tools = [t for t in self.tools if t.name != "execute"]
|
||||||
self.tools.append(self._create_save_document_tool())
|
if self._should_persist_documents():
|
||||||
|
self.tools.append(self._create_save_document_tool())
|
||||||
if self._sandbox_available:
|
if self._sandbox_available:
|
||||||
self.tools.append(self._create_execute_code_tool())
|
self.tools.append(self._create_execute_code_tool())
|
||||||
|
|
||||||
|
|
@ -637,15 +663,25 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
runtime: ToolRuntime[None, FilesystemState],
|
runtime: ToolRuntime[None, FilesystemState],
|
||||||
) -> Command | str:
|
) -> Command | str:
|
||||||
resolved_backend = self._get_backend(runtime)
|
resolved_backend = self._get_backend(runtime)
|
||||||
|
target_path = self._resolve_write_target_path(file_path, runtime)
|
||||||
try:
|
try:
|
||||||
validated_path = validate_path(file_path)
|
validated_path = validate_path(target_path)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
return f"Error: {exc}"
|
return f"Error: {exc}"
|
||||||
res: WriteResult = resolved_backend.write(validated_path, content)
|
res: WriteResult = resolved_backend.write(validated_path, content)
|
||||||
if res.error:
|
if res.error:
|
||||||
return res.error
|
return res.error
|
||||||
|
verify_error = self._verify_written_content_sync(
|
||||||
|
backend=resolved_backend,
|
||||||
|
path=validated_path,
|
||||||
|
expected_content=content,
|
||||||
|
)
|
||||||
|
if verify_error:
|
||||||
|
return verify_error
|
||||||
|
|
||||||
if not self._is_kb_document(validated_path):
|
if self._should_persist_documents() and not self._is_kb_document(
|
||||||
|
validated_path
|
||||||
|
):
|
||||||
persist_result = self._run_async_blocking(
|
persist_result = self._run_async_blocking(
|
||||||
self._persist_new_document(
|
self._persist_new_document(
|
||||||
file_path=validated_path, content=content
|
file_path=validated_path, content=content
|
||||||
|
|
@ -682,15 +718,25 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
runtime: ToolRuntime[None, FilesystemState],
|
runtime: ToolRuntime[None, FilesystemState],
|
||||||
) -> Command | str:
|
) -> Command | str:
|
||||||
resolved_backend = self._get_backend(runtime)
|
resolved_backend = self._get_backend(runtime)
|
||||||
|
target_path = self._resolve_write_target_path(file_path, runtime)
|
||||||
try:
|
try:
|
||||||
validated_path = validate_path(file_path)
|
validated_path = validate_path(target_path)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
return f"Error: {exc}"
|
return f"Error: {exc}"
|
||||||
res: WriteResult = await resolved_backend.awrite(validated_path, content)
|
res: WriteResult = await resolved_backend.awrite(validated_path, content)
|
||||||
if res.error:
|
if res.error:
|
||||||
return res.error
|
return res.error
|
||||||
|
verify_error = await self._verify_written_content_async(
|
||||||
|
backend=resolved_backend,
|
||||||
|
path=validated_path,
|
||||||
|
expected_content=content,
|
||||||
|
)
|
||||||
|
if verify_error:
|
||||||
|
return verify_error
|
||||||
|
|
||||||
if not self._is_kb_document(validated_path):
|
if self._should_persist_documents() and not self._is_kb_document(
|
||||||
|
validated_path
|
||||||
|
):
|
||||||
persist_result = await self._persist_new_document(
|
persist_result = await self._persist_new_document(
|
||||||
file_path=validated_path,
|
file_path=validated_path,
|
||||||
content=content,
|
content=content,
|
||||||
|
|
@ -726,6 +772,164 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
"""Return True for paths under /documents/ (KB-sourced, XML-wrapped)."""
|
"""Return True for paths under /documents/ (KB-sourced, XML-wrapped)."""
|
||||||
return path.startswith("/documents/")
|
return path.startswith("/documents/")
|
||||||
|
|
||||||
|
def _should_persist_documents(self) -> bool:
|
||||||
|
"""Only cloud mode persists file content to Document/Chunk tables."""
|
||||||
|
return self._filesystem_mode == FilesystemMode.CLOUD
|
||||||
|
|
||||||
|
def _default_mount_prefix(self, runtime: ToolRuntime[None, FilesystemState]) -> str:
|
||||||
|
backend = self._get_backend(runtime)
|
||||||
|
if isinstance(backend, MultiRootLocalFolderBackend):
|
||||||
|
return f"/{backend.default_mount()}"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _normalize_local_mount_path(
|
||||||
|
self, candidate: str, runtime: ToolRuntime[None, FilesystemState]
|
||||||
|
) -> str:
|
||||||
|
backend = self._get_backend(runtime)
|
||||||
|
mount_prefix = self._default_mount_prefix(runtime)
|
||||||
|
normalized_candidate = re.sub(r"/+", "/", candidate.strip().replace("\\", "/"))
|
||||||
|
if not mount_prefix or not isinstance(backend, MultiRootLocalFolderBackend):
|
||||||
|
if normalized_candidate.startswith("/"):
|
||||||
|
return normalized_candidate
|
||||||
|
return f"/{normalized_candidate.lstrip('/')}"
|
||||||
|
|
||||||
|
mount_names = set(backend.list_mounts())
|
||||||
|
if normalized_candidate.startswith("/"):
|
||||||
|
first_segment = normalized_candidate.lstrip("/").split("/", 1)[0]
|
||||||
|
if first_segment in mount_names:
|
||||||
|
return normalized_candidate
|
||||||
|
return f"{mount_prefix}{normalized_candidate}"
|
||||||
|
|
||||||
|
relative = normalized_candidate.lstrip("/")
|
||||||
|
first_segment = relative.split("/", 1)[0]
|
||||||
|
if first_segment in mount_names:
|
||||||
|
return f"/{relative}"
|
||||||
|
return f"{mount_prefix}/{relative}"
|
||||||
|
|
||||||
|
def _get_contract_suggested_path(
|
||||||
|
self, runtime: ToolRuntime[None, FilesystemState]
|
||||||
|
) -> str:
|
||||||
|
contract = runtime.state.get("file_operation_contract") or {}
|
||||||
|
suggested = contract.get("suggested_path")
|
||||||
|
if isinstance(suggested, str) and suggested.strip():
|
||||||
|
cleaned = suggested.strip()
|
||||||
|
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
|
||||||
|
return self._normalize_local_mount_path(cleaned, runtime)
|
||||||
|
return cleaned
|
||||||
|
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
|
||||||
|
mount_prefix = self._default_mount_prefix(runtime)
|
||||||
|
if mount_prefix:
|
||||||
|
return f"{mount_prefix}/notes.md"
|
||||||
|
return "/notes.md"
|
||||||
|
|
||||||
|
def _resolve_write_target_path(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
runtime: ToolRuntime[None, FilesystemState],
|
||||||
|
) -> str:
|
||||||
|
candidate = file_path.strip()
|
||||||
|
if not candidate:
|
||||||
|
return self._get_contract_suggested_path(runtime)
|
||||||
|
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
|
||||||
|
return self._normalize_local_mount_path(candidate, runtime)
|
||||||
|
if not candidate.startswith("/"):
|
||||||
|
return f"/{candidate.lstrip('/')}"
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_error_text(value: str) -> bool:
|
||||||
|
return value.startswith("Error:")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_for_verification_sync(backend: Any, path: str) -> str:
|
||||||
|
read_raw = getattr(backend, "read_raw", None)
|
||||||
|
if callable(read_raw):
|
||||||
|
return read_raw(path)
|
||||||
|
return backend.read(path, offset=0, limit=200000)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _read_for_verification_async(backend: Any, path: str) -> str:
|
||||||
|
aread_raw = getattr(backend, "aread_raw", None)
|
||||||
|
if callable(aread_raw):
|
||||||
|
return await aread_raw(path)
|
||||||
|
return await backend.aread(path, offset=0, limit=200000)
|
||||||
|
|
||||||
|
def _verify_written_content_sync(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
backend: Any,
|
||||||
|
path: str,
|
||||||
|
expected_content: str,
|
||||||
|
) -> str | None:
|
||||||
|
actual = self._read_for_verification_sync(backend, path)
|
||||||
|
if self._is_error_text(actual):
|
||||||
|
return f"Error: could not verify written file '{path}'."
|
||||||
|
if actual.rstrip() != expected_content.rstrip():
|
||||||
|
return (
|
||||||
|
"Error: file write verification failed; expected content was not fully written "
|
||||||
|
f"to '{path}'."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _verify_written_content_async(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
backend: Any,
|
||||||
|
path: str,
|
||||||
|
expected_content: str,
|
||||||
|
) -> str | None:
|
||||||
|
actual = await self._read_for_verification_async(backend, path)
|
||||||
|
if self._is_error_text(actual):
|
||||||
|
return f"Error: could not verify written file '{path}'."
|
||||||
|
if actual.rstrip() != expected_content.rstrip():
|
||||||
|
return (
|
||||||
|
"Error: file write verification failed; expected content was not fully written "
|
||||||
|
f"to '{path}'."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _verify_edited_content_sync(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
backend: Any,
|
||||||
|
path: str,
|
||||||
|
new_string: str,
|
||||||
|
) -> tuple[str | None, str | None]:
|
||||||
|
updated_content = self._read_for_verification_sync(backend, path)
|
||||||
|
if self._is_error_text(updated_content):
|
||||||
|
return (
|
||||||
|
f"Error: could not verify edited file '{path}'.",
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if new_string and new_string not in updated_content:
|
||||||
|
return (
|
||||||
|
"Error: edit verification failed; updated content was not found in "
|
||||||
|
f"'{path}'.",
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return None, updated_content
|
||||||
|
|
||||||
|
async def _verify_edited_content_async(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
backend: Any,
|
||||||
|
path: str,
|
||||||
|
new_string: str,
|
||||||
|
) -> tuple[str | None, str | None]:
|
||||||
|
updated_content = await self._read_for_verification_async(backend, path)
|
||||||
|
if self._is_error_text(updated_content):
|
||||||
|
return (
|
||||||
|
f"Error: could not verify edited file '{path}'.",
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if new_string and new_string not in updated_content:
|
||||||
|
return (
|
||||||
|
"Error: edit verification failed; updated content was not found in "
|
||||||
|
f"'{path}'.",
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return None, updated_content
|
||||||
|
|
||||||
def _create_edit_file_tool(self) -> BaseTool:
|
def _create_edit_file_tool(self) -> BaseTool:
|
||||||
"""Create edit_file with DB persistence (skipped for KB documents)."""
|
"""Create edit_file with DB persistence (skipped for KB documents)."""
|
||||||
tool_description = (
|
tool_description = (
|
||||||
|
|
@ -754,8 +958,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
] = False,
|
] = False,
|
||||||
) -> Command | str:
|
) -> Command | str:
|
||||||
resolved_backend = self._get_backend(runtime)
|
resolved_backend = self._get_backend(runtime)
|
||||||
|
target_path = self._resolve_write_target_path(file_path, runtime)
|
||||||
try:
|
try:
|
||||||
validated_path = validate_path(file_path)
|
validated_path = validate_path(target_path)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
return f"Error: {exc}"
|
return f"Error: {exc}"
|
||||||
res: EditResult = resolved_backend.edit(
|
res: EditResult = resolved_backend.edit(
|
||||||
|
|
@ -767,13 +972,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
if res.error:
|
if res.error:
|
||||||
return res.error
|
return res.error
|
||||||
|
|
||||||
if not self._is_kb_document(validated_path):
|
verify_error, updated_content = self._verify_edited_content_sync(
|
||||||
read_result = resolved_backend.read(
|
backend=resolved_backend,
|
||||||
validated_path, offset=0, limit=200000
|
path=validated_path,
|
||||||
)
|
new_string=new_string,
|
||||||
if read_result.error or read_result.file_data is None:
|
)
|
||||||
return f"Error: could not reload edited file '{validated_path}' for persistence."
|
if verify_error:
|
||||||
updated_content = read_result.file_data["content"]
|
return verify_error
|
||||||
|
|
||||||
|
if self._should_persist_documents() and not self._is_kb_document(
|
||||||
|
validated_path
|
||||||
|
):
|
||||||
|
if updated_content is None:
|
||||||
|
return (
|
||||||
|
f"Error: could not reload edited file '{validated_path}' for "
|
||||||
|
"persistence."
|
||||||
|
)
|
||||||
persist_result = self._run_async_blocking(
|
persist_result = self._run_async_blocking(
|
||||||
self._persist_edited_document(
|
self._persist_edited_document(
|
||||||
file_path=validated_path,
|
file_path=validated_path,
|
||||||
|
|
@ -818,8 +1032,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
] = False,
|
] = False,
|
||||||
) -> Command | str:
|
) -> Command | str:
|
||||||
resolved_backend = self._get_backend(runtime)
|
resolved_backend = self._get_backend(runtime)
|
||||||
|
target_path = self._resolve_write_target_path(file_path, runtime)
|
||||||
try:
|
try:
|
||||||
validated_path = validate_path(file_path)
|
validated_path = validate_path(target_path)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
return f"Error: {exc}"
|
return f"Error: {exc}"
|
||||||
res: EditResult = await resolved_backend.aedit(
|
res: EditResult = await resolved_backend.aedit(
|
||||||
|
|
@ -831,13 +1046,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
if res.error:
|
if res.error:
|
||||||
return res.error
|
return res.error
|
||||||
|
|
||||||
if not self._is_kb_document(validated_path):
|
verify_error, updated_content = await self._verify_edited_content_async(
|
||||||
read_result = await resolved_backend.aread(
|
backend=resolved_backend,
|
||||||
validated_path, offset=0, limit=200000
|
path=validated_path,
|
||||||
)
|
new_string=new_string,
|
||||||
if read_result.error or read_result.file_data is None:
|
)
|
||||||
return f"Error: could not reload edited file '{validated_path}' for persistence."
|
if verify_error:
|
||||||
updated_content = read_result.file_data["content"]
|
return verify_error
|
||||||
|
|
||||||
|
if self._should_persist_documents() and not self._is_kb_document(
|
||||||
|
validated_path
|
||||||
|
):
|
||||||
|
if updated_content is None:
|
||||||
|
return (
|
||||||
|
f"Error: could not reload edited file '{validated_path}' for "
|
||||||
|
"persistence."
|
||||||
|
)
|
||||||
persist_error = await self._persist_edited_document(
|
persist_error = await self._persist_edited_document(
|
||||||
file_path=validated_path,
|
file_path=validated_path,
|
||||||
updated_content=updated_content,
|
updated_content=updated_content,
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
|
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
from app.db import (
|
from app.db import (
|
||||||
NATIVE_TO_LEGACY_DOCTYPE,
|
NATIVE_TO_LEGACY_DOCTYPE,
|
||||||
Chunk,
|
Chunk,
|
||||||
|
|
@ -857,6 +858,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
*,
|
*,
|
||||||
llm: BaseChatModel | None = None,
|
llm: BaseChatModel | None = None,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||||
available_connectors: list[str] | None = None,
|
available_connectors: list[str] | None = None,
|
||||||
available_document_types: list[str] | None = None,
|
available_document_types: list[str] | None = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
|
|
@ -865,6 +867,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
) -> None:
|
) -> None:
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.search_space_id = search_space_id
|
self.search_space_id = search_space_id
|
||||||
|
self.filesystem_mode = filesystem_mode
|
||||||
self.available_connectors = available_connectors
|
self.available_connectors = available_connectors
|
||||||
self.available_document_types = available_document_types
|
self.available_document_types = available_document_types
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
@ -996,6 +999,9 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
messages = state.get("messages") or []
|
messages = state.get("messages") or []
|
||||||
if not messages:
|
if not messages:
|
||||||
return None
|
return None
|
||||||
|
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||||
|
# Local-folder mode should not seed cloud KB documents into filesystem.
|
||||||
|
return None
|
||||||
|
|
||||||
last_human = None
|
last_human = None
|
||||||
for msg in reversed(messages):
|
for msg in reversed(messages):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,316 @@
|
||||||
|
"""Desktop local-folder filesystem backend for deepagents tools."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import fnmatch
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from deepagents.backends.protocol import (
|
||||||
|
EditResult,
|
||||||
|
FileDownloadResponse,
|
||||||
|
FileInfo,
|
||||||
|
FileUploadResponse,
|
||||||
|
GrepMatch,
|
||||||
|
WriteResult,
|
||||||
|
)
|
||||||
|
from deepagents.backends.utils import (
|
||||||
|
create_file_data,
|
||||||
|
format_read_response,
|
||||||
|
perform_string_replacement,
|
||||||
|
)
|
||||||
|
|
||||||
|
_INVALID_PATH = "invalid_path"
|
||||||
|
_FILE_NOT_FOUND = "file_not_found"
|
||||||
|
_IS_DIRECTORY = "is_directory"
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFolderBackend:
|
||||||
|
"""Filesystem backend rooted to a single local folder."""
|
||||||
|
|
||||||
|
def __init__(self, root_path: str) -> None:
|
||||||
|
root = Path(root_path).expanduser().resolve()
|
||||||
|
if not root.exists() or not root.is_dir():
|
||||||
|
msg = f"Local filesystem root does not exist or is not a directory: {root_path}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
self._root = root
|
||||||
|
self._locks: dict[str, threading.Lock] = {}
|
||||||
|
self._locks_mu = threading.Lock()
|
||||||
|
|
||||||
|
def _lock_for(self, path: str) -> threading.Lock:
|
||||||
|
with self._locks_mu:
|
||||||
|
if path not in self._locks:
|
||||||
|
self._locks[path] = threading.Lock()
|
||||||
|
return self._locks[path]
|
||||||
|
|
||||||
|
def _resolve_virtual(self, virtual_path: str, *, allow_root: bool = False) -> Path:
|
||||||
|
if not virtual_path.startswith("/"):
|
||||||
|
msg = f"Invalid path (must be absolute): {virtual_path}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
rel = virtual_path.lstrip("/")
|
||||||
|
candidate = self._root if rel == "" else (self._root / rel)
|
||||||
|
resolved = candidate.resolve()
|
||||||
|
if not allow_root and resolved == self._root:
|
||||||
|
msg = "Path must refer to a file or child directory under root"
|
||||||
|
raise ValueError(msg)
|
||||||
|
if not resolved.is_relative_to(self._root):
|
||||||
|
msg = f"Path escapes local filesystem root: {virtual_path}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_virtual(path: Path, root: Path) -> str:
|
||||||
|
rel = path.relative_to(root).as_posix()
|
||||||
|
return "/" if rel == "." else f"/{rel}"
|
||||||
|
|
||||||
|
def _write_text_atomic(self, path: Path, content: str) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
temp_path = path.with_suffix(f"{path.suffix}.tmp")
|
||||||
|
temp_path.write_text(content, encoding="utf-8")
|
||||||
|
os.replace(temp_path, path)
|
||||||
|
|
||||||
|
def ls_info(self, path: str) -> list[FileInfo]:
|
||||||
|
try:
|
||||||
|
target = self._resolve_virtual(path, allow_root=True)
|
||||||
|
except ValueError:
|
||||||
|
return []
|
||||||
|
if not target.exists() or not target.is_dir():
|
||||||
|
return []
|
||||||
|
infos: list[FileInfo] = []
|
||||||
|
for child in sorted(target.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())):
|
||||||
|
infos.append(
|
||||||
|
FileInfo(
|
||||||
|
path=self._to_virtual(child, self._root),
|
||||||
|
is_dir=child.is_dir(),
|
||||||
|
size=child.stat().st_size if child.is_file() else 0,
|
||||||
|
modified_at=str(child.stat().st_mtime),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return infos
|
||||||
|
|
||||||
|
async def als_info(self, path: str) -> list[FileInfo]:
|
||||||
|
return await asyncio.to_thread(self.ls_info, path)
|
||||||
|
|
||||||
|
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(file_path)
|
||||||
|
except ValueError:
|
||||||
|
return f"Error: Invalid path '{file_path}'"
|
||||||
|
if not path.exists():
|
||||||
|
return f"Error: File '{file_path}' not found"
|
||||||
|
if not path.is_file():
|
||||||
|
return f"Error: Path '{file_path}' is not a file"
|
||||||
|
content = path.read_text(encoding="utf-8", errors="replace")
|
||||||
|
file_data = create_file_data(content)
|
||||||
|
return format_read_response(file_data, offset, limit)
|
||||||
|
|
||||||
|
async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||||
|
return await asyncio.to_thread(self.read, file_path, offset, limit)
|
||||||
|
|
||||||
|
def read_raw(self, file_path: str) -> str:
|
||||||
|
"""Read raw file text without line-number formatting."""
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(file_path)
|
||||||
|
except ValueError:
|
||||||
|
return f"Error: Invalid path '{file_path}'"
|
||||||
|
if not path.exists():
|
||||||
|
return f"Error: File '{file_path}' not found"
|
||||||
|
if not path.is_file():
|
||||||
|
return f"Error: Path '{file_path}' is not a file"
|
||||||
|
return path.read_text(encoding="utf-8", errors="replace")
|
||||||
|
|
||||||
|
async def aread_raw(self, file_path: str) -> str:
|
||||||
|
"""Async variant of read_raw."""
|
||||||
|
return await asyncio.to_thread(self.read_raw, file_path)
|
||||||
|
|
||||||
|
def write(self, file_path: str, content: str) -> WriteResult:
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(file_path)
|
||||||
|
except ValueError:
|
||||||
|
return WriteResult(error=f"Error: Invalid path '{file_path}'")
|
||||||
|
lock = self._lock_for(file_path)
|
||||||
|
with lock:
|
||||||
|
if path.exists():
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
f"Cannot write to {file_path} because it already exists. "
|
||||||
|
"Read and then make an edit, or write to a new path."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._write_text_atomic(path, content)
|
||||||
|
return WriteResult(path=file_path, files_update=None)
|
||||||
|
|
||||||
|
async def awrite(self, file_path: str, content: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.write, file_path, content)
|
||||||
|
|
||||||
|
def edit(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
old_string: str,
|
||||||
|
new_string: str,
|
||||||
|
replace_all: bool = False,
|
||||||
|
) -> EditResult:
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(file_path)
|
||||||
|
except ValueError:
|
||||||
|
return EditResult(error=f"Error: Invalid path '{file_path}'")
|
||||||
|
lock = self._lock_for(file_path)
|
||||||
|
with lock:
|
||||||
|
if not path.exists() or not path.is_file():
|
||||||
|
return EditResult(error=f"Error: File '{file_path}' not found")
|
||||||
|
content = path.read_text(encoding="utf-8", errors="replace")
|
||||||
|
result = perform_string_replacement(content, old_string, new_string, replace_all)
|
||||||
|
if isinstance(result, str):
|
||||||
|
return EditResult(error=result)
|
||||||
|
updated_content, occurrences = result
|
||||||
|
self._write_text_atomic(path, updated_content)
|
||||||
|
return EditResult(path=file_path, files_update=None, occurrences=int(occurrences))
|
||||||
|
|
||||||
|
async def aedit(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
old_string: str,
|
||||||
|
new_string: str,
|
||||||
|
replace_all: bool = False,
|
||||||
|
) -> EditResult:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.edit, file_path, old_string, new_string, replace_all
|
||||||
|
)
|
||||||
|
|
||||||
|
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||||
|
try:
|
||||||
|
base = self._resolve_virtual(path, allow_root=True)
|
||||||
|
except ValueError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if pattern.startswith("/"):
|
||||||
|
search_base = self._root
|
||||||
|
normalized_pattern = pattern.lstrip("/")
|
||||||
|
else:
|
||||||
|
search_base = base
|
||||||
|
normalized_pattern = pattern
|
||||||
|
|
||||||
|
matches: list[FileInfo] = []
|
||||||
|
for hit in search_base.glob(normalized_pattern):
|
||||||
|
try:
|
||||||
|
resolved = hit.resolve()
|
||||||
|
if not resolved.is_relative_to(self._root):
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
matches.append(
|
||||||
|
FileInfo(
|
||||||
|
path=self._to_virtual(resolved, self._root),
|
||||||
|
is_dir=resolved.is_dir(),
|
||||||
|
size=resolved.stat().st_size if resolved.is_file() else 0,
|
||||||
|
modified_at=str(resolved.stat().st_mtime),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||||
|
return await asyncio.to_thread(self.glob_info, pattern, path)
|
||||||
|
|
||||||
|
def _iter_candidate_files(self, path: str | None, glob: str | None) -> list[Path]:
|
||||||
|
base_virtual = path or "/"
|
||||||
|
try:
|
||||||
|
base = self._resolve_virtual(base_virtual, allow_root=True)
|
||||||
|
except ValueError:
|
||||||
|
return []
|
||||||
|
if not base.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
candidates = [p for p in base.rglob("*") if p.is_file()]
|
||||||
|
if glob:
|
||||||
|
candidates = [
|
||||||
|
p
|
||||||
|
for p in candidates
|
||||||
|
if fnmatch.fnmatch(self._to_virtual(p, self._root), glob)
|
||||||
|
or fnmatch.fnmatch(p.name, glob)
|
||||||
|
]
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
def grep_raw(
|
||||||
|
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||||
|
) -> list[GrepMatch] | str:
|
||||||
|
if not pattern:
|
||||||
|
return "Error: pattern cannot be empty"
|
||||||
|
matches: list[GrepMatch] = []
|
||||||
|
for file_path in self._iter_candidate_files(path, glob):
|
||||||
|
try:
|
||||||
|
lines = file_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
for idx, line in enumerate(lines, start=1):
|
||||||
|
if pattern in line:
|
||||||
|
matches.append(
|
||||||
|
GrepMatch(
|
||||||
|
path=self._to_virtual(file_path, self._root),
|
||||||
|
line=idx,
|
||||||
|
text=line,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
async def agrep_raw(
|
||||||
|
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||||
|
) -> list[GrepMatch] | str:
|
||||||
|
return await asyncio.to_thread(self.grep_raw, pattern, path, glob)
|
||||||
|
|
||||||
|
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||||
|
responses: list[FileUploadResponse] = []
|
||||||
|
for virtual_path, content in files:
|
||||||
|
try:
|
||||||
|
target = self._resolve_virtual(virtual_path)
|
||||||
|
target.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
temp_path = target.with_suffix(f"{target.suffix}.tmp")
|
||||||
|
temp_path.write_bytes(content)
|
||||||
|
os.replace(temp_path, target)
|
||||||
|
responses.append(FileUploadResponse(path=virtual_path, error=None))
|
||||||
|
except FileNotFoundError:
|
||||||
|
responses.append(
|
||||||
|
FileUploadResponse(path=virtual_path, error=_FILE_NOT_FOUND)
|
||||||
|
)
|
||||||
|
except IsADirectoryError:
|
||||||
|
responses.append(FileUploadResponse(path=virtual_path, error=_IS_DIRECTORY))
|
||||||
|
except Exception:
|
||||||
|
responses.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH))
|
||||||
|
return responses
|
||||||
|
|
||||||
|
async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||||
|
return await asyncio.to_thread(self.upload_files, files)
|
||||||
|
|
||||||
|
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||||
|
responses: list[FileDownloadResponse] = []
|
||||||
|
for virtual_path in paths:
|
||||||
|
try:
|
||||||
|
target = self._resolve_virtual(virtual_path)
|
||||||
|
if not target.exists():
|
||||||
|
responses.append(
|
||||||
|
FileDownloadResponse(
|
||||||
|
path=virtual_path, content=None, error=_FILE_NOT_FOUND
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if target.is_dir():
|
||||||
|
responses.append(
|
||||||
|
FileDownloadResponse(
|
||||||
|
path=virtual_path, content=None, error=_IS_DIRECTORY
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
responses.append(
|
||||||
|
FileDownloadResponse(
|
||||||
|
path=virtual_path, content=target.read_bytes(), error=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
responses.append(
|
||||||
|
FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH)
|
||||||
|
)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||||
|
return await asyncio.to_thread(self.download_files, paths)
|
||||||
|
|
@ -0,0 +1,329 @@
|
||||||
|
"""Aggregate multiple LocalFolderBackend roots behind mount-prefixed virtual paths."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from deepagents.backends.protocol import (
|
||||||
|
EditResult,
|
||||||
|
FileDownloadResponse,
|
||||||
|
FileInfo,
|
||||||
|
FileUploadResponse,
|
||||||
|
GrepMatch,
|
||||||
|
WriteResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend
|
||||||
|
|
||||||
|
_INVALID_PATH = "invalid_path"
|
||||||
|
_FILE_NOT_FOUND = "file_not_found"
|
||||||
|
_IS_DIRECTORY = "is_directory"
|
||||||
|
|
||||||
|
|
||||||
|
class MultiRootLocalFolderBackend:
|
||||||
|
"""Route filesystem operations to one of several mounted local roots.
|
||||||
|
|
||||||
|
Virtual paths are namespaced as:
|
||||||
|
- `/<mount>/...`
|
||||||
|
where `<mount>` is derived from each selected root folder name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mounts: tuple[tuple[str, str], ...]) -> None:
|
||||||
|
if not mounts:
|
||||||
|
msg = "At least one local mount is required"
|
||||||
|
raise ValueError(msg)
|
||||||
|
self._mount_to_backend: dict[str, LocalFolderBackend] = {}
|
||||||
|
for raw_mount, raw_root in mounts:
|
||||||
|
mount = raw_mount.strip()
|
||||||
|
if not mount:
|
||||||
|
msg = "Mount id cannot be empty"
|
||||||
|
raise ValueError(msg)
|
||||||
|
if mount in self._mount_to_backend:
|
||||||
|
msg = f"Duplicate mount id: {mount}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
normalized_root = str(Path(raw_root).expanduser().resolve())
|
||||||
|
self._mount_to_backend[mount] = LocalFolderBackend(normalized_root)
|
||||||
|
self._mount_order = tuple(self._mount_to_backend.keys())
|
||||||
|
|
||||||
|
def list_mounts(self) -> tuple[str, ...]:
|
||||||
|
return self._mount_order
|
||||||
|
|
||||||
|
def default_mount(self) -> str:
|
||||||
|
return self._mount_order[0]
|
||||||
|
|
||||||
|
def _mount_error(self) -> str:
|
||||||
|
mounts = ", ".join(f"/{mount}" for mount in self._mount_order)
|
||||||
|
return (
|
||||||
|
"Path must start with one of the selected folders: "
|
||||||
|
f"{mounts}. Example: /{self._mount_order[0]}/file.txt"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _split_mount_path(self, virtual_path: str) -> tuple[str, str]:
|
||||||
|
if not virtual_path.startswith("/"):
|
||||||
|
msg = f"Invalid path (must be absolute): {virtual_path}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
rel = virtual_path.lstrip("/")
|
||||||
|
if not rel:
|
||||||
|
raise ValueError(self._mount_error())
|
||||||
|
mount, _, remainder = rel.partition("/")
|
||||||
|
backend = self._mount_to_backend.get(mount)
|
||||||
|
if backend is None:
|
||||||
|
raise ValueError(self._mount_error())
|
||||||
|
local_path = f"/{remainder}" if remainder else "/"
|
||||||
|
return mount, local_path
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prefix_mount_path(mount: str, local_path: str) -> str:
|
||||||
|
if local_path == "/":
|
||||||
|
return f"/{mount}"
|
||||||
|
return f"/{mount}{local_path}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_value(item: Any, key: str) -> Any:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
return item.get(key)
|
||||||
|
return getattr(item, key, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_str(cls, item: Any, key: str) -> str:
|
||||||
|
value = cls._get_value(item, key)
|
||||||
|
return value if isinstance(value, str) else ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_int(cls, item: Any, key: str) -> int:
|
||||||
|
value = cls._get_value(item, key)
|
||||||
|
return int(value) if isinstance(value, int | float) else 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_bool(cls, item: Any, key: str) -> bool:
|
||||||
|
value = cls._get_value(item, key)
|
||||||
|
return bool(value)
|
||||||
|
|
||||||
|
def _list_mount_roots(self) -> list[FileInfo]:
|
||||||
|
return [
|
||||||
|
FileInfo(path=f"/{mount}", is_dir=True, size=0, modified_at="0")
|
||||||
|
for mount in self._mount_order
|
||||||
|
]
|
||||||
|
|
||||||
|
def _transform_infos(self, mount: str, infos: list[FileInfo]) -> list[FileInfo]:
|
||||||
|
transformed: list[FileInfo] = []
|
||||||
|
for info in infos:
|
||||||
|
transformed.append(
|
||||||
|
FileInfo(
|
||||||
|
path=self._prefix_mount_path(mount, self._get_str(info, "path")),
|
||||||
|
is_dir=self._get_bool(info, "is_dir"),
|
||||||
|
size=self._get_int(info, "size"),
|
||||||
|
modified_at=self._get_str(info, "modified_at"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
def ls_info(self, path: str) -> list[FileInfo]:
|
||||||
|
if path == "/":
|
||||||
|
return self._list_mount_roots()
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(path)
|
||||||
|
except ValueError:
|
||||||
|
return []
|
||||||
|
return self._transform_infos(mount, self._mount_to_backend[mount].ls_info(local_path))
|
||||||
|
|
||||||
|
async def als_info(self, path: str) -> list[FileInfo]:
|
||||||
|
return await asyncio.to_thread(self.ls_info, path)
|
||||||
|
|
||||||
|
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
return self._mount_to_backend[mount].read(local_path, offset, limit)
|
||||||
|
|
||||||
|
async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||||
|
return await asyncio.to_thread(self.read, file_path, offset, limit)
|
||||||
|
|
||||||
|
def read_raw(self, file_path: str) -> str:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
return self._mount_to_backend[mount].read_raw(local_path)
|
||||||
|
|
||||||
|
async def aread_raw(self, file_path: str) -> str:
|
||||||
|
return await asyncio.to_thread(self.read_raw, file_path)
|
||||||
|
|
||||||
|
def write(self, file_path: str, content: str) -> WriteResult:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return WriteResult(error=f"Error: {exc}")
|
||||||
|
result = self._mount_to_backend[mount].write(local_path, content)
|
||||||
|
if result.path:
|
||||||
|
result.path = self._prefix_mount_path(mount, result.path)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def awrite(self, file_path: str, content: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.write, file_path, content)
|
||||||
|
|
||||||
|
def edit(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
old_string: str,
|
||||||
|
new_string: str,
|
||||||
|
replace_all: bool = False,
|
||||||
|
) -> EditResult:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return EditResult(error=f"Error: {exc}")
|
||||||
|
result = self._mount_to_backend[mount].edit(
|
||||||
|
local_path, old_string, new_string, replace_all
|
||||||
|
)
|
||||||
|
if result.path:
|
||||||
|
result.path = self._prefix_mount_path(mount, result.path)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def aedit(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
old_string: str,
|
||||||
|
new_string: str,
|
||||||
|
replace_all: bool = False,
|
||||||
|
) -> EditResult:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.edit, file_path, old_string, new_string, replace_all
|
||||||
|
)
|
||||||
|
|
||||||
|
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||||
|
if path == "/":
|
||||||
|
prefixed_results: list[FileInfo] = []
|
||||||
|
if pattern.startswith("/"):
|
||||||
|
mount, _, remainder = pattern.lstrip("/").partition("/")
|
||||||
|
backend = self._mount_to_backend.get(mount)
|
||||||
|
if not backend:
|
||||||
|
return []
|
||||||
|
local_pattern = f"/{remainder}" if remainder else "/"
|
||||||
|
return self._transform_infos(
|
||||||
|
mount, backend.glob_info(local_pattern, path="/")
|
||||||
|
)
|
||||||
|
for mount, backend in self._mount_to_backend.items():
|
||||||
|
prefixed_results.extend(
|
||||||
|
self._transform_infos(mount, backend.glob_info(pattern, path="/"))
|
||||||
|
)
|
||||||
|
return prefixed_results
|
||||||
|
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(path)
|
||||||
|
except ValueError:
|
||||||
|
return []
|
||||||
|
return self._transform_infos(
|
||||||
|
mount, self._mount_to_backend[mount].glob_info(pattern, path=local_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||||
|
return await asyncio.to_thread(self.glob_info, pattern, path)
|
||||||
|
|
||||||
|
def grep_raw(
|
||||||
|
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||||
|
) -> list[GrepMatch] | str:
|
||||||
|
if not pattern:
|
||||||
|
return "Error: pattern cannot be empty"
|
||||||
|
if path is None or path == "/":
|
||||||
|
all_matches: list[GrepMatch] = []
|
||||||
|
for mount, backend in self._mount_to_backend.items():
|
||||||
|
result = backend.grep_raw(pattern, path="/", glob=glob)
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result
|
||||||
|
all_matches.extend(
|
||||||
|
[
|
||||||
|
GrepMatch(
|
||||||
|
path=self._prefix_mount_path(mount, self._get_str(match, "path")),
|
||||||
|
line=self._get_int(match, "line"),
|
||||||
|
text=self._get_str(match, "text"),
|
||||||
|
)
|
||||||
|
for match in result
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return all_matches
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
|
||||||
|
result = self._mount_to_backend[mount].grep_raw(
|
||||||
|
pattern, path=local_path, glob=glob
|
||||||
|
)
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result
|
||||||
|
return [
|
||||||
|
GrepMatch(
|
||||||
|
path=self._prefix_mount_path(mount, self._get_str(match, "path")),
|
||||||
|
line=self._get_int(match, "line"),
|
||||||
|
text=self._get_str(match, "text"),
|
||||||
|
)
|
||||||
|
for match in result
|
||||||
|
]
|
||||||
|
|
||||||
|
async def agrep_raw(
|
||||||
|
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||||
|
) -> list[GrepMatch] | str:
|
||||||
|
return await asyncio.to_thread(self.grep_raw, pattern, path, glob)
|
||||||
|
|
||||||
|
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||||
|
grouped: dict[str, list[tuple[str, bytes]]] = {}
|
||||||
|
invalid: list[FileUploadResponse] = []
|
||||||
|
for virtual_path, content in files:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(virtual_path)
|
||||||
|
except ValueError:
|
||||||
|
invalid.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH))
|
||||||
|
continue
|
||||||
|
grouped.setdefault(mount, []).append((local_path, content))
|
||||||
|
|
||||||
|
responses = list(invalid)
|
||||||
|
for mount, mount_files in grouped.items():
|
||||||
|
result = self._mount_to_backend[mount].upload_files(mount_files)
|
||||||
|
responses.extend(
|
||||||
|
[
|
||||||
|
FileUploadResponse(
|
||||||
|
path=self._prefix_mount_path(mount, self._get_str(item, "path")),
|
||||||
|
error=self._get_str(item, "error") or None,
|
||||||
|
)
|
||||||
|
for item in result
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||||
|
return await asyncio.to_thread(self.upload_files, files)
|
||||||
|
|
||||||
|
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||||
|
grouped: dict[str, list[str]] = {}
|
||||||
|
invalid: list[FileDownloadResponse] = []
|
||||||
|
for virtual_path in paths:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(virtual_path)
|
||||||
|
except ValueError:
|
||||||
|
invalid.append(
|
||||||
|
FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
grouped.setdefault(mount, []).append(local_path)
|
||||||
|
|
||||||
|
responses = list(invalid)
|
||||||
|
for mount, mount_paths in grouped.items():
|
||||||
|
result = self._mount_to_backend[mount].download_files(mount_paths)
|
||||||
|
responses.extend(
|
||||||
|
[
|
||||||
|
FileDownloadResponse(
|
||||||
|
path=self._prefix_mount_path(mount, self._get_str(item, "path")),
|
||||||
|
content=self._get_value(item, "content"),
|
||||||
|
error=self._get_str(item, "error") or None,
|
||||||
|
)
|
||||||
|
for item in result
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||||
|
return await asyncio.to_thread(self.download_files, paths)
|
||||||
|
|
@ -0,0 +1,123 @@
|
||||||
|
"""Safe wrapper around deepagents' SummarizationMiddleware.
|
||||||
|
|
||||||
|
Upstream issue
|
||||||
|
--------------
|
||||||
|
`deepagents.middleware.summarization.SummarizationMiddleware._aoffload_to_backend`
|
||||||
|
(and its sync counterpart) call
|
||||||
|
``get_buffer_string(filtered_messages)`` before writing the evicted history
|
||||||
|
to the backend file. In recent ``langchain-core`` versions, ``get_buffer_string``
|
||||||
|
accesses ``m.text`` which iterates ``self.content`` — this raises
|
||||||
|
``TypeError: 'NoneType' object is not iterable`` whenever an ``AIMessage``
|
||||||
|
has ``content=None`` (common when a model returns *only* tool_calls, seen
|
||||||
|
frequently with Azure OpenAI ``gpt-5.x`` responses streamed through
|
||||||
|
LiteLLM).
|
||||||
|
|
||||||
|
The exception aborts the whole agent turn, so the user just sees "Error during
|
||||||
|
chat" with no assistant response.
|
||||||
|
|
||||||
|
Fix
|
||||||
|
---
|
||||||
|
We subclass ``SummarizationMiddleware`` and override
|
||||||
|
``_filter_summary_messages`` — the only call site that feeds messages into
|
||||||
|
``get_buffer_string`` — to return *copies* of messages whose ``content`` is
|
||||||
|
``None`` with ``content=""``. The originals flowing through the rest of the
|
||||||
|
agent state are untouched.
|
||||||
|
|
||||||
|
We also expose a drop-in ``create_safe_summarization_middleware`` factory
|
||||||
|
that mirrors ``deepagents.middleware.summarization.create_summarization_middleware``
|
||||||
|
but instantiates our safe subclass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from deepagents.middleware.summarization import (
|
||||||
|
SummarizationMiddleware,
|
||||||
|
compute_summarization_defaults,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from deepagents.backends.protocol import BACKEND_TYPES
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import AnyMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
|
||||||
|
"""Return ``msg`` with ``content`` coerced to a non-``None`` value.
|
||||||
|
|
||||||
|
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``;
|
||||||
|
when a provider streams back an ``AIMessage`` with only tool_calls and
|
||||||
|
no text, ``content`` can be ``None`` and the iteration explodes. We
|
||||||
|
replace ``None`` with an empty string so downstream consumers that only
|
||||||
|
care about text see an empty body.
|
||||||
|
|
||||||
|
The original message is left untouched — we return a copy via
|
||||||
|
pydantic's ``model_copy`` when available, otherwise we fall back to
|
||||||
|
re-setting the attribute on a shallow copy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if getattr(msg, "content", "not-missing") is not None:
|
||||||
|
return msg
|
||||||
|
|
||||||
|
try:
|
||||||
|
return msg.model_copy(update={"content": ""})
|
||||||
|
except AttributeError:
|
||||||
|
import copy
|
||||||
|
|
||||||
|
new_msg = copy.copy(msg)
|
||||||
|
try:
|
||||||
|
new_msg.content = ""
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
logger.debug(
|
||||||
|
"Could not sanitize content=None on message of type %s",
|
||||||
|
type(msg).__name__,
|
||||||
|
)
|
||||||
|
return msg
|
||||||
|
return new_msg
|
||||||
|
|
||||||
|
|
||||||
|
class SafeSummarizationMiddleware(SummarizationMiddleware):
|
||||||
|
"""`SummarizationMiddleware` that tolerates messages with ``content=None``.
|
||||||
|
|
||||||
|
Only ``_filter_summary_messages`` is overridden — this is the single
|
||||||
|
helper invoked by both the sync and async offload paths immediately
|
||||||
|
before ``get_buffer_string``. Normalising here means we get coverage
|
||||||
|
for both without having to copy the (long, rapidly-changing) offload
|
||||||
|
implementations from upstream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _filter_summary_messages(self, messages: list[AnyMessage]) -> list[AnyMessage]:
|
||||||
|
filtered = super()._filter_summary_messages(messages)
|
||||||
|
return [_sanitize_message_content(m) for m in filtered]
|
||||||
|
|
||||||
|
|
||||||
|
def create_safe_summarization_middleware(
|
||||||
|
model: BaseChatModel,
|
||||||
|
backend: BACKEND_TYPES,
|
||||||
|
) -> SafeSummarizationMiddleware:
|
||||||
|
"""Drop-in replacement for ``create_summarization_middleware``.
|
||||||
|
|
||||||
|
Mirrors the defaults computed by ``deepagents`` but returns our
|
||||||
|
``SafeSummarizationMiddleware`` subclass so the
|
||||||
|
``content=None`` crash in ``get_buffer_string`` is avoided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
defaults = compute_summarization_defaults(model)
|
||||||
|
return SafeSummarizationMiddleware(
|
||||||
|
model=model,
|
||||||
|
backend=backend,
|
||||||
|
trigger=defaults["trigger"],
|
||||||
|
keep=defaults["keep"],
|
||||||
|
trim_tokens_to_summarize=None,
|
||||||
|
truncate_args_settings=defaults["truncate_args_settings"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SafeSummarizationMiddleware",
|
||||||
|
"create_safe_summarization_middleware",
|
||||||
|
]
|
||||||
|
|
@ -38,8 +38,66 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
||||||
* Formatting, summarization, or analysis of content already present in the conversation
|
* Formatting, summarization, or analysis of content already present in the conversation
|
||||||
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
||||||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||||
|
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see <tool_routing> below
|
||||||
</knowledge_base_only_policy>
|
</knowledge_base_only_policy>
|
||||||
|
|
||||||
|
<tool_routing>
|
||||||
|
CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
|
||||||
|
Their data is NEVER in the knowledge base. You MUST call their tools immediately — never
|
||||||
|
say "I don't see it in the knowledge base" or ask the user if they want you to check.
|
||||||
|
Ignore any knowledge base results for these services.
|
||||||
|
|
||||||
|
When to use which tool:
|
||||||
|
- Linear (issues) → list_issues, get_issue, save_issue (create/update)
|
||||||
|
- ClickUp (tasks) → clickup_search, clickup_get_task
|
||||||
|
- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
|
||||||
|
- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread
|
||||||
|
- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table
|
||||||
|
- Knowledge base content (Notion, GitHub, files, notes) → automatically searched
|
||||||
|
- Real-time public web data → call web_search
|
||||||
|
- Reading a specific webpage → call scrape_webpage
|
||||||
|
</tool_routing>
|
||||||
|
|
||||||
|
<parameter_resolution>
|
||||||
|
Some service tools require identifiers or context you do not have (account IDs,
|
||||||
|
workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw
|
||||||
|
IDs or technical identifiers — they cannot memorise them.
|
||||||
|
|
||||||
|
Instead, follow this discovery pattern:
|
||||||
|
1. Call a listing/discovery tool to find available options.
|
||||||
|
2. ONE result → use it silently, no question to the user.
|
||||||
|
3. MULTIPLE results → present the options by their display names and let the
|
||||||
|
user choose. Never show raw UUIDs — always use friendly names.
|
||||||
|
|
||||||
|
Discovery tools by level:
|
||||||
|
- Which account/workspace? → get_connected_accounts("<service>")
|
||||||
|
- Which Jira site (cloudId)? → getAccessibleAtlassianResources
|
||||||
|
- Which Jira project? → getVisibleJiraProjects (after resolving cloudId)
|
||||||
|
- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project)
|
||||||
|
- Which channel? → slack_search_channels
|
||||||
|
- Which base? → list_bases
|
||||||
|
- Which table? → list_tables_for_base (after resolving baseId)
|
||||||
|
- Which task? → clickup_search
|
||||||
|
- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira)
|
||||||
|
|
||||||
|
For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to
|
||||||
|
obtain the cloudId, then pass it to other Jira tools. When creating an issue,
|
||||||
|
chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue.
|
||||||
|
If there is only one option at each step, use it silently. If multiple, present
|
||||||
|
friendly names.
|
||||||
|
|
||||||
|
Chain discovery when needed — e.g. for Airtable records: list_bases → pick
|
||||||
|
base → list_tables_for_base → pick table → list_records_for_table.
|
||||||
|
|
||||||
|
MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for
|
||||||
|
the same service, tool names are prefixed to avoid collisions — e.g.
|
||||||
|
linear_25_list_issues and linear_30_list_issues instead of two list_issues.
|
||||||
|
Each prefixed tool's description starts with [Account: <display_name>] so you
|
||||||
|
know which account it targets. Use get_connected_accounts("<service>") to see
|
||||||
|
the full list of accounts with their connector IDs and display names.
|
||||||
|
When only one account is connected, tools have their normal unprefixed names.
|
||||||
|
</parameter_resolution>
|
||||||
|
|
||||||
<memory_protocol>
|
<memory_protocol>
|
||||||
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||||
reveal durable facts about the user (role, interests, preferences, projects,
|
reveal durable facts about the user (role, interests, preferences, projects,
|
||||||
|
|
@ -76,8 +134,66 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
||||||
* Formatting, summarization, or analysis of content already present in the conversation
|
* Formatting, summarization, or analysis of content already present in the conversation
|
||||||
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
||||||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||||
|
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see <tool_routing> below
|
||||||
</knowledge_base_only_policy>
|
</knowledge_base_only_policy>
|
||||||
|
|
||||||
|
<tool_routing>
|
||||||
|
CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
|
||||||
|
Their data is NEVER in the knowledge base. You MUST call their tools immediately — never
|
||||||
|
say "I don't see it in the knowledge base" or ask if they want you to check.
|
||||||
|
Ignore any knowledge base results for these services.
|
||||||
|
|
||||||
|
When to use which tool:
|
||||||
|
- Linear (issues) → list_issues, get_issue, save_issue (create/update)
|
||||||
|
- ClickUp (tasks) → clickup_search, clickup_get_task
|
||||||
|
- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
|
||||||
|
- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread
|
||||||
|
- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table
|
||||||
|
- Knowledge base content (Notion, GitHub, files, notes) → automatically searched
|
||||||
|
- Real-time public web data → call web_search
|
||||||
|
- Reading a specific webpage → call scrape_webpage
|
||||||
|
</tool_routing>
|
||||||
|
|
||||||
|
<parameter_resolution>
|
||||||
|
Some service tools require identifiers or context you do not have (account IDs,
|
||||||
|
workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw
|
||||||
|
IDs or technical identifiers — they cannot memorise them.
|
||||||
|
|
||||||
|
Instead, follow this discovery pattern:
|
||||||
|
1. Call a listing/discovery tool to find available options.
|
||||||
|
2. ONE result → use it silently, no question to the user.
|
||||||
|
3. MULTIPLE results → present the options by their display names and let the
|
||||||
|
user choose. Never show raw UUIDs — always use friendly names.
|
||||||
|
|
||||||
|
Discovery tools by level:
|
||||||
|
- Which account/workspace? → get_connected_accounts("<service>")
|
||||||
|
- Which Jira site (cloudId)? → getAccessibleAtlassianResources
|
||||||
|
- Which Jira project? → getVisibleJiraProjects (after resolving cloudId)
|
||||||
|
- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project)
|
||||||
|
- Which channel? → slack_search_channels
|
||||||
|
- Which base? → list_bases
|
||||||
|
- Which table? → list_tables_for_base (after resolving baseId)
|
||||||
|
- Which task? → clickup_search
|
||||||
|
- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira)
|
||||||
|
|
||||||
|
For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to
|
||||||
|
obtain the cloudId, then pass it to other Jira tools. When creating an issue,
|
||||||
|
chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue.
|
||||||
|
If there is only one option at each step, use it silently. If multiple, present
|
||||||
|
friendly names.
|
||||||
|
|
||||||
|
Chain discovery when needed — e.g. for Airtable records: list_bases → pick
|
||||||
|
base → list_tables_for_base → pick table → list_records_for_table.
|
||||||
|
|
||||||
|
MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for
|
||||||
|
the same service, tool names are prefixed to avoid collisions — e.g.
|
||||||
|
linear_25_list_issues and linear_30_list_issues instead of two list_issues.
|
||||||
|
Each prefixed tool's description starts with [Account: <display_name>] so you
|
||||||
|
know which account it targets. Use get_connected_accounts("<service>") to see
|
||||||
|
the full list of accounts with their connector IDs and display names.
|
||||||
|
When only one account is connected, tools have their normal unprefixed names.
|
||||||
|
</parameter_resolution>
|
||||||
|
|
||||||
<memory_protocol>
|
<memory_protocol>
|
||||||
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||||
reveal durable facts about the team (decisions, conventions, architecture, processes,
|
reveal durable facts about the team (decisions, conventions, architecture, processes,
|
||||||
|
|
@ -450,6 +566,9 @@ _TOOL_INSTRUCTIONS["generate_resume"] = """
|
||||||
- WHEN NOT TO CALL: General career advice, resume tips, cover letters, or reviewing
|
- WHEN NOT TO CALL: General career advice, resume tips, cover letters, or reviewing
|
||||||
a resume without making changes. For cover letters, use generate_report instead.
|
a resume without making changes. For cover letters, use generate_report instead.
|
||||||
- The tool produces Typst source code that is compiled to a PDF preview automatically.
|
- The tool produces Typst source code that is compiled to a PDF preview automatically.
|
||||||
|
- PAGE POLICY:
|
||||||
|
- Default behavior is ONE PAGE. For new resume creation, set max_pages=1 unless the user explicitly asks for more.
|
||||||
|
- If the user requests a longer resume (e.g., "make it 2 pages"), set max_pages to that value.
|
||||||
- Args:
|
- Args:
|
||||||
- user_info: The user's resume content — work experience, education, skills, contact
|
- user_info: The user's resume content — work experience, education, skills, contact
|
||||||
info, etc. Can be structured or unstructured text.
|
info, etc. Can be structured or unstructured text.
|
||||||
|
|
@ -465,6 +584,7 @@ _TOOL_INSTRUCTIONS["generate_resume"] = """
|
||||||
"keep it to one page"). For revisions, describe what to change.
|
"keep it to one page"). For revisions, describe what to change.
|
||||||
- parent_report_id: Set this when the user wants to MODIFY an existing resume from
|
- parent_report_id: Set this when the user wants to MODIFY an existing resume from
|
||||||
this conversation. Use the report_id from a previous generate_resume result.
|
this conversation. Use the report_id from a previous generate_resume result.
|
||||||
|
- max_pages: Maximum resume length in pages (integer 1-5). Default is 1.
|
||||||
- Returns: Dict with status, report_id, title, and content_type.
|
- Returns: Dict with status, report_id, title, and content_type.
|
||||||
- After calling: Give a brief confirmation. Do NOT paste resume content in chat. Do NOT mention report_id or any internal IDs — the resume card is shown automatically.
|
- After calling: Give a brief confirmation. Do NOT paste resume content in chat. Do NOT mention report_id or any internal IDs — the resume card is shown automatically.
|
||||||
- VERSIONING: Same rules as generate_report — set parent_report_id for modifications
|
- VERSIONING: Same rules as generate_report — set parent_report_id for modifications
|
||||||
|
|
@ -473,17 +593,20 @@ _TOOL_INSTRUCTIONS["generate_resume"] = """
|
||||||
|
|
||||||
_TOOL_EXAMPLES["generate_resume"] = """
|
_TOOL_EXAMPLES["generate_resume"] = """
|
||||||
- User: "Build me a resume. I'm John Doe, engineer at Acme Corp..."
|
- User: "Build me a resume. I'm John Doe, engineer at Acme Corp..."
|
||||||
- Call: `generate_resume(user_info="John Doe, engineer at Acme Corp...")`
|
- Call: `generate_resume(user_info="John Doe, engineer at Acme Corp...", max_pages=1)`
|
||||||
- WHY: Has creation verb "build" + resume → call the tool.
|
- WHY: Has creation verb "build" + resume → call the tool.
|
||||||
- User: "Create my CV with this info: [experience, education, skills]"
|
- User: "Create my CV with this info: [experience, education, skills]"
|
||||||
- Call: `generate_resume(user_info="[experience, education, skills]")`
|
- Call: `generate_resume(user_info="[experience, education, skills]", max_pages=1)`
|
||||||
- User: "Build me a resume" (and there is a resume/CV document in the conversation context)
|
- User: "Build me a resume" (and there is a resume/CV document in the conversation context)
|
||||||
- Extract the FULL content from the document in context, then call:
|
- Extract the FULL content from the document in context, then call:
|
||||||
`generate_resume(user_info="Name: John Doe\\nEmail: john@example.com\\n\\nExperience:\\n- Senior Engineer at Acme Corp (2020-2024)\\n Led team of 5...\\n\\nEducation:\\n- BS Computer Science, MIT (2016-2020)\\n\\nSkills: Python, TypeScript, AWS...")`
|
`generate_resume(user_info="Name: John Doe\\nEmail: john@example.com\\n\\nExperience:\\n- Senior Engineer at Acme Corp (2020-2024)\\n Led team of 5...\\n\\nEducation:\\n- BS Computer Science, MIT (2016-2020)\\n\\nSkills: Python, TypeScript, AWS...", max_pages=1)`
|
||||||
- WHY: Document content is available in context — extract ALL of it into user_info. Do NOT ignore referenced documents.
|
- WHY: Document content is available in context — extract ALL of it into user_info. Do NOT ignore referenced documents.
|
||||||
- User: (after resume generated) "Change my title to Senior Engineer"
|
- User: (after resume generated) "Change my title to Senior Engineer"
|
||||||
- Call: `generate_resume(user_info="", user_instructions="Change the job title to Senior Engineer", parent_report_id=<previous_report_id>)`
|
- Call: `generate_resume(user_info="", user_instructions="Change the job title to Senior Engineer", parent_report_id=<previous_report_id>, max_pages=1)`
|
||||||
- WHY: Modification verb "change" + refers to existing resume → set parent_report_id.
|
- WHY: Modification verb "change" + refers to existing resume → set parent_report_id.
|
||||||
|
- User: (after resume generated) "Make this 2 pages and expand projects"
|
||||||
|
- Call: `generate_resume(user_info="", user_instructions="Expand projects and keep this to at most 2 pages", parent_report_id=<previous_report_id>, max_pages=2)`
|
||||||
|
- WHY: Explicit page increase request → set max_pages to 2.
|
||||||
- User: "How should I structure my resume?"
|
- User: "How should I structure my resume?"
|
||||||
- Do NOT call generate_resume. Answer in chat with advice.
|
- Do NOT call generate_resume. Answer in chat with advice.
|
||||||
- WHY: No creation/modification verb.
|
- WHY: No creation/modification verb.
|
||||||
|
|
@ -692,11 +815,36 @@ Your goal is to provide helpful, informative answers in a clean, readable format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_mcp_routing_block(
|
||||||
|
mcp_connector_tools: dict[str, list[str]] | None,
|
||||||
|
) -> str:
|
||||||
|
"""Build an additional tool routing block for generic MCP connectors.
|
||||||
|
|
||||||
|
When users add MCP servers (e.g. GitLab, GitHub), the LLM needs to know
|
||||||
|
those tools exist and should be called directly — not searched in the
|
||||||
|
knowledge base.
|
||||||
|
"""
|
||||||
|
if not mcp_connector_tools:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"\n<mcp_tool_routing>",
|
||||||
|
"You also have direct tools from these user-connected MCP servers.",
|
||||||
|
"Their data is NEVER in the knowledge base — call their tools directly.",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
for server_name, tool_names in mcp_connector_tools.items():
|
||||||
|
lines.append(f"- {server_name} → {', '.join(tool_names)}")
|
||||||
|
lines.append("</mcp_tool_routing>\n")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def build_surfsense_system_prompt(
|
def build_surfsense_system_prompt(
|
||||||
today: datetime | None = None,
|
today: datetime | None = None,
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
enabled_tool_names: set[str] | None = None,
|
enabled_tool_names: set[str] | None = None,
|
||||||
disabled_tool_names: set[str] | None = None,
|
disabled_tool_names: set[str] | None = None,
|
||||||
|
mcp_connector_tools: dict[str, list[str]] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build the SurfSense system prompt with default settings.
|
Build the SurfSense system prompt with default settings.
|
||||||
|
|
@ -711,6 +859,9 @@ def build_surfsense_system_prompt(
|
||||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||||
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
||||||
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
||||||
|
mcp_connector_tools: Mapping of MCP server display name → list of tool names
|
||||||
|
for generic MCP connectors. Injected into the system prompt so the LLM
|
||||||
|
knows to call these tools directly.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Complete system prompt string
|
Complete system prompt string
|
||||||
|
|
@ -718,6 +869,7 @@ def build_surfsense_system_prompt(
|
||||||
|
|
||||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
system_instructions = _get_system_instructions(visibility, today)
|
system_instructions = _get_system_instructions(visibility, today)
|
||||||
|
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
|
||||||
tools_instructions = _get_tools_instructions(
|
tools_instructions = _get_tools_instructions(
|
||||||
visibility, enabled_tool_names, disabled_tool_names
|
visibility, enabled_tool_names, disabled_tool_names
|
||||||
)
|
)
|
||||||
|
|
@ -733,6 +885,7 @@ def build_configurable_system_prompt(
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
enabled_tool_names: set[str] | None = None,
|
enabled_tool_names: set[str] | None = None,
|
||||||
disabled_tool_names: set[str] | None = None,
|
disabled_tool_names: set[str] | None = None,
|
||||||
|
mcp_connector_tools: dict[str, list[str]] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
||||||
|
|
@ -754,6 +907,9 @@ def build_configurable_system_prompt(
|
||||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||||
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
||||||
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
||||||
|
mcp_connector_tools: Mapping of MCP server display name → list of tool names
|
||||||
|
for generic MCP connectors. Injected into the system prompt so the LLM
|
||||||
|
knows to call these tools directly.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Complete system prompt string
|
Complete system prompt string
|
||||||
|
|
@ -771,6 +927,8 @@ def build_configurable_system_prompt(
|
||||||
else:
|
else:
|
||||||
system_instructions = ""
|
system_instructions = ""
|
||||||
|
|
||||||
|
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
|
||||||
|
|
||||||
# Tools instructions: only include enabled tools, note disabled ones
|
# Tools instructions: only include enabled tools, note disabled ones
|
||||||
tools_instructions = _get_tools_instructions(
|
tools_instructions = _get_tools_instructions(
|
||||||
thread_visibility, enabled_tool_names, disabled_tool_names
|
thread_visibility, enabled_tool_names, disabled_tool_names
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""Connected-accounts discovery tool.
|
||||||
|
|
||||||
|
Lets the LLM discover which accounts are connected for a given service
|
||||||
|
(e.g. "jira", "linear", "slack") and retrieve the metadata it needs to
|
||||||
|
call action tools — such as Jira's ``cloudId``.
|
||||||
|
|
||||||
|
The tool returns **only** non-sensitive fields explicitly listed in the
|
||||||
|
service's ``account_metadata_keys`` (see ``registry.py``), plus the
|
||||||
|
always-present ``display_name`` and ``connector_id``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import StructuredTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
from app.services.mcp_oauth.registry import MCP_SERVICES
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SERVICE_KEY_BY_CONNECTOR_TYPE: dict[str, str] = {
|
||||||
|
cfg.connector_type: key for key, cfg in MCP_SERVICES.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GetConnectedAccountsInput(BaseModel):
|
||||||
|
service: str = Field(
|
||||||
|
description=(
|
||||||
|
"Service key to look up connected accounts for. "
|
||||||
|
"Valid values: " + ", ".join(sorted(MCP_SERVICES.keys()))
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_display_name(connector: SearchSourceConnector) -> str:
|
||||||
|
"""Best-effort human-readable label for a connector."""
|
||||||
|
cfg = connector.config or {}
|
||||||
|
if cfg.get("display_name"):
|
||||||
|
return cfg["display_name"]
|
||||||
|
if cfg.get("base_url"):
|
||||||
|
return f"{connector.name} ({cfg['base_url']})"
|
||||||
|
if cfg.get("organization_name"):
|
||||||
|
return f"{connector.name} ({cfg['organization_name']})"
|
||||||
|
return connector.name
|
||||||
|
|
||||||
|
|
||||||
|
def create_get_connected_accounts_tool(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str,
|
||||||
|
) -> StructuredTool:
|
||||||
|
|
||||||
|
async def _run(service: str) -> list[dict[str, Any]]:
|
||||||
|
svc_cfg = MCP_SERVICES.get(service)
|
||||||
|
if not svc_cfg:
|
||||||
|
return [{"error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}"}]
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector_type = SearchSourceConnectorType(svc_cfg.connector_type)
|
||||||
|
except ValueError:
|
||||||
|
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type == connector_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connectors = result.scalars().all()
|
||||||
|
|
||||||
|
if not connectors:
|
||||||
|
return [{"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."}]
|
||||||
|
|
||||||
|
is_multi = len(connectors) > 1
|
||||||
|
|
||||||
|
accounts: list[dict[str, Any]] = []
|
||||||
|
for conn in connectors:
|
||||||
|
cfg = conn.config or {}
|
||||||
|
entry: dict[str, Any] = {
|
||||||
|
"connector_id": conn.id,
|
||||||
|
"display_name": _extract_display_name(conn),
|
||||||
|
"service": service,
|
||||||
|
}
|
||||||
|
if is_multi:
|
||||||
|
entry["tool_prefix"] = f"{service}_{conn.id}"
|
||||||
|
for key in svc_cfg.account_metadata_keys:
|
||||||
|
if key in cfg:
|
||||||
|
entry[key] = cfg[key]
|
||||||
|
accounts.append(entry)
|
||||||
|
|
||||||
|
return accounts
|
||||||
|
|
||||||
|
return StructuredTool(
|
||||||
|
name="get_connected_accounts",
|
||||||
|
description=(
|
||||||
|
"Discover which accounts are connected for a service (e.g. jira, linear, slack, clickup, airtable). "
|
||||||
|
"Returns display names and service-specific metadata the action tools need "
|
||||||
|
"(e.g. Jira's cloudId). Call this BEFORE using a service's action tools when "
|
||||||
|
"you need an account identifier or are unsure which account to use."
|
||||||
|
),
|
||||||
|
coroutine=_run,
|
||||||
|
args_schema=GetConnectedAccountsInput,
|
||||||
|
metadata={"hitl": False},
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
from app.agents.new_chat.tools.discord.list_channels import (
|
||||||
|
create_list_discord_channels_tool,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.tools.discord.read_messages import (
|
||||||
|
create_read_discord_messages_tool,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.tools.discord.send_message import (
|
||||||
|
create_send_discord_message_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_list_discord_channels_tool",
|
||||||
|
"create_read_discord_messages_tool",
|
||||||
|
"create_send_discord_message_tool",
|
||||||
|
]
|
||||||
42
surfsense_backend/app/agents/new_chat/tools/discord/_auth.py
Normal file
42
surfsense_backend/app/agents/new_chat/tools/discord/_auth.py
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
"""Shared auth helper for Discord agent tools (REST API, not gateway bot)."""
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
from app.utils.oauth_security import TokenEncryption
|
||||||
|
|
||||||
|
DISCORD_API = "https://discord.com/api/v10"
|
||||||
|
|
||||||
|
|
||||||
|
async def get_discord_connector(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str,
|
||||||
|
) -> SearchSourceConnector | None:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_bot_token(connector: SearchSourceConnector) -> str:
|
||||||
|
"""Extract and decrypt the bot token from connector config."""
|
||||||
|
cfg = dict(connector.config)
|
||||||
|
if cfg.get("_token_encrypted") and config.SECRET_KEY:
|
||||||
|
enc = TokenEncryption(config.SECRET_KEY)
|
||||||
|
if cfg.get("bot_token"):
|
||||||
|
cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"])
|
||||||
|
token = cfg.get("bot_token")
|
||||||
|
if not token:
|
||||||
|
raise ValueError("Discord bot token not found in connector config.")
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def get_guild_id(connector: SearchSourceConnector) -> str | None:
|
||||||
|
return connector.config.get("guild_id")
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_list_discord_channels_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def list_discord_channels() -> dict[str, Any]:
|
||||||
|
"""List text channels in the connected Discord server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status and a list of channels (id, name).
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Discord connector found."}
|
||||||
|
|
||||||
|
guild_id = get_guild_id(connector)
|
||||||
|
if not guild_id:
|
||||||
|
return {"status": "error", "message": "No guild ID in Discord connector config."}
|
||||||
|
|
||||||
|
token = get_bot_token(connector)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{DISCORD_API}/guilds/{guild_id}/channels",
|
||||||
|
headers={"Authorization": f"Bot {token}"},
|
||||||
|
timeout=15.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||||
|
|
||||||
|
# Type 0 = text channel
|
||||||
|
channels = [
|
||||||
|
{"id": ch["id"], "name": ch["name"]}
|
||||||
|
for ch in resp.json()
|
||||||
|
if ch.get("type") == 0
|
||||||
|
]
|
||||||
|
return {"status": "success", "guild_id": guild_id, "channels": channels, "total": len(channels)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error listing Discord channels: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to list Discord channels."}
|
||||||
|
|
||||||
|
return list_discord_channels
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_read_discord_messages_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def read_discord_messages(
|
||||||
|
channel_id: str,
|
||||||
|
limit: int = 25,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Read recent messages from a Discord text channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: The Discord channel ID (from list_discord_channels).
|
||||||
|
limit: Number of messages to fetch (default 25, max 50).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status and a list of messages including
|
||||||
|
id, author, content, timestamp.
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||||
|
|
||||||
|
limit = min(limit, 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Discord connector found."}
|
||||||
|
|
||||||
|
token = get_bot_token(connector)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{DISCORD_API}/channels/{channel_id}/messages",
|
||||||
|
headers={"Authorization": f"Bot {token}"},
|
||||||
|
params={"limit": limit},
|
||||||
|
timeout=15.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||||
|
if resp.status_code == 403:
|
||||||
|
return {"status": "error", "message": "Bot lacks permission to read this channel."}
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"id": m["id"],
|
||||||
|
"author": m.get("author", {}).get("username", "Unknown"),
|
||||||
|
"content": m.get("content", ""),
|
||||||
|
"timestamp": m.get("timestamp", ""),
|
||||||
|
}
|
||||||
|
for m in resp.json()
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"status": "success", "channel_id": channel_id, "messages": messages, "total": len(messages)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error reading Discord messages: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to read Discord messages."}
|
||||||
|
|
||||||
|
return read_discord_messages
|
||||||
|
|
@ -0,0 +1,96 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
|
||||||
|
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_send_discord_message_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def send_discord_message(
|
||||||
|
channel_id: str,
|
||||||
|
content: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send a message to a Discord text channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: The Discord channel ID (from list_discord_channels).
|
||||||
|
content: The message text (max 2000 characters).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status, message_id on success.
|
||||||
|
|
||||||
|
IMPORTANT:
|
||||||
|
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||||
|
|
||||||
|
if len(content) > 2000:
|
||||||
|
return {"status": "error", "message": "Message exceeds Discord's 2000-character limit."}
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Discord connector found."}
|
||||||
|
|
||||||
|
result = request_approval(
|
||||||
|
action_type="discord_send_message",
|
||||||
|
tool_name="send_discord_message",
|
||||||
|
params={"channel_id": channel_id, "content": content},
|
||||||
|
context={"connector_id": connector.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rejected:
|
||||||
|
return {"status": "rejected", "message": "User declined. Message was not sent."}
|
||||||
|
|
||||||
|
final_content = result.params.get("content", content)
|
||||||
|
final_channel = result.params.get("channel_id", channel_id)
|
||||||
|
|
||||||
|
token = get_bot_token(connector)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{DISCORD_API}/channels/{final_channel}/messages",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bot {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={"content": final_content},
|
||||||
|
timeout=15.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||||
|
if resp.status_code == 403:
|
||||||
|
return {"status": "error", "message": "Bot lacks permission to send messages in this channel."}
|
||||||
|
if resp.status_code not in (200, 201):
|
||||||
|
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||||
|
|
||||||
|
msg_data = resp.json()
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message_id": msg_data.get("id"),
|
||||||
|
"message": f"Message sent to channel {final_channel}.",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error sending Discord message: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to send Discord message."}
|
||||||
|
|
||||||
|
return send_discord_message
|
||||||
|
|
@ -1,6 +1,12 @@
|
||||||
from app.agents.new_chat.tools.gmail.create_draft import (
|
from app.agents.new_chat.tools.gmail.create_draft import (
|
||||||
create_create_gmail_draft_tool,
|
create_create_gmail_draft_tool,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.tools.gmail.read_email import (
|
||||||
|
create_read_gmail_email_tool,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||||
|
create_search_gmail_tool,
|
||||||
|
)
|
||||||
from app.agents.new_chat.tools.gmail.send_email import (
|
from app.agents.new_chat.tools.gmail.send_email import (
|
||||||
create_send_gmail_email_tool,
|
create_send_gmail_email_tool,
|
||||||
)
|
)
|
||||||
|
|
@ -13,6 +19,8 @@ from app.agents.new_chat.tools.gmail.update_draft import (
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"create_create_gmail_draft_tool",
|
"create_create_gmail_draft_tool",
|
||||||
|
"create_read_gmail_email_tool",
|
||||||
|
"create_search_gmail_tool",
|
||||||
"create_send_gmail_email_tool",
|
"create_send_gmail_email_tool",
|
||||||
"create_trash_gmail_email_tool",
|
"create_trash_gmail_email_tool",
|
||||||
"create_update_gmail_draft_tool",
|
"create_update_gmail_draft_tool",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GMAIL_TYPES = [
|
||||||
|
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_read_gmail_email_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def read_gmail_email(message_id: str) -> dict[str, Any]:
|
||||||
|
"""Read the full content of a specific Gmail email by its message ID.
|
||||||
|
|
||||||
|
Use after search_gmail to get the complete body of an email.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id: The Gmail message ID (from search_gmail results).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status and the full email content formatted as markdown.
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||||
|
|
||||||
|
creds = _build_credentials(connector)
|
||||||
|
|
||||||
|
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||||
|
|
||||||
|
gmail = GoogleGmailConnector(
|
||||||
|
credentials=creds,
|
||||||
|
session=db_session,
|
||||||
|
user_id=user_id,
|
||||||
|
connector_id=connector.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
detail, error = await gmail.get_message_details(message_id)
|
||||||
|
if error:
|
||||||
|
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
|
||||||
|
return {"status": "auth_error", "message": error, "connector_type": "gmail"}
|
||||||
|
return {"status": "error", "message": error}
|
||||||
|
|
||||||
|
if not detail:
|
||||||
|
return {"status": "not_found", "message": f"Email with ID '{message_id}' not found."}
|
||||||
|
|
||||||
|
content = gmail.format_message_to_markdown(detail)
|
||||||
|
|
||||||
|
return {"status": "success", "message_id": message_id, "content": content}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error reading Gmail email: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to read email. Please try again."}
|
||||||
|
|
||||||
|
return read_gmail_email
|
||||||
|
|
@ -0,0 +1,165 @@
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GMAIL_TYPES = [
|
||||||
|
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||||
|
]
|
||||||
|
|
||||||
|
_token_encryption_cache: object | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_token_encryption():
|
||||||
|
global _token_encryption_cache
|
||||||
|
if _token_encryption_cache is None:
|
||||||
|
from app.config import config
|
||||||
|
from app.utils.oauth_security import TokenEncryption
|
||||||
|
|
||||||
|
if not config.SECRET_KEY:
|
||||||
|
raise RuntimeError("SECRET_KEY not configured for token decryption.")
|
||||||
|
_token_encryption_cache = TokenEncryption(config.SECRET_KEY)
|
||||||
|
return _token_encryption_cache
|
||||||
|
|
||||||
|
|
||||||
|
def _build_credentials(connector: SearchSourceConnector):
|
||||||
|
"""Build Google OAuth Credentials from a connector's stored config.
|
||||||
|
|
||||||
|
Handles both native OAuth connectors (with encrypted tokens) and
|
||||||
|
Composio-backed connectors. Shared by Gmail and Calendar tools.
|
||||||
|
"""
|
||||||
|
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
|
|
||||||
|
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||||
|
from app.utils.google_credentials import build_composio_credentials
|
||||||
|
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
raise ValueError("Composio connected account ID not found.")
|
||||||
|
return build_composio_credentials(cca_id)
|
||||||
|
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
cfg = dict(connector.config)
|
||||||
|
if cfg.get("_token_encrypted"):
|
||||||
|
enc = _get_token_encryption()
|
||||||
|
for key in ("token", "refresh_token", "client_secret"):
|
||||||
|
if cfg.get(key):
|
||||||
|
cfg[key] = enc.decrypt_token(cfg[key])
|
||||||
|
|
||||||
|
exp = (cfg.get("expiry") or "").replace("Z", "")
|
||||||
|
return Credentials(
|
||||||
|
token=cfg.get("token"),
|
||||||
|
refresh_token=cfg.get("refresh_token"),
|
||||||
|
token_uri=cfg.get("token_uri"),
|
||||||
|
client_id=cfg.get("client_id"),
|
||||||
|
client_secret=cfg.get("client_secret"),
|
||||||
|
scopes=cfg.get("scopes", []),
|
||||||
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_search_gmail_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def search_gmail(
|
||||||
|
query: str,
|
||||||
|
max_results: int = 10,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Search emails in the user's Gmail inbox using Gmail search syntax.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Gmail search query, same syntax as the Gmail search bar.
|
||||||
|
Examples: "from:alice@example.com", "subject:meeting",
|
||||||
|
"is:unread", "after:2024/01/01 before:2024/02/01",
|
||||||
|
"has:attachment", "in:sent".
|
||||||
|
max_results: Number of emails to return (default 10, max 20).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status and a list of email summaries including
|
||||||
|
message_id, subject, from, date, snippet.
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||||
|
|
||||||
|
max_results = min(max_results, 20)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||||
|
}
|
||||||
|
|
||||||
|
creds = _build_credentials(connector)
|
||||||
|
|
||||||
|
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||||
|
|
||||||
|
gmail = GoogleGmailConnector(
|
||||||
|
credentials=creds,
|
||||||
|
session=db_session,
|
||||||
|
user_id=user_id,
|
||||||
|
connector_id=connector.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages_list, error = await gmail.get_messages_list(
|
||||||
|
max_results=max_results, query=query
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
|
||||||
|
return {"status": "auth_error", "message": error, "connector_type": "gmail"}
|
||||||
|
return {"status": "error", "message": error}
|
||||||
|
|
||||||
|
if not messages_list:
|
||||||
|
return {"status": "success", "emails": [], "total": 0, "message": "No emails found."}
|
||||||
|
|
||||||
|
emails = []
|
||||||
|
for msg in messages_list:
|
||||||
|
detail, err = await gmail.get_message_details(msg["id"])
|
||||||
|
if err:
|
||||||
|
continue
|
||||||
|
headers = {
|
||||||
|
h["name"].lower(): h["value"]
|
||||||
|
for h in detail.get("payload", {}).get("headers", [])
|
||||||
|
}
|
||||||
|
emails.append({
|
||||||
|
"message_id": detail.get("id"),
|
||||||
|
"thread_id": detail.get("threadId"),
|
||||||
|
"subject": headers.get("subject", "No Subject"),
|
||||||
|
"from": headers.get("from", "Unknown"),
|
||||||
|
"to": headers.get("to", ""),
|
||||||
|
"date": headers.get("date", ""),
|
||||||
|
"snippet": detail.get("snippet", ""),
|
||||||
|
"labels": detail.get("labelIds", []),
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"status": "success", "emails": emails, "total": len(emails)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error searching Gmail: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to search Gmail. Please try again."}
|
||||||
|
|
||||||
|
return search_gmail
|
||||||
|
|
@ -4,6 +4,9 @@ from app.agents.new_chat.tools.google_calendar.create_event import (
|
||||||
from app.agents.new_chat.tools.google_calendar.delete_event import (
|
from app.agents.new_chat.tools.google_calendar.delete_event import (
|
||||||
create_delete_calendar_event_tool,
|
create_delete_calendar_event_tool,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.tools.google_calendar.search_events import (
|
||||||
|
create_search_calendar_events_tool,
|
||||||
|
)
|
||||||
from app.agents.new_chat.tools.google_calendar.update_event import (
|
from app.agents.new_chat.tools.google_calendar.update_event import (
|
||||||
create_update_calendar_event_tool,
|
create_update_calendar_event_tool,
|
||||||
)
|
)
|
||||||
|
|
@ -11,5 +14,6 @@ from app.agents.new_chat.tools.google_calendar.update_event import (
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"create_create_calendar_event_tool",
|
"create_create_calendar_event_tool",
|
||||||
"create_delete_calendar_event_tool",
|
"create_delete_calendar_event_tool",
|
||||||
|
"create_search_calendar_events_tool",
|
||||||
"create_update_calendar_event_tool",
|
"create_update_calendar_event_tool",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_CALENDAR_TYPES = [
|
||||||
|
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_search_calendar_events_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def search_calendar_events(
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
max_results: int = 25,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Search Google Calendar events within a date range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date in YYYY-MM-DD format (e.g. "2026-04-01").
|
||||||
|
end_date: End date in YYYY-MM-DD format (e.g. "2026-04-30").
|
||||||
|
max_results: Maximum number of events to return (default 25, max 50).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status and a list of events including
|
||||||
|
event_id, summary, start, end, location, attendees.
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Calendar tool not properly configured."}
|
||||||
|
|
||||||
|
max_results = min(max_results, 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||||
|
}
|
||||||
|
|
||||||
|
creds = _build_credentials(connector)
|
||||||
|
|
||||||
|
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
||||||
|
|
||||||
|
cal = GoogleCalendarConnector(
|
||||||
|
credentials=creds,
|
||||||
|
session=db_session,
|
||||||
|
user_id=user_id,
|
||||||
|
connector_id=connector.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
events_raw, error = await cal.get_all_primary_calendar_events(
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
|
||||||
|
return {"status": "auth_error", "message": error, "connector_type": "google_calendar"}
|
||||||
|
if "no events found" in error.lower():
|
||||||
|
return {"status": "success", "events": [], "total": 0, "message": error}
|
||||||
|
return {"status": "error", "message": error}
|
||||||
|
|
||||||
|
events = []
|
||||||
|
for ev in events_raw:
|
||||||
|
start = ev.get("start", {})
|
||||||
|
end = ev.get("end", {})
|
||||||
|
attendees_raw = ev.get("attendees", [])
|
||||||
|
events.append({
|
||||||
|
"event_id": ev.get("id"),
|
||||||
|
"summary": ev.get("summary", "No Title"),
|
||||||
|
"start": start.get("dateTime") or start.get("date", ""),
|
||||||
|
"end": end.get("dateTime") or end.get("date", ""),
|
||||||
|
"location": ev.get("location", ""),
|
||||||
|
"description": ev.get("description", ""),
|
||||||
|
"html_link": ev.get("htmlLink", ""),
|
||||||
|
"attendees": [
|
||||||
|
a.get("email", "") for a in attendees_raw[:10]
|
||||||
|
],
|
||||||
|
"status": ev.get("status", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"status": "success", "events": events, "total": len(events)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error searching calendar events: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to search calendar events. Please try again."}
|
||||||
|
|
||||||
|
return search_calendar_events
|
||||||
|
|
@ -130,8 +130,8 @@ def request_approval(
|
||||||
try:
|
try:
|
||||||
decision_type, edited_params = _parse_decision(approval)
|
decision_type, edited_params = _parse_decision(approval)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning("No approval decision received for %s", tool_name)
|
logger.warning("No approval decision received for %s — rejecting for safety", tool_name)
|
||||||
return HITLResult(rejected=False, decision_type="error", params=params)
|
return HITLResult(rejected=True, decision_type="error", params=params)
|
||||||
|
|
||||||
logger.info("User decision for %s: %s", tool_name, decision_type)
|
logger.info("User decision for %s: %s", tool_name, decision_type)
|
||||||
|
|
||||||
|
|
|
||||||
15
surfsense_backend/app/agents/new_chat/tools/luma/__init__.py
Normal file
15
surfsense_backend/app/agents/new_chat/tools/luma/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
from app.agents.new_chat.tools.luma.create_event import (
|
||||||
|
create_create_luma_event_tool,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.tools.luma.list_events import (
|
||||||
|
create_list_luma_events_tool,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.tools.luma.read_event import (
|
||||||
|
create_read_luma_event_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_create_luma_event_tool",
|
||||||
|
"create_list_luma_events_tool",
|
||||||
|
"create_read_luma_event_tool",
|
||||||
|
]
|
||||||
38
surfsense_backend/app/agents/new_chat/tools/luma/_auth.py
Normal file
38
surfsense_backend/app/agents/new_chat/tools/luma/_auth.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
"""Shared auth helper for Luma agent tools."""
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
LUMA_API = "https://public-api.luma.com/v1"
|
||||||
|
|
||||||
|
|
||||||
|
async def get_luma_connector(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str,
|
||||||
|
) -> SearchSourceConnector | None:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_key(connector: SearchSourceConnector) -> str:
|
||||||
|
"""Extract the API key from connector config (handles both key names)."""
|
||||||
|
key = connector.config.get("api_key") or connector.config.get("LUMA_API_KEY")
|
||||||
|
if not key:
|
||||||
|
raise ValueError("Luma API key not found in connector config.")
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def luma_headers(api_key: str) -> dict[str, str]:
|
||||||
|
return {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"x-luma-api-key": api_key,
|
||||||
|
}
|
||||||
116
surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
Normal file
116
surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
Normal file
|
|
@ -0,0 +1,116 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
|
||||||
|
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_create_luma_event_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def create_luma_event(
|
||||||
|
name: str,
|
||||||
|
start_at: str,
|
||||||
|
end_at: str,
|
||||||
|
description: str | None = None,
|
||||||
|
timezone: str = "UTC",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create a new event on Luma.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The event title.
|
||||||
|
start_at: Start time in ISO 8601 format (e.g. "2026-05-01T18:00:00").
|
||||||
|
end_at: End time in ISO 8601 format (e.g. "2026-05-01T20:00:00").
|
||||||
|
description: Optional event description (markdown supported).
|
||||||
|
timezone: Timezone string (default "UTC", e.g. "America/New_York").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status, event_id on success.
|
||||||
|
|
||||||
|
IMPORTANT:
|
||||||
|
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Luma connector found."}
|
||||||
|
|
||||||
|
result = request_approval(
|
||||||
|
action_type="luma_create_event",
|
||||||
|
tool_name="create_luma_event",
|
||||||
|
params={
|
||||||
|
"name": name,
|
||||||
|
"start_at": start_at,
|
||||||
|
"end_at": end_at,
|
||||||
|
"description": description,
|
||||||
|
"timezone": timezone,
|
||||||
|
},
|
||||||
|
context={"connector_id": connector.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rejected:
|
||||||
|
return {"status": "rejected", "message": "User declined. Event was not created."}
|
||||||
|
|
||||||
|
final_name = result.params.get("name", name)
|
||||||
|
final_start = result.params.get("start_at", start_at)
|
||||||
|
final_end = result.params.get("end_at", end_at)
|
||||||
|
final_desc = result.params.get("description", description)
|
||||||
|
final_tz = result.params.get("timezone", timezone)
|
||||||
|
|
||||||
|
api_key = get_api_key(connector)
|
||||||
|
headers = luma_headers(api_key)
|
||||||
|
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
"name": final_name,
|
||||||
|
"start_at": final_start,
|
||||||
|
"end_at": final_end,
|
||||||
|
"timezone": final_tz,
|
||||||
|
}
|
||||||
|
if final_desc:
|
||||||
|
body["description_md"] = final_desc
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{LUMA_API}/event/create",
|
||||||
|
headers=headers,
|
||||||
|
json=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
|
||||||
|
if resp.status_code == 403:
|
||||||
|
return {"status": "error", "message": "Luma Plus subscription required to create events via API."}
|
||||||
|
if resp.status_code not in (200, 201):
|
||||||
|
return {"status": "error", "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}"}
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"event_id": event_id,
|
||||||
|
"message": f"Event '{final_name}' created on Luma.",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error creating Luma event: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to create Luma event."}
|
||||||
|
|
||||||
|
return create_luma_event
|
||||||
100
surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
Normal file
100
surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_list_luma_events_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def list_luma_events(
|
||||||
|
max_results: int = 25,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""List upcoming and recent Luma events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_results: Maximum events to return (default 25, max 50).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status and a list of events including
|
||||||
|
event_id, name, start_at, end_at, location, url.
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||||
|
|
||||||
|
max_results = min(max_results, 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Luma connector found."}
|
||||||
|
|
||||||
|
api_key = get_api_key(connector)
|
||||||
|
headers = luma_headers(api_key)
|
||||||
|
|
||||||
|
all_entries: list[dict] = []
|
||||||
|
cursor = None
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
|
while len(all_entries) < max_results:
|
||||||
|
params: dict[str, Any] = {"limit": min(100, max_results - len(all_entries))}
|
||||||
|
if cursor:
|
||||||
|
params["cursor"] = cursor
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
f"{LUMA_API}/calendar/list-events",
|
||||||
|
headers=headers,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return {"status": "error", "message": f"Luma API error: {resp.status_code}"}
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
entries = data.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
break
|
||||||
|
all_entries.extend(entries)
|
||||||
|
|
||||||
|
next_cursor = data.get("next_cursor")
|
||||||
|
if not next_cursor:
|
||||||
|
break
|
||||||
|
cursor = next_cursor
|
||||||
|
|
||||||
|
events = []
|
||||||
|
for entry in all_entries[:max_results]:
|
||||||
|
ev = entry.get("event", {})
|
||||||
|
geo = ev.get("geo_info", {})
|
||||||
|
events.append({
|
||||||
|
"event_id": entry.get("api_id"),
|
||||||
|
"name": ev.get("name", "Untitled"),
|
||||||
|
"start_at": ev.get("start_at", ""),
|
||||||
|
"end_at": ev.get("end_at", ""),
|
||||||
|
"timezone": ev.get("timezone", ""),
|
||||||
|
"location": geo.get("name", ""),
|
||||||
|
"url": ev.get("url", ""),
|
||||||
|
"visibility": ev.get("visibility", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"status": "success", "events": events, "total": len(events)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error listing Luma events: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to list Luma events."}
|
||||||
|
|
||||||
|
return list_luma_events
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_read_luma_event_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def read_luma_event(event_id: str) -> dict[str, Any]:
|
||||||
|
"""Read detailed information about a specific Luma event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id: The Luma event API ID (from list_luma_events).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status and full event details including
|
||||||
|
description, attendees count, meeting URL.
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Luma connector found."}
|
||||||
|
|
||||||
|
api_key = get_api_key(connector)
|
||||||
|
headers = luma_headers(api_key)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{LUMA_API}/events/{event_id}",
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
|
||||||
|
if resp.status_code == 404:
|
||||||
|
return {"status": "not_found", "message": f"Event '{event_id}' not found."}
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return {"status": "error", "message": f"Luma API error: {resp.status_code}"}
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
ev = data.get("event", data)
|
||||||
|
geo = ev.get("geo_info", {})
|
||||||
|
|
||||||
|
event_detail = {
|
||||||
|
"event_id": event_id,
|
||||||
|
"name": ev.get("name", ""),
|
||||||
|
"description": ev.get("description", ""),
|
||||||
|
"start_at": ev.get("start_at", ""),
|
||||||
|
"end_at": ev.get("end_at", ""),
|
||||||
|
"timezone": ev.get("timezone", ""),
|
||||||
|
"location_name": geo.get("name", ""),
|
||||||
|
"address": geo.get("address", ""),
|
||||||
|
"url": ev.get("url", ""),
|
||||||
|
"meeting_url": ev.get("meeting_url", ""),
|
||||||
|
"visibility": ev.get("visibility", ""),
|
||||||
|
"cover_url": ev.get("cover_url", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"status": "success", "event": event_detail}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error reading Luma event: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to read Luma event."}
|
||||||
|
|
||||||
|
return read_luma_event
|
||||||
|
|
@ -45,6 +45,18 @@ class MCPClient:
|
||||||
async def connect(self, max_retries: int = MAX_RETRIES):
|
async def connect(self, max_retries: int = MAX_RETRIES):
|
||||||
"""Connect to the MCP server and manage its lifecycle.
|
"""Connect to the MCP server and manage its lifecycle.
|
||||||
|
|
||||||
|
Retries only apply to the **connection** phase (spawning the process,
|
||||||
|
initialising the session). Once the session is yielded to the caller,
|
||||||
|
any exception raised by the caller propagates normally -- the context
|
||||||
|
manager will NOT retry after ``yield``.
|
||||||
|
|
||||||
|
Previous implementation wrapped both connection AND yield inside the
|
||||||
|
retry loop. Because ``@asynccontextmanager`` only allows a single
|
||||||
|
``yield``, a failure after yield caused the generator to attempt a
|
||||||
|
second yield on retry, triggering
|
||||||
|
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
|
||||||
|
the stdio subprocess.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
max_retries: Maximum number of connection retry attempts
|
max_retries: Maximum number of connection retry attempts
|
||||||
|
|
||||||
|
|
@ -57,26 +69,22 @@ class MCPClient:
|
||||||
"""
|
"""
|
||||||
last_error = None
|
last_error = None
|
||||||
delay = RETRY_DELAY
|
delay = RETRY_DELAY
|
||||||
|
connected = False
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
# Merge env vars with current environment
|
|
||||||
server_env = os.environ.copy()
|
server_env = os.environ.copy()
|
||||||
server_env.update(self.env)
|
server_env.update(self.env)
|
||||||
|
|
||||||
# Create server parameters with env
|
|
||||||
server_params = StdioServerParameters(
|
server_params = StdioServerParameters(
|
||||||
command=self.command, args=self.args, env=server_env
|
command=self.command, args=self.args, env=server_env
|
||||||
)
|
)
|
||||||
|
|
||||||
# Spawn server process and create session
|
|
||||||
# Note: Cannot combine these context managers because ClientSession
|
|
||||||
# needs the read/write streams from stdio_client
|
|
||||||
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
||||||
async with ClientSession(read, write) as session:
|
async with ClientSession(read, write) as session:
|
||||||
# Initialize the connection
|
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
self.session = session
|
self.session = session
|
||||||
|
connected = True
|
||||||
|
|
||||||
if attempt > 0:
|
if attempt > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -91,10 +99,16 @@ class MCPClient:
|
||||||
self.command,
|
self.command,
|
||||||
" ".join(self.args),
|
" ".join(self.args),
|
||||||
)
|
)
|
||||||
yield session
|
try:
|
||||||
return # Success, exit retry loop
|
yield session
|
||||||
|
finally:
|
||||||
|
self.session = None
|
||||||
|
return
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self.session = None
|
||||||
|
if connected:
|
||||||
|
raise
|
||||||
last_error = e
|
last_error = e
|
||||||
if attempt < max_retries - 1:
|
if attempt < max_retries - 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -105,7 +119,7 @@ class MCPClient:
|
||||||
delay,
|
delay,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
delay *= RETRY_BACKOFF # Exponential backoff
|
delay *= RETRY_BACKOFF
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to connect to MCP server after %d attempts: %s",
|
"Failed to connect to MCP server after %d attempts: %s",
|
||||||
|
|
@ -113,10 +127,7 @@ class MCPClient:
|
||||||
e,
|
e,
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
self.session = None
|
|
||||||
|
|
||||||
# All retries exhausted
|
|
||||||
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
|
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
|
||||||
if last_error:
|
if last_error:
|
||||||
error_msg += f": {last_error}"
|
error_msg += f": {last_error}"
|
||||||
|
|
@ -161,12 +172,18 @@ class MCPClient:
|
||||||
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
|
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
async def call_tool(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any],
|
||||||
|
timeout: float = 60.0,
|
||||||
|
) -> Any:
|
||||||
"""Call a tool on the MCP server.
|
"""Call a tool on the MCP server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_name: Name of the tool to call
|
tool_name: Name of the tool to call
|
||||||
arguments: Arguments to pass to the tool
|
arguments: Arguments to pass to the tool
|
||||||
|
timeout: Maximum seconds to wait for the tool to respond
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tool execution result
|
Tool execution result
|
||||||
|
|
@ -185,10 +202,11 @@ class MCPClient:
|
||||||
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
|
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call tools/call RPC method
|
response = await asyncio.wait_for(
|
||||||
response = await self.session.call_tool(tool_name, arguments=arguments)
|
self.session.call_tool(tool_name, arguments=arguments),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
# Extract content from response
|
|
||||||
result = []
|
result = []
|
||||||
for content in response.content:
|
for content in response.content:
|
||||||
if hasattr(content, "text"):
|
if hasattr(content, "text"):
|
||||||
|
|
@ -202,15 +220,17 @@ class MCPClient:
|
||||||
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
|
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
|
||||||
return result_str
|
return result_str
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
"MCP tool '%s' timed out after %.0fs", tool_name, timeout
|
||||||
|
)
|
||||||
|
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Handle validation errors from MCP server responses
|
|
||||||
# Some MCP servers (like server-memory) return extra fields not in their schema
|
|
||||||
if "Invalid structured content" in str(e):
|
if "Invalid structured content" in str(e):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"MCP server returned data not matching its schema, but continuing: %s",
|
"MCP server returned data not matching its schema, but continuing: %s",
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
# Try to extract result from error message or return a success message
|
|
||||||
return "Operation completed (server returned unexpected format)"
|
return "Operation completed (server returned unexpected format)"
|
||||||
raise
|
raise
|
||||||
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -50,6 +50,11 @@ from .confluence import (
|
||||||
create_delete_confluence_page_tool,
|
create_delete_confluence_page_tool,
|
||||||
create_update_confluence_page_tool,
|
create_update_confluence_page_tool,
|
||||||
)
|
)
|
||||||
|
from .discord import (
|
||||||
|
create_list_discord_channels_tool,
|
||||||
|
create_read_discord_messages_tool,
|
||||||
|
create_send_discord_message_tool,
|
||||||
|
)
|
||||||
from .dropbox import (
|
from .dropbox import (
|
||||||
create_create_dropbox_file_tool,
|
create_create_dropbox_file_tool,
|
||||||
create_delete_dropbox_file_tool,
|
create_delete_dropbox_file_tool,
|
||||||
|
|
@ -57,6 +62,8 @@ from .dropbox import (
|
||||||
from .generate_image import create_generate_image_tool
|
from .generate_image import create_generate_image_tool
|
||||||
from .gmail import (
|
from .gmail import (
|
||||||
create_create_gmail_draft_tool,
|
create_create_gmail_draft_tool,
|
||||||
|
create_read_gmail_email_tool,
|
||||||
|
create_search_gmail_tool,
|
||||||
create_send_gmail_email_tool,
|
create_send_gmail_email_tool,
|
||||||
create_trash_gmail_email_tool,
|
create_trash_gmail_email_tool,
|
||||||
create_update_gmail_draft_tool,
|
create_update_gmail_draft_tool,
|
||||||
|
|
@ -64,21 +71,18 @@ from .gmail import (
|
||||||
from .google_calendar import (
|
from .google_calendar import (
|
||||||
create_create_calendar_event_tool,
|
create_create_calendar_event_tool,
|
||||||
create_delete_calendar_event_tool,
|
create_delete_calendar_event_tool,
|
||||||
|
create_search_calendar_events_tool,
|
||||||
create_update_calendar_event_tool,
|
create_update_calendar_event_tool,
|
||||||
)
|
)
|
||||||
from .google_drive import (
|
from .google_drive import (
|
||||||
create_create_google_drive_file_tool,
|
create_create_google_drive_file_tool,
|
||||||
create_delete_google_drive_file_tool,
|
create_delete_google_drive_file_tool,
|
||||||
)
|
)
|
||||||
from .jira import (
|
from .connected_accounts import create_get_connected_accounts_tool
|
||||||
create_create_jira_issue_tool,
|
from .luma import (
|
||||||
create_delete_jira_issue_tool,
|
create_create_luma_event_tool,
|
||||||
create_update_jira_issue_tool,
|
create_list_luma_events_tool,
|
||||||
)
|
create_read_luma_event_tool,
|
||||||
from .linear import (
|
|
||||||
create_create_linear_issue_tool,
|
|
||||||
create_delete_linear_issue_tool,
|
|
||||||
create_update_linear_issue_tool,
|
|
||||||
)
|
)
|
||||||
from .mcp_tool import load_mcp_tools
|
from .mcp_tool import load_mcp_tools
|
||||||
from .notion import (
|
from .notion import (
|
||||||
|
|
@ -95,6 +99,11 @@ from .report import create_generate_report_tool
|
||||||
from .resume import create_generate_resume_tool
|
from .resume import create_generate_resume_tool
|
||||||
from .scrape_webpage import create_scrape_webpage_tool
|
from .scrape_webpage import create_scrape_webpage_tool
|
||||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||||
|
from .teams import (
|
||||||
|
create_list_teams_channels_tool,
|
||||||
|
create_read_teams_messages_tool,
|
||||||
|
create_send_teams_message_tool,
|
||||||
|
)
|
||||||
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
||||||
from .video_presentation import create_generate_video_presentation_tool
|
from .video_presentation import create_generate_video_presentation_tool
|
||||||
from .web_search import create_web_search_tool
|
from .web_search import create_web_search_tool
|
||||||
|
|
@ -114,6 +123,8 @@ class ToolDefinition:
|
||||||
factory: Callable that creates the tool. Receives a dict of dependencies.
|
factory: Callable that creates the tool. Receives a dict of dependencies.
|
||||||
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
|
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
|
||||||
enabled_by_default: Whether the tool is enabled when no explicit config is provided
|
enabled_by_default: Whether the tool is enabled when no explicit config is provided
|
||||||
|
required_connector: Searchable type string (e.g. ``"LINEAR_CONNECTOR"``)
|
||||||
|
that must be in ``available_connectors`` for the tool to be enabled.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -123,6 +134,7 @@ class ToolDefinition:
|
||||||
requires: list[str] = field(default_factory=list)
|
requires: list[str] = field(default_factory=list)
|
||||||
enabled_by_default: bool = True
|
enabled_by_default: bool = True
|
||||||
hidden: bool = False
|
hidden: bool = False
|
||||||
|
required_connector: str | None = None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -221,6 +233,21 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
requires=["db_session"],
|
requires=["db_session"],
|
||||||
),
|
),
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
# SERVICE ACCOUNT DISCOVERY
|
||||||
|
# Generic tool for the LLM to discover connected accounts and resolve
|
||||||
|
# service-specific identifiers (e.g. Jira cloudId, Slack team, etc.)
|
||||||
|
# =========================================================================
|
||||||
|
ToolDefinition(
|
||||||
|
name="get_connected_accounts",
|
||||||
|
description="Discover connected accounts for a service and their metadata",
|
||||||
|
factory=lambda deps: create_get_connected_accounts_tool(
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
),
|
||||||
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
),
|
||||||
|
# =========================================================================
|
||||||
# MEMORY TOOL - single update_memory, private or team by thread_visibility
|
# MEMORY TOOL - single update_memory, private or team by thread_visibility
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
|
|
@ -248,40 +275,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# LINEAR TOOLS - create, update, delete issues
|
|
||||||
# Auto-disabled when no Linear connector is configured (see chat_deepagent.py)
|
|
||||||
# =========================================================================
|
|
||||||
ToolDefinition(
|
|
||||||
name="create_linear_issue",
|
|
||||||
description="Create a new issue in the user's Linear workspace",
|
|
||||||
factory=lambda deps: create_create_linear_issue_tool(
|
|
||||||
db_session=deps["db_session"],
|
|
||||||
search_space_id=deps["search_space_id"],
|
|
||||||
user_id=deps["user_id"],
|
|
||||||
),
|
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
|
||||||
),
|
|
||||||
ToolDefinition(
|
|
||||||
name="update_linear_issue",
|
|
||||||
description="Update an existing indexed Linear issue",
|
|
||||||
factory=lambda deps: create_update_linear_issue_tool(
|
|
||||||
db_session=deps["db_session"],
|
|
||||||
search_space_id=deps["search_space_id"],
|
|
||||||
user_id=deps["user_id"],
|
|
||||||
),
|
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
|
||||||
),
|
|
||||||
ToolDefinition(
|
|
||||||
name="delete_linear_issue",
|
|
||||||
description="Archive (delete) an existing indexed Linear issue",
|
|
||||||
factory=lambda deps: create_delete_linear_issue_tool(
|
|
||||||
db_session=deps["db_session"],
|
|
||||||
search_space_id=deps["search_space_id"],
|
|
||||||
user_id=deps["user_id"],
|
|
||||||
),
|
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
|
||||||
),
|
|
||||||
# =========================================================================
|
|
||||||
# NOTION TOOLS - create, update, delete pages
|
# NOTION TOOLS - create, update, delete pages
|
||||||
# Auto-disabled when no Notion connector is configured (see chat_deepagent.py)
|
# Auto-disabled when no Notion connector is configured (see chat_deepagent.py)
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
@ -294,6 +287,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="NOTION_CONNECTOR",
|
||||||
),
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="update_notion_page",
|
name="update_notion_page",
|
||||||
|
|
@ -304,6 +298,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="NOTION_CONNECTOR",
|
||||||
),
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="delete_notion_page",
|
name="delete_notion_page",
|
||||||
|
|
@ -314,6 +309,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="NOTION_CONNECTOR",
|
||||||
),
|
),
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# GOOGLE DRIVE TOOLS - create files, delete files
|
# GOOGLE DRIVE TOOLS - create files, delete files
|
||||||
|
|
@ -328,6 +324,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="GOOGLE_DRIVE_FILE",
|
||||||
),
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="delete_google_drive_file",
|
name="delete_google_drive_file",
|
||||||
|
|
@ -338,6 +335,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="GOOGLE_DRIVE_FILE",
|
||||||
),
|
),
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# DROPBOX TOOLS - create and trash files
|
# DROPBOX TOOLS - create and trash files
|
||||||
|
|
@ -352,6 +350,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="DROPBOX_FILE",
|
||||||
),
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="delete_dropbox_file",
|
name="delete_dropbox_file",
|
||||||
|
|
@ -362,6 +361,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="DROPBOX_FILE",
|
||||||
),
|
),
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# ONEDRIVE TOOLS - create and trash files
|
# ONEDRIVE TOOLS - create and trash files
|
||||||
|
|
@ -376,6 +376,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="ONEDRIVE_FILE",
|
||||||
),
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="delete_onedrive_file",
|
name="delete_onedrive_file",
|
||||||
|
|
@ -386,11 +387,23 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="ONEDRIVE_FILE",
|
||||||
),
|
),
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# GOOGLE CALENDAR TOOLS - create, update, delete events
|
# GOOGLE CALENDAR TOOLS - search, create, update, delete events
|
||||||
# Auto-disabled when no Google Calendar connector is configured
|
# Auto-disabled when no Google Calendar connector is configured
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
ToolDefinition(
|
||||||
|
name="search_calendar_events",
|
||||||
|
description="Search Google Calendar events within a date range",
|
||||||
|
factory=lambda deps: create_search_calendar_events_tool(
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
),
|
||||||
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||||
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="create_calendar_event",
|
name="create_calendar_event",
|
||||||
description="Create a new event on Google Calendar",
|
description="Create a new event on Google Calendar",
|
||||||
|
|
@ -400,6 +413,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||||
),
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="update_calendar_event",
|
name="update_calendar_event",
|
||||||
|
|
@ -410,6 +424,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||||
),
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="delete_calendar_event",
|
name="delete_calendar_event",
|
||||||
|
|
@ -420,11 +435,34 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||||
),
|
),
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# GMAIL TOOLS - create drafts, update drafts, send emails, trash emails
|
# GMAIL TOOLS - search, read, create drafts, update drafts, send, trash
|
||||||
# Auto-disabled when no Gmail connector is configured
|
# Auto-disabled when no Gmail connector is configured
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
ToolDefinition(
|
||||||
|
name="search_gmail",
|
||||||
|
description="Search emails in Gmail using Gmail search syntax",
|
||||||
|
factory=lambda deps: create_search_gmail_tool(
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
),
|
||||||
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||||
|
),
|
||||||
|
ToolDefinition(
|
||||||
|
name="read_gmail_email",
|
||||||
|
description="Read the full content of a specific Gmail email",
|
||||||
|
factory=lambda deps: create_read_gmail_email_tool(
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
),
|
||||||
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||||
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="create_gmail_draft",
|
name="create_gmail_draft",
|
||||||
description="Create a draft email in Gmail",
|
description="Create a draft email in Gmail",
|
||||||
|
|
@ -434,6 +472,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||||
),
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="send_gmail_email",
|
name="send_gmail_email",
|
||||||
|
|
@ -444,6 +483,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||||
),
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="trash_gmail_email",
|
name="trash_gmail_email",
|
||||||
|
|
@ -454,6 +494,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||||
),
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="update_gmail_draft",
|
name="update_gmail_draft",
|
||||||
|
|
@ -464,40 +505,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
),
|
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||||
# =========================================================================
|
|
||||||
# JIRA TOOLS - create, update, delete issues
|
|
||||||
# Auto-disabled when no Jira connector is configured (see chat_deepagent.py)
|
|
||||||
# =========================================================================
|
|
||||||
ToolDefinition(
|
|
||||||
name="create_jira_issue",
|
|
||||||
description="Create a new issue in the user's Jira project",
|
|
||||||
factory=lambda deps: create_create_jira_issue_tool(
|
|
||||||
db_session=deps["db_session"],
|
|
||||||
search_space_id=deps["search_space_id"],
|
|
||||||
user_id=deps["user_id"],
|
|
||||||
),
|
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
|
||||||
),
|
|
||||||
ToolDefinition(
|
|
||||||
name="update_jira_issue",
|
|
||||||
description="Update an existing indexed Jira issue",
|
|
||||||
factory=lambda deps: create_update_jira_issue_tool(
|
|
||||||
db_session=deps["db_session"],
|
|
||||||
search_space_id=deps["search_space_id"],
|
|
||||||
user_id=deps["user_id"],
|
|
||||||
),
|
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
|
||||||
),
|
|
||||||
ToolDefinition(
|
|
||||||
name="delete_jira_issue",
|
|
||||||
description="Delete an existing indexed Jira issue",
|
|
||||||
factory=lambda deps: create_delete_jira_issue_tool(
|
|
||||||
db_session=deps["db_session"],
|
|
||||||
search_space_id=deps["search_space_id"],
|
|
||||||
user_id=deps["user_id"],
|
|
||||||
),
|
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
|
||||||
),
|
),
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# CONFLUENCE TOOLS - create, update, delete pages
|
# CONFLUENCE TOOLS - create, update, delete pages
|
||||||
|
|
@ -512,6 +520,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="CONFLUENCE_CONNECTOR",
|
||||||
),
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="update_confluence_page",
|
name="update_confluence_page",
|
||||||
|
|
@ -522,6 +531,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="CONFLUENCE_CONNECTOR",
|
||||||
),
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="delete_confluence_page",
|
name="delete_confluence_page",
|
||||||
|
|
@ -532,6 +542,118 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
),
|
),
|
||||||
requires=["db_session", "search_space_id", "user_id"],
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="CONFLUENCE_CONNECTOR",
|
||||||
|
),
|
||||||
|
# =========================================================================
|
||||||
|
# DISCORD TOOLS - list channels, read messages, send messages
|
||||||
|
# Auto-disabled when no Discord connector is configured
|
||||||
|
# =========================================================================
|
||||||
|
ToolDefinition(
|
||||||
|
name="list_discord_channels",
|
||||||
|
description="List text channels in the connected Discord server",
|
||||||
|
factory=lambda deps: create_list_discord_channels_tool(
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
),
|
||||||
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="DISCORD_CONNECTOR",
|
||||||
|
),
|
||||||
|
ToolDefinition(
|
||||||
|
name="read_discord_messages",
|
||||||
|
description="Read recent messages from a Discord text channel",
|
||||||
|
factory=lambda deps: create_read_discord_messages_tool(
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
),
|
||||||
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="DISCORD_CONNECTOR",
|
||||||
|
),
|
||||||
|
ToolDefinition(
|
||||||
|
name="send_discord_message",
|
||||||
|
description="Send a message to a Discord text channel",
|
||||||
|
factory=lambda deps: create_send_discord_message_tool(
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
),
|
||||||
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="DISCORD_CONNECTOR",
|
||||||
|
),
|
||||||
|
# =========================================================================
|
||||||
|
# TEAMS TOOLS - list channels, read messages, send messages
|
||||||
|
# Auto-disabled when no Teams connector is configured
|
||||||
|
# =========================================================================
|
||||||
|
ToolDefinition(
|
||||||
|
name="list_teams_channels",
|
||||||
|
description="List Microsoft Teams and their channels",
|
||||||
|
factory=lambda deps: create_list_teams_channels_tool(
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
),
|
||||||
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="TEAMS_CONNECTOR",
|
||||||
|
),
|
||||||
|
ToolDefinition(
|
||||||
|
name="read_teams_messages",
|
||||||
|
description="Read recent messages from a Microsoft Teams channel",
|
||||||
|
factory=lambda deps: create_read_teams_messages_tool(
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
),
|
||||||
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="TEAMS_CONNECTOR",
|
||||||
|
),
|
||||||
|
ToolDefinition(
|
||||||
|
name="send_teams_message",
|
||||||
|
description="Send a message to a Microsoft Teams channel",
|
||||||
|
factory=lambda deps: create_send_teams_message_tool(
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
),
|
||||||
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="TEAMS_CONNECTOR",
|
||||||
|
),
|
||||||
|
# =========================================================================
|
||||||
|
# LUMA TOOLS - list events, read event details, create events
|
||||||
|
# Auto-disabled when no Luma connector is configured
|
||||||
|
# =========================================================================
|
||||||
|
ToolDefinition(
|
||||||
|
name="list_luma_events",
|
||||||
|
description="List upcoming and recent Luma events",
|
||||||
|
factory=lambda deps: create_list_luma_events_tool(
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
),
|
||||||
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="LUMA_CONNECTOR",
|
||||||
|
),
|
||||||
|
ToolDefinition(
|
||||||
|
name="read_luma_event",
|
||||||
|
description="Read detailed information about a specific Luma event",
|
||||||
|
factory=lambda deps: create_read_luma_event_tool(
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
),
|
||||||
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="LUMA_CONNECTOR",
|
||||||
|
),
|
||||||
|
ToolDefinition(
|
||||||
|
name="create_luma_event",
|
||||||
|
description="Create a new event on Luma",
|
||||||
|
factory=lambda deps: create_create_luma_event_tool(
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
),
|
||||||
|
requires=["db_session", "search_space_id", "user_id"],
|
||||||
|
required_connector="LUMA_CONNECTOR",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -549,6 +671,22 @@ def get_tool_by_name(name: str) -> ToolDefinition | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_connector_gated_tools(
|
||||||
|
available_connectors: list[str] | None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Return tool names to disable"""
|
||||||
|
if available_connectors is None:
|
||||||
|
available = set()
|
||||||
|
else:
|
||||||
|
available = set(available_connectors)
|
||||||
|
|
||||||
|
disabled: list[str] = []
|
||||||
|
for tool_def in BUILTIN_TOOLS:
|
||||||
|
if tool_def.required_connector and tool_def.required_connector not in available:
|
||||||
|
disabled.append(tool_def.name)
|
||||||
|
return disabled
|
||||||
|
|
||||||
|
|
||||||
def get_all_tool_names() -> list[str]:
|
def get_all_tool_names() -> list[str]:
|
||||||
"""Get names of all registered tools."""
|
"""Get names of all registered tools."""
|
||||||
return [tool_def.name for tool_def in BUILTIN_TOOLS]
|
return [tool_def.name for tool_def in BUILTIN_TOOLS]
|
||||||
|
|
@ -690,15 +828,15 @@ async def build_tools_async(
|
||||||
)
|
)
|
||||||
tools.extend(mcp_tools)
|
tools.extend(mcp_tools)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
|
"Registered %d MCP tools: %s",
|
||||||
|
len(mcp_tools), [t.name for t in mcp_tools],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log error but don't fail - just continue without MCP tools
|
logging.exception("Failed to load MCP tools: %s", e)
|
||||||
logging.exception(f"Failed to load MCP tools: {e!s}")
|
|
||||||
|
|
||||||
# Log all tools being returned to agent
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}",
|
"Total tools for agent: %d — %s",
|
||||||
|
len(tools), [t.name for t in tools],
|
||||||
)
|
)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
|
||||||
|
|
@ -13,11 +13,13 @@ Uses the same short-lived session pattern as generate_report so no DB
|
||||||
connection is held during the long LLM call.
|
connection is held during the long LLM call.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import pypdf
|
||||||
import typst
|
import typst
|
||||||
from langchain_core.callbacks import dispatch_custom_event
|
from langchain_core.callbacks import dispatch_custom_event
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
@ -114,7 +116,7 @@ _TEMPLATES: dict[str, dict[str, str]] = {
|
||||||
entries-highlights-nested-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt),
|
entries-highlights-nested-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt),
|
||||||
entries-highlights-space-left: 0cm,
|
entries-highlights-space-left: 0cm,
|
||||||
entries-highlights-space-above: 0.08cm,
|
entries-highlights-space-above: 0.08cm,
|
||||||
entries-highlights-space-between-items: 0.08cm,
|
entries-highlights-space-between-items: 0.02cm,
|
||||||
entries-highlights-space-between-bullet-and-text: 0.3em,
|
entries-highlights-space-between-bullet-and-text: 0.3em,
|
||||||
date: datetime(
|
date: datetime(
|
||||||
year: {year},
|
year: {year},
|
||||||
|
|
@ -166,8 +168,8 @@ Available components (use ONLY these):
|
||||||
#summary([Short paragraph summary]) // Optional summary inside an entry
|
#summary([Short paragraph summary]) // Optional summary inside an entry
|
||||||
#content-area([Free-form content]) // Freeform text block
|
#content-area([Free-form content]) // Freeform text block
|
||||||
|
|
||||||
For skills sections, use bold labels directly:
|
For skills sections, use one bullet per category label:
|
||||||
#strong[Category:] item1, item2, item3
|
- #strong[Category:] item1, item2, item3
|
||||||
|
|
||||||
For simple list sections (e.g. Honors), use plain bullet points:
|
For simple list sections (e.g. Honors), use plain bullet points:
|
||||||
- Item one
|
- Item one
|
||||||
|
|
@ -184,15 +186,19 @@ RULES:
|
||||||
- Every section MUST use == heading.
|
- Every section MUST use == heading.
|
||||||
- Use #regular-entry() for experience, projects, publications, certifications, and similar entries.
|
- Use #regular-entry() for experience, projects, publications, certifications, and similar entries.
|
||||||
- Use #education-entry() for education.
|
- Use #education-entry() for education.
|
||||||
- Use #strong[Label:] for skills categories.
|
- For skills sections, use one bullet line per category with a bold label.
|
||||||
- Keep content professional, concise, and achievement-oriented.
|
- Keep content professional, concise, and achievement-oriented.
|
||||||
- Use action verbs for bullet points (Led, Built, Designed, Reduced, etc.).
|
- Use action verbs for bullet points (Led, Built, Designed, Reduced, etc.).
|
||||||
- This template works for ALL professions — adapt sections to the user's field.
|
- This template works for ALL professions — adapt sections to the user's field.
|
||||||
|
- Default behavior should prioritize concise one-page content.
|
||||||
""",
|
""",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_TEMPLATE = "classic"
|
DEFAULT_TEMPLATE = "classic"
|
||||||
|
MIN_RESUME_PAGES = 1
|
||||||
|
MAX_RESUME_PAGES = 5
|
||||||
|
MAX_COMPRESSION_ATTEMPTS = 2
|
||||||
|
|
||||||
|
|
||||||
# ─── Template Helpers ─────────────────────────────────────────────────────────
|
# ─── Template Helpers ─────────────────────────────────────────────────────────
|
||||||
|
|
@ -315,6 +321,8 @@ You are an expert resume writer. Generate professional resume content as Typst m
|
||||||
**User Information:**
|
**User Information:**
|
||||||
{user_info}
|
{user_info}
|
||||||
|
|
||||||
|
**Target Maximum Pages:** {max_pages}
|
||||||
|
|
||||||
{user_instructions_section}
|
{user_instructions_section}
|
||||||
|
|
||||||
Generate the resume content now (starting with = Full Name):
|
Generate the resume content now (starting with = Full Name):
|
||||||
|
|
@ -326,6 +334,8 @@ Apply ONLY the requested changes — do NOT rewrite sections that are not affect
|
||||||
|
|
||||||
{llm_reference}
|
{llm_reference}
|
||||||
|
|
||||||
|
**Target Maximum Pages:** {max_pages}
|
||||||
|
|
||||||
**Modification Instructions:** {user_instructions}
|
**Modification Instructions:** {user_instructions}
|
||||||
|
|
||||||
**EXISTING RESUME CONTENT:**
|
**EXISTING RESUME CONTENT:**
|
||||||
|
|
@ -352,6 +362,28 @@ The resume content you generated failed to compile. Fix the error while preservi
|
||||||
(starting with = Full Name), NOT the #import or #show rule:**
|
(starting with = Full Name), NOT the #import or #show rule:**
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_COMPRESS_TO_PAGE_LIMIT_PROMPT = """\
|
||||||
|
The resume compiles, but it exceeds the maximum allowed page count.
|
||||||
|
Compress the resume while preserving high-impact accomplishments and role relevance.
|
||||||
|
|
||||||
|
{llm_reference}
|
||||||
|
|
||||||
|
**Target Maximum Pages:** {max_pages}
|
||||||
|
**Current Page Count:** {actual_pages}
|
||||||
|
**Compression Attempt:** {attempt_number}
|
||||||
|
|
||||||
|
Compression priorities (in this order):
|
||||||
|
1) Keep recent, high-impact, role-relevant bullets.
|
||||||
|
2) Remove low-impact or redundant bullets.
|
||||||
|
3) Shorten verbose wording while preserving meaning.
|
||||||
|
4) Trim older or less relevant details before recent ones.
|
||||||
|
|
||||||
|
Return the complete updated Typst content (starting with = Full Name), and keep it at or below the target pages.
|
||||||
|
|
||||||
|
**EXISTING RESUME CONTENT:**
|
||||||
|
{previous_content}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
# ─── Helpers ─────────────────────────────────────────────────────────────────
|
# ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
@ -373,6 +405,24 @@ def _compile_typst(source: str) -> bytes:
|
||||||
return typst.compile(source.encode("utf-8"))
|
return typst.compile(source.encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def _count_pdf_pages(pdf_bytes: bytes) -> int:
|
||||||
|
"""Count the number of pages in compiled PDF bytes."""
|
||||||
|
with io.BytesIO(pdf_bytes) as pdf_stream:
|
||||||
|
reader = pypdf.PdfReader(pdf_stream)
|
||||||
|
return len(reader.pages)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_max_pages(max_pages: int) -> int:
|
||||||
|
"""Validate and normalize max_pages input."""
|
||||||
|
if MIN_RESUME_PAGES <= max_pages <= MAX_RESUME_PAGES:
|
||||||
|
return max_pages
|
||||||
|
msg = (
|
||||||
|
f"max_pages must be between {MIN_RESUME_PAGES} and "
|
||||||
|
f"{MAX_RESUME_PAGES}. Received: {max_pages}"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
# ─── Tool Factory ───────────────────────────────────────────────────────────
|
# ─── Tool Factory ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -394,6 +444,7 @@ def create_generate_resume_tool(
|
||||||
user_info: str,
|
user_info: str,
|
||||||
user_instructions: str | None = None,
|
user_instructions: str | None = None,
|
||||||
parent_report_id: int | None = None,
|
parent_report_id: int | None = None,
|
||||||
|
max_pages: int = 1,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generate a professional resume as a Typst document.
|
Generate a professional resume as a Typst document.
|
||||||
|
|
@ -426,6 +477,8 @@ def create_generate_resume_tool(
|
||||||
"use a modern style"). For revisions, describe what to change.
|
"use a modern style"). For revisions, describe what to change.
|
||||||
parent_report_id: ID of a previous resume to revise (creates
|
parent_report_id: ID of a previous resume to revise (creates
|
||||||
new version in the same version group).
|
new version in the same version group).
|
||||||
|
max_pages: Maximum number of pages for the generated resume.
|
||||||
|
Defaults to 1. Allowed range: 1-5.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with status, report_id, title, and content_type.
|
Dict with status, report_id, title, and content_type.
|
||||||
|
|
@ -469,6 +522,19 @@ def create_generate_resume_tool(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
try:
|
||||||
|
validated_max_pages = _validate_max_pages(max_pages)
|
||||||
|
except ValueError as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
report_id = await _save_failed_report(error_msg)
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"error": error_msg,
|
||||||
|
"report_id": report_id,
|
||||||
|
"title": "Resume",
|
||||||
|
"content_type": "typst",
|
||||||
|
}
|
||||||
|
|
||||||
# ── Phase 1: READ ─────────────────────────────────────────────
|
# ── Phase 1: READ ─────────────────────────────────────────────
|
||||||
async with shielded_async_session() as read_session:
|
async with shielded_async_session() as read_session:
|
||||||
if parent_report_id:
|
if parent_report_id:
|
||||||
|
|
@ -512,6 +578,7 @@ def create_generate_resume_tool(
|
||||||
parent_body = _strip_header(parent_content)
|
parent_body = _strip_header(parent_content)
|
||||||
prompt = _REVISION_PROMPT.format(
|
prompt = _REVISION_PROMPT.format(
|
||||||
llm_reference=llm_reference,
|
llm_reference=llm_reference,
|
||||||
|
max_pages=validated_max_pages,
|
||||||
user_instructions=user_instructions
|
user_instructions=user_instructions
|
||||||
or "Improve and refine the resume.",
|
or "Improve and refine the resume.",
|
||||||
previous_content=parent_body,
|
previous_content=parent_body,
|
||||||
|
|
@ -524,6 +591,7 @@ def create_generate_resume_tool(
|
||||||
prompt = _RESUME_PROMPT.format(
|
prompt = _RESUME_PROMPT.format(
|
||||||
llm_reference=llm_reference,
|
llm_reference=llm_reference,
|
||||||
user_info=user_info,
|
user_info=user_info,
|
||||||
|
max_pages=validated_max_pages,
|
||||||
user_instructions_section=user_instructions_section,
|
user_instructions_section=user_instructions_section,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -551,49 +619,116 @@ def create_generate_resume_tool(
|
||||||
)
|
)
|
||||||
|
|
||||||
name = _extract_name(body) or "Resume"
|
name = _extract_name(body) or "Resume"
|
||||||
header = _build_header(template, name)
|
typst_source = ""
|
||||||
typst_source = header + body
|
actual_pages = 0
|
||||||
|
compression_attempts = 0
|
||||||
|
target_page_met = False
|
||||||
|
|
||||||
compile_error: str | None = None
|
for compression_round in range(MAX_COMPRESSION_ATTEMPTS + 1):
|
||||||
for attempt in range(2):
|
header = _build_header(template, name)
|
||||||
try:
|
typst_source = header + body
|
||||||
_compile_typst(typst_source)
|
compile_error: str | None = None
|
||||||
compile_error = None
|
pdf_bytes: bytes | None = None
|
||||||
break
|
|
||||||
except Exception as e:
|
for compile_attempt in range(2):
|
||||||
compile_error = str(e)
|
try:
|
||||||
logger.warning(
|
pdf_bytes = _compile_typst(typst_source)
|
||||||
f"[generate_resume] Compile attempt {attempt + 1} failed: {compile_error}"
|
compile_error = None
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
compile_error = str(e)
|
||||||
|
logger.warning(
|
||||||
|
"[generate_resume] Compile attempt %s failed: %s",
|
||||||
|
compile_attempt + 1,
|
||||||
|
compile_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
if compile_attempt == 0:
|
||||||
|
dispatch_custom_event(
|
||||||
|
"report_progress",
|
||||||
|
{
|
||||||
|
"phase": "fixing",
|
||||||
|
"message": "Fixing compilation issue...",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
fix_prompt = _FIX_COMPILE_PROMPT.format(
|
||||||
|
llm_reference=llm_reference,
|
||||||
|
error=compile_error,
|
||||||
|
full_source=typst_source,
|
||||||
|
)
|
||||||
|
fix_response = await llm.ainvoke(
|
||||||
|
[HumanMessage(content=fix_prompt)]
|
||||||
|
)
|
||||||
|
if fix_response.content and isinstance(
|
||||||
|
fix_response.content, str
|
||||||
|
):
|
||||||
|
body = _strip_typst_fences(fix_response.content)
|
||||||
|
body = _strip_imports(body)
|
||||||
|
name = _extract_name(body) or name
|
||||||
|
header = _build_header(template, name)
|
||||||
|
typst_source = header + body
|
||||||
|
|
||||||
|
if compile_error or not pdf_bytes:
|
||||||
|
error_msg = (
|
||||||
|
"Typst compilation failed after 2 attempts: "
|
||||||
|
f"{compile_error or 'Unknown compile error'}"
|
||||||
)
|
)
|
||||||
|
report_id = await _save_failed_report(error_msg)
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"error": error_msg,
|
||||||
|
"report_id": report_id,
|
||||||
|
"title": "Resume",
|
||||||
|
"content_type": "typst",
|
||||||
|
}
|
||||||
|
|
||||||
if attempt == 0:
|
actual_pages = _count_pdf_pages(pdf_bytes)
|
||||||
dispatch_custom_event(
|
if actual_pages <= validated_max_pages:
|
||||||
"report_progress",
|
target_page_met = True
|
||||||
{
|
break
|
||||||
"phase": "fixing",
|
|
||||||
"message": "Fixing compilation issue...",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
fix_prompt = _FIX_COMPILE_PROMPT.format(
|
|
||||||
llm_reference=llm_reference,
|
|
||||||
error=compile_error,
|
|
||||||
full_source=typst_source,
|
|
||||||
)
|
|
||||||
fix_response = await llm.ainvoke(
|
|
||||||
[HumanMessage(content=fix_prompt)]
|
|
||||||
)
|
|
||||||
if fix_response.content and isinstance(
|
|
||||||
fix_response.content, str
|
|
||||||
):
|
|
||||||
body = _strip_typst_fences(fix_response.content)
|
|
||||||
body = _strip_imports(body)
|
|
||||||
name = _extract_name(body) or name
|
|
||||||
header = _build_header(template, name)
|
|
||||||
typst_source = header + body
|
|
||||||
|
|
||||||
if compile_error:
|
if compression_round >= MAX_COMPRESSION_ATTEMPTS:
|
||||||
|
break
|
||||||
|
|
||||||
|
compression_attempts += 1
|
||||||
|
dispatch_custom_event(
|
||||||
|
"report_progress",
|
||||||
|
{
|
||||||
|
"phase": "compressing",
|
||||||
|
"message": f"Condensing resume to {validated_max_pages} page(s)...",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
compress_prompt = _COMPRESS_TO_PAGE_LIMIT_PROMPT.format(
|
||||||
|
llm_reference=llm_reference,
|
||||||
|
max_pages=validated_max_pages,
|
||||||
|
actual_pages=actual_pages,
|
||||||
|
attempt_number=compression_attempts,
|
||||||
|
previous_content=body,
|
||||||
|
)
|
||||||
|
compress_response = await llm.ainvoke(
|
||||||
|
[HumanMessage(content=compress_prompt)]
|
||||||
|
)
|
||||||
|
if not compress_response.content or not isinstance(
|
||||||
|
compress_response.content, str
|
||||||
|
):
|
||||||
|
error_msg = "LLM returned empty content while compressing resume"
|
||||||
|
report_id = await _save_failed_report(error_msg)
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"error": error_msg,
|
||||||
|
"report_id": report_id,
|
||||||
|
"title": "Resume",
|
||||||
|
"content_type": "typst",
|
||||||
|
}
|
||||||
|
|
||||||
|
body = _strip_typst_fences(compress_response.content)
|
||||||
|
body = _strip_imports(body)
|
||||||
|
name = _extract_name(body) or name
|
||||||
|
|
||||||
|
if actual_pages > MAX_RESUME_PAGES:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"Typst compilation failed after 2 attempts: {compile_error}"
|
"Resume exceeds hard page limit after compression retries. "
|
||||||
|
f"Hard limit: <= {MAX_RESUME_PAGES} page(s), actual: {actual_pages}."
|
||||||
)
|
)
|
||||||
report_id = await _save_failed_report(error_msg)
|
report_id = await _save_failed_report(error_msg)
|
||||||
return {
|
return {
|
||||||
|
|
@ -616,6 +751,11 @@ def create_generate_resume_tool(
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"word_count": len(typst_source.split()),
|
"word_count": len(typst_source.split()),
|
||||||
"char_count": len(typst_source),
|
"char_count": len(typst_source),
|
||||||
|
"target_max_pages": validated_max_pages,
|
||||||
|
"actual_page_count": actual_pages,
|
||||||
|
"page_limit_enforced": True,
|
||||||
|
"compression_attempts": compression_attempts,
|
||||||
|
"target_page_met": target_page_met,
|
||||||
}
|
}
|
||||||
|
|
||||||
async with shielded_async_session() as write_session:
|
async with shielded_async_session() as write_session:
|
||||||
|
|
@ -647,7 +787,14 @@ def create_generate_resume_tool(
|
||||||
"title": resume_title,
|
"title": resume_title,
|
||||||
"content_type": "typst",
|
"content_type": "typst",
|
||||||
"is_revision": bool(parent_content),
|
"is_revision": bool(parent_content),
|
||||||
"message": f"Resume generated successfully: {resume_title}",
|
"message": (
|
||||||
|
f"Resume generated successfully: {resume_title}"
|
||||||
|
if target_page_met
|
||||||
|
else (
|
||||||
|
f"Resume generated, but could not fit the target of <= {validated_max_pages} "
|
||||||
|
f"page(s). Final length: {actual_pages} page(s)."
|
||||||
|
)
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
from app.agents.new_chat.tools.teams.list_channels import (
|
||||||
|
create_list_teams_channels_tool,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.tools.teams.read_messages import (
|
||||||
|
create_read_teams_messages_tool,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.tools.teams.send_message import (
|
||||||
|
create_send_teams_message_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_list_teams_channels_tool",
|
||||||
|
"create_read_teams_messages_tool",
|
||||||
|
"create_send_teams_message_tool",
|
||||||
|
]
|
||||||
37
surfsense_backend/app/agents/new_chat/tools/teams/_auth.py
Normal file
37
surfsense_backend/app/agents/new_chat/tools/teams/_auth.py
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
"""Shared auth helper for Teams agent tools (Microsoft Graph REST API)."""
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
GRAPH_API = "https://graph.microsoft.com/v1.0"
|
||||||
|
|
||||||
|
|
||||||
|
async def get_teams_connector(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str,
|
||||||
|
) -> SearchSourceConnector | None:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_access_token(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
) -> str:
|
||||||
|
"""Get a valid Microsoft Graph access token, refreshing if expired."""
|
||||||
|
from app.connectors.teams_connector import TeamsConnector
|
||||||
|
|
||||||
|
tc = TeamsConnector(
|
||||||
|
session=db_session,
|
||||||
|
connector_id=connector.id,
|
||||||
|
)
|
||||||
|
return await tc._get_valid_token()
|
||||||
|
|
@ -0,0 +1,77 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_list_teams_channels_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def list_teams_channels() -> dict[str, Any]:
|
||||||
|
"""List all Microsoft Teams and their channels the user has access to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status and a list of teams, each containing
|
||||||
|
team_id, team_name, and a list of channels (id, name).
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Teams connector found."}
|
||||||
|
|
||||||
|
token = await get_access_token(db_session, connector)
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
|
teams_resp = await client.get(f"{GRAPH_API}/me/joinedTeams", headers=headers)
|
||||||
|
|
||||||
|
if teams_resp.status_code == 401:
|
||||||
|
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
|
||||||
|
if teams_resp.status_code != 200:
|
||||||
|
return {"status": "error", "message": f"Graph API error: {teams_resp.status_code}"}
|
||||||
|
|
||||||
|
teams_data = teams_resp.json().get("value", [])
|
||||||
|
result_teams = []
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
|
for team in teams_data:
|
||||||
|
team_id = team["id"]
|
||||||
|
ch_resp = await client.get(
|
||||||
|
f"{GRAPH_API}/teams/{team_id}/channels",
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
channels = []
|
||||||
|
if ch_resp.status_code == 200:
|
||||||
|
channels = [
|
||||||
|
{"id": ch["id"], "name": ch.get("displayName", "")}
|
||||||
|
for ch in ch_resp.json().get("value", [])
|
||||||
|
]
|
||||||
|
result_teams.append({
|
||||||
|
"team_id": team_id,
|
||||||
|
"team_name": team.get("displayName", ""),
|
||||||
|
"channels": channels,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"status": "success", "teams": result_teams, "total_teams": len(result_teams)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error listing Teams channels: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to list Teams channels."}
|
||||||
|
|
||||||
|
return list_teams_channels
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_read_teams_messages_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def read_teams_messages(
|
||||||
|
team_id: str,
|
||||||
|
channel_id: str,
|
||||||
|
limit: int = 25,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Read recent messages from a Microsoft Teams channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: The team ID (from list_teams_channels).
|
||||||
|
channel_id: The channel ID (from list_teams_channels).
|
||||||
|
limit: Number of messages to fetch (default 25, max 50).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status and a list of messages including
|
||||||
|
id, sender, content, timestamp.
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||||
|
|
||||||
|
limit = min(limit, 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Teams connector found."}
|
||||||
|
|
||||||
|
token = await get_access_token(db_session, connector)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
params={"$top": limit},
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
|
||||||
|
if resp.status_code == 403:
|
||||||
|
return {"status": "error", "message": "Insufficient permissions to read this channel."}
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return {"status": "error", "message": f"Graph API error: {resp.status_code}"}
|
||||||
|
|
||||||
|
raw_msgs = resp.json().get("value", [])
|
||||||
|
messages = []
|
||||||
|
for m in raw_msgs:
|
||||||
|
sender = m.get("from", {})
|
||||||
|
user_info = sender.get("user", {}) if sender else {}
|
||||||
|
body = m.get("body", {})
|
||||||
|
messages.append({
|
||||||
|
"id": m.get("id"),
|
||||||
|
"sender": user_info.get("displayName", "Unknown"),
|
||||||
|
"content": body.get("content", ""),
|
||||||
|
"content_type": body.get("contentType", "text"),
|
||||||
|
"timestamp": m.get("createdDateTime", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"team_id": team_id,
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"messages": messages,
|
||||||
|
"total": len(messages),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error reading Teams messages: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to read Teams messages."}
|
||||||
|
|
||||||
|
return read_teams_messages
|
||||||
|
|
@ -0,0 +1,101 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
|
||||||
|
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_send_teams_message_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def send_teams_message(
|
||||||
|
team_id: str,
|
||||||
|
channel_id: str,
|
||||||
|
content: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send a message to a Microsoft Teams channel.
|
||||||
|
|
||||||
|
Requires the ChannelMessage.Send OAuth scope. If the user gets a
|
||||||
|
permission error, they may need to re-authenticate with updated scopes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: The team ID (from list_teams_channels).
|
||||||
|
channel_id: The channel ID (from list_teams_channels).
|
||||||
|
content: The message text (HTML supported).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status, message_id on success.
|
||||||
|
|
||||||
|
IMPORTANT:
|
||||||
|
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Teams connector found."}
|
||||||
|
|
||||||
|
result = request_approval(
|
||||||
|
action_type="teams_send_message",
|
||||||
|
tool_name="send_teams_message",
|
||||||
|
params={"team_id": team_id, "channel_id": channel_id, "content": content},
|
||||||
|
context={"connector_id": connector.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rejected:
|
||||||
|
return {"status": "rejected", "message": "User declined. Message was not sent."}
|
||||||
|
|
||||||
|
final_content = result.params.get("content", content)
|
||||||
|
final_team = result.params.get("team_id", team_id)
|
||||||
|
final_channel = result.params.get("channel_id", channel_id)
|
||||||
|
|
||||||
|
token = await get_access_token(db_session, connector)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={"body": {"content": final_content}},
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
|
||||||
|
if resp.status_code == 403:
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
|
||||||
|
}
|
||||||
|
if resp.status_code not in (200, 201):
|
||||||
|
return {"status": "error", "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}"}
|
||||||
|
|
||||||
|
msg_data = resp.json()
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message_id": msg_data.get("id"),
|
||||||
|
"message": f"Message sent to Teams channel.",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error sending Teams message: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to send Teams message."}
|
||||||
|
|
||||||
|
return send_teams_message
|
||||||
41
surfsense_backend/app/agents/new_chat/tools/tool_response.py
Normal file
41
surfsense_backend/app/agents/new_chat/tools/tool_response.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
"""Standardised response dict factories for LangChain agent tools."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResponse:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def success(message: str, **data: Any) -> dict[str, Any]:
|
||||||
|
return {"status": "success", "message": message, **data}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def error(error: str, **data: Any) -> dict[str, Any]:
|
||||||
|
return {"status": "error", "error": error, **data}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def auth_error(service: str, **data: Any) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"error": (
|
||||||
|
f"{service} authentication has expired or been revoked. "
|
||||||
|
"Please re-connect the integration in Settings → Connectors."
|
||||||
|
),
|
||||||
|
**data,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rejected(message: str = "Action was declined by the user.") -> dict[str, Any]:
|
||||||
|
return {"status": "rejected", "message": message}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def not_found(
|
||||||
|
resource: str, identifier: str, **data: Any
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"status": "not_found",
|
||||||
|
"error": f"{resource} '{identifier}' was not found.",
|
||||||
|
**data,
|
||||||
|
}
|
||||||
|
|
@ -141,6 +141,15 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons
|
||||||
exc.status_code,
|
exc.status_code,
|
||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
|
elif exc.status_code >= 400:
|
||||||
|
_error_logger.warning(
|
||||||
|
"[%s] %s %s - HTTPException %d: %s",
|
||||||
|
rid,
|
||||||
|
request.method,
|
||||||
|
request.url.path,
|
||||||
|
exc.status_code,
|
||||||
|
message,
|
||||||
|
)
|
||||||
if should_sanitize:
|
if should_sanitize:
|
||||||
message = GENERIC_5XX_MESSAGE
|
message = GENERIC_5XX_MESSAGE
|
||||||
err_code = "INTERNAL_ERROR"
|
err_code = "INTERNAL_ERROR"
|
||||||
|
|
@ -170,6 +179,15 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons
|
||||||
exc.status_code,
|
exc.status_code,
|
||||||
detail,
|
detail,
|
||||||
)
|
)
|
||||||
|
elif exc.status_code >= 400:
|
||||||
|
_error_logger.warning(
|
||||||
|
"[%s] %s %s - HTTPException %d: %s",
|
||||||
|
rid,
|
||||||
|
request.method,
|
||||||
|
request.url.path,
|
||||||
|
exc.status_code,
|
||||||
|
detail,
|
||||||
|
)
|
||||||
if should_sanitize:
|
if should_sanitize:
|
||||||
detail = GENERIC_5XX_MESSAGE
|
detail = GENERIC_5XX_MESSAGE
|
||||||
code = _status_to_code(exc.status_code, detail)
|
code = _status_to_code(exc.status_code, detail)
|
||||||
|
|
|
||||||
|
|
@ -136,20 +136,12 @@ celery_app.conf.update(
|
||||||
# never block fast user-facing tasks (file uploads, podcasts, etc.)
|
# never block fast user-facing tasks (file uploads, podcasts, etc.)
|
||||||
task_routes={
|
task_routes={
|
||||||
# Connector indexing tasks → connectors queue
|
# Connector indexing tasks → connectors queue
|
||||||
"index_slack_messages": {"queue": CONNECTORS_QUEUE},
|
|
||||||
"index_notion_pages": {"queue": CONNECTORS_QUEUE},
|
"index_notion_pages": {"queue": CONNECTORS_QUEUE},
|
||||||
"index_github_repos": {"queue": CONNECTORS_QUEUE},
|
"index_github_repos": {"queue": CONNECTORS_QUEUE},
|
||||||
"index_linear_issues": {"queue": CONNECTORS_QUEUE},
|
|
||||||
"index_jira_issues": {"queue": CONNECTORS_QUEUE},
|
|
||||||
"index_confluence_pages": {"queue": CONNECTORS_QUEUE},
|
"index_confluence_pages": {"queue": CONNECTORS_QUEUE},
|
||||||
"index_clickup_tasks": {"queue": CONNECTORS_QUEUE},
|
|
||||||
"index_google_calendar_events": {"queue": CONNECTORS_QUEUE},
|
"index_google_calendar_events": {"queue": CONNECTORS_QUEUE},
|
||||||
"index_airtable_records": {"queue": CONNECTORS_QUEUE},
|
|
||||||
"index_google_gmail_messages": {"queue": CONNECTORS_QUEUE},
|
"index_google_gmail_messages": {"queue": CONNECTORS_QUEUE},
|
||||||
"index_google_drive_files": {"queue": CONNECTORS_QUEUE},
|
"index_google_drive_files": {"queue": CONNECTORS_QUEUE},
|
||||||
"index_discord_messages": {"queue": CONNECTORS_QUEUE},
|
|
||||||
"index_teams_messages": {"queue": CONNECTORS_QUEUE},
|
|
||||||
"index_luma_events": {"queue": CONNECTORS_QUEUE},
|
|
||||||
"index_elasticsearch_documents": {"queue": CONNECTORS_QUEUE},
|
"index_elasticsearch_documents": {"queue": CONNECTORS_QUEUE},
|
||||||
"index_crawled_urls": {"queue": CONNECTORS_QUEUE},
|
"index_crawled_urls": {"queue": CONNECTORS_QUEUE},
|
||||||
"index_bookstack_pages": {"queue": CONNECTORS_QUEUE},
|
"index_bookstack_pages": {"queue": CONNECTORS_QUEUE},
|
||||||
|
|
|
||||||
|
|
@ -339,6 +339,9 @@ class Config:
|
||||||
# self-hosted: Full access to local file system connectors (Obsidian, etc.)
|
# self-hosted: Full access to local file system connectors (Obsidian, etc.)
|
||||||
# cloud: Only cloud-based connectors available
|
# cloud: Only cloud-based connectors available
|
||||||
DEPLOYMENT_MODE = os.getenv("SURFSENSE_DEPLOYMENT_MODE", "self-hosted")
|
DEPLOYMENT_MODE = os.getenv("SURFSENSE_DEPLOYMENT_MODE", "self-hosted")
|
||||||
|
ENABLE_DESKTOP_LOCAL_FILESYSTEM = (
|
||||||
|
os.getenv("ENABLE_DESKTOP_LOCAL_FILESYSTEM", "FALSE").upper() == "TRUE"
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_self_hosted(cls) -> bool:
|
def is_self_hosted(cls) -> bool:
|
||||||
|
|
|
||||||
98
surfsense_backend/app/connectors/exceptions.py
Normal file
98
surfsense_backend/app/connectors/exceptions.py
Normal file
|
|
@ -0,0 +1,98 @@
|
||||||
|
"""Standard exception hierarchy for all connectors.
|
||||||
|
|
||||||
|
ConnectorError
|
||||||
|
├── ConnectorAuthError (401/403 — non-retryable)
|
||||||
|
├── ConnectorRateLimitError (429 — retryable, carries ``retry_after``)
|
||||||
|
├── ConnectorTimeoutError (timeout/504 — retryable)
|
||||||
|
└── ConnectorAPIError (5xx or unexpected — retryable when >= 500)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorError(Exception):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
service: str = "",
|
||||||
|
status_code: int | None = None,
|
||||||
|
response_body: Any = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.service = service
|
||||||
|
self.status_code = status_code
|
||||||
|
self.response_body = response_body
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retryable(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorAuthError(ConnectorError):
|
||||||
|
"""Token expired, revoked, insufficient scopes, or needs re-auth (401/403)."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retryable(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorRateLimitError(ConnectorError):
|
||||||
|
"""429 Too Many Requests."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Rate limited",
|
||||||
|
*,
|
||||||
|
service: str = "",
|
||||||
|
retry_after: float | None = None,
|
||||||
|
status_code: int = 429,
|
||||||
|
response_body: Any = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
message,
|
||||||
|
service=service,
|
||||||
|
status_code=status_code,
|
||||||
|
response_body=response_body,
|
||||||
|
)
|
||||||
|
self.retry_after = retry_after
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retryable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorTimeoutError(ConnectorError):
|
||||||
|
"""Request timeout or gateway timeout (504)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Request timed out",
|
||||||
|
*,
|
||||||
|
service: str = "",
|
||||||
|
status_code: int | None = None,
|
||||||
|
response_body: Any = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
message,
|
||||||
|
service=service,
|
||||||
|
status_code=status_code,
|
||||||
|
response_body=response_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retryable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorAPIError(ConnectorError):
|
||||||
|
"""Generic API error (5xx or unexpected status codes)."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retryable(self) -> bool:
|
||||||
|
if self.status_code is not None:
|
||||||
|
return self.status_code >= 500
|
||||||
|
return False
|
||||||
|
|
@ -30,6 +30,7 @@ from .jira_add_connector_route import router as jira_add_connector_router
|
||||||
from .linear_add_connector_route import router as linear_add_connector_router
|
from .linear_add_connector_route import router as linear_add_connector_router
|
||||||
from .logs_routes import router as logs_router
|
from .logs_routes import router as logs_router
|
||||||
from .luma_add_connector_route import router as luma_add_connector_router
|
from .luma_add_connector_route import router as luma_add_connector_router
|
||||||
|
from .mcp_oauth_route import router as mcp_oauth_router
|
||||||
from .memory_routes import router as memory_router
|
from .memory_routes import router as memory_router
|
||||||
from .model_list_routes import router as model_list_router
|
from .model_list_routes import router as model_list_router
|
||||||
from .new_chat_routes import router as new_chat_router
|
from .new_chat_routes import router as new_chat_router
|
||||||
|
|
@ -97,6 +98,7 @@ router.include_router(logs_router)
|
||||||
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
|
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
|
||||||
router.include_router(surfsense_docs_router) # Surfsense documentation for citations
|
router.include_router(surfsense_docs_router) # Surfsense documentation for citations
|
||||||
router.include_router(notifications_router) # Notifications with Zero sync
|
router.include_router(notifications_router) # Notifications with Zero sync
|
||||||
|
router.include_router(mcp_oauth_router) # MCP OAuth 2.1 for Linear, Jira, ClickUp, Slack, Airtable
|
||||||
router.include_router(composio_router) # Composio OAuth and toolkit management
|
router.include_router(composio_router) # Composio OAuth and toolkit management
|
||||||
router.include_router(public_chat_router) # Public chat sharing and cloning
|
router.include_router(public_chat_router) # Public chat sharing and cloning
|
||||||
router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages
|
router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages
|
||||||
|
|
|
||||||
|
|
@ -311,7 +311,7 @@ async def airtable_callback(
|
||||||
new_connector = SearchSourceConnector(
|
new_connector = SearchSourceConnector(
|
||||||
name=connector_name,
|
name=connector_name,
|
||||||
connector_type=SearchSourceConnectorType.AIRTABLE_CONNECTOR,
|
connector_type=SearchSourceConnectorType.AIRTABLE_CONNECTOR,
|
||||||
is_indexable=True,
|
is_indexable=False,
|
||||||
config=credentials_dict,
|
config=credentials_dict,
|
||||||
search_space_id=space_id,
|
search_space_id=space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
||||||
|
|
@ -301,7 +301,7 @@ async def clickup_callback(
|
||||||
# Update existing connector
|
# Update existing connector
|
||||||
existing_connector.config = connector_config
|
existing_connector.config = connector_config
|
||||||
existing_connector.name = "ClickUp Connector"
|
existing_connector.name = "ClickUp Connector"
|
||||||
existing_connector.is_indexable = True
|
existing_connector.is_indexable = False
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Updated existing ClickUp connector for user {user_id} in space {space_id}"
|
f"Updated existing ClickUp connector for user {user_id} in space {space_id}"
|
||||||
)
|
)
|
||||||
|
|
@ -310,7 +310,7 @@ async def clickup_callback(
|
||||||
new_connector = SearchSourceConnector(
|
new_connector = SearchSourceConnector(
|
||||||
name="ClickUp Connector",
|
name="ClickUp Connector",
|
||||||
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
||||||
is_indexable=True,
|
is_indexable=False,
|
||||||
config=connector_config,
|
config=connector_config,
|
||||||
search_space_id=space_id,
|
search_space_id=space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
||||||
|
|
@ -326,7 +326,7 @@ async def discord_callback(
|
||||||
new_connector = SearchSourceConnector(
|
new_connector = SearchSourceConnector(
|
||||||
name=connector_name,
|
name=connector_name,
|
||||||
connector_type=SearchSourceConnectorType.DISCORD_CONNECTOR,
|
connector_type=SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||||
is_indexable=True,
|
is_indexable=False,
|
||||||
config=connector_config,
|
config=connector_config,
|
||||||
search_space_id=space_id,
|
search_space_id=space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
||||||
|
|
@ -340,7 +340,7 @@ async def calendar_callback(
|
||||||
config=creds_dict,
|
config=creds_dict,
|
||||||
search_space_id=space_id,
|
search_space_id=space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
is_indexable=True,
|
is_indexable=False,
|
||||||
)
|
)
|
||||||
session.add(db_connector)
|
session.add(db_connector)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
|
||||||
|
|
@ -371,7 +371,7 @@ async def gmail_callback(
|
||||||
config=creds_dict,
|
config=creds_dict,
|
||||||
search_space_id=space_id,
|
search_space_id=space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
is_indexable=True,
|
is_indexable=False,
|
||||||
)
|
)
|
||||||
session.add(db_connector)
|
session.add(db_connector)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
|
||||||
|
|
@ -386,7 +386,7 @@ async def jira_callback(
|
||||||
new_connector = SearchSourceConnector(
|
new_connector = SearchSourceConnector(
|
||||||
name=connector_name,
|
name=connector_name,
|
||||||
connector_type=SearchSourceConnectorType.JIRA_CONNECTOR,
|
connector_type=SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||||
is_indexable=True,
|
is_indexable=False,
|
||||||
config=connector_config,
|
config=connector_config,
|
||||||
search_space_id=space_id,
|
search_space_id=space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
||||||
|
|
@ -399,7 +399,7 @@ async def linear_callback(
|
||||||
new_connector = SearchSourceConnector(
|
new_connector = SearchSourceConnector(
|
||||||
name=connector_name,
|
name=connector_name,
|
||||||
connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR,
|
connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||||
is_indexable=True,
|
is_indexable=False,
|
||||||
config=connector_config,
|
config=connector_config,
|
||||||
search_space_id=space_id,
|
search_space_id=space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ async def add_luma_connector(
|
||||||
if existing_connector:
|
if existing_connector:
|
||||||
# Update existing connector with new API key
|
# Update existing connector with new API key
|
||||||
existing_connector.config = {"api_key": request.api_key}
|
existing_connector.config = {"api_key": request.api_key}
|
||||||
existing_connector.is_indexable = True
|
existing_connector.is_indexable = False
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(existing_connector)
|
await session.refresh(existing_connector)
|
||||||
|
|
||||||
|
|
@ -82,7 +82,7 @@ async def add_luma_connector(
|
||||||
config={"api_key": request.api_key},
|
config={"api_key": request.api_key},
|
||||||
search_space_id=request.space_id,
|
search_space_id=request.space_id,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
is_indexable=True,
|
is_indexable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
session.add(db_connector)
|
session.add(db_connector)
|
||||||
|
|
|
||||||
601
surfsense_backend/app/routes/mcp_oauth_route.py
Normal file
601
surfsense_backend/app/routes/mcp_oauth_route.py
Normal file
|
|
@ -0,0 +1,601 @@
|
||||||
|
"""Generic MCP OAuth 2.1 route for services with official MCP servers.
|
||||||
|
|
||||||
|
Handles the full flow: discovery → DCR → PKCE authorization → token exchange
|
||||||
|
→ MCP_CONNECTOR creation. Currently supports Linear, Jira, ClickUp, Slack,
|
||||||
|
and Airtable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.db import (
|
||||||
|
SearchSourceConnector,
|
||||||
|
SearchSourceConnectorType,
|
||||||
|
User,
|
||||||
|
get_async_session,
|
||||||
|
)
|
||||||
|
from app.users import current_active_user
|
||||||
|
from app.utils.connector_naming import generate_unique_connector_name
|
||||||
|
from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_pkce_pair
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_account_metadata(
|
||||||
|
service_key: str, access_token: str, token_json: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Fetch display-friendly account metadata after a successful token exchange.
|
||||||
|
|
||||||
|
DCR services (Linear, Jira, ClickUp) issue MCP-scoped tokens that cannot
|
||||||
|
call their standard REST/GraphQL APIs — metadata discovery for those
|
||||||
|
happens at runtime through MCP tools instead.
|
||||||
|
|
||||||
|
Pre-configured services (Slack, Airtable) use standard OAuth tokens that
|
||||||
|
*can* call their APIs, so we extract metadata here.
|
||||||
|
|
||||||
|
Failures are logged but never block connector creation.
|
||||||
|
"""
|
||||||
|
from app.services.mcp_oauth.registry import MCP_SERVICES
|
||||||
|
|
||||||
|
svc = MCP_SERVICES.get(service_key)
|
||||||
|
if not svc or svc.supports_dcr:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
meta: dict[str, Any] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if service_key == "slack":
|
||||||
|
team_info = token_json.get("team", {})
|
||||||
|
meta["team_id"] = team_info.get("id", "")
|
||||||
|
# TODO: oauth.v2.user.access only returns team.id, not
|
||||||
|
# team.name. To populate team_name, add "team:read" scope
|
||||||
|
# and call GET /api/team.info here.
|
||||||
|
meta["team_name"] = team_info.get("name", "")
|
||||||
|
if meta["team_name"]:
|
||||||
|
meta["display_name"] = meta["team_name"]
|
||||||
|
elif meta["team_id"]:
|
||||||
|
meta["display_name"] = f"Slack ({meta['team_id']})"
|
||||||
|
|
||||||
|
elif service_key == "airtable":
|
||||||
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||||
|
resp = await client.get(
|
||||||
|
"https://api.airtable.com/v0/meta/whoami",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
whoami = resp.json()
|
||||||
|
meta["user_id"] = whoami.get("id", "")
|
||||||
|
meta["user_email"] = whoami.get("email", "")
|
||||||
|
meta["display_name"] = whoami.get("email", "Airtable")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Airtable whoami API returned %d (non-blocking)", resp.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to fetch account metadata for %s (non-blocking)",
|
||||||
|
service_key,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return meta
|
||||||
|
|
||||||
|
_state_manager: OAuthStateManager | None = None
|
||||||
|
_token_encryption: TokenEncryption | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_state_manager() -> OAuthStateManager:
|
||||||
|
global _state_manager
|
||||||
|
if _state_manager is None:
|
||||||
|
if not config.SECRET_KEY:
|
||||||
|
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
|
||||||
|
_state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||||
|
return _state_manager
|
||||||
|
|
||||||
|
|
||||||
|
def _get_token_encryption() -> TokenEncryption:
|
||||||
|
global _token_encryption
|
||||||
|
if _token_encryption is None:
|
||||||
|
if not config.SECRET_KEY:
|
||||||
|
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
|
||||||
|
_token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||||
|
return _token_encryption
|
||||||
|
|
||||||
|
|
||||||
|
def _build_redirect_uri(service: str) -> str:
|
||||||
|
base = config.BACKEND_URL or "http://localhost:8000"
|
||||||
|
return f"{base.rstrip('/')}/api/v1/auth/mcp/{service}/connector/callback"
|
||||||
|
|
||||||
|
|
||||||
|
def _frontend_redirect(
|
||||||
|
space_id: int | None,
|
||||||
|
*,
|
||||||
|
success: bool = False,
|
||||||
|
connector_id: int | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
service: str = "mcp",
|
||||||
|
) -> RedirectResponse:
|
||||||
|
if success and space_id:
|
||||||
|
qs = f"success=true&connector={service}-mcp-connector"
|
||||||
|
if connector_id:
|
||||||
|
qs += f"&connectorId={connector_id}"
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
|
||||||
|
)
|
||||||
|
if error and space_id:
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
|
||||||
|
)
|
||||||
|
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# /add — start MCP OAuth flow
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/auth/mcp/{service}/connector/add")
|
||||||
|
async def connect_mcp_service(
|
||||||
|
service: str,
|
||||||
|
space_id: int,
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
from app.services.mcp_oauth.registry import get_service
|
||||||
|
|
||||||
|
svc = get_service(service)
|
||||||
|
if not svc:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.services.mcp_oauth.discovery import (
|
||||||
|
discover_oauth_metadata,
|
||||||
|
register_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = await discover_oauth_metadata(
|
||||||
|
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
|
||||||
|
)
|
||||||
|
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
|
||||||
|
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
|
||||||
|
registration_endpoint = metadata.get("registration_endpoint")
|
||||||
|
|
||||||
|
if not auth_endpoint or not token_endpoint:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
|
||||||
|
)
|
||||||
|
|
||||||
|
redirect_uri = _build_redirect_uri(service)
|
||||||
|
|
||||||
|
if svc.supports_dcr and registration_endpoint:
|
||||||
|
dcr = await register_client(registration_endpoint, redirect_uri)
|
||||||
|
client_id = dcr.get("client_id")
|
||||||
|
client_secret = dcr.get("client_secret", "")
|
||||||
|
if not client_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"DCR for {svc.name} did not return a client_id.",
|
||||||
|
)
|
||||||
|
elif svc.client_id_env:
|
||||||
|
client_id = getattr(config, svc.client_id_env, None)
|
||||||
|
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
|
||||||
|
if not client_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
|
||||||
|
)
|
||||||
|
|
||||||
|
verifier, challenge = generate_pkce_pair()
|
||||||
|
enc = _get_token_encryption()
|
||||||
|
|
||||||
|
state = _get_state_manager().generate_secure_state(
|
||||||
|
space_id,
|
||||||
|
user.id,
|
||||||
|
service=service,
|
||||||
|
code_verifier=verifier,
|
||||||
|
mcp_client_id=client_id,
|
||||||
|
mcp_client_secret=enc.encrypt_token(client_secret) if client_secret else "",
|
||||||
|
mcp_token_endpoint=token_endpoint,
|
||||||
|
mcp_url=svc.mcp_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_params: dict[str, str] = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"response_type": "code",
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"code_challenge": challenge,
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
if svc.scopes:
|
||||||
|
auth_params[svc.scope_param] = " ".join(svc.scopes)
|
||||||
|
|
||||||
|
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Generated %s MCP OAuth URL for user %s, space %s",
|
||||||
|
svc.name, user.id, space_id,
|
||||||
|
)
|
||||||
|
return {"auth_url": auth_url}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to initiate %s MCP OAuth: %s", service, e, exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to initiate {service} MCP OAuth.",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# /callback — handle OAuth redirect
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/auth/mcp/{service}/connector/callback")
|
||||||
|
async def mcp_oauth_callback(
|
||||||
|
service: str,
|
||||||
|
code: str | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
):
|
||||||
|
if error:
|
||||||
|
logger.warning("%s MCP OAuth error: %s", service, error)
|
||||||
|
space_id = None
|
||||||
|
if state:
|
||||||
|
try:
|
||||||
|
data = _get_state_manager().validate_state(state)
|
||||||
|
space_id = data.get("space_id")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return _frontend_redirect(
|
||||||
|
space_id, error=f"{service}_mcp_oauth_denied", service=service,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not code:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing authorization code")
|
||||||
|
if not state:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing state parameter")
|
||||||
|
|
||||||
|
data = _get_state_manager().validate_state(state)
|
||||||
|
user_id = UUID(data["user_id"])
|
||||||
|
space_id = data["space_id"]
|
||||||
|
svc_key = data.get("service", service)
|
||||||
|
|
||||||
|
if svc_key != service:
|
||||||
|
raise HTTPException(status_code=400, detail="State/path service mismatch")
|
||||||
|
|
||||||
|
from app.services.mcp_oauth.registry import get_service
|
||||||
|
|
||||||
|
svc = get_service(svc_key)
|
||||||
|
if not svc:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {svc_key}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.services.mcp_oauth.discovery import exchange_code_for_tokens
|
||||||
|
|
||||||
|
enc = _get_token_encryption()
|
||||||
|
client_id = data["mcp_client_id"]
|
||||||
|
client_secret = (
|
||||||
|
enc.decrypt_token(data["mcp_client_secret"])
|
||||||
|
if data.get("mcp_client_secret")
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
token_endpoint = data["mcp_token_endpoint"]
|
||||||
|
code_verifier = data["code_verifier"]
|
||||||
|
mcp_url = data["mcp_url"]
|
||||||
|
redirect_uri = _build_redirect_uri(service)
|
||||||
|
|
||||||
|
token_json = await exchange_code_for_tokens(
|
||||||
|
token_endpoint=token_endpoint,
|
||||||
|
code=code,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
code_verifier=code_verifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token = token_json.get("access_token")
|
||||||
|
refresh_token = token_json.get("refresh_token")
|
||||||
|
expires_in = token_json.get("expires_in")
|
||||||
|
scope = token_json.get("scope")
|
||||||
|
|
||||||
|
if not access_token and "authed_user" in token_json:
|
||||||
|
authed = token_json["authed_user"]
|
||||||
|
access_token = authed.get("access_token")
|
||||||
|
refresh_token = refresh_token or authed.get("refresh_token")
|
||||||
|
scope = scope or authed.get("scope")
|
||||||
|
expires_in = expires_in or authed.get("expires_in")
|
||||||
|
|
||||||
|
if not access_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"No access token received from {svc.name}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
expires_at = None
|
||||||
|
if expires_in:
|
||||||
|
expires_at = datetime.now(UTC) + timedelta(
|
||||||
|
seconds=int(expires_in)
|
||||||
|
)
|
||||||
|
|
||||||
|
connector_config = {
|
||||||
|
"server_config": {
|
||||||
|
"transport": "streamable-http",
|
||||||
|
"url": mcp_url,
|
||||||
|
},
|
||||||
|
"mcp_service": svc_key,
|
||||||
|
"mcp_oauth": {
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": enc.encrypt_token(client_secret) if client_secret else "",
|
||||||
|
"token_endpoint": token_endpoint,
|
||||||
|
"access_token": enc.encrypt_token(access_token),
|
||||||
|
"refresh_token": enc.encrypt_token(refresh_token) if refresh_token else None,
|
||||||
|
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||||
|
"scope": scope,
|
||||||
|
},
|
||||||
|
"_token_encrypted": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
account_meta = await _fetch_account_metadata(svc_key, access_token, token_json)
|
||||||
|
if account_meta:
|
||||||
|
_SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email",
|
||||||
|
"workspace_id", "workspace_name", "organization_name",
|
||||||
|
"organization_url_key", "cloud_id", "site_name", "base_url"}
|
||||||
|
for k, v in account_meta.items():
|
||||||
|
if k in _SAFE_META_KEYS:
|
||||||
|
connector_config[k] = v
|
||||||
|
logger.info(
|
||||||
|
"Stored account metadata for %s: display_name=%s",
|
||||||
|
svc_key, account_meta.get("display_name", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Re-auth path ----
|
||||||
|
db_connector_type = SearchSourceConnectorType(svc.connector_type)
|
||||||
|
reauth_connector_id = data.get("connector_id")
|
||||||
|
if reauth_connector_id:
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == reauth_connector_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.search_space_id == space_id,
|
||||||
|
SearchSourceConnector.connector_type == db_connector_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db_connector = result.scalars().first()
|
||||||
|
if not db_connector:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Connector not found during re-auth",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_connector.config = connector_config
|
||||||
|
flag_modified(db_connector, "config")
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(db_connector)
|
||||||
|
|
||||||
|
_invalidate_cache(space_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Re-authenticated %s MCP connector %s for user %s",
|
||||||
|
svc.name, db_connector.id, user_id,
|
||||||
|
)
|
||||||
|
reauth_return_url = data.get("return_url")
|
||||||
|
if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"):
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||||
|
)
|
||||||
|
return _frontend_redirect(
|
||||||
|
space_id, success=True, connector_id=db_connector.id, service=service,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- New connector path ----
|
||||||
|
naming_identifier = account_meta.get("display_name")
|
||||||
|
connector_name = await generate_unique_connector_name(
|
||||||
|
session,
|
||||||
|
db_connector_type,
|
||||||
|
space_id,
|
||||||
|
user_id,
|
||||||
|
naming_identifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_connector = SearchSourceConnector(
|
||||||
|
name=connector_name,
|
||||||
|
connector_type=db_connector_type,
|
||||||
|
is_indexable=False,
|
||||||
|
config=connector_config,
|
||||||
|
search_space_id=space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
session.add(new_connector)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await session.commit()
|
||||||
|
except IntegrityError as e:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409, detail="A connector for this service already exists.",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
_invalidate_cache(space_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Created %s MCP connector %s for user %s in space %s",
|
||||||
|
svc.name, new_connector.id, user_id, space_id,
|
||||||
|
)
|
||||||
|
return _frontend_redirect(
|
||||||
|
space_id, success=True, connector_id=new_connector.id, service=service,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to complete %s MCP OAuth: %s", service, e, exc_info=True,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Failed to complete {service} MCP OAuth.",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# /reauth — re-authenticate an existing MCP connector
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/auth/mcp/{service}/connector/reauth")
|
||||||
|
async def reauth_mcp_service(
|
||||||
|
service: str,
|
||||||
|
space_id: int,
|
||||||
|
connector_id: int,
|
||||||
|
return_url: str | None = None,
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
):
|
||||||
|
from app.services.mcp_oauth.registry import get_service
|
||||||
|
|
||||||
|
svc = get_service(service)
|
||||||
|
if not svc:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
|
||||||
|
|
||||||
|
db_connector_type = SearchSourceConnectorType(svc.connector_type)
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == connector_id,
|
||||||
|
SearchSourceConnector.user_id == user.id,
|
||||||
|
SearchSourceConnector.search_space_id == space_id,
|
||||||
|
SearchSourceConnector.connector_type == db_connector_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not result.scalars().first():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Connector not found or access denied",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.services.mcp_oauth.discovery import (
|
||||||
|
discover_oauth_metadata,
|
||||||
|
register_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = await discover_oauth_metadata(
|
||||||
|
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
|
||||||
|
)
|
||||||
|
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
|
||||||
|
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
|
||||||
|
registration_endpoint = metadata.get("registration_endpoint")
|
||||||
|
|
||||||
|
if not auth_endpoint or not token_endpoint:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
|
||||||
|
)
|
||||||
|
|
||||||
|
redirect_uri = _build_redirect_uri(service)
|
||||||
|
|
||||||
|
if svc.supports_dcr and registration_endpoint:
|
||||||
|
dcr = await register_client(registration_endpoint, redirect_uri)
|
||||||
|
client_id = dcr.get("client_id")
|
||||||
|
client_secret = dcr.get("client_secret", "")
|
||||||
|
if not client_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"DCR for {svc.name} did not return a client_id.",
|
||||||
|
)
|
||||||
|
elif svc.client_id_env:
|
||||||
|
client_id = getattr(config, svc.client_id_env, None)
|
||||||
|
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
|
||||||
|
if not client_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
|
||||||
|
)
|
||||||
|
|
||||||
|
verifier, challenge = generate_pkce_pair()
|
||||||
|
enc = _get_token_encryption()
|
||||||
|
|
||||||
|
extra: dict = {
|
||||||
|
"service": service,
|
||||||
|
"code_verifier": verifier,
|
||||||
|
"mcp_client_id": client_id,
|
||||||
|
"mcp_client_secret": enc.encrypt_token(client_secret) if client_secret else "",
|
||||||
|
"mcp_token_endpoint": token_endpoint,
|
||||||
|
"mcp_url": svc.mcp_url,
|
||||||
|
"connector_id": connector_id,
|
||||||
|
}
|
||||||
|
if return_url and return_url.startswith("/"):
|
||||||
|
extra["return_url"] = return_url
|
||||||
|
|
||||||
|
state = _get_state_manager().generate_secure_state(
|
||||||
|
space_id, user.id, **extra,
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_params: dict[str, str] = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"response_type": "code",
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"code_challenge": challenge,
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
if svc.scopes:
|
||||||
|
auth_params[svc.scope_param] = " ".join(svc.scopes)
|
||||||
|
|
||||||
|
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Initiating %s MCP re-auth for user %s, connector %s",
|
||||||
|
svc.name, user.id, connector_id,
|
||||||
|
)
|
||||||
|
return {"auth_url": auth_url}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to initiate %s MCP re-auth: %s", service, e, exc_info=True,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Failed to initiate {service} MCP re-auth.",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _invalidate_cache(space_id: int) -> None:
|
||||||
|
try:
|
||||||
|
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||||
|
|
||||||
|
invalidate_mcp_tools_cache(space_id)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("MCP cache invalidation skipped", exc_info=True)
|
||||||
|
|
@ -22,6 +22,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import (
|
||||||
|
ClientPlatform,
|
||||||
|
LocalFilesystemMount,
|
||||||
|
FilesystemMode,
|
||||||
|
FilesystemSelection,
|
||||||
|
)
|
||||||
|
from app.config import config
|
||||||
from app.db import (
|
from app.db import (
|
||||||
ChatComment,
|
ChatComment,
|
||||||
ChatVisibility,
|
ChatVisibility,
|
||||||
|
|
@ -36,6 +43,7 @@ from app.db import (
|
||||||
)
|
)
|
||||||
from app.schemas.new_chat import (
|
from app.schemas.new_chat import (
|
||||||
AgentToolInfo,
|
AgentToolInfo,
|
||||||
|
LocalFilesystemMountPayload,
|
||||||
NewChatMessageRead,
|
NewChatMessageRead,
|
||||||
NewChatRequest,
|
NewChatRequest,
|
||||||
NewChatThreadCreate,
|
NewChatThreadCreate,
|
||||||
|
|
@ -63,6 +71,67 @@ _background_tasks: set[asyncio.Task] = set()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_filesystem_selection(
|
||||||
|
*,
|
||||||
|
mode: str,
|
||||||
|
client_platform: str,
|
||||||
|
local_mounts: list[LocalFilesystemMountPayload] | None,
|
||||||
|
) -> FilesystemSelection:
|
||||||
|
"""Validate and normalize filesystem mode settings from request payload."""
|
||||||
|
try:
|
||||||
|
resolved_mode = FilesystemMode(mode)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid filesystem_mode") from exc
|
||||||
|
try:
|
||||||
|
resolved_platform = ClientPlatform(client_platform)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid client_platform") from exc
|
||||||
|
|
||||||
|
if resolved_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
|
||||||
|
if not config.ENABLE_DESKTOP_LOCAL_FILESYSTEM:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Desktop local filesystem mode is disabled on this deployment.",
|
||||||
|
)
|
||||||
|
if resolved_platform != ClientPlatform.DESKTOP:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="desktop_local_folder mode is only available on desktop runtime.",
|
||||||
|
)
|
||||||
|
normalized_mounts: list[tuple[str, str]] = []
|
||||||
|
seen_mounts: set[str] = set()
|
||||||
|
for mount in local_mounts or []:
|
||||||
|
mount_id = mount.mount_id.strip()
|
||||||
|
root_path = mount.root_path.strip()
|
||||||
|
if not mount_id or not root_path:
|
||||||
|
continue
|
||||||
|
if mount_id in seen_mounts:
|
||||||
|
continue
|
||||||
|
seen_mounts.add(mount_id)
|
||||||
|
normalized_mounts.append((mount_id, root_path))
|
||||||
|
if not normalized_mounts:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=(
|
||||||
|
"local_filesystem_mounts must include at least one mount for "
|
||||||
|
"desktop_local_folder mode."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return FilesystemSelection(
|
||||||
|
mode=resolved_mode,
|
||||||
|
client_platform=resolved_platform,
|
||||||
|
local_mounts=tuple(
|
||||||
|
LocalFilesystemMount(mount_id=mount_id, root_path=root_path)
|
||||||
|
for mount_id, root_path in normalized_mounts
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return FilesystemSelection(
|
||||||
|
mode=FilesystemMode.CLOUD,
|
||||||
|
client_platform=resolved_platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _try_delete_sandbox(thread_id: int) -> None:
|
def _try_delete_sandbox(thread_id: int) -> None:
|
||||||
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
||||||
from app.agents.new_chat.sandbox import (
|
from app.agents.new_chat.sandbox import (
|
||||||
|
|
@ -1098,6 +1167,7 @@ async def list_agent_tools(
|
||||||
@router.post("/new_chat")
|
@router.post("/new_chat")
|
||||||
async def handle_new_chat(
|
async def handle_new_chat(
|
||||||
request: NewChatRequest,
|
request: NewChatRequest,
|
||||||
|
http_request: Request,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
|
|
@ -1133,6 +1203,11 @@ async def handle_new_chat(
|
||||||
|
|
||||||
# Check thread-level access based on visibility
|
# Check thread-level access based on visibility
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
filesystem_selection = _resolve_filesystem_selection(
|
||||||
|
mode=request.filesystem_mode,
|
||||||
|
client_platform=request.client_platform,
|
||||||
|
local_mounts=request.local_filesystem_mounts,
|
||||||
|
)
|
||||||
|
|
||||||
# Get search space to check LLM config preferences
|
# Get search space to check LLM config preferences
|
||||||
search_space_result = await session.execute(
|
search_space_result = await session.execute(
|
||||||
|
|
@ -1175,6 +1250,8 @@ async def handle_new_chat(
|
||||||
thread_visibility=thread.visibility,
|
thread_visibility=thread.visibility,
|
||||||
current_user_display_name=user.display_name or "A team member",
|
current_user_display_name=user.display_name or "A team member",
|
||||||
disabled_tools=request.disabled_tools,
|
disabled_tools=request.disabled_tools,
|
||||||
|
filesystem_selection=filesystem_selection,
|
||||||
|
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
|
|
@ -1202,6 +1279,7 @@ async def handle_new_chat(
|
||||||
async def regenerate_response(
|
async def regenerate_response(
|
||||||
thread_id: int,
|
thread_id: int,
|
||||||
request: RegenerateRequest,
|
request: RegenerateRequest,
|
||||||
|
http_request: Request,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
|
|
@ -1247,6 +1325,11 @@ async def regenerate_response(
|
||||||
|
|
||||||
# Check thread-level access based on visibility
|
# Check thread-level access based on visibility
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
filesystem_selection = _resolve_filesystem_selection(
|
||||||
|
mode=request.filesystem_mode,
|
||||||
|
client_platform=request.client_platform,
|
||||||
|
local_mounts=request.local_filesystem_mounts,
|
||||||
|
)
|
||||||
|
|
||||||
# Get the checkpointer and state history
|
# Get the checkpointer and state history
|
||||||
checkpointer = await get_checkpointer()
|
checkpointer = await get_checkpointer()
|
||||||
|
|
@ -1412,6 +1495,8 @@ async def regenerate_response(
|
||||||
thread_visibility=thread.visibility,
|
thread_visibility=thread.visibility,
|
||||||
current_user_display_name=user.display_name or "A team member",
|
current_user_display_name=user.display_name or "A team member",
|
||||||
disabled_tools=request.disabled_tools,
|
disabled_tools=request.disabled_tools,
|
||||||
|
filesystem_selection=filesystem_selection,
|
||||||
|
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
streaming_completed = True
|
streaming_completed = True
|
||||||
|
|
@ -1477,6 +1562,7 @@ async def regenerate_response(
|
||||||
async def resume_chat(
|
async def resume_chat(
|
||||||
thread_id: int,
|
thread_id: int,
|
||||||
request: ResumeRequest,
|
request: ResumeRequest,
|
||||||
|
http_request: Request,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
|
|
@ -1498,6 +1584,11 @@ async def resume_chat(
|
||||||
)
|
)
|
||||||
|
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
filesystem_selection = _resolve_filesystem_selection(
|
||||||
|
mode=request.filesystem_mode,
|
||||||
|
client_platform=request.client_platform,
|
||||||
|
local_mounts=request.local_filesystem_mounts,
|
||||||
|
)
|
||||||
|
|
||||||
search_space_result = await session.execute(
|
search_space_result = await session.execute(
|
||||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||||
|
|
@ -1526,6 +1617,8 @@ async def resume_chat(
|
||||||
user_id=str(user.id),
|
user_id=str(user.id),
|
||||||
llm_config_id=llm_config_id,
|
llm_config_id=llm_config_id,
|
||||||
thread_visibility=thread.visibility,
|
thread_visibility=thread.visibility,
|
||||||
|
filesystem_selection=filesystem_selection,
|
||||||
|
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
|
|
|
||||||
620
surfsense_backend/app/routes/oauth_connector_base.py
Normal file
620
surfsense_backend/app/routes/oauth_connector_base.py
Normal file
|
|
@ -0,0 +1,620 @@
|
||||||
|
"""Reusable base for OAuth 2.0 connector routes.
|
||||||
|
|
||||||
|
Subclasses override ``fetch_account_info``, ``build_connector_config``,
|
||||||
|
and ``get_connector_display_name`` to customise provider-specific behaviour.
|
||||||
|
Call ``build_router()`` to get a FastAPI ``APIRouter`` with ``/connector/add``,
|
||||||
|
``/connector/callback``, and ``/connector/reauth`` endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.db import (
|
||||||
|
SearchSourceConnector,
|
||||||
|
SearchSourceConnectorType,
|
||||||
|
User,
|
||||||
|
get_async_session,
|
||||||
|
)
|
||||||
|
from app.users import current_active_user
|
||||||
|
from app.utils.connector_naming import (
|
||||||
|
check_duplicate_connector,
|
||||||
|
generate_unique_connector_name,
|
||||||
|
)
|
||||||
|
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthConnectorRoute:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
provider_name: str,
|
||||||
|
connector_type: SearchSourceConnectorType,
|
||||||
|
authorize_url: str,
|
||||||
|
token_url: str,
|
||||||
|
client_id_env: str,
|
||||||
|
client_secret_env: str,
|
||||||
|
redirect_uri_env: str,
|
||||||
|
scopes: list[str],
|
||||||
|
auth_prefix: str,
|
||||||
|
use_pkce: bool = False,
|
||||||
|
token_auth_method: str = "body",
|
||||||
|
is_indexable: bool = True,
|
||||||
|
extra_auth_params: dict[str, str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.provider_name = provider_name
|
||||||
|
self.connector_type = connector_type
|
||||||
|
self.authorize_url = authorize_url
|
||||||
|
self.token_url = token_url
|
||||||
|
self.client_id_env = client_id_env
|
||||||
|
self.client_secret_env = client_secret_env
|
||||||
|
self.redirect_uri_env = redirect_uri_env
|
||||||
|
self.scopes = scopes
|
||||||
|
self.auth_prefix = auth_prefix.rstrip("/")
|
||||||
|
self.use_pkce = use_pkce
|
||||||
|
self.token_auth_method = token_auth_method
|
||||||
|
self.is_indexable = is_indexable
|
||||||
|
self.extra_auth_params = extra_auth_params or {}
|
||||||
|
|
||||||
|
self._state_manager: OAuthStateManager | None = None
|
||||||
|
self._token_encryption: TokenEncryption | None = None
|
||||||
|
|
||||||
|
def _get_client_id(self) -> str:
|
||||||
|
value = getattr(config, self.client_id_env, None)
|
||||||
|
if not value:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"{self.provider_name.title()} OAuth not configured "
|
||||||
|
f"({self.client_id_env} missing).",
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _get_client_secret(self) -> str:
|
||||||
|
value = getattr(config, self.client_secret_env, None)
|
||||||
|
if not value:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"{self.provider_name.title()} OAuth not configured "
|
||||||
|
f"({self.client_secret_env} missing).",
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _get_redirect_uri(self) -> str:
|
||||||
|
value = getattr(config, self.redirect_uri_env, None)
|
||||||
|
if not value:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"{self.redirect_uri_env} not configured.",
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _get_state_manager(self) -> OAuthStateManager:
|
||||||
|
if self._state_manager is None:
|
||||||
|
if not config.SECRET_KEY:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="SECRET_KEY not configured for OAuth security.",
|
||||||
|
)
|
||||||
|
self._state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||||
|
return self._state_manager
|
||||||
|
|
||||||
|
def _get_token_encryption(self) -> TokenEncryption:
|
||||||
|
if self._token_encryption is None:
|
||||||
|
if not config.SECRET_KEY:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="SECRET_KEY not configured for token encryption.",
|
||||||
|
)
|
||||||
|
self._token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||||
|
return self._token_encryption
|
||||||
|
|
||||||
|
def _frontend_redirect(
|
||||||
|
self,
|
||||||
|
space_id: int | None,
|
||||||
|
*,
|
||||||
|
success: bool = False,
|
||||||
|
connector_id: int | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
) -> RedirectResponse:
|
||||||
|
if success and space_id:
|
||||||
|
connector_slug = f"{self.provider_name}-connector"
|
||||||
|
qs = f"success=true&connector={connector_slug}"
|
||||||
|
if connector_id:
|
||||||
|
qs += f"&connectorId={connector_id}"
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
|
||||||
|
)
|
||||||
|
if error and space_id:
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error={error}"
|
||||||
|
)
|
||||||
|
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
|
||||||
|
|
||||||
|
async def fetch_account_info(self, access_token: str) -> dict[str, Any]:
|
||||||
|
"""Override to fetch account/workspace info after token exchange.
|
||||||
|
|
||||||
|
Return dict is merged into connector config; key ``"name"`` is used
|
||||||
|
for the display name and dedup.
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def build_connector_config(
|
||||||
|
self,
|
||||||
|
token_json: dict[str, Any],
|
||||||
|
account_info: dict[str, Any],
|
||||||
|
encryption: TokenEncryption,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Override for custom config shapes. Default: standard encrypted OAuth fields."""
|
||||||
|
access_token = token_json.get("access_token", "")
|
||||||
|
refresh_token = token_json.get("refresh_token")
|
||||||
|
|
||||||
|
expires_at = None
|
||||||
|
if token_json.get("expires_in"):
|
||||||
|
expires_at = datetime.now(UTC) + timedelta(
|
||||||
|
seconds=int(token_json["expires_in"])
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg: dict[str, Any] = {
|
||||||
|
"access_token": encryption.encrypt_token(access_token),
|
||||||
|
"refresh_token": (
|
||||||
|
encryption.encrypt_token(refresh_token) if refresh_token else None
|
||||||
|
),
|
||||||
|
"token_type": token_json.get("token_type", "Bearer"),
|
||||||
|
"expires_in": token_json.get("expires_in"),
|
||||||
|
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||||
|
"scope": token_json.get("scope"),
|
||||||
|
"_token_encrypted": True,
|
||||||
|
}
|
||||||
|
cfg.update(account_info)
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
def get_connector_display_name(self, account_info: dict[str, Any]) -> str:
|
||||||
|
return str(account_info.get("name", self.provider_name.title()))
|
||||||
|
|
||||||
|
async def on_token_refresh_failure(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
connector.config = {**connector.config, "auth_expired": True}
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(connector)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to persist auth_expired flag for connector %s",
|
||||||
|
connector.id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _exchange_code(
|
||||||
|
self, code: str, extra_state: dict[str, Any]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
client_id = self._get_client_id()
|
||||||
|
client_secret = self._get_client_secret()
|
||||||
|
redirect_uri = self._get_redirect_uri()
|
||||||
|
|
||||||
|
headers: dict[str, str] = {
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
}
|
||||||
|
body: dict[str, str] = {
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.token_auth_method == "basic":
|
||||||
|
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||||
|
headers["Authorization"] = f"Basic {creds}"
|
||||||
|
else:
|
||||||
|
body["client_id"] = client_id
|
||||||
|
body["client_secret"] = client_secret
|
||||||
|
|
||||||
|
if self.use_pkce:
|
||||||
|
verifier = extra_state.get("code_verifier")
|
||||||
|
if verifier:
|
||||||
|
body["code_verifier"] = verifier
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.post(
|
||||||
|
self.token_url, data=body, headers=headers, timeout=30.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
detail = resp.text
|
||||||
|
try:
|
||||||
|
detail = resp.json().get("error_description", detail)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail=f"Token exchange failed: {detail}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
async def refresh_token(
|
||||||
|
self, session: AsyncSession, connector: SearchSourceConnector
|
||||||
|
) -> SearchSourceConnector:
|
||||||
|
encryption = self._get_token_encryption()
|
||||||
|
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||||
|
|
||||||
|
refresh_tok = connector.config.get("refresh_token")
|
||||||
|
if is_encrypted and refresh_tok:
|
||||||
|
try:
|
||||||
|
refresh_tok = encryption.decrypt_token(refresh_tok)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to decrypt refresh token: %s", e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Failed to decrypt stored refresh token"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if not refresh_tok:
|
||||||
|
await self.on_token_refresh_failure(session, connector)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No refresh token available. Please re-authenticate.",
|
||||||
|
)
|
||||||
|
|
||||||
|
client_id = self._get_client_id()
|
||||||
|
client_secret = self._get_client_secret()
|
||||||
|
|
||||||
|
headers: dict[str, str] = {
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
}
|
||||||
|
body: dict[str, str] = {
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": refresh_tok,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.token_auth_method == "basic":
|
||||||
|
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||||
|
headers["Authorization"] = f"Basic {creds}"
|
||||||
|
else:
|
||||||
|
body["client_id"] = client_id
|
||||||
|
body["client_secret"] = client_secret
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.post(
|
||||||
|
self.token_url, data=body, headers=headers, timeout=30.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
error_detail = resp.text
|
||||||
|
try:
|
||||||
|
ej = resp.json()
|
||||||
|
error_detail = ej.get("error_description", error_detail)
|
||||||
|
error_code = ej.get("error", "")
|
||||||
|
except Exception:
|
||||||
|
error_code = ""
|
||||||
|
combined = (error_detail + error_code).lower()
|
||||||
|
if any(kw in combined for kw in ("invalid_grant", "expired", "revoked")):
|
||||||
|
await self.on_token_refresh_failure(session, connector)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail=f"{self.provider_name.title()} authentication failed. "
|
||||||
|
"Please re-authenticate.",
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||||
|
)
|
||||||
|
|
||||||
|
token_json = resp.json()
|
||||||
|
new_access = token_json.get("access_token")
|
||||||
|
if not new_access:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="No access token received from refresh"
|
||||||
|
)
|
||||||
|
|
||||||
|
expires_at = None
|
||||||
|
if token_json.get("expires_in"):
|
||||||
|
expires_at = datetime.now(UTC) + timedelta(
|
||||||
|
seconds=int(token_json["expires_in"])
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_config = dict(connector.config)
|
||||||
|
updated_config["access_token"] = encryption.encrypt_token(new_access)
|
||||||
|
new_refresh = token_json.get("refresh_token")
|
||||||
|
if new_refresh:
|
||||||
|
updated_config["refresh_token"] = encryption.encrypt_token(new_refresh)
|
||||||
|
updated_config["expires_in"] = token_json.get("expires_in")
|
||||||
|
updated_config["expires_at"] = expires_at.isoformat() if expires_at else None
|
||||||
|
updated_config["scope"] = token_json.get("scope", updated_config.get("scope"))
|
||||||
|
updated_config["_token_encrypted"] = True
|
||||||
|
updated_config.pop("auth_expired", None)
|
||||||
|
|
||||||
|
connector.config = updated_config
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(connector)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Refreshed %s token for connector %s",
|
||||||
|
self.provider_name,
|
||||||
|
connector.id,
|
||||||
|
)
|
||||||
|
return connector
|
||||||
|
|
||||||
|
def build_router(self) -> APIRouter:
|
||||||
|
router = APIRouter()
|
||||||
|
oauth = self
|
||||||
|
|
||||||
|
@router.get(f"{oauth.auth_prefix}/connector/add")
|
||||||
|
async def connect(
|
||||||
|
space_id: int,
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
if not space_id:
|
||||||
|
raise HTTPException(status_code=400, detail="space_id is required")
|
||||||
|
|
||||||
|
client_id = oauth._get_client_id()
|
||||||
|
state_mgr = oauth._get_state_manager()
|
||||||
|
|
||||||
|
extra_state: dict[str, Any] = {}
|
||||||
|
auth_params: dict[str, str] = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"response_type": "code",
|
||||||
|
"redirect_uri": oauth._get_redirect_uri(),
|
||||||
|
"scope": " ".join(oauth.scopes),
|
||||||
|
}
|
||||||
|
|
||||||
|
if oauth.use_pkce:
|
||||||
|
from app.utils.oauth_security import generate_pkce_pair
|
||||||
|
|
||||||
|
verifier, challenge = generate_pkce_pair()
|
||||||
|
extra_state["code_verifier"] = verifier
|
||||||
|
auth_params["code_challenge"] = challenge
|
||||||
|
auth_params["code_challenge_method"] = "S256"
|
||||||
|
|
||||||
|
auth_params.update(oauth.extra_auth_params)
|
||||||
|
|
||||||
|
state_encoded = state_mgr.generate_secure_state(
|
||||||
|
space_id, user.id, **extra_state
|
||||||
|
)
|
||||||
|
auth_params["state"] = state_encoded
|
||||||
|
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Generated %s OAuth URL for user %s, space %s",
|
||||||
|
oauth.provider_name,
|
||||||
|
user.id,
|
||||||
|
space_id,
|
||||||
|
)
|
||||||
|
return {"auth_url": auth_url}
|
||||||
|
|
||||||
|
@router.get(f"{oauth.auth_prefix}/connector/reauth")
|
||||||
|
async def reauth(
|
||||||
|
space_id: int,
|
||||||
|
connector_id: int,
|
||||||
|
return_url: str | None = None,
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
):
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == connector_id,
|
||||||
|
SearchSourceConnector.user_id == user.id,
|
||||||
|
SearchSourceConnector.search_space_id == space_id,
|
||||||
|
SearchSourceConnector.connector_type == oauth.connector_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not result.scalars().first():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"{oauth.provider_name.title()} connector not found "
|
||||||
|
"or access denied",
|
||||||
|
)
|
||||||
|
|
||||||
|
client_id = oauth._get_client_id()
|
||||||
|
state_mgr = oauth._get_state_manager()
|
||||||
|
|
||||||
|
extra: dict[str, Any] = {"connector_id": connector_id}
|
||||||
|
if return_url and return_url.startswith("/") and not return_url.startswith("//"):
|
||||||
|
extra["return_url"] = return_url
|
||||||
|
|
||||||
|
auth_params: dict[str, str] = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"response_type": "code",
|
||||||
|
"redirect_uri": oauth._get_redirect_uri(),
|
||||||
|
"scope": " ".join(oauth.scopes),
|
||||||
|
}
|
||||||
|
|
||||||
|
if oauth.use_pkce:
|
||||||
|
from app.utils.oauth_security import generate_pkce_pair
|
||||||
|
|
||||||
|
verifier, challenge = generate_pkce_pair()
|
||||||
|
extra["code_verifier"] = verifier
|
||||||
|
auth_params["code_challenge"] = challenge
|
||||||
|
auth_params["code_challenge_method"] = "S256"
|
||||||
|
|
||||||
|
auth_params.update(oauth.extra_auth_params)
|
||||||
|
|
||||||
|
state_encoded = state_mgr.generate_secure_state(
|
||||||
|
space_id, user.id, **extra
|
||||||
|
)
|
||||||
|
auth_params["state"] = state_encoded
|
||||||
|
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Initiating %s re-auth for user %s, connector %s",
|
||||||
|
oauth.provider_name,
|
||||||
|
user.id,
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return {"auth_url": auth_url}
|
||||||
|
|
||||||
|
@router.get(f"{oauth.auth_prefix}/connector/callback")
|
||||||
|
async def callback(
|
||||||
|
code: str | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
):
|
||||||
|
error_label = f"{oauth.provider_name}_oauth_denied"
|
||||||
|
|
||||||
|
if error:
|
||||||
|
logger.warning("%s OAuth error: %s", oauth.provider_name, error)
|
||||||
|
space_id = None
|
||||||
|
if state:
|
||||||
|
try:
|
||||||
|
data = oauth._get_state_manager().validate_state(state)
|
||||||
|
space_id = data.get("space_id")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return oauth._frontend_redirect(space_id, error=error_label)
|
||||||
|
|
||||||
|
if not code:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Missing authorization code"
|
||||||
|
)
|
||||||
|
if not state:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Missing state parameter"
|
||||||
|
)
|
||||||
|
|
||||||
|
state_mgr = oauth._get_state_manager()
|
||||||
|
try:
|
||||||
|
data = state_mgr.validate_state(state)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Invalid or expired state parameter."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
user_id = UUID(data["user_id"])
|
||||||
|
space_id = data["space_id"]
|
||||||
|
|
||||||
|
token_json = await oauth._exchange_code(code, data)
|
||||||
|
|
||||||
|
access_token = token_json.get("access_token", "")
|
||||||
|
if not access_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"No access token received from {oauth.provider_name.title()}",
|
||||||
|
)
|
||||||
|
|
||||||
|
account_info = await oauth.fetch_account_info(access_token)
|
||||||
|
encryption = oauth._get_token_encryption()
|
||||||
|
connector_config = oauth.build_connector_config(
|
||||||
|
token_json, account_info, encryption
|
||||||
|
)
|
||||||
|
|
||||||
|
display_name = oauth.get_connector_display_name(account_info)
|
||||||
|
|
||||||
|
# --- Re-auth path ---
|
||||||
|
reauth_connector_id = data.get("connector_id")
|
||||||
|
reauth_return_url = data.get("return_url")
|
||||||
|
|
||||||
|
if reauth_connector_id:
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == reauth_connector_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.search_space_id == space_id,
|
||||||
|
SearchSourceConnector.connector_type == oauth.connector_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db_connector = result.scalars().first()
|
||||||
|
if not db_connector:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Connector not found or access denied during re-auth",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_connector.config = connector_config
|
||||||
|
flag_modified(db_connector, "config")
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(db_connector)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Re-authenticated %s connector %s for user %s",
|
||||||
|
oauth.provider_name,
|
||||||
|
db_connector.id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"):
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||||
|
)
|
||||||
|
return oauth._frontend_redirect(
|
||||||
|
space_id, success=True, connector_id=db_connector.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- New connector path ---
|
||||||
|
is_dup = await check_duplicate_connector(
|
||||||
|
session,
|
||||||
|
oauth.connector_type,
|
||||||
|
space_id,
|
||||||
|
user_id,
|
||||||
|
display_name,
|
||||||
|
)
|
||||||
|
if is_dup:
|
||||||
|
logger.warning(
|
||||||
|
"Duplicate %s connector for user %s (%s)",
|
||||||
|
oauth.provider_name,
|
||||||
|
user_id,
|
||||||
|
display_name,
|
||||||
|
)
|
||||||
|
return oauth._frontend_redirect(
|
||||||
|
space_id,
|
||||||
|
error=f"duplicate_account&connector={oauth.provider_name}-connector",
|
||||||
|
)
|
||||||
|
|
||||||
|
connector_name = await generate_unique_connector_name(
|
||||||
|
session,
|
||||||
|
oauth.connector_type,
|
||||||
|
space_id,
|
||||||
|
user_id,
|
||||||
|
display_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_connector = SearchSourceConnector(
|
||||||
|
name=connector_name,
|
||||||
|
connector_type=oauth.connector_type,
|
||||||
|
is_indexable=oauth.is_indexable,
|
||||||
|
config=connector_config,
|
||||||
|
search_space_id=space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
session.add(new_connector)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await session.commit()
|
||||||
|
except IntegrityError as e:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409, detail="A connector for this service already exists."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Created %s connector %s for user %s in space %s",
|
||||||
|
oauth.provider_name,
|
||||||
|
new_connector.id,
|
||||||
|
user_id,
|
||||||
|
space_id,
|
||||||
|
)
|
||||||
|
return oauth._frontend_redirect(
|
||||||
|
space_id, success=True, connector_id=new_connector.id
|
||||||
|
)
|
||||||
|
|
||||||
|
return router
|
||||||
|
|
@ -693,27 +693,10 @@ async def index_connector_content(
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Index content from a connector to a search space.
|
Index content from a KB connector to a search space.
|
||||||
Requires CONNECTORS_UPDATE permission (to trigger indexing).
|
|
||||||
|
|
||||||
Currently supports:
|
Live connectors (Slack, Teams, Linear, Jira, ClickUp, Calendar, Airtable,
|
||||||
- SLACK_CONNECTOR: Indexes messages from all accessible Slack channels
|
Gmail, Discord, Luma) use real-time agent tools instead.
|
||||||
- TEAMS_CONNECTOR: Indexes messages from all accessible Microsoft Teams channels
|
|
||||||
- NOTION_CONNECTOR: Indexes pages from all accessible Notion pages
|
|
||||||
- GITHUB_CONNECTOR: Indexes code and documentation from GitHub repositories
|
|
||||||
- LINEAR_CONNECTOR: Indexes issues and comments from Linear
|
|
||||||
- JIRA_CONNECTOR: Indexes issues and comments from Jira
|
|
||||||
- DISCORD_CONNECTOR: Indexes messages from all accessible Discord channels
|
|
||||||
- LUMA_CONNECTOR: Indexes events from Luma
|
|
||||||
- ELASTICSEARCH_CONNECTOR: Indexes documents from Elasticsearch
|
|
||||||
- WEBCRAWLER_CONNECTOR: Indexes web pages from crawled websites
|
|
||||||
|
|
||||||
Args:
|
|
||||||
connector_id: ID of the connector to use
|
|
||||||
search_space_id: ID of the search space to store indexed content
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with indexing status
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get the connector first
|
# Get the connector first
|
||||||
|
|
@ -770,9 +753,7 @@ async def index_connector_content(
|
||||||
|
|
||||||
# For calendar connectors, default to today but allow future dates if explicitly provided
|
# For calendar connectors, default to today but allow future dates if explicitly provided
|
||||||
if connector.connector_type in [
|
if connector.connector_type in [
|
||||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
|
||||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||||
SearchSourceConnectorType.LUMA_CONNECTOR,
|
|
||||||
]:
|
]:
|
||||||
# Default to today if no end_date provided (users can manually select future dates)
|
# Default to today if no end_date provided (users can manually select future dates)
|
||||||
indexing_to = today_str if end_date is None else end_date
|
indexing_to = today_str if end_date is None else end_date
|
||||||
|
|
@ -796,33 +777,22 @@ async def index_connector_content(
|
||||||
# For non-calendar connectors, cap at today
|
# For non-calendar connectors, cap at today
|
||||||
indexing_to = end_date if end_date else today_str
|
indexing_to = end_date if end_date else today_str
|
||||||
|
|
||||||
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES
|
||||||
from app.tasks.celery_tasks.connector_tasks import (
|
|
||||||
index_slack_messages_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
if connector.connector_type in LIVE_CONNECTOR_TYPES:
|
||||||
f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
return {
|
||||||
)
|
"message": (
|
||||||
index_slack_messages_task.delay(
|
f"{connector.connector_type.value} uses real-time agent tools; "
|
||||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
"background indexing is disabled."
|
||||||
)
|
),
|
||||||
response_message = "Slack indexing started in the background."
|
"indexing_started": False,
|
||||||
|
"connector_id": connector_id,
|
||||||
|
"search_space_id": search_space_id,
|
||||||
|
"indexing_from": indexing_from,
|
||||||
|
"indexing_to": indexing_to,
|
||||||
|
}
|
||||||
|
|
||||||
elif connector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR:
|
if connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||||
from app.tasks.celery_tasks.connector_tasks import (
|
|
||||||
index_teams_messages_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Triggering Teams indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
|
||||||
)
|
|
||||||
index_teams_messages_task.delay(
|
|
||||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
|
||||||
)
|
|
||||||
response_message = "Teams indexing started in the background."
|
|
||||||
|
|
||||||
elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
|
||||||
from app.tasks.celery_tasks.connector_tasks import index_notion_pages_task
|
from app.tasks.celery_tasks.connector_tasks import index_notion_pages_task
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -844,28 +814,6 @@ async def index_connector_content(
|
||||||
)
|
)
|
||||||
response_message = "GitHub indexing started in the background."
|
response_message = "GitHub indexing started in the background."
|
||||||
|
|
||||||
elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR:
|
|
||||||
from app.tasks.celery_tasks.connector_tasks import index_linear_issues_task
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
|
||||||
)
|
|
||||||
index_linear_issues_task.delay(
|
|
||||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
|
||||||
)
|
|
||||||
response_message = "Linear indexing started in the background."
|
|
||||||
|
|
||||||
elif connector.connector_type == SearchSourceConnectorType.JIRA_CONNECTOR:
|
|
||||||
from app.tasks.celery_tasks.connector_tasks import index_jira_issues_task
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Triggering Jira indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
|
||||||
)
|
|
||||||
index_jira_issues_task.delay(
|
|
||||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
|
||||||
)
|
|
||||||
response_message = "Jira indexing started in the background."
|
|
||||||
|
|
||||||
elif connector.connector_type == SearchSourceConnectorType.CONFLUENCE_CONNECTOR:
|
elif connector.connector_type == SearchSourceConnectorType.CONFLUENCE_CONNECTOR:
|
||||||
from app.tasks.celery_tasks.connector_tasks import (
|
from app.tasks.celery_tasks.connector_tasks import (
|
||||||
index_confluence_pages_task,
|
index_confluence_pages_task,
|
||||||
|
|
@ -892,59 +840,6 @@ async def index_connector_content(
|
||||||
)
|
)
|
||||||
response_message = "BookStack indexing started in the background."
|
response_message = "BookStack indexing started in the background."
|
||||||
|
|
||||||
elif connector.connector_type == SearchSourceConnectorType.CLICKUP_CONNECTOR:
|
|
||||||
from app.tasks.celery_tasks.connector_tasks import index_clickup_tasks_task
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Triggering ClickUp indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
|
||||||
)
|
|
||||||
index_clickup_tasks_task.delay(
|
|
||||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
|
||||||
)
|
|
||||||
response_message = "ClickUp indexing started in the background."
|
|
||||||
|
|
||||||
elif (
|
|
||||||
connector.connector_type
|
|
||||||
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR
|
|
||||||
):
|
|
||||||
from app.tasks.celery_tasks.connector_tasks import (
|
|
||||||
index_google_calendar_events_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Triggering Google Calendar indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
|
||||||
)
|
|
||||||
index_google_calendar_events_task.delay(
|
|
||||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
|
||||||
)
|
|
||||||
response_message = "Google Calendar indexing started in the background."
|
|
||||||
elif connector.connector_type == SearchSourceConnectorType.AIRTABLE_CONNECTOR:
|
|
||||||
from app.tasks.celery_tasks.connector_tasks import (
|
|
||||||
index_airtable_records_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Triggering Airtable indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
|
||||||
)
|
|
||||||
index_airtable_records_task.delay(
|
|
||||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
|
||||||
)
|
|
||||||
response_message = "Airtable indexing started in the background."
|
|
||||||
elif (
|
|
||||||
connector.connector_type == SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR
|
|
||||||
):
|
|
||||||
from app.tasks.celery_tasks.connector_tasks import (
|
|
||||||
index_google_gmail_messages_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Triggering Google Gmail indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
|
||||||
)
|
|
||||||
index_google_gmail_messages_task.delay(
|
|
||||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
|
||||||
)
|
|
||||||
response_message = "Google Gmail indexing started in the background."
|
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
connector.connector_type == SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR
|
connector.connector_type == SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR
|
||||||
):
|
):
|
||||||
|
|
@ -1089,30 +984,6 @@ async def index_connector_content(
|
||||||
)
|
)
|
||||||
response_message = "Dropbox indexing started in the background."
|
response_message = "Dropbox indexing started in the background."
|
||||||
|
|
||||||
elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
|
|
||||||
from app.tasks.celery_tasks.connector_tasks import (
|
|
||||||
index_discord_messages_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
|
||||||
)
|
|
||||||
index_discord_messages_task.delay(
|
|
||||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
|
||||||
)
|
|
||||||
response_message = "Discord indexing started in the background."
|
|
||||||
|
|
||||||
elif connector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR:
|
|
||||||
from app.tasks.celery_tasks.connector_tasks import index_luma_events_task
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Triggering Luma indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
|
||||||
)
|
|
||||||
index_luma_events_task.delay(
|
|
||||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
|
||||||
)
|
|
||||||
response_message = "Luma indexing started in the background."
|
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR
|
== SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR
|
||||||
|
|
@ -1319,57 +1190,6 @@ async def _update_connector_timestamp_by_id(session: AsyncSession, connector_id:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
|
|
||||||
|
|
||||||
async def run_slack_indexing_with_new_session(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a new session and run the Slack indexing task.
|
|
||||||
This prevents session leaks by creating a dedicated session for the background task.
|
|
||||||
"""
|
|
||||||
async with async_session_maker() as session:
|
|
||||||
await run_slack_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_slack_indexing(
|
|
||||||
session: AsyncSession,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Background task to run Slack indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
connector_id: ID of the Slack connector
|
|
||||||
search_space_id: ID of the search space
|
|
||||||
user_id: ID of the user
|
|
||||||
start_date: Start date for indexing
|
|
||||||
end_date: End date for indexing
|
|
||||||
"""
|
|
||||||
from app.tasks.connector_indexers import index_slack_messages
|
|
||||||
|
|
||||||
await _run_indexing_with_notifications(
|
|
||||||
session=session,
|
|
||||||
connector_id=connector_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
indexing_function=index_slack_messages,
|
|
||||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
|
||||||
supports_heartbeat_callback=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_AUTH_ERROR_PATTERNS = (
|
_AUTH_ERROR_PATTERNS = (
|
||||||
"failed to refresh linear oauth",
|
"failed to refresh linear oauth",
|
||||||
"failed to refresh your notion connection",
|
"failed to refresh your notion connection",
|
||||||
|
|
@ -1908,215 +1728,6 @@ async def run_github_indexing(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Add new helper functions for Linear indexing
|
|
||||||
async def run_linear_indexing_with_new_session(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Wrapper to run Linear indexing with its own database session."""
|
|
||||||
logger.info(
|
|
||||||
f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
|
||||||
)
|
|
||||||
async with async_session_maker() as session:
|
|
||||||
await run_linear_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
logger.info(f"Background task finished: Indexing Linear connector {connector_id}")
|
|
||||||
|
|
||||||
|
|
||||||
async def run_linear_indexing(
|
|
||||||
session: AsyncSession,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Background task to run Linear indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
connector_id: ID of the Linear connector
|
|
||||||
search_space_id: ID of the search space
|
|
||||||
user_id: ID of the user
|
|
||||||
start_date: Start date for indexing
|
|
||||||
end_date: End date for indexing
|
|
||||||
"""
|
|
||||||
from app.tasks.connector_indexers import index_linear_issues
|
|
||||||
|
|
||||||
await _run_indexing_with_notifications(
|
|
||||||
session=session,
|
|
||||||
connector_id=connector_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
indexing_function=index_linear_issues,
|
|
||||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
|
||||||
supports_heartbeat_callback=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Add new helper functions for discord indexing
|
|
||||||
async def run_discord_indexing_with_new_session(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a new session and run the Discord indexing task.
|
|
||||||
This prevents session leaks by creating a dedicated session for the background task.
|
|
||||||
"""
|
|
||||||
async with async_session_maker() as session:
|
|
||||||
await run_discord_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_discord_indexing(
|
|
||||||
session: AsyncSession,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Background task to run Discord indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
connector_id: ID of the Discord connector
|
|
||||||
search_space_id: ID of the search space
|
|
||||||
user_id: ID of the user
|
|
||||||
start_date: Start date for indexing
|
|
||||||
end_date: End date for indexing
|
|
||||||
"""
|
|
||||||
from app.tasks.connector_indexers import index_discord_messages
|
|
||||||
|
|
||||||
await _run_indexing_with_notifications(
|
|
||||||
session=session,
|
|
||||||
connector_id=connector_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
indexing_function=index_discord_messages,
|
|
||||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
|
||||||
supports_heartbeat_callback=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_teams_indexing_with_new_session(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a new session and run the Microsoft Teams indexing task.
|
|
||||||
This prevents session leaks by creating a dedicated session for the background task.
|
|
||||||
"""
|
|
||||||
async with async_session_maker() as session:
|
|
||||||
await run_teams_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_teams_indexing(
|
|
||||||
session: AsyncSession,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Background task to run Microsoft Teams indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
connector_id: ID of the Teams connector
|
|
||||||
search_space_id: ID of the search space
|
|
||||||
user_id: ID of the user
|
|
||||||
start_date: Start date for indexing
|
|
||||||
end_date: End date for indexing
|
|
||||||
"""
|
|
||||||
from app.tasks.connector_indexers.teams_indexer import index_teams_messages
|
|
||||||
|
|
||||||
await _run_indexing_with_notifications(
|
|
||||||
session=session,
|
|
||||||
connector_id=connector_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
indexing_function=index_teams_messages,
|
|
||||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
|
||||||
supports_heartbeat_callback=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Add new helper functions for Jira indexing
|
|
||||||
async def run_jira_indexing_with_new_session(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Wrapper to run Jira indexing with its own database session."""
|
|
||||||
logger.info(
|
|
||||||
f"Background task started: Indexing Jira connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
|
||||||
)
|
|
||||||
async with async_session_maker() as session:
|
|
||||||
await run_jira_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
logger.info(f"Background task finished: Indexing Jira connector {connector_id}")
|
|
||||||
|
|
||||||
|
|
||||||
async def run_jira_indexing(
|
|
||||||
session: AsyncSession,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Background task to run Jira indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
connector_id: ID of the Jira connector
|
|
||||||
search_space_id: ID of the search space
|
|
||||||
user_id: ID of the user
|
|
||||||
start_date: Start date for indexing
|
|
||||||
end_date: End date for indexing
|
|
||||||
"""
|
|
||||||
from app.tasks.connector_indexers import index_jira_issues
|
|
||||||
|
|
||||||
await _run_indexing_with_notifications(
|
|
||||||
session=session,
|
|
||||||
connector_id=connector_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
indexing_function=index_jira_issues,
|
|
||||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
|
||||||
supports_heartbeat_callback=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Add new helper functions for Confluence indexing
|
# Add new helper functions for Confluence indexing
|
||||||
async def run_confluence_indexing_with_new_session(
|
async def run_confluence_indexing_with_new_session(
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
|
|
@ -2172,112 +1783,6 @@ async def run_confluence_indexing(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Add new helper functions for ClickUp indexing
|
|
||||||
async def run_clickup_indexing_with_new_session(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Wrapper to run ClickUp indexing with its own database session."""
|
|
||||||
logger.info(
|
|
||||||
f"Background task started: Indexing ClickUp connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
|
||||||
)
|
|
||||||
async with async_session_maker() as session:
|
|
||||||
await run_clickup_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
logger.info(f"Background task finished: Indexing ClickUp connector {connector_id}")
|
|
||||||
|
|
||||||
|
|
||||||
async def run_clickup_indexing(
|
|
||||||
session: AsyncSession,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Background task to run ClickUp indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
connector_id: ID of the ClickUp connector
|
|
||||||
search_space_id: ID of the search space
|
|
||||||
user_id: ID of the user
|
|
||||||
start_date: Start date for indexing
|
|
||||||
end_date: End date for indexing
|
|
||||||
"""
|
|
||||||
from app.tasks.connector_indexers import index_clickup_tasks
|
|
||||||
|
|
||||||
await _run_indexing_with_notifications(
|
|
||||||
session=session,
|
|
||||||
connector_id=connector_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
indexing_function=index_clickup_tasks,
|
|
||||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
|
||||||
supports_heartbeat_callback=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Add new helper functions for Airtable indexing
|
|
||||||
async def run_airtable_indexing_with_new_session(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Wrapper to run Airtable indexing with its own database session."""
|
|
||||||
logger.info(
|
|
||||||
f"Background task started: Indexing Airtable connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
|
||||||
)
|
|
||||||
async with async_session_maker() as session:
|
|
||||||
await run_airtable_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
logger.info(f"Background task finished: Indexing Airtable connector {connector_id}")
|
|
||||||
|
|
||||||
|
|
||||||
async def run_airtable_indexing(
|
|
||||||
session: AsyncSession,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Background task to run Airtable indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
connector_id: ID of the Airtable connector
|
|
||||||
search_space_id: ID of the search space
|
|
||||||
user_id: ID of the user
|
|
||||||
start_date: Start date for indexing
|
|
||||||
end_date: End date for indexing
|
|
||||||
"""
|
|
||||||
from app.tasks.connector_indexers import index_airtable_records
|
|
||||||
|
|
||||||
await _run_indexing_with_notifications(
|
|
||||||
session=session,
|
|
||||||
connector_id=connector_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
indexing_function=index_airtable_records,
|
|
||||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
|
||||||
supports_heartbeat_callback=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Add new helper functions for Google Calendar indexing
|
# Add new helper functions for Google Calendar indexing
|
||||||
async def run_google_calendar_indexing_with_new_session(
|
async def run_google_calendar_indexing_with_new_session(
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
|
|
@ -2816,58 +2321,6 @@ async def run_dropbox_indexing(
|
||||||
logger.error(f"Failed to update notification: {notif_error!s}")
|
logger.error(f"Failed to update notification: {notif_error!s}")
|
||||||
|
|
||||||
|
|
||||||
# Add new helper functions for luma indexing
|
|
||||||
async def run_luma_indexing_with_new_session(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a new session and run the Luma indexing task.
|
|
||||||
This prevents session leaks by creating a dedicated session for the background task.
|
|
||||||
"""
|
|
||||||
async with async_session_maker() as session:
|
|
||||||
await run_luma_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_luma_indexing(
|
|
||||||
session: AsyncSession,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Background task to run Luma indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
connector_id: ID of the Luma connector
|
|
||||||
search_space_id: ID of the search space
|
|
||||||
user_id: ID of the user
|
|
||||||
start_date: Start date for indexing
|
|
||||||
end_date: End date for indexing
|
|
||||||
"""
|
|
||||||
from app.tasks.connector_indexers import index_luma_events
|
|
||||||
|
|
||||||
await _run_indexing_with_notifications(
|
|
||||||
session=session,
|
|
||||||
connector_id=connector_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
indexing_function=index_luma_events,
|
|
||||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
|
||||||
supports_heartbeat_callback=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_elasticsearch_indexing_with_new_session(
|
async def run_elasticsearch_indexing_with_new_session(
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
@ -3580,13 +3033,18 @@ async def trust_mcp_tool(
|
||||||
"""Add a tool to the MCP connector's trusted (always-allow) list.
|
"""Add a tool to the MCP connector's trusted (always-allow) list.
|
||||||
|
|
||||||
Once trusted, the tool executes without HITL approval on subsequent calls.
|
Once trusted, the tool executes without HITL approval on subsequent calls.
|
||||||
|
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors
|
||||||
|
(LINEAR_CONNECTOR, JIRA_CONNECTOR, etc.) by checking for ``server_config``.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
from sqlalchemy import cast
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
|
||||||
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.id == connector_id,
|
SearchSourceConnector.id == connector_id,
|
||||||
SearchSourceConnector.connector_type
|
SearchSourceConnector.user_id == user.id,
|
||||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
connector = result.scalars().first()
|
||||||
|
|
@ -3631,13 +3089,17 @@ async def untrust_mcp_tool(
|
||||||
"""Remove a tool from the MCP connector's trusted list.
|
"""Remove a tool from the MCP connector's trusted list.
|
||||||
|
|
||||||
The tool will require HITL approval again on subsequent calls.
|
The tool will require HITL approval again on subsequent calls.
|
||||||
|
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
from sqlalchemy import cast
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
|
||||||
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.id == connector_id,
|
SearchSourceConnector.id == connector_id,
|
||||||
SearchSourceConnector.connector_type
|
SearchSourceConnector.user_id == user.id,
|
||||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
connector = result.scalars().first()
|
||||||
|
|
|
||||||
|
|
@ -312,7 +312,7 @@ async def slack_callback(
|
||||||
new_connector = SearchSourceConnector(
|
new_connector = SearchSourceConnector(
|
||||||
name=connector_name,
|
name=connector_name,
|
||||||
connector_type=SearchSourceConnectorType.SLACK_CONNECTOR,
|
connector_type=SearchSourceConnectorType.SLACK_CONNECTOR,
|
||||||
is_indexable=True,
|
is_indexable=False,
|
||||||
config=connector_config,
|
config=connector_config,
|
||||||
search_space_id=space_id,
|
search_space_id=space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ SCOPES = [
|
||||||
"Team.ReadBasic.All", # Read basic team information
|
"Team.ReadBasic.All", # Read basic team information
|
||||||
"Channel.ReadBasic.All", # Read basic channel information
|
"Channel.ReadBasic.All", # Read basic channel information
|
||||||
"ChannelMessage.Read.All", # Read messages in channels
|
"ChannelMessage.Read.All", # Read messages in channels
|
||||||
|
"ChannelMessage.Send", # Send messages in channels
|
||||||
]
|
]
|
||||||
|
|
||||||
# Initialize security utilities
|
# Initialize security utilities
|
||||||
|
|
@ -320,7 +321,7 @@ async def teams_callback(
|
||||||
new_connector = SearchSourceConnector(
|
new_connector = SearchSourceConnector(
|
||||||
name=connector_name,
|
name=connector_name,
|
||||||
connector_type=SearchSourceConnectorType.TEAMS_CONNECTOR,
|
connector_type=SearchSourceConnectorType.TEAMS_CONNECTOR,
|
||||||
is_indexable=True,
|
is_indexable=False,
|
||||||
config=connector_config,
|
config=connector_config,
|
||||||
search_space_id=space_id,
|
search_space_id=space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
||||||
|
|
@ -168,6 +168,11 @@ class ChatMessage(BaseModel):
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFilesystemMountPayload(BaseModel):
|
||||||
|
mount_id: str
|
||||||
|
root_path: str
|
||||||
|
|
||||||
|
|
||||||
class NewChatRequest(BaseModel):
|
class NewChatRequest(BaseModel):
|
||||||
"""Request schema for the deep agent chat endpoint."""
|
"""Request schema for the deep agent chat endpoint."""
|
||||||
|
|
||||||
|
|
@ -184,6 +189,9 @@ class NewChatRequest(BaseModel):
|
||||||
disabled_tools: list[str] | None = (
|
disabled_tools: list[str] | None = (
|
||||||
None # Optional list of tool names the user has disabled from the UI
|
None # Optional list of tool names the user has disabled from the UI
|
||||||
)
|
)
|
||||||
|
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||||
|
client_platform: Literal["web", "desktop"] = "web"
|
||||||
|
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||||
|
|
||||||
|
|
||||||
class RegenerateRequest(BaseModel):
|
class RegenerateRequest(BaseModel):
|
||||||
|
|
@ -204,6 +212,9 @@ class RegenerateRequest(BaseModel):
|
||||||
mentioned_document_ids: list[int] | None = None
|
mentioned_document_ids: list[int] | None = None
|
||||||
mentioned_surfsense_doc_ids: list[int] | None = None
|
mentioned_surfsense_doc_ids: list[int] | None = None
|
||||||
disabled_tools: list[str] | None = None
|
disabled_tools: list[str] | None = None
|
||||||
|
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||||
|
client_platform: Literal["web", "desktop"] = "web"
|
||||||
|
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -227,6 +238,9 @@ class ResumeDecision(BaseModel):
|
||||||
class ResumeRequest(BaseModel):
|
class ResumeRequest(BaseModel):
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
decisions: list[ResumeDecision]
|
decisions: list[ResumeDecision]
|
||||||
|
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||||
|
client_platform: Literal["web", "desktop"] = "web"
|
||||||
|
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ COMPOSIO_TOOLKIT_NAMES = {
|
||||||
}
|
}
|
||||||
|
|
||||||
# Toolkits that support indexing (Phase 1: Google services only)
|
# Toolkits that support indexing (Phase 1: Google services only)
|
||||||
INDEXABLE_TOOLKITS = {"googledrive", "gmail", "googlecalendar"}
|
INDEXABLE_TOOLKITS = {"googledrive"}
|
||||||
|
|
||||||
# Mapping of toolkit IDs to connector types
|
# Mapping of toolkit IDs to connector types
|
||||||
TOOLKIT_TO_CONNECTOR_TYPE = {
|
TOOLKIT_TO_CONNECTOR_TYPE = {
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -66,6 +65,8 @@ class ConfluenceKBSyncService:
|
||||||
if dup:
|
if dup:
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -184,6 +185,8 @@ class ConfluenceKBSyncService:
|
||||||
|
|
||||||
space_id = (document.document_metadata or {}).get("space_id", "")
|
space_id = (document.document_metadata or {}).get("space_id", "")
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -73,6 +72,8 @@ class DropboxKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -78,6 +77,8 @@ class GmailKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ from app.db import (
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
)
|
)
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -91,6 +90,8 @@ class GoogleCalendarKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -249,6 +250,8 @@ class GoogleCalendarKBSyncService:
|
||||||
if not indexable_content:
|
if not indexable_content:
|
||||||
return {"status": "error", "message": "Event produced empty content"}
|
return {"status": "error", "message": "Event produced empty content"}
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -75,6 +74,8 @@ class GoogleDriveKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.connectors.jira_history import JiraHistoryConnector
|
from app.connectors.jira_history import JiraHistoryConnector
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -75,6 +74,8 @@ class JiraKBSyncService:
|
||||||
if dup:
|
if dup:
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -190,6 +191,8 @@ class JiraKBSyncService:
|
||||||
state = formatted.get("status", "Unknown")
|
state = formatted.get("status", "Unknown")
|
||||||
comment_count = len(formatted.get("comments", []))
|
comment_count = len(formatted.get("comments", []))
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.connectors.linear_connector import LinearConnector
|
from app.connectors.linear_connector import LinearConnector
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -85,6 +84,8 @@ class LinearKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -226,6 +227,8 @@ class LinearKBSyncService:
|
||||||
comment_count = len(formatted_issue.get("comments", []))
|
comment_count = len(formatted_issue.get("comments", []))
|
||||||
formatted_issue.get("description", "")
|
formatted_issue.get("description", "")
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -133,6 +133,44 @@ PROVIDER_MAP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
|
||||||
|
# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
|
||||||
|
# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
|
||||||
|
# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
|
||||||
|
# request to an Azure endpoint, which then 404s with ``Resource not found``.
|
||||||
|
# Only providers with a well-known, stable public base URL are listed here —
|
||||||
|
# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
||||||
|
# huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
||||||
|
# so their existing config-driven behaviour is preserved.
|
||||||
|
PROVIDER_DEFAULT_API_BASE = {
|
||||||
|
"openrouter": "https://openrouter.ai/api/v1",
|
||||||
|
"groq": "https://api.groq.com/openai/v1",
|
||||||
|
"mistral": "https://api.mistral.ai/v1",
|
||||||
|
"perplexity": "https://api.perplexity.ai",
|
||||||
|
"xai": "https://api.x.ai/v1",
|
||||||
|
"cerebras": "https://api.cerebras.ai/v1",
|
||||||
|
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||||
|
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
||||||
|
"together_ai": "https://api.together.xyz/v1",
|
||||||
|
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
||||||
|
"cometapi": "https://api.cometapi.com/v1",
|
||||||
|
"sambanova": "https://api.sambanova.ai/v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Canonical provider → base URL when a config uses a generic ``openai``-style
|
||||||
|
# prefix but the ``provider`` field tells us which API it really is
|
||||||
|
# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
|
||||||
|
# each has its own base URL).
|
||||||
|
PROVIDER_KEY_DEFAULT_API_BASE = {
|
||||||
|
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||||
|
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||||
|
"MOONSHOT": "https://api.moonshot.ai/v1",
|
||||||
|
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
||||||
|
"MINIMAX": "https://api.minimax.io/v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class LLMRouterService:
|
class LLMRouterService:
|
||||||
"""
|
"""
|
||||||
Singleton service for managing LiteLLM Router.
|
Singleton service for managing LiteLLM Router.
|
||||||
|
|
@ -224,6 +262,16 @@ class LLMRouterService:
|
||||||
# hits ContextWindowExceededError.
|
# hits ContextWindowExceededError.
|
||||||
full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
|
full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
|
||||||
|
|
||||||
|
# Build a general-purpose fallback list so NotFound/timeout/rate-limit
|
||||||
|
# style failures on one deployment don't bubble up as hard errors —
|
||||||
|
# the router retries with a sibling deployment in ``auto-large``.
|
||||||
|
# ``auto-large`` is the large-context subset of ``auto``; if it is
|
||||||
|
# empty we fall back to ``auto`` itself so the router at least picks a
|
||||||
|
# different deployment in the same group.
|
||||||
|
fallbacks: list[dict[str, list[str]]] | None = None
|
||||||
|
if ctx_fallbacks:
|
||||||
|
fallbacks = [{"auto": ["auto-large"]}]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
router_kwargs: dict[str, Any] = {
|
router_kwargs: dict[str, Any] = {
|
||||||
"model_list": full_model_list,
|
"model_list": full_model_list,
|
||||||
|
|
@ -237,15 +285,24 @@ class LLMRouterService:
|
||||||
}
|
}
|
||||||
if ctx_fallbacks:
|
if ctx_fallbacks:
|
||||||
router_kwargs["context_window_fallbacks"] = ctx_fallbacks
|
router_kwargs["context_window_fallbacks"] = ctx_fallbacks
|
||||||
|
if fallbacks:
|
||||||
|
router_kwargs["fallbacks"] = fallbacks
|
||||||
|
|
||||||
instance._router = Router(**router_kwargs)
|
instance._router = Router(**router_kwargs)
|
||||||
instance._initialized = True
|
instance._initialized = True
|
||||||
|
|
||||||
|
global _cached_context_profile, _cached_context_profile_computed
|
||||||
|
_cached_context_profile = None
|
||||||
|
_cached_context_profile_computed = False
|
||||||
|
_router_instance_cache.clear()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"LLM Router initialized with %d deployments, "
|
"LLM Router initialized with %d deployments, "
|
||||||
"strategy: %s, context_window_fallbacks: %s",
|
"strategy: %s, context_window_fallbacks: %s, fallbacks: %s",
|
||||||
len(model_list),
|
len(model_list),
|
||||||
final_settings.get("routing_strategy"),
|
final_settings.get("routing_strategy"),
|
||||||
ctx_fallbacks or "none",
|
ctx_fallbacks or "none",
|
||||||
|
fallbacks or "none",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize LLM Router: {e}")
|
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||||
|
|
@ -348,10 +405,11 @@ class LLMRouterService:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Build model string
|
# Build model string
|
||||||
|
provider = config.get("provider", "").upper()
|
||||||
if config.get("custom_provider"):
|
if config.get("custom_provider"):
|
||||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
provider_prefix = config["custom_provider"]
|
||||||
|
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||||
else:
|
else:
|
||||||
provider = config.get("provider", "").upper()
|
|
||||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||||
|
|
||||||
|
|
@ -361,9 +419,19 @@ class LLMRouterService:
|
||||||
"api_key": config.get("api_key"),
|
"api_key": config.get("api_key"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add optional api_base
|
# Resolve ``api_base``. Config value wins; otherwise apply a
|
||||||
if config.get("api_base"):
|
# provider-aware default so the deployment does not silently
|
||||||
litellm_params["api_base"] = config["api_base"]
|
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
|
||||||
|
# requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
|
||||||
|
# docstring for the motivating bug (OpenRouter models 404-ing
|
||||||
|
# against an Azure endpoint).
|
||||||
|
api_base = config.get("api_base")
|
||||||
|
if not api_base:
|
||||||
|
api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
|
||||||
|
if not api_base:
|
||||||
|
api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
|
||||||
|
if api_base:
|
||||||
|
litellm_params["api_base"] = api_base
|
||||||
|
|
||||||
# Add any additional litellm parameters
|
# Add any additional litellm parameters
|
||||||
if config.get("litellm_params"):
|
if config.get("litellm_params"):
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
@ -6,7 +7,6 @@ from langchain_litellm import ChatLiteLLM
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import NewLLMConfig, SearchSpace
|
from app.db import NewLLMConfig, SearchSpace
|
||||||
from app.services.llm_router_service import (
|
from app.services.llm_router_service import (
|
||||||
|
|
@ -32,6 +32,39 @@ litellm.callbacks = [token_tracker]
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Providers that require an interactive OAuth / device-flow login before
|
||||||
|
# issuing any completion. LiteLLM implements these with blocking sync polling
|
||||||
|
# (requests + time.sleep), which would freeze the FastAPI event loop if
|
||||||
|
# invoked from validation. They are never usable from a headless backend,
|
||||||
|
# so we reject them at the edge.
|
||||||
|
_INTERACTIVE_AUTH_PROVIDERS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"github_copilot",
|
||||||
|
"github-copilot",
|
||||||
|
"githubcopilot",
|
||||||
|
"copilot",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hard upper bound for a single validation call. Must exceed the ChatLiteLLM
|
||||||
|
# request timeout (30s) by a small margin so a well-behaved provider never
|
||||||
|
# trips the watchdog, while any pathological/blocking provider is killed.
|
||||||
|
_VALIDATION_TIMEOUT_SECONDS: float = 35.0
|
||||||
|
|
||||||
|
|
||||||
|
def _is_interactive_auth_provider(
|
||||||
|
provider: str | None, custom_provider: str | None
|
||||||
|
) -> bool:
|
||||||
|
"""Return True if the given provider triggers interactive OAuth in LiteLLM."""
|
||||||
|
for raw in (custom_provider, provider):
|
||||||
|
if not raw:
|
||||||
|
continue
|
||||||
|
normalized = raw.strip().lower().replace(" ", "_")
|
||||||
|
if normalized in _INTERACTIVE_AUTH_PROVIDERS:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class LLMRole:
|
class LLMRole:
|
||||||
AGENT = "agent" # For agent/chat operations
|
AGENT = "agent" # For agent/chat operations
|
||||||
DOCUMENT_SUMMARY = "document_summary" # For document summarization
|
DOCUMENT_SUMMARY = "document_summary" # For document summarization
|
||||||
|
|
@ -93,6 +126,25 @@ async def validate_llm_config(
|
||||||
- is_valid: True if config works, False otherwise
|
- is_valid: True if config works, False otherwise
|
||||||
- error_message: Empty string if valid, error description if invalid
|
- error_message: Empty string if valid, error description if invalid
|
||||||
"""
|
"""
|
||||||
|
# Reject providers that require interactive OAuth/device-flow auth.
|
||||||
|
# LiteLLM's github_copilot provider (and similar) uses a blocking sync
|
||||||
|
# Authenticator that polls GitHub for up to several minutes and prints a
|
||||||
|
# device code to stdout. Running it on the FastAPI event loop will freeze
|
||||||
|
# the entire backend, so we refuse them up front.
|
||||||
|
if _is_interactive_auth_provider(provider, custom_provider):
|
||||||
|
msg = (
|
||||||
|
"Provider requires interactive OAuth/device-flow authentication "
|
||||||
|
"(e.g. github_copilot) and cannot be used in a hosted backend. "
|
||||||
|
"Please choose a provider that authenticates via API key."
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Rejected LLM config validation for interactive-auth provider "
|
||||||
|
"(provider=%r, custom_provider=%r)",
|
||||||
|
provider,
|
||||||
|
custom_provider,
|
||||||
|
)
|
||||||
|
return False, msg
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build the model string for litellm
|
# Build the model string for litellm
|
||||||
if custom_provider:
|
if custom_provider:
|
||||||
|
|
@ -151,11 +203,34 @@ async def validate_llm_config(
|
||||||
if litellm_params:
|
if litellm_params:
|
||||||
litellm_kwargs.update(litellm_params)
|
litellm_kwargs.update(litellm_params)
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
# Make a simple test call
|
# Run the test call in a worker thread with a hard timeout. Some
|
||||||
|
# LiteLLM providers have synchronous blocking code paths (e.g. OAuth
|
||||||
|
# authenticators that call time.sleep and requests.post) that would
|
||||||
|
# otherwise freeze the asyncio event loop. Offloading to a thread and
|
||||||
|
# bounding the wait keeps the server responsive even if a provider
|
||||||
|
# misbehaves.
|
||||||
test_message = HumanMessage(content="Hello")
|
test_message = HumanMessage(content="Hello")
|
||||||
response = await llm.ainvoke([test_message])
|
try:
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(llm.invoke, [test_message]),
|
||||||
|
timeout=_VALIDATION_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"LLM config validation timed out after %ss for model: %s",
|
||||||
|
_VALIDATION_TIMEOUT_SECONDS,
|
||||||
|
model_string,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Validation timed out after {int(_VALIDATION_TIMEOUT_SECONDS)}s. "
|
||||||
|
"The provider is unreachable or requires interactive "
|
||||||
|
"authentication that is not supported by the backend.",
|
||||||
|
)
|
||||||
|
|
||||||
# If we got here without exception, the config is valid
|
# If we got here without exception, the config is valid
|
||||||
if response and response.content:
|
if response and response.content:
|
||||||
|
|
@ -303,6 +378,8 @@ async def get_search_space_llm_instance(
|
||||||
if disable_streaming:
|
if disable_streaming:
|
||||||
litellm_kwargs["disable_streaming"] = True
|
litellm_kwargs["disable_streaming"] = True
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
# Get the LLM configuration from database (NewLLMConfig)
|
# Get the LLM configuration from database (NewLLMConfig)
|
||||||
|
|
@ -380,6 +457,8 @@ async def get_search_space_llm_instance(
|
||||||
if disable_streaming:
|
if disable_streaming:
|
||||||
litellm_kwargs["disable_streaming"] = True
|
litellm_kwargs["disable_streaming"] = True
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -481,6 +560,8 @@ async def get_vision_llm(
|
||||||
if global_cfg.get("litellm_params"):
|
if global_cfg.get("litellm_params"):
|
||||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
|
|
@ -514,6 +595,8 @@ async def get_vision_llm(
|
||||||
if vision_cfg.litellm_params:
|
if vision_cfg.litellm_params:
|
||||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
0
surfsense_backend/app/services/mcp_oauth/__init__.py
Normal file
0
surfsense_backend/app/services/mcp_oauth/__init__.py
Normal file
121
surfsense_backend/app/services/mcp_oauth/discovery.py
Normal file
121
surfsense_backend/app/services/mcp_oauth/discovery.py
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
"""MCP OAuth 2.1 metadata discovery, Dynamic Client Registration, and token exchange."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def discover_oauth_metadata(
|
||||||
|
mcp_url: str,
|
||||||
|
*,
|
||||||
|
origin_override: str | None = None,
|
||||||
|
timeout: float = 15.0,
|
||||||
|
) -> dict:
|
||||||
|
"""Fetch OAuth 2.1 metadata from the MCP server's well-known endpoint.
|
||||||
|
|
||||||
|
Per the MCP spec the discovery document lives at the *origin* of the
|
||||||
|
MCP server URL. ``origin_override`` can be used when the OAuth server
|
||||||
|
lives on a different domain (e.g. Airtable: MCP at ``mcp.airtable.com``,
|
||||||
|
OAuth at ``airtable.com``).
|
||||||
|
"""
|
||||||
|
if origin_override:
|
||||||
|
origin = origin_override.rstrip("/")
|
||||||
|
else:
|
||||||
|
parsed = urlparse(mcp_url)
|
||||||
|
origin = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
discovery_url = f"{origin}/.well-known/oauth-authorization-server"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||||
|
resp = await client.get(discovery_url, timeout=timeout)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def register_client(
|
||||||
|
registration_endpoint: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
*,
|
||||||
|
client_name: str = "SurfSense",
|
||||||
|
timeout: float = 15.0,
|
||||||
|
) -> dict:
|
||||||
|
"""Perform Dynamic Client Registration (RFC 7591)."""
|
||||||
|
payload = {
|
||||||
|
"client_name": client_name,
|
||||||
|
"redirect_uris": [redirect_uri],
|
||||||
|
"grant_types": ["authorization_code", "refresh_token"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
"token_endpoint_auth_method": "client_secret_basic",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
registration_endpoint, json=payload, timeout=timeout,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def exchange_code_for_tokens(
|
||||||
|
token_endpoint: str,
|
||||||
|
code: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
code_verifier: str,
|
||||||
|
*,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
) -> dict:
|
||||||
|
"""Exchange an authorization code for access + refresh tokens."""
|
||||||
|
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
token_endpoint,
|
||||||
|
data={
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"code_verifier": code_verifier,
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
"Authorization": f"Basic {creds}",
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_access_token(
|
||||||
|
token_endpoint: str,
|
||||||
|
refresh_token: str,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
*,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
) -> dict:
|
||||||
|
"""Refresh an expired access token."""
|
||||||
|
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
token_endpoint,
|
||||||
|
data={
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
"Authorization": f"Basic {creds}",
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
161
surfsense_backend/app/services/mcp_oauth/registry.py
Normal file
161
surfsense_backend/app/services/mcp_oauth/registry.py
Normal file
|
|
@ -0,0 +1,161 @@
|
||||||
|
"""Registry of MCP services with OAuth support.
|
||||||
|
|
||||||
|
Each entry maps a URL-safe service key to its MCP server endpoint and
|
||||||
|
authentication configuration. Services with ``supports_dcr=True`` use
|
||||||
|
RFC 7591 Dynamic Client Registration (the MCP server issues its own
|
||||||
|
credentials); the rest use pre-configured credentials via env vars.
|
||||||
|
|
||||||
|
``allowed_tools`` whitelists which MCP tools to expose to the agent.
|
||||||
|
An empty list means "load every tool the server advertises" (used for
|
||||||
|
user-managed generic MCP servers). Service-specific entries should
|
||||||
|
curate this list to keep the agent's tool count low and selection
|
||||||
|
accuracy high.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnectorType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MCPServiceConfig:
|
||||||
|
name: str
|
||||||
|
mcp_url: str
|
||||||
|
connector_type: str
|
||||||
|
supports_dcr: bool = True
|
||||||
|
oauth_discovery_origin: str | None = None
|
||||||
|
client_id_env: str | None = None
|
||||||
|
client_secret_env: str | None = None
|
||||||
|
scopes: list[str] = field(default_factory=list)
|
||||||
|
scope_param: str = "scope"
|
||||||
|
auth_endpoint_override: str | None = None
|
||||||
|
token_endpoint_override: str | None = None
|
||||||
|
allowed_tools: list[str] = field(default_factory=list)
|
||||||
|
readonly_tools: frozenset[str] = field(default_factory=frozenset)
|
||||||
|
account_metadata_keys: list[str] = field(default_factory=list)
|
||||||
|
"""``connector.config`` keys exposed by ``get_connected_accounts``.
|
||||||
|
|
||||||
|
Only listed keys are returned to the LLM — tokens and secrets are
|
||||||
|
never included. Every service should at least have its
|
||||||
|
``display_name`` populated during OAuth; additional service-specific
|
||||||
|
fields (e.g. Jira ``cloud_id``) are listed here so the LLM can pass
|
||||||
|
them to action tools.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
MCP_SERVICES: dict[str, MCPServiceConfig] = {
|
||||||
|
"linear": MCPServiceConfig(
|
||||||
|
name="Linear",
|
||||||
|
mcp_url="https://mcp.linear.app/mcp",
|
||||||
|
connector_type="LINEAR_CONNECTOR",
|
||||||
|
allowed_tools=[
|
||||||
|
"list_issues",
|
||||||
|
"get_issue",
|
||||||
|
"save_issue",
|
||||||
|
],
|
||||||
|
readonly_tools=frozenset({"list_issues", "get_issue"}),
|
||||||
|
account_metadata_keys=["organization_name", "organization_url_key"],
|
||||||
|
),
|
||||||
|
"jira": MCPServiceConfig(
|
||||||
|
name="Jira",
|
||||||
|
mcp_url="https://mcp.atlassian.com/v1/mcp",
|
||||||
|
connector_type="JIRA_CONNECTOR",
|
||||||
|
allowed_tools=[
|
||||||
|
"getAccessibleAtlassianResources",
|
||||||
|
"searchJiraIssuesUsingJql",
|
||||||
|
"getVisibleJiraProjects",
|
||||||
|
"getJiraProjectIssueTypesMetadata",
|
||||||
|
"createJiraIssue",
|
||||||
|
"editJiraIssue",
|
||||||
|
],
|
||||||
|
readonly_tools=frozenset({
|
||||||
|
"getAccessibleAtlassianResources",
|
||||||
|
"searchJiraIssuesUsingJql",
|
||||||
|
"getVisibleJiraProjects",
|
||||||
|
"getJiraProjectIssueTypesMetadata",
|
||||||
|
}),
|
||||||
|
account_metadata_keys=["cloud_id", "site_name", "base_url"],
|
||||||
|
),
|
||||||
|
"clickup": MCPServiceConfig(
|
||||||
|
name="ClickUp",
|
||||||
|
mcp_url="https://mcp.clickup.com/mcp",
|
||||||
|
connector_type="CLICKUP_CONNECTOR",
|
||||||
|
allowed_tools=[
|
||||||
|
"clickup_search",
|
||||||
|
"clickup_get_task",
|
||||||
|
],
|
||||||
|
readonly_tools=frozenset({"clickup_search", "clickup_get_task"}),
|
||||||
|
account_metadata_keys=["workspace_id", "workspace_name"],
|
||||||
|
),
|
||||||
|
"slack": MCPServiceConfig(
|
||||||
|
name="Slack",
|
||||||
|
mcp_url="https://mcp.slack.com/mcp",
|
||||||
|
connector_type="SLACK_CONNECTOR",
|
||||||
|
supports_dcr=False,
|
||||||
|
client_id_env="SLACK_CLIENT_ID",
|
||||||
|
client_secret_env="SLACK_CLIENT_SECRET",
|
||||||
|
auth_endpoint_override="https://slack.com/oauth/v2_user/authorize",
|
||||||
|
token_endpoint_override="https://slack.com/api/oauth.v2.user.access",
|
||||||
|
scopes=[
|
||||||
|
"search:read.public", "search:read.private", "search:read.mpim", "search:read.im",
|
||||||
|
"channels:history", "groups:history", "mpim:history", "im:history",
|
||||||
|
],
|
||||||
|
allowed_tools=[
|
||||||
|
"slack_search_channels",
|
||||||
|
"slack_read_channel",
|
||||||
|
"slack_read_thread",
|
||||||
|
],
|
||||||
|
readonly_tools=frozenset({"slack_search_channels", "slack_read_channel", "slack_read_thread"}),
|
||||||
|
# TODO: oauth.v2.user.access only returns team.id, not team.name.
|
||||||
|
# To populate team_name, either add "team:read" scope and call
|
||||||
|
# GET /api/team.info during OAuth callback, or switch to oauth.v2.access.
|
||||||
|
account_metadata_keys=["team_id", "team_name"],
|
||||||
|
),
|
||||||
|
"airtable": MCPServiceConfig(
|
||||||
|
name="Airtable",
|
||||||
|
mcp_url="https://mcp.airtable.com/mcp",
|
||||||
|
connector_type="AIRTABLE_CONNECTOR",
|
||||||
|
supports_dcr=False,
|
||||||
|
oauth_discovery_origin="https://airtable.com",
|
||||||
|
client_id_env="AIRTABLE_CLIENT_ID",
|
||||||
|
client_secret_env="AIRTABLE_CLIENT_SECRET",
|
||||||
|
scopes=["data.records:read", "schema.bases:read"],
|
||||||
|
allowed_tools=[
|
||||||
|
"list_bases",
|
||||||
|
"list_tables_for_base",
|
||||||
|
"list_records_for_table",
|
||||||
|
],
|
||||||
|
readonly_tools=frozenset({"list_bases", "list_tables_for_base", "list_records_for_table"}),
|
||||||
|
account_metadata_keys=["user_id", "user_email"],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
_CONNECTOR_TYPE_TO_SERVICE: dict[str, MCPServiceConfig] = {
|
||||||
|
svc.connector_type: svc for svc in MCP_SERVICES.values()
|
||||||
|
}
|
||||||
|
|
||||||
|
LIVE_CONNECTOR_TYPES: frozenset[SearchSourceConnectorType] = frozenset({
|
||||||
|
SearchSourceConnectorType.SLACK_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.TEAMS_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.AIRTABLE_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.LUMA_CONNECTOR,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def get_service(key: str) -> MCPServiceConfig | None:
|
||||||
|
return MCP_SERVICES.get(key)
|
||||||
|
|
||||||
|
|
||||||
|
def get_service_by_connector_type(connector_type: str) -> MCPServiceConfig | None:
|
||||||
|
"""Look up an MCP service config by its ``connector_type`` enum value."""
|
||||||
|
return _CONNECTOR_TYPE_TO_SERVICE.get(connector_type)
|
||||||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -74,6 +73,8 @@ class NotionKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -244,6 +245,8 @@ class NotionKBSyncService:
|
||||||
f"Final content length: {len(full_content)} chars, verified={content_verified}"
|
f"Final content length: {len(full_content)} chars, verified={content_verified}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
logger.debug("Generating summary and embeddings")
|
logger.debug("Generating summary and embeddings")
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
|
|
|
||||||
|
|
@ -227,8 +227,6 @@ class NotionToolMetadataService:
|
||||||
async def _check_account_health(self, connector_id: int) -> bool:
|
async def _check_account_health(self, connector_id: int) -> bool:
|
||||||
"""Check if a Notion connector's token is still valid.
|
"""Check if a Notion connector's token is still valid.
|
||||||
|
|
||||||
Uses a lightweight ``users.me()`` call to verify the token.
|
|
||||||
|
|
||||||
Returns True if the token is expired/invalid, False if healthy.
|
Returns True if the token is expired/invalid, False if healthy.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import Document, DocumentType
|
from app.db import Document, DocumentType
|
||||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -73,6 +72,8 @@ class OneDriveKBSyncService:
|
||||||
)
|
)
|
||||||
content_hash = unique_hash
|
content_hash = unique_hash
|
||||||
|
|
||||||
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
user_llm = await get_user_long_context_llm(
|
user_llm = await get_user_long_context_llm(
|
||||||
self.db_session,
|
self.db_session,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
|
||||||
|
|
@ -39,52 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="index_slack_messages", bind=True)
|
|
||||||
def index_slack_messages_task(
|
|
||||||
self,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Celery task to index Slack messages."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_slack_messages(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
_handle_greenlet_error(e, "index_slack_messages", connector_id)
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_slack_messages(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Index Slack messages with new session."""
|
|
||||||
from app.routes.search_source_connectors_routes import (
|
|
||||||
run_slack_indexing,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with get_celery_session_maker()() as session:
|
|
||||||
await run_slack_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="index_notion_pages", bind=True)
|
@celery_app.task(name="index_notion_pages", bind=True)
|
||||||
def index_notion_pages_task(
|
def index_notion_pages_task(
|
||||||
self,
|
self,
|
||||||
|
|
@ -174,92 +128,6 @@ async def _index_github_repos(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="index_linear_issues", bind=True)
|
|
||||||
def index_linear_issues_task(
|
|
||||||
self,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Celery task to index Linear issues."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_linear_issues(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_linear_issues(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Index Linear issues with new session."""
|
|
||||||
from app.routes.search_source_connectors_routes import (
|
|
||||||
run_linear_indexing,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with get_celery_session_maker()() as session:
|
|
||||||
await run_linear_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="index_jira_issues", bind=True)
|
|
||||||
def index_jira_issues_task(
|
|
||||||
self,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Celery task to index Jira issues."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_jira_issues(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_jira_issues(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Index Jira issues with new session."""
|
|
||||||
from app.routes.search_source_connectors_routes import (
|
|
||||||
run_jira_indexing,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with get_celery_session_maker()() as session:
|
|
||||||
await run_jira_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="index_confluence_pages", bind=True)
|
@celery_app.task(name="index_confluence_pages", bind=True)
|
||||||
def index_confluence_pages_task(
|
def index_confluence_pages_task(
|
||||||
self,
|
self,
|
||||||
|
|
@ -303,49 +171,6 @@ async def _index_confluence_pages(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="index_clickup_tasks", bind=True)
|
|
||||||
def index_clickup_tasks_task(
|
|
||||||
self,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Celery task to index ClickUp tasks."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_clickup_tasks(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_clickup_tasks(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Index ClickUp tasks with new session."""
|
|
||||||
from app.routes.search_source_connectors_routes import (
|
|
||||||
run_clickup_indexing,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with get_celery_session_maker()() as session:
|
|
||||||
await run_clickup_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="index_google_calendar_events", bind=True)
|
@celery_app.task(name="index_google_calendar_events", bind=True)
|
||||||
def index_google_calendar_events_task(
|
def index_google_calendar_events_task(
|
||||||
self,
|
self,
|
||||||
|
|
@ -392,49 +217,6 @@ async def _index_google_calendar_events(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="index_airtable_records", bind=True)
|
|
||||||
def index_airtable_records_task(
|
|
||||||
self,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Celery task to index Airtable records."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_airtable_records(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_airtable_records(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Index Airtable records with new session."""
|
|
||||||
from app.routes.search_source_connectors_routes import (
|
|
||||||
run_airtable_indexing,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with get_celery_session_maker()() as session:
|
|
||||||
await run_airtable_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="index_google_gmail_messages", bind=True)
|
@celery_app.task(name="index_google_gmail_messages", bind=True)
|
||||||
def index_google_gmail_messages_task(
|
def index_google_gmail_messages_task(
|
||||||
self,
|
self,
|
||||||
|
|
@ -622,135 +404,6 @@ async def _index_dropbox_files(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="index_discord_messages", bind=True)
|
|
||||||
def index_discord_messages_task(
|
|
||||||
self,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Celery task to index Discord messages."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_discord_messages(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_discord_messages(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Index Discord messages with new session."""
|
|
||||||
from app.routes.search_source_connectors_routes import (
|
|
||||||
run_discord_indexing,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with get_celery_session_maker()() as session:
|
|
||||||
await run_discord_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="index_teams_messages", bind=True)
|
|
||||||
def index_teams_messages_task(
|
|
||||||
self,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Celery task to index Microsoft Teams messages."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_teams_messages(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_teams_messages(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Index Microsoft Teams messages with new session."""
|
|
||||||
from app.routes.search_source_connectors_routes import (
|
|
||||||
run_teams_indexing,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with get_celery_session_maker()() as session:
|
|
||||||
await run_teams_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="index_luma_events", bind=True)
|
|
||||||
def index_luma_events_task(
|
|
||||||
self,
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Celery task to index Luma events."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_luma_events(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_luma_events(
|
|
||||||
connector_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
):
|
|
||||||
"""Index Luma events with new session."""
|
|
||||||
from app.routes.search_source_connectors_routes import (
|
|
||||||
run_luma_indexing,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with get_celery_session_maker()() as session:
|
|
||||||
await run_luma_indexing(
|
|
||||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="index_elasticsearch_documents", bind=True)
|
@celery_app.task(name="index_elasticsearch_documents", bind=True)
|
||||||
def index_elasticsearch_documents_task(
|
def index_elasticsearch_documents_task(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -51,50 +51,51 @@ async def _check_and_trigger_schedules():
|
||||||
|
|
||||||
logger.info(f"Found {len(due_connectors)} connectors due for indexing")
|
logger.info(f"Found {len(due_connectors)} connectors due for indexing")
|
||||||
|
|
||||||
# Import all indexing tasks
|
# Import indexing tasks for KB connectors only.
|
||||||
|
# Live connectors (Linear, Slack, Jira, ClickUp, Airtable, Discord,
|
||||||
|
# Teams, Gmail, Calendar, Luma) use real-time tools instead.
|
||||||
from app.tasks.celery_tasks.connector_tasks import (
|
from app.tasks.celery_tasks.connector_tasks import (
|
||||||
index_airtable_records_task,
|
|
||||||
index_clickup_tasks_task,
|
|
||||||
index_confluence_pages_task,
|
index_confluence_pages_task,
|
||||||
index_crawled_urls_task,
|
index_crawled_urls_task,
|
||||||
index_discord_messages_task,
|
|
||||||
index_elasticsearch_documents_task,
|
index_elasticsearch_documents_task,
|
||||||
index_github_repos_task,
|
index_github_repos_task,
|
||||||
index_google_calendar_events_task,
|
|
||||||
index_google_drive_files_task,
|
index_google_drive_files_task,
|
||||||
index_google_gmail_messages_task,
|
|
||||||
index_jira_issues_task,
|
|
||||||
index_linear_issues_task,
|
|
||||||
index_luma_events_task,
|
|
||||||
index_notion_pages_task,
|
index_notion_pages_task,
|
||||||
index_slack_messages_task,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map connector types to their tasks
|
|
||||||
task_map = {
|
task_map = {
|
||||||
SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task,
|
|
||||||
SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task,
|
SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task,
|
||||||
SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task,
|
SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task,
|
||||||
SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task,
|
|
||||||
SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task,
|
|
||||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task,
|
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task,
|
||||||
SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task,
|
|
||||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
|
|
||||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task,
|
|
||||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task,
|
|
||||||
SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task,
|
|
||||||
SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task,
|
|
||||||
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
|
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
|
||||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
|
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
|
||||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
|
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
|
||||||
# Composio connector types (unified with native Google tasks)
|
|
||||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
|
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
|
||||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: index_google_gmail_messages_task,
|
|
||||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES
|
||||||
|
|
||||||
|
# Disable obsolete periodic indexing for live connectors in one batch.
|
||||||
|
live_disabled = []
|
||||||
|
for connector in due_connectors:
|
||||||
|
if connector.connector_type in LIVE_CONNECTOR_TYPES:
|
||||||
|
connector.periodic_indexing_enabled = False
|
||||||
|
connector.next_scheduled_at = None
|
||||||
|
live_disabled.append(connector)
|
||||||
|
if live_disabled:
|
||||||
|
await session.commit()
|
||||||
|
for c in live_disabled:
|
||||||
|
logger.info(
|
||||||
|
"Disabled obsolete periodic indexing for live connector %s (%s)",
|
||||||
|
c.id,
|
||||||
|
c.connector_type.value,
|
||||||
|
)
|
||||||
|
|
||||||
# Trigger indexing for each due connector
|
# Trigger indexing for each due connector
|
||||||
for connector in due_connectors:
|
for connector in due_connectors:
|
||||||
|
if connector in live_disabled:
|
||||||
|
continue
|
||||||
|
|
||||||
# Primary guard: Redis lock indicates a task is currently running.
|
# Primary guard: Redis lock indicates a task is currently running.
|
||||||
if is_connector_indexing_locked(connector.id):
|
if is_connector_indexing_locked(connector.id):
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,8 @@ from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemSelection
|
||||||
|
from app.config import config
|
||||||
from app.agents.new_chat.llm_config import (
|
from app.agents.new_chat.llm_config import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
create_chat_litellm_from_agent_config,
|
create_chat_litellm_from_agent_config,
|
||||||
|
|
@ -145,6 +147,102 @@ class StreamResult:
|
||||||
interrupt_value: dict[str, Any] | None = None
|
interrupt_value: dict[str, Any] | None = None
|
||||||
sandbox_files: list[str] = field(default_factory=list)
|
sandbox_files: list[str] = field(default_factory=list)
|
||||||
agent_called_update_memory: bool = False
|
agent_called_update_memory: bool = False
|
||||||
|
request_id: str | None = None
|
||||||
|
turn_id: str = ""
|
||||||
|
filesystem_mode: str = "cloud"
|
||||||
|
client_platform: str = "web"
|
||||||
|
intent_detected: str = "chat_only"
|
||||||
|
intent_confidence: float = 0.0
|
||||||
|
write_attempted: bool = False
|
||||||
|
write_succeeded: bool = False
|
||||||
|
verification_succeeded: bool = False
|
||||||
|
commit_gate_passed: bool = True
|
||||||
|
commit_gate_reason: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_output_to_text(tool_output: Any) -> str:
|
||||||
|
if isinstance(tool_output, dict):
|
||||||
|
if isinstance(tool_output.get("result"), str):
|
||||||
|
return tool_output["result"]
|
||||||
|
if isinstance(tool_output.get("error"), str):
|
||||||
|
return tool_output["error"]
|
||||||
|
return json.dumps(tool_output, ensure_ascii=False)
|
||||||
|
return str(tool_output)
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_output_has_error(tool_output: Any) -> bool:
|
||||||
|
if isinstance(tool_output, dict):
|
||||||
|
if tool_output.get("error"):
|
||||||
|
return True
|
||||||
|
result = tool_output.get("result")
|
||||||
|
if isinstance(result, str) and result.strip().lower().startswith("error:"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
if isinstance(tool_output, str):
|
||||||
|
return tool_output.strip().lower().startswith("error:")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_resolved_file_path(*, tool_name: str, tool_output: Any) -> str | None:
|
||||||
|
if isinstance(tool_output, dict):
|
||||||
|
path_value = tool_output.get("path")
|
||||||
|
if isinstance(path_value, str) and path_value.strip():
|
||||||
|
return path_value.strip()
|
||||||
|
text = _tool_output_to_text(tool_output)
|
||||||
|
if tool_name == "write_file":
|
||||||
|
match = re.search(r"Updated file\s+(.+)$", text.strip())
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
if tool_name == "edit_file":
|
||||||
|
match = re.search(r"in '([^']+)'", text)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _contract_enforcement_active(result: StreamResult) -> bool:
|
||||||
|
# Keep policy deterministic with no env-driven progression modes:
|
||||||
|
# enforce the file-operation contract only in desktop local-folder mode.
|
||||||
|
return result.filesystem_mode == "desktop_local_folder"
|
||||||
|
|
||||||
|
|
||||||
|
def _evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]:
|
||||||
|
if result.intent_detected != "file_write":
|
||||||
|
return True, ""
|
||||||
|
if not result.write_attempted:
|
||||||
|
return False, "no_write_attempt"
|
||||||
|
if not result.write_succeeded:
|
||||||
|
return False, "write_failed"
|
||||||
|
if not result.verification_succeeded:
|
||||||
|
return False, "verification_failed"
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
|
def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"stage": stage,
|
||||||
|
"request_id": result.request_id or "unknown",
|
||||||
|
"turn_id": result.turn_id or "unknown",
|
||||||
|
"chat_id": result.turn_id.split(":", 1)[0] if ":" in result.turn_id else "unknown",
|
||||||
|
"filesystem_mode": result.filesystem_mode,
|
||||||
|
"client_platform": result.client_platform,
|
||||||
|
"intent_detected": result.intent_detected,
|
||||||
|
"intent_confidence": result.intent_confidence,
|
||||||
|
"write_attempted": result.write_attempted,
|
||||||
|
"write_succeeded": result.write_succeeded,
|
||||||
|
"verification_succeeded": result.verification_succeeded,
|
||||||
|
"commit_gate_passed": result.commit_gate_passed,
|
||||||
|
"commit_gate_reason": result.commit_gate_reason or None,
|
||||||
|
}
|
||||||
|
payload.update(extra)
|
||||||
|
_perf_log.info("[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
async def _stream_agent_events(
|
async def _stream_agent_events(
|
||||||
|
|
@ -239,6 +337,8 @@ async def _stream_agent_events(
|
||||||
tool_name = event.get("name", "unknown_tool")
|
tool_name = event.get("name", "unknown_tool")
|
||||||
run_id = event.get("run_id", "")
|
run_id = event.get("run_id", "")
|
||||||
tool_input = event.get("data", {}).get("input", {})
|
tool_input = event.get("data", {}).get("input", {})
|
||||||
|
if tool_name in ("write_file", "edit_file"):
|
||||||
|
result.write_attempted = True
|
||||||
|
|
||||||
if current_text_id is not None:
|
if current_text_id is not None:
|
||||||
yield streaming_service.format_text_end(current_text_id)
|
yield streaming_service.format_text_end(current_text_id)
|
||||||
|
|
@ -514,6 +614,14 @@ async def _stream_agent_events(
|
||||||
else:
|
else:
|
||||||
tool_output = {"result": str(raw_output) if raw_output else "completed"}
|
tool_output = {"result": str(raw_output) if raw_output else "completed"}
|
||||||
|
|
||||||
|
if tool_name in ("write_file", "edit_file"):
|
||||||
|
if _tool_output_has_error(tool_output):
|
||||||
|
# Keep successful evidence if a previous write/edit in this turn succeeded.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
result.write_succeeded = True
|
||||||
|
result.verification_succeeded = True
|
||||||
|
|
||||||
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
|
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
|
||||||
original_step_id = tool_step_ids.get(
|
original_step_id = tool_step_ids.get(
|
||||||
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
|
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
|
||||||
|
|
@ -925,6 +1033,30 @@ async def _stream_agent_events(
|
||||||
f"Scrape failed: {error_msg}",
|
f"Scrape failed: {error_msg}",
|
||||||
"error",
|
"error",
|
||||||
)
|
)
|
||||||
|
elif tool_name in ("write_file", "edit_file"):
|
||||||
|
resolved_path = _extract_resolved_file_path(
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_output=tool_output,
|
||||||
|
)
|
||||||
|
result_text = _tool_output_to_text(tool_output)
|
||||||
|
if _tool_output_has_error(tool_output):
|
||||||
|
yield streaming_service.format_tool_output_available(
|
||||||
|
tool_call_id,
|
||||||
|
{
|
||||||
|
"status": "error",
|
||||||
|
"error": result_text,
|
||||||
|
"path": resolved_path,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield streaming_service.format_tool_output_available(
|
||||||
|
tool_call_id,
|
||||||
|
{
|
||||||
|
"status": "completed",
|
||||||
|
"path": resolved_path,
|
||||||
|
"result": result_text,
|
||||||
|
},
|
||||||
|
)
|
||||||
elif tool_name == "generate_report":
|
elif tool_name == "generate_report":
|
||||||
# Stream the full report result so frontend can render the ReportCard
|
# Stream the full report result so frontend can render the ReportCard
|
||||||
yield streaming_service.format_tool_output_available(
|
yield streaming_service.format_tool_output_available(
|
||||||
|
|
@ -1143,10 +1275,59 @@ async def _stream_agent_events(
|
||||||
if completion_event:
|
if completion_event:
|
||||||
yield completion_event
|
yield completion_event
|
||||||
|
|
||||||
|
state = await agent.aget_state(config)
|
||||||
|
state_values = getattr(state, "values", {}) or {}
|
||||||
|
contract_state = state_values.get("file_operation_contract") or {}
|
||||||
|
contract_turn_id = contract_state.get("turn_id")
|
||||||
|
current_turn_id = config.get("configurable", {}).get("turn_id", "")
|
||||||
|
intent_value = contract_state.get("intent")
|
||||||
|
if (
|
||||||
|
isinstance(intent_value, str)
|
||||||
|
and intent_value in ("chat_only", "file_write", "file_read")
|
||||||
|
and contract_turn_id == current_turn_id
|
||||||
|
):
|
||||||
|
result.intent_detected = intent_value
|
||||||
|
if (
|
||||||
|
isinstance(intent_value, str)
|
||||||
|
and intent_value in (
|
||||||
|
"chat_only",
|
||||||
|
"file_write",
|
||||||
|
"file_read",
|
||||||
|
)
|
||||||
|
and contract_turn_id != current_turn_id
|
||||||
|
):
|
||||||
|
# Ignore stale intent contracts from previous turns/checkpoints.
|
||||||
|
result.intent_detected = "chat_only"
|
||||||
|
result.intent_confidence = (
|
||||||
|
_safe_float(contract_state.get("confidence"), default=0.0)
|
||||||
|
if contract_turn_id == current_turn_id
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.intent_detected == "file_write":
|
||||||
|
result.commit_gate_passed, result.commit_gate_reason = (
|
||||||
|
_evaluate_file_contract_outcome(result)
|
||||||
|
)
|
||||||
|
if not result.commit_gate_passed:
|
||||||
|
if _contract_enforcement_active(result):
|
||||||
|
gate_notice = (
|
||||||
|
"I could not complete the requested file write because no successful "
|
||||||
|
"write_file/edit_file operation was confirmed."
|
||||||
|
)
|
||||||
|
gate_text_id = streaming_service.generate_text_id()
|
||||||
|
yield streaming_service.format_text_start(gate_text_id)
|
||||||
|
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
|
||||||
|
yield streaming_service.format_text_end(gate_text_id)
|
||||||
|
yield streaming_service.format_terminal_info(gate_notice, "error")
|
||||||
|
accumulated_text = gate_notice
|
||||||
|
else:
|
||||||
|
result.commit_gate_passed = True
|
||||||
|
result.commit_gate_reason = ""
|
||||||
|
|
||||||
result.accumulated_text = accumulated_text
|
result.accumulated_text = accumulated_text
|
||||||
result.agent_called_update_memory = called_update_memory
|
result.agent_called_update_memory = called_update_memory
|
||||||
|
_log_file_contract("turn_outcome", result)
|
||||||
|
|
||||||
state = await agent.aget_state(config)
|
|
||||||
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
|
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
|
||||||
if is_interrupted:
|
if is_interrupted:
|
||||||
result.is_interrupted = True
|
result.is_interrupted = True
|
||||||
|
|
@ -1167,6 +1348,8 @@ async def stream_new_chat(
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
current_user_display_name: str | None = None,
|
current_user_display_name: str | None = None,
|
||||||
disabled_tools: list[str] | None = None,
|
disabled_tools: list[str] | None = None,
|
||||||
|
filesystem_selection: FilesystemSelection | None = None,
|
||||||
|
request_id: str | None = None,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Stream chat responses from the new SurfSense deep agent.
|
Stream chat responses from the new SurfSense deep agent.
|
||||||
|
|
@ -1194,6 +1377,20 @@ async def stream_new_chat(
|
||||||
streaming_service = VercelStreamingService()
|
streaming_service = VercelStreamingService()
|
||||||
stream_result = StreamResult()
|
stream_result = StreamResult()
|
||||||
_t_total = time.perf_counter()
|
_t_total = time.perf_counter()
|
||||||
|
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
|
||||||
|
fs_platform = (
|
||||||
|
filesystem_selection.client_platform.value if filesystem_selection else "web"
|
||||||
|
)
|
||||||
|
stream_result.request_id = request_id
|
||||||
|
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
|
||||||
|
stream_result.filesystem_mode = fs_mode
|
||||||
|
stream_result.client_platform = fs_platform
|
||||||
|
_log_file_contract("turn_start", stream_result)
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_new_chat] filesystem_mode=%s client_platform=%s",
|
||||||
|
fs_mode,
|
||||||
|
fs_platform,
|
||||||
|
)
|
||||||
log_system_snapshot("stream_new_chat_START")
|
log_system_snapshot("stream_new_chat_START")
|
||||||
|
|
||||||
from app.services.token_tracking_service import start_turn
|
from app.services.token_tracking_service import start_turn
|
||||||
|
|
@ -1329,6 +1526,7 @@ async def stream_new_chat(
|
||||||
thread_visibility=visibility,
|
thread_visibility=visibility,
|
||||||
disabled_tools=disabled_tools,
|
disabled_tools=disabled_tools,
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
|
filesystem_selection=filesystem_selection,
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
||||||
|
|
@ -1435,6 +1633,8 @@ async def stream_new_chat(
|
||||||
# We will use this to simulate group chat functionality in the future
|
# We will use this to simulate group chat functionality in the future
|
||||||
"messages": langchain_messages,
|
"messages": langchain_messages,
|
||||||
"search_space_id": search_space_id,
|
"search_space_id": search_space_id,
|
||||||
|
"request_id": request_id or "unknown",
|
||||||
|
"turn_id": stream_result.turn_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
|
|
@ -1464,6 +1664,8 @@ async def stream_new_chat(
|
||||||
# Configure LangGraph with thread_id for memory
|
# Configure LangGraph with thread_id for memory
|
||||||
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
|
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
|
||||||
configurable = {"thread_id": str(chat_id)}
|
configurable = {"thread_id": str(chat_id)}
|
||||||
|
configurable["request_id"] = request_id or "unknown"
|
||||||
|
configurable["turn_id"] = stream_result.turn_id
|
||||||
if checkpoint_id:
|
if checkpoint_id:
|
||||||
configurable["checkpoint_id"] = checkpoint_id
|
configurable["checkpoint_id"] = checkpoint_id
|
||||||
|
|
||||||
|
|
@ -1871,10 +2073,26 @@ async def stream_resume_chat(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
llm_config_id: int = -1,
|
llm_config_id: int = -1,
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
|
filesystem_selection: FilesystemSelection | None = None,
|
||||||
|
request_id: str | None = None,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
streaming_service = VercelStreamingService()
|
streaming_service = VercelStreamingService()
|
||||||
stream_result = StreamResult()
|
stream_result = StreamResult()
|
||||||
_t_total = time.perf_counter()
|
_t_total = time.perf_counter()
|
||||||
|
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
|
||||||
|
fs_platform = (
|
||||||
|
filesystem_selection.client_platform.value if filesystem_selection else "web"
|
||||||
|
)
|
||||||
|
stream_result.request_id = request_id
|
||||||
|
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
|
||||||
|
stream_result.filesystem_mode = fs_mode
|
||||||
|
stream_result.client_platform = fs_platform
|
||||||
|
_log_file_contract("turn_start", stream_result)
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_resume] filesystem_mode=%s client_platform=%s",
|
||||||
|
fs_mode,
|
||||||
|
fs_platform,
|
||||||
|
)
|
||||||
|
|
||||||
from app.services.token_tracking_service import start_turn
|
from app.services.token_tracking_service import start_turn
|
||||||
|
|
||||||
|
|
@ -1991,6 +2209,7 @@ async def stream_resume_chat(
|
||||||
agent_config=agent_config,
|
agent_config=agent_config,
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
thread_visibility=visibility,
|
thread_visibility=visibility,
|
||||||
|
filesystem_selection=filesystem_selection,
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
||||||
|
|
@ -2009,7 +2228,11 @@ async def stream_resume_chat(
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"configurable": {"thread_id": str(chat_id)},
|
"configurable": {
|
||||||
|
"thread_id": str(chat_id),
|
||||||
|
"request_id": request_id or "unknown",
|
||||||
|
"turn_id": stream_result.turn_id,
|
||||||
|
},
|
||||||
"recursion_limit": 80,
|
"recursion_limit": 80,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,75 +1,29 @@
|
||||||
"""
|
"""
|
||||||
Connector indexers module for background tasks.
|
Connector indexers module for background tasks.
|
||||||
|
|
||||||
This module provides a collection of connector indexers for different platforms
|
Each indexer handles content indexing from a specific connector type.
|
||||||
and services. Each indexer is responsible for handling the indexing of content
|
Live connectors (Slack, Linear, Jira, ClickUp, Airtable, Discord, Teams,
|
||||||
from a specific connector type.
|
Luma) now use real-time agent tools instead of background indexing.
|
||||||
|
|
||||||
Available indexers:
|
|
||||||
- Slack: Index messages from Slack channels
|
|
||||||
- Notion: Index pages from Notion workspaces
|
|
||||||
- GitHub: Index repositories and files from GitHub
|
|
||||||
- Linear: Index issues from Linear workspaces
|
|
||||||
- Jira: Index issues from Jira projects
|
|
||||||
- Confluence: Index pages from Confluence spaces
|
|
||||||
- BookStack: Index pages from BookStack wiki instances
|
|
||||||
- Discord: Index messages from Discord servers
|
|
||||||
- ClickUp: Index tasks from ClickUp workspaces
|
|
||||||
- Google Gmail: Index messages from Google Gmail
|
|
||||||
- Google Calendar: Index events from Google Calendar
|
|
||||||
- Luma: Index events from Luma
|
|
||||||
- Webcrawler: Index crawled URLs
|
|
||||||
- Elasticsearch: Index documents from Elasticsearch instances
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Communication platforms
|
|
||||||
# Calendar and scheduling
|
|
||||||
from .airtable_indexer import index_airtable_records
|
|
||||||
from .bookstack_indexer import index_bookstack_pages
|
from .bookstack_indexer import index_bookstack_pages
|
||||||
|
|
||||||
# Note: composio_indexer is imported directly in connector_tasks.py to avoid circular imports
|
|
||||||
from .clickup_indexer import index_clickup_tasks
|
|
||||||
from .confluence_indexer import index_confluence_pages
|
from .confluence_indexer import index_confluence_pages
|
||||||
from .discord_indexer import index_discord_messages
|
|
||||||
|
|
||||||
# Development platforms
|
|
||||||
from .elasticsearch_indexer import index_elasticsearch_documents
|
from .elasticsearch_indexer import index_elasticsearch_documents
|
||||||
from .github_indexer import index_github_repos
|
from .github_indexer import index_github_repos
|
||||||
from .google_calendar_indexer import index_google_calendar_events
|
from .google_calendar_indexer import index_google_calendar_events
|
||||||
from .google_drive_indexer import index_google_drive_files
|
from .google_drive_indexer import index_google_drive_files
|
||||||
from .google_gmail_indexer import index_google_gmail_messages
|
from .google_gmail_indexer import index_google_gmail_messages
|
||||||
from .jira_indexer import index_jira_issues
|
|
||||||
|
|
||||||
# Issue tracking and project management
|
|
||||||
from .linear_indexer import index_linear_issues
|
|
||||||
|
|
||||||
# Documentation and knowledge management
|
|
||||||
from .luma_indexer import index_luma_events
|
|
||||||
from .notion_indexer import index_notion_pages
|
from .notion_indexer import index_notion_pages
|
||||||
from .slack_indexer import index_slack_messages
|
|
||||||
from .webcrawler_indexer import index_crawled_urls
|
from .webcrawler_indexer import index_crawled_urls
|
||||||
|
|
||||||
__all__ = [ # noqa: RUF022
|
__all__ = [
|
||||||
"index_airtable_records",
|
|
||||||
"index_bookstack_pages",
|
"index_bookstack_pages",
|
||||||
# "index_composio_connector", # Imported directly in connector_tasks.py to avoid circular imports
|
|
||||||
"index_clickup_tasks",
|
|
||||||
"index_confluence_pages",
|
"index_confluence_pages",
|
||||||
"index_discord_messages",
|
"index_crawled_urls",
|
||||||
# Development platforms
|
|
||||||
"index_elasticsearch_documents",
|
"index_elasticsearch_documents",
|
||||||
"index_github_repos",
|
"index_github_repos",
|
||||||
# Calendar and scheduling
|
|
||||||
"index_google_calendar_events",
|
"index_google_calendar_events",
|
||||||
"index_google_drive_files",
|
"index_google_drive_files",
|
||||||
"index_luma_events",
|
|
||||||
"index_jira_issues",
|
|
||||||
# Issue tracking and project management
|
|
||||||
"index_linear_issues",
|
|
||||||
# Documentation and knowledge management
|
|
||||||
"index_notion_pages",
|
|
||||||
"index_crawled_urls",
|
|
||||||
# Communication platforms
|
|
||||||
"index_slack_messages",
|
|
||||||
"index_google_gmail_messages",
|
"index_google_gmail_messages",
|
||||||
|
"index_notion_pages",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
129
surfsense_backend/app/utils/async_retry.py
Normal file
129
surfsense_backend/app/utils/async_retry.py
Normal file
|
|
@ -0,0 +1,129 @@
|
||||||
|
"""Async retry decorators for connector API calls, built on tenacity."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception,
|
||||||
|
stop_after_attempt,
|
||||||
|
stop_after_delay,
|
||||||
|
wait_exponential_jitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.connectors.exceptions import (
|
||||||
|
ConnectorAPIError,
|
||||||
|
ConnectorAuthError,
|
||||||
|
ConnectorError,
|
||||||
|
ConnectorRateLimitError,
|
||||||
|
ConnectorTimeoutError,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=Callable)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_retryable(exc: BaseException) -> bool:
|
||||||
|
if isinstance(exc, ConnectorError):
|
||||||
|
return exc.retryable
|
||||||
|
if isinstance(exc, (httpx.TimeoutException, httpx.ConnectError)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def build_retry(
|
||||||
|
*,
|
||||||
|
max_attempts: int = 4,
|
||||||
|
max_delay: float = 60.0,
|
||||||
|
initial_delay: float = 1.0,
|
||||||
|
total_timeout: float = 180.0,
|
||||||
|
service: str = "",
|
||||||
|
) -> Callable:
|
||||||
|
"""Configurable tenacity ``@retry`` decorator with exponential backoff + jitter."""
|
||||||
|
_logger = logging.getLogger(f"connector.retry.{service}") if service else logger
|
||||||
|
|
||||||
|
return retry(
|
||||||
|
retry=retry_if_exception(_is_retryable),
|
||||||
|
stop=(stop_after_attempt(max_attempts) | stop_after_delay(total_timeout)),
|
||||||
|
wait=wait_exponential_jitter(initial=initial_delay, max=max_delay),
|
||||||
|
reraise=True,
|
||||||
|
before_sleep=before_sleep_log(_logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def retry_on_transient(
|
||||||
|
*,
|
||||||
|
service: str = "",
|
||||||
|
max_attempts: int = 4,
|
||||||
|
) -> Callable:
|
||||||
|
"""Shorthand: retry up to *max_attempts* on rate-limits, timeouts, and 5xx."""
|
||||||
|
return build_retry(max_attempts=max_attempts, service=service)
|
||||||
|
|
||||||
|
|
||||||
|
def raise_for_status(
|
||||||
|
response: httpx.Response,
|
||||||
|
*,
|
||||||
|
service: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""Map non-2xx httpx responses to the appropriate ``ConnectorError``."""
|
||||||
|
if response.is_success:
|
||||||
|
return
|
||||||
|
|
||||||
|
status = response.status_code
|
||||||
|
|
||||||
|
try:
|
||||||
|
body = response.json()
|
||||||
|
except Exception:
|
||||||
|
body = response.text[:500] if response.text else None
|
||||||
|
|
||||||
|
if status == 429:
|
||||||
|
retry_after_raw = response.headers.get("Retry-After")
|
||||||
|
retry_after: float | None = None
|
||||||
|
if retry_after_raw:
|
||||||
|
try:
|
||||||
|
retry_after = float(retry_after_raw)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
raise ConnectorRateLimitError(
|
||||||
|
f"{service} rate limited (429)",
|
||||||
|
service=service,
|
||||||
|
retry_after=retry_after,
|
||||||
|
response_body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if status in (401, 403):
|
||||||
|
raise ConnectorAuthError(
|
||||||
|
f"{service} authentication failed ({status})",
|
||||||
|
service=service,
|
||||||
|
status_code=status,
|
||||||
|
response_body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if status == 504:
|
||||||
|
raise ConnectorTimeoutError(
|
||||||
|
f"{service} gateway timeout (504)",
|
||||||
|
service=service,
|
||||||
|
status_code=status,
|
||||||
|
response_body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if status >= 500:
|
||||||
|
raise ConnectorAPIError(
|
||||||
|
f"{service} server error ({status})",
|
||||||
|
service=service,
|
||||||
|
status_code=status,
|
||||||
|
response_body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ConnectorAPIError(
|
||||||
|
f"{service} request failed ({status})",
|
||||||
|
service=service,
|
||||||
|
status_code=status,
|
||||||
|
response_body=body,
|
||||||
|
)
|
||||||
|
|
@ -39,7 +39,7 @@ BASE_NAME_FOR_TYPE = {
|
||||||
def get_base_name_for_type(connector_type: SearchSourceConnectorType) -> str:
|
def get_base_name_for_type(connector_type: SearchSourceConnectorType) -> str:
|
||||||
"""Get a friendly display name for a connector type."""
|
"""Get a friendly display name for a connector type."""
|
||||||
return BASE_NAME_FOR_TYPE.get(
|
return BASE_NAME_FOR_TYPE.get(
|
||||||
connector_type, connector_type.replace("_", " ").title()
|
connector_type, connector_type.value.replace("_", " ").title()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -231,9 +231,11 @@ async def generate_unique_connector_name(
|
||||||
base = get_base_name_for_type(connector_type)
|
base = get_base_name_for_type(connector_type)
|
||||||
|
|
||||||
if identifier:
|
if identifier:
|
||||||
return f"{base} - {identifier}"
|
name = f"{base} - {identifier}"
|
||||||
|
return await ensure_unique_connector_name(
|
||||||
|
session, name, search_space_id, user_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Fallback: use counter for uniqueness
|
|
||||||
count = await count_connectors_of_type(
|
count = await count_connectors_of_type(
|
||||||
session, connector_type, search_space_id, user_id
|
session, connector_type, search_space_id, user_id
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -18,19 +18,9 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Mapping of connector types to their corresponding Celery task names
|
# Mapping of connector types to their corresponding Celery task names
|
||||||
CONNECTOR_TASK_MAP = {
|
CONNECTOR_TASK_MAP = {
|
||||||
SearchSourceConnectorType.SLACK_CONNECTOR: "index_slack_messages",
|
|
||||||
SearchSourceConnectorType.TEAMS_CONNECTOR: "index_teams_messages",
|
|
||||||
SearchSourceConnectorType.NOTION_CONNECTOR: "index_notion_pages",
|
SearchSourceConnectorType.NOTION_CONNECTOR: "index_notion_pages",
|
||||||
SearchSourceConnectorType.GITHUB_CONNECTOR: "index_github_repos",
|
SearchSourceConnectorType.GITHUB_CONNECTOR: "index_github_repos",
|
||||||
SearchSourceConnectorType.LINEAR_CONNECTOR: "index_linear_issues",
|
|
||||||
SearchSourceConnectorType.JIRA_CONNECTOR: "index_jira_issues",
|
|
||||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "index_confluence_pages",
|
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "index_confluence_pages",
|
||||||
SearchSourceConnectorType.CLICKUP_CONNECTOR: "index_clickup_tasks",
|
|
||||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "index_google_calendar_events",
|
|
||||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR: "index_airtable_records",
|
|
||||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: "index_google_gmail_messages",
|
|
||||||
SearchSourceConnectorType.DISCORD_CONNECTOR: "index_discord_messages",
|
|
||||||
SearchSourceConnectorType.LUMA_CONNECTOR: "index_luma_events",
|
|
||||||
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: "index_elasticsearch_documents",
|
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: "index_elasticsearch_documents",
|
||||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: "index_crawled_urls",
|
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: "index_crawled_urls",
|
||||||
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: "index_bookstack_pages",
|
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: "index_bookstack_pages",
|
||||||
|
|
@ -83,39 +73,19 @@ def create_periodic_schedule(
|
||||||
f"(frequency: {frequency_minutes} minutes). Triggering first run..."
|
f"(frequency: {frequency_minutes} minutes). Triggering first run..."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import all indexing tasks
|
|
||||||
from app.tasks.celery_tasks.connector_tasks import (
|
from app.tasks.celery_tasks.connector_tasks import (
|
||||||
index_airtable_records_task,
|
|
||||||
index_bookstack_pages_task,
|
index_bookstack_pages_task,
|
||||||
index_clickup_tasks_task,
|
|
||||||
index_confluence_pages_task,
|
index_confluence_pages_task,
|
||||||
index_crawled_urls_task,
|
index_crawled_urls_task,
|
||||||
index_discord_messages_task,
|
|
||||||
index_elasticsearch_documents_task,
|
index_elasticsearch_documents_task,
|
||||||
index_github_repos_task,
|
index_github_repos_task,
|
||||||
index_google_calendar_events_task,
|
|
||||||
index_google_gmail_messages_task,
|
|
||||||
index_jira_issues_task,
|
|
||||||
index_linear_issues_task,
|
|
||||||
index_luma_events_task,
|
|
||||||
index_notion_pages_task,
|
index_notion_pages_task,
|
||||||
index_slack_messages_task,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map connector type to task
|
|
||||||
task_map = {
|
task_map = {
|
||||||
SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task,
|
|
||||||
SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task,
|
SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task,
|
||||||
SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task,
|
SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task,
|
||||||
SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task,
|
|
||||||
SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task,
|
|
||||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task,
|
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task,
|
||||||
SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task,
|
|
||||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
|
|
||||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task,
|
|
||||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task,
|
|
||||||
SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task,
|
|
||||||
SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task,
|
|
||||||
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
|
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
|
||||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
|
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
|
||||||
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: index_bookstack_pages_task,
|
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: index_bookstack_pages_task,
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ dependencies = [
|
||||||
"deepagents>=0.4.12",
|
"deepagents>=0.4.12",
|
||||||
"stripe>=15.0.0",
|
"stripe>=15.0.0",
|
||||||
"azure-ai-documentintelligence>=1.0.2",
|
"azure-ai-documentintelligence>=1.0.2",
|
||||||
"litellm>=1.83.0",
|
"litellm>=1.83.4",
|
||||||
"langchain-litellm>=0.6.4",
|
"langchain-litellm>=0.6.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,213 @@
|
||||||
|
"""Unit tests for resume page-limit helpers and enforcement flow."""
|
||||||
|
|
||||||
|
import io
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pypdf
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools import resume as resume_tool
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeReport:
|
||||||
|
_next_id = 1000
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
self.id = None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self, parent_report=None):
|
||||||
|
self.parent_report = parent_report
|
||||||
|
self.added: list[_FakeReport] = []
|
||||||
|
|
||||||
|
async def get(self, _model, _id):
|
||||||
|
return self.parent_report
|
||||||
|
|
||||||
|
def add(self, report):
|
||||||
|
self.added.append(report)
|
||||||
|
|
||||||
|
async def commit(self):
|
||||||
|
for report in self.added:
|
||||||
|
if getattr(report, "id", None) is None:
|
||||||
|
report.id = _FakeReport._next_id
|
||||||
|
_FakeReport._next_id += 1
|
||||||
|
|
||||||
|
async def refresh(self, _report):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _SessionContext:
|
||||||
|
def __init__(self, session):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self.session
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class _SessionFactory:
|
||||||
|
def __init__(self, sessions):
|
||||||
|
self._sessions = list(sessions)
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
if not self._sessions:
|
||||||
|
raise RuntimeError("No fake sessions left")
|
||||||
|
return _SessionContext(self._sessions.pop(0))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pdf_with_pages(page_count: int) -> bytes:
|
||||||
|
writer = pypdf.PdfWriter()
|
||||||
|
for _ in range(page_count):
|
||||||
|
writer.add_blank_page(width=612, height=792)
|
||||||
|
output = io.BytesIO()
|
||||||
|
writer.write(output)
|
||||||
|
return output.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def test_count_pdf_pages_reads_compiled_bytes() -> None:
|
||||||
|
pdf_bytes = _make_pdf_with_pages(2)
|
||||||
|
assert resume_tool._count_pdf_pages(pdf_bytes) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_max_pages_rejects_out_of_range() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
resume_tool._validate_max_pages(0)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
resume_tool._validate_max_pages(6)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_resume_defaults_to_one_page_target(monkeypatch) -> None:
|
||||||
|
read_session = _FakeSession()
|
||||||
|
write_session = _FakeSession()
|
||||||
|
session_factory = _SessionFactory([read_session, write_session])
|
||||||
|
monkeypatch.setattr(resume_tool, "shielded_async_session", session_factory)
|
||||||
|
monkeypatch.setattr(resume_tool, "Report", _FakeReport)
|
||||||
|
|
||||||
|
prompts: list[str] = []
|
||||||
|
|
||||||
|
async def _llm_invoke(messages):
|
||||||
|
prompts.append(messages[0].content)
|
||||||
|
return SimpleNamespace(content="= Jane Doe\n== Experience\n- Built systems")
|
||||||
|
|
||||||
|
llm = SimpleNamespace(ainvoke=AsyncMock(side_effect=_llm_invoke))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
resume_tool,
|
||||||
|
"get_document_summary_llm",
|
||||||
|
AsyncMock(return_value=llm),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(resume_tool, "_compile_typst", lambda _source: b"pdf")
|
||||||
|
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: 1)
|
||||||
|
|
||||||
|
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||||
|
result = await tool.ainvoke({"user_info": "Jane Doe experience"})
|
||||||
|
|
||||||
|
assert result["status"] == "ready"
|
||||||
|
assert prompts
|
||||||
|
assert "**Target Maximum Pages:** 1" in prompts[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_resume_compresses_when_over_limit(monkeypatch) -> None:
|
||||||
|
read_session = _FakeSession()
|
||||||
|
write_session = _FakeSession()
|
||||||
|
session_factory = _SessionFactory([read_session, write_session])
|
||||||
|
monkeypatch.setattr(resume_tool, "shielded_async_session", session_factory)
|
||||||
|
monkeypatch.setattr(resume_tool, "Report", _FakeReport)
|
||||||
|
|
||||||
|
responses = [
|
||||||
|
SimpleNamespace(content="= Jane Doe\n== Experience\n- Detailed bullet 1"),
|
||||||
|
SimpleNamespace(content="= Jane Doe\n== Experience\n- Condensed bullet"),
|
||||||
|
]
|
||||||
|
llm = SimpleNamespace(ainvoke=AsyncMock(side_effect=responses))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
resume_tool,
|
||||||
|
"get_document_summary_llm",
|
||||||
|
AsyncMock(return_value=llm),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(resume_tool, "_compile_typst", lambda _source: b"pdf")
|
||||||
|
page_counts = iter([2, 1])
|
||||||
|
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
||||||
|
|
||||||
|
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||||
|
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
|
||||||
|
|
||||||
|
assert result["status"] == "ready"
|
||||||
|
assert write_session.added, "Expected successful report write"
|
||||||
|
metadata = write_session.added[0].report_metadata
|
||||||
|
assert metadata["target_max_pages"] == 1
|
||||||
|
assert metadata["actual_page_count"] == 1
|
||||||
|
assert metadata["compression_attempts"] == 1
|
||||||
|
assert metadata["page_limit_enforced"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_resume_returns_ready_when_target_not_met(monkeypatch) -> None:
|
||||||
|
read_session = _FakeSession()
|
||||||
|
write_session = _FakeSession()
|
||||||
|
session_factory = _SessionFactory([read_session, write_session])
|
||||||
|
monkeypatch.setattr(resume_tool, "shielded_async_session", session_factory)
|
||||||
|
monkeypatch.setattr(resume_tool, "Report", _FakeReport)
|
||||||
|
|
||||||
|
responses = [
|
||||||
|
SimpleNamespace(content="= Jane Doe\n== Experience\n- Long detail"),
|
||||||
|
SimpleNamespace(content="= Jane Doe\n== Experience\n- Still long"),
|
||||||
|
SimpleNamespace(content="= Jane Doe\n== Experience\n- Still too long"),
|
||||||
|
]
|
||||||
|
llm = SimpleNamespace(ainvoke=AsyncMock(side_effect=responses))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
resume_tool,
|
||||||
|
"get_document_summary_llm",
|
||||||
|
AsyncMock(return_value=llm),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(resume_tool, "_compile_typst", lambda _source: b"pdf")
|
||||||
|
page_counts = iter([3, 3, 2])
|
||||||
|
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
||||||
|
|
||||||
|
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||||
|
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
|
||||||
|
|
||||||
|
assert result["status"] == "ready"
|
||||||
|
assert "could not fit the target" in (result["message"] or "").lower()
|
||||||
|
metadata = write_session.added[0].report_metadata
|
||||||
|
assert metadata["target_page_met"] is False
|
||||||
|
assert metadata["actual_page_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_resume_fails_when_hard_limit_exceeded(monkeypatch) -> None:
|
||||||
|
read_session = _FakeSession()
|
||||||
|
failed_session = _FakeSession()
|
||||||
|
session_factory = _SessionFactory([read_session, failed_session])
|
||||||
|
monkeypatch.setattr(resume_tool, "shielded_async_session", session_factory)
|
||||||
|
monkeypatch.setattr(resume_tool, "Report", _FakeReport)
|
||||||
|
|
||||||
|
responses = [
|
||||||
|
SimpleNamespace(content="= Jane Doe\n== Experience\n- Long detail"),
|
||||||
|
SimpleNamespace(content="= Jane Doe\n== Experience\n- Still long"),
|
||||||
|
SimpleNamespace(content="= Jane Doe\n== Experience\n- Still too long"),
|
||||||
|
]
|
||||||
|
llm = SimpleNamespace(ainvoke=AsyncMock(side_effect=responses))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
resume_tool,
|
||||||
|
"get_document_summary_llm",
|
||||||
|
AsyncMock(return_value=llm),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(resume_tool, "_compile_typst", lambda _source: b"pdf")
|
||||||
|
page_counts = iter([7, 6, 6])
|
||||||
|
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
||||||
|
|
||||||
|
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||||
|
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
|
||||||
|
|
||||||
|
assert result["status"] == "failed"
|
||||||
|
assert "hard page limit" in (result["error"] or "").lower()
|
||||||
|
assert failed_session.added, "Expected failed report persistence"
|
||||||
|
|
@ -0,0 +1,214 @@
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.file_intent import (
|
||||||
|
FileIntentMiddleware,
|
||||||
|
FileOperationIntent,
|
||||||
|
_fallback_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeLLM:
|
||||||
|
def __init__(self, response_text: str):
|
||||||
|
self._response_text = response_text
|
||||||
|
|
||||||
|
async def ainvoke(self, *_args, **_kwargs):
|
||||||
|
return AIMessage(content=self._response_text)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_write_intent_injects_contract_message():
|
||||||
|
llm = _FakeLLM(
|
||||||
|
'{"intent":"file_write","confidence":0.93,"suggested_filename":"ideas.md"}'
|
||||||
|
)
|
||||||
|
middleware = FileIntentMiddleware(llm=llm)
|
||||||
|
state = {
|
||||||
|
"messages": [HumanMessage(content="Create another random note for me")],
|
||||||
|
"turn_id": "123:456",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
contract = result["file_operation_contract"]
|
||||||
|
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||||
|
assert contract["suggested_path"] == "/ideas.md"
|
||||||
|
assert contract["turn_id"] == "123:456"
|
||||||
|
assert any(
|
||||||
|
"file_operation_contract" in str(msg.content)
|
||||||
|
for msg in result["messages"]
|
||||||
|
if hasattr(msg, "content")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_write_intent_does_not_inject_contract_message():
|
||||||
|
llm = _FakeLLM(
|
||||||
|
'{"intent":"file_read","confidence":0.88,"suggested_filename":null}'
|
||||||
|
)
|
||||||
|
middleware = FileIntentMiddleware(llm=llm)
|
||||||
|
original_messages = [HumanMessage(content="Read /notes.md")]
|
||||||
|
state = {"messages": original_messages, "turn_id": "abc:def"}
|
||||||
|
|
||||||
|
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["file_operation_contract"]["intent"] == FileOperationIntent.FILE_READ.value
|
||||||
|
assert "messages" not in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_write_null_filename_uses_semantic_default_path():
|
||||||
|
llm = _FakeLLM(
|
||||||
|
'{"intent":"file_write","confidence":0.74,"suggested_filename":null}'
|
||||||
|
)
|
||||||
|
middleware = FileIntentMiddleware(llm=llm)
|
||||||
|
state = {
|
||||||
|
"messages": [HumanMessage(content="create a random markdown file")],
|
||||||
|
"turn_id": "turn:1",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
contract = result["file_operation_contract"]
|
||||||
|
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||||
|
assert contract["suggested_path"] == "/notes.md"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_write_null_filename_infers_json_extension():
|
||||||
|
llm = _FakeLLM(
|
||||||
|
'{"intent":"file_write","confidence":0.71,"suggested_filename":null}'
|
||||||
|
)
|
||||||
|
middleware = FileIntentMiddleware(llm=llm)
|
||||||
|
state = {
|
||||||
|
"messages": [HumanMessage(content="create a sample json config file")],
|
||||||
|
"turn_id": "turn:2",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
contract = result["file_operation_contract"]
|
||||||
|
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||||
|
assert contract["suggested_path"] == "/notes.json"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_write_txt_suggestion_is_normalized_to_markdown():
|
||||||
|
llm = _FakeLLM(
|
||||||
|
'{"intent":"file_write","confidence":0.82,"suggested_filename":"random.txt"}'
|
||||||
|
)
|
||||||
|
middleware = FileIntentMiddleware(llm=llm)
|
||||||
|
state = {
|
||||||
|
"messages": [HumanMessage(content="create a random file")],
|
||||||
|
"turn_id": "turn:3",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
contract = result["file_operation_contract"]
|
||||||
|
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||||
|
assert contract["suggested_path"] == "/random.md"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_write_with_suggested_directory_preserves_folder():
|
||||||
|
llm = _FakeLLM(
|
||||||
|
'{"intent":"file_write","confidence":0.86,"suggested_filename":"random.md","suggested_directory":"pc backups","suggested_path":null}'
|
||||||
|
)
|
||||||
|
middleware = FileIntentMiddleware(llm=llm)
|
||||||
|
state = {
|
||||||
|
"messages": [HumanMessage(content="create a random file in pc backups folder")],
|
||||||
|
"turn_id": "turn:4",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
contract = result["file_operation_contract"]
|
||||||
|
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||||
|
assert contract["suggested_path"] == "/pc_backups/random.md"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_write_with_suggested_path_takes_precedence():
|
||||||
|
llm = _FakeLLM(
|
||||||
|
'{"intent":"file_write","confidence":0.9,"suggested_filename":"ignored.md","suggested_directory":"docs","suggested_path":"/reports/q2/summary.md"}'
|
||||||
|
)
|
||||||
|
middleware = FileIntentMiddleware(llm=llm)
|
||||||
|
state = {
|
||||||
|
"messages": [HumanMessage(content="create report")],
|
||||||
|
"turn_id": "turn:5",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
contract = result["file_operation_contract"]
|
||||||
|
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||||
|
assert contract["suggested_path"] == "/reports/q2/summary.md"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_write_infers_directory_from_user_text_when_missing():
|
||||||
|
llm = _FakeLLM(
|
||||||
|
'{"intent":"file_write","confidence":0.83,"suggested_filename":"random.md","suggested_directory":null,"suggested_path":null}'
|
||||||
|
)
|
||||||
|
middleware = FileIntentMiddleware(llm=llm)
|
||||||
|
state = {
|
||||||
|
"messages": [HumanMessage(content="create a random file in pc backups folder")],
|
||||||
|
"turn_id": "turn:6",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
contract = result["file_operation_contract"]
|
||||||
|
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
||||||
|
assert contract["suggested_path"] == "/pc_backups/random.md"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_path_normalizes_windows_slashes() -> None:
|
||||||
|
resolved = _fallback_path(
|
||||||
|
suggested_filename="summary.md",
|
||||||
|
suggested_path=r"\reports\q2\summary.md",
|
||||||
|
user_text="create report",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved == "/reports/q2/summary.md"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_path_normalizes_windows_drive_path() -> None:
|
||||||
|
resolved = _fallback_path(
|
||||||
|
suggested_filename=None,
|
||||||
|
suggested_path=r"C:\Users\anish\notes\todo.md",
|
||||||
|
user_text="create note",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved == "/C/Users/anish/notes/todo.md"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_path_normalizes_mixed_separators_and_duplicate_slashes() -> None:
|
||||||
|
resolved = _fallback_path(
|
||||||
|
suggested_filename="summary.md",
|
||||||
|
suggested_path=r"\\reports\\q2//summary.md",
|
||||||
|
user_text="create report",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved == "/reports/q2/summary.md"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_path_keeps_posix_style_absolute_path_for_linux_and_macos() -> None:
|
||||||
|
resolved = _fallback_path(
|
||||||
|
suggested_filename=None,
|
||||||
|
suggested_path="/var/log/surfsense/notes.md",
|
||||||
|
user_text="create note",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved == "/var/log/surfsense/notes.md"
|
||||||
|
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||||
|
from app.agents.new_chat.filesystem_selection import (
|
||||||
|
ClientPlatform,
|
||||||
|
FilesystemMode,
|
||||||
|
FilesystemSelection,
|
||||||
|
LocalFilesystemMount,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||||
|
MultiRootLocalFolderBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class _RuntimeStub:
|
||||||
|
state = {"files": {}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_backend_resolver_returns_multi_root_backend_for_single_root(tmp_path: Path):
|
||||||
|
selection = FilesystemSelection(
|
||||||
|
mode=FilesystemMode.DESKTOP_LOCAL_FOLDER,
|
||||||
|
client_platform=ClientPlatform.DESKTOP,
|
||||||
|
local_mounts=(LocalFilesystemMount(mount_id="tmp", root_path=str(tmp_path)),),
|
||||||
|
)
|
||||||
|
resolver = build_backend_resolver(selection)
|
||||||
|
|
||||||
|
backend = resolver(_RuntimeStub())
|
||||||
|
assert isinstance(backend, MultiRootLocalFolderBackend)
|
||||||
|
|
||||||
|
|
||||||
|
def test_backend_resolver_uses_cloud_mode_by_default():
|
||||||
|
resolver = build_backend_resolver(FilesystemSelection())
|
||||||
|
backend = resolver(_RuntimeStub())
|
||||||
|
# StateBackend class name check keeps this test decoupled
|
||||||
|
# from internal deepagents runtime class identity.
|
||||||
|
assert backend.__class__.__name__ == "StateBackend"
|
||||||
|
|
||||||
|
|
||||||
|
def test_backend_resolver_returns_multi_root_backend_for_multiple_roots(tmp_path: Path):
|
||||||
|
root_one = tmp_path / "resume"
|
||||||
|
root_two = tmp_path / "notes"
|
||||||
|
root_one.mkdir()
|
||||||
|
root_two.mkdir()
|
||||||
|
selection = FilesystemSelection(
|
||||||
|
mode=FilesystemMode.DESKTOP_LOCAL_FOLDER,
|
||||||
|
client_platform=ClientPlatform.DESKTOP,
|
||||||
|
local_mounts=(
|
||||||
|
LocalFilesystemMount(mount_id="resume", root_path=str(root_one)),
|
||||||
|
LocalFilesystemMount(mount_id="notes", root_path=str(root_two)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
resolver = build_backend_resolver(selection)
|
||||||
|
|
||||||
|
backend = resolver(_RuntimeStub())
|
||||||
|
assert isinstance(backend, MultiRootLocalFolderBackend)
|
||||||
|
|
@ -0,0 +1,164 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||||
|
MultiRootLocalFolderBackend,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class _BackendWithRawRead:
|
||||||
|
def __init__(self, content: str) -> None:
|
||||||
|
self._content = content
|
||||||
|
|
||||||
|
def read(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
|
||||||
|
del file_path, offset, limit
|
||||||
|
return " 1\tline1\n 2\tline2"
|
||||||
|
|
||||||
|
async def aread(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
|
||||||
|
return self.read(file_path, offset, limit)
|
||||||
|
|
||||||
|
def read_raw(self, file_path: str) -> str:
|
||||||
|
del file_path
|
||||||
|
return self._content
|
||||||
|
|
||||||
|
async def aread_raw(self, file_path: str) -> str:
|
||||||
|
return self.read_raw(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
class _RuntimeNoSuggestedPath:
|
||||||
|
state = {"file_operation_contract": {}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_written_content_prefers_raw_sync() -> None:
|
||||||
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
|
expected = "line1\nline2"
|
||||||
|
backend = _BackendWithRawRead(expected)
|
||||||
|
|
||||||
|
verify_error = middleware._verify_written_content_sync(
|
||||||
|
backend=backend,
|
||||||
|
path="/note.md",
|
||||||
|
expected_content=expected,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert verify_error is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_contract_suggested_path_falls_back_to_notes_md() -> None:
|
||||||
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
|
middleware._filesystem_mode = FilesystemMode.CLOUD
|
||||||
|
suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type]
|
||||||
|
assert suggested == "/notes.md"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_written_content_prefers_raw_async() -> None:
|
||||||
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
|
expected = "line1\nline2"
|
||||||
|
backend = _BackendWithRawRead(expected)
|
||||||
|
|
||||||
|
verify_error = await middleware._verify_written_content_async(
|
||||||
|
backend=backend,
|
||||||
|
path="/note.md",
|
||||||
|
expected_content=expected,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert verify_error is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_local_mount_path_prefixes_default_mount(tmp_path: Path) -> None:
|
||||||
|
root = tmp_path / "PC Backups"
|
||||||
|
root.mkdir()
|
||||||
|
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
|
||||||
|
runtime = _RuntimeNoSuggestedPath()
|
||||||
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
|
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
|
||||||
|
|
||||||
|
resolved = middleware._normalize_local_mount_path("/random-note.md", runtime) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert resolved == "/pc_backups/random-note.md"
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_local_mount_path_keeps_explicit_mount(tmp_path: Path) -> None:
|
||||||
|
root = tmp_path / "PC Backups"
|
||||||
|
root.mkdir()
|
||||||
|
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
|
||||||
|
runtime = _RuntimeNoSuggestedPath()
|
||||||
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
|
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
|
||||||
|
|
||||||
|
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
|
||||||
|
"/pc_backups/notes/random-note.md",
|
||||||
|
runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved == "/pc_backups/notes/random-note.md"
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_local_mount_path_windows_backslashes(tmp_path: Path) -> None:
|
||||||
|
root = tmp_path / "PC Backups"
|
||||||
|
root.mkdir()
|
||||||
|
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
|
||||||
|
runtime = _RuntimeNoSuggestedPath()
|
||||||
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
|
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
|
||||||
|
|
||||||
|
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
|
||||||
|
r"\notes\random-note.md",
|
||||||
|
runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved == "/pc_backups/notes/random-note.md"
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_local_mount_path_normalizes_mixed_separators(tmp_path: Path) -> None:
|
||||||
|
root = tmp_path / "PC Backups"
|
||||||
|
root.mkdir()
|
||||||
|
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
|
||||||
|
runtime = _RuntimeNoSuggestedPath()
|
||||||
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
|
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
|
||||||
|
|
||||||
|
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
|
||||||
|
r"\\notes//nested\\random-note.md",
|
||||||
|
runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved == "/pc_backups/notes/nested/random-note.md"
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_local_mount_path_keeps_explicit_mount_with_backslashes(
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
root = tmp_path / "PC Backups"
|
||||||
|
root.mkdir()
|
||||||
|
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
|
||||||
|
runtime = _RuntimeNoSuggestedPath()
|
||||||
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
|
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
|
||||||
|
|
||||||
|
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
|
||||||
|
r"\pc_backups\notes\random-note.md",
|
||||||
|
runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved == "/pc_backups/notes/random-note.md"
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_local_mount_path_prefixes_posix_absolute_path_for_linux_and_macos(
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
root = tmp_path / "PC Backups"
|
||||||
|
root.mkdir()
|
||||||
|
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
|
||||||
|
runtime = _RuntimeNoSuggestedPath()
|
||||||
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
|
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
|
||||||
|
|
||||||
|
resolved = middleware._normalize_local_mount_path("/var/log/app.log", runtime) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert resolved == "/pc_backups/var/log/app.log"
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_write_read_edit_roundtrip(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
|
||||||
|
write = backend.write("/notes/test.md", "line1\nline2")
|
||||||
|
assert write.error is None
|
||||||
|
assert write.path == "/notes/test.md"
|
||||||
|
|
||||||
|
read = backend.read("/notes/test.md", offset=0, limit=20)
|
||||||
|
assert "line1" in read
|
||||||
|
assert "line2" in read
|
||||||
|
|
||||||
|
edit = backend.edit("/notes/test.md", "line2", "updated")
|
||||||
|
assert edit.error is None
|
||||||
|
assert edit.occurrences == 1
|
||||||
|
|
||||||
|
read_after = backend.read("/notes/test.md", offset=0, limit=20)
|
||||||
|
assert "updated" in read_after
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_blocks_path_escape(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
|
||||||
|
result = backend.write("/../../etc/passwd", "bad")
|
||||||
|
assert result.error is not None
|
||||||
|
assert "Invalid path" in result.error
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_glob_and_grep(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
(tmp_path / "docs").mkdir()
|
||||||
|
(tmp_path / "docs" / "a.txt").write_text("hello world\n")
|
||||||
|
(tmp_path / "docs" / "b.md").write_text("hello markdown\n")
|
||||||
|
|
||||||
|
infos = backend.glob_info("**/*.txt", "/docs")
|
||||||
|
paths = {info["path"] for info in infos}
|
||||||
|
assert "/docs/a.txt" in paths
|
||||||
|
|
||||||
|
grep = backend.grep_raw("hello", "/docs", "*.md")
|
||||||
|
assert isinstance(grep, list)
|
||||||
|
assert any(match["path"] == "/docs/b.md" for match in grep)
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_read_raw_returns_exact_content(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
expected = "# Title\n\nline 1\nline 2\n"
|
||||||
|
write = backend.write("/notes/raw.md", expected)
|
||||||
|
assert write.error is None
|
||||||
|
|
||||||
|
raw = backend.read_raw("/notes/raw.md")
|
||||||
|
assert raw == expected
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||||
|
MultiRootLocalFolderBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_mount_ids_preserve_client_mapping_order(tmp_path: Path) -> None:
|
||||||
|
root_one = tmp_path / "PC Backups"
|
||||||
|
root_two = tmp_path / "pc_backups"
|
||||||
|
root_three = tmp_path / "notes@2026"
|
||||||
|
root_one.mkdir()
|
||||||
|
root_two.mkdir()
|
||||||
|
root_three.mkdir()
|
||||||
|
|
||||||
|
backend = MultiRootLocalFolderBackend(
|
||||||
|
(
|
||||||
|
("pc_backups", str(root_one)),
|
||||||
|
("pc_backups_2", str(root_two)),
|
||||||
|
("notes_2026", str(root_three)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert backend.list_mounts() == ("pc_backups", "pc_backups_2", "notes_2026")
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.tasks.chat.stream_new_chat import (
|
||||||
|
StreamResult,
|
||||||
|
_contract_enforcement_active,
|
||||||
|
_evaluate_file_contract_outcome,
|
||||||
|
_tool_output_has_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_output_error_detection():
|
||||||
|
assert _tool_output_has_error("Error: failed to write file")
|
||||||
|
assert _tool_output_has_error({"error": "boom"})
|
||||||
|
assert _tool_output_has_error({"result": "Error: disk is full"})
|
||||||
|
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_write_contract_outcome_reasons():
|
||||||
|
result = StreamResult(intent_detected="file_write")
|
||||||
|
passed, reason = _evaluate_file_contract_outcome(result)
|
||||||
|
assert not passed
|
||||||
|
assert reason == "no_write_attempt"
|
||||||
|
|
||||||
|
result.write_attempted = True
|
||||||
|
passed, reason = _evaluate_file_contract_outcome(result)
|
||||||
|
assert not passed
|
||||||
|
assert reason == "write_failed"
|
||||||
|
|
||||||
|
result.write_succeeded = True
|
||||||
|
passed, reason = _evaluate_file_contract_outcome(result)
|
||||||
|
assert not passed
|
||||||
|
assert reason == "verification_failed"
|
||||||
|
|
||||||
|
result.verification_succeeded = True
|
||||||
|
passed, reason = _evaluate_file_contract_outcome(result)
|
||||||
|
assert passed
|
||||||
|
assert reason == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_contract_enforcement_local_only():
|
||||||
|
result = StreamResult(filesystem_mode="desktop_local_folder")
|
||||||
|
assert _contract_enforcement_active(result)
|
||||||
|
|
||||||
|
result.filesystem_mode = "cloud"
|
||||||
|
assert not _contract_enforcement_active(result)
|
||||||
|
|
||||||
2
surfsense_backend/uv.lock
generated
2
surfsense_backend/uv.lock
generated
|
|
@ -8070,7 +8070,7 @@ requires-dist = [
|
||||||
{ name = "langgraph", specifier = ">=1.1.3" },
|
{ name = "langgraph", specifier = ">=1.1.3" },
|
||||||
{ name = "langgraph-checkpoint-postgres", specifier = ">=3.0.2" },
|
{ name = "langgraph-checkpoint-postgres", specifier = ">=3.0.2" },
|
||||||
{ name = "linkup-sdk", specifier = ">=0.2.4" },
|
{ name = "linkup-sdk", specifier = ">=0.2.4" },
|
||||||
{ name = "litellm", specifier = ">=1.83.0" },
|
{ name = "litellm", specifier = ">=1.83.4" },
|
||||||
{ name = "llama-cloud-services", specifier = ">=0.6.25" },
|
{ name = "llama-cloud-services", specifier = ">=0.6.25" },
|
||||||
{ name = "markdown", specifier = ">=3.7" },
|
{ name = "markdown", specifier = ">=3.7" },
|
||||||
{ name = "markdownify", specifier = ">=0.14.1" },
|
{ name = "markdownify", specifier = ">=0.14.1" },
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,8 @@ export const IPC_CHANNELS = {
|
||||||
FOLDER_SYNC_SEED_MTIMES: 'folder-sync:seed-mtimes',
|
FOLDER_SYNC_SEED_MTIMES: 'folder-sync:seed-mtimes',
|
||||||
BROWSE_FILES: 'browse:files',
|
BROWSE_FILES: 'browse:files',
|
||||||
READ_LOCAL_FILES: 'browse:read-local-files',
|
READ_LOCAL_FILES: 'browse:read-local-files',
|
||||||
|
READ_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:read-local-file-text',
|
||||||
|
WRITE_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:write-local-file-text',
|
||||||
// Auth token sync across windows
|
// Auth token sync across windows
|
||||||
GET_AUTH_TOKENS: 'auth:get-tokens',
|
GET_AUTH_TOKENS: 'auth:get-tokens',
|
||||||
SET_AUTH_TOKENS: 'auth:set-tokens',
|
SET_AUTH_TOKENS: 'auth:set-tokens',
|
||||||
|
|
@ -51,4 +53,9 @@ export const IPC_CHANNELS = {
|
||||||
ANALYTICS_RESET: 'analytics:reset',
|
ANALYTICS_RESET: 'analytics:reset',
|
||||||
ANALYTICS_CAPTURE: 'analytics:capture',
|
ANALYTICS_CAPTURE: 'analytics:capture',
|
||||||
ANALYTICS_GET_CONTEXT: 'analytics:get-context',
|
ANALYTICS_GET_CONTEXT: 'analytics:get-context',
|
||||||
|
// Agent filesystem mode
|
||||||
|
AGENT_FILESYSTEM_GET_SETTINGS: 'agent-filesystem:get-settings',
|
||||||
|
AGENT_FILESYSTEM_GET_MOUNTS: 'agent-filesystem:get-mounts',
|
||||||
|
AGENT_FILESYSTEM_SET_SETTINGS: 'agent-filesystem:set-settings',
|
||||||
|
AGENT_FILESYSTEM_PICK_ROOT: 'agent-filesystem:pick-root',
|
||||||
} as const;
|
} as const;
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,14 @@ import {
|
||||||
resetUser as analyticsReset,
|
resetUser as analyticsReset,
|
||||||
trackEvent,
|
trackEvent,
|
||||||
} from '../modules/analytics';
|
} from '../modules/analytics';
|
||||||
|
import {
|
||||||
|
readAgentLocalFileText,
|
||||||
|
writeAgentLocalFileText,
|
||||||
|
getAgentFilesystemMounts,
|
||||||
|
getAgentFilesystemSettings,
|
||||||
|
pickAgentFilesystemRoot,
|
||||||
|
setAgentFilesystemSettings,
|
||||||
|
} from '../modules/agent-filesystem';
|
||||||
|
|
||||||
let authTokens: { bearer: string; refresh: string } | null = null;
|
let authTokens: { bearer: string; refresh: string } | null = null;
|
||||||
|
|
||||||
|
|
@ -118,6 +126,29 @@ export function registerIpcHandlers(): void {
|
||||||
readLocalFiles(paths)
|
readLocalFiles(paths)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
ipcMain.handle(IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, async (_event, virtualPath: string) => {
|
||||||
|
try {
|
||||||
|
const result = await readAgentLocalFileText(virtualPath);
|
||||||
|
return { ok: true, path: result.path, content: result.content };
|
||||||
|
} catch (error) {
|
||||||
|
const message = error instanceof Error ? error.message : 'Failed to read local file';
|
||||||
|
return { ok: false, path: virtualPath, error: message };
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
ipcMain.handle(
|
||||||
|
IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT,
|
||||||
|
async (_event, virtualPath: string, content: string) => {
|
||||||
|
try {
|
||||||
|
const result = await writeAgentLocalFileText(virtualPath, content);
|
||||||
|
return { ok: true, path: result.path };
|
||||||
|
} catch (error) {
|
||||||
|
const message = error instanceof Error ? error.message : 'Failed to write local file';
|
||||||
|
return { ok: false, path: virtualPath, error: message };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
ipcMain.handle(IPC_CHANNELS.SET_AUTH_TOKENS, (_event, tokens: { bearer: string; refresh: string }) => {
|
ipcMain.handle(IPC_CHANNELS.SET_AUTH_TOKENS, (_event, tokens: { bearer: string; refresh: string }) => {
|
||||||
authTokens = tokens;
|
authTokens = tokens;
|
||||||
});
|
});
|
||||||
|
|
@ -191,4 +222,22 @@ export function registerIpcHandlers(): void {
|
||||||
platform: process.platform,
|
platform: process.platform,
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS, () =>
|
||||||
|
getAgentFilesystemSettings()
|
||||||
|
);
|
||||||
|
|
||||||
|
ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS, () =>
|
||||||
|
getAgentFilesystemMounts()
|
||||||
|
);
|
||||||
|
|
||||||
|
ipcMain.handle(
|
||||||
|
IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS,
|
||||||
|
(_event, settings: { mode?: 'cloud' | 'desktop_local_folder'; localRootPaths?: string[] | null }) =>
|
||||||
|
setAgentFilesystemSettings(settings)
|
||||||
|
);
|
||||||
|
|
||||||
|
ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT, () =>
|
||||||
|
pickAgentFilesystemRoot()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
254
surfsense_desktop/src/modules/agent-filesystem.ts
Normal file
254
surfsense_desktop/src/modules/agent-filesystem.ts
Normal file
|
|
@ -0,0 +1,254 @@
|
||||||
|
import { app, dialog } from "electron";
|
||||||
|
import { access, mkdir, readFile, writeFile } from "node:fs/promises";
|
||||||
|
import { dirname, isAbsolute, join, relative, resolve } from "node:path";
|
||||||
|
|
||||||
|
export type AgentFilesystemMode = "cloud" | "desktop_local_folder";
|
||||||
|
|
||||||
|
export interface AgentFilesystemSettings {
|
||||||
|
mode: AgentFilesystemMode;
|
||||||
|
localRootPaths: string[];
|
||||||
|
updatedAt: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const SETTINGS_FILENAME = "agent-filesystem-settings.json";
|
||||||
|
const MAX_LOCAL_ROOTS = 5;
|
||||||
|
|
||||||
|
function getSettingsPath(): string {
|
||||||
|
return join(app.getPath("userData"), SETTINGS_FILENAME);
|
||||||
|
}
|
||||||
|
|
||||||
|
function getDefaultSettings(): AgentFilesystemSettings {
|
||||||
|
return {
|
||||||
|
mode: "cloud",
|
||||||
|
localRootPaths: [],
|
||||||
|
updatedAt: new Date().toISOString(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function normalizeLocalRootPaths(paths: unknown): string[] {
|
||||||
|
if (!Array.isArray(paths)) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const uniquePaths = new Set<string>();
|
||||||
|
for (const path of paths) {
|
||||||
|
if (typeof path !== "string") continue;
|
||||||
|
const trimmed = path.trim();
|
||||||
|
if (!trimmed) continue;
|
||||||
|
uniquePaths.add(trimmed);
|
||||||
|
if (uniquePaths.size >= MAX_LOCAL_ROOTS) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return [...uniquePaths];
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getAgentFilesystemSettings(): Promise<AgentFilesystemSettings> {
|
||||||
|
try {
|
||||||
|
const raw = await readFile(getSettingsPath(), "utf8");
|
||||||
|
const parsed = JSON.parse(raw) as Partial<AgentFilesystemSettings>;
|
||||||
|
if (parsed.mode !== "cloud" && parsed.mode !== "desktop_local_folder") {
|
||||||
|
return getDefaultSettings();
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
mode: parsed.mode,
|
||||||
|
localRootPaths: normalizeLocalRootPaths(parsed.localRootPaths),
|
||||||
|
updatedAt: parsed.updatedAt ?? new Date().toISOString(),
|
||||||
|
};
|
||||||
|
} catch {
|
||||||
|
return getDefaultSettings();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function setAgentFilesystemSettings(
|
||||||
|
settings: {
|
||||||
|
mode?: AgentFilesystemMode;
|
||||||
|
localRootPaths?: string[] | null;
|
||||||
|
}
|
||||||
|
): Promise<AgentFilesystemSettings> {
|
||||||
|
const current = await getAgentFilesystemSettings();
|
||||||
|
const nextMode =
|
||||||
|
settings.mode === "cloud" || settings.mode === "desktop_local_folder"
|
||||||
|
? settings.mode
|
||||||
|
: current.mode;
|
||||||
|
const next: AgentFilesystemSettings = {
|
||||||
|
mode: nextMode,
|
||||||
|
localRootPaths:
|
||||||
|
settings.localRootPaths === undefined
|
||||||
|
? current.localRootPaths
|
||||||
|
: normalizeLocalRootPaths(settings.localRootPaths ?? []),
|
||||||
|
updatedAt: new Date().toISOString(),
|
||||||
|
};
|
||||||
|
|
||||||
|
const settingsPath = getSettingsPath();
|
||||||
|
await mkdir(dirname(settingsPath), { recursive: true });
|
||||||
|
await writeFile(settingsPath, JSON.stringify(next, null, 2), "utf8");
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function pickAgentFilesystemRoot(): Promise<string | null> {
|
||||||
|
const result = await dialog.showOpenDialog({
|
||||||
|
title: "Select local folder for Agent Filesystem",
|
||||||
|
properties: ["openDirectory"],
|
||||||
|
});
|
||||||
|
if (result.canceled || result.filePaths.length === 0) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return result.filePaths[0] ?? null;
|
||||||
|
}
|
||||||
|
|
||||||
|
function resolveVirtualPath(rootPath: string, virtualPath: string): string {
|
||||||
|
if (!virtualPath.startsWith("/")) {
|
||||||
|
throw new Error("Path must start with '/'");
|
||||||
|
}
|
||||||
|
const normalizedRoot = resolve(rootPath);
|
||||||
|
const relativePath = virtualPath.replace(/^\/+/, "");
|
||||||
|
if (!relativePath) {
|
||||||
|
throw new Error("Path must refer to a file under the selected root");
|
||||||
|
}
|
||||||
|
const absolutePath = resolve(normalizedRoot, relativePath);
|
||||||
|
const rel = relative(normalizedRoot, absolutePath);
|
||||||
|
if (!rel || rel.startsWith("..") || isAbsolute(rel)) {
|
||||||
|
throw new Error("Path escapes selected local root");
|
||||||
|
}
|
||||||
|
return absolutePath;
|
||||||
|
}
|
||||||
|
|
||||||
|
function toVirtualPath(rootPath: string, absolutePath: string): string {
|
||||||
|
const normalizedRoot = resolve(rootPath);
|
||||||
|
const rel = relative(normalizedRoot, absolutePath);
|
||||||
|
if (!rel || rel.startsWith("..") || isAbsolute(rel)) {
|
||||||
|
return "/";
|
||||||
|
}
|
||||||
|
return `/${rel.replace(/\\/g, "/")}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type LocalRootMount = {
|
||||||
|
mount: string;
|
||||||
|
rootPath: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
function sanitizeMountName(rawMount: string): string {
|
||||||
|
const normalized = rawMount
|
||||||
|
.trim()
|
||||||
|
.toLowerCase()
|
||||||
|
.replace(/[^a-z0-9_-]+/g, "_")
|
||||||
|
.replace(/_+/g, "_")
|
||||||
|
.replace(/^[_-]+|[_-]+$/g, "");
|
||||||
|
return normalized || "root";
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildRootMounts(rootPaths: string[]): LocalRootMount[] {
|
||||||
|
const mounts: LocalRootMount[] = [];
|
||||||
|
const usedMounts = new Set<string>();
|
||||||
|
for (const rawRootPath of rootPaths) {
|
||||||
|
const normalizedRoot = resolve(rawRootPath);
|
||||||
|
const baseMount = sanitizeMountName(normalizedRoot.split(/[\\/]/).at(-1) || "root");
|
||||||
|
let mount = baseMount;
|
||||||
|
let suffix = 2;
|
||||||
|
while (usedMounts.has(mount)) {
|
||||||
|
mount = `${baseMount}-${suffix}`;
|
||||||
|
suffix += 1;
|
||||||
|
}
|
||||||
|
usedMounts.add(mount);
|
||||||
|
mounts.push({ mount, rootPath: normalizedRoot });
|
||||||
|
}
|
||||||
|
return mounts;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getAgentFilesystemMounts(): Promise<LocalRootMount[]> {
|
||||||
|
const rootPaths = await resolveCurrentRootPaths();
|
||||||
|
return buildRootMounts(rootPaths);
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseMountedVirtualPath(
|
||||||
|
virtualPath: string,
|
||||||
|
mounts: LocalRootMount[]
|
||||||
|
): {
|
||||||
|
mount: string;
|
||||||
|
subPath: string;
|
||||||
|
} {
|
||||||
|
if (!virtualPath.startsWith("/")) {
|
||||||
|
throw new Error("Path must start with '/'");
|
||||||
|
}
|
||||||
|
const trimmed = virtualPath.replace(/^\/+/, "");
|
||||||
|
if (!trimmed) {
|
||||||
|
throw new Error("Path must include a mounted root segment");
|
||||||
|
}
|
||||||
|
|
||||||
|
const [mount, ...rest] = trimmed.split("/");
|
||||||
|
const remainder = rest.join("/");
|
||||||
|
const directMount = mounts.find((entry) => entry.mount === mount);
|
||||||
|
if (!directMount) {
|
||||||
|
throw new Error(
|
||||||
|
`Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (!remainder) {
|
||||||
|
throw new Error("Path must include a file path under the mounted root");
|
||||||
|
}
|
||||||
|
return { mount, subPath: `/${remainder}` };
|
||||||
|
}
|
||||||
|
|
||||||
|
function findMountByName(mounts: LocalRootMount[], mountName: string): LocalRootMount | undefined {
|
||||||
|
return mounts.find((entry) => entry.mount === mountName);
|
||||||
|
}
|
||||||
|
|
||||||
|
function toMountedVirtualPath(mount: string, rootPath: string, absolutePath: string): string {
|
||||||
|
const relativePath = toVirtualPath(rootPath, absolutePath);
|
||||||
|
return `/${mount}${relativePath}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function resolveCurrentRootPaths(): Promise<string[]> {
|
||||||
|
const settings = await getAgentFilesystemSettings();
|
||||||
|
if (settings.localRootPaths.length === 0) {
|
||||||
|
throw new Error("No local filesystem roots selected");
|
||||||
|
}
|
||||||
|
return settings.localRootPaths;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function readAgentLocalFileText(
|
||||||
|
virtualPath: string
|
||||||
|
): Promise<{ path: string; content: string }> {
|
||||||
|
const rootPaths = await resolveCurrentRootPaths();
|
||||||
|
const mounts = buildRootMounts(rootPaths);
|
||||||
|
const { mount, subPath } = parseMountedVirtualPath(virtualPath, mounts);
|
||||||
|
const rootMount = findMountByName(mounts, mount);
|
||||||
|
if (!rootMount) {
|
||||||
|
throw new Error(
|
||||||
|
`Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const absolutePath = resolveVirtualPath(rootMount.rootPath, subPath);
|
||||||
|
const content = await readFile(absolutePath, "utf8");
|
||||||
|
return {
|
||||||
|
path: toMountedVirtualPath(rootMount.mount, rootMount.rootPath, absolutePath),
|
||||||
|
content,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function writeAgentLocalFileText(
|
||||||
|
virtualPath: string,
|
||||||
|
content: string
|
||||||
|
): Promise<{ path: string }> {
|
||||||
|
const rootPaths = await resolveCurrentRootPaths();
|
||||||
|
const mounts = buildRootMounts(rootPaths);
|
||||||
|
const { mount, subPath } = parseMountedVirtualPath(virtualPath, mounts);
|
||||||
|
const rootMount = findMountByName(mounts, mount);
|
||||||
|
if (!rootMount) {
|
||||||
|
throw new Error(
|
||||||
|
`Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let selectedAbsolutePath = resolveVirtualPath(rootMount.rootPath, subPath);
|
||||||
|
|
||||||
|
try {
|
||||||
|
await access(selectedAbsolutePath);
|
||||||
|
} catch {
|
||||||
|
// New files are created under the selected mounted root.
|
||||||
|
}
|
||||||
|
await mkdir(dirname(selectedAbsolutePath), { recursive: true });
|
||||||
|
await writeFile(selectedAbsolutePath, content, "utf8");
|
||||||
|
return {
|
||||||
|
path: toMountedVirtualPath(rootMount.mount, rootMount.rootPath, selectedAbsolutePath),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
@ -71,6 +71,10 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
||||||
// Browse files via native dialog
|
// Browse files via native dialog
|
||||||
browseFiles: () => ipcRenderer.invoke(IPC_CHANNELS.BROWSE_FILES),
|
browseFiles: () => ipcRenderer.invoke(IPC_CHANNELS.BROWSE_FILES),
|
||||||
readLocalFiles: (paths: string[]) => ipcRenderer.invoke(IPC_CHANNELS.READ_LOCAL_FILES, paths),
|
readLocalFiles: (paths: string[]) => ipcRenderer.invoke(IPC_CHANNELS.READ_LOCAL_FILES, paths),
|
||||||
|
readAgentLocalFileText: (virtualPath: string) =>
|
||||||
|
ipcRenderer.invoke(IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, virtualPath),
|
||||||
|
writeAgentLocalFileText: (virtualPath: string, content: string) =>
|
||||||
|
ipcRenderer.invoke(IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT, virtualPath, content),
|
||||||
|
|
||||||
// Auth token sync across windows
|
// Auth token sync across windows
|
||||||
getAuthTokens: () => ipcRenderer.invoke(IPC_CHANNELS.GET_AUTH_TOKENS),
|
getAuthTokens: () => ipcRenderer.invoke(IPC_CHANNELS.GET_AUTH_TOKENS),
|
||||||
|
|
@ -101,4 +105,14 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
||||||
analyticsCapture: (event: string, properties?: Record<string, unknown>) =>
|
analyticsCapture: (event: string, properties?: Record<string, unknown>) =>
|
||||||
ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_CAPTURE, { event, properties }),
|
ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_CAPTURE, { event, properties }),
|
||||||
getAnalyticsContext: () => ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_GET_CONTEXT),
|
getAnalyticsContext: () => ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_GET_CONTEXT),
|
||||||
|
// Agent filesystem mode
|
||||||
|
getAgentFilesystemSettings: () =>
|
||||||
|
ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS),
|
||||||
|
getAgentFilesystemMounts: () =>
|
||||||
|
ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS),
|
||||||
|
setAgentFilesystemSettings: (settings: {
|
||||||
|
mode?: "cloud" | "desktop_local_folder";
|
||||||
|
localRootPaths?: string[] | null;
|
||||||
|
}) => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS, settings),
|
||||||
|
pickAgentFilesystemRoot: () => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT),
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ import {
|
||||||
import { useChatSessionStateSync } from "@/hooks/use-chat-session-state";
|
import { useChatSessionStateSync } from "@/hooks/use-chat-session-state";
|
||||||
import { useMessagesSync } from "@/hooks/use-messages-sync";
|
import { useMessagesSync } from "@/hooks/use-messages-sync";
|
||||||
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||||
|
import { getAgentFilesystemSelection } from "@/lib/agent-filesystem";
|
||||||
import { getBearerToken } from "@/lib/auth-utils";
|
import { getBearerToken } from "@/lib/auth-utils";
|
||||||
import { convertToThreadMessage } from "@/lib/chat/message-utils";
|
import { convertToThreadMessage } from "@/lib/chat/message-utils";
|
||||||
import {
|
import {
|
||||||
|
|
@ -158,7 +159,7 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] {
|
||||||
/**
|
/**
|
||||||
* Tools that should render custom UI in the chat.
|
* Tools that should render custom UI in the chat.
|
||||||
*/
|
*/
|
||||||
const TOOLS_WITH_UI = new Set([
|
const BASE_TOOLS_WITH_UI = new Set([
|
||||||
"web_search",
|
"web_search",
|
||||||
"generate_podcast",
|
"generate_podcast",
|
||||||
"generate_report",
|
"generate_report",
|
||||||
|
|
@ -210,6 +211,7 @@ export default function NewChatPage() {
|
||||||
assistantMsgId: string;
|
assistantMsgId: string;
|
||||||
interruptData: Record<string, unknown>;
|
interruptData: Record<string, unknown>;
|
||||||
} | null>(null);
|
} | null>(null);
|
||||||
|
const toolsWithUI = useMemo(() => new Set([...BASE_TOOLS_WITH_UI]), []);
|
||||||
|
|
||||||
// Get disabled tools from the tool toggle UI
|
// Get disabled tools from the tool toggle UI
|
||||||
const disabledTools = useAtomValue(disabledToolsAtom);
|
const disabledTools = useAtomValue(disabledToolsAtom);
|
||||||
|
|
@ -656,6 +658,15 @@ export default function NewChatPage() {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||||
|
const selection = await getAgentFilesystemSelection();
|
||||||
|
if (
|
||||||
|
selection.filesystem_mode === "desktop_local_folder" &&
|
||||||
|
(!selection.local_filesystem_mounts ||
|
||||||
|
selection.local_filesystem_mounts.length === 0)
|
||||||
|
) {
|
||||||
|
toast.error("Select a local folder before using Local Folder mode.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Build message history for context
|
// Build message history for context
|
||||||
const messageHistory = messages
|
const messageHistory = messages
|
||||||
|
|
@ -691,6 +702,9 @@ export default function NewChatPage() {
|
||||||
chat_id: currentThreadId,
|
chat_id: currentThreadId,
|
||||||
user_query: userQuery.trim(),
|
user_query: userQuery.trim(),
|
||||||
search_space_id: searchSpaceId,
|
search_space_id: searchSpaceId,
|
||||||
|
filesystem_mode: selection.filesystem_mode,
|
||||||
|
client_platform: selection.client_platform,
|
||||||
|
local_filesystem_mounts: selection.local_filesystem_mounts,
|
||||||
messages: messageHistory,
|
messages: messageHistory,
|
||||||
mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined,
|
mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined,
|
||||||
mentioned_surfsense_doc_ids: hasSurfsenseDocIds
|
mentioned_surfsense_doc_ids: hasSurfsenseDocIds
|
||||||
|
|
@ -709,7 +723,7 @@ export default function NewChatPage() {
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
prev.map((m) =>
|
prev.map((m) =>
|
||||||
m.id === assistantMsgId
|
m.id === assistantMsgId
|
||||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
|
||||||
: m
|
: m
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
@ -724,7 +738,7 @@ export default function NewChatPage() {
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "tool-input-start":
|
case "tool-input-start":
|
||||||
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
|
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
|
@ -734,7 +748,7 @@ export default function NewChatPage() {
|
||||||
} else {
|
} else {
|
||||||
addToolCall(
|
addToolCall(
|
||||||
contentPartsState,
|
contentPartsState,
|
||||||
TOOLS_WITH_UI,
|
toolsWithUI,
|
||||||
parsed.toolCallId,
|
parsed.toolCallId,
|
||||||
parsed.toolName,
|
parsed.toolName,
|
||||||
parsed.input || {}
|
parsed.input || {}
|
||||||
|
|
@ -830,7 +844,7 @@ export default function NewChatPage() {
|
||||||
const tcId = `interrupt-${action.name}`;
|
const tcId = `interrupt-${action.name}`;
|
||||||
addToolCall(
|
addToolCall(
|
||||||
contentPartsState,
|
contentPartsState,
|
||||||
TOOLS_WITH_UI,
|
toolsWithUI,
|
||||||
tcId,
|
tcId,
|
||||||
action.name,
|
action.name,
|
||||||
action.args,
|
action.args,
|
||||||
|
|
@ -844,7 +858,7 @@ export default function NewChatPage() {
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
prev.map((m) =>
|
prev.map((m) =>
|
||||||
m.id === assistantMsgId
|
m.id === assistantMsgId
|
||||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
|
||||||
: m
|
: m
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
@ -871,7 +885,7 @@ export default function NewChatPage() {
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
|
|
||||||
// Skip persistence for interrupted messages -- handleResume will persist the final version
|
// Skip persistence for interrupted messages -- handleResume will persist the final version
|
||||||
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
|
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
||||||
if (contentParts.length > 0 && !wasInterrupted) {
|
if (contentParts.length > 0 && !wasInterrupted) {
|
||||||
try {
|
try {
|
||||||
const savedMessage = await appendMessage(currentThreadId, {
|
const savedMessage = await appendMessage(currentThreadId, {
|
||||||
|
|
@ -907,10 +921,10 @@ export default function NewChatPage() {
|
||||||
const hasContent = contentParts.some(
|
const hasContent = contentParts.some(
|
||||||
(part) =>
|
(part) =>
|
||||||
(part.type === "text" && part.text.length > 0) ||
|
(part.type === "text" && part.text.length > 0) ||
|
||||||
(part.type === "tool-call" && TOOLS_WITH_UI.has(part.toolName))
|
(part.type === "tool-call" && toolsWithUI.has(part.toolName))
|
||||||
);
|
);
|
||||||
if (hasContent && currentThreadId) {
|
if (hasContent && currentThreadId) {
|
||||||
const partialContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
|
const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
||||||
try {
|
try {
|
||||||
const savedMessage = await appendMessage(currentThreadId, {
|
const savedMessage = await appendMessage(currentThreadId, {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
|
|
@ -1074,6 +1088,7 @@ export default function NewChatPage() {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||||
|
const selection = await getAgentFilesystemSelection();
|
||||||
const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, {
|
const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
|
|
@ -1083,6 +1098,9 @@ export default function NewChatPage() {
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
search_space_id: searchSpaceId,
|
search_space_id: searchSpaceId,
|
||||||
decisions,
|
decisions,
|
||||||
|
filesystem_mode: selection.filesystem_mode,
|
||||||
|
client_platform: selection.client_platform,
|
||||||
|
local_filesystem_mounts: selection.local_filesystem_mounts,
|
||||||
}),
|
}),
|
||||||
signal: controller.signal,
|
signal: controller.signal,
|
||||||
});
|
});
|
||||||
|
|
@ -1095,7 +1113,7 @@ export default function NewChatPage() {
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
prev.map((m) =>
|
prev.map((m) =>
|
||||||
m.id === assistantMsgId
|
m.id === assistantMsgId
|
||||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
|
||||||
: m
|
: m
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
@ -1110,7 +1128,7 @@ export default function NewChatPage() {
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "tool-input-start":
|
case "tool-input-start":
|
||||||
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
|
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
|
@ -1122,7 +1140,7 @@ export default function NewChatPage() {
|
||||||
} else {
|
} else {
|
||||||
addToolCall(
|
addToolCall(
|
||||||
contentPartsState,
|
contentPartsState,
|
||||||
TOOLS_WITH_UI,
|
toolsWithUI,
|
||||||
parsed.toolCallId,
|
parsed.toolCallId,
|
||||||
parsed.toolName,
|
parsed.toolName,
|
||||||
parsed.input || {}
|
parsed.input || {}
|
||||||
|
|
@ -1173,7 +1191,7 @@ export default function NewChatPage() {
|
||||||
const tcId = `interrupt-${action.name}`;
|
const tcId = `interrupt-${action.name}`;
|
||||||
addToolCall(
|
addToolCall(
|
||||||
contentPartsState,
|
contentPartsState,
|
||||||
TOOLS_WITH_UI,
|
toolsWithUI,
|
||||||
tcId,
|
tcId,
|
||||||
action.name,
|
action.name,
|
||||||
action.args,
|
action.args,
|
||||||
|
|
@ -1190,7 +1208,7 @@ export default function NewChatPage() {
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
prev.map((m) =>
|
prev.map((m) =>
|
||||||
m.id === assistantMsgId
|
m.id === assistantMsgId
|
||||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
|
||||||
: m
|
: m
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
@ -1214,7 +1232,7 @@ export default function NewChatPage() {
|
||||||
|
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
|
|
||||||
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
|
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
||||||
if (contentParts.length > 0) {
|
if (contentParts.length > 0) {
|
||||||
try {
|
try {
|
||||||
const savedMessage = await appendMessage(resumeThreadId, {
|
const savedMessage = await appendMessage(resumeThreadId, {
|
||||||
|
|
@ -1406,6 +1424,7 @@ export default function NewChatPage() {
|
||||||
]);
|
]);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
const selection = await getAgentFilesystemSelection();
|
||||||
const response = await fetch(getRegenerateUrl(threadId), {
|
const response = await fetch(getRegenerateUrl(threadId), {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
|
|
@ -1416,6 +1435,9 @@ export default function NewChatPage() {
|
||||||
search_space_id: searchSpaceId,
|
search_space_id: searchSpaceId,
|
||||||
user_query: newUserQuery || null,
|
user_query: newUserQuery || null,
|
||||||
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
|
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
|
||||||
|
filesystem_mode: selection.filesystem_mode,
|
||||||
|
client_platform: selection.client_platform,
|
||||||
|
local_filesystem_mounts: selection.local_filesystem_mounts,
|
||||||
}),
|
}),
|
||||||
signal: controller.signal,
|
signal: controller.signal,
|
||||||
});
|
});
|
||||||
|
|
@ -1428,7 +1450,7 @@ export default function NewChatPage() {
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
prev.map((m) =>
|
prev.map((m) =>
|
||||||
m.id === assistantMsgId
|
m.id === assistantMsgId
|
||||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
|
||||||
: m
|
: m
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
@ -1443,7 +1465,7 @@ export default function NewChatPage() {
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "tool-input-start":
|
case "tool-input-start":
|
||||||
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
|
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
|
@ -1453,7 +1475,7 @@ export default function NewChatPage() {
|
||||||
} else {
|
} else {
|
||||||
addToolCall(
|
addToolCall(
|
||||||
contentPartsState,
|
contentPartsState,
|
||||||
TOOLS_WITH_UI,
|
toolsWithUI,
|
||||||
parsed.toolCallId,
|
parsed.toolCallId,
|
||||||
parsed.toolName,
|
parsed.toolName,
|
||||||
parsed.input || {}
|
parsed.input || {}
|
||||||
|
|
@ -1502,7 +1524,7 @@ export default function NewChatPage() {
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
|
|
||||||
// Persist messages after streaming completes
|
// Persist messages after streaming completes
|
||||||
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
|
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
||||||
if (contentParts.length > 0) {
|
if (contentParts.length > 0) {
|
||||||
try {
|
try {
|
||||||
// Persist user message (for both edit and reload modes, since backend deleted it)
|
// Persist user message (for both edit and reload modes, since backend deleted it)
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { BrainCog, Power, Rocket, Zap } from "lucide-react";
|
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { DEFAULT_SHORTCUTS, ShortcutRecorder } from "@/components/desktop/shortcut-recorder";
|
|
||||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||||
import { Label } from "@/components/ui/label";
|
import { Label } from "@/components/ui/label";
|
||||||
import {
|
import {
|
||||||
|
|
@ -24,9 +22,6 @@ export function DesktopContent() {
|
||||||
const [loading, setLoading] = useState(true);
|
const [loading, setLoading] = useState(true);
|
||||||
const [enabled, setEnabled] = useState(true);
|
const [enabled, setEnabled] = useState(true);
|
||||||
|
|
||||||
const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS);
|
|
||||||
const [shortcutsLoaded, setShortcutsLoaded] = useState(false);
|
|
||||||
|
|
||||||
const [searchSpaces, setSearchSpaces] = useState<SearchSpace[]>([]);
|
const [searchSpaces, setSearchSpaces] = useState<SearchSpace[]>([]);
|
||||||
const [activeSpaceId, setActiveSpaceId] = useState<string | null>(null);
|
const [activeSpaceId, setActiveSpaceId] = useState<string | null>(null);
|
||||||
|
|
||||||
|
|
@ -37,7 +32,6 @@ export function DesktopContent() {
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!api) {
|
if (!api) {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
setShortcutsLoaded(true);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -48,15 +42,13 @@ export function DesktopContent() {
|
||||||
|
|
||||||
Promise.all([
|
Promise.all([
|
||||||
api.getAutocompleteEnabled(),
|
api.getAutocompleteEnabled(),
|
||||||
api.getShortcuts?.() ?? Promise.resolve(null),
|
|
||||||
api.getActiveSearchSpace?.() ?? Promise.resolve(null),
|
api.getActiveSearchSpace?.() ?? Promise.resolve(null),
|
||||||
searchSpacesApiService.getSearchSpaces(),
|
searchSpacesApiService.getSearchSpaces(),
|
||||||
hasAutoLaunchApi ? api.getAutoLaunch() : Promise.resolve(null),
|
hasAutoLaunchApi ? api.getAutoLaunch() : Promise.resolve(null),
|
||||||
])
|
])
|
||||||
.then(([autoEnabled, config, spaceId, spaces, autoLaunch]) => {
|
.then(([autoEnabled, spaceId, spaces, autoLaunch]) => {
|
||||||
if (!mounted) return;
|
if (!mounted) return;
|
||||||
setEnabled(autoEnabled);
|
setEnabled(autoEnabled);
|
||||||
if (config) setShortcuts(config);
|
|
||||||
setActiveSpaceId(spaceId);
|
setActiveSpaceId(spaceId);
|
||||||
if (spaces) setSearchSpaces(spaces);
|
if (spaces) setSearchSpaces(spaces);
|
||||||
if (autoLaunch) {
|
if (autoLaunch) {
|
||||||
|
|
@ -65,12 +57,10 @@ export function DesktopContent() {
|
||||||
setAutoLaunchSupported(autoLaunch.supported);
|
setAutoLaunchSupported(autoLaunch.supported);
|
||||||
}
|
}
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
setShortcutsLoaded(true);
|
|
||||||
})
|
})
|
||||||
.catch(() => {
|
.catch(() => {
|
||||||
if (!mounted) return;
|
if (!mounted) return;
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
setShortcutsLoaded(true);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
|
|
@ -82,7 +72,7 @@ export function DesktopContent() {
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col items-center justify-center py-12 text-center">
|
<div className="flex flex-col items-center justify-center py-12 text-center">
|
||||||
<p className="text-sm text-muted-foreground">
|
<p className="text-sm text-muted-foreground">
|
||||||
Desktop settings are only available in the SurfSense desktop app.
|
App preferences are only available in the SurfSense desktop app.
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|
@ -101,24 +91,6 @@ export function DesktopContent() {
|
||||||
await api.setAutocompleteEnabled(checked);
|
await api.setAutocompleteEnabled(checked);
|
||||||
};
|
};
|
||||||
|
|
||||||
const updateShortcut = (
|
|
||||||
key: "generalAssist" | "quickAsk" | "autocomplete",
|
|
||||||
accelerator: string
|
|
||||||
) => {
|
|
||||||
setShortcuts((prev) => {
|
|
||||||
const updated = { ...prev, [key]: accelerator };
|
|
||||||
api.setShortcuts?.({ [key]: accelerator }).catch(() => {
|
|
||||||
toast.error("Failed to update shortcut");
|
|
||||||
});
|
|
||||||
return updated;
|
|
||||||
});
|
|
||||||
toast.success("Shortcut updated");
|
|
||||||
};
|
|
||||||
|
|
||||||
const resetShortcut = (key: "generalAssist" | "quickAsk" | "autocomplete") => {
|
|
||||||
updateShortcut(key, DEFAULT_SHORTCUTS[key]);
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleAutoLaunchToggle = async (checked: boolean) => {
|
const handleAutoLaunchToggle = async (checked: boolean) => {
|
||||||
if (!autoLaunchSupported || !api.setAutoLaunch) {
|
if (!autoLaunchSupported || !api.setAutoLaunch) {
|
||||||
toast.error("Please update the desktop app to configure launch on startup");
|
toast.error("Please update the desktop app to configure launch on startup");
|
||||||
|
|
@ -196,7 +168,6 @@ export function DesktopContent() {
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
||||||
<CardTitle className="text-base md:text-lg flex items-center gap-2">
|
<CardTitle className="text-base md:text-lg flex items-center gap-2">
|
||||||
<Power className="h-4 w-4" />
|
|
||||||
Launch on Startup
|
Launch on Startup
|
||||||
</CardTitle>
|
</CardTitle>
|
||||||
<CardDescription className="text-xs md:text-sm">
|
<CardDescription className="text-xs md:text-sm">
|
||||||
|
|
@ -245,56 +216,6 @@ export function DesktopContent() {
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
|
|
||||||
{/* Keyboard Shortcuts */}
|
|
||||||
<Card>
|
|
||||||
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
|
||||||
<CardTitle className="text-base md:text-lg">Keyboard Shortcuts</CardTitle>
|
|
||||||
<CardDescription className="text-xs md:text-sm">
|
|
||||||
Customize the global keyboard shortcuts for desktop features.
|
|
||||||
</CardDescription>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent className="px-3 md:px-6 pb-3 md:pb-6">
|
|
||||||
{shortcutsLoaded ? (
|
|
||||||
<div className="flex flex-col gap-3">
|
|
||||||
<ShortcutRecorder
|
|
||||||
value={shortcuts.generalAssist}
|
|
||||||
onChange={(accel) => updateShortcut("generalAssist", accel)}
|
|
||||||
onReset={() => resetShortcut("generalAssist")}
|
|
||||||
defaultValue={DEFAULT_SHORTCUTS.generalAssist}
|
|
||||||
label="General Assist"
|
|
||||||
description="Launch SurfSense instantly from any application"
|
|
||||||
icon={Rocket}
|
|
||||||
/>
|
|
||||||
<ShortcutRecorder
|
|
||||||
value={shortcuts.quickAsk}
|
|
||||||
onChange={(accel) => updateShortcut("quickAsk", accel)}
|
|
||||||
onReset={() => resetShortcut("quickAsk")}
|
|
||||||
defaultValue={DEFAULT_SHORTCUTS.quickAsk}
|
|
||||||
label="Quick Assist"
|
|
||||||
description="Select text anywhere, then ask AI to explain, rewrite, or act on it"
|
|
||||||
icon={Zap}
|
|
||||||
/>
|
|
||||||
<ShortcutRecorder
|
|
||||||
value={shortcuts.autocomplete}
|
|
||||||
onChange={(accel) => updateShortcut("autocomplete", accel)}
|
|
||||||
onReset={() => resetShortcut("autocomplete")}
|
|
||||||
defaultValue={DEFAULT_SHORTCUTS.autocomplete}
|
|
||||||
label="Extreme Assist"
|
|
||||||
description="AI drafts text using your screen context and knowledge base"
|
|
||||||
icon={BrainCog}
|
|
||||||
/>
|
|
||||||
<p className="text-[11px] text-muted-foreground">
|
|
||||||
Click a shortcut and press a new key combination to change it.
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<div className="flex justify-center py-4">
|
|
||||||
<Spinner size="sm" />
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
|
|
||||||
{/* Extreme Assist Toggle */}
|
{/* Extreme Assist Toggle */}
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue