Merge remote-tracking branch 'upstream/dev' into feat/obsidian-plugin

This commit is contained in:
Anish Sarkar 2026-04-24 21:34:55 +05:30
commit 9b1b9a90c0
175 changed files with 10592 additions and 2302 deletions

View file

@ -77,6 +77,8 @@ services:
- shared_temp:/shared_tmp
env_file:
- ../surfsense_backend/.env
extra_hosts:
- "host.docker.internal:host-gateway"
environment:
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
@ -118,6 +120,8 @@ services:
- shared_temp:/shared_tmp
env_file:
- ../surfsense_backend/.env
extra_hosts:
- "host.docker.internal:host-gateway"
environment:
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}

View file

@ -60,6 +60,8 @@ services:
- shared_temp:/shared_tmp
env_file:
- .env
extra_hosts:
- "host.docker.internal:host-gateway"
environment:
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
@ -100,6 +102,8 @@ services:
- shared_temp:/shared_tmp
env_file:
- .env
extra_hosts:
- "host.docker.internal:host-gateway"
environment:
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}

View file

@ -239,6 +239,9 @@ LLAMA_CLOUD_API_KEY=llx-nnn
# DAYTONA_TARGET=us
# DAYTONA_SNAPSHOT_ID=
# Desktop local filesystem mode (chat file tools run against a local folder root)
# ENABLE_DESKTOP_LOCAL_FILESYSTEM=FALSE
# OPTIONAL: Add these for LangSmith Observability
LANGSMITH_TRACING=true
LANGSMITH_ENDPOINT=https://api.smith.langchain.com

View file

@ -24,7 +24,6 @@ from deepagents.backends import StateBackend
from deepagents.graph import BASE_AGENT_PROMPT
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
from deepagents.middleware.summarization import create_summarization_middleware
from langchain.agents import create_agent
from langchain.agents.middleware import TodoListMiddleware
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
@ -34,18 +33,24 @@ from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.filesystem_backends import build_backend_resolver
from app.agents.new_chat.filesystem_selection import FilesystemSelection
from app.agents.new_chat.llm_config import AgentConfig
from app.agents.new_chat.middleware import (
DedupHITLToolCallsMiddleware,
FileIntentMiddleware,
KnowledgeBaseSearchMiddleware,
MemoryInjectionMiddleware,
SurfSenseFilesystemMiddleware,
)
from app.agents.new_chat.middleware.safe_summarization import (
create_safe_summarization_middleware,
)
from app.agents.new_chat.system_prompt import (
build_configurable_system_prompt,
build_surfsense_system_prompt,
)
from app.agents.new_chat.tools.registry import build_tools_async
from app.agents.new_chat.tools.registry import build_tools_async, get_connector_gated_tools
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger
@ -162,6 +167,7 @@ async def create_surfsense_deep_agent(
thread_visibility: ChatVisibility | None = None,
mentioned_document_ids: list[int] | None = None,
anon_session_id: str | None = None,
filesystem_selection: FilesystemSelection | None = None,
):
"""
Create a SurfSense deep agent with configurable tools and prompts.
@ -236,6 +242,8 @@ async def create_surfsense_deep_agent(
)
"""
_t_agent_total = time.perf_counter()
filesystem_selection = filesystem_selection or FilesystemSelection()
backend_resolver = build_backend_resolver(filesystem_selection)
# Discover available connectors and document types for this search space
available_connectors: list[str] | None = None
@ -285,105 +293,10 @@ async def create_surfsense_deep_agent(
"llm": llm,
}
# Disable Notion action tools if no Notion connector is configured
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
has_notion_connector = (
available_connectors is not None and "NOTION_CONNECTOR" in available_connectors
modified_disabled_tools.extend(
get_connector_gated_tools(available_connectors)
)
if not has_notion_connector:
notion_tools = [
"create_notion_page",
"update_notion_page",
"delete_notion_page",
]
modified_disabled_tools.extend(notion_tools)
# Disable Linear action tools if no Linear connector is configured
has_linear_connector = (
available_connectors is not None and "LINEAR_CONNECTOR" in available_connectors
)
if not has_linear_connector:
linear_tools = [
"create_linear_issue",
"update_linear_issue",
"delete_linear_issue",
]
modified_disabled_tools.extend(linear_tools)
# Disable Google Drive action tools if no Google Drive connector is configured
has_google_drive_connector = (
available_connectors is not None and "GOOGLE_DRIVE_FILE" in available_connectors
)
if not has_google_drive_connector:
google_drive_tools = [
"create_google_drive_file",
"delete_google_drive_file",
]
modified_disabled_tools.extend(google_drive_tools)
has_dropbox_connector = (
available_connectors is not None and "DROPBOX_FILE" in available_connectors
)
if not has_dropbox_connector:
modified_disabled_tools.extend(["create_dropbox_file", "delete_dropbox_file"])
has_onedrive_connector = (
available_connectors is not None and "ONEDRIVE_FILE" in available_connectors
)
if not has_onedrive_connector:
modified_disabled_tools.extend(["create_onedrive_file", "delete_onedrive_file"])
# Disable Google Calendar action tools if no Google Calendar connector is configured
has_google_calendar_connector = (
available_connectors is not None
and "GOOGLE_CALENDAR_CONNECTOR" in available_connectors
)
if not has_google_calendar_connector:
calendar_tools = [
"create_calendar_event",
"update_calendar_event",
"delete_calendar_event",
]
modified_disabled_tools.extend(calendar_tools)
# Disable Gmail action tools if no Gmail connector is configured
has_gmail_connector = (
available_connectors is not None
and "GOOGLE_GMAIL_CONNECTOR" in available_connectors
)
if not has_gmail_connector:
gmail_tools = [
"create_gmail_draft",
"update_gmail_draft",
"send_gmail_email",
"trash_gmail_email",
]
modified_disabled_tools.extend(gmail_tools)
# Disable Jira action tools if no Jira connector is configured
has_jira_connector = (
available_connectors is not None and "JIRA_CONNECTOR" in available_connectors
)
if not has_jira_connector:
jira_tools = [
"create_jira_issue",
"update_jira_issue",
"delete_jira_issue",
]
modified_disabled_tools.extend(jira_tools)
# Disable Confluence action tools if no Confluence connector is configured
has_confluence_connector = (
available_connectors is not None
and "CONFLUENCE_CONNECTOR" in available_connectors
)
if not has_confluence_connector:
confluence_tools = [
"create_confluence_page",
"update_confluence_page",
"delete_confluence_page",
]
modified_disabled_tools.extend(confluence_tools)
# Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware.
if "search_knowledge_base" not in modified_disabled_tools:
@ -407,6 +320,20 @@ async def create_surfsense_deep_agent(
_t0 = time.perf_counter()
_enabled_tool_names = {t.name for t in tools}
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
# Collect generic MCP connector info so the system prompt can route queries
# to their tools instead of falling back to "not in knowledge base".
_mcp_connector_tools: dict[str, list[str]] = {}
for t in tools:
meta = getattr(t, "metadata", None) or {}
if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"):
_mcp_connector_tools.setdefault(
meta["mcp_connector_name"], [],
).append(t.name)
if _mcp_connector_tools:
_perf_log.info("MCP connector tool routing: %s", _mcp_connector_tools)
if agent_config is not None:
system_prompt = build_configurable_system_prompt(
custom_system_instructions=agent_config.system_instructions,
@ -415,12 +342,14 @@ async def create_surfsense_deep_agent(
thread_visibility=thread_visibility,
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools,
)
else:
system_prompt = build_surfsense_system_prompt(
thread_visibility=thread_visibility,
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools,
)
_perf_log.info(
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
@ -437,12 +366,15 @@ async def create_surfsense_deep_agent(
gp_middleware = [
TodoListMiddleware(),
_memory_middleware,
FileIntentMiddleware(llm=llm),
SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_selection.mode,
search_space_id=search_space_id,
created_by_id=user_id,
thread_id=thread_id,
),
create_summarization_middleware(llm, StateBackend),
create_safe_summarization_middleware(llm, StateBackend),
PatchToolCallsMiddleware(),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
@ -458,21 +390,25 @@ async def create_surfsense_deep_agent(
deepagent_middleware = [
TodoListMiddleware(),
_memory_middleware,
FileIntentMiddleware(llm=llm),
KnowledgeBaseSearchMiddleware(
llm=llm,
search_space_id=search_space_id,
filesystem_mode=filesystem_selection.mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
anon_session_id=anon_session_id,
),
SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_selection.mode,
search_space_id=search_space_id,
created_by_id=user_id,
thread_id=thread_id,
),
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
create_summarization_middleware(llm, StateBackend),
create_safe_summarization_middleware(llm, StateBackend),
PatchToolCallsMiddleware(),
DedupHITLToolCallsMiddleware(agent_tools=tools),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),

View file

@ -4,7 +4,15 @@ Context schema definitions for SurfSense agents.
This module defines the custom state schema used by the SurfSense deep agent.
"""
from typing import TypedDict
from typing import NotRequired, TypedDict
class FileOperationContractState(TypedDict):
intent: str
confidence: float
suggested_path: str
timestamp: str
turn_id: str
class SurfSenseContextSchema(TypedDict):
@ -24,5 +32,8 @@ class SurfSenseContextSchema(TypedDict):
"""
search_space_id: int
file_operation_contract: NotRequired[FileOperationContractState]
turn_id: NotRequired[str]
request_id: NotRequired[str]
# These are runtime-injected and won't be serialized
# db_session and connector_service are passed when invoking the agent

View file

@ -0,0 +1,42 @@
"""Filesystem backend resolver for cloud and desktop-local modes."""
from __future__ import annotations
from collections.abc import Callable
from functools import lru_cache
from deepagents.backends.state import StateBackend
from langgraph.prebuilt.tool_node import ToolRuntime
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
MultiRootLocalFolderBackend,
)
@lru_cache(maxsize=64)
def _cached_multi_root_backend(
mounts: tuple[tuple[str, str], ...],
) -> MultiRootLocalFolderBackend:
return MultiRootLocalFolderBackend(mounts)
def build_backend_resolver(
selection: FilesystemSelection,
) -> Callable[[ToolRuntime], StateBackend | MultiRootLocalFolderBackend]:
"""Create deepagents backend resolver for the selected filesystem mode."""
if selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER and selection.local_mounts:
def _resolve_local(_runtime: ToolRuntime) -> MultiRootLocalFolderBackend:
mounts = tuple(
(entry.mount_id, entry.root_path) for entry in selection.local_mounts
)
return _cached_multi_root_backend(mounts)
return _resolve_local
def _resolve_cloud(runtime: ToolRuntime) -> StateBackend:
return StateBackend(runtime)
return _resolve_cloud

View file

@ -0,0 +1,41 @@
"""Filesystem mode contracts and selection helpers for chat sessions."""
from __future__ import annotations
from dataclasses import dataclass
from enum import StrEnum
class FilesystemMode(StrEnum):
"""Supported filesystem backends for agent tool execution."""
CLOUD = "cloud"
DESKTOP_LOCAL_FOLDER = "desktop_local_folder"
class ClientPlatform(StrEnum):
"""Client runtime reported by the caller."""
WEB = "web"
DESKTOP = "desktop"
@dataclass(slots=True)
class LocalFilesystemMount:
"""Canonical mount mapping provided by desktop runtime."""
mount_id: str
root_path: str
@dataclass(slots=True)
class FilesystemSelection:
"""Resolved filesystem selection for a single chat request."""
mode: FilesystemMode = FilesystemMode.CLOUD
client_platform: ClientPlatform = ClientPlatform.WEB
local_mounts: tuple[LocalFilesystemMount, ...] = ()
@property
def is_local_mode(self) -> bool:
return self.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER

View file

@ -6,6 +6,9 @@ from app.agents.new_chat.middleware.dedup_tool_calls import (
from app.agents.new_chat.middleware.filesystem import (
SurfSenseFilesystemMiddleware,
)
from app.agents.new_chat.middleware.file_intent import (
FileIntentMiddleware,
)
from app.agents.new_chat.middleware.knowledge_search import (
KnowledgeBaseSearchMiddleware,
)
@ -15,6 +18,7 @@ from app.agents.new_chat.middleware.memory_injection import (
__all__ = [
"DedupHITLToolCallsMiddleware",
"FileIntentMiddleware",
"KnowledgeBaseSearchMiddleware",
"MemoryInjectionMiddleware",
"SurfSenseFilesystemMiddleware",

View file

@ -0,0 +1,352 @@
"""Semantic file-intent routing middleware for new chat turns.
This middleware classifies the latest human turn into a small intent set:
- chat_only
- file_write
- file_read
For ``file_write`` turns it injects a strict system contract so the model
uses filesystem tools before claiming success, and provides a deterministic
fallback path when no filename is specified by the user.
"""
from __future__ import annotations
import json
import logging
import re
from datetime import UTC, datetime
from enum import StrEnum
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langgraph.runtime import Runtime
from pydantic import BaseModel, Field, ValidationError
logger = logging.getLogger(__name__)
class FileOperationIntent(StrEnum):
CHAT_ONLY = "chat_only"
FILE_WRITE = "file_write"
FILE_READ = "file_read"
class FileIntentPlan(BaseModel):
intent: FileOperationIntent = Field(
description="Primary user intent for this turn."
)
confidence: float = Field(
ge=0.0,
le=1.0,
default=0.5,
description="Model confidence in the selected intent.",
)
suggested_filename: str | None = Field(
default=None,
description="Optional filename (e.g. notes.md) inferred from user request.",
)
suggested_directory: str | None = Field(
default=None,
description=(
"Optional directory path (e.g. /reports/q2 or reports/q2) inferred from "
"user request."
),
)
suggested_path: str | None = Field(
default=None,
description=(
"Optional full file path (e.g. /reports/q2/summary.md). If present, this "
"takes precedence over suggested_directory + suggested_filename."
),
)
def _extract_text_from_message(message: BaseMessage) -> str:
content = getattr(message, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
return "\n".join(part for part in parts if part)
return str(content)
def _extract_json_payload(text: str) -> str:
stripped = text.strip()
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
if fenced:
return fenced.group(1)
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start : end + 1]
return stripped
def _sanitize_filename(value: str) -> str:
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
name = re.sub(r"\s+", "-", name)
name = name.strip("._-")
if not name:
name = "note"
if len(name) > 80:
name = name[:80].rstrip("-_.")
return name
def _sanitize_path_segment(value: str) -> str:
segment = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
segment = re.sub(r"\s+", "_", segment)
segment = segment.strip("._-")
return segment
def _infer_text_file_extension(user_text: str) -> str:
lowered = user_text.lower()
if any(token in lowered for token in ("json", ".json")):
return ".json"
if any(token in lowered for token in ("yaml", "yml", ".yaml", ".yml")):
return ".yaml"
if any(token in lowered for token in ("csv", ".csv")):
return ".csv"
if any(token in lowered for token in ("python", ".py")):
return ".py"
if any(token in lowered for token in ("typescript", ".ts", ".tsx")):
return ".ts"
if any(token in lowered for token in ("javascript", ".js", ".mjs", ".cjs")):
return ".js"
if any(token in lowered for token in ("html", ".html")):
return ".html"
if any(token in lowered for token in ("css", ".css")):
return ".css"
if any(token in lowered for token in ("sql", ".sql")):
return ".sql"
if any(token in lowered for token in ("toml", ".toml")):
return ".toml"
if any(token in lowered for token in ("ini", ".ini")):
return ".ini"
if any(token in lowered for token in ("xml", ".xml")):
return ".xml"
if any(token in lowered for token in ("markdown", ".md", "readme")):
return ".md"
return ".md"
def _normalize_directory(value: str) -> str:
raw = value.strip().replace("\\", "/")
raw = raw.strip("/")
if not raw:
return ""
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
parts = [part for part in parts if part]
return "/".join(parts)
def _normalize_file_path(value: str) -> str:
raw = value.strip().replace("\\", "/").strip()
if not raw:
return ""
had_trailing_slash = raw.endswith("/")
raw = raw.strip("/")
if not raw:
return ""
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
parts = [part for part in parts if part]
if not parts:
return ""
if had_trailing_slash:
return f"/{'/'.join(parts)}/"
return f"/{'/'.join(parts)}"
def _infer_directory_from_user_text(user_text: str) -> str | None:
patterns = (
r"\b(?:in|inside|under)\s+(?:the\s+)?([a-zA-Z0-9 _\-/]+?)\s+folder\b",
r"\b(?:in|inside|under)\s+([a-zA-Z0-9 _\-/]+?)\b",
)
lowered = user_text.lower()
for pattern in patterns:
match = re.search(pattern, lowered, flags=re.IGNORECASE)
if not match:
continue
candidate = match.group(1).strip()
if candidate in {"the", "a", "an"}:
continue
normalized = _normalize_directory(candidate)
if normalized:
return normalized
return None
def _fallback_path(
suggested_filename: str | None,
*,
suggested_directory: str | None = None,
suggested_path: str | None = None,
user_text: str,
) -> str:
default_extension = _infer_text_file_extension(user_text)
inferred_dir = _infer_directory_from_user_text(user_text)
sanitized_filename = ""
if suggested_filename:
sanitized_filename = _sanitize_filename(suggested_filename)
if sanitized_filename.lower().endswith(".txt"):
sanitized_filename = f"{sanitized_filename[:-4]}.md"
if not sanitized_filename:
sanitized_filename = f"notes{default_extension}"
elif "." not in sanitized_filename:
sanitized_filename = f"{sanitized_filename}{default_extension}"
normalized_suggested_path = (
_normalize_file_path(suggested_path) if suggested_path else ""
)
if normalized_suggested_path:
if normalized_suggested_path.endswith("/"):
return f"{normalized_suggested_path.rstrip('/')}/{sanitized_filename}"
return normalized_suggested_path
directory = _normalize_directory(suggested_directory or "")
if not directory and inferred_dir:
directory = inferred_dir
if directory:
return f"/{directory}/{sanitized_filename}"
return f"/{sanitized_filename}"
def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str:
return (
"Classify the latest user request into a filesystem intent for an AI agent.\n"
"Return JSON only with this exact schema:\n"
'{"intent":"chat_only|file_write|file_read","confidence":0.0,"suggested_filename":"string or null","suggested_directory":"string or null","suggested_path":"string or null"}\n\n'
"Rules:\n"
"- Use semantic intent, not literal keywords.\n"
"- file_write: user asks to create/save/write/update/edit content as a file.\n"
"- file_read: user asks to open/read/list/search existing files.\n"
"- chat_only: conversational/analysis responses without required file operations.\n"
"- For file_write, choose a concise semantic suggested_filename and match the requested format.\n"
"- If the user mentions a folder/directory, populate suggested_directory.\n"
"- If user specifies an explicit full path, populate suggested_path.\n"
"- Use extensions that match user intent (e.g. .md, .json, .yaml, .csv, .py, .ts, .js, .html, .css, .sql).\n"
"- Do not use .txt; prefer .md for generic text notes.\n"
"- Do not include dates or timestamps in suggested_filename unless explicitly requested.\n"
"- Never include markdown or explanation.\n\n"
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
f"Latest user message:\n{user_text}"
)
def _build_recent_conversation(messages: list[BaseMessage], *, max_messages: int = 6) -> str:
rows: list[str] = []
for msg in messages[-max_messages:]:
role = "user" if isinstance(msg, HumanMessage) else "assistant"
text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip()
if text:
rows.append(f"{role}: {text[:280]}")
return "\n".join(rows)
class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Classify file intent and inject a strict file-write contract."""
tools = ()
def __init__(self, *, llm: BaseChatModel | None = None) -> None:
self.llm = llm
async def _classify_intent(
self, *, messages: list[BaseMessage], user_text: str
) -> FileIntentPlan:
if self.llm is None:
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
prompt = _build_classifier_prompt(
recent_conversation=_build_recent_conversation(messages),
user_text=user_text,
)
try:
response = await self.llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
payload = json.loads(_extract_json_payload(_extract_text_from_message(response)))
plan = FileIntentPlan.model_validate(payload)
return plan
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
logger.warning("File intent classifier returned invalid output: %s", exc)
except Exception as exc: # pragma: no cover - defensive fallback
logger.warning("File intent classifier failed: %s", exc)
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
messages = state.get("messages") or []
if not messages:
return None
last_human: HumanMessage | None = None
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
last_human = msg
break
if last_human is None:
return None
user_text = _extract_text_from_message(last_human).strip()
if not user_text:
return None
plan = await self._classify_intent(messages=messages, user_text=user_text)
suggested_path = _fallback_path(
plan.suggested_filename,
suggested_directory=plan.suggested_directory,
suggested_path=plan.suggested_path,
user_text=user_text,
)
contract = {
"intent": plan.intent.value,
"confidence": plan.confidence,
"suggested_path": suggested_path,
"timestamp": datetime.now(UTC).isoformat(),
"turn_id": state.get("turn_id", ""),
}
if plan.intent != FileOperationIntent.FILE_WRITE:
return {"file_operation_contract": contract}
contract_msg = SystemMessage(
content=(
"<file_operation_contract>\n"
"This turn intent is file_write.\n"
f"Suggested default path: {suggested_path}\n"
"Rules:\n"
"- You MUST call write_file or edit_file before claiming success.\n"
"- If no path is provided by the user, use the suggested default path.\n"
"- Do not claim a file was created/updated unless tool output confirms it.\n"
"- If the write/edit fails, clearly report failure instead of success.\n"
"- Do not include timestamps or dates in generated file content unless the user explicitly asks for them.\n"
"- For open-ended requests (e.g., random note), generate useful concrete content, not placeholders.\n"
"</file_operation_contract>"
)
)
# Insert just before the latest human turn so it applies to this request.
new_messages = list(messages)
insert_at = max(len(new_messages) - 1, 0)
new_messages.insert(insert_at, contract_msg)
return {"messages": new_messages, "file_operation_contract": contract}

View file

@ -26,6 +26,10 @@ from langchain_core.tools import BaseTool, StructuredTool
from langgraph.types import Command
from sqlalchemy import delete, select
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
MultiRootLocalFolderBackend,
)
from app.agents.new_chat.sandbox import (
_evict_sandbox_cache,
delete_sandbox,
@ -50,6 +54,8 @@ SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = """## Following Conventions
- Read files before editing understand existing content before making changes.
- Mimic existing style, naming conventions, and patterns.
- Never claim a file was created/updated unless filesystem tool output confirms success.
- If a file write/edit fails, explicitly report the failure.
## Filesystem Tools
@ -109,13 +115,20 @@ Usage:
- Use chunk IDs (`<chunk id='...'>`) as citations in answers.
"""
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new file to the in-memory filesystem (session-only).
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new text file to the in-memory filesystem (session-only).
Use this to create scratch/working files during the conversation. Files created
here are ephemeral and will not be saved to the user's knowledge base.
To permanently save a document to the user's knowledge base, use the
`save_document` tool instead.
Supported outputs include common LLM-friendly text formats like markdown, json,
yaml, csv, xml, html, css, sql, and code files.
When creating content from open-ended prompts, produce concrete and useful text,
not placeholders. Avoid adding dates/timestamps unless the user explicitly asks
for them.
"""
SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files.
@ -182,11 +195,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
def __init__(
self,
*,
backend: Any = None,
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
search_space_id: int | None = None,
created_by_id: str | None = None,
thread_id: int | str | None = None,
tool_token_limit_before_evict: int | None = 20000,
) -> None:
self._filesystem_mode = filesystem_mode
self._search_space_id = search_space_id
self._created_by_id = created_by_id
self._thread_id = thread_id
@ -204,8 +220,17 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
" extract the data, write it as a clean file (CSV, JSON, etc.),"
" and then run your code against it."
)
if filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
system_prompt += (
"\n\n## Local Folder Mode"
"\n\nThis chat is running in desktop local-folder mode."
" Keep all file operations local. Do not use save_document."
" Always use mount-prefixed absolute paths like /<folder>/file.ext."
" If you are unsure which mounts are available, call ls('/') first."
)
super().__init__(
backend=backend,
system_prompt=system_prompt,
custom_tool_descriptions={
"ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION,
@ -219,7 +244,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
max_execute_timeout=self._MAX_EXECUTE_TIMEOUT,
)
self.tools = [t for t in self.tools if t.name != "execute"]
self.tools.append(self._create_save_document_tool())
if self._should_persist_documents():
self.tools.append(self._create_save_document_tool())
if self._sandbox_available:
self.tools.append(self._create_execute_code_tool())
@ -637,15 +663,25 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
runtime: ToolRuntime[None, FilesystemState],
) -> Command | str:
resolved_backend = self._get_backend(runtime)
target_path = self._resolve_write_target_path(file_path, runtime)
try:
validated_path = validate_path(file_path)
validated_path = validate_path(target_path)
except ValueError as exc:
return f"Error: {exc}"
res: WriteResult = resolved_backend.write(validated_path, content)
if res.error:
return res.error
verify_error = self._verify_written_content_sync(
backend=resolved_backend,
path=validated_path,
expected_content=content,
)
if verify_error:
return verify_error
if not self._is_kb_document(validated_path):
if self._should_persist_documents() and not self._is_kb_document(
validated_path
):
persist_result = self._run_async_blocking(
self._persist_new_document(
file_path=validated_path, content=content
@ -682,15 +718,25 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
runtime: ToolRuntime[None, FilesystemState],
) -> Command | str:
resolved_backend = self._get_backend(runtime)
target_path = self._resolve_write_target_path(file_path, runtime)
try:
validated_path = validate_path(file_path)
validated_path = validate_path(target_path)
except ValueError as exc:
return f"Error: {exc}"
res: WriteResult = await resolved_backend.awrite(validated_path, content)
if res.error:
return res.error
verify_error = await self._verify_written_content_async(
backend=resolved_backend,
path=validated_path,
expected_content=content,
)
if verify_error:
return verify_error
if not self._is_kb_document(validated_path):
if self._should_persist_documents() and not self._is_kb_document(
validated_path
):
persist_result = await self._persist_new_document(
file_path=validated_path,
content=content,
@ -726,6 +772,164 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
"""Return True for paths under /documents/ (KB-sourced, XML-wrapped)."""
return path.startswith("/documents/")
def _should_persist_documents(self) -> bool:
"""Only cloud mode persists file content to Document/Chunk tables."""
return self._filesystem_mode == FilesystemMode.CLOUD
def _default_mount_prefix(self, runtime: ToolRuntime[None, FilesystemState]) -> str:
backend = self._get_backend(runtime)
if isinstance(backend, MultiRootLocalFolderBackend):
return f"/{backend.default_mount()}"
return ""
def _normalize_local_mount_path(
self, candidate: str, runtime: ToolRuntime[None, FilesystemState]
) -> str:
backend = self._get_backend(runtime)
mount_prefix = self._default_mount_prefix(runtime)
normalized_candidate = re.sub(r"/+", "/", candidate.strip().replace("\\", "/"))
if not mount_prefix or not isinstance(backend, MultiRootLocalFolderBackend):
if normalized_candidate.startswith("/"):
return normalized_candidate
return f"/{normalized_candidate.lstrip('/')}"
mount_names = set(backend.list_mounts())
if normalized_candidate.startswith("/"):
first_segment = normalized_candidate.lstrip("/").split("/", 1)[0]
if first_segment in mount_names:
return normalized_candidate
return f"{mount_prefix}{normalized_candidate}"
relative = normalized_candidate.lstrip("/")
first_segment = relative.split("/", 1)[0]
if first_segment in mount_names:
return f"/{relative}"
return f"{mount_prefix}/{relative}"
def _get_contract_suggested_path(
self, runtime: ToolRuntime[None, FilesystemState]
) -> str:
contract = runtime.state.get("file_operation_contract") or {}
suggested = contract.get("suggested_path")
if isinstance(suggested, str) and suggested.strip():
cleaned = suggested.strip()
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
return self._normalize_local_mount_path(cleaned, runtime)
return cleaned
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
mount_prefix = self._default_mount_prefix(runtime)
if mount_prefix:
return f"{mount_prefix}/notes.md"
return "/notes.md"
def _resolve_write_target_path(
self,
file_path: str,
runtime: ToolRuntime[None, FilesystemState],
) -> str:
candidate = file_path.strip()
if not candidate:
return self._get_contract_suggested_path(runtime)
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
return self._normalize_local_mount_path(candidate, runtime)
if not candidate.startswith("/"):
return f"/{candidate.lstrip('/')}"
return candidate
@staticmethod
def _is_error_text(value: str) -> bool:
return value.startswith("Error:")
@staticmethod
def _read_for_verification_sync(backend: Any, path: str) -> str:
read_raw = getattr(backend, "read_raw", None)
if callable(read_raw):
return read_raw(path)
return backend.read(path, offset=0, limit=200000)
@staticmethod
async def _read_for_verification_async(backend: Any, path: str) -> str:
aread_raw = getattr(backend, "aread_raw", None)
if callable(aread_raw):
return await aread_raw(path)
return await backend.aread(path, offset=0, limit=200000)
def _verify_written_content_sync(
self,
*,
backend: Any,
path: str,
expected_content: str,
) -> str | None:
actual = self._read_for_verification_sync(backend, path)
if self._is_error_text(actual):
return f"Error: could not verify written file '{path}'."
if actual.rstrip() != expected_content.rstrip():
return (
"Error: file write verification failed; expected content was not fully written "
f"to '{path}'."
)
return None
async def _verify_written_content_async(
self,
*,
backend: Any,
path: str,
expected_content: str,
) -> str | None:
actual = await self._read_for_verification_async(backend, path)
if self._is_error_text(actual):
return f"Error: could not verify written file '{path}'."
if actual.rstrip() != expected_content.rstrip():
return (
"Error: file write verification failed; expected content was not fully written "
f"to '{path}'."
)
return None
def _verify_edited_content_sync(
self,
*,
backend: Any,
path: str,
new_string: str,
) -> tuple[str | None, str | None]:
updated_content = self._read_for_verification_sync(backend, path)
if self._is_error_text(updated_content):
return (
f"Error: could not verify edited file '{path}'.",
None,
)
if new_string and new_string not in updated_content:
return (
"Error: edit verification failed; updated content was not found in "
f"'{path}'.",
None,
)
return None, updated_content
async def _verify_edited_content_async(
self,
*,
backend: Any,
path: str,
new_string: str,
) -> tuple[str | None, str | None]:
updated_content = await self._read_for_verification_async(backend, path)
if self._is_error_text(updated_content):
return (
f"Error: could not verify edited file '{path}'.",
None,
)
if new_string and new_string not in updated_content:
return (
"Error: edit verification failed; updated content was not found in "
f"'{path}'.",
None,
)
return None, updated_content
def _create_edit_file_tool(self) -> BaseTool:
"""Create edit_file with DB persistence (skipped for KB documents)."""
tool_description = (
@ -754,8 +958,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
] = False,
) -> Command | str:
resolved_backend = self._get_backend(runtime)
target_path = self._resolve_write_target_path(file_path, runtime)
try:
validated_path = validate_path(file_path)
validated_path = validate_path(target_path)
except ValueError as exc:
return f"Error: {exc}"
res: EditResult = resolved_backend.edit(
@ -767,13 +972,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
if res.error:
return res.error
if not self._is_kb_document(validated_path):
read_result = resolved_backend.read(
validated_path, offset=0, limit=200000
)
if read_result.error or read_result.file_data is None:
return f"Error: could not reload edited file '{validated_path}' for persistence."
updated_content = read_result.file_data["content"]
verify_error, updated_content = self._verify_edited_content_sync(
backend=resolved_backend,
path=validated_path,
new_string=new_string,
)
if verify_error:
return verify_error
if self._should_persist_documents() and not self._is_kb_document(
validated_path
):
if updated_content is None:
return (
f"Error: could not reload edited file '{validated_path}' for "
"persistence."
)
persist_result = self._run_async_blocking(
self._persist_edited_document(
file_path=validated_path,
@ -818,8 +1032,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
] = False,
) -> Command | str:
resolved_backend = self._get_backend(runtime)
target_path = self._resolve_write_target_path(file_path, runtime)
try:
validated_path = validate_path(file_path)
validated_path = validate_path(target_path)
except ValueError as exc:
return f"Error: {exc}"
res: EditResult = await resolved_backend.aedit(
@ -831,13 +1046,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
if res.error:
return res.error
if not self._is_kb_document(validated_path):
read_result = await resolved_backend.aread(
validated_path, offset=0, limit=200000
)
if read_result.error or read_result.file_data is None:
return f"Error: could not reload edited file '{validated_path}' for persistence."
updated_content = read_result.file_data["content"]
verify_error, updated_content = await self._verify_edited_content_async(
backend=resolved_backend,
path=validated_path,
new_string=new_string,
)
if verify_error:
return verify_error
if self._should_persist_documents() and not self._is_kb_document(
validated_path
):
if updated_content is None:
return (
f"Error: could not reload edited file '{validated_path}' for "
"persistence."
)
persist_error = await self._persist_edited_document(
file_path=validated_path,
updated_content=updated_content,

View file

@ -28,6 +28,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.db import (
NATIVE_TO_LEGACY_DOCTYPE,
Chunk,
@ -857,6 +858,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
*,
llm: BaseChatModel | None = None,
search_space_id: int,
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
top_k: int = 10,
@ -865,6 +867,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
) -> None:
self.llm = llm
self.search_space_id = search_space_id
self.filesystem_mode = filesystem_mode
self.available_connectors = available_connectors
self.available_document_types = available_document_types
self.top_k = top_k
@ -996,6 +999,9 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
messages = state.get("messages") or []
if not messages:
return None
if self.filesystem_mode != FilesystemMode.CLOUD:
# Local-folder mode should not seed cloud KB documents into filesystem.
return None
last_human = None
for msg in reversed(messages):

View file

@ -0,0 +1,316 @@
"""Desktop local-folder filesystem backend for deepagents tools."""
from __future__ import annotations
import asyncio
import fnmatch
import os
import threading
from pathlib import Path
from deepagents.backends.protocol import (
EditResult,
FileDownloadResponse,
FileInfo,
FileUploadResponse,
GrepMatch,
WriteResult,
)
from deepagents.backends.utils import (
create_file_data,
format_read_response,
perform_string_replacement,
)
_INVALID_PATH = "invalid_path"
_FILE_NOT_FOUND = "file_not_found"
_IS_DIRECTORY = "is_directory"
class LocalFolderBackend:
"""Filesystem backend rooted to a single local folder."""
def __init__(self, root_path: str) -> None:
root = Path(root_path).expanduser().resolve()
if not root.exists() or not root.is_dir():
msg = f"Local filesystem root does not exist or is not a directory: {root_path}"
raise ValueError(msg)
self._root = root
self._locks: dict[str, threading.Lock] = {}
self._locks_mu = threading.Lock()
def _lock_for(self, path: str) -> threading.Lock:
with self._locks_mu:
if path not in self._locks:
self._locks[path] = threading.Lock()
return self._locks[path]
def _resolve_virtual(self, virtual_path: str, *, allow_root: bool = False) -> Path:
if not virtual_path.startswith("/"):
msg = f"Invalid path (must be absolute): {virtual_path}"
raise ValueError(msg)
rel = virtual_path.lstrip("/")
candidate = self._root if rel == "" else (self._root / rel)
resolved = candidate.resolve()
if not allow_root and resolved == self._root:
msg = "Path must refer to a file or child directory under root"
raise ValueError(msg)
if not resolved.is_relative_to(self._root):
msg = f"Path escapes local filesystem root: {virtual_path}"
raise ValueError(msg)
return resolved
@staticmethod
def _to_virtual(path: Path, root: Path) -> str:
rel = path.relative_to(root).as_posix()
return "/" if rel == "." else f"/{rel}"
def _write_text_atomic(self, path: Path, content: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
temp_path = path.with_suffix(f"{path.suffix}.tmp")
temp_path.write_text(content, encoding="utf-8")
os.replace(temp_path, path)
def ls_info(self, path: str) -> list[FileInfo]:
try:
target = self._resolve_virtual(path, allow_root=True)
except ValueError:
return []
if not target.exists() or not target.is_dir():
return []
infos: list[FileInfo] = []
for child in sorted(target.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())):
infos.append(
FileInfo(
path=self._to_virtual(child, self._root),
is_dir=child.is_dir(),
size=child.stat().st_size if child.is_file() else 0,
modified_at=str(child.stat().st_mtime),
)
)
return infos
async def als_info(self, path: str) -> list[FileInfo]:
return await asyncio.to_thread(self.ls_info, path)
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
try:
path = self._resolve_virtual(file_path)
except ValueError:
return f"Error: Invalid path '{file_path}'"
if not path.exists():
return f"Error: File '{file_path}' not found"
if not path.is_file():
return f"Error: Path '{file_path}' is not a file"
content = path.read_text(encoding="utf-8", errors="replace")
file_data = create_file_data(content)
return format_read_response(file_data, offset, limit)
async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
return await asyncio.to_thread(self.read, file_path, offset, limit)
def read_raw(self, file_path: str) -> str:
"""Read raw file text without line-number formatting."""
try:
path = self._resolve_virtual(file_path)
except ValueError:
return f"Error: Invalid path '{file_path}'"
if not path.exists():
return f"Error: File '{file_path}' not found"
if not path.is_file():
return f"Error: Path '{file_path}' is not a file"
return path.read_text(encoding="utf-8", errors="replace")
async def aread_raw(self, file_path: str) -> str:
"""Async variant of read_raw."""
return await asyncio.to_thread(self.read_raw, file_path)
def write(self, file_path: str, content: str) -> WriteResult:
try:
path = self._resolve_virtual(file_path)
except ValueError:
return WriteResult(error=f"Error: Invalid path '{file_path}'")
lock = self._lock_for(file_path)
with lock:
if path.exists():
return WriteResult(
error=(
f"Cannot write to {file_path} because it already exists. "
"Read and then make an edit, or write to a new path."
)
)
self._write_text_atomic(path, content)
return WriteResult(path=file_path, files_update=None)
async def awrite(self, file_path: str, content: str) -> WriteResult:
return await asyncio.to_thread(self.write, file_path, content)
def edit(
self,
file_path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
) -> EditResult:
try:
path = self._resolve_virtual(file_path)
except ValueError:
return EditResult(error=f"Error: Invalid path '{file_path}'")
lock = self._lock_for(file_path)
with lock:
if not path.exists() or not path.is_file():
return EditResult(error=f"Error: File '{file_path}' not found")
content = path.read_text(encoding="utf-8", errors="replace")
result = perform_string_replacement(content, old_string, new_string, replace_all)
if isinstance(result, str):
return EditResult(error=result)
updated_content, occurrences = result
self._write_text_atomic(path, updated_content)
return EditResult(path=file_path, files_update=None, occurrences=int(occurrences))
async def aedit(
self,
file_path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
) -> EditResult:
return await asyncio.to_thread(
self.edit, file_path, old_string, new_string, replace_all
)
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
try:
base = self._resolve_virtual(path, allow_root=True)
except ValueError:
return []
if pattern.startswith("/"):
search_base = self._root
normalized_pattern = pattern.lstrip("/")
else:
search_base = base
normalized_pattern = pattern
matches: list[FileInfo] = []
for hit in search_base.glob(normalized_pattern):
try:
resolved = hit.resolve()
if not resolved.is_relative_to(self._root):
continue
except Exception:
continue
matches.append(
FileInfo(
path=self._to_virtual(resolved, self._root),
is_dir=resolved.is_dir(),
size=resolved.stat().st_size if resolved.is_file() else 0,
modified_at=str(resolved.stat().st_mtime),
)
)
return matches
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
return await asyncio.to_thread(self.glob_info, pattern, path)
def _iter_candidate_files(self, path: str | None, glob: str | None) -> list[Path]:
base_virtual = path or "/"
try:
base = self._resolve_virtual(base_virtual, allow_root=True)
except ValueError:
return []
if not base.exists():
return []
candidates = [p for p in base.rglob("*") if p.is_file()]
if glob:
candidates = [
p
for p in candidates
if fnmatch.fnmatch(self._to_virtual(p, self._root), glob)
or fnmatch.fnmatch(p.name, glob)
]
return candidates
def grep_raw(
self, pattern: str, path: str | None = None, glob: str | None = None
) -> list[GrepMatch] | str:
if not pattern:
return "Error: pattern cannot be empty"
matches: list[GrepMatch] = []
for file_path in self._iter_candidate_files(path, glob):
try:
lines = file_path.read_text(encoding="utf-8", errors="replace").splitlines()
except Exception:
continue
for idx, line in enumerate(lines, start=1):
if pattern in line:
matches.append(
GrepMatch(
path=self._to_virtual(file_path, self._root),
line=idx,
text=line,
)
)
return matches
async def agrep_raw(
self, pattern: str, path: str | None = None, glob: str | None = None
) -> list[GrepMatch] | str:
return await asyncio.to_thread(self.grep_raw, pattern, path, glob)
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
responses: list[FileUploadResponse] = []
for virtual_path, content in files:
try:
target = self._resolve_virtual(virtual_path)
target.parent.mkdir(parents=True, exist_ok=True)
temp_path = target.with_suffix(f"{target.suffix}.tmp")
temp_path.write_bytes(content)
os.replace(temp_path, target)
responses.append(FileUploadResponse(path=virtual_path, error=None))
except FileNotFoundError:
responses.append(
FileUploadResponse(path=virtual_path, error=_FILE_NOT_FOUND)
)
except IsADirectoryError:
responses.append(FileUploadResponse(path=virtual_path, error=_IS_DIRECTORY))
except Exception:
responses.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH))
return responses
async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
return await asyncio.to_thread(self.upload_files, files)
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
responses: list[FileDownloadResponse] = []
for virtual_path in paths:
try:
target = self._resolve_virtual(virtual_path)
if not target.exists():
responses.append(
FileDownloadResponse(
path=virtual_path, content=None, error=_FILE_NOT_FOUND
)
)
continue
if target.is_dir():
responses.append(
FileDownloadResponse(
path=virtual_path, content=None, error=_IS_DIRECTORY
)
)
continue
responses.append(
FileDownloadResponse(
path=virtual_path, content=target.read_bytes(), error=None
)
)
except Exception:
responses.append(
FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH)
)
return responses
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
return await asyncio.to_thread(self.download_files, paths)

View file

@ -0,0 +1,329 @@
"""Aggregate multiple LocalFolderBackend roots behind mount-prefixed virtual paths."""
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import Any
from deepagents.backends.protocol import (
EditResult,
FileDownloadResponse,
FileInfo,
FileUploadResponse,
GrepMatch,
WriteResult,
)
from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend
_INVALID_PATH = "invalid_path"
_FILE_NOT_FOUND = "file_not_found"
_IS_DIRECTORY = "is_directory"
class MultiRootLocalFolderBackend:
"""Route filesystem operations to one of several mounted local roots.
Virtual paths are namespaced as:
- `/<mount>/...`
where `<mount>` is derived from each selected root folder name.
"""
def __init__(self, mounts: tuple[tuple[str, str], ...]) -> None:
if not mounts:
msg = "At least one local mount is required"
raise ValueError(msg)
self._mount_to_backend: dict[str, LocalFolderBackend] = {}
for raw_mount, raw_root in mounts:
mount = raw_mount.strip()
if not mount:
msg = "Mount id cannot be empty"
raise ValueError(msg)
if mount in self._mount_to_backend:
msg = f"Duplicate mount id: {mount}"
raise ValueError(msg)
normalized_root = str(Path(raw_root).expanduser().resolve())
self._mount_to_backend[mount] = LocalFolderBackend(normalized_root)
self._mount_order = tuple(self._mount_to_backend.keys())
def list_mounts(self) -> tuple[str, ...]:
return self._mount_order
def default_mount(self) -> str:
return self._mount_order[0]
def _mount_error(self) -> str:
mounts = ", ".join(f"/{mount}" for mount in self._mount_order)
return (
"Path must start with one of the selected folders: "
f"{mounts}. Example: /{self._mount_order[0]}/file.txt"
)
def _split_mount_path(self, virtual_path: str) -> tuple[str, str]:
if not virtual_path.startswith("/"):
msg = f"Invalid path (must be absolute): {virtual_path}"
raise ValueError(msg)
rel = virtual_path.lstrip("/")
if not rel:
raise ValueError(self._mount_error())
mount, _, remainder = rel.partition("/")
backend = self._mount_to_backend.get(mount)
if backend is None:
raise ValueError(self._mount_error())
local_path = f"/{remainder}" if remainder else "/"
return mount, local_path
@staticmethod
def _prefix_mount_path(mount: str, local_path: str) -> str:
if local_path == "/":
return f"/{mount}"
return f"/{mount}{local_path}"
@staticmethod
def _get_value(item: Any, key: str) -> Any:
if isinstance(item, dict):
return item.get(key)
return getattr(item, key, None)
@classmethod
def _get_str(cls, item: Any, key: str) -> str:
value = cls._get_value(item, key)
return value if isinstance(value, str) else ""
@classmethod
def _get_int(cls, item: Any, key: str) -> int:
value = cls._get_value(item, key)
return int(value) if isinstance(value, int | float) else 0
@classmethod
def _get_bool(cls, item: Any, key: str) -> bool:
value = cls._get_value(item, key)
return bool(value)
def _list_mount_roots(self) -> list[FileInfo]:
return [
FileInfo(path=f"/{mount}", is_dir=True, size=0, modified_at="0")
for mount in self._mount_order
]
def _transform_infos(self, mount: str, infos: list[FileInfo]) -> list[FileInfo]:
transformed: list[FileInfo] = []
for info in infos:
transformed.append(
FileInfo(
path=self._prefix_mount_path(mount, self._get_str(info, "path")),
is_dir=self._get_bool(info, "is_dir"),
size=self._get_int(info, "size"),
modified_at=self._get_str(info, "modified_at"),
)
)
return transformed
def ls_info(self, path: str) -> list[FileInfo]:
if path == "/":
return self._list_mount_roots()
try:
mount, local_path = self._split_mount_path(path)
except ValueError:
return []
return self._transform_infos(mount, self._mount_to_backend[mount].ls_info(local_path))
async def als_info(self, path: str) -> list[FileInfo]:
return await asyncio.to_thread(self.ls_info, path)
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
try:
mount, local_path = self._split_mount_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
return self._mount_to_backend[mount].read(local_path, offset, limit)
async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
return await asyncio.to_thread(self.read, file_path, offset, limit)
def read_raw(self, file_path: str) -> str:
try:
mount, local_path = self._split_mount_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
return self._mount_to_backend[mount].read_raw(local_path)
async def aread_raw(self, file_path: str) -> str:
return await asyncio.to_thread(self.read_raw, file_path)
def write(self, file_path: str, content: str) -> WriteResult:
try:
mount, local_path = self._split_mount_path(file_path)
except ValueError as exc:
return WriteResult(error=f"Error: {exc}")
result = self._mount_to_backend[mount].write(local_path, content)
if result.path:
result.path = self._prefix_mount_path(mount, result.path)
return result
async def awrite(self, file_path: str, content: str) -> WriteResult:
return await asyncio.to_thread(self.write, file_path, content)
def edit(
self,
file_path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
) -> EditResult:
try:
mount, local_path = self._split_mount_path(file_path)
except ValueError as exc:
return EditResult(error=f"Error: {exc}")
result = self._mount_to_backend[mount].edit(
local_path, old_string, new_string, replace_all
)
if result.path:
result.path = self._prefix_mount_path(mount, result.path)
return result
async def aedit(
self,
file_path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
) -> EditResult:
return await asyncio.to_thread(
self.edit, file_path, old_string, new_string, replace_all
)
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
if path == "/":
prefixed_results: list[FileInfo] = []
if pattern.startswith("/"):
mount, _, remainder = pattern.lstrip("/").partition("/")
backend = self._mount_to_backend.get(mount)
if not backend:
return []
local_pattern = f"/{remainder}" if remainder else "/"
return self._transform_infos(
mount, backend.glob_info(local_pattern, path="/")
)
for mount, backend in self._mount_to_backend.items():
prefixed_results.extend(
self._transform_infos(mount, backend.glob_info(pattern, path="/"))
)
return prefixed_results
try:
mount, local_path = self._split_mount_path(path)
except ValueError:
return []
return self._transform_infos(
mount, self._mount_to_backend[mount].glob_info(pattern, path=local_path)
)
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
return await asyncio.to_thread(self.glob_info, pattern, path)
def grep_raw(
self, pattern: str, path: str | None = None, glob: str | None = None
) -> list[GrepMatch] | str:
if not pattern:
return "Error: pattern cannot be empty"
if path is None or path == "/":
all_matches: list[GrepMatch] = []
for mount, backend in self._mount_to_backend.items():
result = backend.grep_raw(pattern, path="/", glob=glob)
if isinstance(result, str):
return result
all_matches.extend(
[
GrepMatch(
path=self._prefix_mount_path(mount, self._get_str(match, "path")),
line=self._get_int(match, "line"),
text=self._get_str(match, "text"),
)
for match in result
]
)
return all_matches
try:
mount, local_path = self._split_mount_path(path)
except ValueError as exc:
return f"Error: {exc}"
result = self._mount_to_backend[mount].grep_raw(
pattern, path=local_path, glob=glob
)
if isinstance(result, str):
return result
return [
GrepMatch(
path=self._prefix_mount_path(mount, self._get_str(match, "path")),
line=self._get_int(match, "line"),
text=self._get_str(match, "text"),
)
for match in result
]
async def agrep_raw(
self, pattern: str, path: str | None = None, glob: str | None = None
) -> list[GrepMatch] | str:
return await asyncio.to_thread(self.grep_raw, pattern, path, glob)
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
grouped: dict[str, list[tuple[str, bytes]]] = {}
invalid: list[FileUploadResponse] = []
for virtual_path, content in files:
try:
mount, local_path = self._split_mount_path(virtual_path)
except ValueError:
invalid.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH))
continue
grouped.setdefault(mount, []).append((local_path, content))
responses = list(invalid)
for mount, mount_files in grouped.items():
result = self._mount_to_backend[mount].upload_files(mount_files)
responses.extend(
[
FileUploadResponse(
path=self._prefix_mount_path(mount, self._get_str(item, "path")),
error=self._get_str(item, "error") or None,
)
for item in result
]
)
return responses
async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
return await asyncio.to_thread(self.upload_files, files)
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
grouped: dict[str, list[str]] = {}
invalid: list[FileDownloadResponse] = []
for virtual_path in paths:
try:
mount, local_path = self._split_mount_path(virtual_path)
except ValueError:
invalid.append(
FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH)
)
continue
grouped.setdefault(mount, []).append(local_path)
responses = list(invalid)
for mount, mount_paths in grouped.items():
result = self._mount_to_backend[mount].download_files(mount_paths)
responses.extend(
[
FileDownloadResponse(
path=self._prefix_mount_path(mount, self._get_str(item, "path")),
content=self._get_value(item, "content"),
error=self._get_str(item, "error") or None,
)
for item in result
]
)
return responses
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
return await asyncio.to_thread(self.download_files, paths)

View file

@ -0,0 +1,123 @@
"""Safe wrapper around deepagents' SummarizationMiddleware.
Upstream issue
--------------
`deepagents.middleware.summarization.SummarizationMiddleware._aoffload_to_backend`
(and its sync counterpart) call
``get_buffer_string(filtered_messages)`` before writing the evicted history
to the backend file. In recent ``langchain-core`` versions, ``get_buffer_string``
accesses ``m.text`` which iterates ``self.content`` this raises
``TypeError: 'NoneType' object is not iterable`` whenever an ``AIMessage``
has ``content=None`` (common when a model returns *only* tool_calls, seen
frequently with Azure OpenAI ``gpt-5.x`` responses streamed through
LiteLLM).
The exception aborts the whole agent turn, so the user just sees "Error during
chat" with no assistant response.
Fix
---
We subclass ``SummarizationMiddleware`` and override
``_filter_summary_messages`` the only call site that feeds messages into
``get_buffer_string`` to return *copies* of messages whose ``content`` is
``None`` with ``content=""``. The originals flowing through the rest of the
agent state are untouched.
We also expose a drop-in ``create_safe_summarization_middleware`` factory
that mirrors ``deepagents.middleware.summarization.create_summarization_middleware``
but instantiates our safe subclass.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from deepagents.middleware.summarization import (
SummarizationMiddleware,
compute_summarization_defaults,
)
if TYPE_CHECKING:
from deepagents.backends.protocol import BACKEND_TYPES
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AnyMessage
logger = logging.getLogger(__name__)
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
"""Return ``msg`` with ``content`` coerced to a non-``None`` value.
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``;
when a provider streams back an ``AIMessage`` with only tool_calls and
no text, ``content`` can be ``None`` and the iteration explodes. We
replace ``None`` with an empty string so downstream consumers that only
care about text see an empty body.
The original message is left untouched we return a copy via
pydantic's ``model_copy`` when available, otherwise we fall back to
re-setting the attribute on a shallow copy.
"""
if getattr(msg, "content", "not-missing") is not None:
return msg
try:
return msg.model_copy(update={"content": ""})
except AttributeError:
import copy
new_msg = copy.copy(msg)
try:
new_msg.content = ""
except Exception: # pragma: no cover - defensive
logger.debug(
"Could not sanitize content=None on message of type %s",
type(msg).__name__,
)
return msg
return new_msg
class SafeSummarizationMiddleware(SummarizationMiddleware):
"""`SummarizationMiddleware` that tolerates messages with ``content=None``.
Only ``_filter_summary_messages`` is overridden this is the single
helper invoked by both the sync and async offload paths immediately
before ``get_buffer_string``. Normalising here means we get coverage
for both without having to copy the (long, rapidly-changing) offload
implementations from upstream.
"""
def _filter_summary_messages(self, messages: list[AnyMessage]) -> list[AnyMessage]:
filtered = super()._filter_summary_messages(messages)
return [_sanitize_message_content(m) for m in filtered]
def create_safe_summarization_middleware(
model: BaseChatModel,
backend: BACKEND_TYPES,
) -> SafeSummarizationMiddleware:
"""Drop-in replacement for ``create_summarization_middleware``.
Mirrors the defaults computed by ``deepagents`` but returns our
``SafeSummarizationMiddleware`` subclass so the
``content=None`` crash in ``get_buffer_string`` is avoided.
"""
defaults = compute_summarization_defaults(model)
return SafeSummarizationMiddleware(
model=model,
backend=backend,
trigger=defaults["trigger"],
keep=defaults["keep"],
trim_tokens_to_summarize=None,
truncate_args_settings=defaults["truncate_args_settings"],
)
__all__ = [
"SafeSummarizationMiddleware",
"create_safe_summarization_middleware",
]

View file

@ -38,8 +38,66 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
* Formatting, summarization, or analysis of content already present in the conversation
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) see <tool_routing> below
</knowledge_base_only_policy>
<tool_routing>
CRITICAL You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
Their data is NEVER in the knowledge base. You MUST call their tools immediately never
say "I don't see it in the knowledge base" or ask the user if they want you to check.
Ignore any knowledge base results for these services.
When to use which tool:
- Linear (issues) list_issues, get_issue, save_issue (create/update)
- ClickUp (tasks) clickup_search, clickup_get_task
- Jira (issues) getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
- Slack (messages, channels) slack_search_channels, slack_read_channel, slack_read_thread
- Airtable (bases, tables, records) list_bases, list_tables_for_base, list_records_for_table
- Knowledge base content (Notion, GitHub, files, notes) automatically searched
- Real-time public web data call web_search
- Reading a specific webpage call scrape_webpage
</tool_routing>
<parameter_resolution>
Some service tools require identifiers or context you do not have (account IDs,
workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw
IDs or technical identifiers they cannot memorise them.
Instead, follow this discovery pattern:
1. Call a listing/discovery tool to find available options.
2. ONE result use it silently, no question to the user.
3. MULTIPLE results present the options by their display names and let the
user choose. Never show raw UUIDs always use friendly names.
Discovery tools by level:
- Which account/workspace? get_connected_accounts("<service>")
- Which Jira site (cloudId)? getAccessibleAtlassianResources
- Which Jira project? getVisibleJiraProjects (after resolving cloudId)
- Which Jira issue type? getJiraProjectIssueTypesMetadata (after resolving project)
- Which channel? slack_search_channels
- Which base? list_bases
- Which table? list_tables_for_base (after resolving baseId)
- Which task? clickup_search
- Which issue? list_issues (Linear) or searchJiraIssuesUsingJql (Jira)
For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to
obtain the cloudId, then pass it to other Jira tools. When creating an issue,
chain: getAccessibleAtlassianResources getVisibleJiraProjects createJiraIssue.
If there is only one option at each step, use it silently. If multiple, present
friendly names.
Chain discovery when needed e.g. for Airtable records: list_bases pick
base list_tables_for_base pick table list_records_for_table.
MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for
the same service, tool names are prefixed to avoid collisions e.g.
linear_25_list_issues and linear_30_list_issues instead of two list_issues.
Each prefixed tool's description starts with [Account: <display_name>] so you
know which account it targets. Use get_connected_accounts("<service>") to see
the full list of accounts with their connector IDs and display names.
When only one account is connected, tools have their normal unprefixed names.
</parameter_resolution>
<memory_protocol>
IMPORTANT After understanding each user message, ALWAYS check: does this message
reveal durable facts about the user (role, interests, preferences, projects,
@ -76,8 +134,66 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
* Formatting, summarization, or analysis of content already present in the conversation
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) see <tool_routing> below
</knowledge_base_only_policy>
<tool_routing>
CRITICAL You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
Their data is NEVER in the knowledge base. You MUST call their tools immediately never
say "I don't see it in the knowledge base" or ask if they want you to check.
Ignore any knowledge base results for these services.
When to use which tool:
- Linear (issues) list_issues, get_issue, save_issue (create/update)
- ClickUp (tasks) clickup_search, clickup_get_task
- Jira (issues) getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
- Slack (messages, channels) slack_search_channels, slack_read_channel, slack_read_thread
- Airtable (bases, tables, records) list_bases, list_tables_for_base, list_records_for_table
- Knowledge base content (Notion, GitHub, files, notes) automatically searched
- Real-time public web data call web_search
- Reading a specific webpage call scrape_webpage
</tool_routing>
<parameter_resolution>
Some service tools require identifiers or context you do not have (account IDs,
workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw
IDs or technical identifiers they cannot memorise them.
Instead, follow this discovery pattern:
1. Call a listing/discovery tool to find available options.
2. ONE result use it silently, no question to the user.
3. MULTIPLE results present the options by their display names and let the
user choose. Never show raw UUIDs always use friendly names.
Discovery tools by level:
- Which account/workspace? get_connected_accounts("<service>")
- Which Jira site (cloudId)? getAccessibleAtlassianResources
- Which Jira project? getVisibleJiraProjects (after resolving cloudId)
- Which Jira issue type? getJiraProjectIssueTypesMetadata (after resolving project)
- Which channel? slack_search_channels
- Which base? list_bases
- Which table? list_tables_for_base (after resolving baseId)
- Which task? clickup_search
- Which issue? list_issues (Linear) or searchJiraIssuesUsingJql (Jira)
For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to
obtain the cloudId, then pass it to other Jira tools. When creating an issue,
chain: getAccessibleAtlassianResources getVisibleJiraProjects createJiraIssue.
If there is only one option at each step, use it silently. If multiple, present
friendly names.
Chain discovery when needed e.g. for Airtable records: list_bases pick
base list_tables_for_base pick table list_records_for_table.
MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for
the same service, tool names are prefixed to avoid collisions e.g.
linear_25_list_issues and linear_30_list_issues instead of two list_issues.
Each prefixed tool's description starts with [Account: <display_name>] so you
know which account it targets. Use get_connected_accounts("<service>") to see
the full list of accounts with their connector IDs and display names.
When only one account is connected, tools have their normal unprefixed names.
</parameter_resolution>
<memory_protocol>
IMPORTANT After understanding each user message, ALWAYS check: does this message
reveal durable facts about the team (decisions, conventions, architecture, processes,
@ -450,6 +566,9 @@ _TOOL_INSTRUCTIONS["generate_resume"] = """
- WHEN NOT TO CALL: General career advice, resume tips, cover letters, or reviewing
a resume without making changes. For cover letters, use generate_report instead.
- The tool produces Typst source code that is compiled to a PDF preview automatically.
- PAGE POLICY:
- Default behavior is ONE PAGE. For new resume creation, set max_pages=1 unless the user explicitly asks for more.
- If the user requests a longer resume (e.g., "make it 2 pages"), set max_pages to that value.
- Args:
- user_info: The user's resume content — work experience, education, skills, contact
info, etc. Can be structured or unstructured text.
@ -465,6 +584,7 @@ _TOOL_INSTRUCTIONS["generate_resume"] = """
"keep it to one page"). For revisions, describe what to change.
- parent_report_id: Set this when the user wants to MODIFY an existing resume from
this conversation. Use the report_id from a previous generate_resume result.
- max_pages: Maximum resume length in pages (integer 1-5). Default is 1.
- Returns: Dict with status, report_id, title, and content_type.
- After calling: Give a brief confirmation. Do NOT paste resume content in chat. Do NOT mention report_id or any internal IDs the resume card is shown automatically.
- VERSIONING: Same rules as generate_report set parent_report_id for modifications
@ -473,17 +593,20 @@ _TOOL_INSTRUCTIONS["generate_resume"] = """
_TOOL_EXAMPLES["generate_resume"] = """
- User: "Build me a resume. I'm John Doe, engineer at Acme Corp..."
- Call: `generate_resume(user_info="John Doe, engineer at Acme Corp...")`
- Call: `generate_resume(user_info="John Doe, engineer at Acme Corp...", max_pages=1)`
- WHY: Has creation verb "build" + resume call the tool.
- User: "Create my CV with this info: [experience, education, skills]"
- Call: `generate_resume(user_info="[experience, education, skills]")`
- Call: `generate_resume(user_info="[experience, education, skills]", max_pages=1)`
- User: "Build me a resume" (and there is a resume/CV document in the conversation context)
- Extract the FULL content from the document in context, then call:
`generate_resume(user_info="Name: John Doe\\nEmail: john@example.com\\n\\nExperience:\\n- Senior Engineer at Acme Corp (2020-2024)\\n Led team of 5...\\n\\nEducation:\\n- BS Computer Science, MIT (2016-2020)\\n\\nSkills: Python, TypeScript, AWS...")`
`generate_resume(user_info="Name: John Doe\\nEmail: john@example.com\\n\\nExperience:\\n- Senior Engineer at Acme Corp (2020-2024)\\n Led team of 5...\\n\\nEducation:\\n- BS Computer Science, MIT (2016-2020)\\n\\nSkills: Python, TypeScript, AWS...", max_pages=1)`
- WHY: Document content is available in context extract ALL of it into user_info. Do NOT ignore referenced documents.
- User: (after resume generated) "Change my title to Senior Engineer"
- Call: `generate_resume(user_info="", user_instructions="Change the job title to Senior Engineer", parent_report_id=<previous_report_id>)`
- Call: `generate_resume(user_info="", user_instructions="Change the job title to Senior Engineer", parent_report_id=<previous_report_id>, max_pages=1)`
- WHY: Modification verb "change" + refers to existing resume set parent_report_id.
- User: (after resume generated) "Make this 2 pages and expand projects"
- Call: `generate_resume(user_info="", user_instructions="Expand projects and keep this to at most 2 pages", parent_report_id=<previous_report_id>, max_pages=2)`
- WHY: Explicit page increase request set max_pages to 2.
- User: "How should I structure my resume?"
- Do NOT call generate_resume. Answer in chat with advice.
- WHY: No creation/modification verb.
@ -692,11 +815,36 @@ Your goal is to provide helpful, informative answers in a clean, readable format
"""
def _build_mcp_routing_block(
mcp_connector_tools: dict[str, list[str]] | None,
) -> str:
"""Build an additional tool routing block for generic MCP connectors.
When users add MCP servers (e.g. GitLab, GitHub), the LLM needs to know
those tools exist and should be called directly not searched in the
knowledge base.
"""
if not mcp_connector_tools:
return ""
lines = [
"\n<mcp_tool_routing>",
"You also have direct tools from these user-connected MCP servers.",
"Their data is NEVER in the knowledge base — call their tools directly.",
"",
]
for server_name, tool_names in mcp_connector_tools.items():
lines.append(f"- {server_name}{', '.join(tool_names)}")
lines.append("</mcp_tool_routing>\n")
return "\n".join(lines)
def build_surfsense_system_prompt(
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
enabled_tool_names: set[str] | None = None,
disabled_tool_names: set[str] | None = None,
mcp_connector_tools: dict[str, list[str]] | None = None,
) -> str:
"""
Build the SurfSense system prompt with default settings.
@ -711,6 +859,9 @@ def build_surfsense_system_prompt(
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
mcp_connector_tools: Mapping of MCP server display name list of tool names
for generic MCP connectors. Injected into the system prompt so the LLM
knows to call these tools directly.
Returns:
Complete system prompt string
@ -718,6 +869,7 @@ def build_surfsense_system_prompt(
visibility = thread_visibility or ChatVisibility.PRIVATE
system_instructions = _get_system_instructions(visibility, today)
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
tools_instructions = _get_tools_instructions(
visibility, enabled_tool_names, disabled_tool_names
)
@ -733,6 +885,7 @@ def build_configurable_system_prompt(
thread_visibility: ChatVisibility | None = None,
enabled_tool_names: set[str] | None = None,
disabled_tool_names: set[str] | None = None,
mcp_connector_tools: dict[str, list[str]] | None = None,
) -> str:
"""
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
@ -754,6 +907,9 @@ def build_configurable_system_prompt(
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
mcp_connector_tools: Mapping of MCP server display name list of tool names
for generic MCP connectors. Injected into the system prompt so the LLM
knows to call these tools directly.
Returns:
Complete system prompt string
@ -771,6 +927,8 @@ def build_configurable_system_prompt(
else:
system_instructions = ""
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
# Tools instructions: only include enabled tools, note disabled ones
tools_instructions = _get_tools_instructions(
thread_visibility, enabled_tool_names, disabled_tool_names

View file

@ -0,0 +1,109 @@
"""Connected-accounts discovery tool.
Lets the LLM discover which accounts are connected for a given service
(e.g. "jira", "linear", "slack") and retrieve the metadata it needs to
call action tools such as Jira's ``cloudId``.
The tool returns **only** non-sensitive fields explicitly listed in the
service's ``account_metadata_keys`` (see ``registry.py``), plus the
always-present ``display_name`` and ``connector_id``.
"""
import logging
from typing import Any
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.services.mcp_oauth.registry import MCP_SERVICES
logger = logging.getLogger(__name__)
_SERVICE_KEY_BY_CONNECTOR_TYPE: dict[str, str] = {
cfg.connector_type: key for key, cfg in MCP_SERVICES.items()
}
class GetConnectedAccountsInput(BaseModel):
service: str = Field(
description=(
"Service key to look up connected accounts for. "
"Valid values: " + ", ".join(sorted(MCP_SERVICES.keys()))
),
)
def _extract_display_name(connector: SearchSourceConnector) -> str:
"""Best-effort human-readable label for a connector."""
cfg = connector.config or {}
if cfg.get("display_name"):
return cfg["display_name"]
if cfg.get("base_url"):
return f"{connector.name} ({cfg['base_url']})"
if cfg.get("organization_name"):
return f"{connector.name} ({cfg['organization_name']})"
return connector.name
def create_get_connected_accounts_tool(
db_session: AsyncSession,
search_space_id: int,
user_id: str,
) -> StructuredTool:
async def _run(service: str) -> list[dict[str, Any]]:
svc_cfg = MCP_SERVICES.get(service)
if not svc_cfg:
return [{"error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}"}]
try:
connector_type = SearchSourceConnectorType(svc_cfg.connector_type)
except ValueError:
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == connector_type,
)
)
connectors = result.scalars().all()
if not connectors:
return [{"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."}]
is_multi = len(connectors) > 1
accounts: list[dict[str, Any]] = []
for conn in connectors:
cfg = conn.config or {}
entry: dict[str, Any] = {
"connector_id": conn.id,
"display_name": _extract_display_name(conn),
"service": service,
}
if is_multi:
entry["tool_prefix"] = f"{service}_{conn.id}"
for key in svc_cfg.account_metadata_keys:
if key in cfg:
entry[key] = cfg[key]
accounts.append(entry)
return accounts
return StructuredTool(
name="get_connected_accounts",
description=(
"Discover which accounts are connected for a service (e.g. jira, linear, slack, clickup, airtable). "
"Returns display names and service-specific metadata the action tools need "
"(e.g. Jira's cloudId). Call this BEFORE using a service's action tools when "
"you need an account identifier or are unsure which account to use."
),
coroutine=_run,
args_schema=GetConnectedAccountsInput,
metadata={"hitl": False},
)

View file

@ -0,0 +1,15 @@
from app.agents.new_chat.tools.discord.list_channels import (
create_list_discord_channels_tool,
)
from app.agents.new_chat.tools.discord.read_messages import (
create_read_discord_messages_tool,
)
from app.agents.new_chat.tools.discord.send_message import (
create_send_discord_message_tool,
)
__all__ = [
"create_list_discord_channels_tool",
"create_read_discord_messages_tool",
"create_send_discord_message_tool",
]

View file

@ -0,0 +1,42 @@
"""Shared auth helper for Discord agent tools (REST API, not gateway bot)."""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.utils.oauth_security import TokenEncryption
DISCORD_API = "https://discord.com/api/v10"
async def get_discord_connector(
db_session: AsyncSession,
search_space_id: int,
user_id: str,
) -> SearchSourceConnector | None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR,
)
)
return result.scalars().first()
def get_bot_token(connector: SearchSourceConnector) -> str:
"""Extract and decrypt the bot token from connector config."""
cfg = dict(connector.config)
if cfg.get("_token_encrypted") and config.SECRET_KEY:
enc = TokenEncryption(config.SECRET_KEY)
if cfg.get("bot_token"):
cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"])
token = cfg.get("bot_token")
if not token:
raise ValueError("Discord bot token not found in connector config.")
return token
def get_guild_id(connector: SearchSourceConnector) -> str | None:
return connector.config.get("guild_id")

View file

@ -0,0 +1,67 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
logger = logging.getLogger(__name__)
def create_list_discord_channels_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def list_discord_channels() -> dict[str, Any]:
"""List text channels in the connected Discord server.
Returns:
Dictionary with status and a list of channels (id, name).
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Discord tool not properly configured."}
try:
connector = await get_discord_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
guild_id = get_guild_id(connector)
if not guild_id:
return {"status": "error", "message": "No guild ID in Discord connector config."}
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{DISCORD_API}/guilds/{guild_id}/channels",
headers={"Authorization": f"Bot {token}"},
timeout=15.0,
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
if resp.status_code != 200:
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
# Type 0 = text channel
channels = [
{"id": ch["id"], "name": ch["name"]}
for ch in resp.json()
if ch.get("type") == 0
]
return {"status": "success", "guild_id": guild_id, "channels": channels, "total": len(channels)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error listing Discord channels: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to list Discord channels."}
return list_discord_channels

View file

@ -0,0 +1,80 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
logger = logging.getLogger(__name__)
def create_read_discord_messages_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def read_discord_messages(
channel_id: str,
limit: int = 25,
) -> dict[str, Any]:
"""Read recent messages from a Discord text channel.
Args:
channel_id: The Discord channel ID (from list_discord_channels).
limit: Number of messages to fetch (default 25, max 50).
Returns:
Dictionary with status and a list of messages including
id, author, content, timestamp.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Discord tool not properly configured."}
limit = min(limit, 50)
try:
connector = await get_discord_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{DISCORD_API}/channels/{channel_id}/messages",
headers={"Authorization": f"Bot {token}"},
params={"limit": limit},
timeout=15.0,
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
if resp.status_code == 403:
return {"status": "error", "message": "Bot lacks permission to read this channel."}
if resp.status_code != 200:
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
messages = [
{
"id": m["id"],
"author": m.get("author", {}).get("username", "Unknown"),
"content": m.get("content", ""),
"timestamp": m.get("timestamp", ""),
}
for m in resp.json()
]
return {"status": "success", "channel_id": channel_id, "messages": messages, "total": len(messages)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error reading Discord messages: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to read Discord messages."}
return read_discord_messages

View file

@ -0,0 +1,96 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
logger = logging.getLogger(__name__)
def create_send_discord_message_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def send_discord_message(
channel_id: str,
content: str,
) -> dict[str, Any]:
"""Send a message to a Discord text channel.
Args:
channel_id: The Discord channel ID (from list_discord_channels).
content: The message text (max 2000 characters).
Returns:
Dictionary with status, message_id on success.
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Discord tool not properly configured."}
if len(content) > 2000:
return {"status": "error", "message": "Message exceeds Discord's 2000-character limit."}
try:
connector = await get_discord_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
result = request_approval(
action_type="discord_send_message",
tool_name="send_discord_message",
params={"channel_id": channel_id, "content": content},
context={"connector_id": connector.id},
)
if result.rejected:
return {"status": "rejected", "message": "User declined. Message was not sent."}
final_content = result.params.get("content", content)
final_channel = result.params.get("channel_id", channel_id)
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{DISCORD_API}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bot {token}",
"Content-Type": "application/json",
},
json={"content": final_content},
timeout=15.0,
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
if resp.status_code == 403:
return {"status": "error", "message": "Bot lacks permission to send messages in this channel."}
if resp.status_code not in (200, 201):
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
msg_data = resp.json()
return {
"status": "success",
"message_id": msg_data.get("id"),
"message": f"Message sent to channel {final_channel}.",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error sending Discord message: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to send Discord message."}
return send_discord_message

View file

@ -1,6 +1,12 @@
from app.agents.new_chat.tools.gmail.create_draft import (
create_create_gmail_draft_tool,
)
from app.agents.new_chat.tools.gmail.read_email import (
create_read_gmail_email_tool,
)
from app.agents.new_chat.tools.gmail.search_emails import (
create_search_gmail_tool,
)
from app.agents.new_chat.tools.gmail.send_email import (
create_send_gmail_email_tool,
)
@ -13,6 +19,8 @@ from app.agents.new_chat.tools.gmail.update_draft import (
__all__ = [
"create_create_gmail_draft_tool",
"create_read_gmail_email_tool",
"create_search_gmail_tool",
"create_send_gmail_email_tool",
"create_trash_gmail_email_tool",
"create_update_gmail_draft_tool",

View file

@ -0,0 +1,87 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
_GMAIL_TYPES = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
def create_read_gmail_email_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def read_gmail_email(message_id: str) -> dict[str, Any]:
"""Read the full content of a specific Gmail email by its message ID.
Use after search_gmail to get the complete body of an email.
Args:
message_id: The Gmail message ID (from search_gmail results).
Returns:
Dictionary with status and the full email content formatted as markdown.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."}
try:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
gmail = GoogleGmailConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
detail, error = await gmail.get_message_details(message_id)
if error:
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
return {"status": "auth_error", "message": error, "connector_type": "gmail"}
return {"status": "error", "message": error}
if not detail:
return {"status": "not_found", "message": f"Email with ID '{message_id}' not found."}
content = gmail.format_message_to_markdown(detail)
return {"status": "success", "message_id": message_id, "content": content}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error reading Gmail email: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to read email. Please try again."}
return read_gmail_email

View file

@ -0,0 +1,165 @@
import logging
from datetime import datetime
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
_GMAIL_TYPES = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
_token_encryption_cache: object | None = None
def _get_token_encryption():
global _token_encryption_cache
if _token_encryption_cache is None:
from app.config import config
from app.utils.oauth_security import TokenEncryption
if not config.SECRET_KEY:
raise RuntimeError("SECRET_KEY not configured for token decryption.")
_token_encryption_cache = TokenEncryption(config.SECRET_KEY)
return _token_encryption_cache
def _build_credentials(connector: SearchSourceConnector):
"""Build Google OAuth Credentials from a connector's stored config.
Handles both native OAuth connectors (with encrypted tokens) and
Composio-backed connectors. Shared by Gmail and Calendar tools.
"""
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
raise ValueError("Composio connected account ID not found.")
return build_composio_credentials(cca_id)
from google.oauth2.credentials import Credentials
cfg = dict(connector.config)
if cfg.get("_token_encrypted"):
enc = _get_token_encryption()
for key in ("token", "refresh_token", "client_secret"):
if cfg.get(key):
cfg[key] = enc.decrypt_token(cfg[key])
exp = (cfg.get("expiry") or "").replace("Z", "")
return Credentials(
token=cfg.get("token"),
refresh_token=cfg.get("refresh_token"),
token_uri=cfg.get("token_uri"),
client_id=cfg.get("client_id"),
client_secret=cfg.get("client_secret"),
scopes=cfg.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
def create_search_gmail_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def search_gmail(
query: str,
max_results: int = 10,
) -> dict[str, Any]:
"""Search emails in the user's Gmail inbox using Gmail search syntax.
Args:
query: Gmail search query, same syntax as the Gmail search bar.
Examples: "from:alice@example.com", "subject:meeting",
"is:unread", "after:2024/01/01 before:2024/02/01",
"has:attachment", "in:sent".
max_results: Number of emails to return (default 10, max 20).
Returns:
Dictionary with status and a list of email summaries including
message_id, subject, from, date, snippet.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."}
max_results = min(max_results, 20)
try:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
gmail = GoogleGmailConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
messages_list, error = await gmail.get_messages_list(
max_results=max_results, query=query
)
if error:
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
return {"status": "auth_error", "message": error, "connector_type": "gmail"}
return {"status": "error", "message": error}
if not messages_list:
return {"status": "success", "emails": [], "total": 0, "message": "No emails found."}
emails = []
for msg in messages_list:
detail, err = await gmail.get_message_details(msg["id"])
if err:
continue
headers = {
h["name"].lower(): h["value"]
for h in detail.get("payload", {}).get("headers", [])
}
emails.append({
"message_id": detail.get("id"),
"thread_id": detail.get("threadId"),
"subject": headers.get("subject", "No Subject"),
"from": headers.get("from", "Unknown"),
"to": headers.get("to", ""),
"date": headers.get("date", ""),
"snippet": detail.get("snippet", ""),
"labels": detail.get("labelIds", []),
})
return {"status": "success", "emails": emails, "total": len(emails)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error searching Gmail: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to search Gmail. Please try again."}
return search_gmail

View file

@ -4,6 +4,9 @@ from app.agents.new_chat.tools.google_calendar.create_event import (
from app.agents.new_chat.tools.google_calendar.delete_event import (
create_delete_calendar_event_tool,
)
from app.agents.new_chat.tools.google_calendar.search_events import (
create_search_calendar_events_tool,
)
from app.agents.new_chat.tools.google_calendar.update_event import (
create_update_calendar_event_tool,
)
@ -11,5 +14,6 @@ from app.agents.new_chat.tools.google_calendar.update_event import (
__all__ = [
"create_create_calendar_event_tool",
"create_delete_calendar_event_tool",
"create_search_calendar_events_tool",
"create_update_calendar_event_tool",
]

View file

@ -0,0 +1,114 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
_CALENDAR_TYPES = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
def create_search_calendar_events_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def search_calendar_events(
start_date: str,
end_date: str,
max_results: int = 25,
) -> dict[str, Any]:
"""Search Google Calendar events within a date range.
Args:
start_date: Start date in YYYY-MM-DD format (e.g. "2026-04-01").
end_date: End date in YYYY-MM-DD format (e.g. "2026-04-30").
max_results: Maximum number of events to return (default 25, max 50).
Returns:
Dictionary with status and a list of events including
event_id, summary, start, end, location, attendees.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Calendar tool not properly configured."}
max_results = min(max_results, 50)
try:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
}
creds = _build_credentials(connector)
from app.connectors.google_calendar_connector import GoogleCalendarConnector
cal = GoogleCalendarConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
events_raw, error = await cal.get_all_primary_calendar_events(
start_date=start_date,
end_date=end_date,
max_results=max_results,
)
if error:
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
return {"status": "auth_error", "message": error, "connector_type": "google_calendar"}
if "no events found" in error.lower():
return {"status": "success", "events": [], "total": 0, "message": error}
return {"status": "error", "message": error}
events = []
for ev in events_raw:
start = ev.get("start", {})
end = ev.get("end", {})
attendees_raw = ev.get("attendees", [])
events.append({
"event_id": ev.get("id"),
"summary": ev.get("summary", "No Title"),
"start": start.get("dateTime") or start.get("date", ""),
"end": end.get("dateTime") or end.get("date", ""),
"location": ev.get("location", ""),
"description": ev.get("description", ""),
"html_link": ev.get("htmlLink", ""),
"attendees": [
a.get("email", "") for a in attendees_raw[:10]
],
"status": ev.get("status", ""),
})
return {"status": "success", "events": events, "total": len(events)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error searching calendar events: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to search calendar events. Please try again."}
return search_calendar_events

View file

@ -130,8 +130,8 @@ def request_approval(
try:
decision_type, edited_params = _parse_decision(approval)
except ValueError:
logger.warning("No approval decision received for %s", tool_name)
return HITLResult(rejected=False, decision_type="error", params=params)
logger.warning("No approval decision received for %s — rejecting for safety", tool_name)
return HITLResult(rejected=True, decision_type="error", params=params)
logger.info("User decision for %s: %s", tool_name, decision_type)

View file

@ -0,0 +1,15 @@
from app.agents.new_chat.tools.luma.create_event import (
create_create_luma_event_tool,
)
from app.agents.new_chat.tools.luma.list_events import (
create_list_luma_events_tool,
)
from app.agents.new_chat.tools.luma.read_event import (
create_read_luma_event_tool,
)
__all__ = [
"create_create_luma_event_tool",
"create_list_luma_events_tool",
"create_read_luma_event_tool",
]

View file

@ -0,0 +1,38 @@
"""Shared auth helper for Luma agent tools."""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
LUMA_API = "https://public-api.luma.com/v1"
async def get_luma_connector(
db_session: AsyncSession,
search_space_id: int,
user_id: str,
) -> SearchSourceConnector | None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR,
)
)
return result.scalars().first()
def get_api_key(connector: SearchSourceConnector) -> str:
"""Extract the API key from connector config (handles both key names)."""
key = connector.config.get("api_key") or connector.config.get("LUMA_API_KEY")
if not key:
raise ValueError("Luma API key not found in connector config.")
return key
def luma_headers(api_key: str) -> dict[str, str]:
return {
"Content-Type": "application/json",
"x-luma-api-key": api_key,
}

View file

@ -0,0 +1,116 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
def create_create_luma_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def create_luma_event(
name: str,
start_at: str,
end_at: str,
description: str | None = None,
timezone: str = "UTC",
) -> dict[str, Any]:
"""Create a new event on Luma.
Args:
name: The event title.
start_at: Start time in ISO 8601 format (e.g. "2026-05-01T18:00:00").
end_at: End time in ISO 8601 format (e.g. "2026-05-01T20:00:00").
description: Optional event description (markdown supported).
timezone: Timezone string (default "UTC", e.g. "America/New_York").
Returns:
Dictionary with status, event_id on success.
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
try:
connector = await get_luma_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
result = request_approval(
action_type="luma_create_event",
tool_name="create_luma_event",
params={
"name": name,
"start_at": start_at,
"end_at": end_at,
"description": description,
"timezone": timezone,
},
context={"connector_id": connector.id},
)
if result.rejected:
return {"status": "rejected", "message": "User declined. Event was not created."}
final_name = result.params.get("name", name)
final_start = result.params.get("start_at", start_at)
final_end = result.params.get("end_at", end_at)
final_desc = result.params.get("description", description)
final_tz = result.params.get("timezone", timezone)
api_key = get_api_key(connector)
headers = luma_headers(api_key)
body: dict[str, Any] = {
"name": final_name,
"start_at": final_start,
"end_at": final_end,
"timezone": final_tz,
}
if final_desc:
body["description_md"] = final_desc
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{LUMA_API}/event/create",
headers=headers,
json=body,
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
if resp.status_code == 403:
return {"status": "error", "message": "Luma Plus subscription required to create events via API."}
if resp.status_code not in (200, 201):
return {"status": "error", "message": f"Luma API error: {resp.status_code}{resp.text[:200]}"}
data = resp.json()
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
return {
"status": "success",
"event_id": event_id,
"message": f"Event '{final_name}' created on Luma.",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error creating Luma event: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to create Luma event."}
return create_luma_event

View file

@ -0,0 +1,100 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
def create_list_luma_events_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def list_luma_events(
max_results: int = 25,
) -> dict[str, Any]:
"""List upcoming and recent Luma events.
Args:
max_results: Maximum events to return (default 25, max 50).
Returns:
Dictionary with status and a list of events including
event_id, name, start_at, end_at, location, url.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
max_results = min(max_results, 50)
try:
connector = await get_luma_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
api_key = get_api_key(connector)
headers = luma_headers(api_key)
all_entries: list[dict] = []
cursor = None
async with httpx.AsyncClient(timeout=20.0) as client:
while len(all_entries) < max_results:
params: dict[str, Any] = {"limit": min(100, max_results - len(all_entries))}
if cursor:
params["cursor"] = cursor
resp = await client.get(
f"{LUMA_API}/calendar/list-events",
headers=headers,
params=params,
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
if resp.status_code != 200:
return {"status": "error", "message": f"Luma API error: {resp.status_code}"}
data = resp.json()
entries = data.get("entries", [])
if not entries:
break
all_entries.extend(entries)
next_cursor = data.get("next_cursor")
if not next_cursor:
break
cursor = next_cursor
events = []
for entry in all_entries[:max_results]:
ev = entry.get("event", {})
geo = ev.get("geo_info", {})
events.append({
"event_id": entry.get("api_id"),
"name": ev.get("name", "Untitled"),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location": geo.get("name", ""),
"url": ev.get("url", ""),
"visibility": ev.get("visibility", ""),
})
return {"status": "success", "events": events, "total": len(events)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error listing Luma events: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to list Luma events."}
return list_luma_events

View file

@ -0,0 +1,82 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
def create_read_luma_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def read_luma_event(event_id: str) -> dict[str, Any]:
"""Read detailed information about a specific Luma event.
Args:
event_id: The Luma event API ID (from list_luma_events).
Returns:
Dictionary with status and full event details including
description, attendees count, meeting URL.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
try:
connector = await get_luma_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
api_key = get_api_key(connector)
headers = luma_headers(api_key)
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
f"{LUMA_API}/events/{event_id}",
headers=headers,
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
if resp.status_code == 404:
return {"status": "not_found", "message": f"Event '{event_id}' not found."}
if resp.status_code != 200:
return {"status": "error", "message": f"Luma API error: {resp.status_code}"}
data = resp.json()
ev = data.get("event", data)
geo = ev.get("geo_info", {})
event_detail = {
"event_id": event_id,
"name": ev.get("name", ""),
"description": ev.get("description", ""),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location_name": geo.get("name", ""),
"address": geo.get("address", ""),
"url": ev.get("url", ""),
"meeting_url": ev.get("meeting_url", ""),
"visibility": ev.get("visibility", ""),
"cover_url": ev.get("cover_url", ""),
}
return {"status": "success", "event": event_detail}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error reading Luma event: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to read Luma event."}
return read_luma_event

View file

@ -45,6 +45,18 @@ class MCPClient:
async def connect(self, max_retries: int = MAX_RETRIES):
"""Connect to the MCP server and manage its lifecycle.
Retries only apply to the **connection** phase (spawning the process,
initialising the session). Once the session is yielded to the caller,
any exception raised by the caller propagates normally -- the context
manager will NOT retry after ``yield``.
Previous implementation wrapped both connection AND yield inside the
retry loop. Because ``@asynccontextmanager`` only allows a single
``yield``, a failure after yield caused the generator to attempt a
second yield on retry, triggering
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
the stdio subprocess.
Args:
max_retries: Maximum number of connection retry attempts
@ -57,26 +69,22 @@ class MCPClient:
"""
last_error = None
delay = RETRY_DELAY
connected = False
for attempt in range(max_retries):
try:
# Merge env vars with current environment
server_env = os.environ.copy()
server_env.update(self.env)
# Create server parameters with env
server_params = StdioServerParameters(
command=self.command, args=self.args, env=server_env
)
# Spawn server process and create session
# Note: Cannot combine these context managers because ClientSession
# needs the read/write streams from stdio_client
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
self.session = session
connected = True
if attempt > 0:
logger.info(
@ -91,10 +99,16 @@ class MCPClient:
self.command,
" ".join(self.args),
)
yield session
return # Success, exit retry loop
try:
yield session
finally:
self.session = None
return
except Exception as e:
self.session = None
if connected:
raise
last_error = e
if attempt < max_retries - 1:
logger.warning(
@ -105,7 +119,7 @@ class MCPClient:
delay,
)
await asyncio.sleep(delay)
delay *= RETRY_BACKOFF # Exponential backoff
delay *= RETRY_BACKOFF
else:
logger.error(
"Failed to connect to MCP server after %d attempts: %s",
@ -113,10 +127,7 @@ class MCPClient:
e,
exc_info=True,
)
finally:
self.session = None
# All retries exhausted
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
if last_error:
error_msg += f": {last_error}"
@ -161,12 +172,18 @@ class MCPClient:
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
raise
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
timeout: float = 60.0,
) -> Any:
"""Call a tool on the MCP server.
Args:
tool_name: Name of the tool to call
arguments: Arguments to pass to the tool
timeout: Maximum seconds to wait for the tool to respond
Returns:
Tool execution result
@ -185,10 +202,11 @@ class MCPClient:
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
)
# Call tools/call RPC method
response = await self.session.call_tool(tool_name, arguments=arguments)
response = await asyncio.wait_for(
self.session.call_tool(tool_name, arguments=arguments),
timeout=timeout,
)
# Extract content from response
result = []
for content in response.content:
if hasattr(content, "text"):
@ -202,15 +220,17 @@ class MCPClient:
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
return result_str
except asyncio.TimeoutError:
logger.error(
"MCP tool '%s' timed out after %.0fs", tool_name, timeout
)
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
except RuntimeError as e:
# Handle validation errors from MCP server responses
# Some MCP servers (like server-memory) return extra fields not in their schema
if "Invalid structured content" in str(e):
logger.warning(
"MCP server returned data not matching its schema, but continuing: %s",
e,
)
# Try to extract result from error message or return a success message
return "Operation completed (server returned unexpected format)"
raise
except (ValueError, TypeError, AttributeError, KeyError) as e:

File diff suppressed because it is too large Load diff

View file

@ -50,6 +50,11 @@ from .confluence import (
create_delete_confluence_page_tool,
create_update_confluence_page_tool,
)
from .discord import (
create_list_discord_channels_tool,
create_read_discord_messages_tool,
create_send_discord_message_tool,
)
from .dropbox import (
create_create_dropbox_file_tool,
create_delete_dropbox_file_tool,
@ -57,6 +62,8 @@ from .dropbox import (
from .generate_image import create_generate_image_tool
from .gmail import (
create_create_gmail_draft_tool,
create_read_gmail_email_tool,
create_search_gmail_tool,
create_send_gmail_email_tool,
create_trash_gmail_email_tool,
create_update_gmail_draft_tool,
@ -64,21 +71,18 @@ from .gmail import (
from .google_calendar import (
create_create_calendar_event_tool,
create_delete_calendar_event_tool,
create_search_calendar_events_tool,
create_update_calendar_event_tool,
)
from .google_drive import (
create_create_google_drive_file_tool,
create_delete_google_drive_file_tool,
)
from .jira import (
create_create_jira_issue_tool,
create_delete_jira_issue_tool,
create_update_jira_issue_tool,
)
from .linear import (
create_create_linear_issue_tool,
create_delete_linear_issue_tool,
create_update_linear_issue_tool,
from .connected_accounts import create_get_connected_accounts_tool
from .luma import (
create_create_luma_event_tool,
create_list_luma_events_tool,
create_read_luma_event_tool,
)
from .mcp_tool import load_mcp_tools
from .notion import (
@ -95,6 +99,11 @@ from .report import create_generate_report_tool
from .resume import create_generate_resume_tool
from .scrape_webpage import create_scrape_webpage_tool
from .search_surfsense_docs import create_search_surfsense_docs_tool
from .teams import (
create_list_teams_channels_tool,
create_read_teams_messages_tool,
create_send_teams_message_tool,
)
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
from .video_presentation import create_generate_video_presentation_tool
from .web_search import create_web_search_tool
@ -114,6 +123,8 @@ class ToolDefinition:
factory: Callable that creates the tool. Receives a dict of dependencies.
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
enabled_by_default: Whether the tool is enabled when no explicit config is provided
required_connector: Searchable type string (e.g. ``"LINEAR_CONNECTOR"``)
that must be in ``available_connectors`` for the tool to be enabled.
"""
@ -123,6 +134,7 @@ class ToolDefinition:
requires: list[str] = field(default_factory=list)
enabled_by_default: bool = True
hidden: bool = False
required_connector: str | None = None
# =============================================================================
@ -221,6 +233,21 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
requires=["db_session"],
),
# =========================================================================
# SERVICE ACCOUNT DISCOVERY
# Generic tool for the LLM to discover connected accounts and resolve
# service-specific identifiers (e.g. Jira cloudId, Slack team, etc.)
# =========================================================================
ToolDefinition(
name="get_connected_accounts",
description="Discover connected accounts for a service and their metadata",
factory=lambda deps: create_get_connected_accounts_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
# =========================================================================
# MEMORY TOOL - single update_memory, private or team by thread_visibility
# =========================================================================
ToolDefinition(
@ -248,40 +275,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
],
),
# =========================================================================
# LINEAR TOOLS - create, update, delete issues
# Auto-disabled when no Linear connector is configured (see chat_deepagent.py)
# =========================================================================
ToolDefinition(
name="create_linear_issue",
description="Create a new issue in the user's Linear workspace",
factory=lambda deps: create_create_linear_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="update_linear_issue",
description="Update an existing indexed Linear issue",
factory=lambda deps: create_update_linear_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="delete_linear_issue",
description="Archive (delete) an existing indexed Linear issue",
factory=lambda deps: create_delete_linear_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
# =========================================================================
# NOTION TOOLS - create, update, delete pages
# Auto-disabled when no Notion connector is configured (see chat_deepagent.py)
# =========================================================================
@ -294,6 +287,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="NOTION_CONNECTOR",
),
ToolDefinition(
name="update_notion_page",
@ -304,6 +298,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="NOTION_CONNECTOR",
),
ToolDefinition(
name="delete_notion_page",
@ -314,6 +309,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="NOTION_CONNECTOR",
),
# =========================================================================
# GOOGLE DRIVE TOOLS - create files, delete files
@ -328,6 +324,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_DRIVE_FILE",
),
ToolDefinition(
name="delete_google_drive_file",
@ -338,6 +335,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_DRIVE_FILE",
),
# =========================================================================
# DROPBOX TOOLS - create and trash files
@ -352,6 +350,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DROPBOX_FILE",
),
ToolDefinition(
name="delete_dropbox_file",
@ -362,6 +361,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DROPBOX_FILE",
),
# =========================================================================
# ONEDRIVE TOOLS - create and trash files
@ -376,6 +376,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="ONEDRIVE_FILE",
),
ToolDefinition(
name="delete_onedrive_file",
@ -386,11 +387,23 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="ONEDRIVE_FILE",
),
# =========================================================================
# GOOGLE CALENDAR TOOLS - create, update, delete events
# GOOGLE CALENDAR TOOLS - search, create, update, delete events
# Auto-disabled when no Google Calendar connector is configured
# =========================================================================
ToolDefinition(
name="search_calendar_events",
description="Search Google Calendar events within a date range",
factory=lambda deps: create_search_calendar_events_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_CALENDAR_CONNECTOR",
),
ToolDefinition(
name="create_calendar_event",
description="Create a new event on Google Calendar",
@ -400,6 +413,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_CALENDAR_CONNECTOR",
),
ToolDefinition(
name="update_calendar_event",
@ -410,6 +424,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_CALENDAR_CONNECTOR",
),
ToolDefinition(
name="delete_calendar_event",
@ -420,11 +435,34 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_CALENDAR_CONNECTOR",
),
# =========================================================================
# GMAIL TOOLS - create drafts, update drafts, send emails, trash emails
# GMAIL TOOLS - search, read, create drafts, update drafts, send, trash
# Auto-disabled when no Gmail connector is configured
# =========================================================================
ToolDefinition(
name="search_gmail",
description="Search emails in Gmail using Gmail search syntax",
factory=lambda deps: create_search_gmail_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
ToolDefinition(
name="read_gmail_email",
description="Read the full content of a specific Gmail email",
factory=lambda deps: create_read_gmail_email_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
ToolDefinition(
name="create_gmail_draft",
description="Create a draft email in Gmail",
@ -434,6 +472,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
ToolDefinition(
name="send_gmail_email",
@ -444,6 +483,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
ToolDefinition(
name="trash_gmail_email",
@ -454,6 +494,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
ToolDefinition(
name="update_gmail_draft",
@ -464,40 +505,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
# =========================================================================
# JIRA TOOLS - create, update, delete issues
# Auto-disabled when no Jira connector is configured (see chat_deepagent.py)
# =========================================================================
ToolDefinition(
name="create_jira_issue",
description="Create a new issue in the user's Jira project",
factory=lambda deps: create_create_jira_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="update_jira_issue",
description="Update an existing indexed Jira issue",
factory=lambda deps: create_update_jira_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="delete_jira_issue",
description="Delete an existing indexed Jira issue",
factory=lambda deps: create_delete_jira_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
# =========================================================================
# CONFLUENCE TOOLS - create, update, delete pages
@ -512,6 +520,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="CONFLUENCE_CONNECTOR",
),
ToolDefinition(
name="update_confluence_page",
@ -522,6 +531,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="CONFLUENCE_CONNECTOR",
),
ToolDefinition(
name="delete_confluence_page",
@ -532,6 +542,118 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="CONFLUENCE_CONNECTOR",
),
# =========================================================================
# DISCORD TOOLS - list channels, read messages, send messages
# Auto-disabled when no Discord connector is configured
# =========================================================================
ToolDefinition(
name="list_discord_channels",
description="List text channels in the connected Discord server",
factory=lambda deps: create_list_discord_channels_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DISCORD_CONNECTOR",
),
ToolDefinition(
name="read_discord_messages",
description="Read recent messages from a Discord text channel",
factory=lambda deps: create_read_discord_messages_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DISCORD_CONNECTOR",
),
ToolDefinition(
name="send_discord_message",
description="Send a message to a Discord text channel",
factory=lambda deps: create_send_discord_message_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DISCORD_CONNECTOR",
),
# =========================================================================
# TEAMS TOOLS - list channels, read messages, send messages
# Auto-disabled when no Teams connector is configured
# =========================================================================
ToolDefinition(
name="list_teams_channels",
description="List Microsoft Teams and their channels",
factory=lambda deps: create_list_teams_channels_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="TEAMS_CONNECTOR",
),
ToolDefinition(
name="read_teams_messages",
description="Read recent messages from a Microsoft Teams channel",
factory=lambda deps: create_read_teams_messages_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="TEAMS_CONNECTOR",
),
ToolDefinition(
name="send_teams_message",
description="Send a message to a Microsoft Teams channel",
factory=lambda deps: create_send_teams_message_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="TEAMS_CONNECTOR",
),
# =========================================================================
# LUMA TOOLS - list events, read event details, create events
# Auto-disabled when no Luma connector is configured
# =========================================================================
ToolDefinition(
name="list_luma_events",
description="List upcoming and recent Luma events",
factory=lambda deps: create_list_luma_events_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="LUMA_CONNECTOR",
),
ToolDefinition(
name="read_luma_event",
description="Read detailed information about a specific Luma event",
factory=lambda deps: create_read_luma_event_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="LUMA_CONNECTOR",
),
ToolDefinition(
name="create_luma_event",
description="Create a new event on Luma",
factory=lambda deps: create_create_luma_event_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="LUMA_CONNECTOR",
),
]
@ -549,6 +671,22 @@ def get_tool_by_name(name: str) -> ToolDefinition | None:
return None
def get_connector_gated_tools(
available_connectors: list[str] | None,
) -> list[str]:
"""Return tool names to disable"""
if available_connectors is None:
available = set()
else:
available = set(available_connectors)
disabled: list[str] = []
for tool_def in BUILTIN_TOOLS:
if tool_def.required_connector and tool_def.required_connector not in available:
disabled.append(tool_def.name)
return disabled
def get_all_tool_names() -> list[str]:
"""Get names of all registered tools."""
return [tool_def.name for tool_def in BUILTIN_TOOLS]
@ -690,15 +828,15 @@ async def build_tools_async(
)
tools.extend(mcp_tools)
logging.info(
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
"Registered %d MCP tools: %s",
len(mcp_tools), [t.name for t in mcp_tools],
)
except Exception as e:
# Log error but don't fail - just continue without MCP tools
logging.exception(f"Failed to load MCP tools: {e!s}")
logging.exception("Failed to load MCP tools: %s", e)
# Log all tools being returned to agent
logging.info(
f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}",
"Total tools for agent: %d%s",
len(tools), [t.name for t in tools],
)
return tools

View file

@ -13,11 +13,13 @@ Uses the same short-lived session pattern as generate_report so no DB
connection is held during the long LLM call.
"""
import io
import logging
import re
from datetime import UTC, datetime
from typing import Any
import pypdf
import typst
from langchain_core.callbacks import dispatch_custom_event
from langchain_core.messages import HumanMessage
@ -114,7 +116,7 @@ _TEMPLATES: dict[str, dict[str, str]] = {
entries-highlights-nested-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt),
entries-highlights-space-left: 0cm,
entries-highlights-space-above: 0.08cm,
entries-highlights-space-between-items: 0.08cm,
entries-highlights-space-between-items: 0.02cm,
entries-highlights-space-between-bullet-and-text: 0.3em,
date: datetime(
year: {year},
@ -166,8 +168,8 @@ Available components (use ONLY these):
#summary([Short paragraph summary]) // Optional summary inside an entry
#content-area([Free-form content]) // Freeform text block
For skills sections, use bold labels directly:
#strong[Category:] item1, item2, item3
For skills sections, use one bullet per category label:
- #strong[Category:] item1, item2, item3
For simple list sections (e.g. Honors), use plain bullet points:
- Item one
@ -184,15 +186,19 @@ RULES:
- Every section MUST use == heading.
- Use #regular-entry() for experience, projects, publications, certifications, and similar entries.
- Use #education-entry() for education.
- Use #strong[Label:] for skills categories.
- For skills sections, use one bullet line per category with a bold label.
- Keep content professional, concise, and achievement-oriented.
- Use action verbs for bullet points (Led, Built, Designed, Reduced, etc.).
- This template works for ALL professions adapt sections to the user's field.
- Default behavior should prioritize concise one-page content.
""",
},
}
DEFAULT_TEMPLATE = "classic"
MIN_RESUME_PAGES = 1
MAX_RESUME_PAGES = 5
MAX_COMPRESSION_ATTEMPTS = 2
# ─── Template Helpers ─────────────────────────────────────────────────────────
@ -315,6 +321,8 @@ You are an expert resume writer. Generate professional resume content as Typst m
**User Information:**
{user_info}
**Target Maximum Pages:** {max_pages}
{user_instructions_section}
Generate the resume content now (starting with = Full Name):
@ -326,6 +334,8 @@ Apply ONLY the requested changes — do NOT rewrite sections that are not affect
{llm_reference}
**Target Maximum Pages:** {max_pages}
**Modification Instructions:** {user_instructions}
**EXISTING RESUME CONTENT:**
@ -352,6 +362,28 @@ The resume content you generated failed to compile. Fix the error while preservi
(starting with = Full Name), NOT the #import or #show rule:**
"""
_COMPRESS_TO_PAGE_LIMIT_PROMPT = """\
The resume compiles, but it exceeds the maximum allowed page count.
Compress the resume while preserving high-impact accomplishments and role relevance.
{llm_reference}
**Target Maximum Pages:** {max_pages}
**Current Page Count:** {actual_pages}
**Compression Attempt:** {attempt_number}
Compression priorities (in this order):
1) Keep recent, high-impact, role-relevant bullets.
2) Remove low-impact or redundant bullets.
3) Shorten verbose wording while preserving meaning.
4) Trim older or less relevant details before recent ones.
Return the complete updated Typst content (starting with = Full Name), and keep it at or below the target pages.
**EXISTING RESUME CONTENT:**
{previous_content}
"""
# ─── Helpers ─────────────────────────────────────────────────────────────────
@ -373,6 +405,24 @@ def _compile_typst(source: str) -> bytes:
return typst.compile(source.encode("utf-8"))
def _count_pdf_pages(pdf_bytes: bytes) -> int:
"""Count the number of pages in compiled PDF bytes."""
with io.BytesIO(pdf_bytes) as pdf_stream:
reader = pypdf.PdfReader(pdf_stream)
return len(reader.pages)
def _validate_max_pages(max_pages: int) -> int:
"""Validate and normalize max_pages input."""
if MIN_RESUME_PAGES <= max_pages <= MAX_RESUME_PAGES:
return max_pages
msg = (
f"max_pages must be between {MIN_RESUME_PAGES} and "
f"{MAX_RESUME_PAGES}. Received: {max_pages}"
)
raise ValueError(msg)
# ─── Tool Factory ───────────────────────────────────────────────────────────
@ -394,6 +444,7 @@ def create_generate_resume_tool(
user_info: str,
user_instructions: str | None = None,
parent_report_id: int | None = None,
max_pages: int = 1,
) -> dict[str, Any]:
"""
Generate a professional resume as a Typst document.
@ -426,6 +477,8 @@ def create_generate_resume_tool(
"use a modern style"). For revisions, describe what to change.
parent_report_id: ID of a previous resume to revise (creates
new version in the same version group).
max_pages: Maximum number of pages for the generated resume.
Defaults to 1. Allowed range: 1-5.
Returns:
Dict with status, report_id, title, and content_type.
@ -469,6 +522,19 @@ def create_generate_resume_tool(
return None
try:
try:
validated_max_pages = _validate_max_pages(max_pages)
except ValueError as e:
error_msg = str(e)
report_id = await _save_failed_report(error_msg)
return {
"status": "failed",
"error": error_msg,
"report_id": report_id,
"title": "Resume",
"content_type": "typst",
}
# ── Phase 1: READ ─────────────────────────────────────────────
async with shielded_async_session() as read_session:
if parent_report_id:
@ -512,6 +578,7 @@ def create_generate_resume_tool(
parent_body = _strip_header(parent_content)
prompt = _REVISION_PROMPT.format(
llm_reference=llm_reference,
max_pages=validated_max_pages,
user_instructions=user_instructions
or "Improve and refine the resume.",
previous_content=parent_body,
@ -524,6 +591,7 @@ def create_generate_resume_tool(
prompt = _RESUME_PROMPT.format(
llm_reference=llm_reference,
user_info=user_info,
max_pages=validated_max_pages,
user_instructions_section=user_instructions_section,
)
@ -551,49 +619,116 @@ def create_generate_resume_tool(
)
name = _extract_name(body) or "Resume"
header = _build_header(template, name)
typst_source = header + body
typst_source = ""
actual_pages = 0
compression_attempts = 0
target_page_met = False
compile_error: str | None = None
for attempt in range(2):
try:
_compile_typst(typst_source)
compile_error = None
break
except Exception as e:
compile_error = str(e)
logger.warning(
f"[generate_resume] Compile attempt {attempt + 1} failed: {compile_error}"
for compression_round in range(MAX_COMPRESSION_ATTEMPTS + 1):
header = _build_header(template, name)
typst_source = header + body
compile_error: str | None = None
pdf_bytes: bytes | None = None
for compile_attempt in range(2):
try:
pdf_bytes = _compile_typst(typst_source)
compile_error = None
break
except Exception as e:
compile_error = str(e)
logger.warning(
"[generate_resume] Compile attempt %s failed: %s",
compile_attempt + 1,
compile_error,
)
if compile_attempt == 0:
dispatch_custom_event(
"report_progress",
{
"phase": "fixing",
"message": "Fixing compilation issue...",
},
)
fix_prompt = _FIX_COMPILE_PROMPT.format(
llm_reference=llm_reference,
error=compile_error,
full_source=typst_source,
)
fix_response = await llm.ainvoke(
[HumanMessage(content=fix_prompt)]
)
if fix_response.content and isinstance(
fix_response.content, str
):
body = _strip_typst_fences(fix_response.content)
body = _strip_imports(body)
name = _extract_name(body) or name
header = _build_header(template, name)
typst_source = header + body
if compile_error or not pdf_bytes:
error_msg = (
"Typst compilation failed after 2 attempts: "
f"{compile_error or 'Unknown compile error'}"
)
report_id = await _save_failed_report(error_msg)
return {
"status": "failed",
"error": error_msg,
"report_id": report_id,
"title": "Resume",
"content_type": "typst",
}
if attempt == 0:
dispatch_custom_event(
"report_progress",
{
"phase": "fixing",
"message": "Fixing compilation issue...",
},
)
fix_prompt = _FIX_COMPILE_PROMPT.format(
llm_reference=llm_reference,
error=compile_error,
full_source=typst_source,
)
fix_response = await llm.ainvoke(
[HumanMessage(content=fix_prompt)]
)
if fix_response.content and isinstance(
fix_response.content, str
):
body = _strip_typst_fences(fix_response.content)
body = _strip_imports(body)
name = _extract_name(body) or name
header = _build_header(template, name)
typst_source = header + body
actual_pages = _count_pdf_pages(pdf_bytes)
if actual_pages <= validated_max_pages:
target_page_met = True
break
if compile_error:
if compression_round >= MAX_COMPRESSION_ATTEMPTS:
break
compression_attempts += 1
dispatch_custom_event(
"report_progress",
{
"phase": "compressing",
"message": f"Condensing resume to {validated_max_pages} page(s)...",
},
)
compress_prompt = _COMPRESS_TO_PAGE_LIMIT_PROMPT.format(
llm_reference=llm_reference,
max_pages=validated_max_pages,
actual_pages=actual_pages,
attempt_number=compression_attempts,
previous_content=body,
)
compress_response = await llm.ainvoke(
[HumanMessage(content=compress_prompt)]
)
if not compress_response.content or not isinstance(
compress_response.content, str
):
error_msg = "LLM returned empty content while compressing resume"
report_id = await _save_failed_report(error_msg)
return {
"status": "failed",
"error": error_msg,
"report_id": report_id,
"title": "Resume",
"content_type": "typst",
}
body = _strip_typst_fences(compress_response.content)
body = _strip_imports(body)
name = _extract_name(body) or name
if actual_pages > MAX_RESUME_PAGES:
error_msg = (
f"Typst compilation failed after 2 attempts: {compile_error}"
"Resume exceeds hard page limit after compression retries. "
f"Hard limit: <= {MAX_RESUME_PAGES} page(s), actual: {actual_pages}."
)
report_id = await _save_failed_report(error_msg)
return {
@ -616,6 +751,11 @@ def create_generate_resume_tool(
"status": "ready",
"word_count": len(typst_source.split()),
"char_count": len(typst_source),
"target_max_pages": validated_max_pages,
"actual_page_count": actual_pages,
"page_limit_enforced": True,
"compression_attempts": compression_attempts,
"target_page_met": target_page_met,
}
async with shielded_async_session() as write_session:
@ -647,7 +787,14 @@ def create_generate_resume_tool(
"title": resume_title,
"content_type": "typst",
"is_revision": bool(parent_content),
"message": f"Resume generated successfully: {resume_title}",
"message": (
f"Resume generated successfully: {resume_title}"
if target_page_met
else (
f"Resume generated, but could not fit the target of <= {validated_max_pages} "
f"page(s). Final length: {actual_pages} page(s)."
)
),
}
except Exception as e:

View file

@ -0,0 +1,15 @@
from app.agents.new_chat.tools.teams.list_channels import (
create_list_teams_channels_tool,
)
from app.agents.new_chat.tools.teams.read_messages import (
create_read_teams_messages_tool,
)
from app.agents.new_chat.tools.teams.send_message import (
create_send_teams_message_tool,
)
__all__ = [
"create_list_teams_channels_tool",
"create_read_teams_messages_tool",
"create_send_teams_message_tool",
]

View file

@ -0,0 +1,37 @@
"""Shared auth helper for Teams agent tools (Microsoft Graph REST API)."""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
GRAPH_API = "https://graph.microsoft.com/v1.0"
async def get_teams_connector(
db_session: AsyncSession,
search_space_id: int,
user_id: str,
) -> SearchSourceConnector | None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR,
)
)
return result.scalars().first()
async def get_access_token(
db_session: AsyncSession,
connector: SearchSourceConnector,
) -> str:
"""Get a valid Microsoft Graph access token, refreshing if expired."""
from app.connectors.teams_connector import TeamsConnector
tc = TeamsConnector(
session=db_session,
connector_id=connector.id,
)
return await tc._get_valid_token()

View file

@ -0,0 +1,77 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
def create_list_teams_channels_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def list_teams_channels() -> dict[str, Any]:
"""List all Microsoft Teams and their channels the user has access to.
Returns:
Dictionary with status and a list of teams, each containing
team_id, team_name, and a list of channels (id, name).
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
try:
connector = await get_teams_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
token = await get_access_token(db_session, connector)
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=20.0) as client:
teams_resp = await client.get(f"{GRAPH_API}/me/joinedTeams", headers=headers)
if teams_resp.status_code == 401:
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
if teams_resp.status_code != 200:
return {"status": "error", "message": f"Graph API error: {teams_resp.status_code}"}
teams_data = teams_resp.json().get("value", [])
result_teams = []
async with httpx.AsyncClient(timeout=20.0) as client:
for team in teams_data:
team_id = team["id"]
ch_resp = await client.get(
f"{GRAPH_API}/teams/{team_id}/channels",
headers=headers,
)
channels = []
if ch_resp.status_code == 200:
channels = [
{"id": ch["id"], "name": ch.get("displayName", "")}
for ch in ch_resp.json().get("value", [])
]
result_teams.append({
"team_id": team_id,
"team_name": team.get("displayName", ""),
"channels": channels,
})
return {"status": "success", "teams": result_teams, "total_teams": len(result_teams)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error listing Teams channels: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to list Teams channels."}
return list_teams_channels

View file

@ -0,0 +1,91 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
def create_read_teams_messages_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def read_teams_messages(
team_id: str,
channel_id: str,
limit: int = 25,
) -> dict[str, Any]:
"""Read recent messages from a Microsoft Teams channel.
Args:
team_id: The team ID (from list_teams_channels).
channel_id: The channel ID (from list_teams_channels).
limit: Number of messages to fetch (default 25, max 50).
Returns:
Dictionary with status and a list of messages including
id, sender, content, timestamp.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
limit = min(limit, 50)
try:
connector = await get_teams_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
token = await get_access_token(db_session, connector)
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.get(
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
headers={"Authorization": f"Bearer {token}"},
params={"$top": limit},
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
if resp.status_code == 403:
return {"status": "error", "message": "Insufficient permissions to read this channel."}
if resp.status_code != 200:
return {"status": "error", "message": f"Graph API error: {resp.status_code}"}
raw_msgs = resp.json().get("value", [])
messages = []
for m in raw_msgs:
sender = m.get("from", {})
user_info = sender.get("user", {}) if sender else {}
body = m.get("body", {})
messages.append({
"id": m.get("id"),
"sender": user_info.get("displayName", "Unknown"),
"content": body.get("content", ""),
"content_type": body.get("contentType", "text"),
"timestamp": m.get("createdDateTime", ""),
})
return {
"status": "success",
"team_id": team_id,
"channel_id": channel_id,
"messages": messages,
"total": len(messages),
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error reading Teams messages: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to read Teams messages."}
return read_teams_messages

View file

@ -0,0 +1,101 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
def create_send_teams_message_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def send_teams_message(
team_id: str,
channel_id: str,
content: str,
) -> dict[str, Any]:
"""Send a message to a Microsoft Teams channel.
Requires the ChannelMessage.Send OAuth scope. If the user gets a
permission error, they may need to re-authenticate with updated scopes.
Args:
team_id: The team ID (from list_teams_channels).
channel_id: The channel ID (from list_teams_channels).
content: The message text (HTML supported).
Returns:
Dictionary with status, message_id on success.
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
try:
connector = await get_teams_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
result = request_approval(
action_type="teams_send_message",
tool_name="send_teams_message",
params={"team_id": team_id, "channel_id": channel_id, "content": content},
context={"connector_id": connector.id},
)
if result.rejected:
return {"status": "rejected", "message": "User declined. Message was not sent."}
final_content = result.params.get("content", content)
final_team = result.params.get("team_id", team_id)
final_channel = result.params.get("channel_id", channel_id)
token = await get_access_token(db_session, connector)
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json={"body": {"content": final_content}},
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
if resp.status_code == 403:
return {
"status": "insufficient_permissions",
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
}
if resp.status_code not in (200, 201):
return {"status": "error", "message": f"Graph API error: {resp.status_code}{resp.text[:200]}"}
msg_data = resp.json()
return {
"status": "success",
"message_id": msg_data.get("id"),
"message": f"Message sent to Teams channel.",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error sending Teams message: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to send Teams message."}
return send_teams_message

View file

@ -0,0 +1,41 @@
"""Standardised response dict factories for LangChain agent tools."""
from __future__ import annotations
from typing import Any
class ToolResponse:
@staticmethod
def success(message: str, **data: Any) -> dict[str, Any]:
return {"status": "success", "message": message, **data}
@staticmethod
def error(error: str, **data: Any) -> dict[str, Any]:
return {"status": "error", "error": error, **data}
@staticmethod
def auth_error(service: str, **data: Any) -> dict[str, Any]:
return {
"status": "auth_error",
"error": (
f"{service} authentication has expired or been revoked. "
"Please re-connect the integration in Settings → Connectors."
),
**data,
}
@staticmethod
def rejected(message: str = "Action was declined by the user.") -> dict[str, Any]:
return {"status": "rejected", "message": message}
@staticmethod
def not_found(
resource: str, identifier: str, **data: Any
) -> dict[str, Any]:
return {
"status": "not_found",
"error": f"{resource} '{identifier}' was not found.",
**data,
}

View file

@ -141,6 +141,15 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons
exc.status_code,
message,
)
elif exc.status_code >= 400:
_error_logger.warning(
"[%s] %s %s - HTTPException %d: %s",
rid,
request.method,
request.url.path,
exc.status_code,
message,
)
if should_sanitize:
message = GENERIC_5XX_MESSAGE
err_code = "INTERNAL_ERROR"
@ -170,6 +179,15 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons
exc.status_code,
detail,
)
elif exc.status_code >= 400:
_error_logger.warning(
"[%s] %s %s - HTTPException %d: %s",
rid,
request.method,
request.url.path,
exc.status_code,
detail,
)
if should_sanitize:
detail = GENERIC_5XX_MESSAGE
code = _status_to_code(exc.status_code, detail)

View file

@ -136,20 +136,12 @@ celery_app.conf.update(
# never block fast user-facing tasks (file uploads, podcasts, etc.)
task_routes={
# Connector indexing tasks → connectors queue
"index_slack_messages": {"queue": CONNECTORS_QUEUE},
"index_notion_pages": {"queue": CONNECTORS_QUEUE},
"index_github_repos": {"queue": CONNECTORS_QUEUE},
"index_linear_issues": {"queue": CONNECTORS_QUEUE},
"index_jira_issues": {"queue": CONNECTORS_QUEUE},
"index_confluence_pages": {"queue": CONNECTORS_QUEUE},
"index_clickup_tasks": {"queue": CONNECTORS_QUEUE},
"index_google_calendar_events": {"queue": CONNECTORS_QUEUE},
"index_airtable_records": {"queue": CONNECTORS_QUEUE},
"index_google_gmail_messages": {"queue": CONNECTORS_QUEUE},
"index_google_drive_files": {"queue": CONNECTORS_QUEUE},
"index_discord_messages": {"queue": CONNECTORS_QUEUE},
"index_teams_messages": {"queue": CONNECTORS_QUEUE},
"index_luma_events": {"queue": CONNECTORS_QUEUE},
"index_elasticsearch_documents": {"queue": CONNECTORS_QUEUE},
"index_crawled_urls": {"queue": CONNECTORS_QUEUE},
"index_bookstack_pages": {"queue": CONNECTORS_QUEUE},

View file

@ -339,6 +339,9 @@ class Config:
# self-hosted: Full access to local file system connectors (Obsidian, etc.)
# cloud: Only cloud-based connectors available
DEPLOYMENT_MODE = os.getenv("SURFSENSE_DEPLOYMENT_MODE", "self-hosted")
ENABLE_DESKTOP_LOCAL_FILESYSTEM = (
os.getenv("ENABLE_DESKTOP_LOCAL_FILESYSTEM", "FALSE").upper() == "TRUE"
)
@classmethod
def is_self_hosted(cls) -> bool:

View file

@ -0,0 +1,98 @@
"""Standard exception hierarchy for all connectors.
ConnectorError
ConnectorAuthError (401/403 non-retryable)
ConnectorRateLimitError (429 retryable, carries ``retry_after``)
ConnectorTimeoutError (timeout/504 retryable)
ConnectorAPIError (5xx or unexpected retryable when >= 500)
"""
from __future__ import annotations
from typing import Any
class ConnectorError(Exception):
def __init__(
self,
message: str,
*,
service: str = "",
status_code: int | None = None,
response_body: Any = None,
) -> None:
super().__init__(message)
self.service = service
self.status_code = status_code
self.response_body = response_body
@property
def retryable(self) -> bool:
return False
class ConnectorAuthError(ConnectorError):
"""Token expired, revoked, insufficient scopes, or needs re-auth (401/403)."""
@property
def retryable(self) -> bool:
return False
class ConnectorRateLimitError(ConnectorError):
"""429 Too Many Requests."""
def __init__(
self,
message: str = "Rate limited",
*,
service: str = "",
retry_after: float | None = None,
status_code: int = 429,
response_body: Any = None,
) -> None:
super().__init__(
message,
service=service,
status_code=status_code,
response_body=response_body,
)
self.retry_after = retry_after
@property
def retryable(self) -> bool:
return True
class ConnectorTimeoutError(ConnectorError):
"""Request timeout or gateway timeout (504)."""
def __init__(
self,
message: str = "Request timed out",
*,
service: str = "",
status_code: int | None = None,
response_body: Any = None,
) -> None:
super().__init__(
message,
service=service,
status_code=status_code,
response_body=response_body,
)
@property
def retryable(self) -> bool:
return True
class ConnectorAPIError(ConnectorError):
"""Generic API error (5xx or unexpected status codes)."""
@property
def retryable(self) -> bool:
if self.status_code is not None:
return self.status_code >= 500
return False

View file

@ -30,6 +30,7 @@ from .jira_add_connector_route import router as jira_add_connector_router
from .linear_add_connector_route import router as linear_add_connector_router
from .logs_routes import router as logs_router
from .luma_add_connector_route import router as luma_add_connector_router
from .mcp_oauth_route import router as mcp_oauth_router
from .memory_routes import router as memory_router
from .model_list_routes import router as model_list_router
from .new_chat_routes import router as new_chat_router
@ -97,6 +98,7 @@ router.include_router(logs_router)
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
router.include_router(surfsense_docs_router) # Surfsense documentation for citations
router.include_router(notifications_router) # Notifications with Zero sync
router.include_router(mcp_oauth_router) # MCP OAuth 2.1 for Linear, Jira, ClickUp, Slack, Airtable
router.include_router(composio_router) # Composio OAuth and toolkit management
router.include_router(public_chat_router) # Public chat sharing and cloning
router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages

View file

@ -311,7 +311,7 @@ async def airtable_callback(
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=SearchSourceConnectorType.AIRTABLE_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=credentials_dict,
search_space_id=space_id,
user_id=user_id,

View file

@ -301,7 +301,7 @@ async def clickup_callback(
# Update existing connector
existing_connector.config = connector_config
existing_connector.name = "ClickUp Connector"
existing_connector.is_indexable = True
existing_connector.is_indexable = False
logger.info(
f"Updated existing ClickUp connector for user {user_id} in space {space_id}"
)
@ -310,7 +310,7 @@ async def clickup_callback(
new_connector = SearchSourceConnector(
name="ClickUp Connector",
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,

View file

@ -326,7 +326,7 @@ async def discord_callback(
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=SearchSourceConnectorType.DISCORD_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,

View file

@ -340,7 +340,7 @@ async def calendar_callback(
config=creds_dict,
search_space_id=space_id,
user_id=user_id,
is_indexable=True,
is_indexable=False,
)
session.add(db_connector)
await session.commit()

View file

@ -371,7 +371,7 @@ async def gmail_callback(
config=creds_dict,
search_space_id=space_id,
user_id=user_id,
is_indexable=True,
is_indexable=False,
)
session.add(db_connector)
await session.commit()

View file

@ -386,7 +386,7 @@ async def jira_callback(
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=SearchSourceConnectorType.JIRA_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,

View file

@ -399,7 +399,7 @@ async def linear_callback(
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,

View file

@ -61,7 +61,7 @@ async def add_luma_connector(
if existing_connector:
# Update existing connector with new API key
existing_connector.config = {"api_key": request.api_key}
existing_connector.is_indexable = True
existing_connector.is_indexable = False
await session.commit()
await session.refresh(existing_connector)
@ -82,7 +82,7 @@ async def add_luma_connector(
config={"api_key": request.api_key},
search_space_id=request.space_id,
user_id=user.id,
is_indexable=True,
is_indexable=False,
)
session.add(db_connector)

View file

@ -0,0 +1,601 @@
"""Generic MCP OAuth 2.1 route for services with official MCP servers.
Handles the full flow: discovery DCR PKCE authorization token exchange
MCP_CONNECTOR creation. Currently supports Linear, Jira, ClickUp, Slack,
and Airtable.
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from urllib.parse import urlencode
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import RedirectResponse
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.utils.connector_naming import generate_unique_connector_name
from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_pkce_pair
logger = logging.getLogger(__name__)
router = APIRouter()
async def _fetch_account_metadata(
service_key: str, access_token: str, token_json: dict[str, Any],
) -> dict[str, Any]:
"""Fetch display-friendly account metadata after a successful token exchange.
DCR services (Linear, Jira, ClickUp) issue MCP-scoped tokens that cannot
call their standard REST/GraphQL APIs metadata discovery for those
happens at runtime through MCP tools instead.
Pre-configured services (Slack, Airtable) use standard OAuth tokens that
*can* call their APIs, so we extract metadata here.
Failures are logged but never block connector creation.
"""
from app.services.mcp_oauth.registry import MCP_SERVICES
svc = MCP_SERVICES.get(service_key)
if not svc or svc.supports_dcr:
return {}
import httpx
meta: dict[str, Any] = {}
try:
if service_key == "slack":
team_info = token_json.get("team", {})
meta["team_id"] = team_info.get("id", "")
# TODO: oauth.v2.user.access only returns team.id, not
# team.name. To populate team_name, add "team:read" scope
# and call GET /api/team.info here.
meta["team_name"] = team_info.get("name", "")
if meta["team_name"]:
meta["display_name"] = meta["team_name"]
elif meta["team_id"]:
meta["display_name"] = f"Slack ({meta['team_id']})"
elif service_key == "airtable":
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
"https://api.airtable.com/v0/meta/whoami",
headers={"Authorization": f"Bearer {access_token}"},
)
if resp.status_code == 200:
whoami = resp.json()
meta["user_id"] = whoami.get("id", "")
meta["user_email"] = whoami.get("email", "")
meta["display_name"] = whoami.get("email", "Airtable")
else:
logger.warning(
"Airtable whoami API returned %d (non-blocking)", resp.status_code,
)
except Exception:
logger.warning(
"Failed to fetch account metadata for %s (non-blocking)",
service_key,
exc_info=True,
)
return meta
_state_manager: OAuthStateManager | None = None
_token_encryption: TokenEncryption | None = None
def _get_state_manager() -> OAuthStateManager:
global _state_manager
if _state_manager is None:
if not config.SECRET_KEY:
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
_state_manager = OAuthStateManager(config.SECRET_KEY)
return _state_manager
def _get_token_encryption() -> TokenEncryption:
global _token_encryption
if _token_encryption is None:
if not config.SECRET_KEY:
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
_token_encryption = TokenEncryption(config.SECRET_KEY)
return _token_encryption
def _build_redirect_uri(service: str) -> str:
base = config.BACKEND_URL or "http://localhost:8000"
return f"{base.rstrip('/')}/api/v1/auth/mcp/{service}/connector/callback"
def _frontend_redirect(
space_id: int | None,
*,
success: bool = False,
connector_id: int | None = None,
error: str | None = None,
service: str = "mcp",
) -> RedirectResponse:
if success and space_id:
qs = f"success=true&connector={service}-mcp-connector"
if connector_id:
qs += f"&connectorId={connector_id}"
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
)
if error and space_id:
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
)
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
# ---------------------------------------------------------------------------
# /add — start MCP OAuth flow
# ---------------------------------------------------------------------------
@router.get("/auth/mcp/{service}/connector/add")
async def connect_mcp_service(
service: str,
space_id: int,
user: User = Depends(current_active_user),
):
from app.services.mcp_oauth.registry import get_service
svc = get_service(service)
if not svc:
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
try:
from app.services.mcp_oauth.discovery import (
discover_oauth_metadata,
register_client,
)
metadata = await discover_oauth_metadata(
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
)
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
registration_endpoint = metadata.get("registration_endpoint")
if not auth_endpoint or not token_endpoint:
raise HTTPException(
status_code=502,
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
)
redirect_uri = _build_redirect_uri(service)
if svc.supports_dcr and registration_endpoint:
dcr = await register_client(registration_endpoint, redirect_uri)
client_id = dcr.get("client_id")
client_secret = dcr.get("client_secret", "")
if not client_id:
raise HTTPException(
status_code=502,
detail=f"DCR for {svc.name} did not return a client_id.",
)
elif svc.client_id_env:
client_id = getattr(config, svc.client_id_env, None)
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
if not client_id:
raise HTTPException(
status_code=500,
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
)
else:
raise HTTPException(
status_code=502,
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
)
verifier, challenge = generate_pkce_pair()
enc = _get_token_encryption()
state = _get_state_manager().generate_secure_state(
space_id,
user.id,
service=service,
code_verifier=verifier,
mcp_client_id=client_id,
mcp_client_secret=enc.encrypt_token(client_secret) if client_secret else "",
mcp_token_endpoint=token_endpoint,
mcp_url=svc.mcp_url,
)
auth_params: dict[str, str] = {
"client_id": client_id,
"response_type": "code",
"redirect_uri": redirect_uri,
"code_challenge": challenge,
"code_challenge_method": "S256",
"state": state,
}
if svc.scopes:
auth_params[svc.scope_param] = " ".join(svc.scopes)
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
logger.info(
"Generated %s MCP OAuth URL for user %s, space %s",
svc.name, user.id, space_id,
)
return {"auth_url": auth_url}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to initiate %s MCP OAuth: %s", service, e, exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to initiate {service} MCP OAuth.",
) from e
# ---------------------------------------------------------------------------
# /callback — handle OAuth redirect
# ---------------------------------------------------------------------------
@router.get("/auth/mcp/{service}/connector/callback")
async def mcp_oauth_callback(
service: str,
code: str | None = None,
error: str | None = None,
state: str | None = None,
session: AsyncSession = Depends(get_async_session),
):
if error:
logger.warning("%s MCP OAuth error: %s", service, error)
space_id = None
if state:
try:
data = _get_state_manager().validate_state(state)
space_id = data.get("space_id")
except Exception:
pass
return _frontend_redirect(
space_id, error=f"{service}_mcp_oauth_denied", service=service,
)
if not code:
raise HTTPException(status_code=400, detail="Missing authorization code")
if not state:
raise HTTPException(status_code=400, detail="Missing state parameter")
data = _get_state_manager().validate_state(state)
user_id = UUID(data["user_id"])
space_id = data["space_id"]
svc_key = data.get("service", service)
if svc_key != service:
raise HTTPException(status_code=400, detail="State/path service mismatch")
from app.services.mcp_oauth.registry import get_service
svc = get_service(svc_key)
if not svc:
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {svc_key}")
try:
from app.services.mcp_oauth.discovery import exchange_code_for_tokens
enc = _get_token_encryption()
client_id = data["mcp_client_id"]
client_secret = (
enc.decrypt_token(data["mcp_client_secret"])
if data.get("mcp_client_secret")
else ""
)
token_endpoint = data["mcp_token_endpoint"]
code_verifier = data["code_verifier"]
mcp_url = data["mcp_url"]
redirect_uri = _build_redirect_uri(service)
token_json = await exchange_code_for_tokens(
token_endpoint=token_endpoint,
code=code,
redirect_uri=redirect_uri,
client_id=client_id,
client_secret=client_secret,
code_verifier=code_verifier,
)
access_token = token_json.get("access_token")
refresh_token = token_json.get("refresh_token")
expires_in = token_json.get("expires_in")
scope = token_json.get("scope")
if not access_token and "authed_user" in token_json:
authed = token_json["authed_user"]
access_token = authed.get("access_token")
refresh_token = refresh_token or authed.get("refresh_token")
scope = scope or authed.get("scope")
expires_in = expires_in or authed.get("expires_in")
if not access_token:
raise HTTPException(
status_code=400,
detail=f"No access token received from {svc.name}.",
)
expires_at = None
if expires_in:
expires_at = datetime.now(UTC) + timedelta(
seconds=int(expires_in)
)
connector_config = {
"server_config": {
"transport": "streamable-http",
"url": mcp_url,
},
"mcp_service": svc_key,
"mcp_oauth": {
"client_id": client_id,
"client_secret": enc.encrypt_token(client_secret) if client_secret else "",
"token_endpoint": token_endpoint,
"access_token": enc.encrypt_token(access_token),
"refresh_token": enc.encrypt_token(refresh_token) if refresh_token else None,
"expires_at": expires_at.isoformat() if expires_at else None,
"scope": scope,
},
"_token_encrypted": True,
}
account_meta = await _fetch_account_metadata(svc_key, access_token, token_json)
if account_meta:
_SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email",
"workspace_id", "workspace_name", "organization_name",
"organization_url_key", "cloud_id", "site_name", "base_url"}
for k, v in account_meta.items():
if k in _SAFE_META_KEYS:
connector_config[k] = v
logger.info(
"Stored account metadata for %s: display_name=%s",
svc_key, account_meta.get("display_name", ""),
)
# ---- Re-auth path ----
db_connector_type = SearchSourceConnectorType(svc.connector_type)
reauth_connector_id = data.get("connector_id")
if reauth_connector_id:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == reauth_connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type == db_connector_type,
)
)
db_connector = result.scalars().first()
if not db_connector:
raise HTTPException(
status_code=404,
detail="Connector not found during re-auth",
)
db_connector.config = connector_config
flag_modified(db_connector, "config")
await session.commit()
await session.refresh(db_connector)
_invalidate_cache(space_id)
logger.info(
"Re-authenticated %s MCP connector %s for user %s",
svc.name, db_connector.id, user_id,
)
reauth_return_url = data.get("return_url")
if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"):
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
)
return _frontend_redirect(
space_id, success=True, connector_id=db_connector.id, service=service,
)
# ---- New connector path ----
naming_identifier = account_meta.get("display_name")
connector_name = await generate_unique_connector_name(
session,
db_connector_type,
space_id,
user_id,
naming_identifier,
)
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=db_connector_type,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,
)
session.add(new_connector)
try:
await session.commit()
except IntegrityError as e:
await session.rollback()
raise HTTPException(
status_code=409, detail="A connector for this service already exists.",
) from e
_invalidate_cache(space_id)
logger.info(
"Created %s MCP connector %s for user %s in space %s",
svc.name, new_connector.id, user_id, space_id,
)
return _frontend_redirect(
space_id, success=True, connector_id=new_connector.id, service=service,
)
except HTTPException:
raise
except Exception as e:
logger.error(
"Failed to complete %s MCP OAuth: %s", service, e, exc_info=True,
)
raise HTTPException(
status_code=500,
detail=f"Failed to complete {service} MCP OAuth.",
) from e
# ---------------------------------------------------------------------------
# /reauth — re-authenticate an existing MCP connector
# ---------------------------------------------------------------------------
@router.get("/auth/mcp/{service}/connector/reauth")
async def reauth_mcp_service(
service: str,
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
from app.services.mcp_oauth.registry import get_service
svc = get_service(service)
if not svc:
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
db_connector_type = SearchSourceConnectorType(svc.connector_type)
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type == db_connector_type,
)
)
if not result.scalars().first():
raise HTTPException(
status_code=404, detail="Connector not found or access denied",
)
try:
from app.services.mcp_oauth.discovery import (
discover_oauth_metadata,
register_client,
)
metadata = await discover_oauth_metadata(
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
)
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
registration_endpoint = metadata.get("registration_endpoint")
if not auth_endpoint or not token_endpoint:
raise HTTPException(
status_code=502,
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
)
redirect_uri = _build_redirect_uri(service)
if svc.supports_dcr and registration_endpoint:
dcr = await register_client(registration_endpoint, redirect_uri)
client_id = dcr.get("client_id")
client_secret = dcr.get("client_secret", "")
if not client_id:
raise HTTPException(
status_code=502,
detail=f"DCR for {svc.name} did not return a client_id.",
)
elif svc.client_id_env:
client_id = getattr(config, svc.client_id_env, None)
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
if not client_id:
raise HTTPException(
status_code=500,
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
)
else:
raise HTTPException(
status_code=502,
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
)
verifier, challenge = generate_pkce_pair()
enc = _get_token_encryption()
extra: dict = {
"service": service,
"code_verifier": verifier,
"mcp_client_id": client_id,
"mcp_client_secret": enc.encrypt_token(client_secret) if client_secret else "",
"mcp_token_endpoint": token_endpoint,
"mcp_url": svc.mcp_url,
"connector_id": connector_id,
}
if return_url and return_url.startswith("/"):
extra["return_url"] = return_url
state = _get_state_manager().generate_secure_state(
space_id, user.id, **extra,
)
auth_params: dict[str, str] = {
"client_id": client_id,
"response_type": "code",
"redirect_uri": redirect_uri,
"code_challenge": challenge,
"code_challenge_method": "S256",
"state": state,
}
if svc.scopes:
auth_params[svc.scope_param] = " ".join(svc.scopes)
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
logger.info(
"Initiating %s MCP re-auth for user %s, connector %s",
svc.name, user.id, connector_id,
)
return {"auth_url": auth_url}
except HTTPException:
raise
except Exception as e:
logger.error(
"Failed to initiate %s MCP re-auth: %s", service, e, exc_info=True,
)
raise HTTPException(
status_code=500,
detail=f"Failed to initiate {service} MCP re-auth.",
) from e
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _invalidate_cache(space_id: int) -> None:
try:
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(space_id)
except Exception:
logger.debug("MCP cache invalidation skipped", exc_info=True)

View file

@ -22,6 +22,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.agents.new_chat.filesystem_selection import (
ClientPlatform,
LocalFilesystemMount,
FilesystemMode,
FilesystemSelection,
)
from app.config import config
from app.db import (
ChatComment,
ChatVisibility,
@ -36,6 +43,7 @@ from app.db import (
)
from app.schemas.new_chat import (
AgentToolInfo,
LocalFilesystemMountPayload,
NewChatMessageRead,
NewChatRequest,
NewChatThreadCreate,
@ -63,6 +71,67 @@ _background_tasks: set[asyncio.Task] = set()
router = APIRouter()
def _resolve_filesystem_selection(
*,
mode: str,
client_platform: str,
local_mounts: list[LocalFilesystemMountPayload] | None,
) -> FilesystemSelection:
"""Validate and normalize filesystem mode settings from request payload."""
try:
resolved_mode = FilesystemMode(mode)
except ValueError as exc:
raise HTTPException(status_code=400, detail="Invalid filesystem_mode") from exc
try:
resolved_platform = ClientPlatform(client_platform)
except ValueError as exc:
raise HTTPException(status_code=400, detail="Invalid client_platform") from exc
if resolved_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
if not config.ENABLE_DESKTOP_LOCAL_FILESYSTEM:
raise HTTPException(
status_code=400,
detail="Desktop local filesystem mode is disabled on this deployment.",
)
if resolved_platform != ClientPlatform.DESKTOP:
raise HTTPException(
status_code=400,
detail="desktop_local_folder mode is only available on desktop runtime.",
)
normalized_mounts: list[tuple[str, str]] = []
seen_mounts: set[str] = set()
for mount in local_mounts or []:
mount_id = mount.mount_id.strip()
root_path = mount.root_path.strip()
if not mount_id or not root_path:
continue
if mount_id in seen_mounts:
continue
seen_mounts.add(mount_id)
normalized_mounts.append((mount_id, root_path))
if not normalized_mounts:
raise HTTPException(
status_code=400,
detail=(
"local_filesystem_mounts must include at least one mount for "
"desktop_local_folder mode."
),
)
return FilesystemSelection(
mode=resolved_mode,
client_platform=resolved_platform,
local_mounts=tuple(
LocalFilesystemMount(mount_id=mount_id, root_path=root_path)
for mount_id, root_path in normalized_mounts
),
)
return FilesystemSelection(
mode=FilesystemMode.CLOUD,
client_platform=resolved_platform,
)
def _try_delete_sandbox(thread_id: int) -> None:
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
from app.agents.new_chat.sandbox import (
@ -1098,6 +1167,7 @@ async def list_agent_tools(
@router.post("/new_chat")
async def handle_new_chat(
request: NewChatRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
@ -1133,6 +1203,11 @@ async def handle_new_chat(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
local_mounts=request.local_filesystem_mounts,
)
# Get search space to check LLM config preferences
search_space_result = await session.execute(
@ -1175,6 +1250,8 @@ async def handle_new_chat(
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
disabled_tools=request.disabled_tools,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
),
media_type="text/event-stream",
headers={
@ -1202,6 +1279,7 @@ async def handle_new_chat(
async def regenerate_response(
thread_id: int,
request: RegenerateRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
@ -1247,6 +1325,11 @@ async def regenerate_response(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
local_mounts=request.local_filesystem_mounts,
)
# Get the checkpointer and state history
checkpointer = await get_checkpointer()
@ -1412,6 +1495,8 @@ async def regenerate_response(
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
disabled_tools=request.disabled_tools,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
):
yield chunk
streaming_completed = True
@ -1477,6 +1562,7 @@ async def regenerate_response(
async def resume_chat(
thread_id: int,
request: ResumeRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
@ -1498,6 +1584,11 @@ async def resume_chat(
)
await check_thread_access(session, thread, user)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
local_mounts=request.local_filesystem_mounts,
)
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
@ -1526,6 +1617,8 @@ async def resume_chat(
user_id=str(user.id),
llm_config_id=llm_config_id,
thread_visibility=thread.visibility,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
),
media_type="text/event-stream",
headers={

View file

@ -0,0 +1,620 @@
"""Reusable base for OAuth 2.0 connector routes.
Subclasses override ``fetch_account_info``, ``build_connector_config``,
and ``get_connector_display_name`` to customise provider-specific behaviour.
Call ``build_router()`` to get a FastAPI ``APIRouter`` with ``/connector/add``,
``/connector/callback``, and ``/connector/reauth`` endpoints.
"""
from __future__ import annotations
import base64
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from urllib.parse import urlencode
from uuid import UUID
import httpx
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import RedirectResponse
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
)
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
logger = logging.getLogger(__name__)
class OAuthConnectorRoute:
def __init__(
self,
*,
provider_name: str,
connector_type: SearchSourceConnectorType,
authorize_url: str,
token_url: str,
client_id_env: str,
client_secret_env: str,
redirect_uri_env: str,
scopes: list[str],
auth_prefix: str,
use_pkce: bool = False,
token_auth_method: str = "body",
is_indexable: bool = True,
extra_auth_params: dict[str, str] | None = None,
) -> None:
self.provider_name = provider_name
self.connector_type = connector_type
self.authorize_url = authorize_url
self.token_url = token_url
self.client_id_env = client_id_env
self.client_secret_env = client_secret_env
self.redirect_uri_env = redirect_uri_env
self.scopes = scopes
self.auth_prefix = auth_prefix.rstrip("/")
self.use_pkce = use_pkce
self.token_auth_method = token_auth_method
self.is_indexable = is_indexable
self.extra_auth_params = extra_auth_params or {}
self._state_manager: OAuthStateManager | None = None
self._token_encryption: TokenEncryption | None = None
def _get_client_id(self) -> str:
value = getattr(config, self.client_id_env, None)
if not value:
raise HTTPException(
status_code=500,
detail=f"{self.provider_name.title()} OAuth not configured "
f"({self.client_id_env} missing).",
)
return value
def _get_client_secret(self) -> str:
value = getattr(config, self.client_secret_env, None)
if not value:
raise HTTPException(
status_code=500,
detail=f"{self.provider_name.title()} OAuth not configured "
f"({self.client_secret_env} missing).",
)
return value
def _get_redirect_uri(self) -> str:
value = getattr(config, self.redirect_uri_env, None)
if not value:
raise HTTPException(
status_code=500,
detail=f"{self.redirect_uri_env} not configured.",
)
return value
def _get_state_manager(self) -> OAuthStateManager:
if self._state_manager is None:
if not config.SECRET_KEY:
raise HTTPException(
status_code=500,
detail="SECRET_KEY not configured for OAuth security.",
)
self._state_manager = OAuthStateManager(config.SECRET_KEY)
return self._state_manager
def _get_token_encryption(self) -> TokenEncryption:
if self._token_encryption is None:
if not config.SECRET_KEY:
raise HTTPException(
status_code=500,
detail="SECRET_KEY not configured for token encryption.",
)
self._token_encryption = TokenEncryption(config.SECRET_KEY)
return self._token_encryption
def _frontend_redirect(
self,
space_id: int | None,
*,
success: bool = False,
connector_id: int | None = None,
error: str | None = None,
) -> RedirectResponse:
if success and space_id:
connector_slug = f"{self.provider_name}-connector"
qs = f"success=true&connector={connector_slug}"
if connector_id:
qs += f"&connectorId={connector_id}"
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
)
if error and space_id:
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
)
if error:
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error={error}"
)
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
async def fetch_account_info(self, access_token: str) -> dict[str, Any]:
"""Override to fetch account/workspace info after token exchange.
Return dict is merged into connector config; key ``"name"`` is used
for the display name and dedup.
"""
return {}
def build_connector_config(
self,
token_json: dict[str, Any],
account_info: dict[str, Any],
encryption: TokenEncryption,
) -> dict[str, Any]:
"""Override for custom config shapes. Default: standard encrypted OAuth fields."""
access_token = token_json.get("access_token", "")
refresh_token = token_json.get("refresh_token")
expires_at = None
if token_json.get("expires_in"):
expires_at = datetime.now(UTC) + timedelta(
seconds=int(token_json["expires_in"])
)
cfg: dict[str, Any] = {
"access_token": encryption.encrypt_token(access_token),
"refresh_token": (
encryption.encrypt_token(refresh_token) if refresh_token else None
),
"token_type": token_json.get("token_type", "Bearer"),
"expires_in": token_json.get("expires_in"),
"expires_at": expires_at.isoformat() if expires_at else None,
"scope": token_json.get("scope"),
"_token_encrypted": True,
}
cfg.update(account_info)
return cfg
def get_connector_display_name(self, account_info: dict[str, Any]) -> str:
return str(account_info.get("name", self.provider_name.title()))
async def on_token_refresh_failure(
self,
session: AsyncSession,
connector: SearchSourceConnector,
) -> None:
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await session.commit()
await session.refresh(connector)
except Exception:
logger.warning(
"Failed to persist auth_expired flag for connector %s",
connector.id,
exc_info=True,
)
async def _exchange_code(
self, code: str, extra_state: dict[str, Any]
) -> dict[str, Any]:
client_id = self._get_client_id()
client_secret = self._get_client_secret()
redirect_uri = self._get_redirect_uri()
headers: dict[str, str] = {
"Content-Type": "application/x-www-form-urlencoded",
}
body: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
}
if self.token_auth_method == "basic":
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
headers["Authorization"] = f"Basic {creds}"
else:
body["client_id"] = client_id
body["client_secret"] = client_secret
if self.use_pkce:
verifier = extra_state.get("code_verifier")
if verifier:
body["code_verifier"] = verifier
async with httpx.AsyncClient() as client:
resp = await client.post(
self.token_url, data=body, headers=headers, timeout=30.0
)
if resp.status_code != 200:
detail = resp.text
try:
detail = resp.json().get("error_description", detail)
except Exception:
pass
raise HTTPException(
status_code=400, detail=f"Token exchange failed: {detail}"
)
return resp.json()
async def refresh_token(
self, session: AsyncSession, connector: SearchSourceConnector
) -> SearchSourceConnector:
encryption = self._get_token_encryption()
is_encrypted = connector.config.get("_token_encrypted", False)
refresh_tok = connector.config.get("refresh_token")
if is_encrypted and refresh_tok:
try:
refresh_tok = encryption.decrypt_token(refresh_tok)
except Exception as e:
logger.error("Failed to decrypt refresh token: %s", e)
raise HTTPException(
status_code=500, detail="Failed to decrypt stored refresh token"
) from e
if not refresh_tok:
await self.on_token_refresh_failure(session, connector)
raise HTTPException(
status_code=400,
detail="No refresh token available. Please re-authenticate.",
)
client_id = self._get_client_id()
client_secret = self._get_client_secret()
headers: dict[str, str] = {
"Content-Type": "application/x-www-form-urlencoded",
}
body: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": refresh_tok,
}
if self.token_auth_method == "basic":
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
headers["Authorization"] = f"Basic {creds}"
else:
body["client_id"] = client_id
body["client_secret"] = client_secret
async with httpx.AsyncClient() as client:
resp = await client.post(
self.token_url, data=body, headers=headers, timeout=30.0
)
if resp.status_code != 200:
error_detail = resp.text
try:
ej = resp.json()
error_detail = ej.get("error_description", error_detail)
error_code = ej.get("error", "")
except Exception:
error_code = ""
combined = (error_detail + error_code).lower()
if any(kw in combined for kw in ("invalid_grant", "expired", "revoked")):
await self.on_token_refresh_failure(session, connector)
raise HTTPException(
status_code=401,
detail=f"{self.provider_name.title()} authentication failed. "
"Please re-authenticate.",
)
raise HTTPException(
status_code=400, detail=f"Token refresh failed: {error_detail}"
)
token_json = resp.json()
new_access = token_json.get("access_token")
if not new_access:
raise HTTPException(
status_code=400, detail="No access token received from refresh"
)
expires_at = None
if token_json.get("expires_in"):
expires_at = datetime.now(UTC) + timedelta(
seconds=int(token_json["expires_in"])
)
updated_config = dict(connector.config)
updated_config["access_token"] = encryption.encrypt_token(new_access)
new_refresh = token_json.get("refresh_token")
if new_refresh:
updated_config["refresh_token"] = encryption.encrypt_token(new_refresh)
updated_config["expires_in"] = token_json.get("expires_in")
updated_config["expires_at"] = expires_at.isoformat() if expires_at else None
updated_config["scope"] = token_json.get("scope", updated_config.get("scope"))
updated_config["_token_encrypted"] = True
updated_config.pop("auth_expired", None)
connector.config = updated_config
flag_modified(connector, "config")
await session.commit()
await session.refresh(connector)
logger.info(
"Refreshed %s token for connector %s",
self.provider_name,
connector.id,
)
return connector
def build_router(self) -> APIRouter:
router = APIRouter()
oauth = self
@router.get(f"{oauth.auth_prefix}/connector/add")
async def connect(
space_id: int,
user: User = Depends(current_active_user),
):
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
client_id = oauth._get_client_id()
state_mgr = oauth._get_state_manager()
extra_state: dict[str, Any] = {}
auth_params: dict[str, str] = {
"client_id": client_id,
"response_type": "code",
"redirect_uri": oauth._get_redirect_uri(),
"scope": " ".join(oauth.scopes),
}
if oauth.use_pkce:
from app.utils.oauth_security import generate_pkce_pair
verifier, challenge = generate_pkce_pair()
extra_state["code_verifier"] = verifier
auth_params["code_challenge"] = challenge
auth_params["code_challenge_method"] = "S256"
auth_params.update(oauth.extra_auth_params)
state_encoded = state_mgr.generate_secure_state(
space_id, user.id, **extra_state
)
auth_params["state"] = state_encoded
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
logger.info(
"Generated %s OAuth URL for user %s, space %s",
oauth.provider_name,
user.id,
space_id,
)
return {"auth_url": auth_url}
@router.get(f"{oauth.auth_prefix}/connector/reauth")
async def reauth(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type == oauth.connector_type,
)
)
if not result.scalars().first():
raise HTTPException(
status_code=404,
detail=f"{oauth.provider_name.title()} connector not found "
"or access denied",
)
client_id = oauth._get_client_id()
state_mgr = oauth._get_state_manager()
extra: dict[str, Any] = {"connector_id": connector_id}
if return_url and return_url.startswith("/") and not return_url.startswith("//"):
extra["return_url"] = return_url
auth_params: dict[str, str] = {
"client_id": client_id,
"response_type": "code",
"redirect_uri": oauth._get_redirect_uri(),
"scope": " ".join(oauth.scopes),
}
if oauth.use_pkce:
from app.utils.oauth_security import generate_pkce_pair
verifier, challenge = generate_pkce_pair()
extra["code_verifier"] = verifier
auth_params["code_challenge"] = challenge
auth_params["code_challenge_method"] = "S256"
auth_params.update(oauth.extra_auth_params)
state_encoded = state_mgr.generate_secure_state(
space_id, user.id, **extra
)
auth_params["state"] = state_encoded
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
logger.info(
"Initiating %s re-auth for user %s, connector %s",
oauth.provider_name,
user.id,
connector_id,
)
return {"auth_url": auth_url}
@router.get(f"{oauth.auth_prefix}/connector/callback")
async def callback(
code: str | None = None,
error: str | None = None,
state: str | None = None,
session: AsyncSession = Depends(get_async_session),
):
error_label = f"{oauth.provider_name}_oauth_denied"
if error:
logger.warning("%s OAuth error: %s", oauth.provider_name, error)
space_id = None
if state:
try:
data = oauth._get_state_manager().validate_state(state)
space_id = data.get("space_id")
except Exception:
pass
return oauth._frontend_redirect(space_id, error=error_label)
if not code:
raise HTTPException(
status_code=400, detail="Missing authorization code"
)
if not state:
raise HTTPException(
status_code=400, detail="Missing state parameter"
)
state_mgr = oauth._get_state_manager()
try:
data = state_mgr.validate_state(state)
except Exception as e:
raise HTTPException(
status_code=400, detail="Invalid or expired state parameter."
) from e
user_id = UUID(data["user_id"])
space_id = data["space_id"]
token_json = await oauth._exchange_code(code, data)
access_token = token_json.get("access_token", "")
if not access_token:
raise HTTPException(
status_code=400,
detail=f"No access token received from {oauth.provider_name.title()}",
)
account_info = await oauth.fetch_account_info(access_token)
encryption = oauth._get_token_encryption()
connector_config = oauth.build_connector_config(
token_json, account_info, encryption
)
display_name = oauth.get_connector_display_name(account_info)
# --- Re-auth path ---
reauth_connector_id = data.get("connector_id")
reauth_return_url = data.get("return_url")
if reauth_connector_id:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == reauth_connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type == oauth.connector_type,
)
)
db_connector = result.scalars().first()
if not db_connector:
raise HTTPException(
status_code=404,
detail="Connector not found or access denied during re-auth",
)
db_connector.config = connector_config
flag_modified(db_connector, "config")
await session.commit()
await session.refresh(db_connector)
logger.info(
"Re-authenticated %s connector %s for user %s",
oauth.provider_name,
db_connector.id,
user_id,
)
if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"):
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
)
return oauth._frontend_redirect(
space_id, success=True, connector_id=db_connector.id
)
# --- New connector path ---
is_dup = await check_duplicate_connector(
session,
oauth.connector_type,
space_id,
user_id,
display_name,
)
if is_dup:
logger.warning(
"Duplicate %s connector for user %s (%s)",
oauth.provider_name,
user_id,
display_name,
)
return oauth._frontend_redirect(
space_id,
error=f"duplicate_account&connector={oauth.provider_name}-connector",
)
connector_name = await generate_unique_connector_name(
session,
oauth.connector_type,
space_id,
user_id,
display_name,
)
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=oauth.connector_type,
is_indexable=oauth.is_indexable,
config=connector_config,
search_space_id=space_id,
user_id=user_id,
)
session.add(new_connector)
try:
await session.commit()
except IntegrityError as e:
await session.rollback()
raise HTTPException(
status_code=409, detail="A connector for this service already exists."
) from e
logger.info(
"Created %s connector %s for user %s in space %s",
oauth.provider_name,
new_connector.id,
user_id,
space_id,
)
return oauth._frontend_redirect(
space_id, success=True, connector_id=new_connector.id
)
return router

View file

@ -693,27 +693,10 @@ async def index_connector_content(
user: User = Depends(current_active_user),
):
"""
Index content from a connector to a search space.
Requires CONNECTORS_UPDATE permission (to trigger indexing).
Index content from a KB connector to a search space.
Currently supports:
- SLACK_CONNECTOR: Indexes messages from all accessible Slack channels
- TEAMS_CONNECTOR: Indexes messages from all accessible Microsoft Teams channels
- NOTION_CONNECTOR: Indexes pages from all accessible Notion pages
- GITHUB_CONNECTOR: Indexes code and documentation from GitHub repositories
- LINEAR_CONNECTOR: Indexes issues and comments from Linear
- JIRA_CONNECTOR: Indexes issues and comments from Jira
- DISCORD_CONNECTOR: Indexes messages from all accessible Discord channels
- LUMA_CONNECTOR: Indexes events from Luma
- ELASTICSEARCH_CONNECTOR: Indexes documents from Elasticsearch
- WEBCRAWLER_CONNECTOR: Indexes web pages from crawled websites
Args:
connector_id: ID of the connector to use
search_space_id: ID of the search space to store indexed content
Returns:
Dictionary with indexing status
Live connectors (Slack, Teams, Linear, Jira, ClickUp, Calendar, Airtable,
Gmail, Discord, Luma) use real-time agent tools instead.
"""
try:
# Get the connector first
@ -770,9 +753,7 @@ async def index_connector_content(
# For calendar connectors, default to today but allow future dates if explicitly provided
if connector.connector_type in [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.LUMA_CONNECTOR,
]:
# Default to today if no end_date provided (users can manually select future dates)
indexing_to = today_str if end_date is None else end_date
@ -796,33 +777,22 @@ async def index_connector_content(
# For non-calendar connectors, cap at today
indexing_to = end_date if end_date else today_str
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import (
index_slack_messages_task,
)
from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES
logger.info(
f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_slack_messages_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Slack indexing started in the background."
if connector.connector_type in LIVE_CONNECTOR_TYPES:
return {
"message": (
f"{connector.connector_type.value} uses real-time agent tools; "
"background indexing is disabled."
),
"indexing_started": False,
"connector_id": connector_id,
"search_space_id": search_space_id,
"indexing_from": indexing_from,
"indexing_to": indexing_to,
}
elif connector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import (
index_teams_messages_task,
)
logger.info(
f"Triggering Teams indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_teams_messages_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Teams indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
if connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import index_notion_pages_task
logger.info(
@ -844,28 +814,6 @@ async def index_connector_content(
)
response_message = "GitHub indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import index_linear_issues_task
logger.info(
f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_linear_issues_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Linear indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.JIRA_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import index_jira_issues_task
logger.info(
f"Triggering Jira indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_jira_issues_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Jira indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.CONFLUENCE_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import (
index_confluence_pages_task,
@ -892,59 +840,6 @@ async def index_connector_content(
)
response_message = "BookStack indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.CLICKUP_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import index_clickup_tasks_task
logger.info(
f"Triggering ClickUp indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_clickup_tasks_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "ClickUp indexing started in the background."
elif (
connector.connector_type
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR
):
from app.tasks.celery_tasks.connector_tasks import (
index_google_calendar_events_task,
)
logger.info(
f"Triggering Google Calendar indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_google_calendar_events_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Google Calendar indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.AIRTABLE_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import (
index_airtable_records_task,
)
logger.info(
f"Triggering Airtable indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_airtable_records_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Airtable indexing started in the background."
elif (
connector.connector_type == SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR
):
from app.tasks.celery_tasks.connector_tasks import (
index_google_gmail_messages_task,
)
logger.info(
f"Triggering Google Gmail indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_google_gmail_messages_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Google Gmail indexing started in the background."
elif (
connector.connector_type == SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR
):
@ -1089,30 +984,6 @@ async def index_connector_content(
)
response_message = "Dropbox indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import (
index_discord_messages_task,
)
logger.info(
f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_discord_messages_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Discord indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import index_luma_events_task
logger.info(
f"Triggering Luma indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_luma_events_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Luma indexing started in the background."
elif (
connector.connector_type
== SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR
@ -1319,57 +1190,6 @@ async def _update_connector_timestamp_by_id(session: AsyncSession, connector_id:
await session.rollback()
async def run_slack_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Create a new session and run the Slack indexing task.
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_slack_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
async def run_slack_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Slack indexing.
Args:
session: Database session
connector_id: ID of the Slack connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_slack_messages
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_slack_messages,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
_AUTH_ERROR_PATTERNS = (
"failed to refresh linear oauth",
"failed to refresh your notion connection",
@ -1908,215 +1728,6 @@ async def run_github_indexing(
)
# Add new helper functions for Linear indexing
async def run_linear_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Wrapper to run Linear indexing with its own database session."""
logger.info(
f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
)
async with async_session_maker() as session:
await run_linear_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
logger.info(f"Background task finished: Indexing Linear connector {connector_id}")
async def run_linear_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Linear indexing.
Args:
session: Database session
connector_id: ID of the Linear connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_linear_issues
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_linear_issues,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
# Add new helper functions for discord indexing
async def run_discord_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Create a new session and run the Discord indexing task.
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_discord_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
async def run_discord_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Discord indexing.
Args:
session: Database session
connector_id: ID of the Discord connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_discord_messages
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_discord_messages,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
async def run_teams_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Create a new session and run the Microsoft Teams indexing task.
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_teams_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
async def run_teams_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Microsoft Teams indexing.
Args:
session: Database session
connector_id: ID of the Teams connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers.teams_indexer import index_teams_messages
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_teams_messages,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
# Add new helper functions for Jira indexing
async def run_jira_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Wrapper to run Jira indexing with its own database session."""
logger.info(
f"Background task started: Indexing Jira connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
)
async with async_session_maker() as session:
await run_jira_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
logger.info(f"Background task finished: Indexing Jira connector {connector_id}")
async def run_jira_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Jira indexing.
Args:
session: Database session
connector_id: ID of the Jira connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_jira_issues
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_jira_issues,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
# Add new helper functions for Confluence indexing
async def run_confluence_indexing_with_new_session(
connector_id: int,
@ -2172,112 +1783,6 @@ async def run_confluence_indexing(
)
# Add new helper functions for ClickUp indexing
async def run_clickup_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Wrapper to run ClickUp indexing with its own database session."""
logger.info(
f"Background task started: Indexing ClickUp connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
)
async with async_session_maker() as session:
await run_clickup_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
logger.info(f"Background task finished: Indexing ClickUp connector {connector_id}")
async def run_clickup_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run ClickUp indexing.
Args:
session: Database session
connector_id: ID of the ClickUp connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_clickup_tasks
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_clickup_tasks,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
# Add new helper functions for Airtable indexing
async def run_airtable_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Wrapper to run Airtable indexing with its own database session."""
logger.info(
f"Background task started: Indexing Airtable connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
)
async with async_session_maker() as session:
await run_airtable_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
logger.info(f"Background task finished: Indexing Airtable connector {connector_id}")
async def run_airtable_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Airtable indexing.
Args:
session: Database session
connector_id: ID of the Airtable connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_airtable_records
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_airtable_records,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
# Add new helper functions for Google Calendar indexing
async def run_google_calendar_indexing_with_new_session(
connector_id: int,
@ -2816,58 +2321,6 @@ async def run_dropbox_indexing(
logger.error(f"Failed to update notification: {notif_error!s}")
# Add new helper functions for luma indexing
async def run_luma_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Create a new session and run the Luma indexing task.
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_luma_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
async def run_luma_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Luma indexing.
Args:
session: Database session
connector_id: ID of the Luma connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_luma_events
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_luma_events,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
async def run_elasticsearch_indexing_with_new_session(
connector_id: int,
search_space_id: int,
@ -3580,13 +3033,18 @@ async def trust_mcp_tool(
"""Add a tool to the MCP connector's trusted (always-allow) list.
Once trusted, the tool executes without HITL approval on subsequent calls.
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors
(LINEAR_CONNECTOR, JIRA_CONNECTOR, etc.) by checking for ``server_config``.
"""
try:
from sqlalchemy import cast
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
SearchSourceConnector.user_id == user.id,
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
)
)
connector = result.scalars().first()
@ -3631,13 +3089,17 @@ async def untrust_mcp_tool(
"""Remove a tool from the MCP connector's trusted list.
The tool will require HITL approval again on subsequent calls.
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors.
"""
try:
from sqlalchemy import cast
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
SearchSourceConnector.user_id == user.id,
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
)
)
connector = result.scalars().first()

View file

@ -312,7 +312,7 @@ async def slack_callback(
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=SearchSourceConnectorType.SLACK_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,

View file

@ -45,6 +45,7 @@ SCOPES = [
"Team.ReadBasic.All", # Read basic team information
"Channel.ReadBasic.All", # Read basic channel information
"ChannelMessage.Read.All", # Read messages in channels
"ChannelMessage.Send", # Send messages in channels
]
# Initialize security utilities
@ -320,7 +321,7 @@ async def teams_callback(
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=SearchSourceConnectorType.TEAMS_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,

View file

@ -168,6 +168,11 @@ class ChatMessage(BaseModel):
content: str
class LocalFilesystemMountPayload(BaseModel):
mount_id: str
root_path: str
class NewChatRequest(BaseModel):
"""Request schema for the deep agent chat endpoint."""
@ -184,6 +189,9 @@ class NewChatRequest(BaseModel):
disabled_tools: list[str] | None = (
None # Optional list of tool names the user has disabled from the UI
)
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
class RegenerateRequest(BaseModel):
@ -204,6 +212,9 @@ class RegenerateRequest(BaseModel):
mentioned_document_ids: list[int] | None = None
mentioned_surfsense_doc_ids: list[int] | None = None
disabled_tools: list[str] | None = None
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
# =============================================================================
@ -227,6 +238,9 @@ class ResumeDecision(BaseModel):
class ResumeRequest(BaseModel):
search_space_id: int
decisions: list[ResumeDecision]
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
# =============================================================================

View file

@ -26,7 +26,7 @@ COMPOSIO_TOOLKIT_NAMES = {
}
# Toolkits that support indexing (Phase 1: Google services only)
INDEXABLE_TOOLKITS = {"googledrive", "gmail", "googlecalendar"}
INDEXABLE_TOOLKITS = {"googledrive"}
# Mapping of toolkit IDs to connector types
TOOLKIT_TO_CONNECTOR_TYPE = {

View file

@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -66,6 +65,8 @@ class ConfluenceKBSyncService:
if dup:
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
@ -184,6 +185,8 @@ class ConfluenceKBSyncService:
space_id = (document.document_metadata or {}).get("space_id", "")
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
)

View file

@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -73,6 +72,8 @@ class DropboxKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,

View file

@ -4,7 +4,6 @@ from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -78,6 +77,8 @@ class GmailKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,

View file

@ -14,7 +14,6 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -91,6 +90,8 @@ class GoogleCalendarKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
@ -249,6 +250,8 @@ class GoogleCalendarKBSyncService:
if not indexable_content:
return {"status": "error", "message": "Event produced empty content"}
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
)

View file

@ -4,7 +4,6 @@ from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -75,6 +74,8 @@ class GoogleDriveKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,

View file

@ -6,7 +6,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.jira_history import JiraHistoryConnector
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -75,6 +74,8 @@ class JiraKBSyncService:
if dup:
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
@ -190,6 +191,8 @@ class JiraKBSyncService:
state = formatted.get("status", "Unknown")
comment_count = len(formatted.get("comments", []))
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
)

View file

@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.linear_connector import LinearConnector
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -85,6 +84,8 @@ class LinearKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
@ -226,6 +227,8 @@ class LinearKBSyncService:
comment_count = len(formatted_issue.get("comments", []))
formatted_issue.get("description", "")
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
)

View file

@ -133,6 +133,44 @@ PROVIDER_MAP = {
}
# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
# request to an Azure endpoint, which then 404s with ``Resource not found``.
# Only providers with a well-known, stable public base URL are listed here —
# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
# huggingface, databricks, cloudflare, replicate) are intentionally omitted
# so their existing config-driven behaviour is preserved.
PROVIDER_DEFAULT_API_BASE = {
"openrouter": "https://openrouter.ai/api/v1",
"groq": "https://api.groq.com/openai/v1",
"mistral": "https://api.mistral.ai/v1",
"perplexity": "https://api.perplexity.ai",
"xai": "https://api.x.ai/v1",
"cerebras": "https://api.cerebras.ai/v1",
"deepinfra": "https://api.deepinfra.com/v1/openai",
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
"together_ai": "https://api.together.xyz/v1",
"anyscale": "https://api.endpoints.anyscale.com/v1",
"cometapi": "https://api.cometapi.com/v1",
"sambanova": "https://api.sambanova.ai/v1",
}
# Canonical provider → base URL when a config uses a generic ``openai``-style
# prefix but the ``provider`` field tells us which API it really is
# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
# each has its own base URL).
PROVIDER_KEY_DEFAULT_API_BASE = {
"DEEPSEEK": "https://api.deepseek.com/v1",
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
"MOONSHOT": "https://api.moonshot.ai/v1",
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
"MINIMAX": "https://api.minimax.io/v1",
}
class LLMRouterService:
"""
Singleton service for managing LiteLLM Router.
@ -224,6 +262,16 @@ class LLMRouterService:
# hits ContextWindowExceededError.
full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
# Build a general-purpose fallback list so NotFound/timeout/rate-limit
# style failures on one deployment don't bubble up as hard errors —
# the router retries with a sibling deployment in ``auto-large``.
# ``auto-large`` is the large-context subset of ``auto``; if it is
# empty we fall back to ``auto`` itself so the router at least picks a
# different deployment in the same group.
fallbacks: list[dict[str, list[str]]] | None = None
if ctx_fallbacks:
fallbacks = [{"auto": ["auto-large"]}]
try:
router_kwargs: dict[str, Any] = {
"model_list": full_model_list,
@ -237,15 +285,24 @@ class LLMRouterService:
}
if ctx_fallbacks:
router_kwargs["context_window_fallbacks"] = ctx_fallbacks
if fallbacks:
router_kwargs["fallbacks"] = fallbacks
instance._router = Router(**router_kwargs)
instance._initialized = True
global _cached_context_profile, _cached_context_profile_computed
_cached_context_profile = None
_cached_context_profile_computed = False
_router_instance_cache.clear()
logger.info(
"LLM Router initialized with %d deployments, "
"strategy: %s, context_window_fallbacks: %s",
"strategy: %s, context_window_fallbacks: %s, fallbacks: %s",
len(model_list),
final_settings.get("routing_strategy"),
ctx_fallbacks or "none",
fallbacks or "none",
)
except Exception as e:
logger.error(f"Failed to initialize LLM Router: {e}")
@ -348,10 +405,11 @@ class LLMRouterService:
return None
# Build model string
provider = config.get("provider", "").upper()
if config.get("custom_provider"):
model_string = f"{config['custom_provider']}/{config['model_name']}"
provider_prefix = config["custom_provider"]
model_string = f"{provider_prefix}/{config['model_name']}"
else:
provider = config.get("provider", "").upper()
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
@ -361,9 +419,19 @@ class LLMRouterService:
"api_key": config.get("api_key"),
}
# Add optional api_base
if config.get("api_base"):
litellm_params["api_base"] = config["api_base"]
# Resolve ``api_base``. Config value wins; otherwise apply a
# provider-aware default so the deployment does not silently
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
# requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
# docstring for the motivating bug (OpenRouter models 404-ing
# against an Azure endpoint).
api_base = config.get("api_base")
if not api_base:
api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
if not api_base:
api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
if api_base:
litellm_params["api_base"] = api_base
# Add any additional litellm parameters
if config.get("litellm_params"):

View file

@ -1,3 +1,4 @@
import asyncio
import logging
import litellm
@ -6,7 +7,6 @@ from langchain_litellm import ChatLiteLLM
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
from app.config import config
from app.db import NewLLMConfig, SearchSpace
from app.services.llm_router_service import (
@ -32,6 +32,39 @@ litellm.callbacks = [token_tracker]
logger = logging.getLogger(__name__)
# Providers that require an interactive OAuth / device-flow login before
# issuing any completion. LiteLLM implements these with blocking sync polling
# (requests + time.sleep), which would freeze the FastAPI event loop if
# invoked from validation. They are never usable from a headless backend,
# so we reject them at the edge.
_INTERACTIVE_AUTH_PROVIDERS: frozenset[str] = frozenset(
{
"github_copilot",
"github-copilot",
"githubcopilot",
"copilot",
}
)
# Hard upper bound for a single validation call. Must exceed the ChatLiteLLM
# request timeout (30s) by a small margin so a well-behaved provider never
# trips the watchdog, while any pathological/blocking provider is killed.
_VALIDATION_TIMEOUT_SECONDS: float = 35.0
def _is_interactive_auth_provider(
provider: str | None, custom_provider: str | None
) -> bool:
"""Return True if the given provider triggers interactive OAuth in LiteLLM."""
for raw in (custom_provider, provider):
if not raw:
continue
normalized = raw.strip().lower().replace(" ", "_")
if normalized in _INTERACTIVE_AUTH_PROVIDERS:
return True
return False
class LLMRole:
AGENT = "agent" # For agent/chat operations
DOCUMENT_SUMMARY = "document_summary" # For document summarization
@ -93,6 +126,25 @@ async def validate_llm_config(
- is_valid: True if config works, False otherwise
- error_message: Empty string if valid, error description if invalid
"""
# Reject providers that require interactive OAuth/device-flow auth.
# LiteLLM's github_copilot provider (and similar) uses a blocking sync
# Authenticator that polls GitHub for up to several minutes and prints a
# device code to stdout. Running it on the FastAPI event loop will freeze
# the entire backend, so we refuse them up front.
if _is_interactive_auth_provider(provider, custom_provider):
msg = (
"Provider requires interactive OAuth/device-flow authentication "
"(e.g. github_copilot) and cannot be used in a hosted backend. "
"Please choose a provider that authenticates via API key."
)
logger.warning(
"Rejected LLM config validation for interactive-auth provider "
"(provider=%r, custom_provider=%r)",
provider,
custom_provider,
)
return False, msg
try:
# Build the model string for litellm
if custom_provider:
@ -151,11 +203,34 @@ async def validate_llm_config(
if litellm_params:
litellm_kwargs.update(litellm_params)
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
llm = SanitizedChatLiteLLM(**litellm_kwargs)
# Make a simple test call
# Run the test call in a worker thread with a hard timeout. Some
# LiteLLM providers have synchronous blocking code paths (e.g. OAuth
# authenticators that call time.sleep and requests.post) that would
# otherwise freeze the asyncio event loop. Offloading to a thread and
# bounding the wait keeps the server responsive even if a provider
# misbehaves.
test_message = HumanMessage(content="Hello")
response = await llm.ainvoke([test_message])
try:
response = await asyncio.wait_for(
asyncio.to_thread(llm.invoke, [test_message]),
timeout=_VALIDATION_TIMEOUT_SECONDS,
)
except TimeoutError:
logger.warning(
"LLM config validation timed out after %ss for model: %s",
_VALIDATION_TIMEOUT_SECONDS,
model_string,
)
return (
False,
f"Validation timed out after {int(_VALIDATION_TIMEOUT_SECONDS)}s. "
"The provider is unreachable or requires interactive "
"authentication that is not supported by the backend.",
)
# If we got here without exception, the config is valid
if response and response.content:
@ -303,6 +378,8 @@ async def get_search_space_llm_instance(
if disable_streaming:
litellm_kwargs["disable_streaming"] = True
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
return SanitizedChatLiteLLM(**litellm_kwargs)
# Get the LLM configuration from database (NewLLMConfig)
@ -380,6 +457,8 @@ async def get_search_space_llm_instance(
if disable_streaming:
litellm_kwargs["disable_streaming"] = True
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
return SanitizedChatLiteLLM(**litellm_kwargs)
except Exception as e:
@ -481,6 +560,8 @@ async def get_vision_llm(
if global_cfg.get("litellm_params"):
litellm_kwargs.update(global_cfg["litellm_params"])
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
return SanitizedChatLiteLLM(**litellm_kwargs)
result = await session.execute(
@ -514,6 +595,8 @@ async def get_vision_llm(
if vision_cfg.litellm_params:
litellm_kwargs.update(vision_cfg.litellm_params)
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
return SanitizedChatLiteLLM(**litellm_kwargs)
except Exception as e:

View file

@ -0,0 +1,121 @@
"""MCP OAuth 2.1 metadata discovery, Dynamic Client Registration, and token exchange."""
from __future__ import annotations
import base64
import logging
from urllib.parse import urlparse
import httpx
logger = logging.getLogger(__name__)
async def discover_oauth_metadata(
mcp_url: str,
*,
origin_override: str | None = None,
timeout: float = 15.0,
) -> dict:
"""Fetch OAuth 2.1 metadata from the MCP server's well-known endpoint.
Per the MCP spec the discovery document lives at the *origin* of the
MCP server URL. ``origin_override`` can be used when the OAuth server
lives on a different domain (e.g. Airtable: MCP at ``mcp.airtable.com``,
OAuth at ``airtable.com``).
"""
if origin_override:
origin = origin_override.rstrip("/")
else:
parsed = urlparse(mcp_url)
origin = f"{parsed.scheme}://{parsed.netloc}"
discovery_url = f"{origin}/.well-known/oauth-authorization-server"
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.get(discovery_url, timeout=timeout)
resp.raise_for_status()
return resp.json()
async def register_client(
registration_endpoint: str,
redirect_uri: str,
*,
client_name: str = "SurfSense",
timeout: float = 15.0,
) -> dict:
"""Perform Dynamic Client Registration (RFC 7591)."""
payload = {
"client_name": client_name,
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_basic",
}
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.post(
registration_endpoint, json=payload, timeout=timeout,
)
resp.raise_for_status()
return resp.json()
async def exchange_code_for_tokens(
token_endpoint: str,
code: str,
redirect_uri: str,
client_id: str,
client_secret: str,
code_verifier: str,
*,
timeout: float = 30.0,
) -> dict:
"""Exchange an authorization code for access + refresh tokens."""
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.post(
token_endpoint,
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
},
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Authorization": f"Basic {creds}",
},
timeout=timeout,
)
resp.raise_for_status()
return resp.json()
async def refresh_access_token(
token_endpoint: str,
refresh_token: str,
client_id: str,
client_secret: str,
*,
timeout: float = 30.0,
) -> dict:
"""Refresh an expired access token."""
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.post(
token_endpoint,
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token,
},
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Authorization": f"Basic {creds}",
},
timeout=timeout,
)
resp.raise_for_status()
return resp.json()

View file

@ -0,0 +1,161 @@
"""Registry of MCP services with OAuth support.
Each entry maps a URL-safe service key to its MCP server endpoint and
authentication configuration. Services with ``supports_dcr=True`` use
RFC 7591 Dynamic Client Registration (the MCP server issues its own
credentials); the rest use pre-configured credentials via env vars.
``allowed_tools`` whitelists which MCP tools to expose to the agent.
An empty list means "load every tool the server advertises" (used for
user-managed generic MCP servers). Service-specific entries should
curate this list to keep the agent's tool count low and selection
accuracy high.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from app.db import SearchSourceConnectorType
@dataclass(frozen=True)
class MCPServiceConfig:
name: str
mcp_url: str
connector_type: str
supports_dcr: bool = True
oauth_discovery_origin: str | None = None
client_id_env: str | None = None
client_secret_env: str | None = None
scopes: list[str] = field(default_factory=list)
scope_param: str = "scope"
auth_endpoint_override: str | None = None
token_endpoint_override: str | None = None
allowed_tools: list[str] = field(default_factory=list)
readonly_tools: frozenset[str] = field(default_factory=frozenset)
account_metadata_keys: list[str] = field(default_factory=list)
"""``connector.config`` keys exposed by ``get_connected_accounts``.
Only listed keys are returned to the LLM tokens and secrets are
never included. Every service should at least have its
``display_name`` populated during OAuth; additional service-specific
fields (e.g. Jira ``cloud_id``) are listed here so the LLM can pass
them to action tools.
"""
MCP_SERVICES: dict[str, MCPServiceConfig] = {
"linear": MCPServiceConfig(
name="Linear",
mcp_url="https://mcp.linear.app/mcp",
connector_type="LINEAR_CONNECTOR",
allowed_tools=[
"list_issues",
"get_issue",
"save_issue",
],
readonly_tools=frozenset({"list_issues", "get_issue"}),
account_metadata_keys=["organization_name", "organization_url_key"],
),
"jira": MCPServiceConfig(
name="Jira",
mcp_url="https://mcp.atlassian.com/v1/mcp",
connector_type="JIRA_CONNECTOR",
allowed_tools=[
"getAccessibleAtlassianResources",
"searchJiraIssuesUsingJql",
"getVisibleJiraProjects",
"getJiraProjectIssueTypesMetadata",
"createJiraIssue",
"editJiraIssue",
],
readonly_tools=frozenset({
"getAccessibleAtlassianResources",
"searchJiraIssuesUsingJql",
"getVisibleJiraProjects",
"getJiraProjectIssueTypesMetadata",
}),
account_metadata_keys=["cloud_id", "site_name", "base_url"],
),
"clickup": MCPServiceConfig(
name="ClickUp",
mcp_url="https://mcp.clickup.com/mcp",
connector_type="CLICKUP_CONNECTOR",
allowed_tools=[
"clickup_search",
"clickup_get_task",
],
readonly_tools=frozenset({"clickup_search", "clickup_get_task"}),
account_metadata_keys=["workspace_id", "workspace_name"],
),
"slack": MCPServiceConfig(
name="Slack",
mcp_url="https://mcp.slack.com/mcp",
connector_type="SLACK_CONNECTOR",
supports_dcr=False,
client_id_env="SLACK_CLIENT_ID",
client_secret_env="SLACK_CLIENT_SECRET",
auth_endpoint_override="https://slack.com/oauth/v2_user/authorize",
token_endpoint_override="https://slack.com/api/oauth.v2.user.access",
scopes=[
"search:read.public", "search:read.private", "search:read.mpim", "search:read.im",
"channels:history", "groups:history", "mpim:history", "im:history",
],
allowed_tools=[
"slack_search_channels",
"slack_read_channel",
"slack_read_thread",
],
readonly_tools=frozenset({"slack_search_channels", "slack_read_channel", "slack_read_thread"}),
# TODO: oauth.v2.user.access only returns team.id, not team.name.
# To populate team_name, either add "team:read" scope and call
# GET /api/team.info during OAuth callback, or switch to oauth.v2.access.
account_metadata_keys=["team_id", "team_name"],
),
"airtable": MCPServiceConfig(
name="Airtable",
mcp_url="https://mcp.airtable.com/mcp",
connector_type="AIRTABLE_CONNECTOR",
supports_dcr=False,
oauth_discovery_origin="https://airtable.com",
client_id_env="AIRTABLE_CLIENT_ID",
client_secret_env="AIRTABLE_CLIENT_SECRET",
scopes=["data.records:read", "schema.bases:read"],
allowed_tools=[
"list_bases",
"list_tables_for_base",
"list_records_for_table",
],
readonly_tools=frozenset({"list_bases", "list_tables_for_base", "list_records_for_table"}),
account_metadata_keys=["user_id", "user_email"],
),
}
_CONNECTOR_TYPE_TO_SERVICE: dict[str, MCPServiceConfig] = {
svc.connector_type: svc for svc in MCP_SERVICES.values()
}
LIVE_CONNECTOR_TYPES: frozenset[SearchSourceConnectorType] = frozenset({
SearchSourceConnectorType.SLACK_CONNECTOR,
SearchSourceConnectorType.TEAMS_CONNECTOR,
SearchSourceConnectorType.LINEAR_CONNECTOR,
SearchSourceConnectorType.JIRA_CONNECTOR,
SearchSourceConnectorType.CLICKUP_CONNECTOR,
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.AIRTABLE_CONNECTOR,
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
SearchSourceConnectorType.DISCORD_CONNECTOR,
SearchSourceConnectorType.LUMA_CONNECTOR,
})
def get_service(key: str) -> MCPServiceConfig | None:
return MCP_SERVICES.get(key)
def get_service_by_connector_type(connector_type: str) -> MCPServiceConfig | None:
"""Look up an MCP service config by its ``connector_type`` enum value."""
return _CONNECTOR_TYPE_TO_SERVICE.get(connector_type)

View file

@ -4,7 +4,6 @@ from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -74,6 +73,8 @@ class NotionKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
@ -244,6 +245,8 @@ class NotionKBSyncService:
f"Final content length: {len(full_content)} chars, verified={content_verified}"
)
from app.services.llm_service import get_user_long_context_llm
logger.debug("Generating summary and embeddings")
user_llm = await get_user_long_context_llm(
self.db_session,

View file

@ -227,8 +227,6 @@ class NotionToolMetadataService:
async def _check_account_health(self, connector_id: int) -> bool:
"""Check if a Notion connector's token is still valid.
Uses a lightweight ``users.me()`` call to verify the token.
Returns True if the token is expired/invalid, False if healthy.
"""
try:

View file

@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -73,6 +72,8 @@ class OneDriveKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,

View file

@ -39,52 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N
)
@celery_app.task(name="index_slack_messages", bind=True)
def index_slack_messages_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Slack messages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_slack_messages(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_slack_messages", connector_id)
raise
finally:
loop.close()
async def _index_slack_messages(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Slack messages with new session."""
from app.routes.search_source_connectors_routes import (
run_slack_indexing,
)
async with get_celery_session_maker()() as session:
await run_slack_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_notion_pages", bind=True)
def index_notion_pages_task(
self,
@ -174,92 +128,6 @@ async def _index_github_repos(
)
@celery_app.task(name="index_linear_issues", bind=True)
def index_linear_issues_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Linear issues."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_linear_issues(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_linear_issues(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Linear issues with new session."""
from app.routes.search_source_connectors_routes import (
run_linear_indexing,
)
async with get_celery_session_maker()() as session:
await run_linear_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_jira_issues", bind=True)
def index_jira_issues_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Jira issues."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_jira_issues(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_jira_issues(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Jira issues with new session."""
from app.routes.search_source_connectors_routes import (
run_jira_indexing,
)
async with get_celery_session_maker()() as session:
await run_jira_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_confluence_pages", bind=True)
def index_confluence_pages_task(
self,
@ -303,49 +171,6 @@ async def _index_confluence_pages(
)
@celery_app.task(name="index_clickup_tasks", bind=True)
def index_clickup_tasks_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index ClickUp tasks."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_clickup_tasks(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_clickup_tasks(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index ClickUp tasks with new session."""
from app.routes.search_source_connectors_routes import (
run_clickup_indexing,
)
async with get_celery_session_maker()() as session:
await run_clickup_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_google_calendar_events", bind=True)
def index_google_calendar_events_task(
self,
@ -392,49 +217,6 @@ async def _index_google_calendar_events(
)
@celery_app.task(name="index_airtable_records", bind=True)
def index_airtable_records_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Airtable records."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_airtable_records(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_airtable_records(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Airtable records with new session."""
from app.routes.search_source_connectors_routes import (
run_airtable_indexing,
)
async with get_celery_session_maker()() as session:
await run_airtable_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_google_gmail_messages", bind=True)
def index_google_gmail_messages_task(
self,
@ -622,135 +404,6 @@ async def _index_dropbox_files(
)
@celery_app.task(name="index_discord_messages", bind=True)
def index_discord_messages_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Discord messages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_discord_messages(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_discord_messages(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Discord messages with new session."""
from app.routes.search_source_connectors_routes import (
run_discord_indexing,
)
async with get_celery_session_maker()() as session:
await run_discord_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_teams_messages", bind=True)
def index_teams_messages_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Microsoft Teams messages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_teams_messages(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_teams_messages(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Microsoft Teams messages with new session."""
from app.routes.search_source_connectors_routes import (
run_teams_indexing,
)
async with get_celery_session_maker()() as session:
await run_teams_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_luma_events", bind=True)
def index_luma_events_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Luma events."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_luma_events(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_luma_events(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Luma events with new session."""
from app.routes.search_source_connectors_routes import (
run_luma_indexing,
)
async with get_celery_session_maker()() as session:
await run_luma_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_elasticsearch_documents", bind=True)
def index_elasticsearch_documents_task(
self,

View file

@ -51,50 +51,51 @@ async def _check_and_trigger_schedules():
logger.info(f"Found {len(due_connectors)} connectors due for indexing")
# Import all indexing tasks
# Import indexing tasks for KB connectors only.
# Live connectors (Linear, Slack, Jira, ClickUp, Airtable, Discord,
# Teams, Gmail, Calendar, Luma) use real-time tools instead.
from app.tasks.celery_tasks.connector_tasks import (
index_airtable_records_task,
index_clickup_tasks_task,
index_confluence_pages_task,
index_crawled_urls_task,
index_discord_messages_task,
index_elasticsearch_documents_task,
index_github_repos_task,
index_google_calendar_events_task,
index_google_drive_files_task,
index_google_gmail_messages_task,
index_jira_issues_task,
index_linear_issues_task,
index_luma_events_task,
index_notion_pages_task,
index_slack_messages_task,
)
# Map connector types to their tasks
task_map = {
SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task,
SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task,
SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task,
SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task,
SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task,
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task,
SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task,
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task,
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task,
SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task,
SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task,
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
# Composio connector types (unified with native Google tasks)
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: index_google_gmail_messages_task,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
}
from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES
# Disable obsolete periodic indexing for live connectors in one batch.
live_disabled = []
for connector in due_connectors:
if connector.connector_type in LIVE_CONNECTOR_TYPES:
connector.periodic_indexing_enabled = False
connector.next_scheduled_at = None
live_disabled.append(connector)
if live_disabled:
await session.commit()
for c in live_disabled:
logger.info(
"Disabled obsolete periodic indexing for live connector %s (%s)",
c.id,
c.connector_type.value,
)
# Trigger indexing for each due connector
for connector in due_connectors:
if connector in live_disabled:
continue
# Primary guard: Redis lock indicates a task is currently running.
if is_connector_indexing_locked(connector.id):
logger.info(

View file

@ -30,6 +30,8 @@ from sqlalchemy.orm import selectinload
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.filesystem_selection import FilesystemSelection
from app.config import config
from app.agents.new_chat.llm_config import (
AgentConfig,
create_chat_litellm_from_agent_config,
@ -145,6 +147,102 @@ class StreamResult:
interrupt_value: dict[str, Any] | None = None
sandbox_files: list[str] = field(default_factory=list)
agent_called_update_memory: bool = False
request_id: str | None = None
turn_id: str = ""
filesystem_mode: str = "cloud"
client_platform: str = "web"
intent_detected: str = "chat_only"
intent_confidence: float = 0.0
write_attempted: bool = False
write_succeeded: bool = False
verification_succeeded: bool = False
commit_gate_passed: bool = True
commit_gate_reason: str = ""
def _safe_float(value: Any, default: float = 0.0) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default
def _tool_output_to_text(tool_output: Any) -> str:
if isinstance(tool_output, dict):
if isinstance(tool_output.get("result"), str):
return tool_output["result"]
if isinstance(tool_output.get("error"), str):
return tool_output["error"]
return json.dumps(tool_output, ensure_ascii=False)
return str(tool_output)
def _tool_output_has_error(tool_output: Any) -> bool:
if isinstance(tool_output, dict):
if tool_output.get("error"):
return True
result = tool_output.get("result")
if isinstance(result, str) and result.strip().lower().startswith("error:"):
return True
return False
if isinstance(tool_output, str):
return tool_output.strip().lower().startswith("error:")
return False
def _extract_resolved_file_path(*, tool_name: str, tool_output: Any) -> str | None:
if isinstance(tool_output, dict):
path_value = tool_output.get("path")
if isinstance(path_value, str) and path_value.strip():
return path_value.strip()
text = _tool_output_to_text(tool_output)
if tool_name == "write_file":
match = re.search(r"Updated file\s+(.+)$", text.strip())
if match:
return match.group(1).strip()
if tool_name == "edit_file":
match = re.search(r"in '([^']+)'", text)
if match:
return match.group(1).strip()
return None
def _contract_enforcement_active(result: StreamResult) -> bool:
# Keep policy deterministic with no env-driven progression modes:
# enforce the file-operation contract only in desktop local-folder mode.
return result.filesystem_mode == "desktop_local_folder"
def _evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]:
if result.intent_detected != "file_write":
return True, ""
if not result.write_attempted:
return False, "no_write_attempt"
if not result.write_succeeded:
return False, "write_failed"
if not result.verification_succeeded:
return False, "verification_failed"
return True, ""
def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
payload: dict[str, Any] = {
"stage": stage,
"request_id": result.request_id or "unknown",
"turn_id": result.turn_id or "unknown",
"chat_id": result.turn_id.split(":", 1)[0] if ":" in result.turn_id else "unknown",
"filesystem_mode": result.filesystem_mode,
"client_platform": result.client_platform,
"intent_detected": result.intent_detected,
"intent_confidence": result.intent_confidence,
"write_attempted": result.write_attempted,
"write_succeeded": result.write_succeeded,
"verification_succeeded": result.verification_succeeded,
"commit_gate_passed": result.commit_gate_passed,
"commit_gate_reason": result.commit_gate_reason or None,
}
payload.update(extra)
_perf_log.info("[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False))
async def _stream_agent_events(
@ -239,6 +337,8 @@ async def _stream_agent_events(
tool_name = event.get("name", "unknown_tool")
run_id = event.get("run_id", "")
tool_input = event.get("data", {}).get("input", {})
if tool_name in ("write_file", "edit_file"):
result.write_attempted = True
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
@ -514,6 +614,14 @@ async def _stream_agent_events(
else:
tool_output = {"result": str(raw_output) if raw_output else "completed"}
if tool_name in ("write_file", "edit_file"):
if _tool_output_has_error(tool_output):
# Keep successful evidence if a previous write/edit in this turn succeeded.
pass
else:
result.write_succeeded = True
result.verification_succeeded = True
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
original_step_id = tool_step_ids.get(
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
@ -925,6 +1033,30 @@ async def _stream_agent_events(
f"Scrape failed: {error_msg}",
"error",
)
elif tool_name in ("write_file", "edit_file"):
resolved_path = _extract_resolved_file_path(
tool_name=tool_name,
tool_output=tool_output,
)
result_text = _tool_output_to_text(tool_output)
if _tool_output_has_error(tool_output):
yield streaming_service.format_tool_output_available(
tool_call_id,
{
"status": "error",
"error": result_text,
"path": resolved_path,
},
)
else:
yield streaming_service.format_tool_output_available(
tool_call_id,
{
"status": "completed",
"path": resolved_path,
"result": result_text,
},
)
elif tool_name == "generate_report":
# Stream the full report result so frontend can render the ReportCard
yield streaming_service.format_tool_output_available(
@ -1143,10 +1275,59 @@ async def _stream_agent_events(
if completion_event:
yield completion_event
state = await agent.aget_state(config)
state_values = getattr(state, "values", {}) or {}
contract_state = state_values.get("file_operation_contract") or {}
contract_turn_id = contract_state.get("turn_id")
current_turn_id = config.get("configurable", {}).get("turn_id", "")
intent_value = contract_state.get("intent")
if (
isinstance(intent_value, str)
and intent_value in ("chat_only", "file_write", "file_read")
and contract_turn_id == current_turn_id
):
result.intent_detected = intent_value
if (
isinstance(intent_value, str)
and intent_value in (
"chat_only",
"file_write",
"file_read",
)
and contract_turn_id != current_turn_id
):
# Ignore stale intent contracts from previous turns/checkpoints.
result.intent_detected = "chat_only"
result.intent_confidence = (
_safe_float(contract_state.get("confidence"), default=0.0)
if contract_turn_id == current_turn_id
else 0.0
)
if result.intent_detected == "file_write":
result.commit_gate_passed, result.commit_gate_reason = (
_evaluate_file_contract_outcome(result)
)
if not result.commit_gate_passed:
if _contract_enforcement_active(result):
gate_notice = (
"I could not complete the requested file write because no successful "
"write_file/edit_file operation was confirmed."
)
gate_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(gate_text_id)
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
yield streaming_service.format_text_end(gate_text_id)
yield streaming_service.format_terminal_info(gate_notice, "error")
accumulated_text = gate_notice
else:
result.commit_gate_passed = True
result.commit_gate_reason = ""
result.accumulated_text = accumulated_text
result.agent_called_update_memory = called_update_memory
_log_file_contract("turn_outcome", result)
state = await agent.aget_state(config)
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
if is_interrupted:
result.is_interrupted = True
@ -1167,6 +1348,8 @@ async def stream_new_chat(
thread_visibility: ChatVisibility | None = None,
current_user_display_name: str | None = None,
disabled_tools: list[str] | None = None,
filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None,
) -> AsyncGenerator[str, None]:
"""
Stream chat responses from the new SurfSense deep agent.
@ -1194,6 +1377,20 @@ async def stream_new_chat(
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
fs_platform = (
filesystem_selection.client_platform.value if filesystem_selection else "web"
)
stream_result.request_id = request_id
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
stream_result.filesystem_mode = fs_mode
stream_result.client_platform = fs_platform
_log_file_contract("turn_start", stream_result)
_perf_log.info(
"[stream_new_chat] filesystem_mode=%s client_platform=%s",
fs_mode,
fs_platform,
)
log_system_snapshot("stream_new_chat_START")
from app.services.token_tracking_service import start_turn
@ -1329,6 +1526,7 @@ async def stream_new_chat(
thread_visibility=visibility,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
)
_perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
@ -1435,6 +1633,8 @@ async def stream_new_chat(
# We will use this to simulate group chat functionality in the future
"messages": langchain_messages,
"search_space_id": search_space_id,
"request_id": request_id or "unknown",
"turn_id": stream_result.turn_id,
}
_perf_log.info(
@ -1464,6 +1664,8 @@ async def stream_new_chat(
# Configure LangGraph with thread_id for memory
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
configurable = {"thread_id": str(chat_id)}
configurable["request_id"] = request_id or "unknown"
configurable["turn_id"] = stream_result.turn_id
if checkpoint_id:
configurable["checkpoint_id"] = checkpoint_id
@ -1871,10 +2073,26 @@ async def stream_resume_chat(
user_id: str | None = None,
llm_config_id: int = -1,
thread_visibility: ChatVisibility | None = None,
filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None,
) -> AsyncGenerator[str, None]:
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
fs_platform = (
filesystem_selection.client_platform.value if filesystem_selection else "web"
)
stream_result.request_id = request_id
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
stream_result.filesystem_mode = fs_mode
stream_result.client_platform = fs_platform
_log_file_contract("turn_start", stream_result)
_perf_log.info(
"[stream_resume] filesystem_mode=%s client_platform=%s",
fs_mode,
fs_platform,
)
from app.services.token_tracking_service import start_turn
@ -1991,6 +2209,7 @@ async def stream_resume_chat(
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
)
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
@ -2009,7 +2228,11 @@ async def stream_resume_chat(
from langgraph.types import Command
config = {
"configurable": {"thread_id": str(chat_id)},
"configurable": {
"thread_id": str(chat_id),
"request_id": request_id or "unknown",
"turn_id": stream_result.turn_id,
},
"recursion_limit": 80,
}

View file

@ -1,75 +1,29 @@
"""
Connector indexers module for background tasks.
This module provides a collection of connector indexers for different platforms
and services. Each indexer is responsible for handling the indexing of content
from a specific connector type.
Available indexers:
- Slack: Index messages from Slack channels
- Notion: Index pages from Notion workspaces
- GitHub: Index repositories and files from GitHub
- Linear: Index issues from Linear workspaces
- Jira: Index issues from Jira projects
- Confluence: Index pages from Confluence spaces
- BookStack: Index pages from BookStack wiki instances
- Discord: Index messages from Discord servers
- ClickUp: Index tasks from ClickUp workspaces
- Google Gmail: Index messages from Google Gmail
- Google Calendar: Index events from Google Calendar
- Luma: Index events from Luma
- Webcrawler: Index crawled URLs
- Elasticsearch: Index documents from Elasticsearch instances
Each indexer handles content indexing from a specific connector type.
Live connectors (Slack, Linear, Jira, ClickUp, Airtable, Discord, Teams,
Luma) now use real-time agent tools instead of background indexing.
"""
# Communication platforms
# Calendar and scheduling
from .airtable_indexer import index_airtable_records
from .bookstack_indexer import index_bookstack_pages
# Note: composio_indexer is imported directly in connector_tasks.py to avoid circular imports
from .clickup_indexer import index_clickup_tasks
from .confluence_indexer import index_confluence_pages
from .discord_indexer import index_discord_messages
# Development platforms
from .elasticsearch_indexer import index_elasticsearch_documents
from .github_indexer import index_github_repos
from .google_calendar_indexer import index_google_calendar_events
from .google_drive_indexer import index_google_drive_files
from .google_gmail_indexer import index_google_gmail_messages
from .jira_indexer import index_jira_issues
# Issue tracking and project management
from .linear_indexer import index_linear_issues
# Documentation and knowledge management
from .luma_indexer import index_luma_events
from .notion_indexer import index_notion_pages
from .slack_indexer import index_slack_messages
from .webcrawler_indexer import index_crawled_urls
__all__ = [ # noqa: RUF022
"index_airtable_records",
__all__ = [
"index_bookstack_pages",
# "index_composio_connector", # Imported directly in connector_tasks.py to avoid circular imports
"index_clickup_tasks",
"index_confluence_pages",
"index_discord_messages",
# Development platforms
"index_crawled_urls",
"index_elasticsearch_documents",
"index_github_repos",
# Calendar and scheduling
"index_google_calendar_events",
"index_google_drive_files",
"index_luma_events",
"index_jira_issues",
# Issue tracking and project management
"index_linear_issues",
# Documentation and knowledge management
"index_notion_pages",
"index_crawled_urls",
# Communication platforms
"index_slack_messages",
"index_google_gmail_messages",
"index_notion_pages",
]

View file

@ -0,0 +1,129 @@
"""Async retry decorators for connector API calls, built on tenacity."""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import TypeVar
import httpx
from tenacity import (
before_sleep_log,
retry,
retry_if_exception,
stop_after_attempt,
stop_after_delay,
wait_exponential_jitter,
)
from app.connectors.exceptions import (
ConnectorAPIError,
ConnectorAuthError,
ConnectorError,
ConnectorRateLimitError,
ConnectorTimeoutError,
)
logger = logging.getLogger(__name__)
F = TypeVar("F", bound=Callable)
def _is_retryable(exc: BaseException) -> bool:
if isinstance(exc, ConnectorError):
return exc.retryable
if isinstance(exc, (httpx.TimeoutException, httpx.ConnectError)):
return True
return False
def build_retry(
*,
max_attempts: int = 4,
max_delay: float = 60.0,
initial_delay: float = 1.0,
total_timeout: float = 180.0,
service: str = "",
) -> Callable:
"""Configurable tenacity ``@retry`` decorator with exponential backoff + jitter."""
_logger = logging.getLogger(f"connector.retry.{service}") if service else logger
return retry(
retry=retry_if_exception(_is_retryable),
stop=(stop_after_attempt(max_attempts) | stop_after_delay(total_timeout)),
wait=wait_exponential_jitter(initial=initial_delay, max=max_delay),
reraise=True,
before_sleep=before_sleep_log(_logger, logging.WARNING),
)
def retry_on_transient(
*,
service: str = "",
max_attempts: int = 4,
) -> Callable:
"""Shorthand: retry up to *max_attempts* on rate-limits, timeouts, and 5xx."""
return build_retry(max_attempts=max_attempts, service=service)
def raise_for_status(
response: httpx.Response,
*,
service: str = "",
) -> None:
"""Map non-2xx httpx responses to the appropriate ``ConnectorError``."""
if response.is_success:
return
status = response.status_code
try:
body = response.json()
except Exception:
body = response.text[:500] if response.text else None
if status == 429:
retry_after_raw = response.headers.get("Retry-After")
retry_after: float | None = None
if retry_after_raw:
try:
retry_after = float(retry_after_raw)
except (ValueError, TypeError):
pass
raise ConnectorRateLimitError(
f"{service} rate limited (429)",
service=service,
retry_after=retry_after,
response_body=body,
)
if status in (401, 403):
raise ConnectorAuthError(
f"{service} authentication failed ({status})",
service=service,
status_code=status,
response_body=body,
)
if status == 504:
raise ConnectorTimeoutError(
f"{service} gateway timeout (504)",
service=service,
status_code=status,
response_body=body,
)
if status >= 500:
raise ConnectorAPIError(
f"{service} server error ({status})",
service=service,
status_code=status,
response_body=body,
)
raise ConnectorAPIError(
f"{service} request failed ({status})",
service=service,
status_code=status,
response_body=body,
)

View file

@ -39,7 +39,7 @@ BASE_NAME_FOR_TYPE = {
def get_base_name_for_type(connector_type: SearchSourceConnectorType) -> str:
"""Get a friendly display name for a connector type."""
return BASE_NAME_FOR_TYPE.get(
connector_type, connector_type.replace("_", " ").title()
connector_type, connector_type.value.replace("_", " ").title()
)
@ -231,9 +231,11 @@ async def generate_unique_connector_name(
base = get_base_name_for_type(connector_type)
if identifier:
return f"{base} - {identifier}"
name = f"{base} - {identifier}"
return await ensure_unique_connector_name(
session, name, search_space_id, user_id,
)
# Fallback: use counter for uniqueness
count = await count_connectors_of_type(
session, connector_type, search_space_id, user_id
)

View file

@ -18,19 +18,9 @@ logger = logging.getLogger(__name__)
# Mapping of connector types to their corresponding Celery task names
CONNECTOR_TASK_MAP = {
SearchSourceConnectorType.SLACK_CONNECTOR: "index_slack_messages",
SearchSourceConnectorType.TEAMS_CONNECTOR: "index_teams_messages",
SearchSourceConnectorType.NOTION_CONNECTOR: "index_notion_pages",
SearchSourceConnectorType.GITHUB_CONNECTOR: "index_github_repos",
SearchSourceConnectorType.LINEAR_CONNECTOR: "index_linear_issues",
SearchSourceConnectorType.JIRA_CONNECTOR: "index_jira_issues",
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "index_confluence_pages",
SearchSourceConnectorType.CLICKUP_CONNECTOR: "index_clickup_tasks",
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "index_google_calendar_events",
SearchSourceConnectorType.AIRTABLE_CONNECTOR: "index_airtable_records",
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: "index_google_gmail_messages",
SearchSourceConnectorType.DISCORD_CONNECTOR: "index_discord_messages",
SearchSourceConnectorType.LUMA_CONNECTOR: "index_luma_events",
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: "index_elasticsearch_documents",
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: "index_crawled_urls",
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: "index_bookstack_pages",
@ -83,39 +73,19 @@ def create_periodic_schedule(
f"(frequency: {frequency_minutes} minutes). Triggering first run..."
)
# Import all indexing tasks
from app.tasks.celery_tasks.connector_tasks import (
index_airtable_records_task,
index_bookstack_pages_task,
index_clickup_tasks_task,
index_confluence_pages_task,
index_crawled_urls_task,
index_discord_messages_task,
index_elasticsearch_documents_task,
index_github_repos_task,
index_google_calendar_events_task,
index_google_gmail_messages_task,
index_jira_issues_task,
index_linear_issues_task,
index_luma_events_task,
index_notion_pages_task,
index_slack_messages_task,
)
# Map connector type to task
task_map = {
SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task,
SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task,
SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task,
SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task,
SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task,
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task,
SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task,
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task,
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task,
SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task,
SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task,
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: index_bookstack_pages_task,

View file

@ -74,7 +74,7 @@ dependencies = [
"deepagents>=0.4.12",
"stripe>=15.0.0",
"azure-ai-documentintelligence>=1.0.2",
"litellm>=1.83.0",
"litellm>=1.83.4",
"langchain-litellm>=0.6.4",
]

View file

@ -0,0 +1,213 @@
"""Unit tests for resume page-limit helpers and enforcement flow."""
import io
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pypdf
import pytest
from app.agents.new_chat.tools import resume as resume_tool
pytestmark = pytest.mark.unit
class _FakeReport:
_next_id = 1000
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
self.id = None
class _FakeSession:
def __init__(self, parent_report=None):
self.parent_report = parent_report
self.added: list[_FakeReport] = []
async def get(self, _model, _id):
return self.parent_report
def add(self, report):
self.added.append(report)
async def commit(self):
for report in self.added:
if getattr(report, "id", None) is None:
report.id = _FakeReport._next_id
_FakeReport._next_id += 1
async def refresh(self, _report):
return None
class _SessionContext:
def __init__(self, session):
self.session = session
async def __aenter__(self):
return self.session
async def __aexit__(self, exc_type, exc, tb):
return False
class _SessionFactory:
def __init__(self, sessions):
self._sessions = list(sessions)
def __call__(self):
if not self._sessions:
raise RuntimeError("No fake sessions left")
return _SessionContext(self._sessions.pop(0))
def _make_pdf_with_pages(page_count: int) -> bytes:
writer = pypdf.PdfWriter()
for _ in range(page_count):
writer.add_blank_page(width=612, height=792)
output = io.BytesIO()
writer.write(output)
return output.getvalue()
def test_count_pdf_pages_reads_compiled_bytes() -> None:
pdf_bytes = _make_pdf_with_pages(2)
assert resume_tool._count_pdf_pages(pdf_bytes) == 2
def test_validate_max_pages_rejects_out_of_range() -> None:
with pytest.raises(ValueError):
resume_tool._validate_max_pages(0)
with pytest.raises(ValueError):
resume_tool._validate_max_pages(6)
@pytest.mark.asyncio
async def test_generate_resume_defaults_to_one_page_target(monkeypatch) -> None:
read_session = _FakeSession()
write_session = _FakeSession()
session_factory = _SessionFactory([read_session, write_session])
monkeypatch.setattr(resume_tool, "shielded_async_session", session_factory)
monkeypatch.setattr(resume_tool, "Report", _FakeReport)
prompts: list[str] = []
async def _llm_invoke(messages):
prompts.append(messages[0].content)
return SimpleNamespace(content="= Jane Doe\n== Experience\n- Built systems")
llm = SimpleNamespace(ainvoke=AsyncMock(side_effect=_llm_invoke))
monkeypatch.setattr(
resume_tool,
"get_document_summary_llm",
AsyncMock(return_value=llm),
)
monkeypatch.setattr(resume_tool, "_compile_typst", lambda _source: b"pdf")
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: 1)
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
result = await tool.ainvoke({"user_info": "Jane Doe experience"})
assert result["status"] == "ready"
assert prompts
assert "**Target Maximum Pages:** 1" in prompts[0]
@pytest.mark.asyncio
async def test_generate_resume_compresses_when_over_limit(monkeypatch) -> None:
read_session = _FakeSession()
write_session = _FakeSession()
session_factory = _SessionFactory([read_session, write_session])
monkeypatch.setattr(resume_tool, "shielded_async_session", session_factory)
monkeypatch.setattr(resume_tool, "Report", _FakeReport)
responses = [
SimpleNamespace(content="= Jane Doe\n== Experience\n- Detailed bullet 1"),
SimpleNamespace(content="= Jane Doe\n== Experience\n- Condensed bullet"),
]
llm = SimpleNamespace(ainvoke=AsyncMock(side_effect=responses))
monkeypatch.setattr(
resume_tool,
"get_document_summary_llm",
AsyncMock(return_value=llm),
)
monkeypatch.setattr(resume_tool, "_compile_typst", lambda _source: b"pdf")
page_counts = iter([2, 1])
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
assert result["status"] == "ready"
assert write_session.added, "Expected successful report write"
metadata = write_session.added[0].report_metadata
assert metadata["target_max_pages"] == 1
assert metadata["actual_page_count"] == 1
assert metadata["compression_attempts"] == 1
assert metadata["page_limit_enforced"] is True
@pytest.mark.asyncio
async def test_generate_resume_returns_ready_when_target_not_met(monkeypatch) -> None:
read_session = _FakeSession()
write_session = _FakeSession()
session_factory = _SessionFactory([read_session, write_session])
monkeypatch.setattr(resume_tool, "shielded_async_session", session_factory)
monkeypatch.setattr(resume_tool, "Report", _FakeReport)
responses = [
SimpleNamespace(content="= Jane Doe\n== Experience\n- Long detail"),
SimpleNamespace(content="= Jane Doe\n== Experience\n- Still long"),
SimpleNamespace(content="= Jane Doe\n== Experience\n- Still too long"),
]
llm = SimpleNamespace(ainvoke=AsyncMock(side_effect=responses))
monkeypatch.setattr(
resume_tool,
"get_document_summary_llm",
AsyncMock(return_value=llm),
)
monkeypatch.setattr(resume_tool, "_compile_typst", lambda _source: b"pdf")
page_counts = iter([3, 3, 2])
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
assert result["status"] == "ready"
assert "could not fit the target" in (result["message"] or "").lower()
metadata = write_session.added[0].report_metadata
assert metadata["target_page_met"] is False
assert metadata["actual_page_count"] == 2
@pytest.mark.asyncio
async def test_generate_resume_fails_when_hard_limit_exceeded(monkeypatch) -> None:
read_session = _FakeSession()
failed_session = _FakeSession()
session_factory = _SessionFactory([read_session, failed_session])
monkeypatch.setattr(resume_tool, "shielded_async_session", session_factory)
monkeypatch.setattr(resume_tool, "Report", _FakeReport)
responses = [
SimpleNamespace(content="= Jane Doe\n== Experience\n- Long detail"),
SimpleNamespace(content="= Jane Doe\n== Experience\n- Still long"),
SimpleNamespace(content="= Jane Doe\n== Experience\n- Still too long"),
]
llm = SimpleNamespace(ainvoke=AsyncMock(side_effect=responses))
monkeypatch.setattr(
resume_tool,
"get_document_summary_llm",
AsyncMock(return_value=llm),
)
monkeypatch.setattr(resume_tool, "_compile_typst", lambda _source: b"pdf")
page_counts = iter([7, 6, 6])
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
assert result["status"] == "failed"
assert "hard page limit" in (result["error"] or "").lower()
assert failed_session.added, "Expected failed report persistence"

View file

@ -0,0 +1,214 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from app.agents.new_chat.middleware.file_intent import (
FileIntentMiddleware,
FileOperationIntent,
_fallback_path,
)
pytestmark = pytest.mark.unit
class _FakeLLM:
def __init__(self, response_text: str):
self._response_text = response_text
async def ainvoke(self, *_args, **_kwargs):
return AIMessage(content=self._response_text)
@pytest.mark.asyncio
async def test_file_write_intent_injects_contract_message():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.93,"suggested_filename":"ideas.md"}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="Create another random note for me")],
"turn_id": "123:456",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/ideas.md"
assert contract["turn_id"] == "123:456"
assert any(
"file_operation_contract" in str(msg.content)
for msg in result["messages"]
if hasattr(msg, "content")
)
@pytest.mark.asyncio
async def test_non_write_intent_does_not_inject_contract_message():
llm = _FakeLLM(
'{"intent":"file_read","confidence":0.88,"suggested_filename":null}'
)
middleware = FileIntentMiddleware(llm=llm)
original_messages = [HumanMessage(content="Read /notes.md")]
state = {"messages": original_messages, "turn_id": "abc:def"}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
assert result["file_operation_contract"]["intent"] == FileOperationIntent.FILE_READ.value
assert "messages" not in result
@pytest.mark.asyncio
async def test_file_write_null_filename_uses_semantic_default_path():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.74,"suggested_filename":null}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="create a random markdown file")],
"turn_id": "turn:1",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/notes.md"
@pytest.mark.asyncio
async def test_file_write_null_filename_infers_json_extension():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.71,"suggested_filename":null}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="create a sample json config file")],
"turn_id": "turn:2",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/notes.json"
@pytest.mark.asyncio
async def test_file_write_txt_suggestion_is_normalized_to_markdown():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.82,"suggested_filename":"random.txt"}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="create a random file")],
"turn_id": "turn:3",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/random.md"
@pytest.mark.asyncio
async def test_file_write_with_suggested_directory_preserves_folder():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.86,"suggested_filename":"random.md","suggested_directory":"pc backups","suggested_path":null}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="create a random file in pc backups folder")],
"turn_id": "turn:4",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/pc_backups/random.md"
@pytest.mark.asyncio
async def test_file_write_with_suggested_path_takes_precedence():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.9,"suggested_filename":"ignored.md","suggested_directory":"docs","suggested_path":"/reports/q2/summary.md"}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="create report")],
"turn_id": "turn:5",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/reports/q2/summary.md"
@pytest.mark.asyncio
async def test_file_write_infers_directory_from_user_text_when_missing():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.83,"suggested_filename":"random.md","suggested_directory":null,"suggested_path":null}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="create a random file in pc backups folder")],
"turn_id": "turn:6",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/pc_backups/random.md"
def test_fallback_path_normalizes_windows_slashes() -> None:
resolved = _fallback_path(
suggested_filename="summary.md",
suggested_path=r"\reports\q2\summary.md",
user_text="create report",
)
assert resolved == "/reports/q2/summary.md"
def test_fallback_path_normalizes_windows_drive_path() -> None:
resolved = _fallback_path(
suggested_filename=None,
suggested_path=r"C:\Users\anish\notes\todo.md",
user_text="create note",
)
assert resolved == "/C/Users/anish/notes/todo.md"
def test_fallback_path_normalizes_mixed_separators_and_duplicate_slashes() -> None:
resolved = _fallback_path(
suggested_filename="summary.md",
suggested_path=r"\\reports\\q2//summary.md",
user_text="create report",
)
assert resolved == "/reports/q2/summary.md"
def test_fallback_path_keeps_posix_style_absolute_path_for_linux_and_macos() -> None:
resolved = _fallback_path(
suggested_filename=None,
suggested_path="/var/log/surfsense/notes.md",
user_text="create note",
)
assert resolved == "/var/log/surfsense/notes.md"

View file

@ -0,0 +1,59 @@
from pathlib import Path
import pytest
from app.agents.new_chat.filesystem_backends import build_backend_resolver
from app.agents.new_chat.filesystem_selection import (
ClientPlatform,
FilesystemMode,
FilesystemSelection,
LocalFilesystemMount,
)
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
MultiRootLocalFolderBackend,
)
pytestmark = pytest.mark.unit
class _RuntimeStub:
state = {"files": {}}
def test_backend_resolver_returns_multi_root_backend_for_single_root(tmp_path: Path):
selection = FilesystemSelection(
mode=FilesystemMode.DESKTOP_LOCAL_FOLDER,
client_platform=ClientPlatform.DESKTOP,
local_mounts=(LocalFilesystemMount(mount_id="tmp", root_path=str(tmp_path)),),
)
resolver = build_backend_resolver(selection)
backend = resolver(_RuntimeStub())
assert isinstance(backend, MultiRootLocalFolderBackend)
def test_backend_resolver_uses_cloud_mode_by_default():
resolver = build_backend_resolver(FilesystemSelection())
backend = resolver(_RuntimeStub())
# StateBackend class name check keeps this test decoupled
# from internal deepagents runtime class identity.
assert backend.__class__.__name__ == "StateBackend"
def test_backend_resolver_returns_multi_root_backend_for_multiple_roots(tmp_path: Path):
root_one = tmp_path / "resume"
root_two = tmp_path / "notes"
root_one.mkdir()
root_two.mkdir()
selection = FilesystemSelection(
mode=FilesystemMode.DESKTOP_LOCAL_FOLDER,
client_platform=ClientPlatform.DESKTOP,
local_mounts=(
LocalFilesystemMount(mount_id="resume", root_path=str(root_one)),
LocalFilesystemMount(mount_id="notes", root_path=str(root_two)),
),
)
resolver = build_backend_resolver(selection)
backend = resolver(_RuntimeStub())
assert isinstance(backend, MultiRootLocalFolderBackend)

View file

@ -0,0 +1,164 @@
from pathlib import Path
import pytest
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
MultiRootLocalFolderBackend,
)
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
pytestmark = pytest.mark.unit
class _BackendWithRawRead:
def __init__(self, content: str) -> None:
self._content = content
def read(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
del file_path, offset, limit
return " 1\tline1\n 2\tline2"
async def aread(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
return self.read(file_path, offset, limit)
def read_raw(self, file_path: str) -> str:
del file_path
return self._content
async def aread_raw(self, file_path: str) -> str:
return self.read_raw(file_path)
class _RuntimeNoSuggestedPath:
state = {"file_operation_contract": {}}
def test_verify_written_content_prefers_raw_sync() -> None:
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
expected = "line1\nline2"
backend = _BackendWithRawRead(expected)
verify_error = middleware._verify_written_content_sync(
backend=backend,
path="/note.md",
expected_content=expected,
)
assert verify_error is None
def test_contract_suggested_path_falls_back_to_notes_md() -> None:
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._filesystem_mode = FilesystemMode.CLOUD
suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type]
assert suggested == "/notes.md"
@pytest.mark.asyncio
async def test_verify_written_content_prefers_raw_async() -> None:
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
expected = "line1\nline2"
backend = _BackendWithRawRead(expected)
verify_error = await middleware._verify_written_content_async(
backend=backend,
path="/note.md",
expected_content=expected,
)
assert verify_error is None
def test_normalize_local_mount_path_prefixes_default_mount(tmp_path: Path) -> None:
root = tmp_path / "PC Backups"
root.mkdir()
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
runtime = _RuntimeNoSuggestedPath()
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path("/random-note.md", runtime) # type: ignore[arg-type]
assert resolved == "/pc_backups/random-note.md"
def test_normalize_local_mount_path_keeps_explicit_mount(tmp_path: Path) -> None:
root = tmp_path / "PC Backups"
root.mkdir()
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
runtime = _RuntimeNoSuggestedPath()
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
"/pc_backups/notes/random-note.md",
runtime,
)
assert resolved == "/pc_backups/notes/random-note.md"
def test_normalize_local_mount_path_windows_backslashes(tmp_path: Path) -> None:
root = tmp_path / "PC Backups"
root.mkdir()
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
runtime = _RuntimeNoSuggestedPath()
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
r"\notes\random-note.md",
runtime,
)
assert resolved == "/pc_backups/notes/random-note.md"
def test_normalize_local_mount_path_normalizes_mixed_separators(tmp_path: Path) -> None:
root = tmp_path / "PC Backups"
root.mkdir()
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
runtime = _RuntimeNoSuggestedPath()
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
r"\\notes//nested\\random-note.md",
runtime,
)
assert resolved == "/pc_backups/notes/nested/random-note.md"
def test_normalize_local_mount_path_keeps_explicit_mount_with_backslashes(
tmp_path: Path,
) -> None:
root = tmp_path / "PC Backups"
root.mkdir()
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
runtime = _RuntimeNoSuggestedPath()
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
r"\pc_backups\notes\random-note.md",
runtime,
)
assert resolved == "/pc_backups/notes/random-note.md"
def test_normalize_local_mount_path_prefixes_posix_absolute_path_for_linux_and_macos(
tmp_path: Path,
) -> None:
root = tmp_path / "PC Backups"
root.mkdir()
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
runtime = _RuntimeNoSuggestedPath()
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path("/var/log/app.log", runtime) # type: ignore[arg-type]
assert resolved == "/pc_backups/var/log/app.log"

View file

@ -0,0 +1,59 @@
from pathlib import Path
import pytest
from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend
pytestmark = pytest.mark.unit
def test_local_backend_write_read_edit_roundtrip(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
write = backend.write("/notes/test.md", "line1\nline2")
assert write.error is None
assert write.path == "/notes/test.md"
read = backend.read("/notes/test.md", offset=0, limit=20)
assert "line1" in read
assert "line2" in read
edit = backend.edit("/notes/test.md", "line2", "updated")
assert edit.error is None
assert edit.occurrences == 1
read_after = backend.read("/notes/test.md", offset=0, limit=20)
assert "updated" in read_after
def test_local_backend_blocks_path_escape(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
result = backend.write("/../../etc/passwd", "bad")
assert result.error is not None
assert "Invalid path" in result.error
def test_local_backend_glob_and_grep(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "docs").mkdir()
(tmp_path / "docs" / "a.txt").write_text("hello world\n")
(tmp_path / "docs" / "b.md").write_text("hello markdown\n")
infos = backend.glob_info("**/*.txt", "/docs")
paths = {info["path"] for info in infos}
assert "/docs/a.txt" in paths
grep = backend.grep_raw("hello", "/docs", "*.md")
assert isinstance(grep, list)
assert any(match["path"] == "/docs/b.md" for match in grep)
def test_local_backend_read_raw_returns_exact_content(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
expected = "# Title\n\nline 1\nline 2\n"
write = backend.write("/notes/raw.md", expected)
assert write.error is None
raw = backend.read_raw("/notes/raw.md")
assert raw == expected

View file

@ -0,0 +1,28 @@
from pathlib import Path
import pytest
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
MultiRootLocalFolderBackend,
)
pytestmark = pytest.mark.unit
def test_mount_ids_preserve_client_mapping_order(tmp_path: Path) -> None:
root_one = tmp_path / "PC Backups"
root_two = tmp_path / "pc_backups"
root_three = tmp_path / "notes@2026"
root_one.mkdir()
root_two.mkdir()
root_three.mkdir()
backend = MultiRootLocalFolderBackend(
(
("pc_backups", str(root_one)),
("pc_backups_2", str(root_two)),
("notes_2026", str(root_three)),
)
)
assert backend.list_mounts() == ("pc_backups", "pc_backups_2", "notes_2026")

View file

@ -0,0 +1,48 @@
import pytest
from app.tasks.chat.stream_new_chat import (
StreamResult,
_contract_enforcement_active,
_evaluate_file_contract_outcome,
_tool_output_has_error,
)
pytestmark = pytest.mark.unit
def test_tool_output_error_detection():
assert _tool_output_has_error("Error: failed to write file")
assert _tool_output_has_error({"error": "boom"})
assert _tool_output_has_error({"result": "Error: disk is full"})
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
def test_file_write_contract_outcome_reasons():
result = StreamResult(intent_detected="file_write")
passed, reason = _evaluate_file_contract_outcome(result)
assert not passed
assert reason == "no_write_attempt"
result.write_attempted = True
passed, reason = _evaluate_file_contract_outcome(result)
assert not passed
assert reason == "write_failed"
result.write_succeeded = True
passed, reason = _evaluate_file_contract_outcome(result)
assert not passed
assert reason == "verification_failed"
result.verification_succeeded = True
passed, reason = _evaluate_file_contract_outcome(result)
assert passed
assert reason == ""
def test_contract_enforcement_local_only():
result = StreamResult(filesystem_mode="desktop_local_folder")
assert _contract_enforcement_active(result)
result.filesystem_mode = "cloud"
assert not _contract_enforcement_active(result)

View file

@ -8070,7 +8070,7 @@ requires-dist = [
{ name = "langgraph", specifier = ">=1.1.3" },
{ name = "langgraph-checkpoint-postgres", specifier = ">=3.0.2" },
{ name = "linkup-sdk", specifier = ">=0.2.4" },
{ name = "litellm", specifier = ">=1.83.0" },
{ name = "litellm", specifier = ">=1.83.4" },
{ name = "llama-cloud-services", specifier = ">=0.6.25" },
{ name = "markdown", specifier = ">=3.7" },
{ name = "markdownify", specifier = ">=0.14.1" },

View file

@ -34,6 +34,8 @@ export const IPC_CHANNELS = {
FOLDER_SYNC_SEED_MTIMES: 'folder-sync:seed-mtimes',
BROWSE_FILES: 'browse:files',
READ_LOCAL_FILES: 'browse:read-local-files',
READ_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:read-local-file-text',
WRITE_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:write-local-file-text',
// Auth token sync across windows
GET_AUTH_TOKENS: 'auth:get-tokens',
SET_AUTH_TOKENS: 'auth:set-tokens',
@ -51,4 +53,9 @@ export const IPC_CHANNELS = {
ANALYTICS_RESET: 'analytics:reset',
ANALYTICS_CAPTURE: 'analytics:capture',
ANALYTICS_GET_CONTEXT: 'analytics:get-context',
// Agent filesystem mode
AGENT_FILESYSTEM_GET_SETTINGS: 'agent-filesystem:get-settings',
AGENT_FILESYSTEM_GET_MOUNTS: 'agent-filesystem:get-mounts',
AGENT_FILESYSTEM_SET_SETTINGS: 'agent-filesystem:set-settings',
AGENT_FILESYSTEM_PICK_ROOT: 'agent-filesystem:pick-root',
} as const;

View file

@ -36,6 +36,14 @@ import {
resetUser as analyticsReset,
trackEvent,
} from '../modules/analytics';
import {
readAgentLocalFileText,
writeAgentLocalFileText,
getAgentFilesystemMounts,
getAgentFilesystemSettings,
pickAgentFilesystemRoot,
setAgentFilesystemSettings,
} from '../modules/agent-filesystem';
let authTokens: { bearer: string; refresh: string } | null = null;
@ -118,6 +126,29 @@ export function registerIpcHandlers(): void {
readLocalFiles(paths)
);
ipcMain.handle(IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, async (_event, virtualPath: string) => {
try {
const result = await readAgentLocalFileText(virtualPath);
return { ok: true, path: result.path, content: result.content };
} catch (error) {
const message = error instanceof Error ? error.message : 'Failed to read local file';
return { ok: false, path: virtualPath, error: message };
}
});
ipcMain.handle(
IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT,
async (_event, virtualPath: string, content: string) => {
try {
const result = await writeAgentLocalFileText(virtualPath, content);
return { ok: true, path: result.path };
} catch (error) {
const message = error instanceof Error ? error.message : 'Failed to write local file';
return { ok: false, path: virtualPath, error: message };
}
}
);
ipcMain.handle(IPC_CHANNELS.SET_AUTH_TOKENS, (_event, tokens: { bearer: string; refresh: string }) => {
authTokens = tokens;
});
@ -191,4 +222,22 @@ export function registerIpcHandlers(): void {
platform: process.platform,
};
});
ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS, () =>
getAgentFilesystemSettings()
);
ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS, () =>
getAgentFilesystemMounts()
);
ipcMain.handle(
IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS,
(_event, settings: { mode?: 'cloud' | 'desktop_local_folder'; localRootPaths?: string[] | null }) =>
setAgentFilesystemSettings(settings)
);
ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT, () =>
pickAgentFilesystemRoot()
);
}

View file

@ -0,0 +1,254 @@
import { app, dialog } from "electron";
import { access, mkdir, readFile, writeFile } from "node:fs/promises";
import { dirname, isAbsolute, join, relative, resolve } from "node:path";
export type AgentFilesystemMode = "cloud" | "desktop_local_folder";
export interface AgentFilesystemSettings {
mode: AgentFilesystemMode;
localRootPaths: string[];
updatedAt: string;
}
const SETTINGS_FILENAME = "agent-filesystem-settings.json";
const MAX_LOCAL_ROOTS = 5;
function getSettingsPath(): string {
return join(app.getPath("userData"), SETTINGS_FILENAME);
}
function getDefaultSettings(): AgentFilesystemSettings {
return {
mode: "cloud",
localRootPaths: [],
updatedAt: new Date().toISOString(),
};
}
function normalizeLocalRootPaths(paths: unknown): string[] {
if (!Array.isArray(paths)) {
return [];
}
const uniquePaths = new Set<string>();
for (const path of paths) {
if (typeof path !== "string") continue;
const trimmed = path.trim();
if (!trimmed) continue;
uniquePaths.add(trimmed);
if (uniquePaths.size >= MAX_LOCAL_ROOTS) {
break;
}
}
return [...uniquePaths];
}
export async function getAgentFilesystemSettings(): Promise<AgentFilesystemSettings> {
try {
const raw = await readFile(getSettingsPath(), "utf8");
const parsed = JSON.parse(raw) as Partial<AgentFilesystemSettings>;
if (parsed.mode !== "cloud" && parsed.mode !== "desktop_local_folder") {
return getDefaultSettings();
}
return {
mode: parsed.mode,
localRootPaths: normalizeLocalRootPaths(parsed.localRootPaths),
updatedAt: parsed.updatedAt ?? new Date().toISOString(),
};
} catch {
return getDefaultSettings();
}
}
export async function setAgentFilesystemSettings(
settings: {
mode?: AgentFilesystemMode;
localRootPaths?: string[] | null;
}
): Promise<AgentFilesystemSettings> {
const current = await getAgentFilesystemSettings();
const nextMode =
settings.mode === "cloud" || settings.mode === "desktop_local_folder"
? settings.mode
: current.mode;
const next: AgentFilesystemSettings = {
mode: nextMode,
localRootPaths:
settings.localRootPaths === undefined
? current.localRootPaths
: normalizeLocalRootPaths(settings.localRootPaths ?? []),
updatedAt: new Date().toISOString(),
};
const settingsPath = getSettingsPath();
await mkdir(dirname(settingsPath), { recursive: true });
await writeFile(settingsPath, JSON.stringify(next, null, 2), "utf8");
return next;
}
export async function pickAgentFilesystemRoot(): Promise<string | null> {
const result = await dialog.showOpenDialog({
title: "Select local folder for Agent Filesystem",
properties: ["openDirectory"],
});
if (result.canceled || result.filePaths.length === 0) {
return null;
}
return result.filePaths[0] ?? null;
}
function resolveVirtualPath(rootPath: string, virtualPath: string): string {
if (!virtualPath.startsWith("/")) {
throw new Error("Path must start with '/'");
}
const normalizedRoot = resolve(rootPath);
const relativePath = virtualPath.replace(/^\/+/, "");
if (!relativePath) {
throw new Error("Path must refer to a file under the selected root");
}
const absolutePath = resolve(normalizedRoot, relativePath);
const rel = relative(normalizedRoot, absolutePath);
if (!rel || rel.startsWith("..") || isAbsolute(rel)) {
throw new Error("Path escapes selected local root");
}
return absolutePath;
}
function toVirtualPath(rootPath: string, absolutePath: string): string {
const normalizedRoot = resolve(rootPath);
const rel = relative(normalizedRoot, absolutePath);
if (!rel || rel.startsWith("..") || isAbsolute(rel)) {
return "/";
}
return `/${rel.replace(/\\/g, "/")}`;
}
export type LocalRootMount = {
mount: string;
rootPath: string;
};
function sanitizeMountName(rawMount: string): string {
const normalized = rawMount
.trim()
.toLowerCase()
.replace(/[^a-z0-9_-]+/g, "_")
.replace(/_+/g, "_")
.replace(/^[_-]+|[_-]+$/g, "");
return normalized || "root";
}
function buildRootMounts(rootPaths: string[]): LocalRootMount[] {
const mounts: LocalRootMount[] = [];
const usedMounts = new Set<string>();
for (const rawRootPath of rootPaths) {
const normalizedRoot = resolve(rawRootPath);
const baseMount = sanitizeMountName(normalizedRoot.split(/[\\/]/).at(-1) || "root");
let mount = baseMount;
let suffix = 2;
while (usedMounts.has(mount)) {
mount = `${baseMount}-${suffix}`;
suffix += 1;
}
usedMounts.add(mount);
mounts.push({ mount, rootPath: normalizedRoot });
}
return mounts;
}
export async function getAgentFilesystemMounts(): Promise<LocalRootMount[]> {
const rootPaths = await resolveCurrentRootPaths();
return buildRootMounts(rootPaths);
}
function parseMountedVirtualPath(
virtualPath: string,
mounts: LocalRootMount[]
): {
mount: string;
subPath: string;
} {
if (!virtualPath.startsWith("/")) {
throw new Error("Path must start with '/'");
}
const trimmed = virtualPath.replace(/^\/+/, "");
if (!trimmed) {
throw new Error("Path must include a mounted root segment");
}
const [mount, ...rest] = trimmed.split("/");
const remainder = rest.join("/");
const directMount = mounts.find((entry) => entry.mount === mount);
if (!directMount) {
throw new Error(
`Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}`
);
}
if (!remainder) {
throw new Error("Path must include a file path under the mounted root");
}
return { mount, subPath: `/${remainder}` };
}
function findMountByName(mounts: LocalRootMount[], mountName: string): LocalRootMount | undefined {
return mounts.find((entry) => entry.mount === mountName);
}
function toMountedVirtualPath(mount: string, rootPath: string, absolutePath: string): string {
const relativePath = toVirtualPath(rootPath, absolutePath);
return `/${mount}${relativePath}`;
}
async function resolveCurrentRootPaths(): Promise<string[]> {
const settings = await getAgentFilesystemSettings();
if (settings.localRootPaths.length === 0) {
throw new Error("No local filesystem roots selected");
}
return settings.localRootPaths;
}
export async function readAgentLocalFileText(
virtualPath: string
): Promise<{ path: string; content: string }> {
const rootPaths = await resolveCurrentRootPaths();
const mounts = buildRootMounts(rootPaths);
const { mount, subPath } = parseMountedVirtualPath(virtualPath, mounts);
const rootMount = findMountByName(mounts, mount);
if (!rootMount) {
throw new Error(
`Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}`
);
}
const absolutePath = resolveVirtualPath(rootMount.rootPath, subPath);
const content = await readFile(absolutePath, "utf8");
return {
path: toMountedVirtualPath(rootMount.mount, rootMount.rootPath, absolutePath),
content,
};
}
export async function writeAgentLocalFileText(
virtualPath: string,
content: string
): Promise<{ path: string }> {
const rootPaths = await resolveCurrentRootPaths();
const mounts = buildRootMounts(rootPaths);
const { mount, subPath } = parseMountedVirtualPath(virtualPath, mounts);
const rootMount = findMountByName(mounts, mount);
if (!rootMount) {
throw new Error(
`Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}`
);
}
let selectedAbsolutePath = resolveVirtualPath(rootMount.rootPath, subPath);
try {
await access(selectedAbsolutePath);
} catch {
// New files are created under the selected mounted root.
}
await mkdir(dirname(selectedAbsolutePath), { recursive: true });
await writeFile(selectedAbsolutePath, content, "utf8");
return {
path: toMountedVirtualPath(rootMount.mount, rootMount.rootPath, selectedAbsolutePath),
};
}

View file

@ -71,6 +71,10 @@ contextBridge.exposeInMainWorld('electronAPI', {
// Browse files via native dialog
browseFiles: () => ipcRenderer.invoke(IPC_CHANNELS.BROWSE_FILES),
readLocalFiles: (paths: string[]) => ipcRenderer.invoke(IPC_CHANNELS.READ_LOCAL_FILES, paths),
readAgentLocalFileText: (virtualPath: string) =>
ipcRenderer.invoke(IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, virtualPath),
writeAgentLocalFileText: (virtualPath: string, content: string) =>
ipcRenderer.invoke(IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT, virtualPath, content),
// Auth token sync across windows
getAuthTokens: () => ipcRenderer.invoke(IPC_CHANNELS.GET_AUTH_TOKENS),
@ -101,4 +105,14 @@ contextBridge.exposeInMainWorld('electronAPI', {
analyticsCapture: (event: string, properties?: Record<string, unknown>) =>
ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_CAPTURE, { event, properties }),
getAnalyticsContext: () => ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_GET_CONTEXT),
// Agent filesystem mode
getAgentFilesystemSettings: () =>
ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS),
getAgentFilesystemMounts: () =>
ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS),
setAgentFilesystemSettings: (settings: {
mode?: "cloud" | "desktop_local_folder";
localRootPaths?: string[] | null;
}) => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS, settings),
pickAgentFilesystemRoot: () => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT),
});

View file

@ -46,6 +46,7 @@ import {
import { useChatSessionStateSync } from "@/hooks/use-chat-session-state";
import { useMessagesSync } from "@/hooks/use-messages-sync";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { getAgentFilesystemSelection } from "@/lib/agent-filesystem";
import { getBearerToken } from "@/lib/auth-utils";
import { convertToThreadMessage } from "@/lib/chat/message-utils";
import {
@ -158,7 +159,7 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] {
/**
* Tools that should render custom UI in the chat.
*/
const TOOLS_WITH_UI = new Set([
const BASE_TOOLS_WITH_UI = new Set([
"web_search",
"generate_podcast",
"generate_report",
@ -210,6 +211,7 @@ export default function NewChatPage() {
assistantMsgId: string;
interruptData: Record<string, unknown>;
} | null>(null);
const toolsWithUI = useMemo(() => new Set([...BASE_TOOLS_WITH_UI]), []);
// Get disabled tools from the tool toggle UI
const disabledTools = useAtomValue(disabledToolsAtom);
@ -656,6 +658,15 @@ export default function NewChatPage() {
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
const selection = await getAgentFilesystemSelection();
if (
selection.filesystem_mode === "desktop_local_folder" &&
(!selection.local_filesystem_mounts ||
selection.local_filesystem_mounts.length === 0)
) {
toast.error("Select a local folder before using Local Folder mode.");
return;
}
// Build message history for context
const messageHistory = messages
@ -691,6 +702,9 @@ export default function NewChatPage() {
chat_id: currentThreadId,
user_query: userQuery.trim(),
search_space_id: searchSpaceId,
filesystem_mode: selection.filesystem_mode,
client_platform: selection.client_platform,
local_filesystem_mounts: selection.local_filesystem_mounts,
messages: messageHistory,
mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined,
mentioned_surfsense_doc_ids: hasSurfsenseDocIds
@ -709,7 +723,7 @@ export default function NewChatPage() {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
: m
)
);
@ -724,7 +738,7 @@ export default function NewChatPage() {
break;
case "tool-input-start":
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
batcher.flush();
break;
@ -734,7 +748,7 @@ export default function NewChatPage() {
} else {
addToolCall(
contentPartsState,
TOOLS_WITH_UI,
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
parsed.input || {}
@ -830,7 +844,7 @@ export default function NewChatPage() {
const tcId = `interrupt-${action.name}`;
addToolCall(
contentPartsState,
TOOLS_WITH_UI,
toolsWithUI,
tcId,
action.name,
action.args,
@ -844,7 +858,7 @@ export default function NewChatPage() {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
: m
)
);
@ -871,7 +885,7 @@ export default function NewChatPage() {
batcher.flush();
// Skip persistence for interrupted messages -- handleResume will persist the final version
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
if (contentParts.length > 0 && !wasInterrupted) {
try {
const savedMessage = await appendMessage(currentThreadId, {
@ -907,10 +921,10 @@ export default function NewChatPage() {
const hasContent = contentParts.some(
(part) =>
(part.type === "text" && part.text.length > 0) ||
(part.type === "tool-call" && TOOLS_WITH_UI.has(part.toolName))
(part.type === "tool-call" && toolsWithUI.has(part.toolName))
);
if (hasContent && currentThreadId) {
const partialContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
try {
const savedMessage = await appendMessage(currentThreadId, {
role: "assistant",
@ -1074,6 +1088,7 @@ export default function NewChatPage() {
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
const selection = await getAgentFilesystemSelection();
const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, {
method: "POST",
headers: {
@ -1083,6 +1098,9 @@ export default function NewChatPage() {
body: JSON.stringify({
search_space_id: searchSpaceId,
decisions,
filesystem_mode: selection.filesystem_mode,
client_platform: selection.client_platform,
local_filesystem_mounts: selection.local_filesystem_mounts,
}),
signal: controller.signal,
});
@ -1095,7 +1113,7 @@ export default function NewChatPage() {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
: m
)
);
@ -1110,7 +1128,7 @@ export default function NewChatPage() {
break;
case "tool-input-start":
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
batcher.flush();
break;
@ -1122,7 +1140,7 @@ export default function NewChatPage() {
} else {
addToolCall(
contentPartsState,
TOOLS_WITH_UI,
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
parsed.input || {}
@ -1173,7 +1191,7 @@ export default function NewChatPage() {
const tcId = `interrupt-${action.name}`;
addToolCall(
contentPartsState,
TOOLS_WITH_UI,
toolsWithUI,
tcId,
action.name,
action.args,
@ -1190,7 +1208,7 @@ export default function NewChatPage() {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
: m
)
);
@ -1214,7 +1232,7 @@ export default function NewChatPage() {
batcher.flush();
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
if (contentParts.length > 0) {
try {
const savedMessage = await appendMessage(resumeThreadId, {
@ -1406,6 +1424,7 @@ export default function NewChatPage() {
]);
try {
const selection = await getAgentFilesystemSelection();
const response = await fetch(getRegenerateUrl(threadId), {
method: "POST",
headers: {
@ -1416,6 +1435,9 @@ export default function NewChatPage() {
search_space_id: searchSpaceId,
user_query: newUserQuery || null,
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
filesystem_mode: selection.filesystem_mode,
client_platform: selection.client_platform,
local_filesystem_mounts: selection.local_filesystem_mounts,
}),
signal: controller.signal,
});
@ -1428,7 +1450,7 @@ export default function NewChatPage() {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
: m
)
);
@ -1443,7 +1465,7 @@ export default function NewChatPage() {
break;
case "tool-input-start":
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
batcher.flush();
break;
@ -1453,7 +1475,7 @@ export default function NewChatPage() {
} else {
addToolCall(
contentPartsState,
TOOLS_WITH_UI,
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
parsed.input || {}
@ -1502,7 +1524,7 @@ export default function NewChatPage() {
batcher.flush();
// Persist messages after streaming completes
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
if (contentParts.length > 0) {
try {
// Persist user message (for both edit and reload modes, since backend deleted it)

View file

@ -1,9 +1,7 @@
"use client";
import { BrainCog, Power, Rocket, Zap } from "lucide-react";
import { useEffect, useState } from "react";
import { toast } from "sonner";
import { DEFAULT_SHORTCUTS, ShortcutRecorder } from "@/components/desktop/shortcut-recorder";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { Label } from "@/components/ui/label";
import {
@ -24,9 +22,6 @@ export function DesktopContent() {
const [loading, setLoading] = useState(true);
const [enabled, setEnabled] = useState(true);
const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS);
const [shortcutsLoaded, setShortcutsLoaded] = useState(false);
const [searchSpaces, setSearchSpaces] = useState<SearchSpace[]>([]);
const [activeSpaceId, setActiveSpaceId] = useState<string | null>(null);
@ -37,7 +32,6 @@ export function DesktopContent() {
useEffect(() => {
if (!api) {
setLoading(false);
setShortcutsLoaded(true);
return;
}
@ -48,15 +42,13 @@ export function DesktopContent() {
Promise.all([
api.getAutocompleteEnabled(),
api.getShortcuts?.() ?? Promise.resolve(null),
api.getActiveSearchSpace?.() ?? Promise.resolve(null),
searchSpacesApiService.getSearchSpaces(),
hasAutoLaunchApi ? api.getAutoLaunch() : Promise.resolve(null),
])
.then(([autoEnabled, config, spaceId, spaces, autoLaunch]) => {
.then(([autoEnabled, spaceId, spaces, autoLaunch]) => {
if (!mounted) return;
setEnabled(autoEnabled);
if (config) setShortcuts(config);
setActiveSpaceId(spaceId);
if (spaces) setSearchSpaces(spaces);
if (autoLaunch) {
@ -65,12 +57,10 @@ export function DesktopContent() {
setAutoLaunchSupported(autoLaunch.supported);
}
setLoading(false);
setShortcutsLoaded(true);
})
.catch(() => {
if (!mounted) return;
setLoading(false);
setShortcutsLoaded(true);
});
return () => {
@ -82,7 +72,7 @@ export function DesktopContent() {
return (
<div className="flex flex-col items-center justify-center py-12 text-center">
<p className="text-sm text-muted-foreground">
Desktop settings are only available in the SurfSense desktop app.
App preferences are only available in the SurfSense desktop app.
</p>
</div>
);
@ -101,24 +91,6 @@ export function DesktopContent() {
await api.setAutocompleteEnabled(checked);
};
const updateShortcut = (
key: "generalAssist" | "quickAsk" | "autocomplete",
accelerator: string
) => {
setShortcuts((prev) => {
const updated = { ...prev, [key]: accelerator };
api.setShortcuts?.({ [key]: accelerator }).catch(() => {
toast.error("Failed to update shortcut");
});
return updated;
});
toast.success("Shortcut updated");
};
const resetShortcut = (key: "generalAssist" | "quickAsk" | "autocomplete") => {
updateShortcut(key, DEFAULT_SHORTCUTS[key]);
};
const handleAutoLaunchToggle = async (checked: boolean) => {
if (!autoLaunchSupported || !api.setAutoLaunch) {
toast.error("Please update the desktop app to configure launch on startup");
@ -196,7 +168,6 @@ export function DesktopContent() {
<Card>
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
<CardTitle className="text-base md:text-lg flex items-center gap-2">
<Power className="h-4 w-4" />
Launch on Startup
</CardTitle>
<CardDescription className="text-xs md:text-sm">
@ -245,56 +216,6 @@ export function DesktopContent() {
</CardContent>
</Card>
{/* Keyboard Shortcuts */}
<Card>
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
<CardTitle className="text-base md:text-lg">Keyboard Shortcuts</CardTitle>
<CardDescription className="text-xs md:text-sm">
Customize the global keyboard shortcuts for desktop features.
</CardDescription>
</CardHeader>
<CardContent className="px-3 md:px-6 pb-3 md:pb-6">
{shortcutsLoaded ? (
<div className="flex flex-col gap-3">
<ShortcutRecorder
value={shortcuts.generalAssist}
onChange={(accel) => updateShortcut("generalAssist", accel)}
onReset={() => resetShortcut("generalAssist")}
defaultValue={DEFAULT_SHORTCUTS.generalAssist}
label="General Assist"
description="Launch SurfSense instantly from any application"
icon={Rocket}
/>
<ShortcutRecorder
value={shortcuts.quickAsk}
onChange={(accel) => updateShortcut("quickAsk", accel)}
onReset={() => resetShortcut("quickAsk")}
defaultValue={DEFAULT_SHORTCUTS.quickAsk}
label="Quick Assist"
description="Select text anywhere, then ask AI to explain, rewrite, or act on it"
icon={Zap}
/>
<ShortcutRecorder
value={shortcuts.autocomplete}
onChange={(accel) => updateShortcut("autocomplete", accel)}
onReset={() => resetShortcut("autocomplete")}
defaultValue={DEFAULT_SHORTCUTS.autocomplete}
label="Extreme Assist"
description="AI drafts text using your screen context and knowledge base"
icon={BrainCog}
/>
<p className="text-[11px] text-muted-foreground">
Click a shortcut and press a new key combination to change it.
</p>
</div>
) : (
<div className="flex justify-center py-4">
<Spinner size="sm" />
</div>
)}
</CardContent>
</Card>
{/* Extreme Assist Toggle */}
<Card>
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">

Some files were not shown because too many files have changed in this diff Show more