mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-16 21:05:20 +02:00
feat: updated file management for main agent
This commit is contained in:
parent
8d50f90060
commit
05ca4c0b9f
27 changed files with 5054 additions and 1803 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -7,4 +7,5 @@ node_modules/
|
|||
.pnpm-store
|
||||
.DS_Store
|
||||
deepagents/
|
||||
debug.log
|
||||
debug.log
|
||||
opencode/
|
||||
|
|
@ -28,13 +28,76 @@ from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
|||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.agents.new_chat.document_xml import build_document_xml
|
||||
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
||||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
build_scoped_filesystem,
|
||||
search_knowledge_base,
|
||||
)
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.db import shielded_async_session
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
try:
|
||||
from deepagents.backends.utils import create_file_data
|
||||
except Exception: # pragma: no cover - defensive
|
||||
|
||||
def create_file_data(content: str) -> dict[str, Any]:
|
||||
return {"content": content.split("\n")}
|
||||
|
||||
|
||||
async def _build_autocomplete_filesystem(
|
||||
*,
|
||||
documents: Any,
|
||||
search_space_id: int,
|
||||
) -> tuple[dict[str, Any], dict[int, str]]:
|
||||
"""Build a ``state['files']``-shaped dict from KB search results.
|
||||
|
||||
This is the autocomplete-specific replacement for the previous
|
||||
``build_scoped_filesystem`` helper. It uses the canonical path resolver
|
||||
so paths line up with the rest of the system, including collision
|
||||
suffixes for duplicate titles.
|
||||
"""
|
||||
files: dict[str, Any] = {}
|
||||
doc_id_to_path: dict[int, str] = {}
|
||||
|
||||
if not documents:
|
||||
return files, doc_id_to_path
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
index = await build_path_index(session, search_space_id)
|
||||
|
||||
for document in documents:
|
||||
if not isinstance(document, dict):
|
||||
continue
|
||||
meta = document.get("document") or {}
|
||||
doc_id = meta.get("id")
|
||||
if not isinstance(doc_id, int):
|
||||
continue
|
||||
title = str(meta.get("title") or "untitled")
|
||||
folder_id = meta.get("folder_id")
|
||||
path = doc_to_virtual_path(
|
||||
doc_id=doc_id, title=title, folder_id=folder_id, index=index
|
||||
)
|
||||
chunk_ids = document.get("matched_chunk_ids") or []
|
||||
try:
|
||||
matched_set = {int(c) for c in chunk_ids}
|
||||
except (TypeError, ValueError):
|
||||
matched_set = set()
|
||||
xml = build_document_xml(document, matched_chunk_ids=matched_set)
|
||||
files[path] = create_file_data(xml)
|
||||
doc_id_to_path[doc_id] = path
|
||||
|
||||
if not files:
|
||||
# Ensure the synthetic /documents folder is visible even when empty.
|
||||
files.setdefault(f"{DOCUMENTS_ROOT}/.placeholder", create_file_data(""))
|
||||
|
||||
return files, doc_id_to_path
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KB_TOP_K = 10
|
||||
|
|
@ -174,7 +237,7 @@ async def precompute_kb_filesystem(
|
|||
if not search_results:
|
||||
return _KBResult()
|
||||
|
||||
new_files, _ = await build_scoped_filesystem(
|
||||
new_files, _ = await _build_autocomplete_filesystem(
|
||||
documents=search_results,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
|
@ -215,13 +278,12 @@ async def precompute_kb_filesystem(
|
|||
class AutocompleteFilesystemMiddleware(SurfSenseFilesystemMiddleware):
|
||||
"""Filesystem middleware for autocomplete — read-only exploration only.
|
||||
|
||||
Strips ``save_document`` (permanent KB persistence) and passes
|
||||
``search_space_id=None`` so ``write_file`` / ``edit_file`` stay ephemeral.
|
||||
Passes ``search_space_id=None`` so the new persistence pipeline is
|
||||
bypassed; the autocomplete flow only reads, never commits to Postgres.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(search_space_id=None, created_by_id=None)
|
||||
self.tools = [t for t in self.tools if t.name != "save_document"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -34,12 +34,15 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemSelection
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.agents.new_chat.middleware import (
|
||||
AnonymousDocumentMiddleware,
|
||||
DedupHITLToolCallsMiddleware,
|
||||
FileIntentMiddleware,
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
KnowledgePriorityMiddleware,
|
||||
KnowledgeTreeMiddleware,
|
||||
MemoryInjectionMiddleware,
|
||||
SurfSenseFilesystemMiddleware,
|
||||
)
|
||||
|
|
@ -246,7 +249,12 @@ async def create_surfsense_deep_agent(
|
|||
"""
|
||||
_t_agent_total = time.perf_counter()
|
||||
filesystem_selection = filesystem_selection or FilesystemSelection()
|
||||
backend_resolver = build_backend_resolver(filesystem_selection)
|
||||
backend_resolver = build_backend_resolver(
|
||||
filesystem_selection,
|
||||
search_space_id=search_space_id
|
||||
if filesystem_selection.mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
)
|
||||
|
||||
# Discover available connectors and document types for this search space
|
||||
available_connectors: list[str] | None = None
|
||||
|
|
@ -299,7 +307,9 @@ async def create_surfsense_deep_agent(
|
|||
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
|
||||
modified_disabled_tools.extend(get_connector_gated_tools(available_connectors))
|
||||
|
||||
# Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware.
|
||||
# Remove direct KB search tool; KnowledgePriorityMiddleware now runs hybrid
|
||||
# search per turn and surfaces hits as a <priority_documents> hint plus
|
||||
# `<chunk_index matched="true">` markers inside lazy-loaded XML.
|
||||
if "search_knowledge_base" not in modified_disabled_tools:
|
||||
modified_disabled_tools.append("search_knowledge_base")
|
||||
|
||||
|
|
@ -365,6 +375,11 @@ async def create_surfsense_deep_agent(
|
|||
)
|
||||
|
||||
# General-purpose subagent middleware
|
||||
# Subagent omits AnonymousDocumentMiddleware, KnowledgeTreeMiddleware,
|
||||
# KnowledgePriorityMiddleware, and KnowledgeBasePersistenceMiddleware - it
|
||||
# inherits state and tools from the parent, but should not (a) re-load
|
||||
# anon docs / re-render the tree / re-run hybrid search, or (b) commit at
|
||||
# its own completion (only the top-level agent's aafter_agent commits).
|
||||
gp_middleware = [
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
|
|
@ -389,19 +404,35 @@ async def create_surfsense_deep_agent(
|
|||
}
|
||||
|
||||
# Main agent middleware
|
||||
# Order: AnonDoc -> Tree -> Priority -> FileIntent -> Filesystem -> Persistence -> ...
|
||||
# before_agent hooks run in declared order; later injections sit closer to
|
||||
# the latest human turn. Tree (large + cacheable) is injected earliest so
|
||||
# provider-side prefix caching has more material to hit; FileIntent (most
|
||||
# actionable per-turn contract) is injected closest to the user message.
|
||||
deepagent_middleware = [
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
FileIntentMiddleware(llm=llm),
|
||||
KnowledgeBaseSearchMiddleware(
|
||||
AnonymousDocumentMiddleware(
|
||||
anon_session_id=anon_session_id,
|
||||
)
|
||||
if filesystem_selection.mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
KnowledgeTreeMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_selection.mode,
|
||||
llm=llm,
|
||||
)
|
||||
if filesystem_selection.mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
KnowledgePriorityMiddleware(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_selection.mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
anon_session_id=anon_session_id,
|
||||
),
|
||||
FileIntentMiddleware(llm=llm),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_selection.mode,
|
||||
|
|
@ -409,12 +440,20 @@ async def create_surfsense_deep_agent(
|
|||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
KnowledgeBasePersistenceMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
filesystem_mode=filesystem_selection.mode,
|
||||
)
|
||||
if filesystem_selection.mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
|
||||
create_safe_summarization_middleware(llm, StateBackend),
|
||||
PatchToolCallsMiddleware(),
|
||||
DedupHITLToolCallsMiddleware(agent_tools=tools),
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
deepagent_middleware = [m for m in deepagent_middleware if m is not None]
|
||||
|
||||
# Combine system_prompt with BASE_AGENT_PROMPT (same as create_deep_agent)
|
||||
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
||||
|
|
|
|||
103
surfsense_backend/app/agents/new_chat/document_xml.py
Normal file
103
surfsense_backend/app/agents/new_chat/document_xml.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""Shared XML builder for KB documents.
|
||||
|
||||
Produces the citation-friendly XML used by every read of a knowledge-base
|
||||
document (lazy-loaded by :class:`KBPostgresBackend` and synthetic anonymous
|
||||
files). The XML carries a ``<chunk_index>`` near the top so the LLM can jump
|
||||
directly to matched-chunk line ranges via ``read_file(offset=…, limit=…)``.
|
||||
|
||||
Extracted from the original ``knowledge_search.py`` so the backend, the
|
||||
priority middleware, and any future renderer share a single implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def build_document_xml(
|
||||
document: dict[str, Any],
|
||||
matched_chunk_ids: set[int] | None = None,
|
||||
) -> str:
|
||||
"""Build citation-friendly XML with a ``<chunk_index>`` for smart seeking.
|
||||
|
||||
Args:
|
||||
document: Dict shape produced by hybrid search / lazy-load helpers.
|
||||
Expected keys: ``document`` (with ``id``, ``title``,
|
||||
``document_type``, ``metadata``) and ``chunks``
|
||||
(list of ``{chunk_id, content}``).
|
||||
matched_chunk_ids: Optional set of chunk IDs to flag as
|
||||
``matched="true"`` in the chunk index.
|
||||
"""
|
||||
matched = matched_chunk_ids or set()
|
||||
|
||||
doc_meta = document.get("document") or {}
|
||||
metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {}
|
||||
document_id = doc_meta.get("id", document.get("document_id", "unknown"))
|
||||
document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN"))
|
||||
title = doc_meta.get("title") or metadata.get("title") or "Untitled Document"
|
||||
url = (
|
||||
metadata.get("url") or metadata.get("source") or metadata.get("page_url") or ""
|
||||
)
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||
|
||||
metadata_lines: list[str] = [
|
||||
"<document>",
|
||||
"<document_metadata>",
|
||||
f" <document_id>{document_id}</document_id>",
|
||||
f" <document_type>{document_type}</document_type>",
|
||||
f" <title><![CDATA[{title}]]></title>",
|
||||
f" <url><![CDATA[{url}]]></url>",
|
||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
||||
"</document_metadata>",
|
||||
"",
|
||||
]
|
||||
|
||||
chunks = document.get("chunks") or []
|
||||
chunk_entries: list[tuple[int | None, str]] = []
|
||||
if isinstance(chunks, list):
|
||||
for chunk in chunks:
|
||||
if not isinstance(chunk, dict):
|
||||
continue
|
||||
chunk_id = chunk.get("chunk_id") or chunk.get("id")
|
||||
chunk_content = str(chunk.get("content", "")).strip()
|
||||
if not chunk_content:
|
||||
continue
|
||||
if chunk_id is None:
|
||||
xml = f" <chunk><![CDATA[{chunk_content}]]></chunk>"
|
||||
else:
|
||||
xml = f" <chunk id='{chunk_id}'><![CDATA[{chunk_content}]]></chunk>"
|
||||
chunk_entries.append((chunk_id, xml))
|
||||
|
||||
index_overhead = 1 + len(chunk_entries) + 1 + 1 + 1
|
||||
first_chunk_line = len(metadata_lines) + index_overhead + 1
|
||||
|
||||
current_line = first_chunk_line
|
||||
index_entry_lines: list[str] = []
|
||||
for cid, xml_str in chunk_entries:
|
||||
num_lines = xml_str.count("\n") + 1
|
||||
end_line = current_line + num_lines - 1
|
||||
matched_attr = ' matched="true"' if cid is not None and cid in matched else ""
|
||||
if cid is not None:
|
||||
index_entry_lines.append(
|
||||
f' <entry chunk_id="{cid}" lines="{current_line}-{end_line}"{matched_attr}/>'
|
||||
)
|
||||
else:
|
||||
index_entry_lines.append(
|
||||
f' <entry lines="{current_line}-{end_line}"{matched_attr}/>'
|
||||
)
|
||||
current_line = end_line + 1
|
||||
|
||||
lines = metadata_lines.copy()
|
||||
lines.append("<chunk_index>")
|
||||
lines.extend(index_entry_lines)
|
||||
lines.append("</chunk_index>")
|
||||
lines.append("")
|
||||
lines.append("<document_content>")
|
||||
for _, xml_str in chunk_entries:
|
||||
lines.append(xml_str)
|
||||
lines.extend(["</document_content>", "</document>"])
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
__all__ = ["build_document_xml"]
|
||||
|
|
@ -5,10 +5,12 @@ from __future__ import annotations
|
|||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
|
||||
from deepagents.backends.protocol import BackendProtocol
|
||||
from deepagents.backends.state import StateBackend
|
||||
from langgraph.prebuilt.tool_node import ToolRuntime
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
|
||||
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||
MultiRootLocalFolderBackend,
|
||||
)
|
||||
|
|
@ -23,8 +25,20 @@ def _cached_multi_root_backend(
|
|||
|
||||
def build_backend_resolver(
|
||||
selection: FilesystemSelection,
|
||||
) -> Callable[[ToolRuntime], StateBackend | MultiRootLocalFolderBackend]:
|
||||
"""Create deepagents backend resolver for the selected filesystem mode."""
|
||||
*,
|
||||
search_space_id: int | None = None,
|
||||
) -> Callable[[ToolRuntime], BackendProtocol]:
|
||||
"""Create deepagents backend resolver for the selected filesystem mode.
|
||||
|
||||
In cloud mode the resolver returns a fresh :class:`KBPostgresBackend`
|
||||
bound to the current ``runtime`` so the backend can read staging state
|
||||
(``staged_dirs``, ``pending_moves``, ``files`` cache, ``kb_anon_doc``,
|
||||
``kb_matched_chunk_ids``) for each tool call. When no ``search_space_id``
|
||||
is provided, the resolver falls back to :class:`StateBackend` (used by
|
||||
sub-agents and tests that don't need DB-backed reads).
|
||||
|
||||
Desktop-local mode unchanged.
|
||||
"""
|
||||
|
||||
if selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER and selection.local_mounts:
|
||||
|
||||
|
|
@ -36,7 +50,14 @@ def build_backend_resolver(
|
|||
|
||||
return _resolve_local
|
||||
|
||||
def _resolve_cloud(runtime: ToolRuntime) -> StateBackend:
|
||||
if search_space_id is not None:
|
||||
|
||||
def _resolve_kb(runtime: ToolRuntime) -> BackendProtocol:
|
||||
return KBPostgresBackend(search_space_id, runtime)
|
||||
|
||||
return _resolve_kb
|
||||
|
||||
def _resolve_state(runtime: ToolRuntime) -> StateBackend:
|
||||
return StateBackend(runtime)
|
||||
|
||||
return _resolve_cloud
|
||||
return _resolve_state
|
||||
|
|
|
|||
113
surfsense_backend/app/agents/new_chat/filesystem_state.py
Normal file
113
surfsense_backend/app/agents/new_chat/filesystem_state.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
"""LangGraph state schema additions used by the SurfSense filesystem agent.
|
||||
|
||||
This schema extends deepagents' upstream :class:`FilesystemState` with the
|
||||
extra fields needed to implement Postgres-backed virtual filesystem semantics:
|
||||
|
||||
* ``cwd`` — current working directory (per-thread checkpointed).
|
||||
* ``staged_dirs`` — pending mkdir requests (cloud only).
|
||||
* ``pending_moves`` — pending move_file requests (cloud only).
|
||||
* ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads.
|
||||
* ``dirty_paths`` — paths whose state file content differs from DB.
|
||||
* ``kb_priority`` — top-K priority hints rendered into a system message.
|
||||
* ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting.
|
||||
* ``kb_anon_doc`` — Redis-loaded anonymous document (if any).
|
||||
* ``tree_version`` — bumped by persistence; invalidates the tree render cache.
|
||||
|
||||
Tools mutate these fields ONLY via ``Command(update=...)`` returns; the
|
||||
reducers in :mod:`app.agents.new_chat.state_reducers` handle merging.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any, NotRequired
|
||||
|
||||
from deepagents.middleware.filesystem import FilesystemState
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.new_chat.state_reducers import (
|
||||
_add_unique_reducer,
|
||||
_dict_merge_with_tombstones_reducer,
|
||||
_list_append_reducer,
|
||||
_replace_reducer,
|
||||
)
|
||||
|
||||
|
||||
class PendingMove(TypedDict):
|
||||
"""A staged move_file operation pending end-of-turn commit."""
|
||||
|
||||
source: str
|
||||
dest: str
|
||||
overwrite: bool
|
||||
|
||||
|
||||
class KbPriorityEntry(TypedDict, total=False):
|
||||
path: str
|
||||
score: float
|
||||
document_id: int | None
|
||||
title: str
|
||||
mentioned: bool
|
||||
|
||||
|
||||
class KbAnonDoc(TypedDict, total=False):
|
||||
"""In-memory anonymous-session document loaded from Redis."""
|
||||
|
||||
path: str
|
||||
title: str
|
||||
content: str
|
||||
chunks: list[dict[str, Any]]
|
||||
|
||||
|
||||
class SurfSenseFilesystemState(FilesystemState):
|
||||
"""Filesystem state used by the SurfSense agent (cloud + desktop).
|
||||
|
||||
Extends deepagents' :class:`FilesystemState` (which provides ``files``)
|
||||
with cloud-mode staging fields and search-priority hints. All extra fields
|
||||
are reducer-backed so that ``Command(update=...)`` payloads merge cleanly
|
||||
across agent steps and across checkpoints.
|
||||
"""
|
||||
|
||||
cwd: NotRequired[Annotated[str, _replace_reducer]]
|
||||
"""Current working directory.
|
||||
|
||||
Defaults to ``"/documents"`` in cloud mode and ``"/"`` (or first mount) in
|
||||
desktop mode. Initialized once per thread by ``KnowledgeTreeMiddleware``.
|
||||
"""
|
||||
|
||||
staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||
"""mkdir paths staged for end-of-turn folder creation (cloud only)."""
|
||||
|
||||
pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]]
|
||||
"""move_file ops staged for end-of-turn commit (cloud only)."""
|
||||
|
||||
doc_id_by_path: NotRequired[
|
||||
Annotated[dict[str, int], _dict_merge_with_tombstones_reducer]
|
||||
]
|
||||
"""virtual_path -> ``Document.id`` for lazily loaded files.
|
||||
|
||||
Populated on first read of a KB document. Used by edit_file/move_file/
|
||||
aafter_agent to map paths back to a real DB row. ``None`` values delete
|
||||
the key (tombstones).
|
||||
"""
|
||||
|
||||
dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||
"""Paths whose ``state["files"]`` content has been modified this turn."""
|
||||
|
||||
kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]]
|
||||
"""Top-K priority hints rendered as a system message before the user turn."""
|
||||
|
||||
kb_matched_chunk_ids: NotRequired[Annotated[dict[int, list[int]], _replace_reducer]]
|
||||
"""Internal: ``Document.id`` -> list of matched chunk IDs from hybrid search."""
|
||||
|
||||
kb_anon_doc: NotRequired[Annotated[KbAnonDoc | None, _replace_reducer]]
|
||||
"""Anonymous-session document loaded from Redis (read-only, no DB row)."""
|
||||
|
||||
tree_version: NotRequired[Annotated[int, _replace_reducer]]
|
||||
"""Monotonically increasing counter; bumped when commits change the KB tree."""
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KbAnonDoc",
|
||||
"KbPriorityEntry",
|
||||
"PendingMove",
|
||||
"SurfSenseFilesystemState",
|
||||
]
|
||||
|
|
@ -1,5 +1,8 @@
|
|||
"""Middleware components for the SurfSense new chat agent."""
|
||||
|
||||
from app.agents.new_chat.middleware.anonymous_document import (
|
||||
AnonymousDocumentMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
)
|
||||
|
|
@ -9,17 +12,30 @@ from app.agents.new_chat.middleware.file_intent import (
|
|||
from app.agents.new_chat.middleware.filesystem import (
|
||||
SurfSenseFilesystemMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.kb_persistence import (
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
commit_staged_filesystem_state,
|
||||
)
|
||||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
KnowledgePriorityMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.knowledge_tree import (
|
||||
KnowledgeTreeMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.memory_injection import (
|
||||
MemoryInjectionMiddleware,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnonymousDocumentMiddleware",
|
||||
"DedupHITLToolCallsMiddleware",
|
||||
"FileIntentMiddleware",
|
||||
"KnowledgeBasePersistenceMiddleware",
|
||||
"KnowledgeBaseSearchMiddleware",
|
||||
"KnowledgePriorityMiddleware",
|
||||
"KnowledgeTreeMiddleware",
|
||||
"MemoryInjectionMiddleware",
|
||||
"SurfSenseFilesystemMiddleware",
|
||||
"commit_staged_filesystem_state",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,91 @@
|
|||
"""Lightweight middleware that loads the anonymous-session document into state.
|
||||
|
||||
Anonymous chats receive a single uploaded document via Redis (no DB row,
|
||||
read-only). This middleware loads it once on the first turn into
|
||||
``state['kb_anon_doc']`` so:
|
||||
|
||||
* :class:`KnowledgeTreeMiddleware` can render the synthetic ``/documents``
|
||||
view without touching the DB.
|
||||
* :class:`KnowledgePriorityMiddleware` skips hybrid search and emits a
|
||||
degenerate priority list.
|
||||
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / ``_load_file_data``)
|
||||
recognises the synthetic path.
|
||||
|
||||
The middleware is a no-op when ``anon_session_id`` is not provided or when
|
||||
the document is already cached in state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, safe_filename
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnonymousDocumentMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Load the anonymous user's uploaded document from Redis into state."""
|
||||
|
||||
tools = ()
|
||||
state_schema = SurfSenseFilesystemState
|
||||
|
||||
def __init__(self, *, anon_session_id: str | None) -> None:
|
||||
self.anon_session_id = anon_session_id
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
if not self.anon_session_id:
|
||||
return None
|
||||
if state.get("kb_anon_doc"):
|
||||
return None
|
||||
|
||||
anon_doc = await self._load_anon_document()
|
||||
if anon_doc is None:
|
||||
return None
|
||||
return {"kb_anon_doc": anon_doc}
|
||||
|
||||
async def _load_anon_document(self) -> dict[str, Any] | None:
|
||||
"""Read ``anon:doc:<session_id>`` from Redis."""
|
||||
try:
|
||||
import redis.asyncio as aioredis # local import to keep cold paths cheap
|
||||
|
||||
from app.config import config
|
||||
|
||||
redis_client = aioredis.from_url(
|
||||
config.REDIS_APP_URL, decode_responses=True
|
||||
)
|
||||
try:
|
||||
redis_key = f"anon:doc:{self.anon_session_id}"
|
||||
data = await redis_client.get(redis_key)
|
||||
if not data:
|
||||
return None
|
||||
payload = json.loads(data)
|
||||
finally:
|
||||
await redis_client.aclose()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load anonymous document from Redis: %s", exc)
|
||||
return None
|
||||
|
||||
title = str(payload.get("filename") or "uploaded_document")
|
||||
content = str(payload.get("content") or "")
|
||||
path = f"{DOCUMENTS_ROOT}/{safe_filename(title)}"
|
||||
return {
|
||||
"path": path,
|
||||
"title": title,
|
||||
"content": content,
|
||||
"chunks": [{"chunk_id": -1, "content": content}] if content else [],
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["AnonymousDocumentMiddleware"]
|
||||
|
|
@ -21,7 +21,7 @@ from typing import Any
|
|||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
|
|
@ -217,8 +217,19 @@ def _build_recent_conversation(
|
|||
messages: list[BaseMessage], *, max_messages: int = 6
|
||||
) -> str:
|
||||
rows: list[str] = []
|
||||
for msg in messages[-max_messages:]:
|
||||
role = "user" if isinstance(msg, HumanMessage) else "assistant"
|
||||
filtered: list[tuple[str, BaseMessage]] = []
|
||||
for msg in messages:
|
||||
role: str | None = None
|
||||
if isinstance(msg, HumanMessage):
|
||||
role = "user"
|
||||
elif isinstance(msg, AIMessage):
|
||||
if getattr(msg, "tool_calls", None):
|
||||
continue
|
||||
role = "assistant"
|
||||
else:
|
||||
continue
|
||||
filtered.append((role, msg))
|
||||
for role, msg in filtered[-max_messages:]:
|
||||
text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip()
|
||||
if text:
|
||||
rows.append(f"{role}: {text[:280]}")
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,622 @@
|
|||
"""End-of-turn persistence for the cloud-mode SurfSense filesystem.
|
||||
|
||||
This middleware runs ``aafter_agent`` once per turn (cloud only). It commits
|
||||
all staged folder creations, file moves, and content writes/edits to
|
||||
Postgres in a single ordered pass:
|
||||
|
||||
1. Materialize ``staged_dirs`` into ``Folder`` rows.
|
||||
2. Apply ``pending_moves`` in order (chained moves resolved via
|
||||
``doc_id_by_path``).
|
||||
3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move
|
||||
sequences commit at the final path.
|
||||
4. Commit content writes / edits for ``/documents/*`` paths, skipping
|
||||
``temp_*`` basenames.
|
||||
|
||||
The commit body is exposed as a free function ``commit_staged_filesystem_state``
|
||||
so the optional stream-task fallback (``stream_new_chat.py``) can call the
|
||||
exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from fractional_indexing import generate_key_between
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.callbacks import dispatch_custom_event
|
||||
from langgraph.runtime import Runtime
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
parse_documents_path,
|
||||
safe_folder_segment,
|
||||
virtual_path_to_doc,
|
||||
)
|
||||
from app.agents.new_chat.state_reducers import _CLEAR
|
||||
from app.db import (
|
||||
Chunk,
|
||||
Document,
|
||||
DocumentType,
|
||||
Folder,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
from app.utils.document_converters import (
|
||||
embed_texts,
|
||||
generate_content_hash,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_TEMP_PREFIX = "temp_"
|
||||
|
||||
|
||||
def _basename(path: str) -> str:
|
||||
return path.rsplit("/", 1)[-1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Folder helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _ensure_folder_hierarchy(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
created_by_id: str | None,
|
||||
folder_parts: list[str],
|
||||
) -> int | None:
|
||||
"""Ensure a chain of folder names exists under the search space.
|
||||
|
||||
Returns the leaf folder id, or ``None`` if ``folder_parts`` is empty
|
||||
(i.e. a document directly under ``/documents/``).
|
||||
"""
|
||||
if not folder_parts:
|
||||
return None
|
||||
parent_id: int | None = None
|
||||
for raw in folder_parts:
|
||||
name = safe_folder_segment(str(raw))
|
||||
query = select(Folder).where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.name == name,
|
||||
)
|
||||
if parent_id is None:
|
||||
query = query.where(Folder.parent_id.is_(None))
|
||||
else:
|
||||
query = query.where(Folder.parent_id == parent_id)
|
||||
result = await session.execute(query)
|
||||
folder = result.scalar_one_or_none()
|
||||
if folder is None:
|
||||
sibling_query = (
|
||||
select(Folder.position).order_by(Folder.position.desc()).limit(1)
|
||||
)
|
||||
sibling_query = sibling_query.where(
|
||||
Folder.search_space_id == search_space_id
|
||||
)
|
||||
if parent_id is None:
|
||||
sibling_query = sibling_query.where(Folder.parent_id.is_(None))
|
||||
else:
|
||||
sibling_query = sibling_query.where(Folder.parent_id == parent_id)
|
||||
sibling_result = await session.execute(sibling_query)
|
||||
last_position = sibling_result.scalar_one_or_none()
|
||||
folder = Folder(
|
||||
name=name,
|
||||
position=generate_key_between(last_position, None),
|
||||
parent_id=parent_id,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=created_by_id,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(folder)
|
||||
await session.flush()
|
||||
parent_id = folder.id
|
||||
return parent_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _create_document(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
virtual_path: str,
|
||||
content: str,
|
||||
search_space_id: int,
|
||||
created_by_id: str | None,
|
||||
) -> Document:
|
||||
"""Create a NOTE Document + Chunks for ``virtual_path``."""
|
||||
folder_parts, title = parse_documents_path(virtual_path)
|
||||
if not title:
|
||||
raise ValueError(f"invalid /documents path '{virtual_path}'")
|
||||
folder_id = await _ensure_folder_hierarchy(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=created_by_id,
|
||||
folder_parts=folder_parts,
|
||||
)
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTE,
|
||||
virtual_path,
|
||||
search_space_id,
|
||||
)
|
||||
# Guard against the unique_identifier_hash constraint: another row at the
|
||||
# same virtual_path (this search space) already owns the hash. Callers are
|
||||
# expected to upsert via the wrapper, but this defends against bypasses
|
||||
# and gives a clean ValueError instead of a session-poisoning IntegrityError.
|
||||
path_collision = await session.execute(
|
||||
select(Document.id).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.unique_identifier_hash == unique_identifier_hash,
|
||||
)
|
||||
)
|
||||
if path_collision.scalar_one_or_none() is not None:
|
||||
raise ValueError(
|
||||
f"a document already exists at path '{virtual_path}' "
|
||||
"(unique_identifier_hash collision)"
|
||||
)
|
||||
content_hash = generate_content_hash(content, search_space_id)
|
||||
content_collision = await session.execute(
|
||||
select(Document.id).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.content_hash == content_hash,
|
||||
)
|
||||
)
|
||||
if content_collision.scalar_one_or_none() is not None:
|
||||
raise ValueError(
|
||||
f"a document with identical content already exists for path '{virtual_path}'"
|
||||
)
|
||||
doc = Document(
|
||||
title=title,
|
||||
document_type=DocumentType.NOTE,
|
||||
document_metadata={"virtual_path": virtual_path},
|
||||
content=content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
source_markdown=content,
|
||||
search_space_id=search_space_id,
|
||||
folder_id=folder_id,
|
||||
created_by_id=created_by_id,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(doc)
|
||||
await session.flush()
|
||||
|
||||
summary_embedding = embed_texts([content])[0]
|
||||
doc.embedding = summary_embedding
|
||||
chunks = chunk_text(content)
|
||||
if chunks:
|
||||
chunk_embeddings = embed_texts(chunks)
|
||||
session.add_all(
|
||||
[
|
||||
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||
for text, embedding in zip(chunks, chunk_embeddings, strict=True)
|
||||
]
|
||||
)
|
||||
return doc
|
||||
|
||||
|
||||
async def _update_document(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
doc_id: int,
|
||||
content: str,
|
||||
virtual_path: str,
|
||||
search_space_id: int,
|
||||
) -> Document | None:
|
||||
"""Update an existing Document's content + chunks."""
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.id == doc_id,
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
document = result.scalar_one_or_none()
|
||||
if document is None:
|
||||
return None
|
||||
|
||||
document.content = content
|
||||
document.source_markdown = content
|
||||
document.content_hash = generate_content_hash(content, search_space_id)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
metadata = dict(document.document_metadata or {})
|
||||
metadata["virtual_path"] = virtual_path
|
||||
document.document_metadata = metadata
|
||||
document.unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTE,
|
||||
virtual_path,
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
summary_embedding = embed_texts([content])[0]
|
||||
document.embedding = summary_embedding
|
||||
|
||||
await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
|
||||
chunks = chunk_text(content)
|
||||
if chunks:
|
||||
chunk_embeddings = embed_texts(chunks)
|
||||
session.add_all(
|
||||
[
|
||||
Chunk(document_id=document.id, content=text, embedding=embedding)
|
||||
for text, embedding in zip(chunks, chunk_embeddings, strict=True)
|
||||
]
|
||||
)
|
||||
return document
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Move helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _apply_move(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
created_by_id: str | None,
|
||||
move: dict[str, Any],
|
||||
doc_id_by_path: dict[str, int],
|
||||
doc_id_path_tombstones: dict[str, int | None],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Apply a single staged move; updates the in-memory mapping for chain resolution."""
|
||||
source = str(move.get("source") or "")
|
||||
dest = str(move.get("dest") or "")
|
||||
if not source or not dest or source == dest:
|
||||
return None
|
||||
|
||||
if not source.startswith(DOCUMENTS_ROOT + "/") or not dest.startswith(
|
||||
DOCUMENTS_ROOT + "/"
|
||||
):
|
||||
return None
|
||||
|
||||
doc_id: int | None = doc_id_by_path.get(source)
|
||||
document: Document | None = None
|
||||
if doc_id is not None:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.id == doc_id,
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
document = result.scalar_one_or_none()
|
||||
if document is None:
|
||||
document = await virtual_path_to_doc(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
virtual_path=source,
|
||||
)
|
||||
if document is None:
|
||||
logger.info(
|
||||
"kb_persistence: skipping move %s -> %s (source not found)",
|
||||
source,
|
||||
dest,
|
||||
)
|
||||
return None
|
||||
|
||||
folder_parts, new_title = parse_documents_path(dest)
|
||||
if not new_title:
|
||||
return None
|
||||
folder_id = await _ensure_folder_hierarchy(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=created_by_id,
|
||||
folder_parts=folder_parts,
|
||||
)
|
||||
|
||||
document.title = new_title
|
||||
document.folder_id = folder_id
|
||||
metadata = dict(document.document_metadata or {})
|
||||
metadata["virtual_path"] = dest
|
||||
document.document_metadata = metadata
|
||||
document.unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTE,
|
||||
dest,
|
||||
search_space_id,
|
||||
)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
|
||||
doc_id_by_path.pop(source, None)
|
||||
doc_id_by_path[dest] = document.id
|
||||
doc_id_path_tombstones[source] = None
|
||||
doc_id_path_tombstones[dest] = document.id
|
||||
return {"id": document.id, "source": source, "dest": dest, "title": new_title}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Commit body
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def commit_staged_filesystem_state(
|
||||
state: dict[str, Any] | AgentState,
|
||||
*,
|
||||
search_space_id: int,
|
||||
created_by_id: str | None,
|
||||
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||
dispatch_events: bool = True,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Commit all staged filesystem changes; return the state delta for reducers.
|
||||
|
||||
Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent`
|
||||
and the optional stream-task fallback.
|
||||
"""
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
|
||||
state_dict: dict[str, Any] = (
|
||||
dict(state)
|
||||
if isinstance(state, dict)
|
||||
else dict(getattr(state, "values", {}) or {})
|
||||
)
|
||||
|
||||
files: dict[str, Any] = state_dict.get("files") or {}
|
||||
staged_dirs: list[str] = list(state_dict.get("staged_dirs") or [])
|
||||
pending_moves: list[dict[str, Any]] = list(state_dict.get("pending_moves") or [])
|
||||
dirty_paths: list[str] = list(state_dict.get("dirty_paths") or [])
|
||||
doc_id_by_path: dict[str, int] = dict(state_dict.get("doc_id_by_path") or {})
|
||||
kb_anon_doc = state_dict.get("kb_anon_doc")
|
||||
|
||||
if kb_anon_doc:
|
||||
temp_paths = [
|
||||
p
|
||||
for p in files
|
||||
if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX)
|
||||
]
|
||||
return {
|
||||
"dirty_paths": [_CLEAR],
|
||||
"staged_dirs": [_CLEAR],
|
||||
"pending_moves": [_CLEAR],
|
||||
"files": dict.fromkeys(temp_paths),
|
||||
}
|
||||
|
||||
if not (staged_dirs or pending_moves or dirty_paths):
|
||||
return None
|
||||
|
||||
committed_creates: list[dict[str, Any]] = []
|
||||
committed_updates: list[dict[str, Any]] = []
|
||||
discarded: list[str] = []
|
||||
applied_moves: list[dict[str, Any]] = []
|
||||
doc_id_path_tombstones: dict[str, int | None] = {}
|
||||
tree_changed = False
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
for folder_path in staged_dirs:
|
||||
if not isinstance(folder_path, str):
|
||||
continue
|
||||
if not folder_path.startswith(DOCUMENTS_ROOT):
|
||||
continue
|
||||
rel = folder_path[len(DOCUMENTS_ROOT) :].strip("/")
|
||||
folder_parts_full = [p for p in rel.split("/") if p]
|
||||
if not folder_parts_full:
|
||||
continue
|
||||
await _ensure_folder_hierarchy(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=created_by_id,
|
||||
folder_parts=folder_parts_full,
|
||||
)
|
||||
tree_changed = True
|
||||
|
||||
for move in pending_moves:
|
||||
applied = await _apply_move(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=created_by_id,
|
||||
move=move,
|
||||
doc_id_by_path=doc_id_by_path,
|
||||
doc_id_path_tombstones=doc_id_path_tombstones,
|
||||
)
|
||||
if applied:
|
||||
applied_moves.append(applied)
|
||||
tree_changed = True
|
||||
|
||||
move_alias = {
|
||||
m["source"]: m["dest"] for m in pending_moves if m.get("source")
|
||||
}
|
||||
|
||||
def _final_path(path: str) -> str:
|
||||
seen: set[str] = set()
|
||||
while path in move_alias and path not in seen:
|
||||
seen.add(path)
|
||||
path = move_alias[path]
|
||||
return path
|
||||
|
||||
kb_dirty_seen: set[str] = set()
|
||||
kb_dirty: list[str] = []
|
||||
for raw in dirty_paths:
|
||||
if not isinstance(raw, str):
|
||||
continue
|
||||
final = _final_path(raw)
|
||||
if not final.startswith(DOCUMENTS_ROOT + "/"):
|
||||
continue
|
||||
if final in kb_dirty_seen:
|
||||
continue
|
||||
kb_dirty_seen.add(final)
|
||||
kb_dirty.append(final)
|
||||
|
||||
for path in kb_dirty:
|
||||
basename = _basename(path)
|
||||
if basename.startswith(_TEMP_PREFIX):
|
||||
discarded.append(path)
|
||||
continue
|
||||
file_data = files.get(path)
|
||||
if not isinstance(file_data, dict):
|
||||
continue
|
||||
content = "\n".join(file_data.get("content") or [])
|
||||
doc_id = doc_id_by_path.get(path)
|
||||
if doc_id is None:
|
||||
# The in-memory ``doc_id_by_path`` is per-thread and starts
|
||||
# empty in every new chat. If the agent writes to a path
|
||||
# that already exists in the DB (e.g. a previous chat's
|
||||
# ``notes.md``), we must NOT try to INSERT — it would hit
|
||||
# ``unique_identifier_hash`` (path-derived). Look up the
|
||||
# existing doc and update it in place instead.
|
||||
existing = await virtual_path_to_doc(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
virtual_path=path,
|
||||
)
|
||||
if existing is not None:
|
||||
doc_id = existing.id
|
||||
doc_id_by_path[path] = existing.id
|
||||
if doc_id is not None:
|
||||
updated = await _update_document(
|
||||
session,
|
||||
doc_id=doc_id,
|
||||
content=content,
|
||||
virtual_path=path,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if updated is not None:
|
||||
committed_updates.append(
|
||||
{
|
||||
"id": updated.id,
|
||||
"title": updated.title,
|
||||
"documentType": DocumentType.NOTE.value,
|
||||
"searchSpaceId": search_space_id,
|
||||
"folderId": updated.folder_id,
|
||||
"createdById": str(created_by_id)
|
||||
if created_by_id
|
||||
else None,
|
||||
"virtualPath": path,
|
||||
}
|
||||
)
|
||||
else:
|
||||
try:
|
||||
new_doc = await _create_document(
|
||||
session,
|
||||
virtual_path=path,
|
||||
content=content,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=created_by_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.warning(
|
||||
"kb_persistence: skipping %s create: %s", path, exc
|
||||
)
|
||||
continue
|
||||
doc_id_by_path[path] = new_doc.id
|
||||
committed_creates.append(
|
||||
{
|
||||
"id": new_doc.id,
|
||||
"title": new_doc.title,
|
||||
"documentType": DocumentType.NOTE.value,
|
||||
"searchSpaceId": search_space_id,
|
||||
"folderId": new_doc.folder_id,
|
||||
"createdById": str(created_by_id)
|
||||
if created_by_id
|
||||
else None,
|
||||
"virtualPath": path,
|
||||
}
|
||||
)
|
||||
tree_changed = True
|
||||
|
||||
await session.commit()
|
||||
except Exception: # pragma: no cover - rollback safety net
|
||||
logger.exception(
|
||||
"kb_persistence: commit failed (search_space=%s)", search_space_id
|
||||
)
|
||||
return None
|
||||
|
||||
if dispatch_events:
|
||||
for payload in committed_creates:
|
||||
try:
|
||||
dispatch_custom_event("document_created", payload)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"kb_persistence: failed to dispatch document_created event"
|
||||
)
|
||||
for payload in committed_updates:
|
||||
try:
|
||||
dispatch_custom_event("document_updated", payload)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"kb_persistence: failed to dispatch document_updated event"
|
||||
)
|
||||
|
||||
temp_paths = [
|
||||
p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX)
|
||||
]
|
||||
|
||||
doc_id_update: dict[str, int | None] = {**doc_id_path_tombstones}
|
||||
for payload in committed_creates:
|
||||
doc_id_update[str(payload.get("virtualPath") or "")] = int(payload["id"])
|
||||
|
||||
delta: dict[str, Any] = {
|
||||
"dirty_paths": [_CLEAR],
|
||||
"staged_dirs": [_CLEAR],
|
||||
"pending_moves": [_CLEAR],
|
||||
}
|
||||
if temp_paths:
|
||||
delta["files"] = dict.fromkeys(temp_paths)
|
||||
if doc_id_update:
|
||||
delta["doc_id_by_path"] = doc_id_update
|
||||
if tree_changed:
|
||||
delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1
|
||||
|
||||
logger.info(
|
||||
"kb_persistence: commit (search_space=%s) creates=%d updates=%d "
|
||||
"moves=%d staged_dirs=%d discarded=%d",
|
||||
search_space_id,
|
||||
len(committed_creates),
|
||||
len(committed_updates),
|
||||
len(applied_moves),
|
||||
len(staged_dirs),
|
||||
len(discarded),
|
||||
)
|
||||
return delta
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class KnowledgeBasePersistenceMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""End-of-turn cloud persistence for the SurfSense filesystem agent."""
|
||||
|
||||
tools = ()
|
||||
state_schema = SurfSenseFilesystemState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
search_space_id: int,
|
||||
created_by_id: str | None,
|
||||
filesystem_mode: FilesystemMode,
|
||||
) -> None:
|
||||
self.search_space_id = search_space_id
|
||||
self.created_by_id = created_by_id
|
||||
self.filesystem_mode = filesystem_mode
|
||||
|
||||
async def aafter_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
return await commit_staged_filesystem_state(
|
||||
state,
|
||||
search_space_id=self.search_space_id,
|
||||
created_by_id=self.created_by_id,
|
||||
filesystem_mode=self.filesystem_mode,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KnowledgeBasePersistenceMiddleware",
|
||||
"commit_staged_filesystem_state",
|
||||
]
|
||||
|
|
@ -0,0 +1,963 @@
|
|||
"""Postgres-backed virtual filesystem for the SurfSense agent (cloud mode).
|
||||
|
||||
The backend is **strictly conforming** to deepagents'
|
||||
:class:`BackendProtocol`. It returns ``WriteResult`` / ``EditResult`` / list
|
||||
shapes exactly as upstream expects (no extra fields). All side-state
|
||||
plumbing — ``dirty_paths``, ``doc_id_by_path``, ``staged_dirs``,
|
||||
``pending_moves``, ``files`` cache — is appended by the overridden tool
|
||||
wrappers in :class:`SurfSenseFilesystemMiddleware` via ``Command.update``.
|
||||
|
||||
The backend never writes to Postgres. End-of-turn persistence is handled by
|
||||
:class:`KnowledgeBasePersistenceMiddleware`. This module is purely a
|
||||
read-side and a state-merging helper.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import fnmatch
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC
|
||||
from typing import Any
|
||||
|
||||
from deepagents.backends.protocol import (
|
||||
BackendProtocol,
|
||||
EditResult,
|
||||
FileDownloadResponse,
|
||||
FileInfo,
|
||||
FileUploadResponse,
|
||||
GrepMatch,
|
||||
WriteResult,
|
||||
)
|
||||
from deepagents.backends.utils import (
|
||||
create_file_data,
|
||||
file_data_to_string,
|
||||
format_read_response,
|
||||
perform_string_replacement,
|
||||
update_file_data,
|
||||
)
|
||||
from langchain.tools import ToolRuntime
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.document_xml import build_document_xml
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
virtual_path_to_doc,
|
||||
)
|
||||
from app.db import Chunk, Document, shielded_async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TEMP_PREFIX = "temp_"
|
||||
_GREP_MAX_TOTAL_MATCHES = 50
|
||||
_GREP_MAX_PER_DOC = 5
|
||||
|
||||
|
||||
def _basename(path: str) -> str:
|
||||
return path.rsplit("/", 1)[-1]
|
||||
|
||||
|
||||
def _is_under(child: str, parent: str) -> bool:
|
||||
"""Return True iff ``child`` is at-or-under ``parent`` (directory semantics)."""
|
||||
if parent == "/":
|
||||
return child.startswith("/")
|
||||
return child == parent or child.startswith(parent.rstrip("/") + "/")
|
||||
|
||||
|
||||
def paginate_listing(
|
||||
infos: list[FileInfo],
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int | None = None,
|
||||
) -> list[FileInfo]:
|
||||
"""Paginate a listing produced by :meth:`KBPostgresBackend.als_info`."""
|
||||
if offset < 0:
|
||||
offset = 0
|
||||
end: int | None
|
||||
end = None if limit is None or limit < 0 else offset + limit
|
||||
return list(infos[offset:end])
|
||||
|
||||
|
||||
class KBPostgresBackend(BackendProtocol):
|
||||
"""Lazy, read-only Postgres view for ``/documents/*`` virtual paths.
|
||||
|
||||
The backend exposes a virtual ``/documents/`` namespace mirroring the
|
||||
``Folder``/``Document`` graph. Reads materialize XML on first access and
|
||||
cache it via the overriding tool wrappers (NOT here). Writes never touch
|
||||
the DB — they return ``files_update`` deltas that the wrappers turn into
|
||||
Command updates, and the persistence middleware commits them at end of
|
||||
turn.
|
||||
"""
|
||||
|
||||
_IMAGE_EXTENSIONS = frozenset({".png", ".jpg", ".jpeg", ".gif", ".webp"})
|
||||
|
||||
def __init__(self, search_space_id: int, runtime: ToolRuntime) -> None:
|
||||
self.search_space_id = search_space_id
|
||||
self.runtime = runtime
|
||||
|
||||
@property
|
||||
def state(self) -> dict[str, Any]:
|
||||
return getattr(self.runtime, "state", {}) or {}
|
||||
|
||||
# ------------------------------------------------------------------ helpers
|
||||
|
||||
def _state_files(self) -> dict[str, Any]:
|
||||
return dict(self.state.get("files") or {})
|
||||
|
||||
def _staged_dirs(self) -> list[str]:
|
||||
return list(self.state.get("staged_dirs") or [])
|
||||
|
||||
def _pending_moves(self) -> list[dict[str, Any]]:
|
||||
return list(self.state.get("pending_moves") or [])
|
||||
|
||||
def _kb_anon_doc(self) -> dict[str, Any] | None:
|
||||
anon = self.state.get("kb_anon_doc")
|
||||
return anon if isinstance(anon, dict) else None
|
||||
|
||||
def _matched_chunk_ids(self, doc_id: int) -> set[int]:
|
||||
mapping = self.state.get("kb_matched_chunk_ids") or {}
|
||||
try:
|
||||
return set(mapping.get(doc_id, []) or [])
|
||||
except TypeError:
|
||||
return set()
|
||||
|
||||
@staticmethod
|
||||
def _file_data_size(file_data: dict[str, Any]) -> int:
|
||||
try:
|
||||
return len("\n".join(file_data.get("content") or []))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _normalize_listing_path(self, path: str) -> str:
|
||||
if not path:
|
||||
return DOCUMENTS_ROOT
|
||||
if path == "/":
|
||||
return path
|
||||
return path.rstrip("/") if path != "/" else path
|
||||
|
||||
def _moved_view_paths(
|
||||
self,
|
||||
existing: dict[str, dict[str, Any]],
|
||||
) -> tuple[set[str], dict[str, str]]:
|
||||
"""Apply ``pending_moves`` to a path set and return ``(removed, alias)``.
|
||||
|
||||
Removed paths should disappear from listings; ``alias[source] = dest``
|
||||
means a virtual entry should appear at ``dest`` even if no DB row is
|
||||
yet there.
|
||||
"""
|
||||
removed: set[str] = set()
|
||||
alias: dict[str, str] = {}
|
||||
for move in self._pending_moves():
|
||||
src = move.get("source")
|
||||
dst = move.get("dest")
|
||||
if not src or not dst:
|
||||
continue
|
||||
removed.add(src)
|
||||
alias[src] = dst
|
||||
existing.pop(src, None)
|
||||
return removed, alias
|
||||
|
||||
# ------------------------------------------------------------------ ls/read
|
||||
|
||||
async def als_info(self, path: str) -> list[FileInfo]: # type: ignore[override]
|
||||
normalized = self._normalize_listing_path(path)
|
||||
infos: list[FileInfo] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
anon = self._kb_anon_doc()
|
||||
if anon:
|
||||
anon_path = str(anon.get("path") or "")
|
||||
if (
|
||||
anon_path
|
||||
and _is_under(anon_path, normalized)
|
||||
and anon_path != normalized
|
||||
and anon_path not in seen
|
||||
):
|
||||
infos.append(
|
||||
FileInfo(
|
||||
path=anon_path,
|
||||
is_dir=False,
|
||||
size=len(str(anon.get("content") or "")),
|
||||
modified_at="",
|
||||
)
|
||||
)
|
||||
seen.add(anon_path)
|
||||
|
||||
files = self._state_files()
|
||||
moved_removed, moved_alias = self._moved_view_paths(files)
|
||||
|
||||
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
db_infos, subdir_paths = await self._list_db_directory(
|
||||
session, normalized
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning("KBPostgresBackend.als_info DB error: %s", exc)
|
||||
db_infos, subdir_paths = [], set()
|
||||
|
||||
for info in db_infos:
|
||||
p = info.get("path", "")
|
||||
if not p or p in seen or p in moved_removed:
|
||||
continue
|
||||
infos.append(info)
|
||||
seen.add(p)
|
||||
|
||||
for src, dst in moved_alias.items():
|
||||
if src not in seen:
|
||||
if not _is_under(dst, normalized):
|
||||
continue
|
||||
rel = (
|
||||
dst[len(normalized) :].lstrip("/")
|
||||
if normalized != "/"
|
||||
else dst.lstrip("/")
|
||||
)
|
||||
if "/" in rel:
|
||||
subdir_paths.add(
|
||||
(normalized.rstrip("/") + "/" + rel.split("/", 1)[0])
|
||||
if normalized != "/"
|
||||
else "/" + rel.split("/", 1)[0]
|
||||
)
|
||||
continue
|
||||
if dst in seen:
|
||||
continue
|
||||
fd = files.get(dst)
|
||||
size = self._file_data_size(fd) if isinstance(fd, dict) else 0
|
||||
infos.append(
|
||||
FileInfo(
|
||||
path=dst,
|
||||
is_dir=False,
|
||||
size=int(size),
|
||||
modified_at=fd.get("modified_at", "")
|
||||
if isinstance(fd, dict)
|
||||
else "",
|
||||
)
|
||||
)
|
||||
seen.add(dst)
|
||||
|
||||
for staged in self._staged_dirs():
|
||||
if not staged or not staged.startswith(DOCUMENTS_ROOT):
|
||||
continue
|
||||
if staged == normalized:
|
||||
continue
|
||||
if not _is_under(staged, normalized):
|
||||
continue
|
||||
rel = (
|
||||
staged[len(normalized) :].lstrip("/")
|
||||
if normalized != "/"
|
||||
else staged.lstrip("/")
|
||||
)
|
||||
if not rel:
|
||||
continue
|
||||
first = rel.split("/", 1)[0]
|
||||
immediate = (
|
||||
normalized.rstrip("/") + "/" + first
|
||||
if normalized != "/"
|
||||
else "/" + first
|
||||
)
|
||||
subdir_paths.add(immediate)
|
||||
|
||||
for sub in sorted(subdir_paths):
|
||||
if sub in seen:
|
||||
continue
|
||||
infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at=""))
|
||||
seen.add(sub)
|
||||
|
||||
for path_key, fd in files.items():
|
||||
if not isinstance(path_key, str) or path_key in seen:
|
||||
continue
|
||||
if not _is_under(path_key, normalized) or path_key == normalized:
|
||||
continue
|
||||
if normalized == "/":
|
||||
rel = path_key.lstrip("/")
|
||||
else:
|
||||
rel = path_key[len(normalized) :].lstrip("/")
|
||||
if not rel:
|
||||
continue
|
||||
if "/" in rel:
|
||||
first = rel.split("/", 1)[0]
|
||||
immediate = (
|
||||
normalized.rstrip("/") + "/" + first
|
||||
if normalized != "/"
|
||||
else "/" + first
|
||||
)
|
||||
if immediate not in seen:
|
||||
infos.append(
|
||||
FileInfo(path=immediate, is_dir=True, size=0, modified_at="")
|
||||
)
|
||||
seen.add(immediate)
|
||||
continue
|
||||
include = path_key.startswith(DOCUMENTS_ROOT) or _basename(
|
||||
path_key
|
||||
).startswith(_TEMP_PREFIX)
|
||||
if not include:
|
||||
continue
|
||||
size = self._file_data_size(fd) if isinstance(fd, dict) else 0
|
||||
infos.append(
|
||||
FileInfo(
|
||||
path=path_key,
|
||||
is_dir=False,
|
||||
size=int(size),
|
||||
modified_at=fd.get("modified_at", "")
|
||||
if isinstance(fd, dict)
|
||||
else "",
|
||||
)
|
||||
)
|
||||
seen.add(path_key)
|
||||
|
||||
infos.sort(key=lambda fi: (not fi.get("is_dir", False), fi.get("path", "")))
|
||||
return infos
|
||||
|
||||
def ls_info(self, path: str) -> list[FileInfo]: # type: ignore[override]
|
||||
return asyncio.run(self.als_info(path))
|
||||
|
||||
async def _list_db_directory(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
normalized_path: str,
|
||||
) -> tuple[list[FileInfo], set[str]]:
|
||||
"""List immediate Folders + Documents at ``normalized_path``.
|
||||
|
||||
Returns ``(file_infos, subdirectory_paths)``. ``normalized_path`` may
|
||||
be ``/`` (synthesizes ``/documents``) or a path under ``/documents``.
|
||||
"""
|
||||
if normalized_path == "/":
|
||||
return (
|
||||
[],
|
||||
{DOCUMENTS_ROOT},
|
||||
)
|
||||
|
||||
if not normalized_path.startswith(DOCUMENTS_ROOT):
|
||||
return [], set()
|
||||
|
||||
index = await build_path_index(session, self.search_space_id)
|
||||
target_folder_id: int | None = None
|
||||
if normalized_path != DOCUMENTS_ROOT:
|
||||
target_path = normalized_path
|
||||
matches = [
|
||||
fid for fid, fpath in index.folder_paths.items() if fpath == target_path
|
||||
]
|
||||
if not matches:
|
||||
return [], set()
|
||||
target_folder_id = matches[0]
|
||||
|
||||
result = await session.execute(
|
||||
select(Document.id, Document.title, Document.folder_id, Document.updated_at)
|
||||
.where(Document.search_space_id == self.search_space_id)
|
||||
.where(
|
||||
Document.folder_id == target_folder_id
|
||||
if target_folder_id is not None
|
||||
else Document.folder_id.is_(None)
|
||||
)
|
||||
)
|
||||
rows = result.all()
|
||||
|
||||
file_infos: list[FileInfo] = []
|
||||
for row in rows:
|
||||
path = doc_to_virtual_path(
|
||||
doc_id=row.id,
|
||||
title=str(row.title or "untitled"),
|
||||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
modified = ""
|
||||
if row.updated_at is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
modified = row.updated_at.astimezone(UTC).isoformat()
|
||||
file_infos.append(
|
||||
FileInfo(
|
||||
path=path,
|
||||
is_dir=False,
|
||||
size=0,
|
||||
modified_at=modified,
|
||||
)
|
||||
)
|
||||
|
||||
subdirs: set[str] = set()
|
||||
for _fid, fpath in index.folder_paths.items():
|
||||
if fpath == normalized_path:
|
||||
continue
|
||||
base = normalized_path.rstrip("/")
|
||||
if not fpath.startswith(base + "/"):
|
||||
continue
|
||||
rel = fpath[len(base) + 1 :]
|
||||
if "/" in rel:
|
||||
continue
|
||||
subdirs.add(base + "/" + rel)
|
||||
return file_infos, subdirs
|
||||
|
||||
async def aread( # type: ignore[override]
|
||||
self,
|
||||
file_path: str,
|
||||
offset: int = 0,
|
||||
limit: int = 2000,
|
||||
) -> str:
|
||||
files = self._state_files()
|
||||
file_data = files.get(file_path)
|
||||
if file_data is not None:
|
||||
return format_read_response(file_data, offset, limit)
|
||||
|
||||
loaded = await self._load_file_data(file_path)
|
||||
if loaded is None:
|
||||
return f"Error: File '{file_path}' not found"
|
||||
file_data, _ = loaded
|
||||
return format_read_response(file_data, offset, limit)
|
||||
|
||||
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str: # type: ignore[override]
|
||||
return asyncio.run(self.aread(file_path, offset, limit))
|
||||
|
||||
async def _load_file_data(
|
||||
self,
|
||||
path: str,
|
||||
) -> tuple[dict[str, Any], int | None] | None:
|
||||
"""Lazy-load a virtual KB document into a deepagents ``FileData``.
|
||||
|
||||
Returns ``(file_data, doc_id)`` or ``None`` if the path doesn't map
|
||||
to any known document. ``doc_id`` is ``None`` for the synthetic
|
||||
anonymous document so the caller doesn't track it as a DB-backed file.
|
||||
"""
|
||||
anon = self._kb_anon_doc()
|
||||
if anon and str(anon.get("path") or "") == path:
|
||||
doc_payload = {
|
||||
"document_id": -1,
|
||||
"chunks": list(anon.get("chunks") or []),
|
||||
"matched_chunk_ids": [],
|
||||
"document": {
|
||||
"id": -1,
|
||||
"title": anon.get("title") or "uploaded_document",
|
||||
"document_type": "FILE",
|
||||
"metadata": {"source": "anonymous_upload"},
|
||||
},
|
||||
"source": "FILE",
|
||||
}
|
||||
xml = build_document_xml(doc_payload, matched_chunk_ids=set())
|
||||
file_data = create_file_data(xml)
|
||||
return file_data, None
|
||||
|
||||
if not path.startswith(DOCUMENTS_ROOT):
|
||||
return None
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
document = await virtual_path_to_doc(
|
||||
session,
|
||||
search_space_id=self.search_space_id,
|
||||
virtual_path=path,
|
||||
)
|
||||
if document is None:
|
||||
return None
|
||||
chunk_rows = await session.execute(
|
||||
select(Chunk.id, Chunk.content)
|
||||
.where(Chunk.document_id == document.id)
|
||||
.order_by(Chunk.id)
|
||||
)
|
||||
chunks = [
|
||||
{"chunk_id": row.id, "content": row.content} for row in chunk_rows.all()
|
||||
]
|
||||
|
||||
doc_payload = {
|
||||
"document_id": document.id,
|
||||
"chunks": chunks,
|
||||
"matched_chunk_ids": list(self._matched_chunk_ids(document.id)),
|
||||
"document": {
|
||||
"id": document.id,
|
||||
"title": document.title,
|
||||
"document_type": (
|
||||
document.document_type.value
|
||||
if getattr(document, "document_type", None) is not None
|
||||
else "UNKNOWN"
|
||||
),
|
||||
"metadata": dict(document.document_metadata or {}),
|
||||
},
|
||||
"source": (
|
||||
document.document_type.value
|
||||
if getattr(document, "document_type", None) is not None
|
||||
else "UNKNOWN"
|
||||
),
|
||||
}
|
||||
xml = build_document_xml(
|
||||
doc_payload,
|
||||
matched_chunk_ids=self._matched_chunk_ids(document.id),
|
||||
)
|
||||
file_data = create_file_data(xml)
|
||||
return file_data, document.id
|
||||
|
||||
# ------------------------------------------------------------------ writes
|
||||
|
||||
async def awrite(self, file_path: str, content: str) -> WriteResult: # type: ignore[override]
|
||||
files = self._state_files()
|
||||
if file_path in files:
|
||||
return WriteResult(
|
||||
error=(
|
||||
f"Cannot write to {file_path} because it already exists. "
|
||||
"Read and then make an edit, or write to a new path."
|
||||
)
|
||||
)
|
||||
new_file_data = create_file_data(content)
|
||||
return WriteResult(path=file_path, files_update={file_path: new_file_data})
|
||||
|
||||
def write(self, file_path: str, content: str) -> WriteResult: # type: ignore[override]
|
||||
return asyncio.run(self.awrite(file_path, content))
|
||||
|
||||
async def aedit( # type: ignore[override]
|
||||
self,
|
||||
file_path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
) -> EditResult:
|
||||
files = self._state_files()
|
||||
file_data = files.get(file_path)
|
||||
if file_data is None:
|
||||
loaded = await self._load_file_data(file_path)
|
||||
if loaded is None:
|
||||
return EditResult(error=f"Error: File '{file_path}' not found")
|
||||
file_data, _ = loaded
|
||||
|
||||
content = file_data_to_string(file_data)
|
||||
result = perform_string_replacement(
|
||||
content, old_string, new_string, replace_all
|
||||
)
|
||||
if isinstance(result, str):
|
||||
return EditResult(error=result)
|
||||
|
||||
new_content, occurrences = result
|
||||
new_file_data = update_file_data(file_data, new_content)
|
||||
return EditResult(
|
||||
path=file_path,
|
||||
files_update={file_path: new_file_data},
|
||||
occurrences=int(occurrences),
|
||||
)
|
||||
|
||||
def edit( # type: ignore[override]
|
||||
self,
|
||||
file_path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
) -> EditResult:
|
||||
return asyncio.run(self.aedit(file_path, old_string, new_string, replace_all))
|
||||
|
||||
# ------------------------------------------------------------------ glob/grep
|
||||
|
||||
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: # type: ignore[override]
|
||||
normalized = self._normalize_listing_path(path)
|
||||
results: list[FileInfo] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
files = self._state_files()
|
||||
moved_removed, _ = self._moved_view_paths(files)
|
||||
regex = re.compile(fnmatch.translate(pattern))
|
||||
for path_key, fd in files.items():
|
||||
if path_key in moved_removed:
|
||||
continue
|
||||
if not _is_under(path_key, normalized):
|
||||
continue
|
||||
rel = (
|
||||
path_key[len(normalized) :].lstrip("/")
|
||||
if normalized != "/"
|
||||
else path_key.lstrip("/")
|
||||
)
|
||||
if not regex.match(rel) and not regex.match(path_key):
|
||||
continue
|
||||
if path_key in seen:
|
||||
continue
|
||||
size = self._file_data_size(fd) if isinstance(fd, dict) else 0
|
||||
results.append(
|
||||
FileInfo(
|
||||
path=path_key,
|
||||
is_dir=False,
|
||||
size=int(size),
|
||||
modified_at=fd.get("modified_at", "")
|
||||
if isinstance(fd, dict)
|
||||
else "",
|
||||
)
|
||||
)
|
||||
seen.add(path_key)
|
||||
|
||||
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
index = await build_path_index(session, self.search_space_id)
|
||||
rows = await session.execute(
|
||||
select(Document.id, Document.title, Document.folder_id).where(
|
||||
Document.search_space_id == self.search_space_id
|
||||
)
|
||||
)
|
||||
for row in rows.all():
|
||||
candidate = doc_to_virtual_path(
|
||||
doc_id=row.id,
|
||||
title=str(row.title or "untitled"),
|
||||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
if candidate in seen or candidate in moved_removed:
|
||||
continue
|
||||
if not _is_under(candidate, normalized):
|
||||
continue
|
||||
rel = (
|
||||
candidate[len(normalized) :].lstrip("/")
|
||||
if normalized != "/"
|
||||
else candidate.lstrip("/")
|
||||
)
|
||||
if not regex.match(rel) and not regex.match(candidate):
|
||||
continue
|
||||
results.append(
|
||||
FileInfo(
|
||||
path=candidate, is_dir=False, size=0, modified_at=""
|
||||
)
|
||||
)
|
||||
seen.add(candidate)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning("KBPostgresBackend.aglob_info DB error: %s", exc)
|
||||
|
||||
results.sort(key=lambda fi: fi.get("path", ""))
|
||||
return results
|
||||
|
||||
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: # type: ignore[override]
|
||||
return asyncio.run(self.aglob_info(pattern, path))
|
||||
|
||||
async def agrep_raw( # type: ignore[override]
|
||||
self,
|
||||
pattern: str,
|
||||
path: str | None = None,
|
||||
glob: str | None = None,
|
||||
) -> list[GrepMatch] | str:
|
||||
if not pattern:
|
||||
return "Error: pattern cannot be empty"
|
||||
|
||||
normalized = self._normalize_listing_path(path or "/")
|
||||
matches: list[GrepMatch] = []
|
||||
|
||||
files = self._state_files()
|
||||
moved_removed, _ = self._moved_view_paths(files)
|
||||
glob_re = re.compile(fnmatch.translate(glob)) if glob else None
|
||||
for path_key, fd in files.items():
|
||||
if path_key in moved_removed:
|
||||
continue
|
||||
if not _is_under(path_key, normalized):
|
||||
continue
|
||||
if glob_re is not None and not glob_re.match(_basename(path_key)):
|
||||
continue
|
||||
if not isinstance(fd, dict):
|
||||
continue
|
||||
for line_no, line in enumerate(fd.get("content") or [], 1):
|
||||
if pattern in line:
|
||||
matches.append(
|
||||
GrepMatch(path=path_key, line=int(line_no), text=str(line))
|
||||
)
|
||||
if len(matches) >= _GREP_MAX_TOTAL_MATCHES:
|
||||
return matches
|
||||
|
||||
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
index = await build_path_index(session, self.search_space_id)
|
||||
sub = (
|
||||
select(Chunk.document_id, Chunk.id, Chunk.content)
|
||||
.join(Document, Document.id == Chunk.document_id)
|
||||
.where(Document.search_space_id == self.search_space_id)
|
||||
.where(Chunk.content.ilike(f"%{pattern}%"))
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunk_rows = await session.execute(sub)
|
||||
per_doc: dict[int, int] = {}
|
||||
doc_id_to_path: dict[int, str] = {}
|
||||
needed_doc_ids: set[int] = set()
|
||||
chunk_buffer: list[tuple[int, int, str]] = []
|
||||
for row in chunk_rows.all():
|
||||
per_doc.setdefault(row.document_id, 0)
|
||||
if per_doc[row.document_id] >= _GREP_MAX_PER_DOC:
|
||||
continue
|
||||
per_doc[row.document_id] += 1
|
||||
chunk_buffer.append((row.document_id, row.id, row.content))
|
||||
needed_doc_ids.add(row.document_id)
|
||||
if sum(per_doc.values()) >= _GREP_MAX_TOTAL_MATCHES - len(
|
||||
matches
|
||||
):
|
||||
break
|
||||
if needed_doc_ids:
|
||||
doc_rows = await session.execute(
|
||||
select(
|
||||
Document.id, Document.title, Document.folder_id
|
||||
).where(Document.id.in_(list(needed_doc_ids)))
|
||||
)
|
||||
for row in doc_rows.all():
|
||||
doc_id_to_path[row.id] = doc_to_virtual_path(
|
||||
doc_id=row.id,
|
||||
title=str(row.title or "untitled"),
|
||||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
for doc_id, chunk_id, content in chunk_buffer:
|
||||
candidate = doc_id_to_path.get(doc_id)
|
||||
if not candidate or candidate in moved_removed:
|
||||
continue
|
||||
if not _is_under(candidate, normalized):
|
||||
continue
|
||||
if glob_re is not None and not glob_re.match(
|
||||
_basename(candidate)
|
||||
):
|
||||
continue
|
||||
snippet = " ".join(str(content).split())[:240]
|
||||
matches.append(
|
||||
GrepMatch(
|
||||
path=candidate,
|
||||
line=0,
|
||||
text=(
|
||||
f"<chunk-match in {candidate} chunk_id={chunk_id}>: "
|
||||
f"{snippet}"
|
||||
),
|
||||
)
|
||||
)
|
||||
if len(matches) >= _GREP_MAX_TOTAL_MATCHES:
|
||||
break
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning("KBPostgresBackend.agrep_raw DB error: %s", exc)
|
||||
|
||||
return matches
|
||||
|
||||
def grep_raw( # type: ignore[override]
|
||||
self,
|
||||
pattern: str,
|
||||
path: str | None = None,
|
||||
glob: str | None = None,
|
||||
) -> list[GrepMatch] | str:
|
||||
return asyncio.run(self.agrep_raw(pattern, path, glob))
|
||||
|
||||
# ------------------------------------------------------------------ list_tree (helper)
|
||||
|
||||
async def alist_tree_listing(
|
||||
self,
|
||||
path: str = DOCUMENTS_ROOT,
|
||||
*,
|
||||
max_depth: int | None = 8,
|
||||
page_size: int = 500,
|
||||
include_files: bool = True,
|
||||
include_dirs: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Recursive tree listing for cloud mode.
|
||||
|
||||
Mirrors the shape returned by :class:`MultiRootLocalFolderBackend.list_tree`:
|
||||
``{"entries": [{path, is_dir, size, modified_at, depth}, ...], "truncated": bool}``.
|
||||
"""
|
||||
normalized = self._normalize_listing_path(path or DOCUMENTS_ROOT)
|
||||
if not normalized.startswith(DOCUMENTS_ROOT) and normalized != "/":
|
||||
return {"error": "Error: path must be under /documents/"}
|
||||
|
||||
entries: list[dict[str, Any]] = []
|
||||
truncated = False
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
index = await build_path_index(session, self.search_space_id)
|
||||
doc_rows_raw = await session.execute(
|
||||
select(
|
||||
Document.id,
|
||||
Document.title,
|
||||
Document.folder_id,
|
||||
Document.updated_at,
|
||||
).where(Document.search_space_id == self.search_space_id)
|
||||
)
|
||||
doc_rows = list(doc_rows_raw.all())
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.warning("KBPostgresBackend.alist_tree_listing DB error: %s", exc)
|
||||
return {"entries": [], "truncated": False}
|
||||
|
||||
files = self._state_files()
|
||||
moved_removed, _ = self._moved_view_paths(files)
|
||||
anon = self._kb_anon_doc()
|
||||
anon_path = str(anon.get("path") or "") if anon else ""
|
||||
|
||||
def _depth_of(p: str) -> int:
|
||||
if p == DOCUMENTS_ROOT:
|
||||
return 0
|
||||
rel_root = (
|
||||
p[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||
if normalized.startswith(DOCUMENTS_ROOT)
|
||||
else p.lstrip("/")
|
||||
)
|
||||
return len([part for part in rel_root.split("/") if part])
|
||||
|
||||
def _add_entry(entry: dict[str, Any]) -> bool:
|
||||
nonlocal truncated
|
||||
if len(entries) >= page_size:
|
||||
truncated = True
|
||||
return False
|
||||
entries.append(entry)
|
||||
return True
|
||||
|
||||
if include_dirs:
|
||||
for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]):
|
||||
if not _is_under(fpath, normalized):
|
||||
continue
|
||||
depth = _depth_of(fpath)
|
||||
if max_depth is not None and depth > max_depth:
|
||||
continue
|
||||
if not _add_entry(
|
||||
{
|
||||
"path": fpath,
|
||||
"is_dir": True,
|
||||
"size": 0,
|
||||
"modified_at": "",
|
||||
"depth": depth,
|
||||
}
|
||||
):
|
||||
return {"entries": entries, "truncated": True}
|
||||
for staged in self._staged_dirs():
|
||||
if not _is_under(staged, normalized):
|
||||
continue
|
||||
depth = _depth_of(staged)
|
||||
if max_depth is not None and depth > max_depth:
|
||||
continue
|
||||
if any(e["path"] == staged for e in entries):
|
||||
continue
|
||||
if not _add_entry(
|
||||
{
|
||||
"path": staged,
|
||||
"is_dir": True,
|
||||
"size": 0,
|
||||
"modified_at": "",
|
||||
"depth": depth,
|
||||
}
|
||||
):
|
||||
return {"entries": entries, "truncated": True}
|
||||
|
||||
if include_files:
|
||||
for row in sorted(doc_rows, key=lambda r: str(r.title or "")):
|
||||
candidate = doc_to_virtual_path(
|
||||
doc_id=row.id,
|
||||
title=str(row.title or "untitled"),
|
||||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
if candidate in moved_removed:
|
||||
continue
|
||||
if not _is_under(candidate, normalized):
|
||||
continue
|
||||
depth = _depth_of(candidate)
|
||||
if max_depth is not None and depth > max_depth:
|
||||
continue
|
||||
modified = ""
|
||||
if row.updated_at is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
modified = row.updated_at.astimezone(UTC).isoformat()
|
||||
if not _add_entry(
|
||||
{
|
||||
"path": candidate,
|
||||
"is_dir": False,
|
||||
"size": 0,
|
||||
"modified_at": modified,
|
||||
"depth": depth,
|
||||
}
|
||||
):
|
||||
return {"entries": entries, "truncated": True}
|
||||
|
||||
if anon_path and _is_under(anon_path, normalized):
|
||||
depth = _depth_of(anon_path)
|
||||
if (max_depth is None or depth <= max_depth) and not _add_entry(
|
||||
{
|
||||
"path": anon_path,
|
||||
"is_dir": False,
|
||||
"size": len(str(anon.get("content") or "")),
|
||||
"modified_at": "",
|
||||
"depth": depth,
|
||||
}
|
||||
):
|
||||
return {"entries": entries, "truncated": True}
|
||||
|
||||
for path_key, fd in files.items():
|
||||
if not isinstance(path_key, str):
|
||||
continue
|
||||
if not _is_under(path_key, normalized):
|
||||
continue
|
||||
if any(e["path"] == path_key for e in entries):
|
||||
continue
|
||||
if not (
|
||||
path_key.startswith(DOCUMENTS_ROOT)
|
||||
or _basename(path_key).startswith(_TEMP_PREFIX)
|
||||
):
|
||||
continue
|
||||
depth = _depth_of(path_key)
|
||||
if max_depth is not None and depth > max_depth:
|
||||
continue
|
||||
size = self._file_data_size(fd) if isinstance(fd, dict) else 0
|
||||
if not _add_entry(
|
||||
{
|
||||
"path": path_key,
|
||||
"is_dir": False,
|
||||
"size": int(size),
|
||||
"modified_at": fd.get("modified_at", "")
|
||||
if isinstance(fd, dict)
|
||||
else "",
|
||||
"depth": depth,
|
||||
}
|
||||
):
|
||||
return {"entries": entries, "truncated": True}
|
||||
|
||||
return {"entries": entries, "truncated": truncated}
|
||||
|
||||
# ------------------------------------------------------------------ uploads (unsupported)
|
||||
|
||||
def upload_files( # type: ignore[override]
|
||||
self, files: list[tuple[str, bytes]]
|
||||
) -> list[FileUploadResponse]:
|
||||
msg = "KBPostgresBackend does not support upload_files."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def download_files( # type: ignore[override]
|
||||
self, paths: list[str]
|
||||
) -> list[FileDownloadResponse]:
|
||||
responses: list[FileDownloadResponse] = []
|
||||
files = self._state_files()
|
||||
for path in paths:
|
||||
fd = files.get(path)
|
||||
if fd is None:
|
||||
responses.append(
|
||||
FileDownloadResponse(
|
||||
path=path, content=None, error="file_not_found"
|
||||
)
|
||||
)
|
||||
continue
|
||||
content_str = file_data_to_string(fd)
|
||||
responses.append(
|
||||
FileDownloadResponse(
|
||||
path=path,
|
||||
content=content_str.encode("utf-8"),
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
return responses
|
||||
|
||||
|
||||
# --- module-level small helpers ---------------------------------------------
|
||||
|
||||
|
||||
async def list_tree_listing(
|
||||
backend: KBPostgresBackend,
|
||||
path: str,
|
||||
*,
|
||||
max_depth: int | None = 8,
|
||||
page_size: int = 500,
|
||||
include_files: bool = True,
|
||||
include_dirs: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Async helper used by the overridden ``list_tree`` tool wrapper."""
|
||||
return await backend.alist_tree_listing(
|
||||
path,
|
||||
max_depth=max_depth,
|
||||
page_size=page_size,
|
||||
include_files=include_files,
|
||||
include_dirs=include_dirs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KBPostgresBackend",
|
||||
"list_tree_listing",
|
||||
"paginate_listing",
|
||||
]
|
||||
|
|
@ -1,10 +1,24 @@
|
|||
"""Knowledge-base pre-search middleware for the SurfSense new chat agent.
|
||||
"""Hybrid-search priority middleware for the SurfSense new chat agent.
|
||||
|
||||
This middleware runs before the main agent loop and seeds a virtual filesystem
|
||||
(`files` state) with relevant documents retrieved via hybrid search. On each
|
||||
turn the filesystem is *expanded* — new results merge with documents loaded
|
||||
during prior turns — and a synthetic ``ls`` result is injected into the message
|
||||
history so the LLM is immediately aware of the current filesystem structure.
|
||||
This middleware runs ``before_agent`` on every turn and writes:
|
||||
|
||||
* ``state["kb_priority"]`` — the top-K most relevant documents for the
|
||||
current user message, used to render a ``<priority_documents>`` system
|
||||
message immediately before the user turn.
|
||||
* ``state["kb_matched_chunk_ids"]`` — internal hand-off mapping
|
||||
(``Document.id`` → matched chunk IDs) consumed by
|
||||
:class:`KBPostgresBackend._load_file_data` when the agent first reads each
|
||||
document, so the XML wrapper can flag matched sections in
|
||||
``<chunk_index>``.
|
||||
|
||||
The previous "scoped filesystem" behaviour (synthetic ``ls`` + state
|
||||
``files`` seeding) is intentionally removed: documents are now lazy-loaded
|
||||
from Postgres on demand, with the full workspace tree rendered separately
|
||||
by :class:`KnowledgeTreeMiddleware`.
|
||||
|
||||
In anonymous mode the middleware skips hybrid search entirely and emits a
|
||||
single-entry priority list pointing at the Redis-loaded document
|
||||
(``state["kb_anon_doc"]``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -13,27 +27,30 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from litellm import token_counter
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
PathIndex,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
|
||||
from app.db import (
|
||||
NATIVE_TO_LEGACY_DOCTYPE,
|
||||
Chunk,
|
||||
Document,
|
||||
Folder,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
|
|
@ -70,7 +87,6 @@ class KBSearchPlan(BaseModel):
|
|||
|
||||
|
||||
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||
"""Extract plain text from a message content."""
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
|
@ -85,19 +101,6 @@ def _extract_text_from_message(message: BaseMessage) -> str:
|
|||
return str(content)
|
||||
|
||||
|
||||
def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
|
||||
"""Convert arbitrary text into a filesystem-safe filename."""
|
||||
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
||||
name = re.sub(r"\s+", " ", name)
|
||||
if not name:
|
||||
name = fallback
|
||||
if len(name) > 180:
|
||||
name = name[:180].rstrip()
|
||||
if not name.lower().endswith(".xml"):
|
||||
name = f"{name}.xml"
|
||||
return name
|
||||
|
||||
|
||||
def _render_recent_conversation(
|
||||
messages: Sequence[BaseMessage],
|
||||
*,
|
||||
|
|
@ -107,10 +110,9 @@ def _render_recent_conversation(
|
|||
) -> str:
|
||||
"""Render recent dialogue for internal planning under a token budget.
|
||||
|
||||
Prefers the latest messages and uses the project's existing model-aware
|
||||
token budgeting hooks when available on the LLM (`_count_tokens`,
|
||||
`_get_max_input_tokens`). Falls back to the prior fixed-message heuristic
|
||||
if token counting is unavailable.
|
||||
Filters to ``HumanMessage`` and ``AIMessage`` (without tool_calls) so that
|
||||
injected ``SystemMessage`` artefacts (priority list, workspace tree,
|
||||
file-write contract) don't pollute the planner prompt.
|
||||
"""
|
||||
rendered: list[tuple[str, str]] = []
|
||||
for message in messages:
|
||||
|
|
@ -133,8 +135,6 @@ def _render_recent_conversation(
|
|||
if not rendered:
|
||||
return ""
|
||||
|
||||
# Exclude the latest user message from "recent conversation" because it is
|
||||
# already passed separately as "Latest user message" in the planner prompt.
|
||||
if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip():
|
||||
rendered = rendered[:-1]
|
||||
|
||||
|
|
@ -216,8 +216,6 @@ def _render_recent_conversation(
|
|||
selected_lines = candidate_lines
|
||||
continue
|
||||
|
||||
# If the full message does not fit, keep as much of this most-recent
|
||||
# older message as possible via binary search.
|
||||
lo, hi = 1, len(text)
|
||||
best_line: str | None = None
|
||||
while lo <= hi:
|
||||
|
|
@ -249,7 +247,6 @@ def _build_kb_planner_prompt(
|
|||
recent_conversation: str,
|
||||
user_text: str,
|
||||
) -> str:
|
||||
"""Build a compact internal prompt for KB query rewriting and date scoping."""
|
||||
today = datetime.now(UTC).date().isoformat()
|
||||
return (
|
||||
"You optimize internal knowledge-base search inputs for document retrieval.\n"
|
||||
|
|
@ -275,12 +272,10 @@ def _build_kb_planner_prompt(
|
|||
|
||||
|
||||
def _extract_json_payload(text: str) -> str:
|
||||
"""Extract a JSON object from a raw LLM response."""
|
||||
stripped = text.strip()
|
||||
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
||||
if fenced:
|
||||
return fenced.group(1)
|
||||
|
||||
start = stripped.find("{")
|
||||
end = stripped.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
|
|
@ -289,7 +284,6 @@ def _extract_json_payload(text: str) -> str:
|
|||
|
||||
|
||||
def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan:
|
||||
"""Parse and validate the planner's JSON response."""
|
||||
payload = json.loads(_extract_json_payload(response_text))
|
||||
return KBSearchPlan.model_validate(payload)
|
||||
|
||||
|
|
@ -298,212 +292,19 @@ def _normalize_optional_date_range(
|
|||
start_date: str | None,
|
||||
end_date: str | None,
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
"""Normalize optional planner dates into a UTC datetime range."""
|
||||
parsed_start = parse_date_or_datetime(start_date) if start_date else None
|
||||
parsed_end = parse_date_or_datetime(end_date) if end_date else None
|
||||
|
||||
if parsed_start is None and parsed_end is None:
|
||||
return None, None
|
||||
|
||||
resolved_start, resolved_end = resolve_date_range(parsed_start, parsed_end)
|
||||
return resolved_start, resolved_end
|
||||
|
||||
|
||||
def _build_document_xml(
|
||||
document: dict[str, Any],
|
||||
matched_chunk_ids: set[int] | None = None,
|
||||
) -> str:
|
||||
"""Build citation-friendly XML with a ``<chunk_index>`` for smart seeking.
|
||||
|
||||
The ``<chunk_index>`` at the top of each document lists every chunk with its
|
||||
line range inside ``<document_content>`` and flags chunks that directly
|
||||
matched the search query (``matched="true"``). This lets the LLM jump
|
||||
straight to the most relevant section via ``read_file(offset=…, limit=…)``
|
||||
instead of reading sequentially from the start.
|
||||
"""
|
||||
matched = matched_chunk_ids or set()
|
||||
|
||||
doc_meta = document.get("document") or {}
|
||||
metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {}
|
||||
document_id = doc_meta.get("id", document.get("document_id", "unknown"))
|
||||
document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN"))
|
||||
title = doc_meta.get("title") or metadata.get("title") or "Untitled Document"
|
||||
url = (
|
||||
metadata.get("url") or metadata.get("source") or metadata.get("page_url") or ""
|
||||
)
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||
|
||||
# --- 1. Metadata header (fixed structure) ---
|
||||
metadata_lines: list[str] = [
|
||||
"<document>",
|
||||
"<document_metadata>",
|
||||
f" <document_id>{document_id}</document_id>",
|
||||
f" <document_type>{document_type}</document_type>",
|
||||
f" <title><![CDATA[{title}]]></title>",
|
||||
f" <url><![CDATA[{url}]]></url>",
|
||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
||||
"</document_metadata>",
|
||||
"",
|
||||
]
|
||||
|
||||
# --- 2. Pre-build chunk XML strings to compute line counts ---
|
||||
chunks = document.get("chunks") or []
|
||||
chunk_entries: list[tuple[int | None, str]] = [] # (chunk_id, xml_string)
|
||||
if isinstance(chunks, list):
|
||||
for chunk in chunks:
|
||||
if not isinstance(chunk, dict):
|
||||
continue
|
||||
chunk_id = chunk.get("chunk_id") or chunk.get("id")
|
||||
chunk_content = str(chunk.get("content", "")).strip()
|
||||
if not chunk_content:
|
||||
continue
|
||||
if chunk_id is None:
|
||||
xml = f" <chunk><![CDATA[{chunk_content}]]></chunk>"
|
||||
else:
|
||||
xml = f" <chunk id='{chunk_id}'><![CDATA[{chunk_content}]]></chunk>"
|
||||
chunk_entries.append((chunk_id, xml))
|
||||
|
||||
# --- 3. Compute line numbers for every chunk ---
|
||||
# Layout (1-indexed lines for read_file):
|
||||
# metadata_lines -> len(metadata_lines) lines
|
||||
# <chunk_index> -> 1 line
|
||||
# index entries -> len(chunk_entries) lines
|
||||
# </chunk_index> -> 1 line
|
||||
# (empty line) -> 1 line
|
||||
# <document_content> -> 1 line
|
||||
# chunk xml lines…
|
||||
# </document_content> -> 1 line
|
||||
# </document> -> 1 line
|
||||
index_overhead = (
|
||||
1 + len(chunk_entries) + 1 + 1 + 1
|
||||
) # tags + empty + <document_content>
|
||||
first_chunk_line = len(metadata_lines) + index_overhead + 1 # 1-indexed
|
||||
|
||||
current_line = first_chunk_line
|
||||
index_entry_lines: list[str] = []
|
||||
for cid, xml_str in chunk_entries:
|
||||
num_lines = xml_str.count("\n") + 1
|
||||
end_line = current_line + num_lines - 1
|
||||
matched_attr = ' matched="true"' if cid is not None and cid in matched else ""
|
||||
if cid is not None:
|
||||
index_entry_lines.append(
|
||||
f' <entry chunk_id="{cid}" lines="{current_line}-{end_line}"{matched_attr}/>'
|
||||
)
|
||||
else:
|
||||
index_entry_lines.append(
|
||||
f' <entry lines="{current_line}-{end_line}"{matched_attr}/>'
|
||||
)
|
||||
current_line = end_line + 1
|
||||
|
||||
# --- 4. Assemble final XML ---
|
||||
lines = metadata_lines.copy()
|
||||
lines.append("<chunk_index>")
|
||||
lines.extend(index_entry_lines)
|
||||
lines.append("</chunk_index>")
|
||||
lines.append("")
|
||||
lines.append("<document_content>")
|
||||
for _, xml_str in chunk_entries:
|
||||
lines.append(xml_str)
|
||||
lines.extend(["</document_content>", "</document>"])
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def _get_folder_paths(
|
||||
session: AsyncSession, search_space_id: int
|
||||
) -> dict[int, str]:
|
||||
"""Return a map of folder_id -> virtual folder path under /documents."""
|
||||
result = await session.execute(
|
||||
select(Folder.id, Folder.name, Folder.parent_id).where(
|
||||
Folder.search_space_id == search_space_id
|
||||
)
|
||||
)
|
||||
rows = result.all()
|
||||
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
|
||||
|
||||
cache: dict[int, str] = {}
|
||||
|
||||
def resolve_path(folder_id: int) -> str:
|
||||
if folder_id in cache:
|
||||
return cache[folder_id]
|
||||
parts: list[str] = []
|
||||
cursor: int | None = folder_id
|
||||
visited: set[int] = set()
|
||||
while cursor is not None and cursor in by_id and cursor not in visited:
|
||||
visited.add(cursor)
|
||||
entry = by_id[cursor]
|
||||
parts.append(
|
||||
_safe_filename(str(entry["name"]), fallback="folder").removesuffix(
|
||||
".xml"
|
||||
)
|
||||
)
|
||||
cursor = entry["parent_id"]
|
||||
parts.reverse()
|
||||
path = "/documents/" + "/".join(parts) if parts else "/documents"
|
||||
cache[folder_id] = path
|
||||
return path
|
||||
|
||||
for folder_id in by_id:
|
||||
resolve_path(folder_id)
|
||||
return cache
|
||||
|
||||
|
||||
def _build_synthetic_ls(
|
||||
existing_files: dict[str, Any] | None,
|
||||
new_files: dict[str, Any],
|
||||
*,
|
||||
mentioned_paths: set[str] | None = None,
|
||||
) -> tuple[AIMessage, ToolMessage]:
|
||||
"""Build a synthetic ls("/documents") tool-call + result for the LLM context.
|
||||
|
||||
Mentioned files are listed first. A separate header tells the LLM which
|
||||
files the user explicitly selected; the path list itself stays clean so
|
||||
paths can be passed directly to ``read_file`` without stripping tags.
|
||||
"""
|
||||
_mentioned = mentioned_paths or set()
|
||||
merged: dict[str, Any] = {**(existing_files or {}), **new_files}
|
||||
doc_paths = [
|
||||
p for p, v in merged.items() if p.startswith("/documents/") and v is not None
|
||||
]
|
||||
|
||||
new_set = set(new_files)
|
||||
mentioned_list = [p for p in doc_paths if p in _mentioned]
|
||||
new_non_mentioned = [p for p in doc_paths if p in new_set and p not in _mentioned]
|
||||
old_paths = [p for p in doc_paths if p not in new_set]
|
||||
ordered = mentioned_list + new_non_mentioned + old_paths
|
||||
|
||||
parts: list[str] = []
|
||||
if mentioned_list:
|
||||
parts.append(
|
||||
"USER-MENTIONED documents (read these thoroughly before answering):"
|
||||
)
|
||||
for p in mentioned_list:
|
||||
parts.append(f" {p}")
|
||||
parts.append("")
|
||||
parts.append(str(ordered) if ordered else "No documents found.")
|
||||
|
||||
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
|
||||
ai_msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}],
|
||||
)
|
||||
tool_msg = ToolMessage(
|
||||
content="\n".join(parts),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
return ai_msg, tool_msg
|
||||
return resolve_date_range(parsed_start, parsed_end)
|
||||
|
||||
|
||||
def _resolve_search_types(
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
) -> list[str] | None:
|
||||
"""Build a flat list of document-type strings for the chunk retriever.
|
||||
|
||||
Includes legacy equivalents from ``NATIVE_TO_LEGACY_DOCTYPE`` so that
|
||||
old documents indexed under Composio names are still found.
|
||||
|
||||
Returns ``None`` when no filtering is desired (search all types).
|
||||
"""
|
||||
types: set[str] = set()
|
||||
if available_document_types:
|
||||
types.update(available_document_types)
|
||||
|
|
@ -531,13 +332,8 @@ async def browse_recent_documents(
|
|||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return documents ordered by recency (newest first), no relevance ranking.
|
||||
|
||||
Used when the user's intent is temporal ("latest file", "most recent upload")
|
||||
and hybrid search would produce poor results because the query has no
|
||||
meaningful topical signal.
|
||||
"""
|
||||
from sqlalchemy import func, select
|
||||
"""Return documents ordered by recency (newest first), no relevance ranking."""
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.db import DocumentType
|
||||
|
||||
|
|
@ -581,7 +377,6 @@ async def browse_recent_documents(
|
|||
return []
|
||||
|
||||
doc_ids = [d.id for d in documents]
|
||||
|
||||
numbered = (
|
||||
select(
|
||||
Chunk.id.label("chunk_id"),
|
||||
|
|
@ -632,6 +427,7 @@ async def browse_recent_documents(
|
|||
else None
|
||||
),
|
||||
"metadata": metadata,
|
||||
"folder_id": getattr(doc, "folder_id", None),
|
||||
},
|
||||
"source": (
|
||||
doc.document_type.value
|
||||
|
|
@ -640,12 +436,6 @@ async def browse_recent_documents(
|
|||
),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"browse_recent_documents: %d docs returned for space=%d",
|
||||
len(results),
|
||||
search_space_id,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
|
|
@ -659,17 +449,11 @@ async def search_knowledge_base(
|
|||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run a single unified hybrid search against the knowledge base.
|
||||
|
||||
Uses one ``ChucksHybridSearchRetriever`` call across all document types
|
||||
instead of fanning out per-connector. This reduces the number of DB
|
||||
queries from ~10 to 2 (one RRF query + one chunk fetch).
|
||||
"""
|
||||
"""Run a single unified hybrid search against the knowledge base."""
|
||||
if not query:
|
||||
return []
|
||||
|
||||
[embedding] = embed_texts([query])
|
||||
|
||||
doc_types = _resolve_search_types(available_connectors, available_document_types)
|
||||
retriever_top_k = min(top_k * 3, 30)
|
||||
|
||||
|
|
@ -693,14 +477,7 @@ async def fetch_mentioned_documents(
|
|||
document_ids: list[int],
|
||||
search_space_id: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch explicitly mentioned documents with *all* their chunks.
|
||||
|
||||
Returns the same dict structure as ``search_knowledge_base`` so results
|
||||
can be merged directly into ``build_scoped_filesystem``. Unlike search
|
||||
results, every chunk is included (no top-K limiting) and none are marked
|
||||
as ``matched`` since the entire document is relevant by virtue of the
|
||||
user's explicit mention.
|
||||
"""
|
||||
"""Fetch explicitly mentioned documents."""
|
||||
if not document_ids:
|
||||
return []
|
||||
|
||||
|
|
@ -750,6 +527,7 @@ async def fetch_mentioned_documents(
|
|||
else None
|
||||
),
|
||||
"metadata": metadata,
|
||||
"folder_id": getattr(doc, "folder_id", None),
|
||||
},
|
||||
"source": (
|
||||
doc.document_type.value
|
||||
|
|
@ -762,96 +540,36 @@ async def fetch_mentioned_documents(
|
|||
return results
|
||||
|
||||
|
||||
async def build_scoped_filesystem(
|
||||
*,
|
||||
documents: Sequence[dict[str, Any]],
|
||||
search_space_id: int,
|
||||
) -> tuple[dict[str, dict[str, str]], dict[int, str]]:
|
||||
"""Build a StateBackend-compatible files dict from search results.
|
||||
|
||||
Returns ``(files, doc_id_to_path)`` so callers can reliably map a
|
||||
document id back to its filesystem path without guessing by title.
|
||||
Paths are collision-proof: when two documents resolve to the same
|
||||
path the doc-id is appended to disambiguate.
|
||||
"""
|
||||
async with shielded_async_session() as session:
|
||||
folder_paths = await _get_folder_paths(session, search_space_id)
|
||||
doc_ids = [
|
||||
(doc.get("document") or {}).get("id")
|
||||
for doc in documents
|
||||
if isinstance(doc, dict)
|
||||
]
|
||||
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
|
||||
folder_by_doc_id: dict[int, int | None] = {}
|
||||
if doc_ids:
|
||||
doc_rows = await session.execute(
|
||||
select(Document.id, Document.folder_id).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.id.in_(doc_ids),
|
||||
)
|
||||
)
|
||||
folder_by_doc_id = {
|
||||
row.id: row.folder_id for row in doc_rows.all() if row.id is not None
|
||||
}
|
||||
|
||||
files: dict[str, dict[str, str]] = {}
|
||||
doc_id_to_path: dict[int, str] = {}
|
||||
for document in documents:
|
||||
doc_meta = document.get("document") or {}
|
||||
title = str(doc_meta.get("title") or "untitled")
|
||||
doc_id = doc_meta.get("id")
|
||||
folder_id = folder_by_doc_id.get(doc_id) if isinstance(doc_id, int) else None
|
||||
base_folder = folder_paths.get(folder_id, "/documents")
|
||||
file_name = _safe_filename(title)
|
||||
path = f"{base_folder}/{file_name}"
|
||||
if path in files:
|
||||
stem = file_name.removesuffix(".xml")
|
||||
path = f"{base_folder}/{stem} ({doc_id}).xml"
|
||||
matched_ids = set(document.get("matched_chunk_ids") or [])
|
||||
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
|
||||
files[path] = {
|
||||
"content": xml_content.split("\n"),
|
||||
"encoding": "utf-8",
|
||||
"created_at": "",
|
||||
"modified_at": "",
|
||||
}
|
||||
if isinstance(doc_id, int):
|
||||
doc_id_to_path[doc_id] = path
|
||||
return files, doc_id_to_path
|
||||
def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage:
|
||||
"""Render the priority list as a single ``<priority_documents>`` system message."""
|
||||
if not priority:
|
||||
body = "(no priority documents for this turn)"
|
||||
else:
|
||||
lines: list[str] = []
|
||||
for entry in priority:
|
||||
score = entry.get("score")
|
||||
mentioned = entry.get("mentioned")
|
||||
score_str = f"{score:.3f}" if isinstance(score, (int, float)) else "n/a"
|
||||
mark = " [USER-MENTIONED]" if mentioned else ""
|
||||
lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}")
|
||||
body = "\n".join(lines)
|
||||
return SystemMessage(
|
||||
content=(
|
||||
"<priority_documents>\n"
|
||||
"These documents are most relevant to the latest user message; "
|
||||
"read them first. Matched sections are flagged inside each "
|
||||
"document's <chunk_index>.\n"
|
||||
f"{body}\n"
|
||||
"</priority_documents>"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _build_anon_scoped_filesystem(
|
||||
documents: Sequence[dict[str, Any]],
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Build a scoped filesystem for anonymous documents without DB queries.
|
||||
|
||||
Anonymous uploads have no folders, so all files go under /documents.
|
||||
"""
|
||||
files: dict[str, dict[str, str]] = {}
|
||||
for document in documents:
|
||||
doc_meta = document.get("document") or {}
|
||||
title = str(doc_meta.get("title") or "untitled")
|
||||
file_name = _safe_filename(title)
|
||||
path = f"/documents/{file_name}"
|
||||
if path in files:
|
||||
doc_id = doc_meta.get("id", "dup")
|
||||
stem = file_name.removesuffix(".xml")
|
||||
path = f"/documents/{stem} ({doc_id}).xml"
|
||||
matched_ids = set(document.get("matched_chunk_ids") or [])
|
||||
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
|
||||
files[path] = {
|
||||
"content": xml_content.split("\n"),
|
||||
"encoding": "utf-8",
|
||||
"created_at": "",
|
||||
"modified_at": "",
|
||||
}
|
||||
return files
|
||||
|
||||
|
||||
class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Pre-agent middleware that always searches the KB and seeds a scoped filesystem."""
|
||||
class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Compute hybrid-search priority hints for the current turn."""
|
||||
|
||||
tools = ()
|
||||
state_schema = SurfSenseFilesystemState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -863,7 +581,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
available_document_types: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
anon_session_id: str | None = None,
|
||||
) -> None:
|
||||
self.llm = llm
|
||||
self.search_space_id = search_space_id
|
||||
|
|
@ -872,7 +589,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
self.available_document_types = available_document_types
|
||||
self.top_k = top_k
|
||||
self.mentioned_document_ids = mentioned_document_ids or []
|
||||
self.anon_session_id = anon_session_id
|
||||
|
||||
async def _plan_search_inputs(
|
||||
self,
|
||||
|
|
@ -880,10 +596,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
messages: Sequence[BaseMessage],
|
||||
user_text: str,
|
||||
) -> tuple[str, datetime | None, datetime | None, bool]:
|
||||
"""Rewrite the KB query and infer optional date filters with the LLM.
|
||||
|
||||
Returns (optimized_query, start_date, end_date, is_recency_query).
|
||||
"""
|
||||
if self.llm is None:
|
||||
return user_text, None, None, False
|
||||
|
||||
|
|
@ -914,7 +626,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
)
|
||||
is_recency = plan.is_recency_query
|
||||
_perf_log.info(
|
||||
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r "
|
||||
"[kb_priority] planner in %.3fs query=%r optimized=%r "
|
||||
"start=%s end=%s recency=%s",
|
||||
loop.time() - t0,
|
||||
user_text[:80],
|
||||
|
|
@ -946,106 +658,68 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
pass
|
||||
return asyncio.run(self.abefore_agent(state, runtime))
|
||||
|
||||
async def _load_anon_document(self) -> dict[str, Any] | None:
|
||||
"""Load the anonymous user's uploaded document from Redis."""
|
||||
if not self.anon_session_id:
|
||||
return None
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from app.config import config
|
||||
|
||||
redis_client = aioredis.from_url(
|
||||
config.REDIS_APP_URL, decode_responses=True
|
||||
)
|
||||
try:
|
||||
redis_key = f"anon:doc:{self.anon_session_id}"
|
||||
data = await redis_client.get(redis_key)
|
||||
if not data:
|
||||
return None
|
||||
doc = json.loads(data)
|
||||
return {
|
||||
"document_id": -1,
|
||||
"content": doc.get("content", ""),
|
||||
"score": 1.0,
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_id": -1,
|
||||
"content": doc.get("content", ""),
|
||||
}
|
||||
],
|
||||
"matched_chunk_ids": [-1],
|
||||
"document": {
|
||||
"id": -1,
|
||||
"title": doc.get("filename", "uploaded_document"),
|
||||
"document_type": "FILE",
|
||||
"metadata": {"source": "anonymous_upload"},
|
||||
},
|
||||
"source": "FILE",
|
||||
"_user_mentioned": True,
|
||||
}
|
||||
finally:
|
||||
await redis_client.aclose()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load anonymous document from Redis: %s", exc)
|
||||
return None
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||
# Local-folder mode should not seed cloud KB documents into filesystem.
|
||||
return None
|
||||
|
||||
last_human = None
|
||||
last_human: HumanMessage | None = None
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, HumanMessage):
|
||||
last_human = msg
|
||||
break
|
||||
if last_human is None:
|
||||
return None
|
||||
|
||||
user_text = _extract_text_from_message(last_human).strip()
|
||||
if not user_text:
|
||||
return None
|
||||
|
||||
t0 = _perf_log and asyncio.get_event_loop().time()
|
||||
existing_files = state.get("files")
|
||||
anon_doc = state.get("kb_anon_doc")
|
||||
if anon_doc:
|
||||
return self._anon_priority(state, anon_doc)
|
||||
|
||||
# --- Anonymous session: load Redis doc and skip DB queries ---
|
||||
if self.anon_session_id:
|
||||
merged: list[dict[str, Any]] = []
|
||||
anon_doc = await self._load_anon_document()
|
||||
if anon_doc:
|
||||
merged.append(anon_doc)
|
||||
return await self._authenticated_priority(state, messages, user_text)
|
||||
|
||||
if merged:
|
||||
new_files = _build_anon_scoped_filesystem(merged)
|
||||
mentioned_paths = set(new_files.keys())
|
||||
else:
|
||||
new_files = {}
|
||||
mentioned_paths = set()
|
||||
def _anon_priority(
|
||||
self,
|
||||
state: AgentState,
|
||||
anon_doc: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
path = str(anon_doc.get("path") or "")
|
||||
title = str(anon_doc.get("title") or "uploaded_document")
|
||||
priority = [
|
||||
{
|
||||
"path": path,
|
||||
"score": 1.0,
|
||||
"document_id": None,
|
||||
"title": title,
|
||||
"mentioned": True,
|
||||
}
|
||||
]
|
||||
new_messages = list(state.get("messages") or [])
|
||||
insert_at = max(len(new_messages) - 1, 0)
|
||||
new_messages.insert(insert_at, _render_priority_message(priority))
|
||||
return {
|
||||
"kb_priority": priority,
|
||||
"kb_matched_chunk_ids": {},
|
||||
"messages": new_messages,
|
||||
}
|
||||
|
||||
ai_msg, tool_msg = _build_synthetic_ls(
|
||||
existing_files,
|
||||
new_files,
|
||||
mentioned_paths=mentioned_paths,
|
||||
)
|
||||
if t0 is not None:
|
||||
_perf_log.info(
|
||||
"[kb_fs_middleware] anon completed in %.3fs new_files=%d",
|
||||
asyncio.get_event_loop().time() - t0,
|
||||
len(new_files),
|
||||
)
|
||||
return {"files": new_files, "messages": [ai_msg, tool_msg]}
|
||||
|
||||
# --- Authenticated session: full KB search ---
|
||||
async def _authenticated_priority(
|
||||
self,
|
||||
state: AgentState,
|
||||
messages: Sequence[BaseMessage],
|
||||
user_text: str,
|
||||
) -> dict[str, Any]:
|
||||
t0 = asyncio.get_event_loop().time()
|
||||
(
|
||||
planned_query,
|
||||
start_date,
|
||||
|
|
@ -1056,7 +730,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
user_text=user_text,
|
||||
)
|
||||
|
||||
# --- 1. Fetch mentioned documents (user-selected, all chunks) ---
|
||||
mentioned_results: list[dict[str, Any]] = []
|
||||
if self.mentioned_document_ids:
|
||||
mentioned_results = await fetch_mentioned_documents(
|
||||
|
|
@ -1065,7 +738,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
)
|
||||
self.mentioned_document_ids = []
|
||||
|
||||
# --- 2. Run KB search (recency browse or hybrid) ---
|
||||
if is_recency:
|
||||
doc_types = _resolve_search_types(
|
||||
self.available_connectors, self.available_document_types
|
||||
|
|
@ -1088,48 +760,108 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
end_date=end_date,
|
||||
)
|
||||
|
||||
# --- 3. Merge: mentioned first, then search (dedup by doc id) ---
|
||||
seen_doc_ids: set[int] = set()
|
||||
merged_auth: list[dict[str, Any]] = []
|
||||
merged: list[dict[str, Any]] = []
|
||||
for doc in mentioned_results:
|
||||
doc_id = (doc.get("document") or {}).get("id")
|
||||
if doc_id is not None:
|
||||
if isinstance(doc_id, int):
|
||||
seen_doc_ids.add(doc_id)
|
||||
merged_auth.append(doc)
|
||||
merged.append(doc)
|
||||
for doc in search_results:
|
||||
doc_id = (doc.get("document") or {}).get("id")
|
||||
if doc_id is not None and doc_id in seen_doc_ids:
|
||||
if isinstance(doc_id, int) and doc_id in seen_doc_ids:
|
||||
continue
|
||||
merged_auth.append(doc)
|
||||
merged.append(doc)
|
||||
|
||||
# --- 4. Build scoped filesystem ---
|
||||
new_files, doc_id_to_path = await build_scoped_filesystem(
|
||||
documents=merged_auth,
|
||||
search_space_id=self.search_space_id,
|
||||
priority, matched_chunk_ids = await self._materialize_priority(merged)
|
||||
|
||||
new_messages = list(messages)
|
||||
insert_at = max(len(new_messages) - 1, 0)
|
||||
new_messages.insert(insert_at, _render_priority_message(priority))
|
||||
|
||||
_perf_log.info(
|
||||
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d",
|
||||
asyncio.get_event_loop().time() - t0,
|
||||
user_text[:80],
|
||||
len(priority),
|
||||
len(mentioned_results),
|
||||
)
|
||||
|
||||
mentioned_doc_ids = {
|
||||
(d.get("document") or {}).get("id") for d in mentioned_results
|
||||
}
|
||||
mentioned_paths = {
|
||||
doc_id_to_path[did] for did in mentioned_doc_ids if did in doc_id_to_path
|
||||
return {
|
||||
"kb_priority": priority,
|
||||
"kb_matched_chunk_ids": matched_chunk_ids,
|
||||
"messages": new_messages,
|
||||
}
|
||||
|
||||
ai_msg, tool_msg = _build_synthetic_ls(
|
||||
existing_files,
|
||||
new_files,
|
||||
mentioned_paths=mentioned_paths,
|
||||
)
|
||||
async def _materialize_priority(
|
||||
self, merged: list[dict[str, Any]]
|
||||
) -> tuple[list[dict[str, Any]], dict[int, list[int]]]:
|
||||
"""Resolve canonical paths and matched chunk ids for the priority list."""
|
||||
priority: list[dict[str, Any]] = []
|
||||
matched_chunk_ids: dict[int, list[int]] = {}
|
||||
|
||||
if t0 is not None:
|
||||
_perf_log.info(
|
||||
"[kb_fs_middleware] completed in %.3fs query=%r optimized=%r "
|
||||
"mentioned=%d new_files=%d total=%d",
|
||||
asyncio.get_event_loop().time() - t0,
|
||||
user_text[:80],
|
||||
planned_query[:120],
|
||||
len(mentioned_results),
|
||||
len(new_files),
|
||||
len(new_files) + len(existing_files or {}),
|
||||
if not merged:
|
||||
return priority, matched_chunk_ids
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
index: PathIndex = await build_path_index(session, self.search_space_id)
|
||||
doc_ids = [
|
||||
(doc.get("document") or {}).get("id")
|
||||
for doc in merged
|
||||
if isinstance(doc, dict)
|
||||
]
|
||||
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
|
||||
folder_by_doc_id: dict[int, int | None] = {}
|
||||
if doc_ids:
|
||||
folder_rows = await session.execute(
|
||||
select(Document.id, Document.folder_id).where(
|
||||
Document.search_space_id == self.search_space_id,
|
||||
Document.id.in_(doc_ids),
|
||||
)
|
||||
)
|
||||
folder_by_doc_id = {row.id: row.folder_id for row in folder_rows.all()}
|
||||
|
||||
for doc in merged:
|
||||
doc_meta = doc.get("document") or {}
|
||||
doc_id = doc_meta.get("id")
|
||||
title = doc_meta.get("title") or "untitled"
|
||||
folder_id = (
|
||||
folder_by_doc_id.get(doc_id)
|
||||
if isinstance(doc_id, int)
|
||||
else doc_meta.get("folder_id")
|
||||
)
|
||||
return {"files": new_files, "messages": [ai_msg, tool_msg]}
|
||||
path = doc_to_virtual_path(
|
||||
doc_id=doc_id if isinstance(doc_id, int) else None,
|
||||
title=str(title),
|
||||
folder_id=folder_id if isinstance(folder_id, int) else None,
|
||||
index=index,
|
||||
)
|
||||
priority.append(
|
||||
{
|
||||
"path": path,
|
||||
"score": float(doc.get("score") or 0.0),
|
||||
"document_id": doc_id if isinstance(doc_id, int) else None,
|
||||
"title": str(title),
|
||||
"mentioned": bool(doc.get("_user_mentioned")),
|
||||
}
|
||||
)
|
||||
if isinstance(doc_id, int):
|
||||
chunk_ids = doc.get("matched_chunk_ids") or []
|
||||
if chunk_ids:
|
||||
matched_chunk_ids[doc_id] = [
|
||||
int(cid) for cid in chunk_ids if isinstance(cid, (int, str))
|
||||
]
|
||||
return priority, matched_chunk_ids
|
||||
|
||||
|
||||
# Backwards-compatible alias for any external imports.
|
||||
KnowledgeBaseSearchMiddleware = KnowledgePriorityMiddleware
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KnowledgeBaseSearchMiddleware",
|
||||
"KnowledgePriorityMiddleware",
|
||||
"browse_recent_documents",
|
||||
"fetch_mentioned_documents",
|
||||
"search_knowledge_base",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,272 @@
|
|||
"""Workspace-tree middleware for the SurfSense agent.
|
||||
|
||||
Renders the full ``Folder``+``Document`` tree under ``/documents/`` once per
|
||||
turn (cloud only), caches it by ``(search_space_id, tree_version)``, and
|
||||
injects the result as a ``<workspace_tree>`` system message immediately
|
||||
before the latest human turn.
|
||||
|
||||
The render is bounded by two truncation layers:
|
||||
|
||||
1. **Entry cap** — at most ``MAX_TREE_ENTRIES`` lines. The remainder is
|
||||
replaced with a "use ls" hint.
|
||||
2. **Token cap** — at most ``MAX_TREE_TOKENS`` tokens (using the LLM's
|
||||
token-count profile when available). If the entry-truncated tree still
|
||||
exceeds the token cap we fall back to a root-only summary.
|
||||
|
||||
Anonymous mode renders only ``state['kb_anon_doc']`` (no DB calls).
|
||||
|
||||
This middleware also performs a one-time initialization of ``state['cwd']``
|
||||
to ``"/documents"`` so subsequent middlewares and tools always see a valid
|
||||
cwd in cloud mode.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
PathIndex,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.db import Document, shielded_async_session
|
||||
|
||||
try:
|
||||
from litellm import token_counter
|
||||
except Exception: # pragma: no cover - optional dep
|
||||
token_counter = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_TREE_ENTRIES = 500
|
||||
MAX_TREE_TOKENS = 4000
|
||||
|
||||
|
||||
def _approx_tokens(text: str) -> int:
|
||||
"""Cheap fallback token estimate (1 token ~= 4 chars)."""
|
||||
return max(1, (len(text) + 3) // 4)
|
||||
|
||||
|
||||
def _count_tokens(text: str, *, llm: BaseChatModel | None) -> int:
|
||||
if llm is None:
|
||||
return _approx_tokens(text)
|
||||
count_fn = getattr(llm, "_count_tokens", None)
|
||||
if callable(count_fn):
|
||||
try:
|
||||
return int(count_fn([{"role": "user", "content": text}]))
|
||||
except Exception:
|
||||
pass
|
||||
profile = getattr(llm, "profile", None)
|
||||
model_names: list[str] = []
|
||||
if isinstance(profile, dict):
|
||||
tcms = profile.get("token_count_models")
|
||||
if isinstance(tcms, list):
|
||||
model_names.extend(name for name in tcms if isinstance(name, str) and name)
|
||||
tcm = profile.get("token_count_model")
|
||||
if isinstance(tcm, str) and tcm and tcm not in model_names:
|
||||
model_names.append(tcm)
|
||||
model_name = model_names[0] if model_names else getattr(llm, "model", None)
|
||||
if not isinstance(model_name, str) or not model_name or token_counter is None:
|
||||
return _approx_tokens(text)
|
||||
try:
|
||||
return int(
|
||||
token_counter(
|
||||
messages=[{"role": "user", "content": text}],
|
||||
model=model_name,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
return _approx_tokens(text)
|
||||
|
||||
|
||||
class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Inject the workspace folder/document tree into the agent's context."""
|
||||
|
||||
tools = ()
|
||||
state_schema = SurfSenseFilesystemState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
search_space_id: int,
|
||||
filesystem_mode: FilesystemMode,
|
||||
llm: BaseChatModel | None = None,
|
||||
max_entries: int = MAX_TREE_ENTRIES,
|
||||
max_tokens: int = MAX_TREE_TOKENS,
|
||||
) -> None:
|
||||
self.search_space_id = search_space_id
|
||||
self.filesystem_mode = filesystem_mode
|
||||
self.llm = llm
|
||||
self.max_entries = max_entries
|
||||
self.max_tokens = max_tokens
|
||||
self._cache: dict[tuple[int, int, bool], str] = {}
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
|
||||
update: dict[str, Any] = {}
|
||||
if not state.get("cwd"):
|
||||
update["cwd"] = DOCUMENTS_ROOT
|
||||
|
||||
anon_doc = state.get("kb_anon_doc")
|
||||
if anon_doc:
|
||||
tree_msg = self._render_anon_tree(anon_doc)
|
||||
else:
|
||||
tree_msg = await self._render_kb_tree(state)
|
||||
|
||||
messages = list(state.get("messages") or [])
|
||||
insert_at = max(len(messages) - 1, 0)
|
||||
messages.insert(insert_at, SystemMessage(content=tree_msg))
|
||||
update["messages"] = messages
|
||||
return update
|
||||
|
||||
def before_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_running():
|
||||
return None
|
||||
except RuntimeError:
|
||||
pass
|
||||
return asyncio.run(self.abefore_agent(state, runtime))
|
||||
|
||||
# ------------------------------------------------------------------ render
|
||||
|
||||
def _render_anon_tree(self, anon_doc: dict[str, Any]) -> str:
|
||||
path = str(anon_doc.get("path") or "")
|
||||
title = str(anon_doc.get("title") or "uploaded_document")
|
||||
return (
|
||||
"<workspace_tree>\n"
|
||||
"Anonymous session — only one read-only document is available.\n"
|
||||
f"{DOCUMENTS_ROOT}/\n"
|
||||
f" {path} — {title}\n"
|
||||
"</workspace_tree>"
|
||||
)
|
||||
|
||||
async def _render_kb_tree(self, state: AgentState) -> str:
|
||||
version = int(state.get("tree_version") or 0)
|
||||
cache_key = (self.search_space_id, version, False)
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
index = await build_path_index(session, self.search_space_id)
|
||||
doc_rows = await session.execute(
|
||||
select(Document.id, Document.title, Document.folder_id).where(
|
||||
Document.search_space_id == self.search_space_id
|
||||
)
|
||||
)
|
||||
docs = list(doc_rows.all())
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning("knowledge_tree: DB error %s", exc)
|
||||
return "<workspace_tree>\n(unavailable)\n</workspace_tree>"
|
||||
|
||||
rendered = self._format_tree(index, docs)
|
||||
self._cache[cache_key] = rendered
|
||||
return rendered
|
||||
|
||||
def _format_tree(self, index: PathIndex, docs: list[Any]) -> str:
|
||||
folder_paths = sorted(set(index.folder_paths.values()))
|
||||
doc_paths = sorted(
|
||||
doc_to_virtual_path(
|
||||
doc_id=row.id,
|
||||
title=str(row.title or "untitled"),
|
||||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
for row in docs
|
||||
)
|
||||
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
|
||||
|
||||
lines: list[str] = []
|
||||
for path in all_paths:
|
||||
depth = (
|
||||
0
|
||||
if path == DOCUMENTS_ROOT
|
||||
else len([p for p in path[len(DOCUMENTS_ROOT) :].split("/") if p])
|
||||
)
|
||||
indent = " " * depth
|
||||
is_dir = path == DOCUMENTS_ROOT or path in folder_paths
|
||||
display = (
|
||||
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
|
||||
)
|
||||
if is_dir:
|
||||
lines.append(f"{indent}{display}/")
|
||||
else:
|
||||
lines.append(f"{indent}{display}")
|
||||
if len(lines) >= self.max_entries:
|
||||
remaining = len(all_paths) - len(lines)
|
||||
if remaining > 0:
|
||||
lines.append(
|
||||
f"... {remaining} more entries — use "
|
||||
"ls('/documents/<folder>', offset, limit) to expand"
|
||||
)
|
||||
break
|
||||
|
||||
body = "\n".join(lines)
|
||||
rendered = f"<workspace_tree>\n{body}\n</workspace_tree>"
|
||||
|
||||
token_count = _count_tokens(rendered, llm=self.llm)
|
||||
if token_count <= self.max_tokens:
|
||||
return rendered
|
||||
|
||||
return self._format_root_summary(folder_paths, doc_paths)
|
||||
|
||||
def _format_root_summary(
|
||||
self, folder_paths: list[str], doc_paths: list[str]
|
||||
) -> str:
|
||||
top_level: dict[str, int] = {}
|
||||
loose_docs = 0
|
||||
for path in doc_paths:
|
||||
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||
if "/" in rel:
|
||||
top = rel.split("/", 1)[0]
|
||||
top_level[top] = top_level.get(top, 0) + 1
|
||||
else:
|
||||
loose_docs += 1
|
||||
for path in folder_paths:
|
||||
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||
if not rel:
|
||||
continue
|
||||
top = rel.split("/", 1)[0]
|
||||
top_level.setdefault(top, 0)
|
||||
|
||||
lines = [DOCUMENTS_ROOT + "/"]
|
||||
for name in sorted(top_level):
|
||||
count = top_level[name]
|
||||
lines.append(f" {name}/ ({count} document{'s' if count != 1 else ''})")
|
||||
if loose_docs:
|
||||
lines.append(
|
||||
f" ({loose_docs} loose document{'s' if loose_docs != 1 else ''})"
|
||||
)
|
||||
lines.append(
|
||||
"Tree is large; use list_tree('/documents/<folder>') to drill in "
|
||||
"or ls('/documents/<folder>', offset, limit) for paginated listings."
|
||||
)
|
||||
return "<workspace_tree>\n" + "\n".join(lines) + "\n</workspace_tree>"
|
||||
|
||||
|
||||
__all__ = ["KnowledgeTreeMiddleware"]
|
||||
351
surfsense_backend/app/agents/new_chat/path_resolver.py
Normal file
351
surfsense_backend/app/agents/new_chat/path_resolver.py
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
"""Canonical virtual-path resolver for SurfSense knowledge-base documents.
|
||||
|
||||
This module is the single source of truth for mapping ``Document`` rows to
|
||||
virtual paths under ``/documents/`` and back. It is used by:
|
||||
|
||||
* :class:`KnowledgeTreeMiddleware` (rendering the workspace tree)
|
||||
* :class:`KnowledgePriorityMiddleware` (computing priority paths)
|
||||
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / move operations)
|
||||
* :class:`KnowledgeBasePersistenceMiddleware` (resolving moves and creates)
|
||||
|
||||
Centralising the logic ensures that title-collision suffixes, folder paths,
|
||||
and ``unique_identifier_hash`` lookups never drift between renders and
|
||||
commits.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType, Folder
|
||||
from app.utils.document_converters import generate_unique_identifier_hash
|
||||
|
||||
DOCUMENTS_ROOT = "/documents"
|
||||
"""Root virtual folder for all KB documents."""
|
||||
|
||||
_INVALID_FILENAME_CHARS = re.compile(r"[\\/:*?\"<>|]+")
|
||||
_WHITESPACE_RUN = re.compile(r"\s+")
|
||||
|
||||
|
||||
def safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
|
||||
"""Convert arbitrary text into a filesystem-safe ``.xml`` filename."""
|
||||
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
|
||||
name = _WHITESPACE_RUN.sub(" ", name)
|
||||
if not name:
|
||||
name = fallback
|
||||
if len(name) > 180:
|
||||
name = name[:180].rstrip()
|
||||
if not name.lower().endswith(".xml"):
|
||||
name = f"{name}.xml"
|
||||
return name
|
||||
|
||||
|
||||
def safe_folder_segment(value: str, *, fallback: str = "folder") -> str:
|
||||
"""Sanitize a single folder name into a path-safe segment."""
|
||||
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
|
||||
name = _WHITESPACE_RUN.sub(" ", name)
|
||||
if not name:
|
||||
return fallback
|
||||
if len(name) > 180:
|
||||
name = name[:180].rstrip()
|
||||
return name
|
||||
|
||||
|
||||
def _suffix_with_doc_id(filename: str, doc_id: int | None) -> str:
|
||||
if doc_id is None:
|
||||
return filename
|
||||
if not filename.lower().endswith(".xml"):
|
||||
return f"{filename} ({doc_id}).xml"
|
||||
stem = filename[:-4]
|
||||
return f"{stem} ({doc_id}).xml"
|
||||
|
||||
|
||||
_SUFFIX_PATTERN = re.compile(r"\s\((\d+)\)\.xml$", re.IGNORECASE)
|
||||
|
||||
|
||||
def parse_doc_id_suffix(filename: str) -> tuple[str, int | None]:
|
||||
"""Strip a trailing ``" (<doc_id>).xml"`` suffix; return ``(stem, doc_id)``.
|
||||
|
||||
If no suffix is present, returns ``(stem_without_xml_extension, None)``.
|
||||
"""
|
||||
match = _SUFFIX_PATTERN.search(filename)
|
||||
if match:
|
||||
doc_id = int(match.group(1))
|
||||
stem = filename[: match.start()]
|
||||
return stem, doc_id
|
||||
if filename.lower().endswith(".xml"):
|
||||
return filename[:-4], None
|
||||
return filename, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PathIndex:
|
||||
"""In-memory occupancy snapshot used by :func:`doc_to_virtual_path`.
|
||||
|
||||
Built once per call site so collision handling is deterministic and so
|
||||
we don't perform N folder lookups per render.
|
||||
"""
|
||||
|
||||
folder_paths: dict[int, str] = field(default_factory=dict)
|
||||
"""``Folder.id`` -> absolute virtual folder path under ``/documents``."""
|
||||
|
||||
occupants: dict[str, int] = field(default_factory=dict)
|
||||
"""virtual path -> ``Document.id`` already occupying that path (this render)."""
|
||||
|
||||
|
||||
async def _build_folder_paths(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> dict[int, str]:
|
||||
"""Compute ``Folder.id`` -> absolute virtual path under ``/documents``."""
|
||||
result = await session.execute(
|
||||
select(Folder.id, Folder.name, Folder.parent_id).where(
|
||||
Folder.search_space_id == search_space_id
|
||||
)
|
||||
)
|
||||
rows = result.all()
|
||||
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
|
||||
cache: dict[int, str] = {}
|
||||
|
||||
def resolve(folder_id: int) -> str:
|
||||
if folder_id in cache:
|
||||
return cache[folder_id]
|
||||
parts: list[str] = []
|
||||
cursor: int | None = folder_id
|
||||
visited: set[int] = set()
|
||||
while cursor is not None and cursor in by_id and cursor not in visited:
|
||||
visited.add(cursor)
|
||||
entry = by_id[cursor]
|
||||
parts.append(safe_folder_segment(str(entry["name"])))
|
||||
cursor = entry["parent_id"]
|
||||
parts.reverse()
|
||||
path = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
|
||||
cache[folder_id] = path
|
||||
return path
|
||||
|
||||
for folder_id in by_id:
|
||||
resolve(folder_id)
|
||||
return cache
|
||||
|
||||
|
||||
async def build_path_index(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
*,
|
||||
populate_occupants: bool = True,
|
||||
) -> PathIndex:
|
||||
"""Build a :class:`PathIndex` for a search space.
|
||||
|
||||
``populate_occupants`` controls whether the occupancy map is pre-seeded
|
||||
from existing ``Document`` rows. Most callers want this so that
|
||||
:func:`doc_to_virtual_path` can detect collisions across the whole space;
|
||||
the persistence middleware sets this to ``False`` when it is iterating to
|
||||
decide where to place fresh documents.
|
||||
"""
|
||||
folder_paths = await _build_folder_paths(session, search_space_id)
|
||||
occupants: dict[str, int] = {}
|
||||
if populate_occupants:
|
||||
rows = await session.execute(
|
||||
select(Document.id, Document.title, Document.folder_id).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
for row in rows.all():
|
||||
base = folder_paths.get(row.folder_id, DOCUMENTS_ROOT)
|
||||
filename = safe_filename(str(row.title or "untitled"))
|
||||
path = f"{base}/{filename}"
|
||||
if path in occupants and occupants[path] != row.id:
|
||||
path = f"{base}/{_suffix_with_doc_id(filename, row.id)}"
|
||||
occupants[path] = row.id
|
||||
return PathIndex(folder_paths=folder_paths, occupants=occupants)
|
||||
|
||||
|
||||
def doc_to_virtual_path(
|
||||
*,
|
||||
doc_id: int | None,
|
||||
title: str,
|
||||
folder_id: int | None,
|
||||
index: PathIndex,
|
||||
) -> str:
|
||||
"""Return the canonical virtual path for a document.
|
||||
|
||||
Mutates ``index.occupants`` so subsequent calls see this assignment and
|
||||
deterministically pick a different suffix for the next colliding doc.
|
||||
"""
|
||||
base = index.folder_paths.get(folder_id, DOCUMENTS_ROOT)
|
||||
filename = safe_filename(str(title or "untitled"))
|
||||
path = f"{base}/{filename}"
|
||||
occupant = index.occupants.get(path)
|
||||
if occupant is not None and occupant != doc_id:
|
||||
path = f"{base}/{_suffix_with_doc_id(filename, doc_id)}"
|
||||
if doc_id is not None:
|
||||
index.occupants[path] = doc_id
|
||||
return path
|
||||
|
||||
|
||||
async def virtual_path_to_doc(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
virtual_path: str,
|
||||
) -> Document | None:
|
||||
"""Resolve a virtual path back to a ``Document`` row.
|
||||
|
||||
Resolution order:
|
||||
1. ``Document.unique_identifier_hash`` lookup (fast path for paths created
|
||||
by SurfSense itself — every NOTE write goes through this hash).
|
||||
2. If the basename carries a ``" (<doc_id>).xml"`` disambiguation suffix,
|
||||
try a direct id lookup constrained to the search space.
|
||||
3. Title-from-basename + folder-resolution lookup as a last resort.
|
||||
"""
|
||||
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
|
||||
return None
|
||||
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTE,
|
||||
virtual_path,
|
||||
search_space_id,
|
||||
)
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.unique_identifier_hash == unique_hash,
|
||||
)
|
||||
)
|
||||
document = result.scalar_one_or_none()
|
||||
if document is not None:
|
||||
return document
|
||||
|
||||
rel = virtual_path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||
if not rel:
|
||||
return None
|
||||
parts = [p for p in rel.split("/") if p]
|
||||
if not parts:
|
||||
return None
|
||||
basename = parts[-1]
|
||||
folder_parts = parts[:-1]
|
||||
|
||||
stem, suffix_doc_id = parse_doc_id_suffix(basename)
|
||||
if suffix_doc_id is not None:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.id == suffix_doc_id,
|
||||
)
|
||||
)
|
||||
document = result.scalar_one_or_none()
|
||||
if document is not None:
|
||||
return document
|
||||
|
||||
folder_id = await _resolve_folder_id(
|
||||
session, search_space_id=search_space_id, folder_parts=folder_parts
|
||||
)
|
||||
title_candidates: list[str] = []
|
||||
raw_title = stem
|
||||
title_candidates.append(raw_title)
|
||||
if raw_title.endswith(".xml"):
|
||||
title_candidates.append(raw_title[:-4])
|
||||
|
||||
for candidate in dict.fromkeys(title_candidates):
|
||||
if not candidate:
|
||||
continue
|
||||
query = select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.title == candidate,
|
||||
)
|
||||
if folder_id is None:
|
||||
query = query.where(Document.folder_id.is_(None))
|
||||
else:
|
||||
query = query.where(Document.folder_id == folder_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalars().first()
|
||||
if document is not None:
|
||||
return document
|
||||
|
||||
# Fallback: title-as-string lookup misses when the real DB title contains
|
||||
# characters that ``safe_filename`` lossily replaces (``:``, ``/``, ``*``,
|
||||
# etc.) — common for connector-imported docs (Google Calendar/Drive etc.).
|
||||
# The workspace tree shows the lossy filename, so the agent passes that
|
||||
# filename back here. Scan all documents in the resolved folder and match
|
||||
# by ``safe_filename(title)`` to recover the original document.
|
||||
folder_scan = select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
if folder_id is None:
|
||||
folder_scan = folder_scan.where(Document.folder_id.is_(None))
|
||||
else:
|
||||
folder_scan = folder_scan.where(Document.folder_id == folder_id)
|
||||
result = await session.execute(folder_scan)
|
||||
for candidate_doc in result.scalars().all():
|
||||
encoded = safe_filename(str(candidate_doc.title or "untitled"))
|
||||
if encoded == basename:
|
||||
return candidate_doc
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_folder_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
folder_parts: list[str],
|
||||
) -> int | None:
|
||||
"""Look up the leaf folder id for a chain of folder names; return ``None`` if missing."""
|
||||
if not folder_parts:
|
||||
return None
|
||||
parent_id: int | None = None
|
||||
for raw in folder_parts:
|
||||
name = safe_folder_segment(raw)
|
||||
query = select(Folder.id).where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.name == name,
|
||||
)
|
||||
if parent_id is None:
|
||||
query = query.where(Folder.parent_id.is_(None))
|
||||
else:
|
||||
query = query.where(Folder.parent_id == parent_id)
|
||||
result = await session.execute(query)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
return None
|
||||
parent_id = row[0]
|
||||
return parent_id
|
||||
|
||||
|
||||
def parse_documents_path(virtual_path: str) -> tuple[list[str], str]:
|
||||
"""Parse a ``/documents/...`` path into ``(folder_parts, document_title)``.
|
||||
|
||||
The title has any ``.xml`` extension and trailing ``" (<doc_id>)"``
|
||||
disambiguation suffix stripped.
|
||||
"""
|
||||
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
|
||||
return [], ""
|
||||
rel = virtual_path[len(DOCUMENTS_ROOT) :].strip("/")
|
||||
if not rel:
|
||||
return [], ""
|
||||
parts = [p for p in rel.split("/") if p]
|
||||
if not parts:
|
||||
return [], ""
|
||||
folder_parts = parts[:-1]
|
||||
basename = parts[-1]
|
||||
stem, _ = parse_doc_id_suffix(basename)
|
||||
title = stem
|
||||
if title.endswith(".xml"):
|
||||
title = title[:-4]
|
||||
return folder_parts, title
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DOCUMENTS_ROOT",
|
||||
"PathIndex",
|
||||
"build_path_index",
|
||||
"doc_to_virtual_path",
|
||||
"parse_doc_id_suffix",
|
||||
"parse_documents_path",
|
||||
"safe_filename",
|
||||
"safe_folder_segment",
|
||||
"virtual_path_to_doc",
|
||||
]
|
||||
201
surfsense_backend/app/agents/new_chat/state_reducers.py
Normal file
201
surfsense_backend/app/agents/new_chat/state_reducers.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
"""Reducers and sentinels for SurfSense filesystem state.
|
||||
|
||||
These reducers back the extra state fields used by the cloud-mode filesystem
|
||||
agent (`cwd`, `staged_dirs`, `pending_moves`, `dirty_paths`, `doc_id_by_path`,
|
||||
`kb_priority`, `kb_matched_chunk_ids`, `kb_anon_doc`, `tree_version`).
|
||||
|
||||
Tools mutate these fields ONLY via `Command(update={...})` returns; the
|
||||
reducers are responsible for merging successive updates atomically and for
|
||||
honouring an explicit reset sentinel (`_CLEAR`) so that a single update can
|
||||
both reset and reseed a list (used by `move_file` / `aafter_agent`).
|
||||
|
||||
The sentinel is intentionally a plain string constant rather than a custom
|
||||
object so that LangGraph's checkpointer (which serializes raw `Command.update`
|
||||
deltas via ``ormsgpack`` BEFORE reducers are applied) can round-trip writes
|
||||
that contain it. The token uses a NUL-bracketed form that cannot collide with
|
||||
any real virtual path, document title, or dict key produced by the agent.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Final, TypeVar
|
||||
|
||||
_CLEAR: Final[str] = "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00"
|
||||
"""Reset sentinel; pass it inside a list/dict update to request a reset.
|
||||
|
||||
For list reducers: ``[_CLEAR, *items]`` resets the field then appends ``items``.
|
||||
For dict reducers: ``{_CLEAR: True, **items}`` resets the field then merges ``items``.
|
||||
|
||||
Because the value is a plain string with embedded NUL bytes, it is natively
|
||||
serializable by ``ormsgpack`` (used by LangGraph's PostgreSQL checkpointer)
|
||||
yet still distinct from any real path / key produced by application code.
|
||||
"""
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _replace_reducer[T](left: T | None, right: T | None) -> T | None:
|
||||
"""Replace `left` outright with `right`. ``None`` on the right is honored as a reset."""
|
||||
return right
|
||||
|
||||
|
||||
def _is_clear(value: Any) -> bool:
|
||||
return isinstance(value, str) and value == _CLEAR
|
||||
|
||||
|
||||
def _add_unique_reducer(
|
||||
left: list[Any] | None,
|
||||
right: list[Any] | None,
|
||||
) -> list[Any]:
|
||||
"""Append items from ``right`` to ``left`` while preserving uniqueness.
|
||||
|
||||
Semantics:
|
||||
- If ``right`` is ``None`` or empty, return ``left`` unchanged.
|
||||
- If ``right`` contains the ``_CLEAR`` sentinel anywhere, the result is
|
||||
reseeded with only the items that appear AFTER the LAST occurrence of
|
||||
``_CLEAR`` (deduplicated, preserving first-seen order). This gives a
|
||||
single-update "reset and reseed" capability.
|
||||
- Otherwise, items from ``right`` are appended to ``left`` (order preserved
|
||||
from first seen) while skipping values that are already present.
|
||||
"""
|
||||
if right is None:
|
||||
return list(left or [])
|
||||
if not right:
|
||||
return list(left or [])
|
||||
|
||||
last_clear = -1
|
||||
for index, item in enumerate(right):
|
||||
if _is_clear(item):
|
||||
last_clear = index
|
||||
|
||||
if last_clear >= 0:
|
||||
seed: list[Any] = []
|
||||
seen: set[Any] = set()
|
||||
for item in right[last_clear + 1 :]:
|
||||
if _is_clear(item):
|
||||
continue
|
||||
try:
|
||||
if item in seen:
|
||||
continue
|
||||
seen.add(item)
|
||||
except TypeError:
|
||||
if item in seed:
|
||||
continue
|
||||
seed.append(item)
|
||||
return seed
|
||||
|
||||
base = list(left or [])
|
||||
try:
|
||||
seen: set[Any] = set(base)
|
||||
except TypeError:
|
||||
seen = set()
|
||||
for item in right:
|
||||
if _is_clear(item):
|
||||
continue
|
||||
try:
|
||||
if item in seen:
|
||||
continue
|
||||
seen.add(item)
|
||||
except TypeError:
|
||||
if item in base:
|
||||
continue
|
||||
base.append(item)
|
||||
return base
|
||||
|
||||
|
||||
def _list_append_reducer(
|
||||
left: list[Any] | None,
|
||||
right: list[Any] | None,
|
||||
) -> list[Any]:
|
||||
"""Append items from ``right`` to ``left`` preserving order and duplicates.
|
||||
|
||||
Honours the ``_CLEAR`` sentinel exactly like :func:`_add_unique_reducer`,
|
||||
but does NOT deduplicate. Used for queues whose ordering and duplicate
|
||||
occurrences matter (e.g. ``pending_moves``).
|
||||
"""
|
||||
if right is None:
|
||||
return list(left or [])
|
||||
if not right:
|
||||
return list(left or [])
|
||||
|
||||
last_clear = -1
|
||||
for index, item in enumerate(right):
|
||||
if _is_clear(item):
|
||||
last_clear = index
|
||||
|
||||
if last_clear >= 0:
|
||||
return [item for item in right[last_clear + 1 :] if not _is_clear(item)]
|
||||
|
||||
base = list(left or [])
|
||||
base.extend(item for item in right if not _is_clear(item))
|
||||
return base
|
||||
|
||||
|
||||
def _dict_merge_with_tombstones_reducer(
|
||||
left: dict[Any, Any] | None,
|
||||
right: dict[Any, Any] | None,
|
||||
) -> dict[Any, Any]:
|
||||
"""Merge ``right`` into ``left`` with two extra capabilities:
|
||||
|
||||
* Keys whose value is ``None`` are removed from the merged result
|
||||
(tombstone semantics, matching the deepagents file-data reducer).
|
||||
* The special key ``_CLEAR`` (with any truthy value) resets ``left`` to
|
||||
``{}`` before merging the remaining keys from ``right``. This makes it
|
||||
possible to atomically clear and reseed the dictionary in a single
|
||||
update.
|
||||
"""
|
||||
if right is None:
|
||||
return dict(left or {})
|
||||
|
||||
if _CLEAR in right or any(_is_clear(k) for k in right):
|
||||
result: dict[Any, Any] = {}
|
||||
for key, value in right.items():
|
||||
if _is_clear(key):
|
||||
continue
|
||||
if value is None:
|
||||
result.pop(key, None)
|
||||
continue
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
if left is None:
|
||||
return {key: value for key, value in right.items() if value is not None}
|
||||
|
||||
result = dict(left)
|
||||
for key, value in right.items():
|
||||
if value is None:
|
||||
result.pop(key, None)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def _initial_filesystem_state() -> dict[str, Any]:
|
||||
"""Default empty values for SurfSense filesystem state fields.
|
||||
|
||||
Consumers should always treat these fields as ``state.get(key) or
|
||||
DEFAULT`` so that fresh threads (without checkpointed state) work
|
||||
correctly.
|
||||
"""
|
||||
return {
|
||||
"cwd": "/documents",
|
||||
"staged_dirs": [],
|
||||
"pending_moves": [],
|
||||
"doc_id_by_path": {},
|
||||
"dirty_paths": [],
|
||||
"kb_priority": [],
|
||||
"kb_matched_chunk_ids": {},
|
||||
"kb_anon_doc": None,
|
||||
"tree_version": 0,
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"_CLEAR",
|
||||
"_add_unique_reducer",
|
||||
"_dict_merge_with_tombstones_reducer",
|
||||
"_initial_filesystem_state",
|
||||
"_list_append_reducer",
|
||||
"_replace_reducer",
|
||||
]
|
||||
|
|
@ -332,7 +332,7 @@ _TOOL_INSTRUCTIONS["scrape_webpage"] = """
|
|||
* When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL
|
||||
* When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices)
|
||||
* When a URL was mentioned earlier in the conversation and the user asks for its actual content
|
||||
* When preloaded `/documents/` data is insufficient and the user wants more
|
||||
* When `/documents/` knowledge-base data is insufficient and the user wants more
|
||||
- Trigger scenarios:
|
||||
* "Read this article and summarize it"
|
||||
* "What does this page say about X?"
|
||||
|
|
|
|||
|
|
@ -20,7 +20,12 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import ImageGeneration, ImageGenerationConfig, SearchSpace
|
||||
from app.db import (
|
||||
ImageGeneration,
|
||||
ImageGenerationConfig,
|
||||
SearchSpace,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
|
|
@ -70,8 +75,13 @@ def create_generate_image_tool(
|
|||
|
||||
Args:
|
||||
search_space_id: The search space ID (for config resolution)
|
||||
db_session: Async database session
|
||||
db_session: Reserved for compatibility with the tool registry.
|
||||
The streaming task's ``AsyncSession`` is shared by every tool;
|
||||
because AsyncSession is not concurrency-safe, parallel tool calls
|
||||
would interleave flushes (e.g. podcast + image in the same step)
|
||||
and poison the transaction. This tool opens its own session.
|
||||
"""
|
||||
del db_session # use a fresh per-call session, see below
|
||||
|
||||
@tool
|
||||
async def generate_image(
|
||||
|
|
@ -93,110 +103,119 @@ def create_generate_image_tool(
|
|||
A dictionary containing the generated image(s) for display in the chat.
|
||||
"""
|
||||
try:
|
||||
# Resolve the image generation config from the search space preference
|
||||
result = await db_session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return {"error": "Search space not found"}
|
||||
|
||||
config_id = (
|
||||
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
|
||||
# Build generation kwargs
|
||||
# NOTE: size, quality, and style are intentionally NOT passed.
|
||||
# Different models support different values for these params
|
||||
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
|
||||
# gpt-image-1 wants "high"/"medium"/"low"; size options also
|
||||
# differ). Letting the model use its own defaults avoids errors.
|
||||
gen_kwargs: dict[str, Any] = {}
|
||||
if n is not None and n > 1:
|
||||
gen_kwargs["n"] = n
|
||||
|
||||
# Call litellm based on config type
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
return {
|
||||
"error": "No image generation models configured. "
|
||||
"Please add an image model in Settings > Image Models."
|
||||
}
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=prompt, model="auto", **gen_kwargs
|
||||
# Use a per-call session so concurrent tool calls don't share an
|
||||
# AsyncSession (which is not concurrency-safe). The streaming
|
||||
# task's session is shared across every tool; without isolation,
|
||||
# autoflushes from a concurrent writer poison this tool too.
|
||||
async with shielded_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
elif config_id < 0:
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
return {"error": f"Image generation config {config_id} not found"}
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return {"error": "Search space not found"}
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg["model_name"],
|
||||
cfg.get("custom_provider"),
|
||||
config_id = (
|
||||
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = user-created ImageGenerationConfig
|
||||
cfg_result = await db_session.execute(
|
||||
select(ImageGenerationConfig).filter(
|
||||
ImageGenerationConfig.id == config_id
|
||||
# Build generation kwargs
|
||||
# NOTE: size, quality, and style are intentionally NOT passed.
|
||||
# Different models support different values for these params
|
||||
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
|
||||
# gpt-image-1 wants "high"/"medium"/"low"; size options also
|
||||
# differ). Letting the model use its own defaults avoids errors.
|
||||
gen_kwargs: dict[str, Any] = {}
|
||||
if n is not None and n > 1:
|
||||
gen_kwargs["n"] = n
|
||||
|
||||
# Call litellm based on config type
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
return {
|
||||
"error": "No image generation models configured. "
|
||||
"Please add an image model in Settings > Image Models."
|
||||
}
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=prompt, model="auto", **gen_kwargs
|
||||
)
|
||||
)
|
||||
db_cfg = cfg_result.scalars().first()
|
||||
if not db_cfg:
|
||||
return {"error": f"Image generation config {config_id} not found"}
|
||||
elif config_id < 0:
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
return {
|
||||
"error": f"Image generation config {config_id} not found"
|
||||
}
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value,
|
||||
db_cfg.model_name,
|
||||
db_cfg.custom_provider,
|
||||
)
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg["model_name"],
|
||||
cfg.get("custom_provider"),
|
||||
)
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = user-created ImageGenerationConfig
|
||||
cfg_result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(
|
||||
ImageGenerationConfig.id == config_id
|
||||
)
|
||||
)
|
||||
db_cfg = cfg_result.scalars().first()
|
||||
if not db_cfg:
|
||||
return {
|
||||
"error": f"Image generation config {config_id} not found"
|
||||
}
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value,
|
||||
db_cfg.model_name,
|
||||
db_cfg.custom_provider,
|
||||
)
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
|
||||
# Parse the response and store in DB
|
||||
response_dict = (
|
||||
response.model_dump()
|
||||
if hasattr(response, "model_dump")
|
||||
else dict(response)
|
||||
)
|
||||
|
||||
# Parse the response and store in DB
|
||||
response_dict = (
|
||||
response.model_dump()
|
||||
if hasattr(response, "model_dump")
|
||||
else dict(response)
|
||||
)
|
||||
# Generate a random access token for this image
|
||||
access_token = generate_image_token()
|
||||
|
||||
# Generate a random access token for this image
|
||||
access_token = generate_image_token()
|
||||
|
||||
# Save to image_generations table for history
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=prompt,
|
||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||
n=n,
|
||||
image_generation_config_id=config_id,
|
||||
response_data=response_dict,
|
||||
search_space_id=search_space_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
db_session.add(db_image_gen)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(db_image_gen)
|
||||
# Save to image_generations table for history
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=prompt,
|
||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||
n=n,
|
||||
image_generation_config_id=config_id,
|
||||
response_data=response_dict,
|
||||
search_space_id=search_space_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
db_image_gen_id = db_image_gen.id
|
||||
|
||||
# Extract image URLs from response
|
||||
images = response_dict.get("data", [])
|
||||
|
|
@ -217,7 +236,7 @@ def create_generate_image_tool(
|
|||
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
||||
image_url = (
|
||||
f"{backend_url}/api/v1/image-generations/"
|
||||
f"{db_image_gen.id}/image?token={access_token}"
|
||||
f"{db_image_gen_id}/image?token={access_token}"
|
||||
)
|
||||
else:
|
||||
return {"error": "No displayable image data in the response"}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing import Any
|
|||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Podcast, PodcastStatus
|
||||
from app.db import Podcast, PodcastStatus, shielded_async_session
|
||||
|
||||
|
||||
def create_generate_podcast_tool(
|
||||
|
|
@ -27,12 +27,16 @@ def create_generate_podcast_tool(
|
|||
|
||||
Args:
|
||||
search_space_id: The user's search space ID
|
||||
db_session: Database session for creating the podcast record
|
||||
db_session: Reserved for future read-side use; the row is written via a
|
||||
fresh, tool-local session so parallel tool calls (e.g. podcast +
|
||||
video presentation in the same agent step) don't share an
|
||||
``AsyncSession`` (which is not concurrency-safe).
|
||||
thread_id: The chat thread ID for associating the podcast
|
||||
|
||||
Returns:
|
||||
A configured tool function for generating podcasts
|
||||
"""
|
||||
del db_session # writes use a fresh tool-local session, see below
|
||||
|
||||
@tool
|
||||
async def generate_podcast(
|
||||
|
|
@ -64,32 +68,40 @@ def create_generate_podcast_tool(
|
|||
- message: Status message (or "error" field if status is failed)
|
||||
"""
|
||||
try:
|
||||
podcast = Podcast(
|
||||
title=podcast_title,
|
||||
status=PodcastStatus.PENDING,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
db_session.add(podcast)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(podcast)
|
||||
# Open a fresh session per call. The streaming task's session is
|
||||
# shared between every tool, and ``AsyncSession`` is NOT safe for
|
||||
# concurrent use: when the LLM emits parallel tool calls, two
|
||||
# concurrent ``add()`` / ``commit()`` paths interleave and the
|
||||
# second one hits "Session.add() during flush" → the transaction
|
||||
# is poisoned for both tools.
|
||||
async with shielded_async_session() as session:
|
||||
podcast = Podcast(
|
||||
title=podcast_title,
|
||||
status=PodcastStatus.PENDING,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
session.add(podcast)
|
||||
await session.commit()
|
||||
await session.refresh(podcast)
|
||||
podcast_id = podcast.id
|
||||
|
||||
from app.tasks.celery_tasks.podcast_tasks import (
|
||||
generate_content_podcast_task,
|
||||
)
|
||||
|
||||
task = generate_content_podcast_task.delay(
|
||||
podcast_id=podcast.id,
|
||||
podcast_id=podcast_id,
|
||||
source_content=source_content,
|
||||
search_space_id=search_space_id,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
print(f"[generate_podcast] Created podcast {podcast.id}, task: {task.id}")
|
||||
print(f"[generate_podcast] Created podcast {podcast_id}, task: {task.id}")
|
||||
|
||||
return {
|
||||
"status": PodcastStatus.PENDING.value,
|
||||
"podcast_id": podcast.id,
|
||||
"podcast_id": podcast_id,
|
||||
"title": podcast_title,
|
||||
"message": "Podcast generation started. This may take a few minutes.",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing import Any
|
|||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import VideoPresentation, VideoPresentationStatus
|
||||
from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session
|
||||
|
||||
|
||||
def create_generate_video_presentation_tool(
|
||||
|
|
@ -23,8 +23,11 @@ def create_generate_video_presentation_tool(
|
|||
Factory function to create the generate_video_presentation tool with injected dependencies.
|
||||
|
||||
Pre-creates video presentation record with pending status so the ID is available
|
||||
immediately for frontend polling.
|
||||
immediately for frontend polling. The row is written via a fresh, tool-local
|
||||
session so parallel tool calls (e.g. video + podcast in the same agent step)
|
||||
don't share an ``AsyncSession`` (which is not concurrency-safe).
|
||||
"""
|
||||
del db_session # writes use a fresh tool-local session, see below
|
||||
|
||||
@tool
|
||||
async def generate_video_presentation(
|
||||
|
|
@ -42,34 +45,40 @@ def create_generate_video_presentation_tool(
|
|||
user_prompt: Optional style/tone instructions.
|
||||
"""
|
||||
try:
|
||||
video_pres = VideoPresentation(
|
||||
title=video_title,
|
||||
status=VideoPresentationStatus.PENDING,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
db_session.add(video_pres)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(video_pres)
|
||||
# See podcast.py for the rationale: parallel tool calls share the
|
||||
# streaming session, and AsyncSession is not concurrency-safe —
|
||||
# interleaved flushes produce "Session.add() during flush" and
|
||||
# poison the transaction for every concurrent tool.
|
||||
async with shielded_async_session() as session:
|
||||
video_pres = VideoPresentation(
|
||||
title=video_title,
|
||||
status=VideoPresentationStatus.PENDING,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
session.add(video_pres)
|
||||
await session.commit()
|
||||
await session.refresh(video_pres)
|
||||
video_pres_id = video_pres.id
|
||||
|
||||
from app.tasks.celery_tasks.video_presentation_tasks import (
|
||||
generate_video_presentation_task,
|
||||
)
|
||||
|
||||
task = generate_video_presentation_task.delay(
|
||||
video_presentation_id=video_pres.id,
|
||||
video_presentation_id=video_pres_id,
|
||||
source_content=source_content,
|
||||
search_space_id=search_space_id,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
print(
|
||||
f"[generate_video_presentation] Created video presentation {video_pres.id}, task: {task.id}"
|
||||
f"[generate_video_presentation] Created video presentation {video_pres_id}, task: {task.id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": VideoPresentationStatus.PENDING.value,
|
||||
"video_presentation_id": video_pres.id,
|
||||
"video_presentation_id": video_pres_id,
|
||||
"title": video_title,
|
||||
"message": "Video presentation generation started. This may take a few minutes.",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from sqlalchemy.orm import selectinload
|
|||
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemSelection
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import (
|
||||
AgentConfig,
|
||||
create_chat_litellm_from_agent_config,
|
||||
|
|
@ -42,6 +42,9 @@ from app.agents.new_chat.memory_extraction import (
|
|||
extract_and_save_memory,
|
||||
extract_and_save_team_memory,
|
||||
)
|
||||
from app.agents.new_chat.middleware.kb_persistence import (
|
||||
commit_staged_filesystem_state,
|
||||
)
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
|
|
@ -258,6 +261,10 @@ async def _stream_agent_events(
|
|||
initial_step_id: str | None = None,
|
||||
initial_step_title: str = "",
|
||||
initial_step_items: list[str] | None = None,
|
||||
*,
|
||||
fallback_commit_search_space_id: int | None = None,
|
||||
fallback_commit_created_by_id: str | None = None,
|
||||
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Shared async generator that streams and formats astream_events from the agent.
|
||||
|
||||
|
|
@ -1280,6 +1287,40 @@ async def _stream_agent_events(
|
|||
|
||||
state = await agent.aget_state(config)
|
||||
state_values = getattr(state, "values", {}) or {}
|
||||
|
||||
# Safety net: if astream_events was cancelled before
|
||||
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
|
||||
# (dirty_paths / staged_dirs / pending_moves) will still be in the
|
||||
# checkpointed state. Run the SAME shared commit helper here so the
|
||||
# turn's writes don't get lost on client disconnect, then push the
|
||||
# delta back into the graph using `as_node=...` so reducers fire as if
|
||||
# the after_agent hook produced it.
|
||||
if (
|
||||
fallback_commit_filesystem_mode == FilesystemMode.CLOUD
|
||||
and fallback_commit_search_space_id is not None
|
||||
and (
|
||||
(state_values.get("dirty_paths") or [])
|
||||
or (state_values.get("staged_dirs") or [])
|
||||
or (state_values.get("pending_moves") or [])
|
||||
)
|
||||
):
|
||||
try:
|
||||
delta = await commit_staged_filesystem_state(
|
||||
state_values,
|
||||
search_space_id=fallback_commit_search_space_id,
|
||||
created_by_id=fallback_commit_created_by_id,
|
||||
filesystem_mode=fallback_commit_filesystem_mode,
|
||||
dispatch_events=False,
|
||||
)
|
||||
if delta:
|
||||
await agent.aupdate_state(
|
||||
config,
|
||||
delta,
|
||||
as_node="KnowledgeBasePersistenceMiddleware.after_agent",
|
||||
)
|
||||
except Exception as exc:
|
||||
_perf_log.warning("[stream_new_chat] safety-net commit failed: %s", exc)
|
||||
|
||||
contract_state = state_values.get("file_operation_contract") or {}
|
||||
contract_turn_id = contract_state.get("turn_id")
|
||||
current_turn_id = config.get("configurable", {}).get("turn_id", "")
|
||||
|
|
@ -1814,6 +1855,13 @@ async def stream_new_chat(
|
|||
initial_step_id=initial_step_id,
|
||||
initial_step_title=initial_title,
|
||||
initial_step_items=initial_items,
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
|
|
@ -2251,6 +2299,13 @@ async def stream_resume_chat(
|
|||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking-resume",
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,198 @@
|
|||
"""Tests for canonical virtual-path resolver helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
PathIndex,
|
||||
doc_to_virtual_path,
|
||||
parse_doc_id_suffix,
|
||||
parse_documents_path,
|
||||
safe_filename,
|
||||
safe_folder_segment,
|
||||
virtual_path_to_doc,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestSafeFilename:
|
||||
def test_appends_xml_extension(self):
|
||||
assert safe_filename("notes").endswith(".xml")
|
||||
|
||||
def test_strips_invalid_chars(self):
|
||||
assert "/" not in safe_filename("a/b\\c.xml")
|
||||
|
||||
def test_falls_back_when_empty(self):
|
||||
assert safe_filename("").endswith(".xml")
|
||||
assert safe_filename("///") == "untitled.xml" or safe_filename("///").endswith(
|
||||
".xml"
|
||||
)
|
||||
|
||||
|
||||
class TestSafeFolderSegment:
|
||||
def test_strips_path_separators(self):
|
||||
assert "/" not in safe_folder_segment("a/b")
|
||||
|
||||
def test_falls_back(self):
|
||||
assert safe_folder_segment("") == "folder"
|
||||
|
||||
|
||||
class TestParseDocIdSuffix:
|
||||
def test_parses_suffix(self):
|
||||
stem, doc_id = parse_doc_id_suffix("My Doc (42).xml")
|
||||
assert stem == "My Doc"
|
||||
assert doc_id == 42
|
||||
|
||||
def test_no_suffix_returns_none(self):
|
||||
stem, doc_id = parse_doc_id_suffix("My Doc.xml")
|
||||
assert stem == "My Doc"
|
||||
assert doc_id is None
|
||||
|
||||
def test_no_xml_extension(self):
|
||||
stem, doc_id = parse_doc_id_suffix("plain")
|
||||
assert stem == "plain"
|
||||
assert doc_id is None
|
||||
|
||||
|
||||
class TestDocToVirtualPath:
|
||||
def test_root_when_no_folder(self):
|
||||
index = PathIndex()
|
||||
path = doc_to_virtual_path(doc_id=1, title="Hello", folder_id=None, index=index)
|
||||
assert path == f"{DOCUMENTS_ROOT}/Hello.xml"
|
||||
assert index.occupants[path] == 1
|
||||
|
||||
def test_collision_picks_doc_id_suffix(self):
|
||||
index = PathIndex(occupants={f"{DOCUMENTS_ROOT}/Hello.xml": 7})
|
||||
path = doc_to_virtual_path(doc_id=8, title="Hello", folder_id=None, index=index)
|
||||
assert path == f"{DOCUMENTS_ROOT}/Hello (8).xml"
|
||||
assert index.occupants[path] == 8
|
||||
|
||||
def test_uses_folder_path_when_known(self):
|
||||
index = PathIndex(folder_paths={5: f"{DOCUMENTS_ROOT}/notes"})
|
||||
path = doc_to_virtual_path(doc_id=2, title="A", folder_id=5, index=index)
|
||||
assert path == f"{DOCUMENTS_ROOT}/notes/A.xml"
|
||||
|
||||
|
||||
class TestParseDocumentsPath:
|
||||
def test_extracts_folder_parts_and_title(self):
|
||||
parts, title = parse_documents_path(f"{DOCUMENTS_ROOT}/foo/bar/baz.xml")
|
||||
assert parts == ["foo", "bar"]
|
||||
assert title == "baz"
|
||||
|
||||
def test_strips_doc_id_suffix(self):
|
||||
parts, title = parse_documents_path(f"{DOCUMENTS_ROOT}/foo/My Doc (12).xml")
|
||||
assert parts == ["foo"]
|
||||
assert title == "My Doc"
|
||||
|
||||
def test_non_documents_returns_empty(self):
|
||||
assert parse_documents_path("/other/x.xml") == ([], "")
|
||||
|
||||
|
||||
def _result_from_scalars(rows: list):
|
||||
"""Build a fake SQLAlchemy ``Result`` whose ``.scalars().all()`` and
|
||||
``.scalars().first()`` yield ``rows``."""
|
||||
scalars = MagicMock()
|
||||
scalars.all.return_value = list(rows)
|
||||
scalars.first.return_value = rows[0] if rows else None
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
result.scalar_one_or_none.return_value = None
|
||||
result.first.return_value = None
|
||||
return result
|
||||
|
||||
|
||||
def _result_from_one(value):
|
||||
result = MagicMock()
|
||||
result.scalar_one_or_none.return_value = value
|
||||
return result
|
||||
|
||||
|
||||
class TestVirtualPathToDoc:
|
||||
"""Lookup must round-trip through ``safe_filename``'s lossy encoding.
|
||||
|
||||
The workspace tree displays ``safe_filename(title)`` as the basename, so
|
||||
when the agent passes that basename back to a tool (move/edit/read) the
|
||||
resolver must find the original document even though characters like
|
||||
``:`` were replaced with ``_``.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_safe_filename_match_when_title_lossy(self):
|
||||
# A Google Calendar-style title that contains a colon — safe_filename
|
||||
# rewrites the colon to ``_``, so the literal title-equality lookup
|
||||
# would miss this row.
|
||||
original_title = "Calendar: Happy birthday!"
|
||||
encoded_basename = safe_filename(original_title)
|
||||
assert encoded_basename == "Calendar_ Happy birthday!.xml"
|
||||
|
||||
target_doc = SimpleNamespace(id=42, title=original_title, folder_id=None)
|
||||
|
||||
session = MagicMock()
|
||||
# Each ``await session.execute(...)`` returns a fresh canned result.
|
||||
# Order matches the resolver's lookup steps:
|
||||
# 1) unique_identifier_hash → no match
|
||||
# 2) literal title match → no match (lossy encoding)
|
||||
# 3) folder scan → returns the row whose title encodes to basename
|
||||
session.execute = AsyncMock(
|
||||
side_effect=[
|
||||
_result_from_one(None),
|
||||
_result_from_scalars([]),
|
||||
_result_from_scalars([target_doc]),
|
||||
]
|
||||
)
|
||||
|
||||
document = await virtual_path_to_doc(
|
||||
session,
|
||||
search_space_id=5,
|
||||
virtual_path=f"{DOCUMENTS_ROOT}/{encoded_basename}",
|
||||
)
|
||||
assert document is target_doc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_doc_matches_safe_filename(self):
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(
|
||||
side_effect=[
|
||||
_result_from_one(None),
|
||||
_result_from_scalars([]),
|
||||
_result_from_scalars(
|
||||
[SimpleNamespace(id=1, title="Something else", folder_id=None)]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
document = await virtual_path_to_doc(
|
||||
session,
|
||||
search_space_id=5,
|
||||
virtual_path=f"{DOCUMENTS_ROOT}/Calendar_ Happy birthday!.xml",
|
||||
)
|
||||
assert document is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_literal_title_match_short_circuits_fallback(self):
|
||||
# When the literal title query hits, the folder-scan fallback must
|
||||
# NOT run (saves a query and avoids picking the wrong doc when two
|
||||
# rows share a lossy encoding).
|
||||
target_doc = SimpleNamespace(id=7, title="Plain Note", folder_id=None)
|
||||
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(
|
||||
side_effect=[
|
||||
_result_from_one(None),
|
||||
_result_from_scalars([target_doc]),
|
||||
]
|
||||
)
|
||||
|
||||
document = await virtual_path_to_doc(
|
||||
session,
|
||||
search_space_id=5,
|
||||
virtual_path=f"{DOCUMENTS_ROOT}/Plain Note.xml",
|
||||
)
|
||||
assert document is target_doc
|
||||
assert session.execute.await_count == 2
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
"""Tests for SurfSense filesystem state reducers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.state_reducers import (
|
||||
_CLEAR,
|
||||
_add_unique_reducer,
|
||||
_dict_merge_with_tombstones_reducer,
|
||||
_initial_filesystem_state,
|
||||
_list_append_reducer,
|
||||
_replace_reducer,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestReplaceReducer:
|
||||
def test_right_wins_outright(self):
|
||||
assert _replace_reducer("a", "b") == "b"
|
||||
|
||||
def test_none_right_returns_none(self):
|
||||
assert _replace_reducer("a", None) is None
|
||||
|
||||
def test_none_left_returns_right(self):
|
||||
assert _replace_reducer(None, "b") == "b"
|
||||
|
||||
|
||||
class TestAddUniqueReducer:
|
||||
def test_appends_unique_items(self):
|
||||
assert _add_unique_reducer(["a"], ["b", "c"]) == ["a", "b", "c"]
|
||||
|
||||
def test_dedupes_against_left(self):
|
||||
assert _add_unique_reducer(["a", "b"], ["b", "c"]) == ["a", "b", "c"]
|
||||
|
||||
def test_dedupes_within_right(self):
|
||||
assert _add_unique_reducer([], ["a", "a", "b"]) == ["a", "b"]
|
||||
|
||||
def test_clear_anywhere_resets_and_reseeds_with_after_items(self):
|
||||
# _CLEAR semantics: only items AFTER the LAST _CLEAR are kept.
|
||||
result = _add_unique_reducer(["x", "y"], ["a", _CLEAR, "b", "c"])
|
||||
assert result == ["b", "c"]
|
||||
|
||||
def test_multiple_clears_use_last(self):
|
||||
result = _add_unique_reducer(["x"], [_CLEAR, "a", _CLEAR, "b"])
|
||||
assert result == ["b"]
|
||||
|
||||
def test_clear_only_resets_to_empty(self):
|
||||
assert _add_unique_reducer(["x", "y"], [_CLEAR]) == []
|
||||
|
||||
def test_empty_right_keeps_left(self):
|
||||
assert _add_unique_reducer(["a"], []) == ["a"]
|
||||
assert _add_unique_reducer(["a"], None) == ["a"]
|
||||
|
||||
|
||||
class TestListAppendReducer:
|
||||
def test_preserves_order_and_duplicates(self):
|
||||
result = _list_append_reducer([{"a": 1}], [{"b": 2}, {"a": 1}])
|
||||
assert result == [{"a": 1}, {"b": 2}, {"a": 1}]
|
||||
|
||||
def test_clear_resets_keeping_after_items(self):
|
||||
result = _list_append_reducer([{"a": 1}], [{"old": 1}, _CLEAR, {"new": 2}])
|
||||
assert result == [{"new": 2}]
|
||||
|
||||
|
||||
class TestDictMergeWithTombstones:
|
||||
def test_merges_keys(self):
|
||||
assert _dict_merge_with_tombstones_reducer({"a": 1}, {"b": 2}) == {
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
}
|
||||
|
||||
def test_none_value_deletes_key(self):
|
||||
result = _dict_merge_with_tombstones_reducer({"a": 1, "b": 2}, {"a": None})
|
||||
assert result == {"b": 2}
|
||||
|
||||
def test_clear_resets_then_merges(self):
|
||||
result = _dict_merge_with_tombstones_reducer(
|
||||
{"a": 1, "b": 2}, {_CLEAR: True, "c": 3}
|
||||
)
|
||||
assert result == {"c": 3}
|
||||
|
||||
def test_clear_keeps_only_post_clear_non_none(self):
|
||||
result = _dict_merge_with_tombstones_reducer(
|
||||
{"a": 1}, {_CLEAR: True, "b": 2, "c": None}
|
||||
)
|
||||
assert result == {"b": 2}
|
||||
|
||||
def test_none_left_handled(self):
|
||||
assert _dict_merge_with_tombstones_reducer(None, {"a": 1, "b": None}) == {
|
||||
"a": 1
|
||||
}
|
||||
|
||||
|
||||
class TestInitialFilesystemState:
|
||||
def test_default_shape(self):
|
||||
state = _initial_filesystem_state()
|
||||
assert state["cwd"] == "/documents"
|
||||
assert state["staged_dirs"] == []
|
||||
assert state["pending_moves"] == []
|
||||
assert state["doc_id_by_path"] == {}
|
||||
assert state["dirty_paths"] == []
|
||||
assert state["kb_priority"] == []
|
||||
assert state["kb_matched_chunk_ids"] == {}
|
||||
assert state["kb_anon_doc"] is None
|
||||
assert state["tree_version"] == 0
|
||||
|
|
@ -36,11 +36,18 @@ def test_backend_resolver_returns_multi_root_backend_for_single_root(tmp_path: P
|
|||
def test_backend_resolver_uses_cloud_mode_by_default():
|
||||
resolver = build_backend_resolver(FilesystemSelection())
|
||||
backend = resolver(_RuntimeStub())
|
||||
# StateBackend class name check keeps this test decoupled
|
||||
# from internal deepagents runtime class identity.
|
||||
# When no search_space_id is provided we fall back to StateBackend so
|
||||
# sub-agents / tests without DB access still work.
|
||||
assert backend.__class__.__name__ == "StateBackend"
|
||||
|
||||
|
||||
def test_backend_resolver_uses_kb_postgres_in_cloud_with_search_space():
|
||||
resolver = build_backend_resolver(FilesystemSelection(), search_space_id=42)
|
||||
backend = resolver(_RuntimeStub())
|
||||
assert backend.__class__.__name__ == "KBPostgresBackend"
|
||||
assert backend.search_space_id == 42
|
||||
|
||||
|
||||
def test_backend_resolver_returns_multi_root_backend_for_multiple_roots(tmp_path: Path):
|
||||
root_one = tmp_path / "resume"
|
||||
root_two = tmp_path / "notes"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,204 @@
|
|||
"""Unit tests for the SurfSense filesystem middleware new behaviors.
|
||||
|
||||
Covers:
|
||||
* cloud cwd defaults to ``/documents`` and relative paths resolve under it
|
||||
* cloud writes outside ``/documents/`` are rejected unless basename starts
|
||||
with ``temp_``
|
||||
* cloud writes/edits to the anonymous document are rejected (read-only)
|
||||
* helper methods on the middleware (``_resolve_relative``,
|
||||
``_check_cloud_write_namespace``, ``_default_cwd``)
|
||||
|
||||
These tests use ``__new__`` to bypass the heavy ``__init__`` and exercise
|
||||
the helper methods directly so the test surface stays narrow and fast.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware.filesystem import (
|
||||
SurfSenseFilesystemMiddleware,
|
||||
_build_filesystem_system_prompt,
|
||||
_build_tool_descriptions,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD):
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
middleware._filesystem_mode = mode
|
||||
return middleware
|
||||
|
||||
|
||||
def _runtime(state: dict | None = None) -> SimpleNamespace:
|
||||
return SimpleNamespace(state=state or {})
|
||||
|
||||
|
||||
class TestCloudCwdDefaults:
|
||||
def test_default_cwd_in_cloud_is_documents_root(self):
|
||||
m = _make_middleware()
|
||||
assert m._default_cwd() == "/documents"
|
||||
|
||||
def test_default_cwd_in_desktop_is_root(self):
|
||||
m = _make_middleware(FilesystemMode.DESKTOP_LOCAL_FOLDER)
|
||||
assert m._default_cwd() == "/"
|
||||
|
||||
def test_current_cwd_uses_state_when_set(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime({"cwd": "/documents/notes"})
|
||||
assert m._current_cwd(runtime) == "/documents/notes"
|
||||
|
||||
def test_current_cwd_falls_back_to_default(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime({})
|
||||
assert m._current_cwd(runtime) == "/documents"
|
||||
|
||||
def test_current_cwd_ignores_invalid(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime({"cwd": "not-absolute"})
|
||||
assert m._current_cwd(runtime) == "/documents"
|
||||
|
||||
|
||||
class TestRelativePathResolution:
|
||||
def test_relative_path_resolves_against_cwd(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime({"cwd": "/documents/projects"})
|
||||
assert (
|
||||
m._resolve_relative("notes.md", runtime) == "/documents/projects/notes.md"
|
||||
)
|
||||
|
||||
def test_relative_path_with_dotdot(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime({"cwd": "/documents/a/b"})
|
||||
assert m._resolve_relative("../c.md", runtime) == "/documents/a/c.md"
|
||||
|
||||
def test_absolute_path_is_kept(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime({"cwd": "/documents"})
|
||||
assert m._resolve_relative("/other/x.md", runtime) == "/other/x.md"
|
||||
|
||||
def test_empty_path_returns_cwd(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime({"cwd": "/documents/projects"})
|
||||
assert m._resolve_relative("", runtime) == "/documents/projects"
|
||||
|
||||
|
||||
class TestCloudWriteNamespacePolicy:
|
||||
def test_documents_path_allowed(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime()
|
||||
assert m._check_cloud_write_namespace("/documents/foo.md", runtime) is None
|
||||
|
||||
def test_documents_root_allowed(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime()
|
||||
assert m._check_cloud_write_namespace("/documents", runtime) is None
|
||||
|
||||
def test_temp_basename_anywhere_allowed(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime()
|
||||
assert m._check_cloud_write_namespace("/temp_scratch.md", runtime) is None
|
||||
assert m._check_cloud_write_namespace("/foo/temp_x.md", runtime) is None
|
||||
assert m._check_cloud_write_namespace("/documents/temp_x.md", runtime) is None
|
||||
|
||||
def test_other_paths_rejected(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime()
|
||||
err = m._check_cloud_write_namespace("/foo/bar.md", runtime)
|
||||
assert err is not None
|
||||
assert "must target /documents" in err
|
||||
|
||||
def test_anon_doc_path_is_read_only(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime(
|
||||
{
|
||||
"kb_anon_doc": {
|
||||
"path": "/documents/uploaded.xml",
|
||||
"title": "uploaded",
|
||||
"content": "",
|
||||
"chunks": [],
|
||||
}
|
||||
}
|
||||
)
|
||||
err = m._check_cloud_write_namespace("/documents/uploaded.xml", runtime)
|
||||
assert err is not None
|
||||
assert "read-only" in err
|
||||
|
||||
def test_desktop_mode_skips_namespace_policy(self):
|
||||
m = _make_middleware(FilesystemMode.DESKTOP_LOCAL_FOLDER)
|
||||
runtime = _runtime()
|
||||
assert m._check_cloud_write_namespace("/random/path.md", runtime) is None
|
||||
|
||||
|
||||
class TestModeSpecificPrompts:
|
||||
"""The prompt and tool descriptions must only describe the active mode.
|
||||
|
||||
Cross-mode noise wastes tokens and confuses the model with rules it
|
||||
cannot use this session.
|
||||
"""
|
||||
|
||||
def test_cloud_prompt_omits_desktop_section(self):
|
||||
prompt = _build_filesystem_system_prompt(
|
||||
FilesystemMode.CLOUD, sandbox_available=False
|
||||
)
|
||||
assert "Local Folder Mode" not in prompt
|
||||
assert "mount-prefixed" not in prompt
|
||||
assert "Persistence Rules" in prompt
|
||||
assert "/documents" in prompt
|
||||
assert "temp_" in prompt
|
||||
|
||||
def test_desktop_prompt_omits_cloud_persistence_rules(self):
|
||||
prompt = _build_filesystem_system_prompt(
|
||||
FilesystemMode.DESKTOP_LOCAL_FOLDER, sandbox_available=False
|
||||
)
|
||||
assert "Persistence Rules" not in prompt
|
||||
assert "Workspace Tree" not in prompt
|
||||
assert "<priority_documents>" not in prompt
|
||||
assert "Local Folder Mode" in prompt
|
||||
assert "mount-prefixed" in prompt
|
||||
|
||||
def test_cloud_tool_descs_omit_desktop_phrases(self):
|
||||
descs = _build_tool_descriptions(FilesystemMode.CLOUD)
|
||||
for name in (
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"move_file",
|
||||
"mkdir",
|
||||
"list_tree",
|
||||
"grep",
|
||||
):
|
||||
text = descs[name]
|
||||
assert "Desktop" not in text, f"{name} leaks desktop hints"
|
||||
assert "Cloud mode:" not in text, f"{name} qualifies a cloud-only desc"
|
||||
|
||||
def test_desktop_tool_descs_omit_cloud_phrases(self):
|
||||
descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER)
|
||||
for name in (
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"move_file",
|
||||
"mkdir",
|
||||
"list_tree",
|
||||
"grep",
|
||||
):
|
||||
text = descs[name]
|
||||
assert "Cloud" not in text, f"{name} leaks cloud hints"
|
||||
assert "/documents/" not in text, f"{name} mentions cloud namespace"
|
||||
assert "temp_" not in text, f"{name} mentions cloud temp_ semantics"
|
||||
|
||||
def test_sandbox_addendum_appended_when_available(self):
|
||||
prompt = _build_filesystem_system_prompt(
|
||||
FilesystemMode.CLOUD, sandbox_available=True
|
||||
)
|
||||
assert "execute_code" in prompt
|
||||
assert "Code Execution" in prompt
|
||||
|
||||
def test_sandbox_addendum_absent_when_unavailable(self):
|
||||
prompt = _build_filesystem_system_prompt(
|
||||
FilesystemMode.CLOUD, sandbox_available=False
|
||||
)
|
||||
assert "execute_code" not in prompt
|
||||
|
|
@ -11,25 +11,6 @@ from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
|||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _BackendWithRawRead:
|
||||
def __init__(self, content: str) -> None:
|
||||
self._content = content
|
||||
|
||||
def read(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
|
||||
del file_path, offset, limit
|
||||
return " 1\tline1\n 2\tline2"
|
||||
|
||||
async def aread(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
|
||||
return self.read(file_path, offset, limit)
|
||||
|
||||
def read_raw(self, file_path: str) -> str:
|
||||
del file_path
|
||||
return self._content
|
||||
|
||||
async def aread_raw(self, file_path: str) -> str:
|
||||
return self.read_raw(file_path)
|
||||
|
||||
|
||||
class _RuntimeNoSuggestedPath:
|
||||
state = {"file_operation_contract": {}}
|
||||
|
||||
|
|
@ -39,40 +20,19 @@ class _RuntimeWithSuggestedPath:
|
|||
self.state = {"file_operation_contract": {"suggested_path": suggested_path}}
|
||||
|
||||
|
||||
def test_verify_written_content_prefers_raw_sync() -> None:
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
expected = "line1\nline2"
|
||||
backend = _BackendWithRawRead(expected)
|
||||
|
||||
verify_error = middleware._verify_written_content_sync(
|
||||
backend=backend,
|
||||
path="/note.md",
|
||||
expected_content=expected,
|
||||
)
|
||||
|
||||
assert verify_error is None
|
||||
|
||||
|
||||
def test_contract_suggested_path_falls_back_to_notes_md() -> None:
|
||||
def test_contract_suggested_path_falls_back_to_documents_notes_md() -> None:
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
middleware._filesystem_mode = FilesystemMode.CLOUD
|
||||
suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type]
|
||||
assert suggested == "/notes.md"
|
||||
# Cloud default cwd is /documents so the fallback lands in the KB.
|
||||
assert suggested == "/documents/notes.md"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_written_content_prefers_raw_async() -> None:
|
||||
def test_contract_suggested_path_falls_back_to_root_notes_md_in_desktop() -> None:
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
expected = "line1\nline2"
|
||||
backend = _BackendWithRawRead(expected)
|
||||
|
||||
verify_error = await middleware._verify_written_content_async(
|
||||
backend=backend,
|
||||
path="/note.md",
|
||||
expected_content=expected,
|
||||
)
|
||||
|
||||
assert verify_error is None
|
||||
middleware._filesystem_mode = FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||
suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type]
|
||||
assert suggested == "/notes.md"
|
||||
|
||||
|
||||
def test_normalize_local_mount_path_prefixes_default_mount(tmp_path: Path) -> None:
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@ import json
|
|||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.agents.new_chat.document_xml import build_document_xml as _build_document_xml
|
||||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
KBSearchPlan,
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
_build_document_xml,
|
||||
_normalize_optional_date_range,
|
||||
_parse_kb_search_plan_response,
|
||||
_render_recent_conversation,
|
||||
|
|
@ -248,17 +248,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
|||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}, {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||
fake_build_scoped_filesystem,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
|
|
@ -298,17 +291,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
|||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}, {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||
fake_build_scoped_filesystem,
|
||||
)
|
||||
|
||||
middleware = KnowledgeBaseSearchMiddleware(
|
||||
llm=FakeLLM("not json"),
|
||||
|
|
@ -334,17 +320,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
|||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}, {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||
fake_build_scoped_filesystem,
|
||||
)
|
||||
|
||||
middleware = KnowledgeBaseSearchMiddleware(
|
||||
llm=FakeLLM(
|
||||
|
|
@ -386,9 +365,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
|||
search_called = True
|
||||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}, {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
|
||||
fake_browse_recent_documents,
|
||||
|
|
@ -397,10 +373,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
|||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||
fake_build_scoped_filesystem,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
|
|
@ -440,9 +412,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
|||
search_captured.update(kwargs)
|
||||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}, {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
|
||||
fake_browse_recent_documents,
|
||||
|
|
@ -451,10 +420,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
|||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||
fake_build_scoped_filesystem,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue