mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
feat(new-chat): integrate filesystem flow into agent pipeline
This commit is contained in:
parent
42d2d2222e
commit
1eadecee23
10 changed files with 574 additions and 25 deletions
|
|
@ -33,9 +33,12 @@ from langgraph.types import Checkpointer
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.agents.new_chat.middleware import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
FileIntentMiddleware,
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
MemoryInjectionMiddleware,
|
||||
SurfSenseFilesystemMiddleware,
|
||||
|
|
@ -164,6 +167,7 @@ async def create_surfsense_deep_agent(
|
|||
thread_visibility: ChatVisibility | None = None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
anon_session_id: str | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
):
|
||||
"""
|
||||
Create a SurfSense deep agent with configurable tools and prompts.
|
||||
|
|
@ -238,6 +242,8 @@ async def create_surfsense_deep_agent(
|
|||
)
|
||||
"""
|
||||
_t_agent_total = time.perf_counter()
|
||||
filesystem_selection = filesystem_selection or FilesystemSelection()
|
||||
backend_resolver = build_backend_resolver(filesystem_selection)
|
||||
|
||||
# Discover available connectors and document types for this search space
|
||||
available_connectors: list[str] | None = None
|
||||
|
|
@ -439,7 +445,10 @@ async def create_surfsense_deep_agent(
|
|||
gp_middleware = [
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
FileIntentMiddleware(llm=llm),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_selection.mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
|
|
@ -460,15 +469,19 @@ async def create_surfsense_deep_agent(
|
|||
deepagent_middleware = [
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
FileIntentMiddleware(llm=llm),
|
||||
KnowledgeBaseSearchMiddleware(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_selection.mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
anon_session_id=anon_session_id,
|
||||
),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_selection.mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,15 @@ Context schema definitions for SurfSense agents.
|
|||
This module defines the custom state schema used by the SurfSense deep agent.
|
||||
"""
|
||||
|
||||
from typing import TypedDict
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
|
||||
class FileOperationContractState(TypedDict):
|
||||
intent: str
|
||||
confidence: float
|
||||
suggested_path: str
|
||||
timestamp: str
|
||||
turn_id: str
|
||||
|
||||
|
||||
class SurfSenseContextSchema(TypedDict):
|
||||
|
|
@ -24,5 +32,8 @@ class SurfSenseContextSchema(TypedDict):
|
|||
"""
|
||||
|
||||
search_space_id: int
|
||||
file_operation_contract: NotRequired[FileOperationContractState]
|
||||
turn_id: NotRequired[str]
|
||||
request_id: NotRequired[str]
|
||||
# These are runtime-injected and won't be serialized
|
||||
# db_session and connector_service are passed when invoking the agent
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from app.agents.new_chat.sandbox import (
|
|||
get_or_create_sandbox,
|
||||
is_sandbox_enabled,
|
||||
)
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import Chunk, Document, DocumentType, Folder, shielded_async_session
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
from app.utils.document_converters import (
|
||||
|
|
@ -50,6 +51,8 @@ SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = """## Following Conventions
|
|||
|
||||
- Read files before editing — understand existing content before making changes.
|
||||
- Mimic existing style, naming conventions, and patterns.
|
||||
- Never claim a file was created/updated unless filesystem tool output confirms success.
|
||||
- If a file write/edit fails, explicitly report the failure.
|
||||
|
||||
## Filesystem Tools
|
||||
|
||||
|
|
@ -109,13 +112,20 @@ Usage:
|
|||
- Use chunk IDs (`<chunk id='...'>`) as citations in answers.
|
||||
"""
|
||||
|
||||
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new file to the in-memory filesystem (session-only).
|
||||
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new text file to the in-memory filesystem (session-only).
|
||||
|
||||
Use this to create scratch/working files during the conversation. Files created
|
||||
here are ephemeral and will not be saved to the user's knowledge base.
|
||||
|
||||
To permanently save a document to the user's knowledge base, use the
|
||||
`save_document` tool instead.
|
||||
|
||||
Supported outputs include common LLM-friendly text formats like markdown, json,
|
||||
yaml, csv, xml, html, css, sql, and code files.
|
||||
|
||||
When creating content from open-ended prompts, produce concrete and useful text,
|
||||
not placeholders. Avoid adding dates/timestamps unless the user explicitly asks
|
||||
for them.
|
||||
"""
|
||||
|
||||
SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files.
|
||||
|
|
@ -182,11 +192,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
backend: Any = None,
|
||||
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||
search_space_id: int | None = None,
|
||||
created_by_id: str | None = None,
|
||||
thread_id: int | str | None = None,
|
||||
tool_token_limit_before_evict: int | None = 20000,
|
||||
) -> None:
|
||||
self._filesystem_mode = filesystem_mode
|
||||
self._search_space_id = search_space_id
|
||||
self._created_by_id = created_by_id
|
||||
self._thread_id = thread_id
|
||||
|
|
@ -204,8 +217,15 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
" extract the data, write it as a clean file (CSV, JSON, etc.),"
|
||||
" and then run your code against it."
|
||||
)
|
||||
if filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
|
||||
system_prompt += (
|
||||
"\n\n## Local Folder Mode"
|
||||
"\n\nThis chat is running in desktop local-folder mode."
|
||||
" Keep all file operations local. Do not use save_document."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
backend=backend,
|
||||
system_prompt=system_prompt,
|
||||
custom_tool_descriptions={
|
||||
"ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION,
|
||||
|
|
@ -219,7 +239,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
max_execute_timeout=self._MAX_EXECUTE_TIMEOUT,
|
||||
)
|
||||
self.tools = [t for t in self.tools if t.name != "execute"]
|
||||
self.tools.append(self._create_save_document_tool())
|
||||
if self._should_persist_documents():
|
||||
self.tools.append(self._create_save_document_tool())
|
||||
if self._sandbox_available:
|
||||
self.tools.append(self._create_execute_code_tool())
|
||||
|
||||
|
|
@ -637,15 +658,25 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
runtime: ToolRuntime[None, FilesystemState],
|
||||
) -> Command | str:
|
||||
resolved_backend = self._get_backend(runtime)
|
||||
target_path = self._resolve_write_target_path(file_path, runtime)
|
||||
try:
|
||||
validated_path = validate_path(file_path)
|
||||
validated_path = validate_path(target_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
res: WriteResult = resolved_backend.write(validated_path, content)
|
||||
if res.error:
|
||||
return res.error
|
||||
verify_error = self._verify_written_content_sync(
|
||||
backend=resolved_backend,
|
||||
path=validated_path,
|
||||
expected_content=content,
|
||||
)
|
||||
if verify_error:
|
||||
return verify_error
|
||||
|
||||
if not self._is_kb_document(validated_path):
|
||||
if self._should_persist_documents() and not self._is_kb_document(
|
||||
validated_path
|
||||
):
|
||||
persist_result = self._run_async_blocking(
|
||||
self._persist_new_document(
|
||||
file_path=validated_path, content=content
|
||||
|
|
@ -682,15 +713,25 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
runtime: ToolRuntime[None, FilesystemState],
|
||||
) -> Command | str:
|
||||
resolved_backend = self._get_backend(runtime)
|
||||
target_path = self._resolve_write_target_path(file_path, runtime)
|
||||
try:
|
||||
validated_path = validate_path(file_path)
|
||||
validated_path = validate_path(target_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
res: WriteResult = await resolved_backend.awrite(validated_path, content)
|
||||
if res.error:
|
||||
return res.error
|
||||
verify_error = await self._verify_written_content_async(
|
||||
backend=resolved_backend,
|
||||
path=validated_path,
|
||||
expected_content=content,
|
||||
)
|
||||
if verify_error:
|
||||
return verify_error
|
||||
|
||||
if not self._is_kb_document(validated_path):
|
||||
if self._should_persist_documents() and not self._is_kb_document(
|
||||
validated_path
|
||||
):
|
||||
persist_result = await self._persist_new_document(
|
||||
file_path=validated_path,
|
||||
content=content,
|
||||
|
|
@ -726,6 +767,124 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
"""Return True for paths under /documents/ (KB-sourced, XML-wrapped)."""
|
||||
return path.startswith("/documents/")
|
||||
|
||||
def _should_persist_documents(self) -> bool:
|
||||
"""Only cloud mode persists file content to Document/Chunk tables."""
|
||||
return self._filesystem_mode == FilesystemMode.CLOUD
|
||||
|
||||
@staticmethod
|
||||
def _get_contract_suggested_path(runtime: ToolRuntime[None, FilesystemState]) -> str:
|
||||
contract = runtime.state.get("file_operation_contract") or {}
|
||||
suggested = contract.get("suggested_path")
|
||||
if isinstance(suggested, str) and suggested.strip():
|
||||
return suggested.strip()
|
||||
return "/notes.md"
|
||||
|
||||
def _resolve_write_target_path(
|
||||
self,
|
||||
file_path: str,
|
||||
runtime: ToolRuntime[None, FilesystemState],
|
||||
) -> str:
|
||||
candidate = file_path.strip()
|
||||
if not candidate:
|
||||
return self._get_contract_suggested_path(runtime)
|
||||
if not candidate.startswith("/"):
|
||||
return f"/{candidate.lstrip('/')}"
|
||||
return candidate
|
||||
|
||||
@staticmethod
|
||||
def _is_error_text(value: str) -> bool:
|
||||
return value.startswith("Error:")
|
||||
|
||||
@staticmethod
|
||||
def _read_for_verification_sync(backend: Any, path: str) -> str:
|
||||
read_raw = getattr(backend, "read_raw", None)
|
||||
if callable(read_raw):
|
||||
return read_raw(path)
|
||||
return backend.read(path, offset=0, limit=200000)
|
||||
|
||||
@staticmethod
|
||||
async def _read_for_verification_async(backend: Any, path: str) -> str:
|
||||
aread_raw = getattr(backend, "aread_raw", None)
|
||||
if callable(aread_raw):
|
||||
return await aread_raw(path)
|
||||
return await backend.aread(path, offset=0, limit=200000)
|
||||
|
||||
def _verify_written_content_sync(
|
||||
self,
|
||||
*,
|
||||
backend: Any,
|
||||
path: str,
|
||||
expected_content: str,
|
||||
) -> str | None:
|
||||
actual = self._read_for_verification_sync(backend, path)
|
||||
if self._is_error_text(actual):
|
||||
return f"Error: could not verify written file '{path}'."
|
||||
if actual.rstrip() != expected_content.rstrip():
|
||||
return (
|
||||
"Error: file write verification failed; expected content was not fully written "
|
||||
f"to '{path}'."
|
||||
)
|
||||
return None
|
||||
|
||||
async def _verify_written_content_async(
|
||||
self,
|
||||
*,
|
||||
backend: Any,
|
||||
path: str,
|
||||
expected_content: str,
|
||||
) -> str | None:
|
||||
actual = await self._read_for_verification_async(backend, path)
|
||||
if self._is_error_text(actual):
|
||||
return f"Error: could not verify written file '{path}'."
|
||||
if actual.rstrip() != expected_content.rstrip():
|
||||
return (
|
||||
"Error: file write verification failed; expected content was not fully written "
|
||||
f"to '{path}'."
|
||||
)
|
||||
return None
|
||||
|
||||
def _verify_edited_content_sync(
|
||||
self,
|
||||
*,
|
||||
backend: Any,
|
||||
path: str,
|
||||
new_string: str,
|
||||
) -> tuple[str | None, str | None]:
|
||||
updated_content = self._read_for_verification_sync(backend, path)
|
||||
if self._is_error_text(updated_content):
|
||||
return (
|
||||
f"Error: could not verify edited file '{path}'.",
|
||||
None,
|
||||
)
|
||||
if new_string and new_string not in updated_content:
|
||||
return (
|
||||
"Error: edit verification failed; updated content was not found in "
|
||||
f"'{path}'.",
|
||||
None,
|
||||
)
|
||||
return None, updated_content
|
||||
|
||||
async def _verify_edited_content_async(
|
||||
self,
|
||||
*,
|
||||
backend: Any,
|
||||
path: str,
|
||||
new_string: str,
|
||||
) -> tuple[str | None, str | None]:
|
||||
updated_content = await self._read_for_verification_async(backend, path)
|
||||
if self._is_error_text(updated_content):
|
||||
return (
|
||||
f"Error: could not verify edited file '{path}'.",
|
||||
None,
|
||||
)
|
||||
if new_string and new_string not in updated_content:
|
||||
return (
|
||||
"Error: edit verification failed; updated content was not found in "
|
||||
f"'{path}'.",
|
||||
None,
|
||||
)
|
||||
return None, updated_content
|
||||
|
||||
def _create_edit_file_tool(self) -> BaseTool:
|
||||
"""Create edit_file with DB persistence (skipped for KB documents)."""
|
||||
tool_description = (
|
||||
|
|
@ -754,8 +913,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
] = False,
|
||||
) -> Command | str:
|
||||
resolved_backend = self._get_backend(runtime)
|
||||
target_path = self._resolve_write_target_path(file_path, runtime)
|
||||
try:
|
||||
validated_path = validate_path(file_path)
|
||||
validated_path = validate_path(target_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
res: EditResult = resolved_backend.edit(
|
||||
|
|
@ -767,13 +927,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
if res.error:
|
||||
return res.error
|
||||
|
||||
if not self._is_kb_document(validated_path):
|
||||
read_result = resolved_backend.read(
|
||||
validated_path, offset=0, limit=200000
|
||||
)
|
||||
if read_result.error or read_result.file_data is None:
|
||||
return f"Error: could not reload edited file '{validated_path}' for persistence."
|
||||
updated_content = read_result.file_data["content"]
|
||||
verify_error, updated_content = self._verify_edited_content_sync(
|
||||
backend=resolved_backend,
|
||||
path=validated_path,
|
||||
new_string=new_string,
|
||||
)
|
||||
if verify_error:
|
||||
return verify_error
|
||||
|
||||
if self._should_persist_documents() and not self._is_kb_document(
|
||||
validated_path
|
||||
):
|
||||
if updated_content is None:
|
||||
return (
|
||||
f"Error: could not reload edited file '{validated_path}' for "
|
||||
"persistence."
|
||||
)
|
||||
persist_result = self._run_async_blocking(
|
||||
self._persist_edited_document(
|
||||
file_path=validated_path,
|
||||
|
|
@ -818,8 +987,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
] = False,
|
||||
) -> Command | str:
|
||||
resolved_backend = self._get_backend(runtime)
|
||||
target_path = self._resolve_write_target_path(file_path, runtime)
|
||||
try:
|
||||
validated_path = validate_path(file_path)
|
||||
validated_path = validate_path(target_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
res: EditResult = await resolved_backend.aedit(
|
||||
|
|
@ -831,13 +1001,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
if res.error:
|
||||
return res.error
|
||||
|
||||
if not self._is_kb_document(validated_path):
|
||||
read_result = await resolved_backend.aread(
|
||||
validated_path, offset=0, limit=200000
|
||||
)
|
||||
if read_result.error or read_result.file_data is None:
|
||||
return f"Error: could not reload edited file '{validated_path}' for persistence."
|
||||
updated_content = read_result.file_data["content"]
|
||||
verify_error, updated_content = await self._verify_edited_content_async(
|
||||
backend=resolved_backend,
|
||||
path=validated_path,
|
||||
new_string=new_string,
|
||||
)
|
||||
if verify_error:
|
||||
return verify_error
|
||||
|
||||
if self._should_persist_documents() and not self._is_kb_document(
|
||||
validated_path
|
||||
):
|
||||
if updated_content is None:
|
||||
return (
|
||||
f"Error: could not reload edited file '{validated_path}' for "
|
||||
"persistence."
|
||||
)
|
||||
persist_error = await self._persist_edited_document(
|
||||
file_path=validated_path,
|
||||
updated_content=updated_content,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import (
|
||||
NATIVE_TO_LEGACY_DOCTYPE,
|
||||
Chunk,
|
||||
|
|
@ -857,6 +858,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
*,
|
||||
llm: BaseChatModel | None = None,
|
||||
search_space_id: int,
|
||||
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
|
|
@ -865,6 +867,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
) -> None:
|
||||
self.llm = llm
|
||||
self.search_space_id = search_space_id
|
||||
self.filesystem_mode = filesystem_mode
|
||||
self.available_connectors = available_connectors
|
||||
self.available_document_types = available_document_types
|
||||
self.top_k = top_k
|
||||
|
|
@ -996,6 +999,9 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||
# Local-folder mode should not seed cloud KB documents into filesystem.
|
||||
return None
|
||||
|
||||
last_human = None
|
||||
for msg in reversed(messages):
|
||||
|
|
|
|||
|
|
@ -141,6 +141,15 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons
|
|||
exc.status_code,
|
||||
message,
|
||||
)
|
||||
elif exc.status_code >= 400:
|
||||
_error_logger.warning(
|
||||
"[%s] %s %s - HTTPException %d: %s",
|
||||
rid,
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc.status_code,
|
||||
message,
|
||||
)
|
||||
if should_sanitize:
|
||||
message = GENERIC_5XX_MESSAGE
|
||||
err_code = "INTERNAL_ERROR"
|
||||
|
|
@ -170,6 +179,15 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons
|
|||
exc.status_code,
|
||||
detail,
|
||||
)
|
||||
elif exc.status_code >= 400:
|
||||
_error_logger.warning(
|
||||
"[%s] %s %s - HTTPException %d: %s",
|
||||
rid,
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc.status_code,
|
||||
detail,
|
||||
)
|
||||
if should_sanitize:
|
||||
detail = GENERIC_5XX_MESSAGE
|
||||
code = _status_to_code(exc.status_code, detail)
|
||||
|
|
|
|||
|
|
@ -339,6 +339,9 @@ class Config:
|
|||
# self-hosted: Full access to local file system connectors (Obsidian, etc.)
|
||||
# cloud: Only cloud-based connectors available
|
||||
DEPLOYMENT_MODE = os.getenv("SURFSENSE_DEPLOYMENT_MODE", "self-hosted")
|
||||
ENABLE_DESKTOP_LOCAL_FILESYSTEM = (
|
||||
os.getenv("ENABLE_DESKTOP_LOCAL_FILESYSTEM", "FALSE").upper() == "TRUE"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_self_hosted(cls) -> bool:
|
||||
|
|
|
|||
|
|
@ -22,6 +22,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import (
|
||||
ClientPlatform,
|
||||
FilesystemMode,
|
||||
FilesystemSelection,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ChatComment,
|
||||
ChatVisibility,
|
||||
|
|
@ -63,6 +69,51 @@ _background_tasks: set[asyncio.Task] = set()
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _resolve_filesystem_selection(
|
||||
*,
|
||||
mode: str,
|
||||
client_platform: str,
|
||||
local_root: str | None,
|
||||
) -> FilesystemSelection:
|
||||
"""Validate and normalize filesystem mode settings from request payload."""
|
||||
try:
|
||||
resolved_mode = FilesystemMode(mode)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid filesystem_mode") from exc
|
||||
try:
|
||||
resolved_platform = ClientPlatform(client_platform)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid client_platform") from exc
|
||||
|
||||
if resolved_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
|
||||
if not config.ENABLE_DESKTOP_LOCAL_FILESYSTEM:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Desktop local filesystem mode is disabled on this deployment.",
|
||||
)
|
||||
if resolved_platform != ClientPlatform.DESKTOP:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="desktop_local_folder mode is only available on desktop runtime.",
|
||||
)
|
||||
if not local_root or not local_root.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="local_filesystem_root is required for desktop_local_folder mode.",
|
||||
)
|
||||
return FilesystemSelection(
|
||||
mode=resolved_mode,
|
||||
client_platform=resolved_platform,
|
||||
local_root_path=local_root.strip(),
|
||||
)
|
||||
|
||||
return FilesystemSelection(
|
||||
mode=FilesystemMode.CLOUD,
|
||||
client_platform=resolved_platform,
|
||||
local_root_path=None,
|
||||
)
|
||||
|
||||
|
||||
def _try_delete_sandbox(thread_id: int) -> None:
|
||||
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
||||
from app.agents.new_chat.sandbox import (
|
||||
|
|
@ -474,6 +525,11 @@ async def get_thread_messages(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
local_root=request.local_filesystem_root,
|
||||
)
|
||||
|
||||
# Get messages with their authors and token usage loaded
|
||||
messages_result = await session.execute(
|
||||
|
|
@ -1098,6 +1154,7 @@ async def list_agent_tools(
|
|||
@router.post("/new_chat")
|
||||
async def handle_new_chat(
|
||||
request: NewChatRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
@ -1133,6 +1190,11 @@ async def handle_new_chat(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
local_root=request.local_filesystem_root,
|
||||
)
|
||||
|
||||
# Get search space to check LLM config preferences
|
||||
search_space_result = await session.execute(
|
||||
|
|
@ -1175,6 +1237,8 @@ async def handle_new_chat(
|
|||
thread_visibility=thread.visibility,
|
||||
current_user_display_name=user.display_name or "A team member",
|
||||
disabled_tools=request.disabled_tools,
|
||||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
@ -1202,6 +1266,7 @@ async def handle_new_chat(
|
|||
async def regenerate_response(
|
||||
thread_id: int,
|
||||
request: RegenerateRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
@ -1247,6 +1312,11 @@ async def regenerate_response(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
local_root=request.local_filesystem_root,
|
||||
)
|
||||
|
||||
# Get the checkpointer and state history
|
||||
checkpointer = await get_checkpointer()
|
||||
|
|
@ -1412,6 +1482,8 @@ async def regenerate_response(
|
|||
thread_visibility=thread.visibility,
|
||||
current_user_display_name=user.display_name or "A team member",
|
||||
disabled_tools=request.disabled_tools,
|
||||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
):
|
||||
yield chunk
|
||||
streaming_completed = True
|
||||
|
|
@ -1477,6 +1549,7 @@ async def regenerate_response(
|
|||
async def resume_chat(
|
||||
thread_id: int,
|
||||
request: ResumeRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
@ -1498,6 +1571,11 @@ async def resume_chat(
|
|||
)
|
||||
|
||||
await check_thread_access(session, thread, user)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
local_root=request.local_filesystem_root,
|
||||
)
|
||||
|
||||
search_space_result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||
|
|
@ -1526,6 +1604,8 @@ async def resume_chat(
|
|||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
thread_visibility=thread.visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
|
|||
|
|
@ -184,6 +184,9 @@ class NewChatRequest(BaseModel):
|
|||
disabled_tools: list[str] | None = (
|
||||
None # Optional list of tool names the user has disabled from the UI
|
||||
)
|
||||
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||
client_platform: Literal["web", "desktop"] = "web"
|
||||
local_filesystem_root: str | None = None
|
||||
|
||||
|
||||
class RegenerateRequest(BaseModel):
|
||||
|
|
@ -204,6 +207,9 @@ class RegenerateRequest(BaseModel):
|
|||
mentioned_document_ids: list[int] | None = None
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None
|
||||
disabled_tools: list[str] | None = None
|
||||
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||
client_platform: Literal["web", "desktop"] = "web"
|
||||
local_filesystem_root: str | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -227,6 +233,9 @@ class ResumeDecision(BaseModel):
|
|||
class ResumeRequest(BaseModel):
|
||||
search_space_id: int
|
||||
decisions: list[ResumeDecision]
|
||||
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||
client_platform: Literal["web", "desktop"] = "web"
|
||||
local_filesystem_root: str | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ from sqlalchemy.orm import selectinload
|
|||
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemSelection
|
||||
from app.config import config
|
||||
from app.agents.new_chat.llm_config import (
|
||||
AgentConfig,
|
||||
create_chat_litellm_from_agent_config,
|
||||
|
|
@ -145,6 +147,85 @@ class StreamResult:
|
|||
interrupt_value: dict[str, Any] | None = None
|
||||
sandbox_files: list[str] = field(default_factory=list)
|
||||
agent_called_update_memory: bool = False
|
||||
request_id: str | None = None
|
||||
turn_id: str = ""
|
||||
filesystem_mode: str = "cloud"
|
||||
client_platform: str = "web"
|
||||
intent_detected: str = "chat_only"
|
||||
intent_confidence: float = 0.0
|
||||
write_attempted: bool = False
|
||||
write_succeeded: bool = False
|
||||
verification_succeeded: bool = False
|
||||
commit_gate_passed: bool = True
|
||||
commit_gate_reason: str = ""
|
||||
|
||||
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _tool_output_to_text(tool_output: Any) -> str:
|
||||
if isinstance(tool_output, dict):
|
||||
if isinstance(tool_output.get("result"), str):
|
||||
return tool_output["result"]
|
||||
if isinstance(tool_output.get("error"), str):
|
||||
return tool_output["error"]
|
||||
return json.dumps(tool_output, ensure_ascii=False)
|
||||
return str(tool_output)
|
||||
|
||||
|
||||
def _tool_output_has_error(tool_output: Any) -> bool:
|
||||
if isinstance(tool_output, dict):
|
||||
if tool_output.get("error"):
|
||||
return True
|
||||
result = tool_output.get("result")
|
||||
if isinstance(result, str) and result.strip().lower().startswith("error:"):
|
||||
return True
|
||||
return False
|
||||
if isinstance(tool_output, str):
|
||||
return tool_output.strip().lower().startswith("error:")
|
||||
return False
|
||||
|
||||
|
||||
def _contract_enforcement_active(result: StreamResult) -> bool:
|
||||
# Keep policy deterministic with no env-driven progression modes:
|
||||
# enforce the file-operation contract only in desktop local-folder mode.
|
||||
return result.filesystem_mode == "desktop_local_folder"
|
||||
|
||||
|
||||
def _evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]:
|
||||
if result.intent_detected != "file_write":
|
||||
return True, ""
|
||||
if not result.write_attempted:
|
||||
return False, "no_write_attempt"
|
||||
if not result.write_succeeded:
|
||||
return False, "write_failed"
|
||||
if not result.verification_succeeded:
|
||||
return False, "verification_failed"
|
||||
return True, ""
|
||||
|
||||
|
||||
def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
|
||||
payload: dict[str, Any] = {
|
||||
"stage": stage,
|
||||
"request_id": result.request_id or "unknown",
|
||||
"turn_id": result.turn_id or "unknown",
|
||||
"chat_id": result.turn_id.split(":", 1)[0] if ":" in result.turn_id else "unknown",
|
||||
"filesystem_mode": result.filesystem_mode,
|
||||
"client_platform": result.client_platform,
|
||||
"intent_detected": result.intent_detected,
|
||||
"intent_confidence": result.intent_confidence,
|
||||
"write_attempted": result.write_attempted,
|
||||
"write_succeeded": result.write_succeeded,
|
||||
"verification_succeeded": result.verification_succeeded,
|
||||
"commit_gate_passed": result.commit_gate_passed,
|
||||
"commit_gate_reason": result.commit_gate_reason or None,
|
||||
}
|
||||
payload.update(extra)
|
||||
_perf_log.info("[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False))
|
||||
|
||||
|
||||
async def _stream_agent_events(
|
||||
|
|
@ -239,6 +320,8 @@ async def _stream_agent_events(
|
|||
tool_name = event.get("name", "unknown_tool")
|
||||
run_id = event.get("run_id", "")
|
||||
tool_input = event.get("data", {}).get("input", {})
|
||||
if tool_name in ("write_file", "edit_file"):
|
||||
result.write_attempted = True
|
||||
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
|
|
@ -514,6 +597,14 @@ async def _stream_agent_events(
|
|||
else:
|
||||
tool_output = {"result": str(raw_output) if raw_output else "completed"}
|
||||
|
||||
if tool_name in ("write_file", "edit_file"):
|
||||
if _tool_output_has_error(tool_output):
|
||||
# Keep successful evidence if a previous write/edit in this turn succeeded.
|
||||
pass
|
||||
else:
|
||||
result.write_succeeded = True
|
||||
result.verification_succeeded = True
|
||||
|
||||
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
|
||||
original_step_id = tool_step_ids.get(
|
||||
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
|
||||
|
|
@ -1143,10 +1234,59 @@ async def _stream_agent_events(
|
|||
if completion_event:
|
||||
yield completion_event
|
||||
|
||||
state = await agent.aget_state(config)
|
||||
state_values = getattr(state, "values", {}) or {}
|
||||
contract_state = state_values.get("file_operation_contract") or {}
|
||||
contract_turn_id = contract_state.get("turn_id")
|
||||
current_turn_id = config.get("configurable", {}).get("turn_id", "")
|
||||
intent_value = contract_state.get("intent")
|
||||
if (
|
||||
isinstance(intent_value, str)
|
||||
and intent_value in ("chat_only", "file_write", "file_read")
|
||||
and contract_turn_id == current_turn_id
|
||||
):
|
||||
result.intent_detected = intent_value
|
||||
if (
|
||||
isinstance(intent_value, str)
|
||||
and intent_value in (
|
||||
"chat_only",
|
||||
"file_write",
|
||||
"file_read",
|
||||
)
|
||||
and contract_turn_id != current_turn_id
|
||||
):
|
||||
# Ignore stale intent contracts from previous turns/checkpoints.
|
||||
result.intent_detected = "chat_only"
|
||||
result.intent_confidence = (
|
||||
_safe_float(contract_state.get("confidence"), default=0.0)
|
||||
if contract_turn_id == current_turn_id
|
||||
else 0.0
|
||||
)
|
||||
|
||||
if result.intent_detected == "file_write":
|
||||
result.commit_gate_passed, result.commit_gate_reason = (
|
||||
_evaluate_file_contract_outcome(result)
|
||||
)
|
||||
if not result.commit_gate_passed:
|
||||
if _contract_enforcement_active(result):
|
||||
gate_notice = (
|
||||
"I could not complete the requested file write because no successful "
|
||||
"write_file/edit_file operation was confirmed."
|
||||
)
|
||||
gate_text_id = streaming_service.generate_text_id()
|
||||
yield streaming_service.format_text_start(gate_text_id)
|
||||
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
|
||||
yield streaming_service.format_text_end(gate_text_id)
|
||||
yield streaming_service.format_terminal_info(gate_notice, "error")
|
||||
accumulated_text = gate_notice
|
||||
else:
|
||||
result.commit_gate_passed = True
|
||||
result.commit_gate_reason = ""
|
||||
|
||||
result.accumulated_text = accumulated_text
|
||||
result.agent_called_update_memory = called_update_memory
|
||||
_log_file_contract("turn_outcome", result)
|
||||
|
||||
state = await agent.aget_state(config)
|
||||
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
|
||||
if is_interrupted:
|
||||
result.is_interrupted = True
|
||||
|
|
@ -1167,6 +1307,8 @@ async def stream_new_chat(
|
|||
thread_visibility: ChatVisibility | None = None,
|
||||
current_user_display_name: str | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream chat responses from the new SurfSense deep agent.
|
||||
|
|
@ -1194,6 +1336,20 @@ async def stream_new_chat(
|
|||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
|
||||
fs_platform = (
|
||||
filesystem_selection.client_platform.value if filesystem_selection else "web"
|
||||
)
|
||||
stream_result.request_id = request_id
|
||||
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
|
||||
stream_result.filesystem_mode = fs_mode
|
||||
stream_result.client_platform = fs_platform
|
||||
_log_file_contract("turn_start", stream_result)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] filesystem_mode=%s client_platform=%s",
|
||||
fs_mode,
|
||||
fs_platform,
|
||||
)
|
||||
log_system_snapshot("stream_new_chat_START")
|
||||
|
||||
from app.services.token_tracking_service import start_turn
|
||||
|
|
@ -1329,6 +1485,7 @@ async def stream_new_chat(
|
|||
thread_visibility=visibility,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
filesystem_selection=filesystem_selection,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
|
|
@ -1435,6 +1592,8 @@ async def stream_new_chat(
|
|||
# We will use this to simulate group chat functionality in the future
|
||||
"messages": langchain_messages,
|
||||
"search_space_id": search_space_id,
|
||||
"request_id": request_id or "unknown",
|
||||
"turn_id": stream_result.turn_id,
|
||||
}
|
||||
|
||||
_perf_log.info(
|
||||
|
|
@ -1464,6 +1623,8 @@ async def stream_new_chat(
|
|||
# Configure LangGraph with thread_id for memory
|
||||
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
|
||||
configurable = {"thread_id": str(chat_id)}
|
||||
configurable["request_id"] = request_id or "unknown"
|
||||
configurable["turn_id"] = stream_result.turn_id
|
||||
if checkpoint_id:
|
||||
configurable["checkpoint_id"] = checkpoint_id
|
||||
|
||||
|
|
@ -1871,10 +2032,26 @@ async def stream_resume_chat(
|
|||
user_id: str | None = None,
|
||||
llm_config_id: int = -1,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
|
||||
fs_platform = (
|
||||
filesystem_selection.client_platform.value if filesystem_selection else "web"
|
||||
)
|
||||
stream_result.request_id = request_id
|
||||
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
|
||||
stream_result.filesystem_mode = fs_mode
|
||||
stream_result.client_platform = fs_platform
|
||||
_log_file_contract("turn_start", stream_result)
|
||||
_perf_log.info(
|
||||
"[stream_resume] filesystem_mode=%s client_platform=%s",
|
||||
fs_mode,
|
||||
fs_platform,
|
||||
)
|
||||
|
||||
from app.services.token_tracking_service import start_turn
|
||||
|
||||
|
|
@ -1991,6 +2168,7 @@ async def stream_resume_chat(
|
|||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
|
|
@ -2009,7 +2187,11 @@ async def stream_resume_chat(
|
|||
from langgraph.types import Command
|
||||
|
||||
config = {
|
||||
"configurable": {"thread_id": str(chat_id)},
|
||||
"configurable": {
|
||||
"thread_id": str(chat_id),
|
||||
"request_id": request_id or "unknown",
|
||||
"turn_id": stream_result.turn_id,
|
||||
},
|
||||
"recursion_limit": 80,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
import pytest
|
||||
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
StreamResult,
|
||||
_contract_enforcement_active,
|
||||
_evaluate_file_contract_outcome,
|
||||
_tool_output_has_error,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_tool_output_error_detection():
|
||||
assert _tool_output_has_error("Error: failed to write file")
|
||||
assert _tool_output_has_error({"error": "boom"})
|
||||
assert _tool_output_has_error({"result": "Error: disk is full"})
|
||||
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
|
||||
|
||||
|
||||
def test_file_write_contract_outcome_reasons():
|
||||
result = StreamResult(intent_detected="file_write")
|
||||
passed, reason = _evaluate_file_contract_outcome(result)
|
||||
assert not passed
|
||||
assert reason == "no_write_attempt"
|
||||
|
||||
result.write_attempted = True
|
||||
passed, reason = _evaluate_file_contract_outcome(result)
|
||||
assert not passed
|
||||
assert reason == "write_failed"
|
||||
|
||||
result.write_succeeded = True
|
||||
passed, reason = _evaluate_file_contract_outcome(result)
|
||||
assert not passed
|
||||
assert reason == "verification_failed"
|
||||
|
||||
result.verification_succeeded = True
|
||||
passed, reason = _evaluate_file_contract_outcome(result)
|
||||
assert passed
|
||||
assert reason == ""
|
||||
|
||||
|
||||
def test_contract_enforcement_local_only():
|
||||
result = StreamResult(filesystem_mode="desktop_local_folder")
|
||||
assert _contract_enforcement_active(result)
|
||||
|
||||
result.filesystem_mode = "cloud"
|
||||
assert not _contract_enforcement_active(result)
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue