This commit is contained in:
CREDO23 2026-04-24 10:56:10 +02:00
commit fa5d6369fb
106 changed files with 5673 additions and 787 deletions

View file

@ -239,6 +239,9 @@ LLAMA_CLOUD_API_KEY=llx-nnn
# DAYTONA_TARGET=us
# DAYTONA_SNAPSHOT_ID=
# Desktop local filesystem mode (chat file tools run against a local folder root)
# ENABLE_DESKTOP_LOCAL_FILESYSTEM=FALSE
# OPTIONAL: Add these for LangSmith Observability
LANGSMITH_TRACING=true
LANGSMITH_ENDPOINT=https://api.smith.langchain.com

View file

@ -33,9 +33,12 @@ from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.filesystem_backends import build_backend_resolver
from app.agents.new_chat.filesystem_selection import FilesystemSelection
from app.agents.new_chat.llm_config import AgentConfig
from app.agents.new_chat.middleware import (
DedupHITLToolCallsMiddleware,
FileIntentMiddleware,
KnowledgeBaseSearchMiddleware,
MemoryInjectionMiddleware,
SurfSenseFilesystemMiddleware,
@ -164,6 +167,7 @@ async def create_surfsense_deep_agent(
thread_visibility: ChatVisibility | None = None,
mentioned_document_ids: list[int] | None = None,
anon_session_id: str | None = None,
filesystem_selection: FilesystemSelection | None = None,
):
"""
Create a SurfSense deep agent with configurable tools and prompts.
@ -238,6 +242,8 @@ async def create_surfsense_deep_agent(
)
"""
_t_agent_total = time.perf_counter()
filesystem_selection = filesystem_selection or FilesystemSelection()
backend_resolver = build_backend_resolver(filesystem_selection)
# Discover available connectors and document types for this search space
available_connectors: list[str] | None = None
@ -314,6 +320,20 @@ async def create_surfsense_deep_agent(
_t0 = time.perf_counter()
_enabled_tool_names = {t.name for t in tools}
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
# Collect generic MCP connector info so the system prompt can route queries
# to their tools instead of falling back to "not in knowledge base".
_mcp_connector_tools: dict[str, list[str]] = {}
for t in tools:
meta = getattr(t, "metadata", None) or {}
if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"):
_mcp_connector_tools.setdefault(
meta["mcp_connector_name"], [],
).append(t.name)
if _mcp_connector_tools:
_perf_log.info("MCP connector tool routing: %s", _mcp_connector_tools)
if agent_config is not None:
system_prompt = build_configurable_system_prompt(
custom_system_instructions=agent_config.system_instructions,
@ -322,12 +342,14 @@ async def create_surfsense_deep_agent(
thread_visibility=thread_visibility,
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools,
)
else:
system_prompt = build_surfsense_system_prompt(
thread_visibility=thread_visibility,
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools,
)
_perf_log.info(
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
@ -344,7 +366,10 @@ async def create_surfsense_deep_agent(
gp_middleware = [
TodoListMiddleware(),
_memory_middleware,
FileIntentMiddleware(llm=llm),
SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_selection.mode,
search_space_id=search_space_id,
created_by_id=user_id,
thread_id=thread_id,
@ -365,15 +390,19 @@ async def create_surfsense_deep_agent(
deepagent_middleware = [
TodoListMiddleware(),
_memory_middleware,
FileIntentMiddleware(llm=llm),
KnowledgeBaseSearchMiddleware(
llm=llm,
search_space_id=search_space_id,
filesystem_mode=filesystem_selection.mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
anon_session_id=anon_session_id,
),
SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_selection.mode,
search_space_id=search_space_id,
created_by_id=user_id,
thread_id=thread_id,

View file

@ -4,7 +4,15 @@ Context schema definitions for SurfSense agents.
This module defines the custom state schema used by the SurfSense deep agent.
"""
from typing import TypedDict
from typing import NotRequired, TypedDict
class FileOperationContractState(TypedDict):
intent: str
confidence: float
suggested_path: str
timestamp: str
turn_id: str
class SurfSenseContextSchema(TypedDict):
@ -24,5 +32,8 @@ class SurfSenseContextSchema(TypedDict):
"""
search_space_id: int
file_operation_contract: NotRequired[FileOperationContractState]
turn_id: NotRequired[str]
request_id: NotRequired[str]
# These are runtime-injected and won't be serialized
# db_session and connector_service are passed when invoking the agent

View file

@ -0,0 +1,42 @@
"""Filesystem backend resolver for cloud and desktop-local modes."""
from __future__ import annotations
from collections.abc import Callable
from functools import lru_cache
from deepagents.backends.state import StateBackend
from langgraph.prebuilt.tool_node import ToolRuntime
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
MultiRootLocalFolderBackend,
)
@lru_cache(maxsize=64)
def _cached_multi_root_backend(
mounts: tuple[tuple[str, str], ...],
) -> MultiRootLocalFolderBackend:
return MultiRootLocalFolderBackend(mounts)
def build_backend_resolver(
selection: FilesystemSelection,
) -> Callable[[ToolRuntime], StateBackend | MultiRootLocalFolderBackend]:
"""Create deepagents backend resolver for the selected filesystem mode."""
if selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER and selection.local_mounts:
def _resolve_local(_runtime: ToolRuntime) -> MultiRootLocalFolderBackend:
mounts = tuple(
(entry.mount_id, entry.root_path) for entry in selection.local_mounts
)
return _cached_multi_root_backend(mounts)
return _resolve_local
def _resolve_cloud(runtime: ToolRuntime) -> StateBackend:
return StateBackend(runtime)
return _resolve_cloud

View file

@ -0,0 +1,41 @@
"""Filesystem mode contracts and selection helpers for chat sessions."""
from __future__ import annotations
from dataclasses import dataclass
from enum import StrEnum
class FilesystemMode(StrEnum):
"""Supported filesystem backends for agent tool execution."""
CLOUD = "cloud"
DESKTOP_LOCAL_FOLDER = "desktop_local_folder"
class ClientPlatform(StrEnum):
"""Client runtime reported by the caller."""
WEB = "web"
DESKTOP = "desktop"
@dataclass(slots=True)
class LocalFilesystemMount:
"""Canonical mount mapping provided by desktop runtime."""
mount_id: str
root_path: str
@dataclass(slots=True)
class FilesystemSelection:
"""Resolved filesystem selection for a single chat request."""
mode: FilesystemMode = FilesystemMode.CLOUD
client_platform: ClientPlatform = ClientPlatform.WEB
local_mounts: tuple[LocalFilesystemMount, ...] = ()
@property
def is_local_mode(self) -> bool:
return self.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER

View file

@ -6,6 +6,9 @@ from app.agents.new_chat.middleware.dedup_tool_calls import (
from app.agents.new_chat.middleware.filesystem import (
SurfSenseFilesystemMiddleware,
)
from app.agents.new_chat.middleware.file_intent import (
FileIntentMiddleware,
)
from app.agents.new_chat.middleware.knowledge_search import (
KnowledgeBaseSearchMiddleware,
)
@ -15,6 +18,7 @@ from app.agents.new_chat.middleware.memory_injection import (
__all__ = [
"DedupHITLToolCallsMiddleware",
"FileIntentMiddleware",
"KnowledgeBaseSearchMiddleware",
"MemoryInjectionMiddleware",
"SurfSenseFilesystemMiddleware",

View file

@ -0,0 +1,352 @@
"""Semantic file-intent routing middleware for new chat turns.
This middleware classifies the latest human turn into a small intent set:
- chat_only
- file_write
- file_read
For ``file_write`` turns it injects a strict system contract so the model
uses filesystem tools before claiming success, and provides a deterministic
fallback path when no filename is specified by the user.
"""
from __future__ import annotations
import json
import logging
import re
from datetime import UTC, datetime
from enum import StrEnum
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langgraph.runtime import Runtime
from pydantic import BaseModel, Field, ValidationError
logger = logging.getLogger(__name__)
class FileOperationIntent(StrEnum):
CHAT_ONLY = "chat_only"
FILE_WRITE = "file_write"
FILE_READ = "file_read"
class FileIntentPlan(BaseModel):
intent: FileOperationIntent = Field(
description="Primary user intent for this turn."
)
confidence: float = Field(
ge=0.0,
le=1.0,
default=0.5,
description="Model confidence in the selected intent.",
)
suggested_filename: str | None = Field(
default=None,
description="Optional filename (e.g. notes.md) inferred from user request.",
)
suggested_directory: str | None = Field(
default=None,
description=(
"Optional directory path (e.g. /reports/q2 or reports/q2) inferred from "
"user request."
),
)
suggested_path: str | None = Field(
default=None,
description=(
"Optional full file path (e.g. /reports/q2/summary.md). If present, this "
"takes precedence over suggested_directory + suggested_filename."
),
)
def _extract_text_from_message(message: BaseMessage) -> str:
content = getattr(message, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
return "\n".join(part for part in parts if part)
return str(content)
def _extract_json_payload(text: str) -> str:
stripped = text.strip()
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
if fenced:
return fenced.group(1)
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start : end + 1]
return stripped
def _sanitize_filename(value: str) -> str:
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
name = re.sub(r"\s+", "-", name)
name = name.strip("._-")
if not name:
name = "note"
if len(name) > 80:
name = name[:80].rstrip("-_.")
return name
def _sanitize_path_segment(value: str) -> str:
segment = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
segment = re.sub(r"\s+", "_", segment)
segment = segment.strip("._-")
return segment
def _infer_text_file_extension(user_text: str) -> str:
lowered = user_text.lower()
if any(token in lowered for token in ("json", ".json")):
return ".json"
if any(token in lowered for token in ("yaml", "yml", ".yaml", ".yml")):
return ".yaml"
if any(token in lowered for token in ("csv", ".csv")):
return ".csv"
if any(token in lowered for token in ("python", ".py")):
return ".py"
if any(token in lowered for token in ("typescript", ".ts", ".tsx")):
return ".ts"
if any(token in lowered for token in ("javascript", ".js", ".mjs", ".cjs")):
return ".js"
if any(token in lowered for token in ("html", ".html")):
return ".html"
if any(token in lowered for token in ("css", ".css")):
return ".css"
if any(token in lowered for token in ("sql", ".sql")):
return ".sql"
if any(token in lowered for token in ("toml", ".toml")):
return ".toml"
if any(token in lowered for token in ("ini", ".ini")):
return ".ini"
if any(token in lowered for token in ("xml", ".xml")):
return ".xml"
if any(token in lowered for token in ("markdown", ".md", "readme")):
return ".md"
return ".md"
def _normalize_directory(value: str) -> str:
raw = value.strip().replace("\\", "/")
raw = raw.strip("/")
if not raw:
return ""
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
parts = [part for part in parts if part]
return "/".join(parts)
def _normalize_file_path(value: str) -> str:
raw = value.strip().replace("\\", "/").strip()
if not raw:
return ""
had_trailing_slash = raw.endswith("/")
raw = raw.strip("/")
if not raw:
return ""
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
parts = [part for part in parts if part]
if not parts:
return ""
if had_trailing_slash:
return f"/{'/'.join(parts)}/"
return f"/{'/'.join(parts)}"
def _infer_directory_from_user_text(user_text: str) -> str | None:
patterns = (
r"\b(?:in|inside|under)\s+(?:the\s+)?([a-zA-Z0-9 _\-/]+?)\s+folder\b",
r"\b(?:in|inside|under)\s+([a-zA-Z0-9 _\-/]+?)\b",
)
lowered = user_text.lower()
for pattern in patterns:
match = re.search(pattern, lowered, flags=re.IGNORECASE)
if not match:
continue
candidate = match.group(1).strip()
if candidate in {"the", "a", "an"}:
continue
normalized = _normalize_directory(candidate)
if normalized:
return normalized
return None
def _fallback_path(
suggested_filename: str | None,
*,
suggested_directory: str | None = None,
suggested_path: str | None = None,
user_text: str,
) -> str:
default_extension = _infer_text_file_extension(user_text)
inferred_dir = _infer_directory_from_user_text(user_text)
sanitized_filename = ""
if suggested_filename:
sanitized_filename = _sanitize_filename(suggested_filename)
if sanitized_filename.lower().endswith(".txt"):
sanitized_filename = f"{sanitized_filename[:-4]}.md"
if not sanitized_filename:
sanitized_filename = f"notes{default_extension}"
elif "." not in sanitized_filename:
sanitized_filename = f"{sanitized_filename}{default_extension}"
normalized_suggested_path = (
_normalize_file_path(suggested_path) if suggested_path else ""
)
if normalized_suggested_path:
if normalized_suggested_path.endswith("/"):
return f"{normalized_suggested_path.rstrip('/')}/{sanitized_filename}"
return normalized_suggested_path
directory = _normalize_directory(suggested_directory or "")
if not directory and inferred_dir:
directory = inferred_dir
if directory:
return f"/{directory}/{sanitized_filename}"
return f"/{sanitized_filename}"
def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str:
return (
"Classify the latest user request into a filesystem intent for an AI agent.\n"
"Return JSON only with this exact schema:\n"
'{"intent":"chat_only|file_write|file_read","confidence":0.0,"suggested_filename":"string or null","suggested_directory":"string or null","suggested_path":"string or null"}\n\n'
"Rules:\n"
"- Use semantic intent, not literal keywords.\n"
"- file_write: user asks to create/save/write/update/edit content as a file.\n"
"- file_read: user asks to open/read/list/search existing files.\n"
"- chat_only: conversational/analysis responses without required file operations.\n"
"- For file_write, choose a concise semantic suggested_filename and match the requested format.\n"
"- If the user mentions a folder/directory, populate suggested_directory.\n"
"- If user specifies an explicit full path, populate suggested_path.\n"
"- Use extensions that match user intent (e.g. .md, .json, .yaml, .csv, .py, .ts, .js, .html, .css, .sql).\n"
"- Do not use .txt; prefer .md for generic text notes.\n"
"- Do not include dates or timestamps in suggested_filename unless explicitly requested.\n"
"- Never include markdown or explanation.\n\n"
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
f"Latest user message:\n{user_text}"
)
def _build_recent_conversation(messages: list[BaseMessage], *, max_messages: int = 6) -> str:
rows: list[str] = []
for msg in messages[-max_messages:]:
role = "user" if isinstance(msg, HumanMessage) else "assistant"
text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip()
if text:
rows.append(f"{role}: {text[:280]}")
return "\n".join(rows)
class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Classify file intent and inject a strict file-write contract."""
tools = ()
def __init__(self, *, llm: BaseChatModel | None = None) -> None:
self.llm = llm
async def _classify_intent(
self, *, messages: list[BaseMessage], user_text: str
) -> FileIntentPlan:
if self.llm is None:
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
prompt = _build_classifier_prompt(
recent_conversation=_build_recent_conversation(messages),
user_text=user_text,
)
try:
response = await self.llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
payload = json.loads(_extract_json_payload(_extract_text_from_message(response)))
plan = FileIntentPlan.model_validate(payload)
return plan
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
logger.warning("File intent classifier returned invalid output: %s", exc)
except Exception as exc: # pragma: no cover - defensive fallback
logger.warning("File intent classifier failed: %s", exc)
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
messages = state.get("messages") or []
if not messages:
return None
last_human: HumanMessage | None = None
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
last_human = msg
break
if last_human is None:
return None
user_text = _extract_text_from_message(last_human).strip()
if not user_text:
return None
plan = await self._classify_intent(messages=messages, user_text=user_text)
suggested_path = _fallback_path(
plan.suggested_filename,
suggested_directory=plan.suggested_directory,
suggested_path=plan.suggested_path,
user_text=user_text,
)
contract = {
"intent": plan.intent.value,
"confidence": plan.confidence,
"suggested_path": suggested_path,
"timestamp": datetime.now(UTC).isoformat(),
"turn_id": state.get("turn_id", ""),
}
if plan.intent != FileOperationIntent.FILE_WRITE:
return {"file_operation_contract": contract}
contract_msg = SystemMessage(
content=(
"<file_operation_contract>\n"
"This turn intent is file_write.\n"
f"Suggested default path: {suggested_path}\n"
"Rules:\n"
"- You MUST call write_file or edit_file before claiming success.\n"
"- If no path is provided by the user, use the suggested default path.\n"
"- Do not claim a file was created/updated unless tool output confirms it.\n"
"- If the write/edit fails, clearly report failure instead of success.\n"
"- Do not include timestamps or dates in generated file content unless the user explicitly asks for them.\n"
"- For open-ended requests (e.g., random note), generate useful concrete content, not placeholders.\n"
"</file_operation_contract>"
)
)
# Insert just before the latest human turn so it applies to this request.
new_messages = list(messages)
insert_at = max(len(new_messages) - 1, 0)
new_messages.insert(insert_at, contract_msg)
return {"messages": new_messages, "file_operation_contract": contract}

View file

@ -26,6 +26,10 @@ from langchain_core.tools import BaseTool, StructuredTool
from langgraph.types import Command
from sqlalchemy import delete, select
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
MultiRootLocalFolderBackend,
)
from app.agents.new_chat.sandbox import (
_evict_sandbox_cache,
delete_sandbox,
@ -50,6 +54,8 @@ SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = """## Following Conventions
- Read files before editing understand existing content before making changes.
- Mimic existing style, naming conventions, and patterns.
- Never claim a file was created/updated unless filesystem tool output confirms success.
- If a file write/edit fails, explicitly report the failure.
## Filesystem Tools
@ -109,13 +115,20 @@ Usage:
- Use chunk IDs (`<chunk id='...'>`) as citations in answers.
"""
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new file to the in-memory filesystem (session-only).
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new text file to the in-memory filesystem (session-only).
Use this to create scratch/working files during the conversation. Files created
here are ephemeral and will not be saved to the user's knowledge base.
To permanently save a document to the user's knowledge base, use the
`save_document` tool instead.
Supported outputs include common LLM-friendly text formats like markdown, json,
yaml, csv, xml, html, css, sql, and code files.
When creating content from open-ended prompts, produce concrete and useful text,
not placeholders. Avoid adding dates/timestamps unless the user explicitly asks
for them.
"""
SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files.
@ -182,11 +195,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
def __init__(
self,
*,
backend: Any = None,
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
search_space_id: int | None = None,
created_by_id: str | None = None,
thread_id: int | str | None = None,
tool_token_limit_before_evict: int | None = 20000,
) -> None:
self._filesystem_mode = filesystem_mode
self._search_space_id = search_space_id
self._created_by_id = created_by_id
self._thread_id = thread_id
@ -204,8 +220,17 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
" extract the data, write it as a clean file (CSV, JSON, etc.),"
" and then run your code against it."
)
if filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
system_prompt += (
"\n\n## Local Folder Mode"
"\n\nThis chat is running in desktop local-folder mode."
" Keep all file operations local. Do not use save_document."
" Always use mount-prefixed absolute paths like /<folder>/file.ext."
" If you are unsure which mounts are available, call ls('/') first."
)
super().__init__(
backend=backend,
system_prompt=system_prompt,
custom_tool_descriptions={
"ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION,
@ -219,7 +244,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
max_execute_timeout=self._MAX_EXECUTE_TIMEOUT,
)
self.tools = [t for t in self.tools if t.name != "execute"]
self.tools.append(self._create_save_document_tool())
if self._should_persist_documents():
self.tools.append(self._create_save_document_tool())
if self._sandbox_available:
self.tools.append(self._create_execute_code_tool())
@ -637,15 +663,25 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
runtime: ToolRuntime[None, FilesystemState],
) -> Command | str:
resolved_backend = self._get_backend(runtime)
target_path = self._resolve_write_target_path(file_path, runtime)
try:
validated_path = validate_path(file_path)
validated_path = validate_path(target_path)
except ValueError as exc:
return f"Error: {exc}"
res: WriteResult = resolved_backend.write(validated_path, content)
if res.error:
return res.error
verify_error = self._verify_written_content_sync(
backend=resolved_backend,
path=validated_path,
expected_content=content,
)
if verify_error:
return verify_error
if not self._is_kb_document(validated_path):
if self._should_persist_documents() and not self._is_kb_document(
validated_path
):
persist_result = self._run_async_blocking(
self._persist_new_document(
file_path=validated_path, content=content
@ -682,15 +718,25 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
runtime: ToolRuntime[None, FilesystemState],
) -> Command | str:
resolved_backend = self._get_backend(runtime)
target_path = self._resolve_write_target_path(file_path, runtime)
try:
validated_path = validate_path(file_path)
validated_path = validate_path(target_path)
except ValueError as exc:
return f"Error: {exc}"
res: WriteResult = await resolved_backend.awrite(validated_path, content)
if res.error:
return res.error
verify_error = await self._verify_written_content_async(
backend=resolved_backend,
path=validated_path,
expected_content=content,
)
if verify_error:
return verify_error
if not self._is_kb_document(validated_path):
if self._should_persist_documents() and not self._is_kb_document(
validated_path
):
persist_result = await self._persist_new_document(
file_path=validated_path,
content=content,
@ -726,6 +772,164 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
"""Return True for paths under /documents/ (KB-sourced, XML-wrapped)."""
return path.startswith("/documents/")
def _should_persist_documents(self) -> bool:
"""Only cloud mode persists file content to Document/Chunk tables."""
return self._filesystem_mode == FilesystemMode.CLOUD
def _default_mount_prefix(self, runtime: ToolRuntime[None, FilesystemState]) -> str:
backend = self._get_backend(runtime)
if isinstance(backend, MultiRootLocalFolderBackend):
return f"/{backend.default_mount()}"
return ""
def _normalize_local_mount_path(
self, candidate: str, runtime: ToolRuntime[None, FilesystemState]
) -> str:
backend = self._get_backend(runtime)
mount_prefix = self._default_mount_prefix(runtime)
normalized_candidate = re.sub(r"/+", "/", candidate.strip().replace("\\", "/"))
if not mount_prefix or not isinstance(backend, MultiRootLocalFolderBackend):
if normalized_candidate.startswith("/"):
return normalized_candidate
return f"/{normalized_candidate.lstrip('/')}"
mount_names = set(backend.list_mounts())
if normalized_candidate.startswith("/"):
first_segment = normalized_candidate.lstrip("/").split("/", 1)[0]
if first_segment in mount_names:
return normalized_candidate
return f"{mount_prefix}{normalized_candidate}"
relative = normalized_candidate.lstrip("/")
first_segment = relative.split("/", 1)[0]
if first_segment in mount_names:
return f"/{relative}"
return f"{mount_prefix}/{relative}"
def _get_contract_suggested_path(
self, runtime: ToolRuntime[None, FilesystemState]
) -> str:
contract = runtime.state.get("file_operation_contract") or {}
suggested = contract.get("suggested_path")
if isinstance(suggested, str) and suggested.strip():
cleaned = suggested.strip()
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
return self._normalize_local_mount_path(cleaned, runtime)
return cleaned
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
mount_prefix = self._default_mount_prefix(runtime)
if mount_prefix:
return f"{mount_prefix}/notes.md"
return "/notes.md"
def _resolve_write_target_path(
self,
file_path: str,
runtime: ToolRuntime[None, FilesystemState],
) -> str:
candidate = file_path.strip()
if not candidate:
return self._get_contract_suggested_path(runtime)
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
return self._normalize_local_mount_path(candidate, runtime)
if not candidate.startswith("/"):
return f"/{candidate.lstrip('/')}"
return candidate
@staticmethod
def _is_error_text(value: str) -> bool:
return value.startswith("Error:")
@staticmethod
def _read_for_verification_sync(backend: Any, path: str) -> str:
read_raw = getattr(backend, "read_raw", None)
if callable(read_raw):
return read_raw(path)
return backend.read(path, offset=0, limit=200000)
@staticmethod
async def _read_for_verification_async(backend: Any, path: str) -> str:
aread_raw = getattr(backend, "aread_raw", None)
if callable(aread_raw):
return await aread_raw(path)
return await backend.aread(path, offset=0, limit=200000)
def _verify_written_content_sync(
self,
*,
backend: Any,
path: str,
expected_content: str,
) -> str | None:
actual = self._read_for_verification_sync(backend, path)
if self._is_error_text(actual):
return f"Error: could not verify written file '{path}'."
if actual.rstrip() != expected_content.rstrip():
return (
"Error: file write verification failed; expected content was not fully written "
f"to '{path}'."
)
return None
async def _verify_written_content_async(
self,
*,
backend: Any,
path: str,
expected_content: str,
) -> str | None:
actual = await self._read_for_verification_async(backend, path)
if self._is_error_text(actual):
return f"Error: could not verify written file '{path}'."
if actual.rstrip() != expected_content.rstrip():
return (
"Error: file write verification failed; expected content was not fully written "
f"to '{path}'."
)
return None
def _verify_edited_content_sync(
self,
*,
backend: Any,
path: str,
new_string: str,
) -> tuple[str | None, str | None]:
updated_content = self._read_for_verification_sync(backend, path)
if self._is_error_text(updated_content):
return (
f"Error: could not verify edited file '{path}'.",
None,
)
if new_string and new_string not in updated_content:
return (
"Error: edit verification failed; updated content was not found in "
f"'{path}'.",
None,
)
return None, updated_content
async def _verify_edited_content_async(
self,
*,
backend: Any,
path: str,
new_string: str,
) -> tuple[str | None, str | None]:
updated_content = await self._read_for_verification_async(backend, path)
if self._is_error_text(updated_content):
return (
f"Error: could not verify edited file '{path}'.",
None,
)
if new_string and new_string not in updated_content:
return (
"Error: edit verification failed; updated content was not found in "
f"'{path}'.",
None,
)
return None, updated_content
def _create_edit_file_tool(self) -> BaseTool:
"""Create edit_file with DB persistence (skipped for KB documents)."""
tool_description = (
@ -754,8 +958,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
] = False,
) -> Command | str:
resolved_backend = self._get_backend(runtime)
target_path = self._resolve_write_target_path(file_path, runtime)
try:
validated_path = validate_path(file_path)
validated_path = validate_path(target_path)
except ValueError as exc:
return f"Error: {exc}"
res: EditResult = resolved_backend.edit(
@ -767,13 +972,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
if res.error:
return res.error
if not self._is_kb_document(validated_path):
read_result = resolved_backend.read(
validated_path, offset=0, limit=200000
)
if read_result.error or read_result.file_data is None:
return f"Error: could not reload edited file '{validated_path}' for persistence."
updated_content = read_result.file_data["content"]
verify_error, updated_content = self._verify_edited_content_sync(
backend=resolved_backend,
path=validated_path,
new_string=new_string,
)
if verify_error:
return verify_error
if self._should_persist_documents() and not self._is_kb_document(
validated_path
):
if updated_content is None:
return (
f"Error: could not reload edited file '{validated_path}' for "
"persistence."
)
persist_result = self._run_async_blocking(
self._persist_edited_document(
file_path=validated_path,
@ -818,8 +1032,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
] = False,
) -> Command | str:
resolved_backend = self._get_backend(runtime)
target_path = self._resolve_write_target_path(file_path, runtime)
try:
validated_path = validate_path(file_path)
validated_path = validate_path(target_path)
except ValueError as exc:
return f"Error: {exc}"
res: EditResult = await resolved_backend.aedit(
@ -831,13 +1046,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
if res.error:
return res.error
if not self._is_kb_document(validated_path):
read_result = await resolved_backend.aread(
validated_path, offset=0, limit=200000
)
if read_result.error or read_result.file_data is None:
return f"Error: could not reload edited file '{validated_path}' for persistence."
updated_content = read_result.file_data["content"]
verify_error, updated_content = await self._verify_edited_content_async(
backend=resolved_backend,
path=validated_path,
new_string=new_string,
)
if verify_error:
return verify_error
if self._should_persist_documents() and not self._is_kb_document(
validated_path
):
if updated_content is None:
return (
f"Error: could not reload edited file '{validated_path}' for "
"persistence."
)
persist_error = await self._persist_edited_document(
file_path=validated_path,
updated_content=updated_content,

View file

@ -28,6 +28,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.db import (
NATIVE_TO_LEGACY_DOCTYPE,
Chunk,
@ -857,6 +858,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
*,
llm: BaseChatModel | None = None,
search_space_id: int,
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
top_k: int = 10,
@ -865,6 +867,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
) -> None:
self.llm = llm
self.search_space_id = search_space_id
self.filesystem_mode = filesystem_mode
self.available_connectors = available_connectors
self.available_document_types = available_document_types
self.top_k = top_k
@ -996,6 +999,9 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
messages = state.get("messages") or []
if not messages:
return None
if self.filesystem_mode != FilesystemMode.CLOUD:
# Local-folder mode should not seed cloud KB documents into filesystem.
return None
last_human = None
for msg in reversed(messages):

View file

@ -0,0 +1,316 @@
"""Desktop local-folder filesystem backend for deepagents tools."""
from __future__ import annotations
import asyncio
import fnmatch
import os
import threading
from pathlib import Path
from deepagents.backends.protocol import (
EditResult,
FileDownloadResponse,
FileInfo,
FileUploadResponse,
GrepMatch,
WriteResult,
)
from deepagents.backends.utils import (
create_file_data,
format_read_response,
perform_string_replacement,
)
_INVALID_PATH = "invalid_path"
_FILE_NOT_FOUND = "file_not_found"
_IS_DIRECTORY = "is_directory"
class LocalFolderBackend:
"""Filesystem backend rooted to a single local folder."""
def __init__(self, root_path: str) -> None:
root = Path(root_path).expanduser().resolve()
if not root.exists() or not root.is_dir():
msg = f"Local filesystem root does not exist or is not a directory: {root_path}"
raise ValueError(msg)
self._root = root
self._locks: dict[str, threading.Lock] = {}
self._locks_mu = threading.Lock()
def _lock_for(self, path: str) -> threading.Lock:
with self._locks_mu:
if path not in self._locks:
self._locks[path] = threading.Lock()
return self._locks[path]
def _resolve_virtual(self, virtual_path: str, *, allow_root: bool = False) -> Path:
if not virtual_path.startswith("/"):
msg = f"Invalid path (must be absolute): {virtual_path}"
raise ValueError(msg)
rel = virtual_path.lstrip("/")
candidate = self._root if rel == "" else (self._root / rel)
resolved = candidate.resolve()
if not allow_root and resolved == self._root:
msg = "Path must refer to a file or child directory under root"
raise ValueError(msg)
if not resolved.is_relative_to(self._root):
msg = f"Path escapes local filesystem root: {virtual_path}"
raise ValueError(msg)
return resolved
@staticmethod
def _to_virtual(path: Path, root: Path) -> str:
rel = path.relative_to(root).as_posix()
return "/" if rel == "." else f"/{rel}"
def _write_text_atomic(self, path: Path, content: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
temp_path = path.with_suffix(f"{path.suffix}.tmp")
temp_path.write_text(content, encoding="utf-8")
os.replace(temp_path, path)
def ls_info(self, path: str) -> list[FileInfo]:
try:
target = self._resolve_virtual(path, allow_root=True)
except ValueError:
return []
if not target.exists() or not target.is_dir():
return []
infos: list[FileInfo] = []
for child in sorted(target.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())):
infos.append(
FileInfo(
path=self._to_virtual(child, self._root),
is_dir=child.is_dir(),
size=child.stat().st_size if child.is_file() else 0,
modified_at=str(child.stat().st_mtime),
)
)
return infos
async def als_info(self, path: str) -> list[FileInfo]:
return await asyncio.to_thread(self.ls_info, path)
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
try:
path = self._resolve_virtual(file_path)
except ValueError:
return f"Error: Invalid path '{file_path}'"
if not path.exists():
return f"Error: File '{file_path}' not found"
if not path.is_file():
return f"Error: Path '{file_path}' is not a file"
content = path.read_text(encoding="utf-8", errors="replace")
file_data = create_file_data(content)
return format_read_response(file_data, offset, limit)
async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
return await asyncio.to_thread(self.read, file_path, offset, limit)
def read_raw(self, file_path: str) -> str:
"""Read raw file text without line-number formatting."""
try:
path = self._resolve_virtual(file_path)
except ValueError:
return f"Error: Invalid path '{file_path}'"
if not path.exists():
return f"Error: File '{file_path}' not found"
if not path.is_file():
return f"Error: Path '{file_path}' is not a file"
return path.read_text(encoding="utf-8", errors="replace")
async def aread_raw(self, file_path: str) -> str:
"""Async variant of read_raw."""
return await asyncio.to_thread(self.read_raw, file_path)
def write(self, file_path: str, content: str) -> WriteResult:
try:
path = self._resolve_virtual(file_path)
except ValueError:
return WriteResult(error=f"Error: Invalid path '{file_path}'")
lock = self._lock_for(file_path)
with lock:
if path.exists():
return WriteResult(
error=(
f"Cannot write to {file_path} because it already exists. "
"Read and then make an edit, or write to a new path."
)
)
self._write_text_atomic(path, content)
return WriteResult(path=file_path, files_update=None)
async def awrite(self, file_path: str, content: str) -> WriteResult:
return await asyncio.to_thread(self.write, file_path, content)
def edit(
self,
file_path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
) -> EditResult:
try:
path = self._resolve_virtual(file_path)
except ValueError:
return EditResult(error=f"Error: Invalid path '{file_path}'")
lock = self._lock_for(file_path)
with lock:
if not path.exists() or not path.is_file():
return EditResult(error=f"Error: File '{file_path}' not found")
content = path.read_text(encoding="utf-8", errors="replace")
result = perform_string_replacement(content, old_string, new_string, replace_all)
if isinstance(result, str):
return EditResult(error=result)
updated_content, occurrences = result
self._write_text_atomic(path, updated_content)
return EditResult(path=file_path, files_update=None, occurrences=int(occurrences))
async def aedit(
self,
file_path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
) -> EditResult:
return await asyncio.to_thread(
self.edit, file_path, old_string, new_string, replace_all
)
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
try:
base = self._resolve_virtual(path, allow_root=True)
except ValueError:
return []
if pattern.startswith("/"):
search_base = self._root
normalized_pattern = pattern.lstrip("/")
else:
search_base = base
normalized_pattern = pattern
matches: list[FileInfo] = []
for hit in search_base.glob(normalized_pattern):
try:
resolved = hit.resolve()
if not resolved.is_relative_to(self._root):
continue
except Exception:
continue
matches.append(
FileInfo(
path=self._to_virtual(resolved, self._root),
is_dir=resolved.is_dir(),
size=resolved.stat().st_size if resolved.is_file() else 0,
modified_at=str(resolved.stat().st_mtime),
)
)
return matches
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
return await asyncio.to_thread(self.glob_info, pattern, path)
def _iter_candidate_files(self, path: str | None, glob: str | None) -> list[Path]:
base_virtual = path or "/"
try:
base = self._resolve_virtual(base_virtual, allow_root=True)
except ValueError:
return []
if not base.exists():
return []
candidates = [p for p in base.rglob("*") if p.is_file()]
if glob:
candidates = [
p
for p in candidates
if fnmatch.fnmatch(self._to_virtual(p, self._root), glob)
or fnmatch.fnmatch(p.name, glob)
]
return candidates
def grep_raw(
self, pattern: str, path: str | None = None, glob: str | None = None
) -> list[GrepMatch] | str:
if not pattern:
return "Error: pattern cannot be empty"
matches: list[GrepMatch] = []
for file_path in self._iter_candidate_files(path, glob):
try:
lines = file_path.read_text(encoding="utf-8", errors="replace").splitlines()
except Exception:
continue
for idx, line in enumerate(lines, start=1):
if pattern in line:
matches.append(
GrepMatch(
path=self._to_virtual(file_path, self._root),
line=idx,
text=line,
)
)
return matches
async def agrep_raw(
self, pattern: str, path: str | None = None, glob: str | None = None
) -> list[GrepMatch] | str:
return await asyncio.to_thread(self.grep_raw, pattern, path, glob)
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
responses: list[FileUploadResponse] = []
for virtual_path, content in files:
try:
target = self._resolve_virtual(virtual_path)
target.parent.mkdir(parents=True, exist_ok=True)
temp_path = target.with_suffix(f"{target.suffix}.tmp")
temp_path.write_bytes(content)
os.replace(temp_path, target)
responses.append(FileUploadResponse(path=virtual_path, error=None))
except FileNotFoundError:
responses.append(
FileUploadResponse(path=virtual_path, error=_FILE_NOT_FOUND)
)
except IsADirectoryError:
responses.append(FileUploadResponse(path=virtual_path, error=_IS_DIRECTORY))
except Exception:
responses.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH))
return responses
async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
return await asyncio.to_thread(self.upload_files, files)
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
responses: list[FileDownloadResponse] = []
for virtual_path in paths:
try:
target = self._resolve_virtual(virtual_path)
if not target.exists():
responses.append(
FileDownloadResponse(
path=virtual_path, content=None, error=_FILE_NOT_FOUND
)
)
continue
if target.is_dir():
responses.append(
FileDownloadResponse(
path=virtual_path, content=None, error=_IS_DIRECTORY
)
)
continue
responses.append(
FileDownloadResponse(
path=virtual_path, content=target.read_bytes(), error=None
)
)
except Exception:
responses.append(
FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH)
)
return responses
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
return await asyncio.to_thread(self.download_files, paths)

View file

@ -0,0 +1,329 @@
"""Aggregate multiple LocalFolderBackend roots behind mount-prefixed virtual paths."""
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import Any
from deepagents.backends.protocol import (
EditResult,
FileDownloadResponse,
FileInfo,
FileUploadResponse,
GrepMatch,
WriteResult,
)
from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend
_INVALID_PATH = "invalid_path"
_FILE_NOT_FOUND = "file_not_found"
_IS_DIRECTORY = "is_directory"
class MultiRootLocalFolderBackend:
"""Route filesystem operations to one of several mounted local roots.
Virtual paths are namespaced as:
- `/<mount>/...`
where `<mount>` is derived from each selected root folder name.
"""
def __init__(self, mounts: tuple[tuple[str, str], ...]) -> None:
if not mounts:
msg = "At least one local mount is required"
raise ValueError(msg)
self._mount_to_backend: dict[str, LocalFolderBackend] = {}
for raw_mount, raw_root in mounts:
mount = raw_mount.strip()
if not mount:
msg = "Mount id cannot be empty"
raise ValueError(msg)
if mount in self._mount_to_backend:
msg = f"Duplicate mount id: {mount}"
raise ValueError(msg)
normalized_root = str(Path(raw_root).expanduser().resolve())
self._mount_to_backend[mount] = LocalFolderBackend(normalized_root)
self._mount_order = tuple(self._mount_to_backend.keys())
def list_mounts(self) -> tuple[str, ...]:
return self._mount_order
def default_mount(self) -> str:
return self._mount_order[0]
def _mount_error(self) -> str:
mounts = ", ".join(f"/{mount}" for mount in self._mount_order)
return (
"Path must start with one of the selected folders: "
f"{mounts}. Example: /{self._mount_order[0]}/file.txt"
)
def _split_mount_path(self, virtual_path: str) -> tuple[str, str]:
if not virtual_path.startswith("/"):
msg = f"Invalid path (must be absolute): {virtual_path}"
raise ValueError(msg)
rel = virtual_path.lstrip("/")
if not rel:
raise ValueError(self._mount_error())
mount, _, remainder = rel.partition("/")
backend = self._mount_to_backend.get(mount)
if backend is None:
raise ValueError(self._mount_error())
local_path = f"/{remainder}" if remainder else "/"
return mount, local_path
@staticmethod
def _prefix_mount_path(mount: str, local_path: str) -> str:
if local_path == "/":
return f"/{mount}"
return f"/{mount}{local_path}"
@staticmethod
def _get_value(item: Any, key: str) -> Any:
if isinstance(item, dict):
return item.get(key)
return getattr(item, key, None)
@classmethod
def _get_str(cls, item: Any, key: str) -> str:
value = cls._get_value(item, key)
return value if isinstance(value, str) else ""
@classmethod
def _get_int(cls, item: Any, key: str) -> int:
value = cls._get_value(item, key)
return int(value) if isinstance(value, int | float) else 0
@classmethod
def _get_bool(cls, item: Any, key: str) -> bool:
value = cls._get_value(item, key)
return bool(value)
def _list_mount_roots(self) -> list[FileInfo]:
return [
FileInfo(path=f"/{mount}", is_dir=True, size=0, modified_at="0")
for mount in self._mount_order
]
def _transform_infos(self, mount: str, infos: list[FileInfo]) -> list[FileInfo]:
transformed: list[FileInfo] = []
for info in infos:
transformed.append(
FileInfo(
path=self._prefix_mount_path(mount, self._get_str(info, "path")),
is_dir=self._get_bool(info, "is_dir"),
size=self._get_int(info, "size"),
modified_at=self._get_str(info, "modified_at"),
)
)
return transformed
def ls_info(self, path: str) -> list[FileInfo]:
if path == "/":
return self._list_mount_roots()
try:
mount, local_path = self._split_mount_path(path)
except ValueError:
return []
return self._transform_infos(mount, self._mount_to_backend[mount].ls_info(local_path))
async def als_info(self, path: str) -> list[FileInfo]:
return await asyncio.to_thread(self.ls_info, path)
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
try:
mount, local_path = self._split_mount_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
return self._mount_to_backend[mount].read(local_path, offset, limit)
async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
return await asyncio.to_thread(self.read, file_path, offset, limit)
def read_raw(self, file_path: str) -> str:
try:
mount, local_path = self._split_mount_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
return self._mount_to_backend[mount].read_raw(local_path)
async def aread_raw(self, file_path: str) -> str:
return await asyncio.to_thread(self.read_raw, file_path)
def write(self, file_path: str, content: str) -> WriteResult:
try:
mount, local_path = self._split_mount_path(file_path)
except ValueError as exc:
return WriteResult(error=f"Error: {exc}")
result = self._mount_to_backend[mount].write(local_path, content)
if result.path:
result.path = self._prefix_mount_path(mount, result.path)
return result
async def awrite(self, file_path: str, content: str) -> WriteResult:
return await asyncio.to_thread(self.write, file_path, content)
def edit(
self,
file_path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
) -> EditResult:
try:
mount, local_path = self._split_mount_path(file_path)
except ValueError as exc:
return EditResult(error=f"Error: {exc}")
result = self._mount_to_backend[mount].edit(
local_path, old_string, new_string, replace_all
)
if result.path:
result.path = self._prefix_mount_path(mount, result.path)
return result
async def aedit(
self,
file_path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
) -> EditResult:
return await asyncio.to_thread(
self.edit, file_path, old_string, new_string, replace_all
)
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
if path == "/":
prefixed_results: list[FileInfo] = []
if pattern.startswith("/"):
mount, _, remainder = pattern.lstrip("/").partition("/")
backend = self._mount_to_backend.get(mount)
if not backend:
return []
local_pattern = f"/{remainder}" if remainder else "/"
return self._transform_infos(
mount, backend.glob_info(local_pattern, path="/")
)
for mount, backend in self._mount_to_backend.items():
prefixed_results.extend(
self._transform_infos(mount, backend.glob_info(pattern, path="/"))
)
return prefixed_results
try:
mount, local_path = self._split_mount_path(path)
except ValueError:
return []
return self._transform_infos(
mount, self._mount_to_backend[mount].glob_info(pattern, path=local_path)
)
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
return await asyncio.to_thread(self.glob_info, pattern, path)
def grep_raw(
self, pattern: str, path: str | None = None, glob: str | None = None
) -> list[GrepMatch] | str:
if not pattern:
return "Error: pattern cannot be empty"
if path is None or path == "/":
all_matches: list[GrepMatch] = []
for mount, backend in self._mount_to_backend.items():
result = backend.grep_raw(pattern, path="/", glob=glob)
if isinstance(result, str):
return result
all_matches.extend(
[
GrepMatch(
path=self._prefix_mount_path(mount, self._get_str(match, "path")),
line=self._get_int(match, "line"),
text=self._get_str(match, "text"),
)
for match in result
]
)
return all_matches
try:
mount, local_path = self._split_mount_path(path)
except ValueError as exc:
return f"Error: {exc}"
result = self._mount_to_backend[mount].grep_raw(
pattern, path=local_path, glob=glob
)
if isinstance(result, str):
return result
return [
GrepMatch(
path=self._prefix_mount_path(mount, self._get_str(match, "path")),
line=self._get_int(match, "line"),
text=self._get_str(match, "text"),
)
for match in result
]
async def agrep_raw(
self, pattern: str, path: str | None = None, glob: str | None = None
) -> list[GrepMatch] | str:
return await asyncio.to_thread(self.grep_raw, pattern, path, glob)
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
grouped: dict[str, list[tuple[str, bytes]]] = {}
invalid: list[FileUploadResponse] = []
for virtual_path, content in files:
try:
mount, local_path = self._split_mount_path(virtual_path)
except ValueError:
invalid.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH))
continue
grouped.setdefault(mount, []).append((local_path, content))
responses = list(invalid)
for mount, mount_files in grouped.items():
result = self._mount_to_backend[mount].upload_files(mount_files)
responses.extend(
[
FileUploadResponse(
path=self._prefix_mount_path(mount, self._get_str(item, "path")),
error=self._get_str(item, "error") or None,
)
for item in result
]
)
return responses
async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
return await asyncio.to_thread(self.upload_files, files)
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
grouped: dict[str, list[str]] = {}
invalid: list[FileDownloadResponse] = []
for virtual_path in paths:
try:
mount, local_path = self._split_mount_path(virtual_path)
except ValueError:
invalid.append(
FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH)
)
continue
grouped.setdefault(mount, []).append(local_path)
responses = list(invalid)
for mount, mount_paths in grouped.items():
result = self._mount_to_backend[mount].download_files(mount_paths)
responses.extend(
[
FileDownloadResponse(
path=self._prefix_mount_path(mount, self._get_str(item, "path")),
content=self._get_value(item, "content"),
error=self._get_str(item, "error") or None,
)
for item in result
]
)
return responses
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
return await asyncio.to_thread(self.download_files, paths)

View file

@ -815,11 +815,36 @@ Your goal is to provide helpful, informative answers in a clean, readable format
"""
def _build_mcp_routing_block(
mcp_connector_tools: dict[str, list[str]] | None,
) -> str:
"""Build an additional tool routing block for generic MCP connectors.
When users add MCP servers (e.g. GitLab, GitHub), the LLM needs to know
those tools exist and should be called directly not searched in the
knowledge base.
"""
if not mcp_connector_tools:
return ""
lines = [
"\n<mcp_tool_routing>",
"You also have direct tools from these user-connected MCP servers.",
"Their data is NEVER in the knowledge base — call their tools directly.",
"",
]
for server_name, tool_names in mcp_connector_tools.items():
lines.append(f"- {server_name}{', '.join(tool_names)}")
lines.append("</mcp_tool_routing>\n")
return "\n".join(lines)
def build_surfsense_system_prompt(
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
enabled_tool_names: set[str] | None = None,
disabled_tool_names: set[str] | None = None,
mcp_connector_tools: dict[str, list[str]] | None = None,
) -> str:
"""
Build the SurfSense system prompt with default settings.
@ -834,6 +859,9 @@ def build_surfsense_system_prompt(
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
mcp_connector_tools: Mapping of MCP server display name list of tool names
for generic MCP connectors. Injected into the system prompt so the LLM
knows to call these tools directly.
Returns:
Complete system prompt string
@ -841,6 +869,7 @@ def build_surfsense_system_prompt(
visibility = thread_visibility or ChatVisibility.PRIVATE
system_instructions = _get_system_instructions(visibility, today)
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
tools_instructions = _get_tools_instructions(
visibility, enabled_tool_names, disabled_tool_names
)
@ -856,6 +885,7 @@ def build_configurable_system_prompt(
thread_visibility: ChatVisibility | None = None,
enabled_tool_names: set[str] | None = None,
disabled_tool_names: set[str] | None = None,
mcp_connector_tools: dict[str, list[str]] | None = None,
) -> str:
"""
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
@ -877,6 +907,9 @@ def build_configurable_system_prompt(
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
mcp_connector_tools: Mapping of MCP server display name list of tool names
for generic MCP connectors. Injected into the system prompt so the LLM
knows to call these tools directly.
Returns:
Complete system prompt string
@ -894,6 +927,8 @@ def build_configurable_system_prompt(
else:
system_instructions = ""
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
# Tools instructions: only include enabled tools, note disabled ones
tools_instructions = _get_tools_instructions(
thread_visibility, enabled_tool_names, disabled_tool_names

View file

@ -45,6 +45,18 @@ class MCPClient:
async def connect(self, max_retries: int = MAX_RETRIES):
"""Connect to the MCP server and manage its lifecycle.
Retries only apply to the **connection** phase (spawning the process,
initialising the session). Once the session is yielded to the caller,
any exception raised by the caller propagates normally -- the context
manager will NOT retry after ``yield``.
Previous implementation wrapped both connection AND yield inside the
retry loop. Because ``@asynccontextmanager`` only allows a single
``yield``, a failure after yield caused the generator to attempt a
second yield on retry, triggering
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
the stdio subprocess.
Args:
max_retries: Maximum number of connection retry attempts
@ -57,26 +69,22 @@ class MCPClient:
"""
last_error = None
delay = RETRY_DELAY
connected = False
for attempt in range(max_retries):
try:
# Merge env vars with current environment
server_env = os.environ.copy()
server_env.update(self.env)
# Create server parameters with env
server_params = StdioServerParameters(
command=self.command, args=self.args, env=server_env
)
# Spawn server process and create session
# Note: Cannot combine these context managers because ClientSession
# needs the read/write streams from stdio_client
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
self.session = session
connected = True
if attempt > 0:
logger.info(
@ -91,10 +99,16 @@ class MCPClient:
self.command,
" ".join(self.args),
)
yield session
return # Success, exit retry loop
try:
yield session
finally:
self.session = None
return
except Exception as e:
self.session = None
if connected:
raise
last_error = e
if attempt < max_retries - 1:
logger.warning(
@ -105,7 +119,7 @@ class MCPClient:
delay,
)
await asyncio.sleep(delay)
delay *= RETRY_BACKOFF # Exponential backoff
delay *= RETRY_BACKOFF
else:
logger.error(
"Failed to connect to MCP server after %d attempts: %s",
@ -113,10 +127,7 @@ class MCPClient:
e,
exc_info=True,
)
finally:
self.session = None
# All retries exhausted
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
if last_error:
error_msg += f": {last_error}"
@ -161,12 +172,18 @@ class MCPClient:
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
raise
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
timeout: float = 60.0,
) -> Any:
"""Call a tool on the MCP server.
Args:
tool_name: Name of the tool to call
arguments: Arguments to pass to the tool
timeout: Maximum seconds to wait for the tool to respond
Returns:
Tool execution result
@ -185,10 +202,11 @@ class MCPClient:
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
)
# Call tools/call RPC method
response = await self.session.call_tool(tool_name, arguments=arguments)
response = await asyncio.wait_for(
self.session.call_tool(tool_name, arguments=arguments),
timeout=timeout,
)
# Extract content from response
result = []
for content in response.content:
if hasattr(content, "text"):
@ -202,15 +220,17 @@ class MCPClient:
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
return result_str
except asyncio.TimeoutError:
logger.error(
"MCP tool '%s' timed out after %.0fs", tool_name, timeout
)
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
except RuntimeError as e:
# Handle validation errors from MCP server responses
# Some MCP servers (like server-memory) return extra fields not in their schema
if "Invalid structured content" in str(e):
logger.warning(
"MCP server returned data not matching its schema, but continuing: %s",
e,
)
# Try to extract result from error message or return a success message
return "Operation completed (server returned unexpected format)"
raise
except (ValueError, TypeError, AttributeError, KeyError) as e:

View file

@ -16,6 +16,7 @@ clicking "Always Allow", which adds the tool name to the connector's
from __future__ import annotations
import asyncio
import logging
import time
from collections import defaultdict
@ -27,7 +28,7 @@ if TYPE_CHECKING:
from langchain_core.tools import StructuredTool
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from pydantic import BaseModel, Field, create_model
from pydantic import BaseModel, ConfigDict, Field, create_model
from sqlalchemy import cast, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
@ -41,6 +42,9 @@ logger = logging.getLogger(__name__)
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
_MCP_CACHE_MAX_SIZE = 50
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
_TOOL_CALL_MAX_RETRIES = 3
_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
@ -62,7 +66,18 @@ def _create_dynamic_input_model_from_schema(
tool_name: str,
input_schema: dict[str, Any],
) -> type[BaseModel]:
"""Create a Pydantic model from MCP tool's JSON schema."""
"""Create a Pydantic model from MCP tool's JSON schema.
Models always allow extra fields (``extra="allow"``) so that parameters
missing from a broken or incomplete JSON schema (e.g. ``zod-to-json-schema``
producing an empty ``$schema``-only object) can still be forwarded to the
MCP server.
When the schema declares **no** properties, a synthetic ``input_data``
field of type ``dict`` is injected so the LLM has a visible parameter to
populate. The caller should unpack ``input_data`` before forwarding to
the MCP server (see ``_unpack_synthetic_input_data``).
"""
properties = input_schema.get("properties", {})
required_fields = input_schema.get("required", [])
@ -82,8 +97,35 @@ def _create_dynamic_input_model_from_schema(
Field(None, description=param_description),
)
if not properties:
field_definitions["input_data"] = (
dict[str, Any] | None,
Field(
None,
description=(
"Arguments to pass to this tool as a JSON object. "
"Infer sensible key names from the tool name and description "
"(e.g. {\"search\": \"my query\"} for a search tool)."
),
),
)
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
return create_model(model_name, **field_definitions)
model = create_model(model_name, __config__=ConfigDict(extra="allow"), **field_definitions)
return model
def _unpack_synthetic_input_data(kwargs: dict[str, Any]) -> dict[str, Any]:
"""Unpack the synthetic ``input_data`` field into top-level kwargs.
When the MCP tool schema is empty, ``_create_dynamic_input_model_from_schema``
adds a catch-all ``input_data: dict`` field. This helper merges that dict
back into the top-level kwargs so the MCP server receives flat arguments.
"""
input_data = kwargs.pop("input_data", None)
if isinstance(input_data, dict):
kwargs.update(input_data)
return kwargs
async def _create_mcp_tool_from_definition_stdio(
@ -101,7 +143,12 @@ async def _create_mcp_tool_from_definition_stdio(
``GraphInterrupt`` propagates cleanly to LangGraph.
"""
tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided")
raw_description = tool_def.get("description", "No description provided")
tool_description = (
f"[MCP server: {connector_name}] {raw_description}"
if connector_name
else raw_description
)
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema)
@ -119,7 +166,7 @@ async def _create_mcp_tool_from_definition_stdio(
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": tool_description,
"tool_description": raw_description,
"mcp_transport": "stdio",
"mcp_connector_id": connector_id,
},
@ -127,18 +174,32 @@ async def _create_mcp_tool_from_definition_stdio(
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in hitl_result.params.items() if v is not None}
)
try:
async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, call_kwargs)
return str(result)
except RuntimeError as e:
logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e)
return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}"
except Exception as e:
logger.exception("MCP tool '%s' execution failed: %s", tool_name, e)
return f"Error: MCP tool '{tool_name}' execution failed: {e!s}"
last_error: Exception | None = None
for attempt in range(_TOOL_CALL_MAX_RETRIES):
try:
async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, call_kwargs)
return str(result)
except Exception as e:
last_error = e
if attempt < _TOOL_CALL_MAX_RETRIES - 1:
delay = _TOOL_CALL_RETRY_DELAY * (2 ** attempt)
logger.warning(
"MCP tool '%s' failed (attempt %d/%d): %s. Retrying in %.1fs...",
tool_name, attempt + 1, _TOOL_CALL_MAX_RETRIES, e, delay,
)
await asyncio.sleep(delay)
else:
logger.error(
"MCP tool '%s' failed after %d attempts: %s",
tool_name, _TOOL_CALL_MAX_RETRIES, e, exc_info=True,
)
return f"Error: MCP tool '{tool_name}' failed after {_TOOL_CALL_MAX_RETRIES} attempts: {last_error!s}"
tool = StructuredTool(
name=tool_name,
@ -148,6 +209,8 @@ async def _create_mcp_tool_from_definition_stdio(
metadata={
"mcp_input_schema": input_schema,
"mcp_transport": "stdio",
"mcp_connector_name": connector_name or None,
"mcp_is_generic": True,
"hitl": True,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
},
@ -167,6 +230,7 @@ async def _create_mcp_tool_from_definition_http(
trusted_tools: list[str] | None = None,
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
is_generic_mcp: bool = False,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
@ -178,7 +242,7 @@ async def _create_mcp_tool_from_definition_http(
but the actual MCP ``call_tool`` still uses the original name.
"""
original_tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided")
raw_description = tool_def.get("description", "No description provided")
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
@ -188,18 +252,51 @@ async def _create_mcp_tool_from_definition_http(
else original_tool_name
)
if tool_name_prefix:
tool_description = f"[Account: {connector_name}] {tool_description}"
tool_description = f"[Account: {connector_name}] {raw_description}"
elif is_generic_mcp and connector_name:
tool_description = f"[MCP server: {connector_name}] {raw_description}"
else:
tool_description = raw_description
logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema)
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
async def _do_mcp_call(
call_headers: dict[str, str],
call_kwargs: dict[str, Any],
timeout: float = 60.0,
) -> str:
"""Execute a single MCP HTTP call with the given headers."""
async with (
streamablehttp_client(url, headers=call_headers) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
response = await asyncio.wait_for(
session.call_tool(original_tool_name, arguments=call_kwargs),
timeout=timeout,
)
result = []
for content in response.content:
if hasattr(content, "text"):
result.append(content.text)
elif hasattr(content, "data"):
result.append(str(content.data))
else:
result.append(str(content))
return "\n".join(result) if result else ""
async def mcp_http_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via HTTP transport."""
logger.debug("MCP HTTP tool '%s' called", exposed_name)
if is_readonly:
call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in kwargs.items() if v is not None}
)
else:
hitl_result = request_approval(
action_type="mcp_tool_call",
@ -207,7 +304,7 @@ async def _create_mcp_tool_from_definition_http(
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": tool_description,
"tool_description": raw_description,
"mcp_transport": "http",
"mcp_connector_id": connector_id,
},
@ -215,34 +312,51 @@ async def _create_mcp_tool_from_definition_http(
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in hitl_result.params.items() if v is not None}
)
try:
async with (
streamablehttp_client(url, headers=headers) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.call_tool(
original_tool_name, arguments=call_kwargs,
result_str = await _do_mcp_call(headers, call_kwargs)
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
return result_str
except Exception as first_err:
if not _is_auth_error(first_err) or connector_id is None:
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, first_err)
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {first_err!s}"
logger.warning(
"MCP HTTP tool '%s' got 401 — attempting token refresh for connector %s",
exposed_name, connector_id,
)
fresh_headers = await _force_refresh_and_get_headers(connector_id)
if fresh_headers is None:
await _mark_connector_auth_expired(connector_id)
return (
f"Error: MCP tool '{exposed_name}' authentication expired. "
"Please re-authenticate the connector in your settings."
)
result = []
for content in response.content:
if hasattr(content, "text"):
result.append(content.text)
elif hasattr(content, "data"):
result.append(str(content.data))
else:
result.append(str(content))
result_str = "\n".join(result) if result else ""
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
try:
result_str = await _do_mcp_call(fresh_headers, call_kwargs)
logger.info(
"MCP HTTP tool '%s' succeeded after 401 recovery",
exposed_name,
)
return result_str
except Exception as e:
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, e)
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
except Exception as retry_err:
logger.exception(
"MCP HTTP tool '%s' still failing after token refresh: %s",
exposed_name, retry_err,
)
if _is_auth_error(retry_err):
await _mark_connector_auth_expired(connector_id)
return (
f"Error: MCP tool '{exposed_name}' authentication expired. "
"Please re-authenticate the connector in your settings."
)
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {retry_err!s}"
tool = StructuredTool(
name=exposed_name,
@ -253,6 +367,8 @@ async def _create_mcp_tool_from_definition_http(
"mcp_input_schema": input_schema,
"mcp_transport": "http",
"mcp_url": url,
"mcp_connector_name": connector_name or None,
"mcp_is_generic": is_generic_mcp,
"hitl": not is_readonly,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
"mcp_original_tool_name": original_tool_name,
@ -334,6 +450,7 @@ async def _load_http_mcp_tools(
allowed_tools: list[str] | None = None,
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
is_generic_mcp: bool = False,
) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server.
@ -365,66 +482,99 @@ async def _load_http_mcp_tools(
allowed_set = set(allowed_tools) if allowed_tools else None
try:
async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]:
"""Connect, initialize, and list tools from the MCP server."""
async with (
streamablehttp_client(url, headers=headers) as (read, write, _),
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.list_tools()
tool_definitions = []
for tool in response.tools:
tool_definitions.append(
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
}
)
return [
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
}
for tool in response.tools
]
total_discovered = len(tool_definitions)
try:
tool_definitions = await _discover(headers)
except Exception as first_err:
if not _is_auth_error(first_err) or connector_id is None:
logger.exception(
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
url, connector_id, first_err,
)
return tools
if allowed_set:
tool_definitions = [
td for td in tool_definitions if td["name"] in allowed_set
]
logger.info(
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
url, connector_id, len(tool_definitions), total_discovered,
)
else:
logger.info(
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
total_discovered, url, connector_id,
)
for tool_def in tool_definitions:
try:
tool = await _create_mcp_tool_from_definition_http(
tool_def,
url,
headers,
connector_name=connector_name,
connector_id=connector_id,
trusted_tools=trusted_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
)
tools.append(tool)
except Exception as e:
logger.exception(
"Failed to create HTTP tool '%s' from connector %d: %s",
tool_def.get("name"), connector_id, e,
)
except Exception as e:
logger.exception(
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
url, connector_id, e,
logger.warning(
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
connector_id,
)
fresh_headers = await _force_refresh_and_get_headers(connector_id)
if fresh_headers is None:
await _mark_connector_auth_expired(connector_id)
logger.error(
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
connector_id,
)
return tools
try:
tool_definitions = await _discover(fresh_headers)
headers = fresh_headers
logger.info(
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
connector_id,
)
except Exception as retry_err:
logger.exception(
"HTTP MCP discovery for connector %d still failing after refresh: %s",
connector_id, retry_err,
)
if _is_auth_error(retry_err):
await _mark_connector_auth_expired(connector_id)
return tools
total_discovered = len(tool_definitions)
if allowed_set:
tool_definitions = [
td for td in tool_definitions if td["name"] in allowed_set
]
logger.info(
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
url, connector_id, len(tool_definitions), total_discovered,
)
else:
logger.info(
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
total_discovered, url, connector_id,
)
for tool_def in tool_definitions:
try:
tool = await _create_mcp_tool_from_definition_http(
tool_def,
url,
headers,
connector_name=connector_name,
connector_id=connector_id,
trusted_tools=trusted_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
is_generic_mcp=is_generic_mcp,
)
tools.append(tool)
except Exception as e:
logger.exception(
"Failed to create HTTP tool '%s' from connector %d: %s",
tool_def.get("name"), connector_id, e,
)
return tools
@ -476,6 +626,91 @@ def _inject_oauth_headers(
return None
async def _refresh_connector_token(
session: AsyncSession,
connector: "SearchSourceConnector",
) -> str | None:
"""Refresh the OAuth token for an MCP connector and persist the result.
This is the shared core used by both proactive (pre-expiry) and reactive
(401 recovery) refresh paths. It handles:
- Decrypting the current refresh token / client secret
- Calling the token endpoint
- Encrypting and persisting the new tokens
- Clearing ``auth_expired`` if it was set
- Invalidating the MCP tools cache
Returns the **plaintext** new access token on success, or ``None`` on
failure (no refresh token, IdP error, etc.).
"""
from datetime import UTC, datetime, timedelta
from sqlalchemy.orm.attributes import flag_modified
from app.services.mcp_oauth.discovery import refresh_access_token
cfg = connector.config or {}
mcp_oauth = cfg.get("mcp_oauth", {})
refresh_token = mcp_oauth.get("refresh_token")
if not refresh_token:
logger.warning(
"MCP connector %s: no refresh_token available",
connector.id,
)
return None
enc = _get_token_enc()
decrypted_refresh = enc.decrypt_token(refresh_token)
decrypted_secret = (
enc.decrypt_token(mcp_oauth["client_secret"])
if mcp_oauth.get("client_secret")
else ""
)
token_json = await refresh_access_token(
token_endpoint=mcp_oauth["token_endpoint"],
refresh_token=decrypted_refresh,
client_id=mcp_oauth["client_id"],
client_secret=decrypted_secret,
)
new_access = token_json.get("access_token")
if not new_access:
logger.warning(
"MCP connector %s: token refresh returned no access_token",
connector.id,
)
return None
new_expires_at = None
if token_json.get("expires_in"):
new_expires_at = datetime.now(UTC) + timedelta(
seconds=int(token_json["expires_in"])
)
updated_oauth = dict(mcp_oauth)
updated_oauth["access_token"] = enc.encrypt_token(new_access)
if token_json.get("refresh_token"):
updated_oauth["refresh_token"] = enc.encrypt_token(
token_json["refresh_token"]
)
updated_oauth["expires_at"] = (
new_expires_at.isoformat() if new_expires_at else None
)
updated_cfg = {**cfg, "mcp_oauth": updated_oauth}
updated_cfg.pop("auth_expired", None)
connector.config = updated_cfg
flag_modified(connector, "config")
await session.commit()
await session.refresh(connector)
invalidate_mcp_tools_cache(connector.search_space_id)
return new_access
async def _maybe_refresh_mcp_oauth_token(
session: AsyncSession,
connector: "SearchSourceConnector",
@ -504,73 +739,13 @@ async def _maybe_refresh_mcp_oauth_token(
except (ValueError, TypeError):
return server_config
refresh_token = mcp_oauth.get("refresh_token")
if not refresh_token:
logger.warning(
"MCP connector %s token expired but no refresh_token available",
connector.id,
)
return server_config
try:
from app.services.mcp_oauth.discovery import refresh_access_token
enc = _get_token_enc()
decrypted_refresh = enc.decrypt_token(refresh_token)
decrypted_secret = (
enc.decrypt_token(mcp_oauth["client_secret"])
if mcp_oauth.get("client_secret")
else ""
)
token_json = await refresh_access_token(
token_endpoint=mcp_oauth["token_endpoint"],
refresh_token=decrypted_refresh,
client_id=mcp_oauth["client_id"],
client_secret=decrypted_secret,
)
new_access = token_json.get("access_token")
new_access = await _refresh_connector_token(session, connector)
if not new_access:
logger.warning(
"MCP connector %s token refresh returned no access_token",
connector.id,
)
return server_config
new_expires_at = None
if token_json.get("expires_in"):
new_expires_at = datetime.now(UTC) + timedelta(
seconds=int(token_json["expires_in"])
)
logger.info("Proactively refreshed MCP OAuth token for connector %s", connector.id)
updated_oauth = dict(mcp_oauth)
updated_oauth["access_token"] = enc.encrypt_token(new_access)
if token_json.get("refresh_token"):
updated_oauth["refresh_token"] = enc.encrypt_token(
token_json["refresh_token"]
)
updated_oauth["expires_at"] = (
new_expires_at.isoformat() if new_expires_at else None
)
from sqlalchemy.orm.attributes import flag_modified
connector.config = {
**cfg,
"server_config": server_config,
"mcp_oauth": updated_oauth,
}
flag_modified(connector, "config")
await session.commit()
await session.refresh(connector)
logger.info("Refreshed MCP OAuth token for connector %s", connector.id)
# Invalidate cache so next call picks up the new token.
invalidate_mcp_tools_cache(connector.search_space_id)
# Return server_config with the fresh token injected for immediate use.
refreshed_config = dict(server_config)
refreshed_config["headers"] = {
**server_config.get("headers", {}),
@ -587,6 +762,117 @@ async def _maybe_refresh_mcp_oauth_token(
return server_config
# ---------------------------------------------------------------------------
# Reactive 401 handling helpers
# ---------------------------------------------------------------------------
def _is_auth_error(exc: Exception) -> bool:
"""Check if an exception indicates an HTTP 401 authentication failure."""
try:
import httpx
if isinstance(exc, httpx.HTTPStatusError):
return exc.response.status_code == 401
except ImportError:
pass
err_str = str(exc).lower()
return "401" in err_str or "unauthorized" in err_str
async def _force_refresh_and_get_headers(
connector_id: int,
) -> dict[str, str] | None:
"""Force-refresh OAuth token for a connector and return fresh HTTP headers.
Opens a **new** DB session so this can be called from inside tool closures
that don't have access to the original session.
Returns ``None`` when the connector is not OAuth-backed, has no
refresh token, or the refresh itself fails.
"""
from app.db import async_session_maker
try:
async with async_session_maker() as session:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
)
)
connector = result.scalars().first()
if not connector:
return None
cfg = connector.config or {}
if not cfg.get("mcp_oauth"):
return None
server_config = cfg.get("server_config", {})
new_access = await _refresh_connector_token(session, connector)
if not new_access:
return None
logger.info(
"Force-refreshed MCP OAuth token for connector %s (401 recovery)",
connector_id,
)
return {
**server_config.get("headers", {}),
"Authorization": f"Bearer {new_access}",
}
except Exception:
logger.warning(
"Failed to force-refresh MCP OAuth token for connector %s",
connector_id,
exc_info=True,
)
return None
async def _mark_connector_auth_expired(connector_id: int) -> None:
"""Set ``config.auth_expired = True`` so the frontend shows re-auth UI."""
from app.db import async_session_maker
try:
async with async_session_maker() as session:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
)
)
connector = result.scalars().first()
if not connector:
return
cfg = dict(connector.config or {})
if cfg.get("auth_expired"):
return
cfg["auth_expired"] = True
connector.config = cfg
from sqlalchemy.orm.attributes import flag_modified
flag_modified(connector, "config")
await session.commit()
logger.info(
"Marked MCP connector %s as auth_expired after unrecoverable 401",
connector_id,
)
invalidate_mcp_tools_cache(connector.search_space_id)
except Exception:
logger.warning(
"Failed to mark connector %s as auth_expired",
connector_id,
exc_info=True,
)
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
"""Invalidate cached MCP tools.
@ -661,7 +947,7 @@ async def load_mcp_tools(
multi_account_types,
)
tools: list[StructuredTool] = []
discovery_tasks: list[dict[str, Any]] = []
for connector in connectors:
try:
cfg = connector.config or {}
@ -674,14 +960,10 @@ async def load_mcp_tools(
)
continue
# For MCP OAuth connectors: refresh if needed, then decrypt the
# access token and inject it into headers at runtime. The DB
# intentionally does NOT store plaintext tokens in server_config.
if cfg.get("mcp_oauth"):
server_config = await _maybe_refresh_mcp_oauth_token(
session, connector, cfg, server_config,
)
# Re-read cfg after potential refresh (connector was reloaded from DB).
cfg = connector.config or {}
server_config = _inject_oauth_headers(cfg, server_config)
if server_config is None:
@ -689,6 +971,7 @@ async def load_mcp_tools(
"Skipping MCP connector %d — OAuth token decryption failed",
connector.id,
)
await _mark_connector_auth_expired(connector.id)
continue
trusted_tools = cfg.get("trusted_tools", [])
@ -703,7 +986,6 @@ async def load_mcp_tools(
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
# Build a prefix only when multiple accounts share the same type.
tool_name_prefix: str | None = None
if ct in multi_account_types and svc_cfg:
service_key = next(
@ -713,34 +995,68 @@ async def load_mcp_tools(
if service_key:
tool_name_prefix = f"{service_key}_{connector.id}"
transport = server_config.get("transport", "stdio")
if transport in ("streamable-http", "http", "sse"):
connector_tools = await _load_http_mcp_tools(
connector.id,
connector.name,
server_config,
trusted_tools=trusted_tools,
allowed_tools=allowed_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
)
else:
connector_tools = await _load_stdio_mcp_tools(
connector.id,
connector.name,
server_config,
trusted_tools=trusted_tools,
)
tools.extend(connector_tools)
discovery_tasks.append({
"connector_id": connector.id,
"connector_name": connector.name,
"server_config": server_config,
"trusted_tools": trusted_tools,
"allowed_tools": allowed_tools,
"readonly_tools": readonly_tools,
"tool_name_prefix": tool_name_prefix,
"transport": server_config.get("transport", "stdio"),
"is_generic_mcp": svc_cfg is None,
})
except Exception as e:
logger.exception(
"Failed to load tools from MCP connector %d: %s",
"Failed to prepare MCP connector %d: %s",
connector.id, e,
)
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
try:
if task["transport"] in ("streamable-http", "http", "sse"):
return await asyncio.wait_for(
_load_http_mcp_tools(
task["connector_id"],
task["connector_name"],
task["server_config"],
trusted_tools=task["trusted_tools"],
allowed_tools=task["allowed_tools"],
readonly_tools=task["readonly_tools"],
tool_name_prefix=task["tool_name_prefix"],
is_generic_mcp=task.get("is_generic_mcp", False),
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
else:
return await asyncio.wait_for(
_load_stdio_mcp_tools(
task["connector_id"],
task["connector_name"],
task["server_config"],
trusted_tools=task["trusted_tools"],
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(
"MCP connector %d timed out after %ds during discovery",
task["connector_id"], _MCP_DISCOVERY_TIMEOUT_SECONDS,
)
return []
except Exception as e:
logger.exception(
"Failed to load tools from MCP connector %d: %s",
task["connector_id"], e,
)
return []
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
tools: list[StructuredTool] = [
tool for sublist in results for tool in sublist
]
_mcp_tools_cache[search_space_id] = (now, tools)
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:

View file

@ -141,6 +141,15 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons
exc.status_code,
message,
)
elif exc.status_code >= 400:
_error_logger.warning(
"[%s] %s %s - HTTPException %d: %s",
rid,
request.method,
request.url.path,
exc.status_code,
message,
)
if should_sanitize:
message = GENERIC_5XX_MESSAGE
err_code = "INTERNAL_ERROR"
@ -170,6 +179,15 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons
exc.status_code,
detail,
)
elif exc.status_code >= 400:
_error_logger.warning(
"[%s] %s %s - HTTPException %d: %s",
rid,
request.method,
request.url.path,
exc.status_code,
detail,
)
if should_sanitize:
detail = GENERIC_5XX_MESSAGE
code = _status_to_code(exc.status_code, detail)

View file

@ -339,6 +339,9 @@ class Config:
# self-hosted: Full access to local file system connectors (Obsidian, etc.)
# cloud: Only cloud-based connectors available
DEPLOYMENT_MODE = os.getenv("SURFSENSE_DEPLOYMENT_MODE", "self-hosted")
ENABLE_DESKTOP_LOCAL_FILESYSTEM = (
os.getenv("ENABLE_DESKTOP_LOCAL_FILESYSTEM", "FALSE").upper() == "TRUE"
)
@classmethod
def is_self_hosted(cls) -> bool:

View file

@ -22,6 +22,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.agents.new_chat.filesystem_selection import (
ClientPlatform,
LocalFilesystemMount,
FilesystemMode,
FilesystemSelection,
)
from app.config import config
from app.db import (
ChatComment,
ChatVisibility,
@ -36,6 +43,7 @@ from app.db import (
)
from app.schemas.new_chat import (
AgentToolInfo,
LocalFilesystemMountPayload,
NewChatMessageRead,
NewChatRequest,
NewChatThreadCreate,
@ -63,6 +71,67 @@ _background_tasks: set[asyncio.Task] = set()
router = APIRouter()
def _resolve_filesystem_selection(
*,
mode: str,
client_platform: str,
local_mounts: list[LocalFilesystemMountPayload] | None,
) -> FilesystemSelection:
"""Validate and normalize filesystem mode settings from request payload."""
try:
resolved_mode = FilesystemMode(mode)
except ValueError as exc:
raise HTTPException(status_code=400, detail="Invalid filesystem_mode") from exc
try:
resolved_platform = ClientPlatform(client_platform)
except ValueError as exc:
raise HTTPException(status_code=400, detail="Invalid client_platform") from exc
if resolved_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
if not config.ENABLE_DESKTOP_LOCAL_FILESYSTEM:
raise HTTPException(
status_code=400,
detail="Desktop local filesystem mode is disabled on this deployment.",
)
if resolved_platform != ClientPlatform.DESKTOP:
raise HTTPException(
status_code=400,
detail="desktop_local_folder mode is only available on desktop runtime.",
)
normalized_mounts: list[tuple[str, str]] = []
seen_mounts: set[str] = set()
for mount in local_mounts or []:
mount_id = mount.mount_id.strip()
root_path = mount.root_path.strip()
if not mount_id or not root_path:
continue
if mount_id in seen_mounts:
continue
seen_mounts.add(mount_id)
normalized_mounts.append((mount_id, root_path))
if not normalized_mounts:
raise HTTPException(
status_code=400,
detail=(
"local_filesystem_mounts must include at least one mount for "
"desktop_local_folder mode."
),
)
return FilesystemSelection(
mode=resolved_mode,
client_platform=resolved_platform,
local_mounts=tuple(
LocalFilesystemMount(mount_id=mount_id, root_path=root_path)
for mount_id, root_path in normalized_mounts
),
)
return FilesystemSelection(
mode=FilesystemMode.CLOUD,
client_platform=resolved_platform,
)
def _try_delete_sandbox(thread_id: int) -> None:
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
from app.agents.new_chat.sandbox import (
@ -1098,6 +1167,7 @@ async def list_agent_tools(
@router.post("/new_chat")
async def handle_new_chat(
request: NewChatRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
@ -1133,6 +1203,11 @@ async def handle_new_chat(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
local_mounts=request.local_filesystem_mounts,
)
# Get search space to check LLM config preferences
search_space_result = await session.execute(
@ -1175,6 +1250,8 @@ async def handle_new_chat(
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
disabled_tools=request.disabled_tools,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
),
media_type="text/event-stream",
headers={
@ -1202,6 +1279,7 @@ async def handle_new_chat(
async def regenerate_response(
thread_id: int,
request: RegenerateRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
@ -1247,6 +1325,11 @@ async def regenerate_response(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
local_mounts=request.local_filesystem_mounts,
)
# Get the checkpointer and state history
checkpointer = await get_checkpointer()
@ -1412,6 +1495,8 @@ async def regenerate_response(
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
disabled_tools=request.disabled_tools,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
):
yield chunk
streaming_completed = True
@ -1477,6 +1562,7 @@ async def regenerate_response(
async def resume_chat(
thread_id: int,
request: ResumeRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
@ -1498,6 +1584,11 @@ async def resume_chat(
)
await check_thread_access(session, thread, user)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
local_mounts=request.local_filesystem_mounts,
)
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
@ -1526,6 +1617,8 @@ async def resume_chat(
user_id=str(user.id),
llm_config_id=llm_config_id,
thread_visibility=thread.visibility,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
),
media_type="text/event-stream",
headers={

View file

@ -3105,13 +3105,18 @@ async def trust_mcp_tool(
"""Add a tool to the MCP connector's trusted (always-allow) list.
Once trusted, the tool executes without HITL approval on subsequent calls.
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors
(LINEAR_CONNECTOR, JIRA_CONNECTOR, etc.) by checking for ``server_config``.
"""
try:
from sqlalchemy import cast
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
SearchSourceConnector.user_id == user.id,
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
)
)
connector = result.scalars().first()
@ -3156,13 +3161,17 @@ async def untrust_mcp_tool(
"""Remove a tool from the MCP connector's trusted list.
The tool will require HITL approval again on subsequent calls.
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors.
"""
try:
from sqlalchemy import cast
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
SearchSourceConnector.user_id == user.id,
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
)
)
connector = result.scalars().first()

View file

@ -168,6 +168,11 @@ class ChatMessage(BaseModel):
content: str
class LocalFilesystemMountPayload(BaseModel):
mount_id: str
root_path: str
class NewChatRequest(BaseModel):
"""Request schema for the deep agent chat endpoint."""
@ -184,6 +189,9 @@ class NewChatRequest(BaseModel):
disabled_tools: list[str] | None = (
None # Optional list of tool names the user has disabled from the UI
)
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
class RegenerateRequest(BaseModel):
@ -204,6 +212,9 @@ class RegenerateRequest(BaseModel):
mentioned_document_ids: list[int] | None = None
mentioned_surfsense_doc_ids: list[int] | None = None
disabled_tools: list[str] | None = None
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
# =============================================================================
@ -227,6 +238,9 @@ class ResumeDecision(BaseModel):
class ResumeRequest(BaseModel):
search_space_id: int
decisions: list[ResumeDecision]
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
# =============================================================================

View file

@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -66,6 +65,8 @@ class ConfluenceKBSyncService:
if dup:
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
@ -184,6 +185,8 @@ class ConfluenceKBSyncService:
space_id = (document.document_metadata or {}).get("space_id", "")
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
)

View file

@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -73,6 +72,8 @@ class DropboxKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,

View file

@ -4,7 +4,6 @@ from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -78,6 +77,8 @@ class GmailKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,

View file

@ -14,7 +14,6 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -91,6 +90,8 @@ class GoogleCalendarKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
@ -249,6 +250,8 @@ class GoogleCalendarKBSyncService:
if not indexable_content:
return {"status": "error", "message": "Event produced empty content"}
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
)

View file

@ -4,7 +4,6 @@ from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -75,6 +74,8 @@ class GoogleDriveKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,

View file

@ -6,7 +6,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.jira_history import JiraHistoryConnector
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -75,6 +74,8 @@ class JiraKBSyncService:
if dup:
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
@ -190,6 +191,8 @@ class JiraKBSyncService:
state = formatted.get("status", "Unknown")
comment_count = len(formatted.get("comments", []))
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
)

View file

@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.linear_connector import LinearConnector
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -85,6 +84,8 @@ class LinearKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
@ -226,6 +227,8 @@ class LinearKBSyncService:
comment_count = len(formatted_issue.get("comments", []))
formatted_issue.get("description", "")
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
)

View file

@ -7,7 +7,6 @@ from langchain_litellm import ChatLiteLLM
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
from app.config import config
from app.db import NewLLMConfig, SearchSpace
from app.services.llm_router_service import (
@ -204,6 +203,8 @@ async def validate_llm_config(
if litellm_params:
litellm_kwargs.update(litellm_params)
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
llm = SanitizedChatLiteLLM(**litellm_kwargs)
# Run the test call in a worker thread with a hard timeout. Some
@ -377,6 +378,8 @@ async def get_search_space_llm_instance(
if disable_streaming:
litellm_kwargs["disable_streaming"] = True
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
return SanitizedChatLiteLLM(**litellm_kwargs)
# Get the LLM configuration from database (NewLLMConfig)
@ -454,6 +457,8 @@ async def get_search_space_llm_instance(
if disable_streaming:
litellm_kwargs["disable_streaming"] = True
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
return SanitizedChatLiteLLM(**litellm_kwargs)
except Exception as e:
@ -555,6 +560,8 @@ async def get_vision_llm(
if global_cfg.get("litellm_params"):
litellm_kwargs.update(global_cfg["litellm_params"])
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
return SanitizedChatLiteLLM(**litellm_kwargs)
result = await session.execute(
@ -588,6 +595,8 @@ async def get_vision_llm(
if vision_cfg.litellm_params:
litellm_kwargs.update(vision_cfg.litellm_params)
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
return SanitizedChatLiteLLM(**litellm_kwargs)
except Exception as e:

View file

@ -4,7 +4,6 @@ from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -74,6 +73,8 @@ class NotionKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
@ -244,6 +245,8 @@ class NotionKBSyncService:
f"Final content length: {len(full_content)} chars, verified={content_verified}"
)
from app.services.llm_service import get_user_long_context_llm
logger.debug("Generating summary and embeddings")
user_llm = await get_user_long_context_llm(
self.db_session,

View file

@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -73,6 +72,8 @@ class OneDriveKBSyncService:
)
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,

View file

@ -30,6 +30,8 @@ from sqlalchemy.orm import selectinload
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.filesystem_selection import FilesystemSelection
from app.config import config
from app.agents.new_chat.llm_config import (
AgentConfig,
create_chat_litellm_from_agent_config,
@ -145,6 +147,102 @@ class StreamResult:
interrupt_value: dict[str, Any] | None = None
sandbox_files: list[str] = field(default_factory=list)
agent_called_update_memory: bool = False
request_id: str | None = None
turn_id: str = ""
filesystem_mode: str = "cloud"
client_platform: str = "web"
intent_detected: str = "chat_only"
intent_confidence: float = 0.0
write_attempted: bool = False
write_succeeded: bool = False
verification_succeeded: bool = False
commit_gate_passed: bool = True
commit_gate_reason: str = ""
def _safe_float(value: Any, default: float = 0.0) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default
def _tool_output_to_text(tool_output: Any) -> str:
if isinstance(tool_output, dict):
if isinstance(tool_output.get("result"), str):
return tool_output["result"]
if isinstance(tool_output.get("error"), str):
return tool_output["error"]
return json.dumps(tool_output, ensure_ascii=False)
return str(tool_output)
def _tool_output_has_error(tool_output: Any) -> bool:
if isinstance(tool_output, dict):
if tool_output.get("error"):
return True
result = tool_output.get("result")
if isinstance(result, str) and result.strip().lower().startswith("error:"):
return True
return False
if isinstance(tool_output, str):
return tool_output.strip().lower().startswith("error:")
return False
def _extract_resolved_file_path(*, tool_name: str, tool_output: Any) -> str | None:
if isinstance(tool_output, dict):
path_value = tool_output.get("path")
if isinstance(path_value, str) and path_value.strip():
return path_value.strip()
text = _tool_output_to_text(tool_output)
if tool_name == "write_file":
match = re.search(r"Updated file\s+(.+)$", text.strip())
if match:
return match.group(1).strip()
if tool_name == "edit_file":
match = re.search(r"in '([^']+)'", text)
if match:
return match.group(1).strip()
return None
def _contract_enforcement_active(result: StreamResult) -> bool:
# Keep policy deterministic with no env-driven progression modes:
# enforce the file-operation contract only in desktop local-folder mode.
return result.filesystem_mode == "desktop_local_folder"
def _evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]:
if result.intent_detected != "file_write":
return True, ""
if not result.write_attempted:
return False, "no_write_attempt"
if not result.write_succeeded:
return False, "write_failed"
if not result.verification_succeeded:
return False, "verification_failed"
return True, ""
def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
payload: dict[str, Any] = {
"stage": stage,
"request_id": result.request_id or "unknown",
"turn_id": result.turn_id or "unknown",
"chat_id": result.turn_id.split(":", 1)[0] if ":" in result.turn_id else "unknown",
"filesystem_mode": result.filesystem_mode,
"client_platform": result.client_platform,
"intent_detected": result.intent_detected,
"intent_confidence": result.intent_confidence,
"write_attempted": result.write_attempted,
"write_succeeded": result.write_succeeded,
"verification_succeeded": result.verification_succeeded,
"commit_gate_passed": result.commit_gate_passed,
"commit_gate_reason": result.commit_gate_reason or None,
}
payload.update(extra)
_perf_log.info("[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False))
async def _stream_agent_events(
@ -239,6 +337,8 @@ async def _stream_agent_events(
tool_name = event.get("name", "unknown_tool")
run_id = event.get("run_id", "")
tool_input = event.get("data", {}).get("input", {})
if tool_name in ("write_file", "edit_file"):
result.write_attempted = True
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
@ -514,6 +614,14 @@ async def _stream_agent_events(
else:
tool_output = {"result": str(raw_output) if raw_output else "completed"}
if tool_name in ("write_file", "edit_file"):
if _tool_output_has_error(tool_output):
# Keep successful evidence if a previous write/edit in this turn succeeded.
pass
else:
result.write_succeeded = True
result.verification_succeeded = True
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
original_step_id = tool_step_ids.get(
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
@ -925,6 +1033,30 @@ async def _stream_agent_events(
f"Scrape failed: {error_msg}",
"error",
)
elif tool_name in ("write_file", "edit_file"):
resolved_path = _extract_resolved_file_path(
tool_name=tool_name,
tool_output=tool_output,
)
result_text = _tool_output_to_text(tool_output)
if _tool_output_has_error(tool_output):
yield streaming_service.format_tool_output_available(
tool_call_id,
{
"status": "error",
"error": result_text,
"path": resolved_path,
},
)
else:
yield streaming_service.format_tool_output_available(
tool_call_id,
{
"status": "completed",
"path": resolved_path,
"result": result_text,
},
)
elif tool_name == "generate_report":
# Stream the full report result so frontend can render the ReportCard
yield streaming_service.format_tool_output_available(
@ -1143,10 +1275,59 @@ async def _stream_agent_events(
if completion_event:
yield completion_event
state = await agent.aget_state(config)
state_values = getattr(state, "values", {}) or {}
contract_state = state_values.get("file_operation_contract") or {}
contract_turn_id = contract_state.get("turn_id")
current_turn_id = config.get("configurable", {}).get("turn_id", "")
intent_value = contract_state.get("intent")
if (
isinstance(intent_value, str)
and intent_value in ("chat_only", "file_write", "file_read")
and contract_turn_id == current_turn_id
):
result.intent_detected = intent_value
if (
isinstance(intent_value, str)
and intent_value in (
"chat_only",
"file_write",
"file_read",
)
and contract_turn_id != current_turn_id
):
# Ignore stale intent contracts from previous turns/checkpoints.
result.intent_detected = "chat_only"
result.intent_confidence = (
_safe_float(contract_state.get("confidence"), default=0.0)
if contract_turn_id == current_turn_id
else 0.0
)
if result.intent_detected == "file_write":
result.commit_gate_passed, result.commit_gate_reason = (
_evaluate_file_contract_outcome(result)
)
if not result.commit_gate_passed:
if _contract_enforcement_active(result):
gate_notice = (
"I could not complete the requested file write because no successful "
"write_file/edit_file operation was confirmed."
)
gate_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(gate_text_id)
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
yield streaming_service.format_text_end(gate_text_id)
yield streaming_service.format_terminal_info(gate_notice, "error")
accumulated_text = gate_notice
else:
result.commit_gate_passed = True
result.commit_gate_reason = ""
result.accumulated_text = accumulated_text
result.agent_called_update_memory = called_update_memory
_log_file_contract("turn_outcome", result)
state = await agent.aget_state(config)
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
if is_interrupted:
result.is_interrupted = True
@ -1167,6 +1348,8 @@ async def stream_new_chat(
thread_visibility: ChatVisibility | None = None,
current_user_display_name: str | None = None,
disabled_tools: list[str] | None = None,
filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None,
) -> AsyncGenerator[str, None]:
"""
Stream chat responses from the new SurfSense deep agent.
@ -1194,6 +1377,20 @@ async def stream_new_chat(
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
fs_platform = (
filesystem_selection.client_platform.value if filesystem_selection else "web"
)
stream_result.request_id = request_id
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
stream_result.filesystem_mode = fs_mode
stream_result.client_platform = fs_platform
_log_file_contract("turn_start", stream_result)
_perf_log.info(
"[stream_new_chat] filesystem_mode=%s client_platform=%s",
fs_mode,
fs_platform,
)
log_system_snapshot("stream_new_chat_START")
from app.services.token_tracking_service import start_turn
@ -1329,6 +1526,7 @@ async def stream_new_chat(
thread_visibility=visibility,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
)
_perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
@ -1435,6 +1633,8 @@ async def stream_new_chat(
# We will use this to simulate group chat functionality in the future
"messages": langchain_messages,
"search_space_id": search_space_id,
"request_id": request_id or "unknown",
"turn_id": stream_result.turn_id,
}
_perf_log.info(
@ -1464,6 +1664,8 @@ async def stream_new_chat(
# Configure LangGraph with thread_id for memory
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
configurable = {"thread_id": str(chat_id)}
configurable["request_id"] = request_id or "unknown"
configurable["turn_id"] = stream_result.turn_id
if checkpoint_id:
configurable["checkpoint_id"] = checkpoint_id
@ -1871,10 +2073,26 @@ async def stream_resume_chat(
user_id: str | None = None,
llm_config_id: int = -1,
thread_visibility: ChatVisibility | None = None,
filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None,
) -> AsyncGenerator[str, None]:
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
fs_platform = (
filesystem_selection.client_platform.value if filesystem_selection else "web"
)
stream_result.request_id = request_id
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
stream_result.filesystem_mode = fs_mode
stream_result.client_platform = fs_platform
_log_file_contract("turn_start", stream_result)
_perf_log.info(
"[stream_resume] filesystem_mode=%s client_platform=%s",
fs_mode,
fs_platform,
)
from app.services.token_tracking_service import start_turn
@ -1991,6 +2209,7 @@ async def stream_resume_chat(
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
)
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
@ -2009,7 +2228,11 @@ async def stream_resume_chat(
from langgraph.types import Command
config = {
"configurable": {"thread_id": str(chat_id)},
"configurable": {
"thread_id": str(chat_id),
"request_id": request_id or "unknown",
"turn_id": stream_result.turn_id,
},
"recursion_limit": 80,
}

View file

@ -0,0 +1,214 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from app.agents.new_chat.middleware.file_intent import (
FileIntentMiddleware,
FileOperationIntent,
_fallback_path,
)
pytestmark = pytest.mark.unit
class _FakeLLM:
def __init__(self, response_text: str):
self._response_text = response_text
async def ainvoke(self, *_args, **_kwargs):
return AIMessage(content=self._response_text)
@pytest.mark.asyncio
async def test_file_write_intent_injects_contract_message():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.93,"suggested_filename":"ideas.md"}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="Create another random note for me")],
"turn_id": "123:456",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/ideas.md"
assert contract["turn_id"] == "123:456"
assert any(
"file_operation_contract" in str(msg.content)
for msg in result["messages"]
if hasattr(msg, "content")
)
@pytest.mark.asyncio
async def test_non_write_intent_does_not_inject_contract_message():
llm = _FakeLLM(
'{"intent":"file_read","confidence":0.88,"suggested_filename":null}'
)
middleware = FileIntentMiddleware(llm=llm)
original_messages = [HumanMessage(content="Read /notes.md")]
state = {"messages": original_messages, "turn_id": "abc:def"}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
assert result["file_operation_contract"]["intent"] == FileOperationIntent.FILE_READ.value
assert "messages" not in result
@pytest.mark.asyncio
async def test_file_write_null_filename_uses_semantic_default_path():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.74,"suggested_filename":null}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="create a random markdown file")],
"turn_id": "turn:1",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/notes.md"
@pytest.mark.asyncio
async def test_file_write_null_filename_infers_json_extension():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.71,"suggested_filename":null}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="create a sample json config file")],
"turn_id": "turn:2",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/notes.json"
@pytest.mark.asyncio
async def test_file_write_txt_suggestion_is_normalized_to_markdown():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.82,"suggested_filename":"random.txt"}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="create a random file")],
"turn_id": "turn:3",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/random.md"
@pytest.mark.asyncio
async def test_file_write_with_suggested_directory_preserves_folder():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.86,"suggested_filename":"random.md","suggested_directory":"pc backups","suggested_path":null}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="create a random file in pc backups folder")],
"turn_id": "turn:4",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/pc_backups/random.md"
@pytest.mark.asyncio
async def test_file_write_with_suggested_path_takes_precedence():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.9,"suggested_filename":"ignored.md","suggested_directory":"docs","suggested_path":"/reports/q2/summary.md"}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="create report")],
"turn_id": "turn:5",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/reports/q2/summary.md"
@pytest.mark.asyncio
async def test_file_write_infers_directory_from_user_text_when_missing():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.83,"suggested_filename":"random.md","suggested_directory":null,"suggested_path":null}'
)
middleware = FileIntentMiddleware(llm=llm)
state = {
"messages": [HumanMessage(content="create a random file in pc backups folder")],
"turn_id": "turn:6",
}
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/pc_backups/random.md"
def test_fallback_path_normalizes_windows_slashes() -> None:
resolved = _fallback_path(
suggested_filename="summary.md",
suggested_path=r"\reports\q2\summary.md",
user_text="create report",
)
assert resolved == "/reports/q2/summary.md"
def test_fallback_path_normalizes_windows_drive_path() -> None:
resolved = _fallback_path(
suggested_filename=None,
suggested_path=r"C:\Users\anish\notes\todo.md",
user_text="create note",
)
assert resolved == "/C/Users/anish/notes/todo.md"
def test_fallback_path_normalizes_mixed_separators_and_duplicate_slashes() -> None:
resolved = _fallback_path(
suggested_filename="summary.md",
suggested_path=r"\\reports\\q2//summary.md",
user_text="create report",
)
assert resolved == "/reports/q2/summary.md"
def test_fallback_path_keeps_posix_style_absolute_path_for_linux_and_macos() -> None:
resolved = _fallback_path(
suggested_filename=None,
suggested_path="/var/log/surfsense/notes.md",
user_text="create note",
)
assert resolved == "/var/log/surfsense/notes.md"

View file

@ -0,0 +1,59 @@
from pathlib import Path
import pytest
from app.agents.new_chat.filesystem_backends import build_backend_resolver
from app.agents.new_chat.filesystem_selection import (
ClientPlatform,
FilesystemMode,
FilesystemSelection,
LocalFilesystemMount,
)
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
MultiRootLocalFolderBackend,
)
pytestmark = pytest.mark.unit
class _RuntimeStub:
state = {"files": {}}
def test_backend_resolver_returns_multi_root_backend_for_single_root(tmp_path: Path):
selection = FilesystemSelection(
mode=FilesystemMode.DESKTOP_LOCAL_FOLDER,
client_platform=ClientPlatform.DESKTOP,
local_mounts=(LocalFilesystemMount(mount_id="tmp", root_path=str(tmp_path)),),
)
resolver = build_backend_resolver(selection)
backend = resolver(_RuntimeStub())
assert isinstance(backend, MultiRootLocalFolderBackend)
def test_backend_resolver_uses_cloud_mode_by_default():
resolver = build_backend_resolver(FilesystemSelection())
backend = resolver(_RuntimeStub())
# StateBackend class name check keeps this test decoupled
# from internal deepagents runtime class identity.
assert backend.__class__.__name__ == "StateBackend"
def test_backend_resolver_returns_multi_root_backend_for_multiple_roots(tmp_path: Path):
root_one = tmp_path / "resume"
root_two = tmp_path / "notes"
root_one.mkdir()
root_two.mkdir()
selection = FilesystemSelection(
mode=FilesystemMode.DESKTOP_LOCAL_FOLDER,
client_platform=ClientPlatform.DESKTOP,
local_mounts=(
LocalFilesystemMount(mount_id="resume", root_path=str(root_one)),
LocalFilesystemMount(mount_id="notes", root_path=str(root_two)),
),
)
resolver = build_backend_resolver(selection)
backend = resolver(_RuntimeStub())
assert isinstance(backend, MultiRootLocalFolderBackend)

View file

@ -0,0 +1,164 @@
from pathlib import Path
import pytest
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
MultiRootLocalFolderBackend,
)
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
pytestmark = pytest.mark.unit
class _BackendWithRawRead:
def __init__(self, content: str) -> None:
self._content = content
def read(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
del file_path, offset, limit
return " 1\tline1\n 2\tline2"
async def aread(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
return self.read(file_path, offset, limit)
def read_raw(self, file_path: str) -> str:
del file_path
return self._content
async def aread_raw(self, file_path: str) -> str:
return self.read_raw(file_path)
class _RuntimeNoSuggestedPath:
state = {"file_operation_contract": {}}
def test_verify_written_content_prefers_raw_sync() -> None:
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
expected = "line1\nline2"
backend = _BackendWithRawRead(expected)
verify_error = middleware._verify_written_content_sync(
backend=backend,
path="/note.md",
expected_content=expected,
)
assert verify_error is None
def test_contract_suggested_path_falls_back_to_notes_md() -> None:
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._filesystem_mode = FilesystemMode.CLOUD
suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type]
assert suggested == "/notes.md"
@pytest.mark.asyncio
async def test_verify_written_content_prefers_raw_async() -> None:
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
expected = "line1\nline2"
backend = _BackendWithRawRead(expected)
verify_error = await middleware._verify_written_content_async(
backend=backend,
path="/note.md",
expected_content=expected,
)
assert verify_error is None
def test_normalize_local_mount_path_prefixes_default_mount(tmp_path: Path) -> None:
root = tmp_path / "PC Backups"
root.mkdir()
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
runtime = _RuntimeNoSuggestedPath()
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path("/random-note.md", runtime) # type: ignore[arg-type]
assert resolved == "/pc_backups/random-note.md"
def test_normalize_local_mount_path_keeps_explicit_mount(tmp_path: Path) -> None:
root = tmp_path / "PC Backups"
root.mkdir()
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
runtime = _RuntimeNoSuggestedPath()
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
"/pc_backups/notes/random-note.md",
runtime,
)
assert resolved == "/pc_backups/notes/random-note.md"
def test_normalize_local_mount_path_windows_backslashes(tmp_path: Path) -> None:
root = tmp_path / "PC Backups"
root.mkdir()
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
runtime = _RuntimeNoSuggestedPath()
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
r"\notes\random-note.md",
runtime,
)
assert resolved == "/pc_backups/notes/random-note.md"
def test_normalize_local_mount_path_normalizes_mixed_separators(tmp_path: Path) -> None:
root = tmp_path / "PC Backups"
root.mkdir()
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
runtime = _RuntimeNoSuggestedPath()
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
r"\\notes//nested\\random-note.md",
runtime,
)
assert resolved == "/pc_backups/notes/nested/random-note.md"
def test_normalize_local_mount_path_keeps_explicit_mount_with_backslashes(
tmp_path: Path,
) -> None:
root = tmp_path / "PC Backups"
root.mkdir()
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
runtime = _RuntimeNoSuggestedPath()
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
r"\pc_backups\notes\random-note.md",
runtime,
)
assert resolved == "/pc_backups/notes/random-note.md"
def test_normalize_local_mount_path_prefixes_posix_absolute_path_for_linux_and_macos(
tmp_path: Path,
) -> None:
root = tmp_path / "PC Backups"
root.mkdir()
backend = MultiRootLocalFolderBackend((("pc_backups", str(root)),))
runtime = _RuntimeNoSuggestedPath()
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path("/var/log/app.log", runtime) # type: ignore[arg-type]
assert resolved == "/pc_backups/var/log/app.log"

View file

@ -0,0 +1,59 @@
from pathlib import Path
import pytest
from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend
pytestmark = pytest.mark.unit
def test_local_backend_write_read_edit_roundtrip(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
write = backend.write("/notes/test.md", "line1\nline2")
assert write.error is None
assert write.path == "/notes/test.md"
read = backend.read("/notes/test.md", offset=0, limit=20)
assert "line1" in read
assert "line2" in read
edit = backend.edit("/notes/test.md", "line2", "updated")
assert edit.error is None
assert edit.occurrences == 1
read_after = backend.read("/notes/test.md", offset=0, limit=20)
assert "updated" in read_after
def test_local_backend_blocks_path_escape(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
result = backend.write("/../../etc/passwd", "bad")
assert result.error is not None
assert "Invalid path" in result.error
def test_local_backend_glob_and_grep(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "docs").mkdir()
(tmp_path / "docs" / "a.txt").write_text("hello world\n")
(tmp_path / "docs" / "b.md").write_text("hello markdown\n")
infos = backend.glob_info("**/*.txt", "/docs")
paths = {info["path"] for info in infos}
assert "/docs/a.txt" in paths
grep = backend.grep_raw("hello", "/docs", "*.md")
assert isinstance(grep, list)
assert any(match["path"] == "/docs/b.md" for match in grep)
def test_local_backend_read_raw_returns_exact_content(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
expected = "# Title\n\nline 1\nline 2\n"
write = backend.write("/notes/raw.md", expected)
assert write.error is None
raw = backend.read_raw("/notes/raw.md")
assert raw == expected

View file

@ -0,0 +1,28 @@
from pathlib import Path
import pytest
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
MultiRootLocalFolderBackend,
)
pytestmark = pytest.mark.unit
def test_mount_ids_preserve_client_mapping_order(tmp_path: Path) -> None:
root_one = tmp_path / "PC Backups"
root_two = tmp_path / "pc_backups"
root_three = tmp_path / "notes@2026"
root_one.mkdir()
root_two.mkdir()
root_three.mkdir()
backend = MultiRootLocalFolderBackend(
(
("pc_backups", str(root_one)),
("pc_backups_2", str(root_two)),
("notes_2026", str(root_three)),
)
)
assert backend.list_mounts() == ("pc_backups", "pc_backups_2", "notes_2026")

View file

@ -0,0 +1,48 @@
import pytest
from app.tasks.chat.stream_new_chat import (
StreamResult,
_contract_enforcement_active,
_evaluate_file_contract_outcome,
_tool_output_has_error,
)
pytestmark = pytest.mark.unit
def test_tool_output_error_detection():
assert _tool_output_has_error("Error: failed to write file")
assert _tool_output_has_error({"error": "boom"})
assert _tool_output_has_error({"result": "Error: disk is full"})
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
def test_file_write_contract_outcome_reasons():
result = StreamResult(intent_detected="file_write")
passed, reason = _evaluate_file_contract_outcome(result)
assert not passed
assert reason == "no_write_attempt"
result.write_attempted = True
passed, reason = _evaluate_file_contract_outcome(result)
assert not passed
assert reason == "write_failed"
result.write_succeeded = True
passed, reason = _evaluate_file_contract_outcome(result)
assert not passed
assert reason == "verification_failed"
result.verification_succeeded = True
passed, reason = _evaluate_file_contract_outcome(result)
assert passed
assert reason == ""
def test_contract_enforcement_local_only():
result = StreamResult(filesystem_mode="desktop_local_folder")
assert _contract_enforcement_active(result)
result.filesystem_mode = "cloud"
assert not _contract_enforcement_active(result)

View file

@ -34,6 +34,8 @@ export const IPC_CHANNELS = {
FOLDER_SYNC_SEED_MTIMES: 'folder-sync:seed-mtimes',
BROWSE_FILES: 'browse:files',
READ_LOCAL_FILES: 'browse:read-local-files',
READ_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:read-local-file-text',
WRITE_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:write-local-file-text',
// Auth token sync across windows
GET_AUTH_TOKENS: 'auth:get-tokens',
SET_AUTH_TOKENS: 'auth:set-tokens',
@ -51,4 +53,9 @@ export const IPC_CHANNELS = {
ANALYTICS_RESET: 'analytics:reset',
ANALYTICS_CAPTURE: 'analytics:capture',
ANALYTICS_GET_CONTEXT: 'analytics:get-context',
// Agent filesystem mode
AGENT_FILESYSTEM_GET_SETTINGS: 'agent-filesystem:get-settings',
AGENT_FILESYSTEM_GET_MOUNTS: 'agent-filesystem:get-mounts',
AGENT_FILESYSTEM_SET_SETTINGS: 'agent-filesystem:set-settings',
AGENT_FILESYSTEM_PICK_ROOT: 'agent-filesystem:pick-root',
} as const;

View file

@ -36,6 +36,14 @@ import {
resetUser as analyticsReset,
trackEvent,
} from '../modules/analytics';
import {
readAgentLocalFileText,
writeAgentLocalFileText,
getAgentFilesystemMounts,
getAgentFilesystemSettings,
pickAgentFilesystemRoot,
setAgentFilesystemSettings,
} from '../modules/agent-filesystem';
let authTokens: { bearer: string; refresh: string } | null = null;
@ -118,6 +126,29 @@ export function registerIpcHandlers(): void {
readLocalFiles(paths)
);
ipcMain.handle(IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, async (_event, virtualPath: string) => {
try {
const result = await readAgentLocalFileText(virtualPath);
return { ok: true, path: result.path, content: result.content };
} catch (error) {
const message = error instanceof Error ? error.message : 'Failed to read local file';
return { ok: false, path: virtualPath, error: message };
}
});
ipcMain.handle(
IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT,
async (_event, virtualPath: string, content: string) => {
try {
const result = await writeAgentLocalFileText(virtualPath, content);
return { ok: true, path: result.path };
} catch (error) {
const message = error instanceof Error ? error.message : 'Failed to write local file';
return { ok: false, path: virtualPath, error: message };
}
}
);
ipcMain.handle(IPC_CHANNELS.SET_AUTH_TOKENS, (_event, tokens: { bearer: string; refresh: string }) => {
authTokens = tokens;
});
@ -191,4 +222,22 @@ export function registerIpcHandlers(): void {
platform: process.platform,
};
});
ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS, () =>
getAgentFilesystemSettings()
);
ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS, () =>
getAgentFilesystemMounts()
);
ipcMain.handle(
IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS,
(_event, settings: { mode?: 'cloud' | 'desktop_local_folder'; localRootPaths?: string[] | null }) =>
setAgentFilesystemSettings(settings)
);
ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT, () =>
pickAgentFilesystemRoot()
);
}

View file

@ -0,0 +1,254 @@
import { app, dialog } from "electron";
import { access, mkdir, readFile, writeFile } from "node:fs/promises";
import { dirname, isAbsolute, join, relative, resolve } from "node:path";
export type AgentFilesystemMode = "cloud" | "desktop_local_folder";
export interface AgentFilesystemSettings {
mode: AgentFilesystemMode;
localRootPaths: string[];
updatedAt: string;
}
const SETTINGS_FILENAME = "agent-filesystem-settings.json";
const MAX_LOCAL_ROOTS = 5;
function getSettingsPath(): string {
return join(app.getPath("userData"), SETTINGS_FILENAME);
}
function getDefaultSettings(): AgentFilesystemSettings {
return {
mode: "cloud",
localRootPaths: [],
updatedAt: new Date().toISOString(),
};
}
function normalizeLocalRootPaths(paths: unknown): string[] {
if (!Array.isArray(paths)) {
return [];
}
const uniquePaths = new Set<string>();
for (const path of paths) {
if (typeof path !== "string") continue;
const trimmed = path.trim();
if (!trimmed) continue;
uniquePaths.add(trimmed);
if (uniquePaths.size >= MAX_LOCAL_ROOTS) {
break;
}
}
return [...uniquePaths];
}
export async function getAgentFilesystemSettings(): Promise<AgentFilesystemSettings> {
try {
const raw = await readFile(getSettingsPath(), "utf8");
const parsed = JSON.parse(raw) as Partial<AgentFilesystemSettings>;
if (parsed.mode !== "cloud" && parsed.mode !== "desktop_local_folder") {
return getDefaultSettings();
}
return {
mode: parsed.mode,
localRootPaths: normalizeLocalRootPaths(parsed.localRootPaths),
updatedAt: parsed.updatedAt ?? new Date().toISOString(),
};
} catch {
return getDefaultSettings();
}
}
export async function setAgentFilesystemSettings(
settings: {
mode?: AgentFilesystemMode;
localRootPaths?: string[] | null;
}
): Promise<AgentFilesystemSettings> {
const current = await getAgentFilesystemSettings();
const nextMode =
settings.mode === "cloud" || settings.mode === "desktop_local_folder"
? settings.mode
: current.mode;
const next: AgentFilesystemSettings = {
mode: nextMode,
localRootPaths:
settings.localRootPaths === undefined
? current.localRootPaths
: normalizeLocalRootPaths(settings.localRootPaths ?? []),
updatedAt: new Date().toISOString(),
};
const settingsPath = getSettingsPath();
await mkdir(dirname(settingsPath), { recursive: true });
await writeFile(settingsPath, JSON.stringify(next, null, 2), "utf8");
return next;
}
export async function pickAgentFilesystemRoot(): Promise<string | null> {
const result = await dialog.showOpenDialog({
title: "Select local folder for Agent Filesystem",
properties: ["openDirectory"],
});
if (result.canceled || result.filePaths.length === 0) {
return null;
}
return result.filePaths[0] ?? null;
}
function resolveVirtualPath(rootPath: string, virtualPath: string): string {
if (!virtualPath.startsWith("/")) {
throw new Error("Path must start with '/'");
}
const normalizedRoot = resolve(rootPath);
const relativePath = virtualPath.replace(/^\/+/, "");
if (!relativePath) {
throw new Error("Path must refer to a file under the selected root");
}
const absolutePath = resolve(normalizedRoot, relativePath);
const rel = relative(normalizedRoot, absolutePath);
if (!rel || rel.startsWith("..") || isAbsolute(rel)) {
throw new Error("Path escapes selected local root");
}
return absolutePath;
}
function toVirtualPath(rootPath: string, absolutePath: string): string {
const normalizedRoot = resolve(rootPath);
const rel = relative(normalizedRoot, absolutePath);
if (!rel || rel.startsWith("..") || isAbsolute(rel)) {
return "/";
}
return `/${rel.replace(/\\/g, "/")}`;
}
export type LocalRootMount = {
mount: string;
rootPath: string;
};
function sanitizeMountName(rawMount: string): string {
const normalized = rawMount
.trim()
.toLowerCase()
.replace(/[^a-z0-9_-]+/g, "_")
.replace(/_+/g, "_")
.replace(/^[_-]+|[_-]+$/g, "");
return normalized || "root";
}
function buildRootMounts(rootPaths: string[]): LocalRootMount[] {
const mounts: LocalRootMount[] = [];
const usedMounts = new Set<string>();
for (const rawRootPath of rootPaths) {
const normalizedRoot = resolve(rawRootPath);
const baseMount = sanitizeMountName(normalizedRoot.split(/[\\/]/).at(-1) || "root");
let mount = baseMount;
let suffix = 2;
while (usedMounts.has(mount)) {
mount = `${baseMount}-${suffix}`;
suffix += 1;
}
usedMounts.add(mount);
mounts.push({ mount, rootPath: normalizedRoot });
}
return mounts;
}
export async function getAgentFilesystemMounts(): Promise<LocalRootMount[]> {
const rootPaths = await resolveCurrentRootPaths();
return buildRootMounts(rootPaths);
}
function parseMountedVirtualPath(
virtualPath: string,
mounts: LocalRootMount[]
): {
mount: string;
subPath: string;
} {
if (!virtualPath.startsWith("/")) {
throw new Error("Path must start with '/'");
}
const trimmed = virtualPath.replace(/^\/+/, "");
if (!trimmed) {
throw new Error("Path must include a mounted root segment");
}
const [mount, ...rest] = trimmed.split("/");
const remainder = rest.join("/");
const directMount = mounts.find((entry) => entry.mount === mount);
if (!directMount) {
throw new Error(
`Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}`
);
}
if (!remainder) {
throw new Error("Path must include a file path under the mounted root");
}
return { mount, subPath: `/${remainder}` };
}
function findMountByName(mounts: LocalRootMount[], mountName: string): LocalRootMount | undefined {
return mounts.find((entry) => entry.mount === mountName);
}
function toMountedVirtualPath(mount: string, rootPath: string, absolutePath: string): string {
const relativePath = toVirtualPath(rootPath, absolutePath);
return `/${mount}${relativePath}`;
}
async function resolveCurrentRootPaths(): Promise<string[]> {
const settings = await getAgentFilesystemSettings();
if (settings.localRootPaths.length === 0) {
throw new Error("No local filesystem roots selected");
}
return settings.localRootPaths;
}
export async function readAgentLocalFileText(
virtualPath: string
): Promise<{ path: string; content: string }> {
const rootPaths = await resolveCurrentRootPaths();
const mounts = buildRootMounts(rootPaths);
const { mount, subPath } = parseMountedVirtualPath(virtualPath, mounts);
const rootMount = findMountByName(mounts, mount);
if (!rootMount) {
throw new Error(
`Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}`
);
}
const absolutePath = resolveVirtualPath(rootMount.rootPath, subPath);
const content = await readFile(absolutePath, "utf8");
return {
path: toMountedVirtualPath(rootMount.mount, rootMount.rootPath, absolutePath),
content,
};
}
export async function writeAgentLocalFileText(
virtualPath: string,
content: string
): Promise<{ path: string }> {
const rootPaths = await resolveCurrentRootPaths();
const mounts = buildRootMounts(rootPaths);
const { mount, subPath } = parseMountedVirtualPath(virtualPath, mounts);
const rootMount = findMountByName(mounts, mount);
if (!rootMount) {
throw new Error(
`Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}`
);
}
let selectedAbsolutePath = resolveVirtualPath(rootMount.rootPath, subPath);
try {
await access(selectedAbsolutePath);
} catch {
// New files are created under the selected mounted root.
}
await mkdir(dirname(selectedAbsolutePath), { recursive: true });
await writeFile(selectedAbsolutePath, content, "utf8");
return {
path: toMountedVirtualPath(rootMount.mount, rootMount.rootPath, selectedAbsolutePath),
};
}

View file

@ -71,6 +71,10 @@ contextBridge.exposeInMainWorld('electronAPI', {
// Browse files via native dialog
browseFiles: () => ipcRenderer.invoke(IPC_CHANNELS.BROWSE_FILES),
readLocalFiles: (paths: string[]) => ipcRenderer.invoke(IPC_CHANNELS.READ_LOCAL_FILES, paths),
readAgentLocalFileText: (virtualPath: string) =>
ipcRenderer.invoke(IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, virtualPath),
writeAgentLocalFileText: (virtualPath: string, content: string) =>
ipcRenderer.invoke(IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT, virtualPath, content),
// Auth token sync across windows
getAuthTokens: () => ipcRenderer.invoke(IPC_CHANNELS.GET_AUTH_TOKENS),
@ -101,4 +105,14 @@ contextBridge.exposeInMainWorld('electronAPI', {
analyticsCapture: (event: string, properties?: Record<string, unknown>) =>
ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_CAPTURE, { event, properties }),
getAnalyticsContext: () => ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_GET_CONTEXT),
// Agent filesystem mode
getAgentFilesystemSettings: () =>
ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS),
getAgentFilesystemMounts: () =>
ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS),
setAgentFilesystemSettings: (settings: {
mode?: "cloud" | "desktop_local_folder";
localRootPaths?: string[] | null;
}) => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS, settings),
pickAgentFilesystemRoot: () => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT),
});

View file

@ -46,6 +46,7 @@ import {
import { useChatSessionStateSync } from "@/hooks/use-chat-session-state";
import { useMessagesSync } from "@/hooks/use-messages-sync";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { getAgentFilesystemSelection } from "@/lib/agent-filesystem";
import { getBearerToken } from "@/lib/auth-utils";
import { convertToThreadMessage } from "@/lib/chat/message-utils";
import {
@ -158,7 +159,7 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] {
/**
* Tools that should render custom UI in the chat.
*/
const TOOLS_WITH_UI = new Set([
const BASE_TOOLS_WITH_UI = new Set([
"web_search",
"generate_podcast",
"generate_report",
@ -210,6 +211,7 @@ export default function NewChatPage() {
assistantMsgId: string;
interruptData: Record<string, unknown>;
} | null>(null);
const toolsWithUI = useMemo(() => new Set([...BASE_TOOLS_WITH_UI]), []);
// Get disabled tools from the tool toggle UI
const disabledTools = useAtomValue(disabledToolsAtom);
@ -656,6 +658,15 @@ export default function NewChatPage() {
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
const selection = await getAgentFilesystemSelection();
if (
selection.filesystem_mode === "desktop_local_folder" &&
(!selection.local_filesystem_mounts ||
selection.local_filesystem_mounts.length === 0)
) {
toast.error("Select a local folder before using Local Folder mode.");
return;
}
// Build message history for context
const messageHistory = messages
@ -691,6 +702,9 @@ export default function NewChatPage() {
chat_id: currentThreadId,
user_query: userQuery.trim(),
search_space_id: searchSpaceId,
filesystem_mode: selection.filesystem_mode,
client_platform: selection.client_platform,
local_filesystem_mounts: selection.local_filesystem_mounts,
messages: messageHistory,
mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined,
mentioned_surfsense_doc_ids: hasSurfsenseDocIds
@ -709,7 +723,7 @@ export default function NewChatPage() {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
: m
)
);
@ -724,7 +738,7 @@ export default function NewChatPage() {
break;
case "tool-input-start":
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
batcher.flush();
break;
@ -734,7 +748,7 @@ export default function NewChatPage() {
} else {
addToolCall(
contentPartsState,
TOOLS_WITH_UI,
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
parsed.input || {}
@ -830,7 +844,7 @@ export default function NewChatPage() {
const tcId = `interrupt-${action.name}`;
addToolCall(
contentPartsState,
TOOLS_WITH_UI,
toolsWithUI,
tcId,
action.name,
action.args,
@ -844,7 +858,7 @@ export default function NewChatPage() {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
: m
)
);
@ -871,7 +885,7 @@ export default function NewChatPage() {
batcher.flush();
// Skip persistence for interrupted messages -- handleResume will persist the final version
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
if (contentParts.length > 0 && !wasInterrupted) {
try {
const savedMessage = await appendMessage(currentThreadId, {
@ -907,10 +921,10 @@ export default function NewChatPage() {
const hasContent = contentParts.some(
(part) =>
(part.type === "text" && part.text.length > 0) ||
(part.type === "tool-call" && TOOLS_WITH_UI.has(part.toolName))
(part.type === "tool-call" && toolsWithUI.has(part.toolName))
);
if (hasContent && currentThreadId) {
const partialContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
try {
const savedMessage = await appendMessage(currentThreadId, {
role: "assistant",
@ -1074,6 +1088,7 @@ export default function NewChatPage() {
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
const selection = await getAgentFilesystemSelection();
const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, {
method: "POST",
headers: {
@ -1083,6 +1098,9 @@ export default function NewChatPage() {
body: JSON.stringify({
search_space_id: searchSpaceId,
decisions,
filesystem_mode: selection.filesystem_mode,
client_platform: selection.client_platform,
local_filesystem_mounts: selection.local_filesystem_mounts,
}),
signal: controller.signal,
});
@ -1095,7 +1113,7 @@ export default function NewChatPage() {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
: m
)
);
@ -1110,7 +1128,7 @@ export default function NewChatPage() {
break;
case "tool-input-start":
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
batcher.flush();
break;
@ -1122,7 +1140,7 @@ export default function NewChatPage() {
} else {
addToolCall(
contentPartsState,
TOOLS_WITH_UI,
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
parsed.input || {}
@ -1173,7 +1191,7 @@ export default function NewChatPage() {
const tcId = `interrupt-${action.name}`;
addToolCall(
contentPartsState,
TOOLS_WITH_UI,
toolsWithUI,
tcId,
action.name,
action.args,
@ -1190,7 +1208,7 @@ export default function NewChatPage() {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
: m
)
);
@ -1214,7 +1232,7 @@ export default function NewChatPage() {
batcher.flush();
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
if (contentParts.length > 0) {
try {
const savedMessage = await appendMessage(resumeThreadId, {
@ -1406,6 +1424,7 @@ export default function NewChatPage() {
]);
try {
const selection = await getAgentFilesystemSelection();
const response = await fetch(getRegenerateUrl(threadId), {
method: "POST",
headers: {
@ -1416,6 +1435,9 @@ export default function NewChatPage() {
search_space_id: searchSpaceId,
user_query: newUserQuery || null,
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
filesystem_mode: selection.filesystem_mode,
client_platform: selection.client_platform,
local_filesystem_mounts: selection.local_filesystem_mounts,
}),
signal: controller.signal,
});
@ -1428,7 +1450,7 @@ export default function NewChatPage() {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
? { ...m, content: buildContentForUI(contentPartsState, toolsWithUI) }
: m
)
);
@ -1443,7 +1465,7 @@ export default function NewChatPage() {
break;
case "tool-input-start":
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
batcher.flush();
break;
@ -1453,7 +1475,7 @@ export default function NewChatPage() {
} else {
addToolCall(
contentPartsState,
TOOLS_WITH_UI,
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
parsed.input || {}
@ -1502,7 +1524,7 @@ export default function NewChatPage() {
batcher.flush();
// Persist messages after streaming completes
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
if (contentParts.length > 0) {
try {
// Persist user message (for both edit and reload modes, since backend deleted it)

View file

@ -1,9 +1,7 @@
"use client";
import { BrainCog, Power, Rocket, Zap } from "lucide-react";
import { useEffect, useState } from "react";
import { toast } from "sonner";
import { DEFAULT_SHORTCUTS, ShortcutRecorder } from "@/components/desktop/shortcut-recorder";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { Label } from "@/components/ui/label";
import {
@ -24,9 +22,6 @@ export function DesktopContent() {
const [loading, setLoading] = useState(true);
const [enabled, setEnabled] = useState(true);
const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS);
const [shortcutsLoaded, setShortcutsLoaded] = useState(false);
const [searchSpaces, setSearchSpaces] = useState<SearchSpace[]>([]);
const [activeSpaceId, setActiveSpaceId] = useState<string | null>(null);
@ -37,7 +32,6 @@ export function DesktopContent() {
useEffect(() => {
if (!api) {
setLoading(false);
setShortcutsLoaded(true);
return;
}
@ -48,15 +42,13 @@ export function DesktopContent() {
Promise.all([
api.getAutocompleteEnabled(),
api.getShortcuts?.() ?? Promise.resolve(null),
api.getActiveSearchSpace?.() ?? Promise.resolve(null),
searchSpacesApiService.getSearchSpaces(),
hasAutoLaunchApi ? api.getAutoLaunch() : Promise.resolve(null),
])
.then(([autoEnabled, config, spaceId, spaces, autoLaunch]) => {
.then(([autoEnabled, spaceId, spaces, autoLaunch]) => {
if (!mounted) return;
setEnabled(autoEnabled);
if (config) setShortcuts(config);
setActiveSpaceId(spaceId);
if (spaces) setSearchSpaces(spaces);
if (autoLaunch) {
@ -65,12 +57,10 @@ export function DesktopContent() {
setAutoLaunchSupported(autoLaunch.supported);
}
setLoading(false);
setShortcutsLoaded(true);
})
.catch(() => {
if (!mounted) return;
setLoading(false);
setShortcutsLoaded(true);
});
return () => {
@ -82,7 +72,7 @@ export function DesktopContent() {
return (
<div className="flex flex-col items-center justify-center py-12 text-center">
<p className="text-sm text-muted-foreground">
Desktop settings are only available in the SurfSense desktop app.
App preferences are only available in the SurfSense desktop app.
</p>
</div>
);
@ -101,24 +91,6 @@ export function DesktopContent() {
await api.setAutocompleteEnabled(checked);
};
const updateShortcut = (
key: "generalAssist" | "quickAsk" | "autocomplete",
accelerator: string
) => {
setShortcuts((prev) => {
const updated = { ...prev, [key]: accelerator };
api.setShortcuts?.({ [key]: accelerator }).catch(() => {
toast.error("Failed to update shortcut");
});
return updated;
});
toast.success("Shortcut updated");
};
const resetShortcut = (key: "generalAssist" | "quickAsk" | "autocomplete") => {
updateShortcut(key, DEFAULT_SHORTCUTS[key]);
};
const handleAutoLaunchToggle = async (checked: boolean) => {
if (!autoLaunchSupported || !api.setAutoLaunch) {
toast.error("Please update the desktop app to configure launch on startup");
@ -196,7 +168,6 @@ export function DesktopContent() {
<Card>
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
<CardTitle className="text-base md:text-lg flex items-center gap-2">
<Power className="h-4 w-4" />
Launch on Startup
</CardTitle>
<CardDescription className="text-xs md:text-sm">
@ -245,56 +216,6 @@ export function DesktopContent() {
</CardContent>
</Card>
{/* Keyboard Shortcuts */}
<Card>
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
<CardTitle className="text-base md:text-lg">Keyboard Shortcuts</CardTitle>
<CardDescription className="text-xs md:text-sm">
Customize the global keyboard shortcuts for desktop features.
</CardDescription>
</CardHeader>
<CardContent className="px-3 md:px-6 pb-3 md:pb-6">
{shortcutsLoaded ? (
<div className="flex flex-col gap-3">
<ShortcutRecorder
value={shortcuts.generalAssist}
onChange={(accel) => updateShortcut("generalAssist", accel)}
onReset={() => resetShortcut("generalAssist")}
defaultValue={DEFAULT_SHORTCUTS.generalAssist}
label="General Assist"
description="Launch SurfSense instantly from any application"
icon={Rocket}
/>
<ShortcutRecorder
value={shortcuts.quickAsk}
onChange={(accel) => updateShortcut("quickAsk", accel)}
onReset={() => resetShortcut("quickAsk")}
defaultValue={DEFAULT_SHORTCUTS.quickAsk}
label="Quick Assist"
description="Select text anywhere, then ask AI to explain, rewrite, or act on it"
icon={Zap}
/>
<ShortcutRecorder
value={shortcuts.autocomplete}
onChange={(accel) => updateShortcut("autocomplete", accel)}
onReset={() => resetShortcut("autocomplete")}
defaultValue={DEFAULT_SHORTCUTS.autocomplete}
label="Extreme Assist"
description="AI drafts text using your screen context and knowledge base"
icon={BrainCog}
/>
<p className="text-[11px] text-muted-foreground">
Click a shortcut and press a new key combination to change it.
</p>
</div>
) : (
<div className="flex justify-center py-4">
<Spinner size="sm" />
</div>
)}
</CardContent>
</Card>
{/* Extreme Assist Toggle */}
<Card>
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">

View file

@ -0,0 +1,205 @@
"use client";
import { BrainCog, Rocket, RotateCcw, Zap } from "lucide-react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner";
import { DEFAULT_SHORTCUTS, keyEventToAccelerator } from "@/components/desktop/shortcut-recorder";
import { Button } from "@/components/ui/button";
import { ShortcutKbd } from "@/components/ui/shortcut-kbd";
import { Spinner } from "@/components/ui/spinner";
import { useElectronAPI } from "@/hooks/use-platform";
type ShortcutKey = "generalAssist" | "quickAsk" | "autocomplete";
type ShortcutMap = typeof DEFAULT_SHORTCUTS;
const HOTKEY_ROWS: Array<{ key: ShortcutKey; label: string; icon: React.ElementType }> = [
{ key: "generalAssist", label: "General Assist", icon: Rocket },
{ key: "quickAsk", label: "Quick Assist", icon: Zap },
{ key: "autocomplete", label: "Extreme Assist", icon: BrainCog },
];
function acceleratorToKeys(accel: string, isMac: boolean): string[] {
if (!accel) return [];
return accel.split("+").map((part) => {
if (part === "CommandOrControl") {
return isMac ? "⌘" : "Ctrl";
}
if (part === "Alt") {
return isMac ? "⌥" : "Alt";
}
if (part === "Shift") {
return isMac ? "⇧" : "Shift";
}
if (part === "Space") return "Space";
return part.length === 1 ? part.toUpperCase() : part;
});
}
function HotkeyRow({
label,
value,
defaultValue,
icon: Icon,
isMac,
onChange,
onReset,
}: {
label: string;
value: string;
defaultValue: string;
icon: React.ElementType;
isMac: boolean;
onChange: (accelerator: string) => void;
onReset: () => void;
}) {
const [recording, setRecording] = useState(false);
const inputRef = useRef<HTMLButtonElement>(null);
const isDefault = value === defaultValue;
const displayKeys = useMemo(() => acceleratorToKeys(value, isMac), [value, isMac]);
const handleKeyDown = useCallback(
(e: React.KeyboardEvent) => {
if (!recording) return;
e.preventDefault();
e.stopPropagation();
if (e.key === "Escape") {
setRecording(false);
return;
}
const accel = keyEventToAccelerator(e);
if (accel) {
onChange(accel);
setRecording(false);
}
},
[onChange, recording]
);
return (
<div className="flex items-center justify-between gap-2.5 border-border/60 border-b py-3 last:border-b-0">
<div className="flex items-center gap-2.5 min-w-0">
<div className="flex size-7 shrink-0 items-center justify-center rounded-md bg-primary/10 text-primary">
<Icon className="size-3.5" />
</div>
<p className="text-sm text-foreground truncate">{label}</p>
</div>
<div className="flex shrink-0 items-center gap-1">
{!isDefault && (
<Button
variant="ghost"
size="icon"
className="size-7 text-muted-foreground hover:text-foreground"
onClick={onReset}
title="Reset to default"
>
<RotateCcw className="size-3" />
</Button>
)}
<button
ref={inputRef}
type="button"
title={recording ? "Press shortcut keys" : "Click to edit shortcut"}
onClick={() => setRecording(true)}
onKeyDown={handleKeyDown}
onBlur={() => setRecording(false)}
className={
recording
? "flex h-7 items-center rounded-md border border-transparent bg-primary/5 outline-none ring-0 focus:outline-none focus-visible:outline-none focus-visible:ring-0"
: "flex h-7 cursor-pointer items-center rounded-md border border-transparent bg-transparent outline-none ring-0 transition-colors hover:bg-accent hover:text-accent-foreground focus:outline-none focus-visible:outline-none focus-visible:ring-0"
}
>
{recording ? (
<span className="px-2 text-[9px] text-primary whitespace-nowrap">
Press hotkeys...
</span>
) : (
<ShortcutKbd keys={displayKeys} className="ml-0 px-1.5 text-foreground/85" />
)}
</button>
</div>
</div>
);
}
export function DesktopShortcutsContent() {
const api = useElectronAPI();
const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS);
const [shortcutsLoaded, setShortcutsLoaded] = useState(false);
const isMac = api?.versions?.platform === "darwin";
useEffect(() => {
if (!api) {
setShortcutsLoaded(true);
return;
}
let mounted = true;
(api.getShortcuts?.() ?? Promise.resolve(null))
.then((config: ShortcutMap | null) => {
if (!mounted) return;
if (config) setShortcuts(config);
setShortcutsLoaded(true);
})
.catch(() => {
if (!mounted) return;
setShortcutsLoaded(true);
});
return () => {
mounted = false;
};
}, [api]);
if (!api) {
return (
<div className="flex flex-col items-center justify-center py-12 text-center">
<p className="text-sm text-muted-foreground">Hotkeys are only available in the SurfSense desktop app.</p>
</div>
);
}
const updateShortcut = (
key: "generalAssist" | "quickAsk" | "autocomplete",
accelerator: string
) => {
setShortcuts((prev) => {
const updated = { ...prev, [key]: accelerator };
api.setShortcuts?.({ [key]: accelerator }).catch(() => {
toast.error("Failed to update shortcut");
});
return updated;
});
toast.success("Shortcut updated");
};
const resetShortcut = (key: ShortcutKey) => {
updateShortcut(key, DEFAULT_SHORTCUTS[key]);
};
return (
shortcutsLoaded ? (
<div className="flex flex-col gap-3">
<div>
{HOTKEY_ROWS.map((row) => (
<HotkeyRow
key={row.key}
label={row.label}
value={shortcuts[row.key]}
defaultValue={DEFAULT_SHORTCUTS[row.key]}
icon={row.icon}
isMac={isMac}
onChange={(accel) => updateShortcut(row.key, accel)}
onReset={() => resetShortcut(row.key)}
/>
))}
</div>
</div>
) : (
<div className="flex justify-center py-4">
<Spinner size="sm" />
</div>
)
);
}

View file

@ -1,7 +1,7 @@
"use client";
import { useAtomValue } from "jotai";
import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pen } from "lucide-react";
import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pencil } from "lucide-react";
import { useCallback, useEffect, useRef, useState } from "react";
import { toast } from "sonner";
import { z } from "zod";
@ -241,7 +241,7 @@ export function MemoryContent() {
onClick={openInput}
className="absolute bottom-3 right-3 z-10 h-[54px] w-[54px] rounded-full border bg-muted/60 backdrop-blur-sm shadow-sm"
>
<Pen className="!h-5 !w-5" />
<Pencil className="!h-5 !w-5" />
</Button>
)}
</div>

View file

@ -1,7 +1,7 @@
"use client";
import { useAtomValue } from "jotai";
import { AlertTriangle, Globe, Lock, PenLine, Sparkles, Trash2 } from "lucide-react";
import { AlertTriangle, Globe, Lock, Pencil, Sparkles, Trash2 } from "lucide-react";
import { useCallback, useState } from "react";
import { toast } from "sonner";
import {
@ -308,7 +308,7 @@ export function PromptsContent() {
className="size-7"
onClick={() => handleEdit(prompt)}
>
<PenLine className="size-3.5" />
<Pencil className="size-3.5" />
</Button>
<Button
variant="ghost"

View file

@ -2,17 +2,18 @@
import { IconBrandGoogleFilled } from "@tabler/icons-react";
import { useAtom } from "jotai";
import { BrainCog, Eye, EyeOff, Rocket, Zap } from "lucide-react";
import { BrainCog, Eye, EyeOff, Rocket, RotateCcw, Zap } from "lucide-react";
import Image from "next/image";
import { useRouter } from "next/navigation";
import { useCallback, useEffect, useState } from "react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner";
import { loginMutationAtom } from "@/atoms/auth/auth-mutation.atoms";
import { DEFAULT_SHORTCUTS, ShortcutRecorder } from "@/components/desktop/shortcut-recorder";
import { DEFAULT_SHORTCUTS, keyEventToAccelerator } from "@/components/desktop/shortcut-recorder";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { Separator } from "@/components/ui/separator";
import { ShortcutKbd } from "@/components/ui/shortcut-kbd";
import { Spinner } from "@/components/ui/spinner";
import { useElectronAPI } from "@/hooks/use-platform";
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
@ -20,6 +21,137 @@ import { setBearerToken } from "@/lib/auth-utils";
import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config";
const isGoogleAuth = AUTH_TYPE === "GOOGLE";
type ShortcutKey = "generalAssist" | "quickAsk" | "autocomplete";
type ShortcutMap = typeof DEFAULT_SHORTCUTS;
const HOTKEY_ROWS: Array<{ key: ShortcutKey; label: string; description: string; icon: React.ElementType }> = [
{
key: "generalAssist",
label: "General Assist",
description: "Launch SurfSense instantly from any application",
icon: Rocket,
},
{
key: "quickAsk",
label: "Quick Assist",
description: "Select text anywhere, then ask AI to explain, rewrite, or act on it",
icon: Zap,
},
{
key: "autocomplete",
label: "Extreme Assist",
description: "AI drafts text using your screen context and knowledge base",
icon: BrainCog,
},
];
function acceleratorToKeys(accel: string, isMac: boolean): string[] {
if (!accel) return [];
return accel.split("+").map((part) => {
if (part === "CommandOrControl") {
return isMac ? "⌘" : "Ctrl";
}
if (part === "Alt") {
return isMac ? "⌥" : "Alt";
}
if (part === "Shift") {
return isMac ? "⇧" : "Shift";
}
if (part === "Space") return "Space";
return part.length === 1 ? part.toUpperCase() : part;
});
}
function HotkeyRow({
label,
description,
value,
defaultValue,
icon: Icon,
isMac,
onChange,
onReset,
}: {
label: string;
description: string;
value: string;
defaultValue: string;
icon: React.ElementType;
isMac: boolean;
onChange: (accelerator: string) => void;
onReset: () => void;
}) {
const [recording, setRecording] = useState(false);
const inputRef = useRef<HTMLButtonElement>(null);
const isDefault = value === defaultValue;
const displayKeys = useMemo(() => acceleratorToKeys(value, isMac), [value, isMac]);
const handleKeyDown = useCallback(
(e: React.KeyboardEvent) => {
if (!recording) return;
e.preventDefault();
e.stopPropagation();
if (e.key === "Escape") {
setRecording(false);
return;
}
const accel = keyEventToAccelerator(e);
if (accel) {
onChange(accel);
setRecording(false);
}
},
[onChange, recording]
);
return (
<div className="flex items-center justify-between gap-2.5 border-border/60 border-b py-3 last:border-b-0">
<div className="flex items-center gap-2.5 min-w-0">
<div className="flex size-7 shrink-0 items-center justify-center rounded-md bg-primary/10 text-primary">
<Icon className="size-3.5" />
</div>
<div className="min-w-0">
<p className="text-sm font-medium text-foreground truncate">{label}</p>
<p className="text-xs text-muted-foreground line-clamp-2">{description}</p>
</div>
</div>
<div className="flex shrink-0 items-center gap-1">
{!isDefault && (
<Button
variant="ghost"
size="icon"
className="size-7 text-muted-foreground hover:text-foreground"
onClick={onReset}
title="Reset to default"
>
<RotateCcw className="size-3" />
</Button>
)}
<button
ref={inputRef}
type="button"
title={recording ? "Press shortcut keys" : "Click to edit shortcut"}
onClick={() => setRecording(true)}
onKeyDown={handleKeyDown}
onBlur={() => setRecording(false)}
className={
recording
? "flex h-7 items-center rounded-md border border-transparent bg-primary/5 outline-none ring-0 focus:outline-none focus-visible:outline-none focus-visible:ring-0"
: "flex h-7 cursor-pointer items-center rounded-md border border-transparent bg-transparent outline-none ring-0 transition-colors hover:bg-accent hover:text-accent-foreground focus:outline-none focus-visible:outline-none focus-visible:ring-0"
}
>
{recording ? (
<span className="px-2 text-[9px] text-primary whitespace-nowrap">Press hotkeys...</span>
) : (
<ShortcutKbd keys={displayKeys} className="ml-0 px-1.5 text-foreground/85" />
)}
</button>
</div>
</div>
);
}
export default function DesktopLoginPage() {
const router = useRouter();
@ -33,6 +165,7 @@ export default function DesktopLoginPage() {
const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS);
const [shortcutsLoaded, setShortcutsLoaded] = useState(false);
const isMac = api?.versions?.platform === "darwin";
useEffect(() => {
if (!api?.getShortcuts) {
@ -41,7 +174,7 @@ export default function DesktopLoginPage() {
}
api
.getShortcuts()
.then((config) => {
.then((config: ShortcutMap | null) => {
if (config) setShortcuts(config);
setShortcutsLoaded(true);
})
@ -117,18 +250,8 @@ export default function DesktopLoginPage() {
};
return (
<div className="relative flex min-h-svh items-center justify-center bg-background p-4 sm:p-6">
{/* Subtle radial glow */}
<div className="pointer-events-none fixed inset-0 overflow-hidden">
<div
className="absolute -top-1/2 left-1/2 size-[800px] -translate-x-1/2 rounded-full opacity-[0.03]"
style={{
background: "radial-gradient(circle, hsl(var(--primary)) 0%, transparent 70%)",
}}
/>
</div>
<div className="relative flex w-full max-w-md flex-col overflow-hidden rounded-xl border bg-card shadow-lg">
<div className="relative flex min-h-svh items-center justify-center bg-background p-4 sm:p-6 select-none">
<div className="relative flex w-full max-w-md flex-col overflow-hidden bg-card shadow-lg">
{/* Header */}
<div className="flex flex-col items-center px-6 pt-6 pb-2 text-center">
<Image
@ -141,7 +264,7 @@ export default function DesktopLoginPage() {
/>
<h1 className="text-lg font-semibold tracking-tight">Welcome to SurfSense Desktop</h1>
<p className="mt-1 text-sm text-muted-foreground">
Configure shortcuts, then sign in to get started.
Configure shortcuts, then sign in to get started
</p>
</div>
@ -151,41 +274,24 @@ export default function DesktopLoginPage() {
{/* ---- Shortcuts ---- */}
{shortcutsLoaded ? (
<div className="flex flex-col gap-2">
<p className="text-xs font-medium uppercase tracking-wider text-muted-foreground">
Keyboard Shortcuts
</p>
<div className="flex flex-col gap-1.5">
<ShortcutRecorder
value={shortcuts.generalAssist}
onChange={(accel) => updateShortcut("generalAssist", accel)}
onReset={() => resetShortcut("generalAssist")}
defaultValue={DEFAULT_SHORTCUTS.generalAssist}
label="General Assist"
description="Launch SurfSense instantly from any application"
icon={Rocket}
/>
<ShortcutRecorder
value={shortcuts.quickAsk}
onChange={(accel) => updateShortcut("quickAsk", accel)}
onReset={() => resetShortcut("quickAsk")}
defaultValue={DEFAULT_SHORTCUTS.quickAsk}
label="Quick Assist"
description="Select text anywhere, then ask AI to explain, rewrite, or act on it"
icon={Zap}
/>
<ShortcutRecorder
value={shortcuts.autocomplete}
onChange={(accel) => updateShortcut("autocomplete", accel)}
onReset={() => resetShortcut("autocomplete")}
defaultValue={DEFAULT_SHORTCUTS.autocomplete}
label="Extreme Assist"
description="AI drafts text using your screen context and knowledge base"
icon={BrainCog}
/>
{/* <p className="text-xs font-medium uppercase tracking-wider text-muted-foreground">
Hotkeys
</p> */}
<div>
{HOTKEY_ROWS.map((row) => (
<HotkeyRow
key={row.key}
label={row.label}
description={row.description}
value={shortcuts[row.key]}
defaultValue={DEFAULT_SHORTCUTS[row.key]}
icon={row.icon}
isMac={isMac}
onChange={(accel) => updateShortcut(row.key, accel)}
onReset={() => resetShortcut(row.key)}
/>
))}
</div>
<p className="text-[11px] text-muted-foreground text-center mt-1">
Click a shortcut and press a new key combination to change it.
</p>
</div>
) : (
<div className="flex justify-center py-6">
@ -197,9 +303,9 @@ export default function DesktopLoginPage() {
{/* ---- Auth ---- */}
<div className="flex flex-col gap-3">
<p className="text-xs font-medium uppercase tracking-wider text-muted-foreground">
{/* <p className="text-xs font-medium uppercase tracking-wider text-muted-foreground">
Sign In
</p>
</p> */}
{isGoogleAuth ? (
<Button variant="outline" className="w-full gap-2 h-10" onClick={handleGoogleLogin}>
@ -261,15 +367,9 @@ export default function DesktopLoginPage() {
</div>
</div>
<Button type="submit" disabled={isLoggingIn} className="h-9 mt-1">
{isLoggingIn ? (
<>
<Spinner size="sm" className="text-primary-foreground" />
Signing in
</>
) : (
"Sign in"
)}
<Button type="submit" disabled={isLoggingIn} className="relative h-9 mt-1">
<span className={isLoggingIn ? "opacity-0" : ""}>Sign in</span>
{isLoggingIn && <Spinner size="sm" className="absolute text-primary-foreground" />}
</Button>
</form>
)}

View file

@ -3,14 +3,18 @@ import { rightPanelCollapsedAtom, rightPanelTabAtom } from "@/atoms/layout/right
interface EditorPanelState {
isOpen: boolean;
kind: "document" | "local_file";
documentId: number | null;
localFilePath: string | null;
searchSpaceId: number | null;
title: string | null;
}
const initialState: EditorPanelState = {
isOpen: false,
kind: "document",
documentId: null,
localFilePath: null,
searchSpaceId: null,
title: null,
};
@ -26,20 +30,38 @@ export const openEditorPanelAtom = atom(
(
get,
set,
{
documentId,
searchSpaceId,
title,
}: { documentId: number; searchSpaceId: number; title?: string }
payload:
| { documentId: number; searchSpaceId: number; title?: string; kind?: "document" }
| {
kind: "local_file";
localFilePath: string;
title?: string;
searchSpaceId?: number;
}
) => {
if (!get(editorPanelAtom).isOpen) {
set(preEditorCollapsedAtom, get(rightPanelCollapsedAtom));
}
if (payload.kind === "local_file") {
set(editorPanelAtom, {
isOpen: true,
kind: "local_file",
documentId: null,
localFilePath: payload.localFilePath,
searchSpaceId: payload.searchSpaceId ?? null,
title: payload.title ?? null,
});
set(rightPanelTabAtom, "editor");
set(rightPanelCollapsedAtom, false);
return;
}
set(editorPanelAtom, {
isOpen: true,
documentId,
searchSpaceId,
title: title ?? null,
kind: "document",
documentId: payload.documentId,
localFilePath: null,
searchSpaceId: payload.searchSpaceId,
title: payload.title ?? null,
});
set(rightPanelTabAtom, "editor");
set(rightPanelCollapsedAtom, false);

View file

@ -123,8 +123,9 @@ export const ConnectorIndicator = forwardRef<ConnectorIndicatorHandle, Connector
handleSkipIndexing,
handleStartEdit,
handleSaveConnector,
handleDisconnectConnector,
handleBackFromEdit,
handleDisconnectConnector,
handleDisconnectFromList,
handleBackFromEdit,
handleBackFromConnect,
handleBackFromYouTube,
handleViewAccountsList,
@ -225,25 +226,27 @@ export const ConnectorIndicator = forwardRef<ConnectorIndicatorHandle, Connector
{isYouTubeView && searchSpaceId ? (
<YouTubeCrawlerView searchSpaceId={searchSpaceId} onBack={handleBackFromYouTube} />
) : viewingMCPList ? (
<ConnectorAccountsListView
connectorType="MCP_CONNECTOR"
connectorTitle="MCP Connectors"
connectors={(allConnectors || []) as SearchSourceConnector[]}
indexingConnectorIds={indexingConnectorIds}
onBack={handleBackFromMCPList}
onManage={handleStartEdit}
onAddAccount={handleAddNewMCPFromList}
addButtonText="Add New MCP Server"
/>
<ConnectorAccountsListView
connectorType="MCP_CONNECTOR"
connectorTitle="MCP Connectors"
connectors={(allConnectors || []) as SearchSourceConnector[]}
indexingConnectorIds={indexingConnectorIds}
onBack={handleBackFromMCPList}
onManage={handleStartEdit}
onDisconnect={(connector) => handleDisconnectFromList(connector, () => refreshConnectors())}
onAddAccount={handleAddNewMCPFromList}
addButtonText="Add New MCP Server"
/>
) : viewingAccountsType ? (
<ConnectorAccountsListView
connectorType={viewingAccountsType.connectorType}
connectorTitle={viewingAccountsType.connectorTitle}
connectors={(connectors || []) as SearchSourceConnector[]}
indexingConnectorIds={indexingConnectorIds}
onBack={handleBackFromAccountsList}
onManage={handleStartEdit}
onAddAccount={() => {
<ConnectorAccountsListView
connectorType={viewingAccountsType.connectorType}
connectorTitle={viewingAccountsType.connectorTitle}
connectors={(connectors || []) as SearchSourceConnector[]}
indexingConnectorIds={indexingConnectorIds}
onBack={handleBackFromAccountsList}
onManage={handleStartEdit}
onDisconnect={(connector) => handleDisconnectFromList(connector, () => refreshConnectors())}
onAddAccount={() => {
// Check both OAUTH_CONNECTORS and COMPOSIO_CONNECTORS
const oauthConnector =
OAUTH_CONNECTORS.find(

View file

@ -1,6 +1,6 @@
"use client";
import { CheckCircle2, ChevronDown, ChevronUp, Server, XCircle } from "lucide-react";
import { CheckCircle2, ChevronDown, ChevronUp, Loader2, Server, XCircle } from "lucide-react";
import { type FC, useRef, useState } from "react";
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
import { Button } from "@/components/ui/button";
@ -212,7 +212,14 @@ export const MCPConnectForm: FC<ConnectFormProps> = ({ onSubmit, isSubmitting })
variant="secondary"
className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80"
>
{isTesting ? "Testing Connection" : "Test Connection"}
{isTesting ? (
<>
<Loader2 className="h-3.5 w-3.5 animate-spin" />
Testing Connection...
</>
) : (
"Test Connection"
)}
</Button>
</div>

View file

@ -1,6 +1,6 @@
"use client";
import { CheckCircle2, ChevronDown, ChevronUp, Server, XCircle } from "lucide-react";
import { CheckCircle2, ChevronDown, ChevronUp, Loader2, Server, XCircle } from "lucide-react";
import type { FC } from "react";
import { useCallback, useEffect, useRef, useState } from "react";
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
@ -217,7 +217,14 @@ export const MCPConfig: FC<MCPConfigProps> = ({ connector, onConfigChange, onNam
variant="secondary"
className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80"
>
{isTesting ? "Testing Connection" : "Test Connection"}
{isTesting ? (
<>
<Loader2 className="h-3.5 w-3.5 animate-spin" />
Testing Connection...
</>
) : (
"Test Connection"
)}
</Button>
</div>

View file

@ -7,7 +7,6 @@ import { toast } from "sonner";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import { Button } from "@/components/ui/button";
import { Spinner } from "@/components/ui/spinner";
import { EnumConnectorName } from "@/contracts/enums/connector";
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
import { authenticatedFetch } from "@/lib/auth-utils";
@ -16,23 +15,11 @@ import { DateRangeSelector } from "../../components/date-range-selector";
import { PeriodicSyncConfig } from "../../components/periodic-sync-config";
import { SummaryConfig } from "../../components/summary-config";
import { VisionLLMConfig } from "../../components/vision-llm-config";
import { LIVE_CONNECTOR_TYPES } from "../../constants/connector-constants";
import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../../constants/connector-constants";
import { getConnectorDisplayName } from "../../tabs/all-connectors-tab";
import { MCPServiceConfig } from "../components/mcp-service-config";
import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index";
const REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
[EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth",
[EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth",
[EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth",
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth",
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth",
[EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
[EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth",
[EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth",
};
interface ConnectorEditViewProps {
connector: SearchSourceConnector;
startDate: Date | undefined;
@ -86,7 +73,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
}) => {
const searchSpaceIdAtom = useAtomValue(activeSearchSpaceIdAtom);
const isAuthExpired = connector.config?.auth_expired === true;
const reauthEndpoint = REAUTH_ENDPOINTS[connector.connector_type];
const reauthEndpoint = getReauthEndpoint(connector);
const [reauthing, setReauthing] = useState(false);
const handleReauth = useCallback(async () => {
@ -124,10 +111,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
// Get connector-specific config component (MCP-backed connectors use a generic view)
const ConnectorConfigComponent = useMemo(() => {
if (isMCPBacked) {
const { MCPServiceConfig } = require("../components/mcp-service-config");
return MCPServiceConfig as FC<ConnectorConfigProps>;
}
if (isMCPBacked) return MCPServiceConfig;
return getConnectorConfigComponent(connector.connector_type);
}, [connector.connector_type, isMCPBacked]);
const [isScrolled, setIsScrolled] = useState(false);

View file

@ -1,4 +1,5 @@
import { EnumConnectorName } from "@/contracts/enums/connector";
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
/**
* Connectors that operate in real time (no background indexing).
@ -367,5 +368,45 @@ export function getConnectorTelemetryMeta(connectorType: string): ConnectorTelem
};
}
// =============================================================================
// REAUTH ENDPOINTS
// =============================================================================
/**
* Legacy (non-MCP) OAuth reauth endpoints, keyed by connector type.
* These are used for connectors that were NOT created via MCP OAuth.
*/
export const LEGACY_REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
[EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth",
[EnumConnectorName.JIRA_CONNECTOR]: "/api/v1/auth/jira/connector/reauth",
[EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth",
[EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth",
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth",
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth",
[EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
[EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth",
[EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth",
[EnumConnectorName.CONFLUENCE_CONNECTOR]: "/api/v1/auth/confluence/connector/reauth",
[EnumConnectorName.TEAMS_CONNECTOR]: "/api/v1/auth/teams/connector/reauth",
[EnumConnectorName.DISCORD_CONNECTOR]: "/api/v1/auth/discord/connector/reauth",
};
/**
* Resolve the reauth endpoint for a connector.
*
* MCP OAuth connectors (those with ``config.mcp_service``) dynamically build
* the URL from the service key. Legacy OAuth connectors fall back to the
* static ``LEGACY_REAUTH_ENDPOINTS`` map.
*/
export function getReauthEndpoint(connector: SearchSourceConnector): string | undefined {
const mcpService = connector.config?.mcp_service as string | undefined;
if (mcpService) {
return `/api/v1/auth/mcp/${mcpService}/connector/reauth`;
}
return LEGACY_REAUTH_ENDPOINTS[connector.connector_type];
}
// Re-export IndexingConfigState from schemas for backward compatibility
export type { IndexingConfigState } from "./connector-popup.schemas";

View file

@ -1311,6 +1311,25 @@ export const useConnectorDialog = () => {
[editingConnector, searchSpaceId, deleteConnector, cameFromMCPList, setIsOpen]
);
const handleDisconnectFromList = useCallback(
async (connector: SearchSourceConnector, refreshConnectors: () => void) => {
if (!searchSpaceId) return;
try {
await deleteConnector({ id: connector.id });
trackConnectorDeleted(Number(searchSpaceId), connector.connector_type, connector.id);
toast.success(`${connector.name} disconnected successfully`);
refreshConnectors();
queryClient.invalidateQueries({
queryKey: cacheKeys.logs.summary(Number(searchSpaceId)),
});
} catch (error) {
console.error("Error disconnecting connector:", error);
toast.error("Failed to disconnect connector");
}
},
[searchSpaceId, deleteConnector]
);
// Handle quick index (index with selected date range, or backend defaults if none selected)
const handleQuickIndexConnector = useCallback(
async (
@ -1484,6 +1503,7 @@ export const useConnectorDialog = () => {
handleStartEdit,
handleSaveConnector,
handleDisconnectConnector,
handleDisconnectFromList,
handleBackFromEdit,
handleBackFromConnect,
handleBackFromYouTube,

View file

@ -1,7 +1,7 @@
"use client";
import { useAtomValue } from "jotai";
import { ArrowLeft, Plus, RefreshCw, Server } from "lucide-react";
import { ArrowLeft, Plus, RefreshCw, Server, Trash2 } from "lucide-react";
import { type FC, useCallback, useState } from "react";
import { toast } from "sonner";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
@ -13,25 +13,10 @@ import type { SearchSourceConnector } from "@/contracts/types/connector.types";
import { authenticatedFetch } from "@/lib/auth-utils";
import { formatRelativeDate } from "@/lib/format-date";
import { cn } from "@/lib/utils";
import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants";
import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../constants/connector-constants";
import { useConnectorStatus } from "../hooks/use-connector-status";
import { getConnectorDisplayName } from "../tabs/all-connectors-tab";
const REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
[EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth",
[EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth",
[EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth",
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth",
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth",
[EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
[EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth",
[EnumConnectorName.JIRA_CONNECTOR]: "/api/v1/auth/jira/connector/reauth",
[EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth",
[EnumConnectorName.CONFLUENCE_CONNECTOR]: "/api/v1/auth/confluence/connector/reauth",
};
interface ConnectorAccountsListViewProps {
connectorType: string;
connectorTitle: string;
@ -39,15 +24,12 @@ interface ConnectorAccountsListViewProps {
indexingConnectorIds: Set<number>;
onBack: () => void;
onManage: (connector: SearchSourceConnector) => void;
onDisconnect?: (connector: SearchSourceConnector) => Promise<void> | void;
onAddAccount: () => void;
isConnecting?: boolean;
addButtonText?: string;
}
function isLiveConnector(connectorType: string): boolean {
return LIVE_CONNECTOR_TYPES.has(connectorType) || connectorType === "MCP_CONNECTOR";
}
export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
connectorType,
connectorTitle,
@ -55,12 +37,15 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
indexingConnectorIds,
onBack,
onManage,
onDisconnect,
onAddAccount,
isConnecting = false,
addButtonText,
}) => {
const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
const [reauthingId, setReauthingId] = useState<number | null>(null);
const [confirmDisconnectId, setConfirmDisconnectId] = useState<number | null>(null);
const [disconnectingId, setDisconnectingId] = useState<number | null>(null);
// Get connector status
const { isConnectorEnabled, getConnectorStatusMessage } = useConnectorStatus();
@ -68,16 +53,15 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
const isEnabled = isConnectorEnabled(connectorType);
const statusMessage = getConnectorStatusMessage(connectorType);
const reauthEndpoint = REAUTH_ENDPOINTS[connectorType];
const handleReauth = useCallback(
async (connectorId: number) => {
if (!searchSpaceId || !reauthEndpoint) return;
setReauthingId(connectorId);
async (connector: SearchSourceConnector) => {
const endpoint = getReauthEndpoint(connector);
if (!searchSpaceId || !endpoint) return;
setReauthingId(connector.id);
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
const url = new URL(`${backendUrl}${reauthEndpoint}`);
url.searchParams.set("connector_id", String(connectorId));
const url = new URL(`${backendUrl}${endpoint}`);
url.searchParams.set("connector_id", String(connector.id));
url.searchParams.set("space_id", String(searchSpaceId));
url.searchParams.set("return_url", window.location.pathname);
const response = await authenticatedFetch(url.toString());
@ -99,7 +83,7 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
setReauthingId(null);
}
},
[searchSpaceId, reauthEndpoint]
[searchSpaceId]
);
// Filter connectors to only show those of this type
@ -198,9 +182,11 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
</div>
) : (
<div className="grid grid-cols-1 sm:grid-cols-2 gap-3">
{typeConnectors.map((connector) => {
const isIndexing = indexingConnectorIds.has(connector.id);
const isAuthExpired = !!reauthEndpoint && connector.config?.auth_expired === true;
{typeConnectors.map((connector) => {
const isIndexing = indexingConnectorIds.has(connector.id);
const connectorReauthEndpoint = getReauthEndpoint(connector);
const isAuthExpired = !!connectorReauthEndpoint && connector.config?.auth_expired === true;
const isLive = LIVE_CONNECTOR_TYPES.has(connector.connector_type) || Boolean(connector.config?.server_config);
return (
<div
@ -231,7 +217,7 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
<Spinner size="xs" />
Syncing
</p>
) : !isLiveConnector(connector.connector_type) ? (
) : !isLive ? (
<p className="text-[10px] mt-1 whitespace-nowrap truncate text-muted-foreground">
{connector.last_indexed_at
? `Last indexed: ${formatRelativeDate(connector.last_indexed_at)}`
@ -239,28 +225,73 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
</p>
) : null}
</div>
{isAuthExpired ? (
<Button
size="sm"
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-amber-600 hover:bg-amber-700 text-white border-0 shadow-xs shrink-0"
onClick={() => handleReauth(connector.id)}
disabled={reauthingId === connector.id}
>
<RefreshCw
className={cn("size-3.5", reauthingId === connector.id && "animate-spin")}
/>
Re-authenticate
</Button>
{isAuthExpired ? (
<Button
size="sm"
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-amber-600 hover:bg-amber-700 text-white border-0 shadow-xs shrink-0"
onClick={() => handleReauth(connector)}
disabled={reauthingId === connector.id}
>
<RefreshCw
className={cn("size-3.5", reauthingId === connector.id && "animate-spin")}
/>
Re-authenticate
</Button>
) : isLive && onDisconnect ? (
confirmDisconnectId === connector.id ? (
<div className="flex items-center gap-1.5 shrink-0">
<Button
variant="destructive"
size="sm"
className="h-8 text-[11px] px-3 rounded-lg font-medium shadow-xs"
onClick={async () => {
setDisconnectingId(connector.id);
setConfirmDisconnectId(null);
try {
await onDisconnect(connector);
} finally {
setDisconnectingId(null);
}
}}
disabled={disconnectingId === connector.id}
>
{disconnectingId === connector.id ? (
<RefreshCw className="size-3.5 animate-spin" />
) : (
"Confirm"
)}
</Button>
<Button
variant="ghost"
size="sm"
className="h-8 text-[11px] px-2 rounded-lg"
onClick={() => setConfirmDisconnectId(null)}
disabled={disconnectingId === connector.id}
>
Cancel
</Button>
</div>
) : (
<Button
variant="secondary"
size="sm"
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80 shrink-0"
onClick={() => onManage(connector)}
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-red-50 hover:text-red-700 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-red-950 dark:hover:text-red-400 shrink-0"
onClick={() => setConfirmDisconnectId(connector.id)}
>
Manage
<Trash2 className="size-3.5" />
Disconnect
</Button>
)}
)
) : (
<Button
variant="secondary"
size="sm"
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80 shrink-0"
onClick={() => onManage(connector)}
>
Manage
</Button>
)}
</div>
);
})}

View file

@ -7,16 +7,20 @@ import {
unstable_memoizeMarkdownComponents as memoizeMarkdownComponents,
useIsMarkdownCodeBlock,
} from "@assistant-ui/react-markdown";
import { useSetAtom } from "jotai";
import { ExternalLinkIcon } from "lucide-react";
import dynamic from "next/dynamic";
import { useParams } from "next/navigation";
import { useTheme } from "next-themes";
import { memo, type ReactNode } from "react";
import rehypeKatex from "rehype-katex";
import remarkGfm from "remark-gfm";
import remarkMath from "remark-math";
import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image";
import "katex/dist/katex.min.css";
import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation";
import { useElectronAPI } from "@/hooks/use-platform";
import { Skeleton } from "@/components/ui/skeleton";
import {
Table,
@ -222,6 +226,18 @@ function extractDomain(url: string): string {
}
}
// Canonical local-file virtual paths are mount-prefixed: /<mount>/<relative/path>
const LOCAL_FILE_PATH_REGEX = /^\/[a-z0-9_-]+\/[^\s`]+(?:\/[^\s`]+)*$/;
function isVirtualFilePathToken(value: string): boolean {
if (!LOCAL_FILE_PATH_REGEX.test(value) || value.startsWith("//")) {
return false;
}
const normalized = value.replace(/\/+$/, "");
const segments = normalized.split("/").filter(Boolean);
return segments.length >= 2;
}
function MarkdownImage({ src, alt }: { src?: string; alt?: string }) {
if (!src) return null;
@ -392,7 +408,51 @@ const defaultComponents = memoizeMarkdownComponents({
code: function Code({ className, children, ...props }) {
const isCodeBlock = useIsMarkdownCodeBlock();
const { resolvedTheme } = useTheme();
const openEditorPanel = useSetAtom(openEditorPanelAtom);
const params = useParams();
const electronAPI = useElectronAPI();
const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text";
const codeString = String(children).replace(/\n$/, "");
const isWebLocalFileCodeBlock =
isCodeBlock &&
!electronAPI &&
isVirtualFilePathToken(codeString.trim()) &&
!codeString.trim().startsWith("//") &&
!codeString.includes("\n");
if (!isCodeBlock) {
const inlineValue = String(children ?? "").trim();
const isLocalPath =
!!electronAPI && isVirtualFilePathToken(inlineValue) && !inlineValue.startsWith("//");
const displayLocalPath = inlineValue.replace(/^\/+/, "");
const searchSpaceIdParam = params?.search_space_id;
const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam)
? Number(searchSpaceIdParam[0])
: Number(searchSpaceIdParam);
if (isLocalPath) {
return (
<button
type="button"
className={cn(
"cursor-pointer font-mono text-[0.9em] font-medium text-primary underline underline-offset-4 transition-colors hover:text-primary/80"
)}
onClick={(event) => {
event.preventDefault();
event.stopPropagation();
openEditorPanel({
kind: "local_file",
localFilePath: inlineValue,
title: inlineValue.split("/").pop() || inlineValue,
searchSpaceId: Number.isFinite(parsedSearchSpaceId)
? parsedSearchSpaceId
: undefined,
});
}}
title="Open in editor panel"
>
{displayLocalPath}
</button>
);
}
return (
<code
className={cn(
@ -405,8 +465,19 @@ const defaultComponents = memoizeMarkdownComponents({
</code>
);
}
const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text";
const codeString = String(children).replace(/\n$/, "");
if (isWebLocalFileCodeBlock) {
return (
<code
className={cn(
"aui-md-inline-code rounded-md border bg-muted px-1.5 py-0.5 font-mono text-[0.9em] font-normal",
className
)}
{...props}
>
{codeString.trim()}
</code>
);
}
return (
<LazyMarkdownCodeBlock
className={className}

View file

@ -1104,7 +1104,13 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
group.tools.flatMap((t, i) =>
i === 0
? [t.description]
: [<Dot key={i} className="inline h-4 w-4" />, t.description]
: [
<Dot
key={`dot-${group.label}-${t.description}`}
className="inline h-4 w-4"
/>,
t.description,
]
)}
</TooltipContent>
</Tooltip>

View file

@ -1,6 +1,6 @@
import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react";
import { useAtomValue } from "jotai";
import { CheckIcon, CopyIcon, FileText, Pen } from "lucide-react";
import { CheckIcon, CopyIcon, FileText, Pencil } from "lucide-react";
import Image from "next/image";
import { type FC, useState } from "react";
import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
@ -136,7 +136,7 @@ const UserActionBar: FC = () => {
{canEdit && (
<ActionBarPrimitive.Edit asChild>
<TooltipIconButton tooltip="Edit" className="aui-user-action-edit">
<Pen />
<Pencil />
</TooltipIconButton>
</ActionBarPrimitive.Edit>
)}

View file

@ -1,6 +1,6 @@
"use client";
import { MoreHorizontal, PenLine, Trash2 } from "lucide-react";
import { MoreHorizontal, Pencil, Trash2 } from "lucide-react";
import { Button } from "@/components/ui/button";
import {
DropdownMenu,
@ -29,7 +29,7 @@ export function CommentActions({ canEdit, canDelete, onEdit, onDelete }: Comment
<DropdownMenuContent align="end">
{canEdit && (
<DropdownMenuItem onClick={onEdit}>
<PenLine className="mr-2 size-4" />
<Pencil className="mr-2 size-4" />
Edit
</DropdownMenuItem>
)}

View file

@ -8,7 +8,7 @@ import {
History,
MoreHorizontal,
Move,
PenLine,
Pencil,
Trash2,
} from "lucide-react";
import React, { useCallback, useRef, useState } from "react";
@ -266,7 +266,7 @@ export const DocumentNode = React.memo(function DocumentNode({
</DropdownMenuItem>
{isEditable && (
<DropdownMenuItem onClick={() => onEdit(doc)}>
<PenLine className="mr-2 h-4 w-4" />
<Pencil className="mr-2 h-4 w-4" />
Edit
</DropdownMenuItem>
)}
@ -309,7 +309,7 @@ export const DocumentNode = React.memo(function DocumentNode({
</ContextMenuItem>
{isEditable && (
<ContextMenuItem onClick={() => onEdit(doc)}>
<PenLine className="mr-2 h-4 w-4" />
<Pencil className="mr-2 h-4 w-4" />
Edit
</ContextMenuItem>
)}

View file

@ -12,7 +12,7 @@ import {
FolderPlus,
MoreHorizontal,
Move,
PenLine,
Pencil,
RefreshCw,
Trash2,
} from "lucide-react";
@ -399,7 +399,7 @@ export const FolderNode = React.memo(function FolderNode({
startRename();
}}
>
<PenLine className="mr-2 h-4 w-4" />
<Pencil className="mr-2 h-4 w-4" />
Rename
</DropdownMenuItem>
<DropdownMenuItem
@ -456,7 +456,7 @@ export const FolderNode = React.memo(function FolderNode({
New subfolder
</ContextMenuItem>
<ContextMenuItem onClick={() => startRename()}>
<PenLine className="mr-2 h-4 w-4" />
<Pencil className="mr-2 h-4 w-4" />
Rename
</ContextMenuItem>
<ContextMenuItem onClick={() => onMove(folder)}>

View file

@ -1,18 +1,31 @@
"use client";
import { useAtomValue, useSetAtom } from "jotai";
import { Download, FileQuestionMark, FileText, Loader2, RefreshCw, XIcon } from "lucide-react";
import {
Check,
Copy,
Download,
FileQuestionMark,
FileText,
Pencil,
RefreshCw,
XIcon,
} from "lucide-react";
import dynamic from "next/dynamic";
import { useCallback, useEffect, useRef, useState } from "react";
import { toast } from "sonner";
import { closeEditorPanelAtom, editorPanelAtom } from "@/atoms/editor/editor-panel.atom";
import { VersionHistoryButton } from "@/components/documents/version-history";
import { SourceCodeEditor } from "@/components/editor/source-code-editor";
import { MarkdownViewer } from "@/components/markdown-viewer";
import { Alert, AlertDescription } from "@/components/ui/alert";
import { Button } from "@/components/ui/button";
import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer";
import { Spinner } from "@/components/ui/spinner";
import { useMediaQuery } from "@/hooks/use-media-query";
import { useElectronAPI } from "@/hooks/use-platform";
import { authenticatedFetch, getBearerToken, redirectToLogin } from "@/lib/auth-utils";
import { inferMonacoLanguageFromPath } from "@/lib/editor-language";
const PlateEditor = dynamic(
() => import("@/components/editor/plate-editor").then((m) => ({ default: m.PlateEditor })),
@ -32,6 +45,7 @@ interface EditorContent {
}
const EDITABLE_DOCUMENT_TYPES = new Set(["FILE", "NOTE"]);
type EditorRenderMode = "rich_markdown" | "source_code";
function EditorPanelSkeleton() {
return (
@ -54,27 +68,38 @@ function EditorPanelSkeleton() {
}
export function EditorPanelContent({
kind = "document",
documentId,
localFilePath,
searchSpaceId,
title,
onClose,
}: {
documentId: number;
searchSpaceId: number;
kind?: "document" | "local_file";
documentId?: number;
localFilePath?: string;
searchSpaceId?: number;
title: string | null;
onClose?: () => void;
}) {
const electronAPI = useElectronAPI();
const [editorDoc, setEditorDoc] = useState<EditorContent | null>(null);
const [isLoading, setIsLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const [saving, setSaving] = useState(false);
const [downloading, setDownloading] = useState(false);
const [isEditing, setIsEditing] = useState(false);
const [editedMarkdown, setEditedMarkdown] = useState<string | null>(null);
const [localFileContent, setLocalFileContent] = useState("");
const [hasCopied, setHasCopied] = useState(false);
const markdownRef = useRef<string>("");
const copyResetTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const initialLoadDone = useRef(false);
const changeCountRef = useRef(0);
const [displayTitle, setDisplayTitle] = useState(title || "Untitled");
const isLocalFileMode = kind === "local_file";
const editorRenderMode: EditorRenderMode = isLocalFileMode ? "source_code" : "rich_markdown";
const isLargeDocument = (editorDoc?.content_size_bytes ?? 0) > LARGE_DOCUMENT_THRESHOLD;
@ -84,17 +109,48 @@ export function EditorPanelContent({
setError(null);
setEditorDoc(null);
setEditedMarkdown(null);
setLocalFileContent("");
setHasCopied(false);
setIsEditing(false);
initialLoadDone.current = false;
changeCountRef.current = 0;
const doFetch = async () => {
const token = getBearerToken();
if (!token) {
redirectToLogin();
return;
}
try {
if (isLocalFileMode) {
if (!localFilePath) {
throw new Error("Missing local file path");
}
if (!electronAPI?.readAgentLocalFileText) {
throw new Error("Local file editor is available only in desktop mode.");
}
const readResult = await electronAPI.readAgentLocalFileText(localFilePath);
if (!readResult.ok) {
throw new Error(readResult.error || "Failed to read local file");
}
const inferredTitle = localFilePath.split("/").pop() || localFilePath;
const content: EditorContent = {
document_id: -1,
title: inferredTitle,
document_type: "NOTE",
source_markdown: readResult.content,
};
markdownRef.current = content.source_markdown;
setLocalFileContent(content.source_markdown);
setDisplayTitle(title || inferredTitle);
setEditorDoc(content);
initialLoadDone.current = true;
return;
}
if (!documentId || !searchSpaceId) {
throw new Error("Missing document context");
}
const token = getBearerToken();
if (!token) {
redirectToLogin();
return;
}
const url = new URL(
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content`
);
@ -136,7 +192,15 @@ export function EditorPanelContent({
doFetch().catch(() => {});
return () => controller.abort();
}, [documentId, searchSpaceId, title]);
}, [documentId, electronAPI, isLocalFileMode, localFilePath, searchSpaceId, title]);
useEffect(() => {
return () => {
if (copyResetTimeoutRef.current) {
clearTimeout(copyResetTimeoutRef.current);
}
};
}, []);
const handleMarkdownChange = useCallback((md: string) => {
markdownRef.current = md;
@ -146,16 +210,55 @@ export function EditorPanelContent({
setEditedMarkdown(md);
}, []);
const handleSave = useCallback(async () => {
const token = getBearerToken();
if (!token) {
toast.error("Please login to save");
redirectToLogin();
return;
const handleCopy = useCallback(async () => {
try {
const textToCopy = markdownRef.current ?? editorDoc?.source_markdown ?? "";
await navigator.clipboard.writeText(textToCopy);
setHasCopied(true);
if (copyResetTimeoutRef.current) {
clearTimeout(copyResetTimeoutRef.current);
}
copyResetTimeoutRef.current = setTimeout(() => {
setHasCopied(false);
}, 1400);
} catch (err) {
console.error("Error copying content:", err);
}
}, [editorDoc?.source_markdown]);
const handleSave = useCallback(async (options?: { silent?: boolean }) => {
setSaving(true);
try {
if (isLocalFileMode) {
if (!localFilePath) {
throw new Error("Missing local file path");
}
if (!electronAPI?.writeAgentLocalFileText) {
throw new Error("Local file editor is available only in desktop mode.");
}
const contentToSave = markdownRef.current;
const writeResult = await electronAPI.writeAgentLocalFileText(
localFilePath,
contentToSave
);
if (!writeResult.ok) {
throw new Error(writeResult.error || "Failed to save local file");
}
setEditorDoc((prev) =>
prev ? { ...prev, source_markdown: contentToSave } : prev
);
setEditedMarkdown(markdownRef.current === contentToSave ? null : markdownRef.current);
return true;
}
if (!searchSpaceId || !documentId) {
throw new Error("Missing document context");
}
const token = getBearerToken();
if (!token) {
toast.error("Please login to save");
redirectToLogin();
return;
}
const response = await authenticatedFetch(
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`,
{
@ -175,39 +278,190 @@ export function EditorPanelContent({
setEditorDoc((prev) => (prev ? { ...prev, source_markdown: markdownRef.current } : prev));
setEditedMarkdown(null);
toast.success("Document saved! Reindexing in background...");
return true;
} catch (err) {
console.error("Error saving document:", err);
toast.error(err instanceof Error ? err.message : "Failed to save document");
return false;
} finally {
setSaving(false);
}
}, [documentId, searchSpaceId]);
}, [documentId, electronAPI, isLocalFileMode, localFilePath, searchSpaceId]);
const isEditableType = editorDoc
? EDITABLE_DOCUMENT_TYPES.has(editorDoc.document_type ?? "") && !isLargeDocument
? (editorRenderMode === "source_code" ||
EDITABLE_DOCUMENT_TYPES.has(editorDoc.document_type ?? "")) &&
!isLargeDocument
: false;
const hasUnsavedChanges = editedMarkdown !== null;
const showDesktopHeader = !!onClose;
const showEditingActions = isEditableType && isEditing;
const localFileLanguage = inferMonacoLanguageFromPath(localFilePath);
const handleCancelEditing = useCallback(() => {
const savedContent = editorDoc?.source_markdown ?? "";
markdownRef.current = savedContent;
setLocalFileContent(savedContent);
setEditedMarkdown(null);
changeCountRef.current = 0;
setIsEditing(false);
}, [editorDoc?.source_markdown]);
return (
<>
<div className="flex items-center justify-between px-4 py-2 shrink-0 border-b">
<div className="flex-1 min-w-0">
<h2 className="text-sm font-semibold truncate">{displayTitle}</h2>
{isEditableType && editedMarkdown !== null && (
<p className="text-[10px] text-muted-foreground">Unsaved changes</p>
)}
{showDesktopHeader ? (
<div className="shrink-0 border-b">
<div className="flex h-14 items-center justify-between px-4">
<h2 className="text-lg font-medium text-muted-foreground select-none">File</h2>
<div className="flex items-center gap-1 shrink-0">
<Button variant="ghost" size="icon" onClick={onClose} className="size-7 shrink-0">
<XIcon className="size-4" />
<span className="sr-only">Close editor panel</span>
</Button>
</div>
</div>
<div className="flex h-10 items-center justify-between gap-2 border-t px-4">
<div className="min-w-0 flex-1">
<p className="truncate text-sm text-muted-foreground">{displayTitle}</p>
</div>
<div className="flex items-center gap-1 shrink-0">
{showEditingActions ? (
<>
<Button
variant="ghost"
size="sm"
className="h-6 px-2 text-xs"
onClick={handleCancelEditing}
disabled={saving}
>
Cancel
</Button>
<Button
variant="secondary"
size="sm"
className="relative h-6 w-[56px] px-0 text-xs"
onClick={async () => {
const saveSucceeded = await handleSave({ silent: true });
if (saveSucceeded) setIsEditing(false);
}}
disabled={saving || !hasUnsavedChanges}
>
<span className={saving ? "opacity-0" : ""}>Save</span>
{saving && <Spinner size="xs" className="absolute" />}
</Button>
</>
) : (
<>
<Button
variant="ghost"
size="icon"
className="size-6"
onClick={() => {
void handleCopy();
}}
disabled={isLoading || !editorDoc}
>
{hasCopied ? <Check className="size-3.5" /> : <Copy className="size-3.5" />}
<span className="sr-only">
{hasCopied ? "Copied file contents" : "Copy file contents"}
</span>
</Button>
{isEditableType && (
<Button
variant="ghost"
size="icon"
className="size-6"
onClick={() => {
changeCountRef.current = 0;
setEditedMarkdown(null);
setIsEditing(true);
}}
>
<Pencil className="size-3.5" />
<span className="sr-only">Edit document</span>
</Button>
)}
</>
)}
{!showEditingActions && !isLocalFileMode && editorDoc?.document_type && documentId && (
<VersionHistoryButton documentId={documentId} documentType={editorDoc.document_type} />
)}
</div>
</div>
</div>
<div className="flex items-center gap-1 shrink-0">
{editorDoc?.document_type && (
<VersionHistoryButton documentId={documentId} documentType={editorDoc.document_type} />
)}
{onClose && (
<Button variant="ghost" size="icon" onClick={onClose} className="size-7 shrink-0">
<XIcon className="size-4" />
<span className="sr-only">Close editor panel</span>
</Button>
)}
) : (
<div className="flex h-14 items-center justify-between border-b px-4 shrink-0">
<div className="flex-1 min-w-0">
<h2 className="text-sm font-semibold truncate">{displayTitle}</h2>
</div>
<div className="flex items-center gap-1 shrink-0">
{showEditingActions ? (
<>
<Button
variant="ghost"
size="sm"
className="h-6 px-2 text-xs"
onClick={handleCancelEditing}
disabled={saving}
>
Cancel
</Button>
<Button
variant="secondary"
size="sm"
className="relative h-6 w-[56px] px-0 text-xs"
onClick={async () => {
const saveSucceeded = await handleSave({ silent: true });
if (saveSucceeded) setIsEditing(false);
}}
disabled={saving || !hasUnsavedChanges}
>
<span className={saving ? "opacity-0" : ""}>Save</span>
{saving && <Spinner size="xs" className="absolute" />}
</Button>
</>
) : (
<>
<Button
variant="ghost"
size="icon"
className="size-6"
onClick={() => {
void handleCopy();
}}
disabled={isLoading || !editorDoc}
>
{hasCopied ? <Check className="size-3.5" /> : <Copy className="size-3.5" />}
<span className="sr-only">
{hasCopied ? "Copied file contents" : "Copy file contents"}
</span>
</Button>
{isEditableType && (
<Button
variant="ghost"
size="icon"
className="size-6"
onClick={() => {
changeCountRef.current = 0;
setEditedMarkdown(null);
setIsEditing(true);
}}
>
<Pencil className="size-3.5" />
<span className="sr-only">Edit document</span>
</Button>
)}
{!isLocalFileMode && editorDoc?.document_type && documentId && (
<VersionHistoryButton
documentId={documentId}
documentType={editorDoc.document_type}
/>
)}
</>
)}
</div>
</div>
</div>
)}
<div className="flex-1 overflow-hidden">
{isLoading ? (
@ -234,7 +488,7 @@ export function EditorPanelContent({
</p>
</div>
</div>
) : isLargeDocument ? (
) : isLargeDocument && !isLocalFileMode ? (
<div className="h-full overflow-y-auto px-5 py-4">
<Alert className="mb-4">
<FileText className="size-4" />
@ -252,6 +506,9 @@ export function EditorPanelContent({
onClick={async () => {
setDownloading(true);
try {
if (!searchSpaceId || !documentId) {
throw new Error("Missing document context");
}
const response = await authenticatedFetch(
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/download-markdown`,
{ method: "GET" }
@ -277,7 +534,7 @@ export function EditorPanelContent({
}}
>
{downloading ? (
<Loader2 className="size-3.5 animate-spin" />
<Spinner size="xs" />
) : (
<Download className="size-3.5" />
)}
@ -287,19 +544,36 @@ export function EditorPanelContent({
</Alert>
<MarkdownViewer content={editorDoc.source_markdown} />
</div>
) : editorRenderMode === "source_code" ? (
<div className="h-full overflow-hidden">
<SourceCodeEditor
path={localFilePath ?? "local-file.txt"}
language={localFileLanguage}
value={localFileContent}
onSave={() => {
void handleSave({ silent: true });
}}
readOnly={!isEditing}
onChange={(next) => {
markdownRef.current = next;
setLocalFileContent(next);
if (!initialLoadDone.current) return;
setEditedMarkdown(next === (editorDoc?.source_markdown ?? "") ? null : next);
}}
/>
</div>
) : isEditableType ? (
<PlateEditor
key={documentId}
key={`${isLocalFileMode ? localFilePath ?? "local-file" : documentId}-${isEditing ? "editing" : "viewing"}`}
preset="full"
markdown={editorDoc.source_markdown}
onMarkdownChange={handleMarkdownChange}
readOnly={false}
readOnly={!isEditing}
placeholder="Start writing..."
editorVariant="default"
onSave={handleSave}
hasUnsavedChanges={editedMarkdown !== null}
isSaving={saving}
defaultEditing={true}
allowModeToggle={false}
reserveToolbarSpace
defaultEditing={isEditing}
className="[&_[role=toolbar]]:!bg-sidebar"
/>
) : (
@ -324,13 +598,19 @@ function DesktopEditorPanel() {
return () => document.removeEventListener("keydown", handleKeyDown);
}, [closePanel]);
if (!panelState.isOpen || !panelState.documentId || !panelState.searchSpaceId) return null;
const hasTarget =
panelState.kind === "document"
? !!panelState.documentId && !!panelState.searchSpaceId
: !!panelState.localFilePath;
if (!panelState.isOpen || !hasTarget) return null;
return (
<div className="flex w-[50%] max-w-[700px] min-w-[380px] flex-col border-l bg-sidebar text-sidebar-foreground animate-in slide-in-from-right-4 duration-300 ease-out">
<EditorPanelContent
documentId={panelState.documentId}
searchSpaceId={panelState.searchSpaceId}
kind={panelState.kind}
documentId={panelState.documentId ?? undefined}
localFilePath={panelState.localFilePath ?? undefined}
searchSpaceId={panelState.searchSpaceId ?? undefined}
title={panelState.title}
onClose={closePanel}
/>
@ -342,7 +622,13 @@ function MobileEditorDrawer() {
const panelState = useAtomValue(editorPanelAtom);
const closePanel = useSetAtom(closeEditorPanelAtom);
if (!panelState.documentId || !panelState.searchSpaceId) return null;
if (panelState.kind === "local_file") return null;
const hasTarget =
panelState.kind === "document"
? !!panelState.documentId && !!panelState.searchSpaceId
: !!panelState.localFilePath;
if (!hasTarget) return null;
return (
<Drawer
@ -360,8 +646,10 @@ function MobileEditorDrawer() {
<DrawerTitle className="sr-only">{panelState.title || "Editor"}</DrawerTitle>
<div className="min-h-0 flex-1 flex flex-col overflow-hidden">
<EditorPanelContent
documentId={panelState.documentId}
searchSpaceId={panelState.searchSpaceId}
kind={panelState.kind}
documentId={panelState.documentId ?? undefined}
localFilePath={panelState.localFilePath ?? undefined}
searchSpaceId={panelState.searchSpaceId ?? undefined}
title={panelState.title}
/>
</div>
@ -373,8 +661,13 @@ function MobileEditorDrawer() {
export function EditorPanel() {
const panelState = useAtomValue(editorPanelAtom);
const isDesktop = useMediaQuery("(min-width: 1024px)");
const hasTarget =
panelState.kind === "document"
? !!panelState.documentId && !!panelState.searchSpaceId
: !!panelState.localFilePath;
if (!panelState.isOpen || !panelState.documentId) return null;
if (!panelState.isOpen || !hasTarget) return null;
if (!isDesktop && panelState.kind === "local_file") return null;
if (isDesktop) {
return <DesktopEditorPanel />;
@ -386,8 +679,12 @@ export function EditorPanel() {
export function MobileEditorPanel() {
const panelState = useAtomValue(editorPanelAtom);
const isDesktop = useMediaQuery("(min-width: 1024px)");
const hasTarget =
panelState.kind === "document"
? !!panelState.documentId && !!panelState.searchSpaceId
: !!panelState.localFilePath;
if (isDesktop || !panelState.isOpen || !panelState.documentId) return null;
if (isDesktop || !panelState.isOpen || !hasTarget || panelState.kind === "local_file") return null;
return <MobileEditorDrawer />;
}

View file

@ -11,12 +11,15 @@ interface EditorSaveContextValue {
isSaving: boolean;
/** Whether the user can toggle between editing and viewing modes */
canToggleMode: boolean;
/** Whether fixed-toolbar space should be reserved even when controls are hidden */
reserveToolbarSpace: boolean;
}
export const EditorSaveContext = createContext<EditorSaveContextValue>({
hasUnsavedChanges: false,
isSaving: false,
canToggleMode: false,
reserveToolbarSpace: false,
});
export function useEditorSave() {

View file

@ -42,6 +42,10 @@ export interface PlateEditorProps {
hasUnsavedChanges?: boolean;
/** Whether a save is in progress */
isSaving?: boolean;
/** Whether edit/view mode toggle UI should be available in toolbars. */
allowModeToggle?: boolean;
/** Reserve fixed-toolbar vertical space even when controls are hidden. */
reserveToolbarSpace?: boolean;
/** Start the editor in editing mode instead of viewing mode. Ignored when readOnly is true. */
defaultEditing?: boolean;
/**
@ -91,6 +95,8 @@ export function PlateEditor({
onSave,
hasUnsavedChanges = false,
isSaving = false,
allowModeToggle = true,
reserveToolbarSpace = false,
defaultEditing = false,
preset = "full",
extraPlugins = [],
@ -174,7 +180,7 @@ export function PlateEditor({
}, [html, markdown, editor]);
// When not forced read-only, the user can toggle between editing/viewing.
const canToggleMode = !readOnly;
const canToggleMode = !readOnly && allowModeToggle;
const contextProviderValue = useMemo(
() => ({
@ -182,8 +188,9 @@ export function PlateEditor({
hasUnsavedChanges,
isSaving,
canToggleMode,
reserveToolbarSpace,
}),
[onSave, hasUnsavedChanges, isSaving, canToggleMode]
[onSave, hasUnsavedChanges, isSaving, canToggleMode, reserveToolbarSpace]
);
return (

View file

@ -1,19 +1,40 @@
"use client";
import { createPlatePlugin } from "platejs/react";
import { useEditorReadOnly } from "platejs/react";
import { useEditorSave } from "@/components/editor/editor-save-context";
import { FixedToolbar } from "@/components/ui/fixed-toolbar";
import { FixedToolbarButtons } from "@/components/ui/fixed-toolbar-buttons";
function ConditionalFixedToolbar() {
const readOnly = useEditorReadOnly();
const { onSave, hasUnsavedChanges, canToggleMode, reserveToolbarSpace } = useEditorSave();
const hasVisibleControls =
!readOnly || canToggleMode || (!!onSave && hasUnsavedChanges && !readOnly);
if (!hasVisibleControls) {
if (!reserveToolbarSpace) return null;
return (
<FixedToolbar className="pointer-events-none opacity-0">
<div className="h-8 w-full" />
</FixedToolbar>
);
}
return (
<FixedToolbar>
<FixedToolbarButtons />
</FixedToolbar>
);
}
export const FixedToolbarKit = [
createPlatePlugin({
key: "fixed-toolbar",
render: {
beforeEditable: () => (
<FixedToolbar>
<FixedToolbarButtons />
</FixedToolbar>
),
beforeEditable: () => <ConditionalFixedToolbar />,
},
}),
];

View file

@ -0,0 +1,152 @@
"use client";
import dynamic from "next/dynamic";
import { useEffect, useRef } from "react";
import { useTheme } from "next-themes";
import { Spinner } from "@/components/ui/spinner";
const MonacoEditor = dynamic(() => import("@monaco-editor/react"), {
ssr: false,
});
interface SourceCodeEditorProps {
value: string;
onChange: (next: string) => void;
path?: string;
language?: string;
readOnly?: boolean;
fontSize?: number;
onSave?: () => Promise<void> | void;
}
export function SourceCodeEditor({
value,
onChange,
path,
language = "plaintext",
readOnly = false,
fontSize = 12,
onSave,
}: SourceCodeEditorProps) {
const { resolvedTheme } = useTheme();
const onSaveRef = useRef(onSave);
const monacoRef = useRef<any>(null);
const normalizedModelPath = (() => {
const raw = (path || "local-file.txt").trim();
const withLeadingSlash = raw.startsWith("/") ? raw : `/${raw}`;
// Monaco model paths should be stable and POSIX-like across platforms.
return withLeadingSlash.replace(/\\/g, "/").replace(/\/{2,}/g, "/");
})();
useEffect(() => {
onSaveRef.current = onSave;
}, [onSave]);
const resolveCssColorToHex = (cssColorValue: string): string | null => {
if (typeof document === "undefined") return null;
const probe = document.createElement("div");
probe.style.color = cssColorValue;
probe.style.position = "absolute";
probe.style.pointerEvents = "none";
probe.style.opacity = "0";
document.body.appendChild(probe);
const computedColor = getComputedStyle(probe).color;
probe.remove();
const match = computedColor.match(/rgba?\((\d+),\s*(\d+),\s*(\d+)/i);
if (!match) return null;
const toHex = (value: string) => Number(value).toString(16).padStart(2, "0");
return `#${toHex(match[1])}${toHex(match[2])}${toHex(match[3])}`;
};
const applySidebarTheme = (monaco: any) => {
const isDark = resolvedTheme === "dark";
const themeName = isDark ? "surfsense-dark" : "surfsense-light";
const fallbackBg = isDark ? "#1e1e1e" : "#ffffff";
const sidebarBgHex = resolveCssColorToHex("var(--sidebar)") ?? fallbackBg;
monaco.editor.defineTheme(themeName, {
base: isDark ? "vs-dark" : "vs",
inherit: true,
rules: [],
colors: {
"editor.background": sidebarBgHex,
"editorGutter.background": sidebarBgHex,
"minimap.background": sidebarBgHex,
"editorLineNumber.background": sidebarBgHex,
"editor.lineHighlightBackground": "#00000000",
},
});
monaco.editor.setTheme(themeName);
};
useEffect(() => {
if (!monacoRef.current) return;
applySidebarTheme(monacoRef.current);
}, [resolvedTheme]);
const isManualSaveEnabled = !!onSave && !readOnly;
return (
<div className="h-full w-full overflow-hidden bg-sidebar [&_.monaco-editor]:!bg-sidebar [&_.monaco-editor_.margin]:!bg-sidebar [&_.monaco-editor_.monaco-editor-background]:!bg-sidebar [&_.monaco-editor-background]:!bg-sidebar [&_.monaco-scrollable-element_.scrollbar_.slider]:rounded-full [&_.monaco-scrollable-element_.scrollbar_.slider]:bg-foreground/25 [&_.monaco-scrollable-element_.scrollbar_.slider:hover]:bg-foreground/40">
<MonacoEditor
path={normalizedModelPath}
language={language}
value={value}
theme={resolvedTheme === "dark" ? "surfsense-dark" : "surfsense-light"}
onChange={(next) => onChange(next ?? "")}
loading={
<div className="flex h-full w-full items-center justify-center">
<Spinner size="md" className="text-muted-foreground" />
</div>
}
beforeMount={(monaco) => {
monacoRef.current = monaco;
applySidebarTheme(monaco);
}}
onMount={(editor, monaco) => {
monacoRef.current = monaco;
applySidebarTheme(monaco);
if (!isManualSaveEnabled) return;
editor.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.KeyS, () => {
void onSaveRef.current?.();
});
}}
options={{
automaticLayout: true,
minimap: { enabled: false },
lineNumbers: "on",
lineNumbersMinChars: 3,
lineDecorationsWidth: 12,
glyphMargin: false,
folding: true,
overviewRulerLanes: 0,
hideCursorInOverviewRuler: true,
scrollBeyondLastLine: false,
renderLineHighlight: "none",
selectionHighlight: false,
occurrencesHighlight: "off",
quickSuggestions: false,
suggestOnTriggerCharacters: false,
acceptSuggestionOnEnter: "off",
parameterHints: { enabled: false },
wordBasedSuggestions: "off",
wordWrap: "off",
scrollbar: {
vertical: "auto",
horizontal: "auto",
verticalScrollbarSize: 8,
horizontalScrollbarSize: 8,
alwaysConsumeMouseWheel: false,
},
tabSize: 2,
insertSpaces: true,
fontSize,
fontFamily:
"ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, Liberation Mono, monospace",
renderWhitespace: "selection",
smoothScrolling: true,
readOnly,
}}
/>
</div>
);
}

View file

@ -9,7 +9,7 @@ import { Switch } from "@/components/ui/switch";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { useAnonymousMode } from "@/contexts/anonymous-mode";
import { useLoginGate } from "@/contexts/login-gate";
import { BACKEND_URL } from "@/lib/env-config";
import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service";
import { cn } from "@/lib/utils";
const ANON_ALLOWED_EXTENSIONS = new Set([
@ -128,24 +128,12 @@ export const FreeComposer: FC = () => {
}
try {
const formData = new FormData();
formData.append("file", file);
const res = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/upload`, {
method: "POST",
credentials: "include",
body: formData,
});
if (res.status === 409) {
gate("upload more documents");
const result = await anonymousChatApiService.uploadDocument(file);
if (!result.ok) {
if (result.reason === "quota_exceeded") gate("upload more documents");
return;
}
if (!res.ok) {
const body = await res.json().catch(() => ({}));
throw new Error(body.detail || `Upload failed: ${res.status}`);
}
const data = await res.json();
const data = result.data;
if (anonMode.isAnonymous) {
anonMode.setUploadedDoc({
filename: data.filename,

View file

@ -1,7 +1,7 @@
"use client";
import { useAtom, useAtomValue, useSetAtom } from "jotai";
import { PanelRight, PanelRightClose } from "lucide-react";
import { PanelRight } from "lucide-react";
import dynamic from "next/dynamic";
import { startTransition, useEffect } from "react";
import { closeHitlEditPanelAtom, hitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
@ -49,11 +49,11 @@ function CollapseButton({ onClick }: { onClick: () => void }) {
<Tooltip>
<TooltipTrigger asChild>
<Button variant="ghost" size="icon" onClick={onClick} className="h-8 w-8 shrink-0">
<PanelRightClose className="h-4 w-4" />
<PanelRight className="h-4 w-4" />
<span className="sr-only">Collapse panel</span>
</Button>
</TooltipTrigger>
<TooltipContent side="left">Collapse panel</TooltipContent>
<TooltipContent side="bottom">Collapse panel</TooltipContent>
</Tooltip>
);
}
@ -70,7 +70,11 @@ export function RightPanelExpandButton() {
const editorState = useAtomValue(editorPanelAtom);
const hitlEditState = useAtomValue(hitlEditPanelAtom);
const reportOpen = reportState.isOpen && !!reportState.reportId;
const editorOpen = editorState.isOpen && !!editorState.documentId;
const editorOpen =
editorState.isOpen &&
(editorState.kind === "document"
? !!editorState.documentId
: !!editorState.localFilePath);
const hitlEditOpen = hitlEditState.isOpen && !!hitlEditState.onSave;
const hasContent = documentsOpen || reportOpen || editorOpen || hitlEditOpen;
@ -90,7 +94,7 @@ export function RightPanelExpandButton() {
<span className="sr-only">Expand panel</span>
</Button>
</TooltipTrigger>
<TooltipContent side="left">Expand panel</TooltipContent>
<TooltipContent side="bottom">Expand panel</TooltipContent>
</Tooltip>
</div>
);
@ -110,7 +114,11 @@ export function RightPanel({ documentsPanel }: RightPanelProps) {
const documentsOpen = documentsPanel?.open ?? false;
const reportOpen = reportState.isOpen && !!reportState.reportId;
const editorOpen = editorState.isOpen && !!editorState.documentId;
const editorOpen =
editorState.isOpen &&
(editorState.kind === "document"
? !!editorState.documentId
: !!editorState.localFilePath);
const hitlEditOpen = hitlEditState.isOpen && !!hitlEditState.onSave;
useEffect(() => {
@ -179,8 +187,10 @@ export function RightPanel({ documentsPanel }: RightPanelProps) {
{effectiveTab === "editor" && editorOpen && (
<div className="h-full flex flex-col">
<EditorPanelContent
documentId={editorState.documentId as number}
searchSpaceId={editorState.searchSpaceId as number}
kind={editorState.kind}
documentId={editorState.documentId ?? undefined}
localFilePath={editorState.localFilePath ?? undefined}
searchSpaceId={editorState.searchSpaceId ?? undefined}
title={editorState.title}
onClose={closeEditor}
/>

View file

@ -8,7 +8,7 @@ import {
ChevronLeft,
MessageCircleMore,
MoreHorizontal,
PenLine,
Pencil,
RotateCcwIcon,
Search,
Trash2,
@ -429,7 +429,7 @@ export function AllPrivateChatsSidebarContent({
<DropdownMenuItem
onClick={() => handleStartRename(thread.id, thread.title || "New Chat")}
>
<PenLine className="mr-2 h-4 w-4" />
<Pencil className="mr-2 h-4 w-4" />
<span>{t("rename") || "Rename"}</span>
</DropdownMenuItem>
)}

View file

@ -8,7 +8,7 @@ import {
ChevronLeft,
MessageCircleMore,
MoreHorizontal,
PenLine,
Pencil,
RotateCcwIcon,
Search,
Trash2,
@ -428,7 +428,7 @@ export function AllSharedChatsSidebarContent({
<DropdownMenuItem
onClick={() => handleStartRename(thread.id, thread.title || "New Chat")}
>
<PenLine className="mr-2 h-4 w-4" />
<Pencil className="mr-2 h-4 w-4" />
<span>{t("rename") || "Rename"}</span>
</DropdownMenuItem>
)}

View file

@ -1,6 +1,6 @@
"use client";
import { ArchiveIcon, MoreHorizontal, PenLine, RotateCcwIcon, Trash2 } from "lucide-react";
import { ArchiveIcon, MoreHorizontal, Pencil, RotateCcwIcon, Trash2 } from "lucide-react";
import { useTranslations } from "next-intl";
import { useCallback, useState } from "react";
import { Button } from "@/components/ui/button";
@ -106,7 +106,7 @@ export function ChatListItem({
onRename();
}}
>
<PenLine className="mr-2 h-4 w-4" />
<Pencil className="mr-2 h-4 w-4" />
<span>{t("rename") || "Rename"}</span>
</DropdownMenuItem>
)}

View file

@ -6,9 +6,14 @@ import {
ChevronLeft,
ChevronRight,
FileText,
Folder,
FolderPlus,
FolderClock,
Laptop,
Lock,
Paperclip,
Search,
Server,
Trash2,
Unplug,
Upload,
@ -58,8 +63,19 @@ import {
} from "@/components/ui/alert-dialog";
import { Avatar, AvatarFallback, AvatarGroup } from "@/components/ui/avatar";
import { Button } from "@/components/ui/button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer";
import { Input } from "@/components/ui/input";
import { Separator } from "@/components/ui/separator";
import { Spinner } from "@/components/ui/spinner";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { useAnonymousMode, useIsAnonymous } from "@/contexts/anonymous-mode";
import { useLoginGate } from "@/contexts/login-gate";
@ -68,17 +84,39 @@ import type { DocumentTypeEnum } from "@/contracts/types/document.types";
import { useDebouncedValue } from "@/hooks/use-debounced-value";
import { useMediaQuery } from "@/hooks/use-media-query";
import { useElectronAPI } from "@/hooks/use-platform";
import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { foldersApiService } from "@/lib/apis/folders-api.service";
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
import { authenticatedFetch } from "@/lib/auth-utils";
import { BACKEND_URL } from "@/lib/env-config";
import { uploadFolderScan } from "@/lib/folder-sync-upload";
import { getSupportedExtensionsSet } from "@/lib/supported-extensions";
import { queries } from "@/zero/queries/index";
import { LocalFilesystemBrowser } from "./LocalFilesystemBrowser";
import { SidebarSlideOutPanel } from "./SidebarSlideOutPanel";
const NON_DELETABLE_DOCUMENT_TYPES: readonly string[] = ["SURFSENSE_DOCS"];
const LOCAL_FILESYSTEM_TRUST_KEY = "surfsense.local-filesystem-trust.v1";
const MAX_LOCAL_FILESYSTEM_ROOTS = 5;
type FilesystemSettings = {
mode: "cloud" | "desktop_local_folder";
localRootPaths: string[];
updatedAt: string;
};
interface WatchedFolderEntry {
path: string;
name: string;
excludePatterns: string[];
fileExtensions: string[] | null;
rootFolderId: number | null;
searchSpaceId: number;
active: boolean;
}
const getFolderDisplayName = (rootPath: string): string =>
rootPath.split(/[\\/]/).at(-1) || rootPath;
const SHOWCASE_CONNECTORS = [
{ type: "GOOGLE_DRIVE_CONNECTOR", label: "Google Drive" },
@ -133,12 +171,119 @@ function AuthenticatedDocumentsSidebar({
const [search, setSearch] = useState("");
const debouncedSearch = useDebouncedValue(search, 250);
const [localSearch, setLocalSearch] = useState("");
const debouncedLocalSearch = useDebouncedValue(localSearch, 250);
const localSearchInputRef = useRef<HTMLInputElement>(null);
const [activeTypes, setActiveTypes] = useState<DocumentTypeEnum[]>([]);
const [filesystemSettings, setFilesystemSettings] = useState<FilesystemSettings | null>(null);
const [localTrustDialogOpen, setLocalTrustDialogOpen] = useState(false);
const [pendingLocalPath, setPendingLocalPath] = useState<string | null>(null);
const [watchedFolderIds, setWatchedFolderIds] = useState<Set<number>>(new Set());
const [folderWatchOpen, setFolderWatchOpen] = useAtom(folderWatchDialogOpenAtom);
const [watchInitialFolder, setWatchInitialFolder] = useAtom(folderWatchInitialFolderAtom);
const isElectron = typeof window !== "undefined" && !!window.electronAPI;
useEffect(() => {
if (!electronAPI?.getAgentFilesystemSettings) return;
let mounted = true;
electronAPI
.getAgentFilesystemSettings()
.then((settings: FilesystemSettings) => {
if (!mounted) return;
setFilesystemSettings(settings);
})
.catch(() => {
if (!mounted) return;
setFilesystemSettings({
mode: "cloud",
localRootPaths: [],
updatedAt: new Date().toISOString(),
});
});
return () => {
mounted = false;
};
}, [electronAPI]);
const hasLocalFilesystemTrust = useCallback(() => {
try {
return window.localStorage.getItem(LOCAL_FILESYSTEM_TRUST_KEY) === "true";
} catch {
return false;
}
}, []);
const localRootPaths = filesystemSettings?.localRootPaths ?? [];
const canAddMoreLocalRoots = localRootPaths.length < MAX_LOCAL_FILESYSTEM_ROOTS;
const applyLocalRootPath = useCallback(
async (path: string) => {
if (!electronAPI?.setAgentFilesystemSettings) return;
const nextLocalRootPaths = [...localRootPaths, path]
.filter((rootPath, index, allPaths) => allPaths.indexOf(rootPath) === index)
.slice(0, MAX_LOCAL_FILESYSTEM_ROOTS);
if (nextLocalRootPaths.length === localRootPaths.length) return;
const updated = await electronAPI.setAgentFilesystemSettings({
mode: "desktop_local_folder",
localRootPaths: nextLocalRootPaths,
});
setFilesystemSettings(updated);
},
[electronAPI, localRootPaths]
);
const runPickLocalRoot = useCallback(async () => {
if (!electronAPI?.pickAgentFilesystemRoot) return;
const picked = await electronAPI.pickAgentFilesystemRoot();
if (!picked) return;
await applyLocalRootPath(picked);
}, [applyLocalRootPath, electronAPI]);
const handlePickFilesystemRoot = useCallback(async () => {
if (!canAddMoreLocalRoots) return;
if (hasLocalFilesystemTrust()) {
await runPickLocalRoot();
return;
}
if (!electronAPI?.pickAgentFilesystemRoot) return;
const picked = await electronAPI.pickAgentFilesystemRoot();
if (!picked) return;
setPendingLocalPath(picked);
setLocalTrustDialogOpen(true);
}, [canAddMoreLocalRoots, electronAPI, hasLocalFilesystemTrust, runPickLocalRoot]);
const handleRemoveFilesystemRoot = useCallback(
async (rootPathToRemove: string) => {
if (!electronAPI?.setAgentFilesystemSettings) return;
const updated = await electronAPI.setAgentFilesystemSettings({
mode: "desktop_local_folder",
localRootPaths: localRootPaths.filter((rootPath) => rootPath !== rootPathToRemove),
});
setFilesystemSettings(updated);
},
[electronAPI, localRootPaths]
);
const handleClearFilesystemRoots = useCallback(async () => {
if (!electronAPI?.setAgentFilesystemSettings) return;
const updated = await electronAPI.setAgentFilesystemSettings({
mode: "desktop_local_folder",
localRootPaths: [],
});
setFilesystemSettings(updated);
}, [electronAPI]);
const handleFilesystemTabChange = useCallback(
async (tab: "cloud" | "local") => {
if (!electronAPI?.setAgentFilesystemSettings) return;
const updated = await electronAPI.setAgentFilesystemSettings({
mode: tab === "cloud" ? "cloud" : "desktop_local_folder",
});
setFilesystemSettings(updated);
},
[electronAPI]
);
// AI File Sort state
const { data: searchSpaces, refetch: refetchSearchSpaces } = useAtomValue(searchSpacesAtom);
const activeSearchSpace = useMemo(
@ -196,7 +341,7 @@ function AuthenticatedDocumentsSidebar({
if (!electronAPI?.getWatchedFolders) return;
const api = electronAPI;
const folders = await api.getWatchedFolders();
const folders = (await api.getWatchedFolders()) as WatchedFolderEntry[];
if (folders.length === 0) {
try {
@ -214,9 +359,11 @@ function AuthenticatedDocumentsSidebar({
active: true,
});
}
const recovered = await api.getWatchedFolders();
const recovered = (await api.getWatchedFolders()) as WatchedFolderEntry[];
const ids = new Set(
recovered.filter((f) => f.rootFolderId != null).map((f) => f.rootFolderId as number)
recovered
.filter((f: WatchedFolderEntry) => f.rootFolderId != null)
.map((f: WatchedFolderEntry) => f.rootFolderId as number)
);
setWatchedFolderIds(ids);
return;
@ -226,7 +373,9 @@ function AuthenticatedDocumentsSidebar({
}
const ids = new Set(
folders.filter((f) => f.rootFolderId != null).map((f) => f.rootFolderId as number)
folders
.filter((f: WatchedFolderEntry) => f.rootFolderId != null)
.map((f: WatchedFolderEntry) => f.rootFolderId as number)
);
setWatchedFolderIds(ids);
}, [searchSpaceId, electronAPI]);
@ -375,8 +524,8 @@ function AuthenticatedDocumentsSidebar({
async (folder: FolderDisplay) => {
if (!electronAPI) return;
const watchedFolders = await electronAPI.getWatchedFolders();
const matched = watchedFolders.find((wf) => wf.rootFolderId === folder.id);
const watchedFolders = (await electronAPI.getWatchedFolders()) as WatchedFolderEntry[];
const matched = watchedFolders.find((wf: WatchedFolderEntry) => wf.rootFolderId === folder.id);
if (!matched) {
toast.error("This folder is not being watched");
return;
@ -405,8 +554,8 @@ function AuthenticatedDocumentsSidebar({
async (folder: FolderDisplay) => {
if (!electronAPI) return;
const watchedFolders = await electronAPI.getWatchedFolders();
const matched = watchedFolders.find((wf) => wf.rootFolderId === folder.id);
const watchedFolders = (await electronAPI.getWatchedFolders()) as WatchedFolderEntry[];
const matched = watchedFolders.find((wf: WatchedFolderEntry) => wf.rootFolderId === folder.id);
if (!matched) {
toast.error("This folder is not being watched");
return;
@ -438,8 +587,10 @@ function AuthenticatedDocumentsSidebar({
if (!confirm(`Delete folder "${folder.name}" and all its contents?`)) return;
try {
if (electronAPI) {
const watchedFolders = await electronAPI.getWatchedFolders();
const matched = watchedFolders.find((wf) => wf.rootFolderId === folder.id);
const watchedFolders = (await electronAPI.getWatchedFolders()) as WatchedFolderEntry[];
const matched = watchedFolders.find(
(wf: WatchedFolderEntry) => wf.rootFolderId === folder.id
);
if (matched) {
await electronAPI.removeWatchedFolder(matched.path);
}
@ -836,59 +987,11 @@ function AuthenticatedDocumentsSidebar({
return () => document.removeEventListener("keydown", handleEscape);
}, [open, onOpenChange, isMobile, setRightPanelCollapsed]);
const documentsContent = (
<>
<div className="shrink-0 flex h-14 items-center px-4">
<div className="flex w-full items-center justify-between">
<div className="flex items-center gap-2">
{isMobile && (
<Button
variant="ghost"
size="icon"
className="h-8 w-8 rounded-full"
onClick={() => onOpenChange(false)}
>
<ChevronLeft className="h-4 w-4 text-muted-foreground" />
<span className="sr-only">{tSidebar("close") || "Close"}</span>
</Button>
)}
<h2 className="select-none text-lg font-semibold">{t("title") || "Documents"}</h2>
</div>
<div className="flex items-center gap-1">
{!isMobile && onDockedChange && (
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="icon"
className="h-8 w-8 rounded-full"
onClick={() => {
if (isDocked) {
onDockedChange(false);
onOpenChange(false);
} else {
onDockedChange(true);
}
}}
>
{isDocked ? (
<ChevronLeft className="h-4 w-4 text-muted-foreground" />
) : (
<ChevronRight className="h-4 w-4 text-muted-foreground" />
)}
<span className="sr-only">{isDocked ? "Collapse panel" : "Expand panel"}</span>
</Button>
</TooltipTrigger>
<TooltipContent className="z-80">
{isDocked ? "Collapse panel" : "Expand panel"}
</TooltipContent>
</Tooltip>
)}
{headerAction}
</div>
</div>
</div>
const showFilesystemTabs = !isMobile && !!electronAPI && !!filesystemSettings;
const currentFilesystemTab = filesystemSettings?.mode === "desktop_local_folder" ? "local" : "cloud";
const cloudContent = (
<>
{/* Connected tools strip */}
<div className="shrink-0 mx-4 mt-4 mb-4 flex select-none items-center gap-2 rounded-lg border bg-muted/50 transition-colors hover:bg-muted/80">
<button
@ -1039,6 +1142,231 @@ function AuthenticatedDocumentsSidebar({
/>
</div>
</div>
</>
);
const localContent = (
<div className="flex min-h-0 flex-1 flex-col select-none">
<div className="mx-4 mt-4 mb-3">
<div className="flex h-7 w-full items-stretch rounded-lg border bg-muted/50 text-[11px] text-muted-foreground">
{localRootPaths.length > 0 ? (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
type="button"
className="min-w-0 flex-1 flex items-center gap-1 rounded-l-lg px-2 text-left transition-colors hover:bg-muted/80 focus-visible:outline-none focus-visible:ring-0 focus-visible:ring-offset-0"
title={localRootPaths.join("\n")}
aria-label="Manage selected folders"
>
<Folder className="size-3 shrink-0 text-muted-foreground" />
<span className="truncate">
{localRootPaths.length === 1
? "1 folder selected"
: `${localRootPaths.length} folders selected`}
</span>
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="start" className="w-56 select-none p-0.5">
<DropdownMenuLabel className="px-1.5 pt-1.5 pb-0.5 text-xs font-medium text-muted-foreground">
Selected folders
</DropdownMenuLabel>
<DropdownMenuSeparator className="mx-1 my-0.5" />
{localRootPaths.map((rootPath) => (
<DropdownMenuItem
key={rootPath}
onClick={() => {
void handleRemoveFilesystemRoot(rootPath);
}}
className="group h-8 gap-1.5 px-1.5 text-sm text-foreground"
>
<Folder className="size-3.5 text-muted-foreground" />
<span className="min-w-0 flex-1 truncate">
{getFolderDisplayName(rootPath)}
</span>
<X className="size-3 text-muted-foreground transition-colors group-hover:text-foreground" />
</DropdownMenuItem>
))}
<DropdownMenuSeparator className="mx-1 my-0.5" />
<DropdownMenuItem
variant="destructive"
className="h-8 px-1.5 text-xs text-destructive focus:text-destructive"
onClick={() => {
void handleClearFilesystemRoots();
}}
>
Clear all folders
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
) : (
<div
className="min-w-0 flex-1 flex items-center gap-1 px-2"
title="No local folders selected"
>
<Folder className="size-3 shrink-0 text-muted-foreground" />
<span className="truncate">No local folders selected</span>
</div>
)}
<Separator
orientation="vertical"
className="data-[orientation=vertical]:h-3 self-center bg-border"
/>
<button
type="button"
className="flex w-8 items-center justify-center rounded-r-lg text-muted-foreground transition-colors hover:bg-muted/80 hover:text-foreground focus-visible:outline-none focus-visible:ring-0 focus-visible:ring-offset-0 disabled:opacity-50"
onClick={() => {
void handlePickFilesystemRoot();
}}
disabled={!canAddMoreLocalRoots}
aria-label="Add folder"
title="Add folder"
>
<FolderPlus className="size-3.5" />
</button>
</div>
</div>
<div className="mx-4 mb-2">
<div className="relative flex-1 min-w-0">
<div className="pointer-events-none absolute inset-y-0 left-0 flex items-center pl-3 text-muted-foreground">
<Search size={13} aria-hidden="true" />
</div>
<Input
ref={localSearchInputRef}
className="peer h-8 w-full pl-8 pr-8 text-sm bg-sidebar border-border/60 select-none focus:select-text"
value={localSearch}
onChange={(e) => setLocalSearch(e.target.value)}
placeholder="Search local files"
type="text"
aria-label="Search local files"
/>
{Boolean(localSearch) && (
<button
type="button"
className="absolute inset-y-0 right-0 flex h-full w-8 items-center justify-center rounded-r-md text-muted-foreground hover:text-foreground transition-colors"
aria-label="Clear local search"
onClick={() => {
setLocalSearch("");
localSearchInputRef.current?.focus();
}}
>
<X size={13} strokeWidth={2} aria-hidden="true" />
</button>
)}
</div>
</div>
<LocalFilesystemBrowser
rootPaths={localRootPaths}
searchSpaceId={searchSpaceId}
searchQuery={debouncedLocalSearch.trim() || undefined}
onOpenFile={(localFilePath) => {
openEditorPanel({
kind: "local_file",
localFilePath,
title: localFilePath.split("/").pop() || localFilePath,
searchSpaceId,
});
}}
/>
</div>
);
const documentsContent = (
<>
<div className="shrink-0 flex h-14 items-center px-4">
<div className="flex w-full items-center justify-between">
<div className="flex items-center gap-3">
{isMobile && (
<Button
variant="ghost"
size="icon"
className="h-8 w-8 rounded-full"
onClick={() => onOpenChange(false)}
>
<ChevronLeft className="h-4 w-4 text-muted-foreground" />
<span className="sr-only">{tSidebar("close") || "Close"}</span>
</Button>
)}
<h2 className="select-none text-lg font-semibold">{t("title") || "Documents"}</h2>
{showFilesystemTabs && (
<Tabs
value={currentFilesystemTab}
onValueChange={(value) => {
void handleFilesystemTabChange(value === "local" ? "local" : "cloud");
}}
>
<TabsList className="h-6 gap-0 rounded-md bg-muted/60 p-0.5 select-none">
<TabsTrigger
value="cloud"
className="h-5 gap-1 px-1.5 text-[11px] select-none focus-visible:ring-0 focus-visible:ring-offset-0 data-[state=active]:bg-muted-foreground/25 data-[state=active]:text-foreground data-[state=active]:shadow-none"
title="Cloud"
>
<Server className="size-3" />
<span>Cloud</span>
</TabsTrigger>
<TabsTrigger
value="local"
className="h-5 gap-1 px-1.5 text-[11px] select-none focus-visible:ring-0 focus-visible:ring-offset-0 data-[state=active]:bg-muted-foreground/25 data-[state=active]:text-foreground data-[state=active]:shadow-none"
title="Local"
>
<Laptop className="size-3" />
<span>Local</span>
</TabsTrigger>
</TabsList>
</Tabs>
)}
</div>
<div className="flex items-center gap-1">
{!isMobile && onDockedChange && (
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="icon"
className="h-8 w-8 rounded-full"
onClick={() => {
if (isDocked) {
onDockedChange(false);
onOpenChange(false);
} else {
onDockedChange(true);
}
}}
>
{isDocked ? (
<ChevronLeft className="h-4 w-4 text-muted-foreground" />
) : (
<ChevronRight className="h-4 w-4 text-muted-foreground" />
)}
<span className="sr-only">{isDocked ? "Collapse panel" : "Expand panel"}</span>
</Button>
</TooltipTrigger>
<TooltipContent className="z-80">
{isDocked ? "Collapse panel" : "Expand panel"}
</TooltipContent>
</Tooltip>
)}
{headerAction}
</div>
</div>
</div>
{showFilesystemTabs ? (
<Tabs
value={currentFilesystemTab}
onValueChange={(value) => {
void handleFilesystemTabChange(value === "local" ? "local" : "cloud");
}}
className="flex min-h-0 flex-1 flex-col"
>
<TabsContent value="cloud" className="mt-0 flex min-h-0 flex-1 flex-col">
{cloudContent}
</TabsContent>
<TabsContent value="local" className="mt-0 flex min-h-0 flex-1 flex-col">
{localContent}
</TabsContent>
</Tabs>
) : (
cloudContent
)}
{versionDocId !== null && (
<VersionHistoryDialog
@ -1062,6 +1390,48 @@ function AuthenticatedDocumentsSidebar({
onSuccess={refreshWatchedIds}
/>
)}
<AlertDialog
open={localTrustDialogOpen}
onOpenChange={(nextOpen) => {
setLocalTrustDialogOpen(nextOpen);
if (!nextOpen) setPendingLocalPath(null);
}}
>
<AlertDialogContent className="sm:max-w-md select-none">
<AlertDialogHeader>
<AlertDialogTitle>Trust this workspace?</AlertDialogTitle>
<AlertDialogDescription>
Local mode can read and edit files inside the folders you select. Continue only if
you trust this workspace and its contents.
</AlertDialogDescription>
{pendingLocalPath && (
<AlertDialogDescription className="mt-1 whitespace-pre-wrap break-words font-mono text-xs">
Folder path: {pendingLocalPath}
</AlertDialogDescription>
)}
</AlertDialogHeader>
<AlertDialogFooter>
<AlertDialogCancel>Cancel</AlertDialogCancel>
<AlertDialogAction
onClick={async () => {
try {
window.localStorage.setItem(LOCAL_FILESYSTEM_TRUST_KEY, "true");
} catch {}
setLocalTrustDialogOpen(false);
const path = pendingLocalPath;
setPendingLocalPath(null);
if (path) {
await applyLocalRootPath(path);
} else {
await runPickLocalRoot();
}
}}
>
I trust this workspace
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
<FolderPickerDialog
open={folderPickerOpen}
@ -1312,24 +1682,12 @@ function AnonymousDocumentsSidebar({
setIsUploading(true);
try {
const formData = new FormData();
formData.append("file", file);
const res = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/upload`, {
method: "POST",
credentials: "include",
body: formData,
});
if (res.status === 409) {
gate("upload more documents");
const result = await anonymousChatApiService.uploadDocument(file);
if (!result.ok) {
if (result.reason === "quota_exceeded") gate("upload more documents");
return;
}
if (!res.ok) {
const body = await res.json().catch(() => ({}));
throw new Error(body.detail || `Upload failed: ${res.status}`);
}
const data = await res.json();
const data = result.data;
if (anonMode.isAnonymous) {
anonMode.setUploadedDoc({
filename: data.filename,

View file

@ -0,0 +1,314 @@
"use client";
import { ChevronDown, ChevronRight, FileText, Folder } from "lucide-react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { DEFAULT_EXCLUDE_PATTERNS } from "@/components/sources/FolderWatchDialog";
import { Spinner } from "@/components/ui/spinner";
import { useElectronAPI } from "@/hooks/use-platform";
import { getSupportedExtensionsSet } from "@/lib/supported-extensions";
interface LocalFilesystemBrowserProps {
rootPaths: string[];
searchSpaceId: number;
searchQuery?: string;
onOpenFile: (fullPath: string) => void;
}
interface LocalFolderFileEntry {
relativePath: string;
fullPath: string;
size: number;
mtimeMs: number;
}
type RootLoadState = {
loading: boolean;
error: string | null;
files: LocalFolderFileEntry[];
};
interface LocalFolderNode {
key: string;
name: string;
folders: Map<string, LocalFolderNode>;
files: LocalFolderFileEntry[];
}
type LocalRootMount = {
mount: string;
rootPath: string;
};
const getFolderDisplayName = (rootPath: string): string =>
rootPath.split(/[\\/]/).at(-1) || rootPath;
function createFolderNode(key: string, name: string): LocalFolderNode {
return {
key,
name,
folders: new Map(),
files: [],
};
}
function getFileName(pathValue: string): string {
return pathValue.split(/[\\/]/).at(-1) || pathValue;
}
function toVirtualPath(relativePath: string): string {
const normalized = relativePath.replace(/\\/g, "/").replace(/^\/+/, "");
return `/${normalized}`;
}
function normalizeRootPathForLookup(rootPath: string, isWindows: boolean): string {
const normalized = rootPath.replace(/\\/g, "/").replace(/\/+$/, "");
return isWindows ? normalized.toLowerCase() : normalized;
}
function toMountedVirtualPath(mount: string, relativePath: string): string {
return `/${mount}${toVirtualPath(relativePath)}`;
}
export function LocalFilesystemBrowser({
rootPaths,
searchSpaceId,
searchQuery,
onOpenFile,
}: LocalFilesystemBrowserProps) {
const electronAPI = useElectronAPI();
const [rootStateMap, setRootStateMap] = useState<Record<string, RootLoadState>>({});
const [expandedFolderKeys, setExpandedFolderKeys] = useState<Set<string>>(new Set());
const [mountByRootKey, setMountByRootKey] = useState<Map<string, string>>(new Map());
const supportedExtensions = useMemo(() => Array.from(getSupportedExtensionsSet()), []);
const isWindowsPlatform = electronAPI?.versions.platform === "win32";
useEffect(() => {
if (!electronAPI?.listFolderFiles) return;
let cancelled = false;
for (const rootPath of rootPaths) {
setRootStateMap((prev) => ({
...prev,
[rootPath]: {
loading: true,
error: null,
files: prev[rootPath]?.files ?? [],
},
}));
}
void Promise.all(
rootPaths.map(async (rootPath) => {
try {
const files = (await electronAPI.listFolderFiles({
path: rootPath,
name: getFolderDisplayName(rootPath),
excludePatterns: DEFAULT_EXCLUDE_PATTERNS,
fileExtensions: supportedExtensions,
rootFolderId: null,
searchSpaceId,
active: true,
})) as LocalFolderFileEntry[];
if (cancelled) return;
setRootStateMap((prev) => ({
...prev,
[rootPath]: {
loading: false,
error: null,
files,
},
}));
} catch (error) {
if (cancelled) return;
setRootStateMap((prev) => ({
...prev,
[rootPath]: {
loading: false,
error: error instanceof Error ? error.message : "Failed to read folder",
files: [],
},
}));
}
})
);
return () => {
cancelled = true;
};
}, [electronAPI, rootPaths, searchSpaceId, supportedExtensions]);
useEffect(() => {
if (!electronAPI?.getAgentFilesystemMounts) {
setMountByRootKey(new Map());
return;
}
let cancelled = false;
void electronAPI
.getAgentFilesystemMounts()
.then((mounts: LocalRootMount[]) => {
if (cancelled) return;
const next = new Map<string, string>();
for (const entry of mounts) {
next.set(normalizeRootPathForLookup(entry.rootPath, isWindowsPlatform), entry.mount);
}
setMountByRootKey(next);
})
.catch(() => {
if (cancelled) return;
setMountByRootKey(new Map());
});
return () => {
cancelled = true;
};
}, [electronAPI, isWindowsPlatform, rootPaths]);
const treeByRoot = useMemo(() => {
const query = searchQuery?.trim().toLowerCase() ?? "";
const hasQuery = query.length > 0;
return rootPaths.map((rootPath) => {
const rootNode = createFolderNode(rootPath, getFolderDisplayName(rootPath));
const allFiles = rootStateMap[rootPath]?.files ?? [];
const files = hasQuery
? allFiles.filter((file) => {
const relativePath = file.relativePath.toLowerCase();
const fileName = getFileName(file.relativePath).toLowerCase();
return relativePath.includes(query) || fileName.includes(query);
})
: allFiles;
for (const file of files) {
const parts = file.relativePath.split(/[\\/]/).filter(Boolean);
let cursor = rootNode;
for (let i = 0; i < parts.length - 1; i++) {
const part = parts[i];
const folderKey = `${cursor.key}/${part}`;
if (!cursor.folders.has(part)) {
cursor.folders.set(part, createFolderNode(folderKey, part));
}
cursor = cursor.folders.get(part) as LocalFolderNode;
}
cursor.files.push(file);
}
return { rootPath, rootNode, matchCount: files.length, totalCount: allFiles.length };
});
}, [rootPaths, rootStateMap, searchQuery]);
const toggleFolder = useCallback((folderKey: string) => {
setExpandedFolderKeys((prev) => {
const next = new Set(prev);
if (next.has(folderKey)) {
next.delete(folderKey);
} else {
next.add(folderKey);
}
return next;
});
}, []);
const renderFolder = useCallback(
(folder: LocalFolderNode, depth: number, mount: string) => {
const isExpanded = expandedFolderKeys.has(folder.key);
const childFolders = Array.from(folder.folders.values()).sort((a, b) =>
a.name.localeCompare(b.name)
);
const files = [...folder.files].sort((a, b) => a.relativePath.localeCompare(b.relativePath));
return (
<div key={folder.key} className="select-none">
<button
type="button"
onClick={() => toggleFolder(folder.key)}
className="flex h-8 w-full items-center gap-1.5 rounded-md px-2 text-left text-sm transition-colors hover:bg-muted/60"
style={{ paddingInlineStart: `${depth * 12 + 8}px` }}
draggable={false}
>
{isExpanded ? (
<ChevronDown className="size-3.5 shrink-0 text-muted-foreground" />
) : (
<ChevronRight className="size-3.5 shrink-0 text-muted-foreground" />
)}
<Folder className="size-3.5 shrink-0 text-muted-foreground" />
<span className="truncate">{folder.name}</span>
</button>
{isExpanded && (
<>
{childFolders.map((childFolder) => renderFolder(childFolder, depth + 1, mount))}
{files.map((file) => (
<button
key={file.fullPath}
type="button"
onClick={() => onOpenFile(toMountedVirtualPath(mount, file.relativePath))}
className="flex h-8 w-full items-center gap-1.5 rounded-md px-2 text-left text-sm transition-colors hover:bg-muted/60"
style={{ paddingInlineStart: `${(depth + 1) * 12 + 22}px` }}
title={file.fullPath}
draggable={false}
>
<FileText className="size-3.5 shrink-0 text-muted-foreground" />
<span className="truncate">{getFileName(file.relativePath)}</span>
</button>
))}
</>
)}
</div>
);
},
[expandedFolderKeys, onOpenFile, toggleFolder]
);
if (rootPaths.length === 0) {
return (
<div className="flex flex-1 flex-col items-center justify-center gap-2 px-4 py-10 text-center text-muted-foreground">
<p className="text-sm font-medium">No local folder selected</p>
<p className="text-xs text-muted-foreground/80">
Add a local folder above to browse files in desktop mode.
</p>
</div>
);
}
return (
<div className="flex-1 min-h-0 overflow-y-auto px-2 py-2">
{treeByRoot.map(({ rootPath, rootNode, matchCount, totalCount }) => {
const state = rootStateMap[rootPath];
const rootKey = normalizeRootPathForLookup(rootPath, isWindowsPlatform);
const mount = mountByRootKey.get(rootKey);
if (!state || state.loading) {
return (
<div key={rootPath} className="flex h-16 items-center gap-2 px-3 text-sm text-muted-foreground">
<Spinner size="sm" />
<span>Loading {getFolderDisplayName(rootPath)}...</span>
</div>
);
}
if (state.error) {
return (
<div key={rootPath} className="rounded-md border border-destructive/20 bg-destructive/5 p-3">
<p className="text-sm font-medium text-destructive">Failed to load local folder</p>
<p className="mt-1 text-xs text-muted-foreground">{state.error}</p>
</div>
);
}
const isEmpty = totalCount === 0;
return (
<div key={rootPath} className="mb-1">
{mount ? renderFolder(rootNode, 0, mount) : null}
{!mount && (
<div className="px-3 pb-2 text-xs text-muted-foreground/80">
Unable to resolve mounted root for this folder.
</div>
)}
{isEmpty && (
<div className="px-3 pb-2 text-xs text-muted-foreground/80">
No supported files found in this folder.
</div>
)}
{!isEmpty && matchCount === 0 && searchQuery && (
<div className="px-3 pb-2 text-xs text-muted-foreground/80">
No matching files in this folder.
</div>
)}
</div>
);
})}
</div>
);
}

View file

@ -1,6 +1,6 @@
"use client";
import { CreditCard, PenSquare, Zap } from "lucide-react";
import { CreditCard, SquarePen, Zap } from "lucide-react";
import Link from "next/link";
import { useParams } from "next/navigation";
import { useTranslations } from "next-intl";
@ -139,7 +139,7 @@ export function Sidebar({
{/* New chat button */}
<div className={cn("flex flex-col gap-0.5 py-2", isCollapsed && "items-center")}>
<SidebarButton
icon={PenSquare}
icon={SquarePen}
label={t("new_chat")}
onClick={onNewChat}
isCollapsed={isCollapsed}

View file

@ -1,6 +1,6 @@
"use client";
import { PanelLeft, PanelLeftClose } from "lucide-react";
import { PanelLeft } from "lucide-react";
import { useTranslations } from "next-intl";
import { Button } from "@/components/ui/button";
import { ShortcutKbd } from "@/components/ui/shortcut-kbd";
@ -23,7 +23,7 @@ export function SidebarCollapseButton({
const button = (
<Button variant="ghost" size="icon" onClick={onToggle} className="h-8 w-8 shrink-0">
{isCollapsed ? <PanelLeft className="h-4 w-4" /> : <PanelLeftClose className="h-4 w-4" />}
<PanelLeft className="h-4 w-4" />
<span className="sr-only">{isCollapsed ? t("expand_sidebar") : t("collapse_sidebar")}</span>
</Button>
);

View file

@ -7,8 +7,8 @@ import {
ExternalLink,
Info,
Languages,
Laptop,
LogOut,
Monitor,
Moon,
Sun,
UserCog,
@ -49,7 +49,7 @@ const LANGUAGES = [
const THEMES = [
{ value: "light" as const, name: "Light", icon: Sun },
{ value: "dark" as const, name: "Dark", icon: Moon },
{ value: "system" as const, name: "System", icon: Laptop },
{ value: "system" as const, name: "System", icon: Monitor },
];
const LEARN_MORE_LINKS = [

View file

@ -1,6 +1,6 @@
"use client";
import { Download, FileQuestionMark, FileText, Loader2, PenLine, RefreshCw } from "lucide-react";
import { Download, FileQuestionMark, FileText, Loader2, Pencil, RefreshCw } from "lucide-react";
import { useRouter } from "next/navigation";
import { useCallback, useEffect, useRef, useState } from "react";
import { toast } from "sonner";
@ -258,7 +258,7 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen
onClick={() => setIsEditing(true)}
className="gap-1.5"
>
<PenLine className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -1,7 +1,7 @@
"use client";
import { useAtomValue, useSetAtom } from "jotai";
import { ChevronDownIcon, XIcon } from "lucide-react";
import { Check, ChevronDownIcon, Copy, Pencil, XIcon } from "lucide-react";
import dynamic from "next/dynamic";
import { useCallback, useEffect, useRef, useState } from "react";
import { toast } from "sonner";
@ -116,6 +116,7 @@ export function ReportPanelContent({
const [exporting, setExporting] = useState<string | null>(null);
const [saving, setSaving] = useState(false);
const copyTimerRef = useRef<ReturnType<typeof setTimeout> | undefined>(undefined);
const changeCountRef = useRef(0);
useEffect(() => {
return () => {
@ -125,6 +126,7 @@ export function ReportPanelContent({
// Editor state — tracks the latest markdown from the Plate editor
const [editedMarkdown, setEditedMarkdown] = useState<string | null>(null);
const [isEditing, setIsEditing] = useState(false);
// Read-only when public (shareToken) OR shared (SEARCH_SPACE visibility)
const currentThreadState = useAtomValue(currentThreadAtom);
@ -188,8 +190,22 @@ export function ReportPanelContent({
// Reset edited markdown when switching versions or reports
useEffect(() => {
setEditedMarkdown(null);
setIsEditing(false);
changeCountRef.current = 0;
}, [activeReportId]);
const handleReportMarkdownChange = useCallback(
(nextMarkdown: string) => {
if (!isEditing) return;
changeCountRef.current += 1;
// Plate may emit an initial normalize/serialize change on mount.
if (changeCountRef.current <= 1) return;
const savedMarkdown = reportContent?.content ?? "";
setEditedMarkdown(nextMarkdown === savedMarkdown ? null : nextMarkdown);
},
[isEditing, reportContent?.content]
);
// Copy markdown content (uses latest editor content)
const handleCopy = useCallback(async () => {
if (!currentMarkdown) return;
@ -257,7 +273,7 @@ export function ReportPanelContent({
// Save edited report content
const handleSave = useCallback(async () => {
if (!currentMarkdown || !activeReportId) return;
if (!currentMarkdown || !activeReportId) return false;
setSaving(true);
try {
const response = await authenticatedFetch(
@ -278,9 +294,11 @@ export function ReportPanelContent({
setReportContent((prev) => (prev ? { ...prev, content: currentMarkdown } : prev));
setEditedMarkdown(null);
toast.success("Report saved successfully");
return true;
} catch (err) {
console.error("Error saving report:", err);
toast.error(err instanceof Error ? err.message : "Failed to save report");
return false;
} finally {
setSaving(false);
}
@ -288,26 +306,21 @@ export function ReportPanelContent({
const activeVersionIndex = versions.findIndex((v) => v.id === activeReportId);
const isPublic = !!shareToken;
const btnBg = isPublic ? "bg-main-panel" : "bg-sidebar";
const isResume = reportContent?.content_type === "typst";
const showReportEditingTier = !isResume;
const hasUnsavedChanges = editedMarkdown !== null;
const handleCancelEditing = useCallback(() => {
setEditedMarkdown(null);
changeCountRef.current = 0;
setIsEditing(false);
}, []);
return (
<>
{/* Action bar — always visible; buttons are disabled while loading */}
<div className="flex h-14 items-center justify-between px-4 shrink-0">
<div className="flex items-center gap-2">
{/* Copy button — hidden for Typst (resume) */}
{reportContent?.content_type !== "typst" && (
<Button
variant="outline"
size="sm"
onClick={handleCopy}
disabled={isLoading || !reportContent?.content}
className={`h-8 min-w-[80px] px-3.5 py-4 text-[15px] ${btnBg} select-none`}
>
{copied ? "Copied" : "Copy"}
</Button>
)}
{/* Export — plain button for resume (typst), dropdown for others */}
{reportContent?.content_type === "typst" ? (
<Button
@ -315,7 +328,7 @@ export function ReportPanelContent({
size="sm"
onClick={() => handleExport("pdf")}
disabled={isLoading || !reportContent?.content || exporting !== null}
className={`h-8 min-w-[100px] px-3.5 py-4 text-[15px] ${btnBg} select-none`}
className={`h-8 min-w-[100px] px-3.5 py-4 text-[15px] ${isPublic ? "bg-main-panel" : "bg-sidebar"} select-none`}
>
{exporting === "pdf" ? <Spinner size="xs" /> : "Download"}
</Button>
@ -326,7 +339,7 @@ export function ReportPanelContent({
variant="outline"
size="sm"
disabled={isLoading || !reportContent?.content}
className={`h-8 px-3.5 py-4 text-[15px] gap-1.5 ${btnBg} select-none`}
className={`h-8 px-3.5 py-4 text-[15px] gap-1.5 ${isPublic ? "bg-main-panel" : "bg-sidebar"} select-none`}
>
Export
<ChevronDownIcon className="size-3" />
@ -352,7 +365,7 @@ export function ReportPanelContent({
<Button
variant="outline"
size="sm"
className={`h-8 px-3.5 py-4 text-[15px] gap-1.5 ${btnBg} select-none`}
className={`h-8 px-3.5 py-4 text-[15px] gap-1.5 ${isPublic ? "bg-main-panel" : "bg-sidebar"} select-none`}
>
v{activeVersionIndex + 1}
<ChevronDownIcon className="size-3" />
@ -383,6 +396,75 @@ export function ReportPanelContent({
)}
</div>
{showReportEditingTier && (
<div className="flex h-10 items-center justify-between gap-2 border-t border-b px-4 shrink-0">
<div className="min-w-0 flex-1">
<p className="truncate text-sm text-muted-foreground">
{reportContent?.title || title}
</p>
</div>
<div className="flex items-center gap-1 shrink-0">
{!isEditing && (
<Button
variant="ghost"
size="icon"
className="size-6"
onClick={() => {
void handleCopy();
}}
disabled={isLoading || !reportContent?.content}
>
{copied ? <Check className="size-3.5" /> : <Copy className="size-3.5" />}
<span className="sr-only">
{copied ? "Copied report content" : "Copy report content"}
</span>
</Button>
)}
{!isReadOnly &&
(isEditing ? (
<>
<Button
variant="ghost"
size="sm"
className="h-6 px-2 text-xs"
onClick={handleCancelEditing}
disabled={saving}
>
Cancel
</Button>
<Button
variant="secondary"
size="sm"
className="relative h-6 w-[56px] px-0 text-xs"
onClick={async () => {
const saveSucceeded = await handleSave();
if (saveSucceeded) setIsEditing(false);
}}
disabled={saving || !hasUnsavedChanges}
>
<span className={saving ? "opacity-0" : ""}>Save</span>
{saving && <Spinner size="xs" className="absolute" />}
</Button>
</>
) : (
<Button
variant="ghost"
size="icon"
className="size-6"
onClick={() => {
setEditedMarkdown(null);
changeCountRef.current = 0;
setIsEditing(true);
}}
>
<Pencil className="size-3.5" />
<span className="sr-only">Edit report</span>
</Button>
))}
</div>
</div>
)}
{/* Report content — skeleton/error/viewer/editor shown only in this area */}
<div className="flex-1 overflow-hidden">
{isLoading ? (
@ -406,15 +488,16 @@ export function ReportPanelContent({
</div>
) : (
<PlateEditor
key={`report-${activeReportId}-${isEditing ? "editing" : "viewing"}`}
preset="full"
markdown={reportContent.content}
onMarkdownChange={setEditedMarkdown}
readOnly={false}
onMarkdownChange={handleReportMarkdownChange}
readOnly={!isEditing}
placeholder="Report content..."
editorVariant="default"
onSave={handleSave}
hasUnsavedChanges={editedMarkdown !== null}
isSaving={saving}
allowModeToggle={false}
reserveToolbarSpace
defaultEditing={isEditing}
className="[&_[role=toolbar]]:!bg-sidebar"
/>
)

View file

@ -2,7 +2,7 @@
import { useQuery, useQueryClient } from "@tanstack/react-query";
import { useAtomValue } from "jotai";
import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pen } from "lucide-react";
import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pencil } from "lucide-react";
import { useEffect, useRef, useState } from "react";
import { toast } from "sonner";
import { z } from "zod";
@ -247,7 +247,7 @@ export function TeamMemoryManager({ searchSpaceId }: TeamMemoryManagerProps) {
onClick={openInput}
className="absolute bottom-3 right-3 z-10 h-[54px] w-[54px] rounded-full border bg-muted/60 backdrop-blur-sm shadow-sm"
>
<Pen className="!h-5 !w-5" />
<Pencil className="!h-5 !w-5" />
</Button>
)}
</div>

View file

@ -1,7 +1,7 @@
"use client";
import { useAtom } from "jotai";
import { Brain, CircleUser, Globe, KeyRound, Monitor, ReceiptText, Sparkles } from "lucide-react";
import { Brain, CircleUser, Globe, Keyboard, KeyRound, Monitor, ReceiptText, Sparkles } from "lucide-react";
import dynamic from "next/dynamic";
import { useTranslations } from "next-intl";
import { useMemo } from "react";
@ -51,6 +51,13 @@ const DesktopContent = dynamic(
),
{ ssr: false }
);
const DesktopShortcutsContent = dynamic(
() =>
import("@/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent").then(
(m) => ({ default: m.DesktopShortcutsContent })
),
{ ssr: false }
);
const MemoryContent = dynamic(
() =>
import("@/app/dashboard/[search_space_id]/user-settings/components/MemoryContent").then(
@ -93,7 +100,18 @@ export function UserSettingsDialog() {
icon: <ReceiptText className="h-4 w-4" />,
},
...(isDesktop
? [{ value: "desktop", label: "Desktop", icon: <Monitor className="h-4 w-4" /> }]
? [
{
value: "desktop",
label: "App Preferences",
icon: <Monitor className="h-4 w-4" />,
},
{
value: "desktop-shortcuts",
label: "Hotkeys",
icon: <Keyboard className="h-4 w-4" />,
},
]
: []),
],
[t, isDesktop]
@ -116,6 +134,7 @@ export function UserSettingsDialog() {
{state.initialTab === "memory" && <MemoryContent />}
{state.initialTab === "purchases" && <PurchaseHistoryContent />}
{state.initialTab === "desktop" && <DesktopContent />}
{state.initialTab === "desktop-shortcuts" && <DesktopShortcutsContent />}
</div>
</SettingsDialog>
);

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, Pen } from "lucide-react";
import { CornerDownLeftIcon, Pencil } from "lucide-react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
import { PlateEditor } from "@/components/editor/plate-editor";
@ -222,7 +222,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, Pen } from "lucide-react";
import { CornerDownLeftIcon, Pencil } from "lucide-react";
import { useCallback, useEffect, useState } from "react";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
import { PlateEditor } from "@/components/editor/plate-editor";
@ -241,7 +241,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, FileIcon, Pen } from "lucide-react";
import { CornerDownLeftIcon, FileIcon, Pencil } from "lucide-react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
import { PlateEditor } from "@/components/editor/plate-editor";
@ -224,7 +224,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -1,8 +1,9 @@
"use client";
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
import { CornerDownLeftIcon, Pen } from "lucide-react";
import { CornerDownLeftIcon, Pencil } from "lucide-react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { toast } from "sonner";
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
@ -116,8 +117,8 @@ function GenericApprovalCard({
if (phase !== "pending" || !isMCPTool) return;
setProcessing();
onDecision({ type: "approve" });
connectorsApiService.trustMCPTool(mcpConnectorId, toolName).catch((err) => {
console.error("Failed to trust MCP tool:", err);
connectorsApiService.trustMCPTool(mcpConnectorId, toolName).catch(() => {
toast.error("Failed to save 'Always Allow' preference. The tool will still require approval next time.");
});
}, [phase, setProcessing, onDecision, isMCPTool, mcpConnectorId, toolName]);
@ -167,7 +168,7 @@ function GenericApprovalCard({
className="rounded-lg text-muted-foreground -mt-1 -mr-2"
onClick={() => setIsEditing(true)}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, Pen, UserIcon, UsersIcon } from "lucide-react";
import { CornerDownLeftIcon, Pencil, UserIcon, UsersIcon } from "lucide-react";
import { useCallback, useEffect, useMemo, useState } from "react";
import type { ExtraField } from "@/atoms/chat/hitl-edit-panel.atom";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
@ -251,7 +251,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, MailIcon, Pen, UserIcon, UsersIcon } from "lucide-react";
import { CornerDownLeftIcon, MailIcon, Pencil, UserIcon, UsersIcon } from "lucide-react";
import { useCallback, useEffect, useMemo, useState } from "react";
import type { ExtraField } from "@/atoms/chat/hitl-edit-panel.atom";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
@ -250,7 +250,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, MailIcon, Pen, UserIcon, UsersIcon } from "lucide-react";
import { CornerDownLeftIcon, MailIcon, Pencil, UserIcon, UsersIcon } from "lucide-react";
import { useCallback, useEffect, useState } from "react";
import type { ExtraField } from "@/atoms/chat/hitl-edit-panel.atom";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
@ -283,7 +283,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { ClockIcon, CornerDownLeftIcon, GlobeIcon, MapPinIcon, Pen, UsersIcon } from "lucide-react";
import { ClockIcon, CornerDownLeftIcon, GlobeIcon, MapPinIcon, Pencil, UsersIcon } from "lucide-react";
import { useCallback, useEffect, useMemo, useState } from "react";
import type { ExtraField } from "@/atoms/chat/hitl-edit-panel.atom";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
@ -332,7 +332,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -7,7 +7,7 @@ import {
ClockIcon,
CornerDownLeftIcon,
MapPinIcon,
Pen,
Pencil,
UsersIcon,
} from "lucide-react";
import { useCallback, useEffect, useState } from "react";
@ -415,7 +415,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, FileIcon, Pen } from "lucide-react";
import { CornerDownLeftIcon, FileIcon, Pencil } from "lucide-react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
import { PlateEditor } from "@/components/editor/plate-editor";
@ -240,7 +240,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, Pen } from "lucide-react";
import { CornerDownLeftIcon, Pencil } from "lucide-react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
import { PlateEditor } from "@/components/editor/plate-editor";
@ -257,7 +257,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, Pen } from "lucide-react";
import { CornerDownLeftIcon, Pencil } from "lucide-react";
import { useCallback, useEffect, useState } from "react";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
import { PlateEditor } from "@/components/editor/plate-editor";
@ -273,7 +273,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, Pen } from "lucide-react";
import { CornerDownLeftIcon, Pencil } from "lucide-react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
import { PlateEditor } from "@/components/editor/plate-editor";
@ -269,7 +269,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, Pen } from "lucide-react";
import { CornerDownLeftIcon, Pencil } from "lucide-react";
import { useCallback, useEffect, useState } from "react";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
import { PlateEditor } from "@/components/editor/plate-editor";
@ -332,7 +332,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, Pen } from "lucide-react";
import { CornerDownLeftIcon, Pencil } from "lucide-react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
import { PlateEditor } from "@/components/editor/plate-editor";
@ -219,7 +219,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, Pen } from "lucide-react";
import { CornerDownLeftIcon, Pencil } from "lucide-react";
import { useCallback, useEffect, useState } from "react";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
import { PlateEditor } from "@/components/editor/plate-editor";
@ -196,7 +196,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -2,7 +2,7 @@
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
import { useSetAtom } from "jotai";
import { CornerDownLeftIcon, FileIcon, Pen } from "lucide-react";
import { CornerDownLeftIcon, FileIcon, Pencil } from "lucide-react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
import { PlateEditor } from "@/components/editor/plate-editor";
@ -209,7 +209,7 @@ function ApprovalCard({
});
}}
>
<Pen className="size-3.5" />
<Pencil className="size-3.5" />
Edit
</Button>
)}

View file

@ -1,6 +1,6 @@
"use client";
import { BookOpenIcon, PenLineIcon } from "lucide-react";
import { BookOpenIcon, Pencil } from "lucide-react";
import { usePlateState } from "platejs/react";
import { ToolbarButton } from "./toolbar";
@ -13,7 +13,7 @@ export function ModeToolbarButton() {
tooltip={readOnly ? "Click to edit" : "Click to view"}
onClick={() => setReadOnly(!readOnly)}
>
{readOnly ? <BookOpenIcon /> : <PenLineIcon />}
{readOnly ? <BookOpenIcon /> : <Pencil />}
</ToolbarButton>
);
}

View file

@ -1,6 +1,7 @@
import {
BookOpen,
Brain,
FileUser,
FileText,
Film,
Globe,
@ -15,6 +16,7 @@ const TOOL_ICONS: Record<string, LucideIcon> = {
generate_podcast: Podcast,
generate_video_presentation: Film,
generate_report: FileText,
generate_resume: FileUser,
generate_image: ImageIcon,
scrape_webpage: ScanLine,
web_search: Globe,

View file

@ -0,0 +1,61 @@
export type AgentFilesystemMode = "cloud" | "desktop_local_folder";
export type ClientPlatform = "web" | "desktop";
export interface AgentFilesystemMountSelection {
mount_id: string;
root_path: string;
}
export interface AgentFilesystemSelection {
filesystem_mode: AgentFilesystemMode;
client_platform: ClientPlatform;
local_filesystem_mounts?: AgentFilesystemMountSelection[];
}
const DEFAULT_SELECTION: AgentFilesystemSelection = {
filesystem_mode: "cloud",
client_platform: "web",
};
export function getClientPlatform(): ClientPlatform {
if (typeof window === "undefined") return "web";
return window.electronAPI ? "desktop" : "web";
}
export async function getAgentFilesystemSelection(): Promise<AgentFilesystemSelection> {
const platform = getClientPlatform();
if (platform !== "desktop" || !window.electronAPI?.getAgentFilesystemSettings) {
return { ...DEFAULT_SELECTION, client_platform: platform };
}
try {
const settings = await window.electronAPI.getAgentFilesystemSettings();
if (settings.mode === "desktop_local_folder") {
const mounts = await window.electronAPI.getAgentFilesystemMounts?.();
const localFilesystemMounts =
mounts?.map((entry) => ({
mount_id: entry.mount,
root_path: entry.rootPath,
})) ?? [];
if (localFilesystemMounts.length === 0) {
return {
filesystem_mode: "cloud",
client_platform: "desktop",
};
}
return {
filesystem_mode: "desktop_local_folder",
client_platform: "desktop",
local_filesystem_mounts: localFilesystemMounts,
};
}
return {
filesystem_mode: "cloud",
client_platform: "desktop",
};
} catch {
return {
filesystem_mode: "cloud",
client_platform: "desktop",
};
}
}

View file

@ -12,6 +12,10 @@ import { ValidationError } from "../error";
const BASE = "/api/v1/public/anon-chat";
export type AnonUploadResult =
| { ok: true; data: { filename: string; size_bytes: number } }
| { ok: false; reason: "quota_exceeded" };
class AnonymousChatApiService {
private baseUrl: string;
@ -71,7 +75,7 @@ class AnonymousChatApiService {
});
};
uploadDocument = async (file: File): Promise<{ filename: string; size_bytes: number }> => {
uploadDocument = async (file: File): Promise<AnonUploadResult> => {
const formData = new FormData();
formData.append("file", file);
const res = await fetch(this.fullUrl("/upload"), {
@ -79,11 +83,15 @@ class AnonymousChatApiService {
credentials: "include",
body: formData,
});
if (res.status === 409) {
return { ok: false, reason: "quota_exceeded" };
}
if (!res.ok) {
const body = await res.json().catch(() => ({}));
throw new Error(body.detail || `Upload failed: ${res.status}`);
}
return res.json();
const data = await res.json();
return { ok: true, data };
};
getDocument = async (): Promise<{ filename: string; size_bytes: number } | null> => {

Some files were not shown because too many files have changed in this diff Show more