mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-27 09:46:25 +02:00
feat: made agent file sytem optimized
This commit is contained in:
parent
ee0b59c0fa
commit
2cc2d339e6
67 changed files with 8011 additions and 5591 deletions
|
|
@ -0,0 +1,17 @@
|
|||
"""Middleware components for the SurfSense new chat agent."""
|
||||
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.filesystem import (
|
||||
SurfSenseFilesystemMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DedupHITLToolCallsMiddleware",
|
||||
"KnowledgeBaseSearchMiddleware",
|
||||
"SurfSenseFilesystemMiddleware",
|
||||
]
|
||||
694
surfsense_backend/app/agents/new_chat/middleware/filesystem.py
Normal file
694
surfsense_backend/app/agents/new_chat/middleware/filesystem.py
Normal file
|
|
@ -0,0 +1,694 @@
|
|||
"""Custom filesystem middleware for the SurfSense agent.
|
||||
|
||||
This middleware customizes prompts and persists write/edit operations for
|
||||
`/documents/*` files into SurfSense's `Document`/`Chunk` tables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
from deepagents import FilesystemMiddleware
|
||||
from deepagents.backends.protocol import EditResult, WriteResult
|
||||
from deepagents.backends.utils import validate_path
|
||||
from deepagents.middleware.filesystem import FilesystemState
|
||||
from fractional_indexing import generate_key_between
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.callbacks import dispatch_custom_event
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
from langgraph.types import Command
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from app.db import Chunk, Document, DocumentType, Folder, shielded_async_session
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
from app.utils.document_converters import (
|
||||
embed_texts,
|
||||
generate_content_hash,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# System Prompt (injected into every model call by wrap_model_call)
|
||||
# =============================================================================
|
||||
|
||||
SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = """## Following Conventions
|
||||
|
||||
- Read files before editing — understand existing content before making changes.
|
||||
- Mimic existing style, naming conventions, and patterns.
|
||||
|
||||
## Filesystem Tools `ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`, `save_document`
|
||||
|
||||
All file paths must start with a `/`.
|
||||
- ls: list files and directories at a given path.
|
||||
- read_file: read a file from the filesystem.
|
||||
- write_file: create a temporary file in the session (not persisted).
|
||||
- edit_file: edit a file in the session (not persisted for /documents/ files).
|
||||
- glob: find files matching a pattern (e.g., "**/*.xml").
|
||||
- grep: search for text within files.
|
||||
- save_document: **permanently** save a new document to the user's knowledge
|
||||
base. Use only when the user explicitly asks to save/create a document.
|
||||
|
||||
## Reading Documents Efficiently
|
||||
|
||||
Documents are formatted as XML. Each document contains:
|
||||
- `<document_metadata>` — title, type, URL, etc.
|
||||
- `<chunk_index>` — a table of every chunk with its **line range** and a
|
||||
`matched="true"` flag for chunks that matched the search query.
|
||||
- `<document_content>` — the actual chunks in original document order.
|
||||
|
||||
**Workflow**: when reading a large document, read the first ~20 lines to see
|
||||
the `<chunk_index>`, identify chunks marked `matched="true"`, then use
|
||||
`read_file(path, offset=<start_line>, limit=<lines>)` to jump directly to
|
||||
those sections instead of reading the entire file sequentially.
|
||||
|
||||
Use `<chunk id='...'>` values as citation IDs in your answers.
|
||||
"""
|
||||
|
||||
# =============================================================================
|
||||
# Per-Tool Descriptions (shown to the LLM as the tool's docstring)
|
||||
# =============================================================================
|
||||
|
||||
SURFSENSE_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path.
|
||||
"""
|
||||
|
||||
SURFSENSE_READ_FILE_TOOL_DESCRIPTION = """Reads a file from the filesystem.
|
||||
|
||||
Usage:
|
||||
- By default, reads up to 100 lines from the beginning.
|
||||
- Use `offset` and `limit` for pagination when files are large.
|
||||
- Results include line numbers.
|
||||
- Documents contain a `<chunk_index>` near the top listing every chunk with
|
||||
its line range and a `matched="true"` flag for search-relevant chunks.
|
||||
Read the index first, then jump to matched chunks with
|
||||
`read_file(path, offset=<start_line>, limit=<num_lines>)`.
|
||||
- Use chunk IDs (`<chunk id='...'>`) as citations in answers.
|
||||
"""
|
||||
|
||||
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new file to the in-memory filesystem (session-only).
|
||||
|
||||
Use this to create scratch/working files during the conversation. Files created
|
||||
here are ephemeral and will not be saved to the user's knowledge base.
|
||||
|
||||
To permanently save a document to the user's knowledge base, use the
|
||||
`save_document` tool instead.
|
||||
"""
|
||||
|
||||
SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files.
|
||||
|
||||
IMPORTANT:
|
||||
- Read the file before editing.
|
||||
- Preserve exact indentation and formatting.
|
||||
- Edits to documents under `/documents/` are session-only (not persisted to the
|
||||
database) because those files use an XML citation wrapper around the original
|
||||
content.
|
||||
"""
|
||||
|
||||
SURFSENSE_GLOB_TOOL_DESCRIPTION = """Find files matching a glob pattern.
|
||||
|
||||
Supports standard glob patterns: `*`, `**`, `?`.
|
||||
Returns absolute file paths.
|
||||
"""
|
||||
|
||||
SURFSENSE_GREP_TOOL_DESCRIPTION = """Search for a literal text pattern across files.
|
||||
|
||||
Use this to locate relevant document files/chunks before reading full files.
|
||||
"""
|
||||
|
||||
SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION = """Permanently saves a document to the user's knowledge base.
|
||||
|
||||
This is an expensive operation — it creates a new Document record in the
|
||||
database, chunks the content, and generates embeddings for search.
|
||||
|
||||
Use ONLY when the user explicitly asks to save/create/store a document.
|
||||
Do NOT use this for scratch work; use `write_file` for temporary files.
|
||||
|
||||
Args:
|
||||
title: The document title (e.g., "Meeting Notes 2025-06-01").
|
||||
content: The plain-text or markdown content to save. Do NOT include XML
|
||||
citation wrappers — pass only the actual document text.
|
||||
folder_path: Optional folder path under /documents/ (e.g., "Work/Notes").
|
||||
Folders are created automatically if they don't exist.
|
||||
"""
|
||||
|
||||
|
||||
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||
"""SurfSense-specific filesystem middleware with DB persistence for docs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
search_space_id: int | None = None,
|
||||
created_by_id: str | None = None,
|
||||
tool_token_limit_before_evict: int | None = 20000,
|
||||
) -> None:
|
||||
self._search_space_id = search_space_id
|
||||
self._created_by_id = created_by_id
|
||||
super().__init__(
|
||||
system_prompt=SURFSENSE_FILESYSTEM_SYSTEM_PROMPT,
|
||||
custom_tool_descriptions={
|
||||
"ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION,
|
||||
"read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION,
|
||||
"write_file": SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION,
|
||||
"edit_file": SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION,
|
||||
"glob": SURFSENSE_GLOB_TOOL_DESCRIPTION,
|
||||
"grep": SURFSENSE_GREP_TOOL_DESCRIPTION,
|
||||
},
|
||||
tool_token_limit_before_evict=tool_token_limit_before_evict,
|
||||
)
|
||||
# Remove the execute tool (no sandbox backend)
|
||||
self.tools = [t for t in self.tools if t.name != "execute"]
|
||||
self.tools.append(self._create_save_document_tool())
|
||||
|
||||
@staticmethod
|
||||
def _run_async_blocking(coro: Any) -> Any:
|
||||
"""Run async coroutine from sync code path when no event loop is running."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_running():
|
||||
return "Error: sync filesystem persistence not supported inside an active event loop."
|
||||
except RuntimeError:
|
||||
pass
|
||||
return asyncio.run(coro)
|
||||
|
||||
@staticmethod
|
||||
def _parse_virtual_path(file_path: str) -> tuple[list[str], str]:
|
||||
"""Parse /documents/... path into folder parts and a document title."""
|
||||
if not file_path.startswith("/documents/"):
|
||||
return [], ""
|
||||
rel = file_path[len("/documents/") :].strip("/")
|
||||
if not rel:
|
||||
return [], ""
|
||||
parts = [part for part in rel.split("/") if part]
|
||||
file_name = parts[-1]
|
||||
title = file_name[:-4] if file_name.lower().endswith(".xml") else file_name
|
||||
return parts[:-1], title
|
||||
|
||||
async def _ensure_folder_hierarchy(
|
||||
self,
|
||||
*,
|
||||
folder_parts: list[str],
|
||||
search_space_id: int,
|
||||
) -> int | None:
|
||||
"""Ensure folder hierarchy exists and return leaf folder ID."""
|
||||
if not folder_parts:
|
||||
return None
|
||||
async with shielded_async_session() as session:
|
||||
parent_id: int | None = None
|
||||
for name in folder_parts:
|
||||
result = await session.execute(
|
||||
select(Folder).where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.parent_id == parent_id
|
||||
if parent_id is not None
|
||||
else Folder.parent_id.is_(None),
|
||||
Folder.name == name,
|
||||
)
|
||||
)
|
||||
folder = result.scalar_one_or_none()
|
||||
if folder is None:
|
||||
sibling_result = await session.execute(
|
||||
select(Folder.position)
|
||||
.where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.parent_id == parent_id
|
||||
if parent_id is not None
|
||||
else Folder.parent_id.is_(None),
|
||||
)
|
||||
.order_by(Folder.position.desc())
|
||||
.limit(1)
|
||||
)
|
||||
last_position = sibling_result.scalar_one_or_none()
|
||||
folder = Folder(
|
||||
name=name,
|
||||
position=generate_key_between(last_position, None),
|
||||
parent_id=parent_id,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=self._created_by_id,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(folder)
|
||||
await session.flush()
|
||||
parent_id = folder.id
|
||||
await session.commit()
|
||||
return parent_id
|
||||
|
||||
async def _persist_new_document(
|
||||
self, *, file_path: str, content: str
|
||||
) -> dict[str, Any] | str:
|
||||
"""Persist a new NOTE document from a newly written file.
|
||||
|
||||
Returns a dict with document metadata on success, or an error string.
|
||||
"""
|
||||
if self._search_space_id is None:
|
||||
return {}
|
||||
folder_parts, title = self._parse_virtual_path(file_path)
|
||||
if not title:
|
||||
return "Error: write_file for document persistence requires path under /documents/<name>.xml"
|
||||
folder_id = await self._ensure_folder_hierarchy(
|
||||
folder_parts=folder_parts,
|
||||
search_space_id=self._search_space_id,
|
||||
)
|
||||
async with shielded_async_session() as session:
|
||||
content_hash = generate_content_hash(content, self._search_space_id)
|
||||
existing = await session.execute(
|
||||
select(Document.id).where(Document.content_hash == content_hash)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
return "Error: A document with identical content already exists."
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTE,
|
||||
file_path,
|
||||
self._search_space_id,
|
||||
)
|
||||
doc = Document(
|
||||
title=title,
|
||||
document_type=DocumentType.NOTE,
|
||||
document_metadata={"virtual_path": file_path},
|
||||
content=content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
source_markdown=content,
|
||||
search_space_id=self._search_space_id,
|
||||
folder_id=folder_id,
|
||||
created_by_id=self._created_by_id,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(doc)
|
||||
await session.flush()
|
||||
|
||||
summary_embedding = embed_texts([content])[0]
|
||||
doc.embedding = summary_embedding
|
||||
chunk_texts = chunk_text(content)
|
||||
if chunk_texts:
|
||||
chunk_embeddings = embed_texts(chunk_texts)
|
||||
chunks = [
|
||||
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||
for text, embedding in zip(
|
||||
chunk_texts, chunk_embeddings, strict=True
|
||||
)
|
||||
]
|
||||
session.add_all(chunks)
|
||||
await session.commit()
|
||||
|
||||
return {
|
||||
"id": doc.id,
|
||||
"title": title,
|
||||
"documentType": DocumentType.NOTE.value,
|
||||
"searchSpaceId": self._search_space_id,
|
||||
"folderId": folder_id,
|
||||
"createdById": str(self._created_by_id)
|
||||
if self._created_by_id
|
||||
else None,
|
||||
}
|
||||
|
||||
async def _persist_edited_document(
|
||||
self, *, file_path: str, updated_content: str
|
||||
) -> str | None:
|
||||
"""Persist edits for an existing NOTE document and recreate chunks."""
|
||||
if self._search_space_id is None:
|
||||
return None
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTE,
|
||||
file_path,
|
||||
self._search_space_id,
|
||||
)
|
||||
doc_id_from_xml: int | None = None
|
||||
match = re.search(r"<document_id>\s*(\d+)\s*</document_id>", updated_content)
|
||||
if match:
|
||||
doc_id_from_xml = int(match.group(1))
|
||||
async with shielded_async_session() as session:
|
||||
doc_result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == self._search_space_id,
|
||||
Document.unique_identifier_hash == unique_identifier_hash,
|
||||
)
|
||||
)
|
||||
document = doc_result.scalar_one_or_none()
|
||||
if document is None and doc_id_from_xml is not None:
|
||||
by_id_result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == self._search_space_id,
|
||||
Document.id == doc_id_from_xml,
|
||||
)
|
||||
)
|
||||
document = by_id_result.scalar_one_or_none()
|
||||
if document is None:
|
||||
return "Error: Could not map edited file to an existing document."
|
||||
|
||||
document.content = updated_content
|
||||
document.source_markdown = updated_content
|
||||
document.content_hash = generate_content_hash(
|
||||
updated_content, self._search_space_id
|
||||
)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
if not document.document_metadata:
|
||||
document.document_metadata = {}
|
||||
document.document_metadata["virtual_path"] = file_path
|
||||
|
||||
summary_embedding = embed_texts([updated_content])[0]
|
||||
document.embedding = summary_embedding
|
||||
|
||||
await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
|
||||
chunk_texts = chunk_text(updated_content)
|
||||
if chunk_texts:
|
||||
chunk_embeddings = embed_texts(chunk_texts)
|
||||
session.add_all(
|
||||
[
|
||||
Chunk(
|
||||
document_id=document.id, content=text, embedding=embedding
|
||||
)
|
||||
for text, embedding in zip(
|
||||
chunk_texts, chunk_embeddings, strict=True
|
||||
)
|
||||
]
|
||||
)
|
||||
await session.commit()
|
||||
return None
|
||||
|
||||
def _create_save_document_tool(self) -> BaseTool:
|
||||
"""Create save_document tool that persists a new document to the KB."""
|
||||
|
||||
def sync_save_document(
|
||||
title: Annotated[str, "Title for the new document."],
|
||||
content: Annotated[
|
||||
str,
|
||||
"Plain-text or markdown content to save. Do NOT include XML wrappers.",
|
||||
],
|
||||
runtime: ToolRuntime[None, FilesystemState],
|
||||
folder_path: Annotated[
|
||||
str,
|
||||
"Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.",
|
||||
] = "",
|
||||
) -> Command | str:
|
||||
if not content.strip():
|
||||
return "Error: content cannot be empty."
|
||||
file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled"
|
||||
if not file_name.lower().endswith(".xml"):
|
||||
file_name = f"{file_name}.xml"
|
||||
folder = folder_path.strip().strip("/") if folder_path else ""
|
||||
virtual_path = (
|
||||
f"/documents/{folder}/{file_name}"
|
||||
if folder
|
||||
else f"/documents/{file_name}"
|
||||
)
|
||||
|
||||
persist_result = self._run_async_blocking(
|
||||
self._persist_new_document(file_path=virtual_path, content=content)
|
||||
)
|
||||
if isinstance(persist_result, str):
|
||||
return persist_result
|
||||
if isinstance(persist_result, dict) and persist_result.get("id"):
|
||||
dispatch_custom_event("document_created", persist_result)
|
||||
return f"Document '{title}' saved to knowledge base (path: {virtual_path})."
|
||||
|
||||
async def async_save_document(
|
||||
title: Annotated[str, "Title for the new document."],
|
||||
content: Annotated[
|
||||
str,
|
||||
"Plain-text or markdown content to save. Do NOT include XML wrappers.",
|
||||
],
|
||||
runtime: ToolRuntime[None, FilesystemState],
|
||||
folder_path: Annotated[
|
||||
str,
|
||||
"Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.",
|
||||
] = "",
|
||||
) -> Command | str:
|
||||
if not content.strip():
|
||||
return "Error: content cannot be empty."
|
||||
file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled"
|
||||
if not file_name.lower().endswith(".xml"):
|
||||
file_name = f"{file_name}.xml"
|
||||
folder = folder_path.strip().strip("/") if folder_path else ""
|
||||
virtual_path = (
|
||||
f"/documents/{folder}/{file_name}"
|
||||
if folder
|
||||
else f"/documents/{file_name}"
|
||||
)
|
||||
|
||||
persist_result = await self._persist_new_document(
|
||||
file_path=virtual_path, content=content
|
||||
)
|
||||
if isinstance(persist_result, str):
|
||||
return persist_result
|
||||
if isinstance(persist_result, dict) and persist_result.get("id"):
|
||||
dispatch_custom_event("document_created", persist_result)
|
||||
return f"Document '{title}' saved to knowledge base (path: {virtual_path})."
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="save_document",
|
||||
description=SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION,
|
||||
func=sync_save_document,
|
||||
coroutine=async_save_document,
|
||||
)
|
||||
|
||||
def _create_write_file_tool(self) -> BaseTool:
|
||||
"""Create write_file — ephemeral for /documents/*, persisted otherwise."""
|
||||
tool_description = (
|
||||
self._custom_tool_descriptions.get("write_file")
|
||||
or SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
def sync_write_file(
|
||||
file_path: Annotated[
|
||||
str,
|
||||
"Absolute path where the file should be created. Must be absolute, not relative.",
|
||||
],
|
||||
content: Annotated[
|
||||
str,
|
||||
"The text content to write to the file. This parameter is required.",
|
||||
],
|
||||
runtime: ToolRuntime[None, FilesystemState],
|
||||
) -> Command | str:
|
||||
resolved_backend = self._get_backend(runtime)
|
||||
try:
|
||||
validated_path = validate_path(file_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
res: WriteResult = resolved_backend.write(validated_path, content)
|
||||
if res.error:
|
||||
return res.error
|
||||
|
||||
if not self._is_kb_document(validated_path):
|
||||
persist_result = self._run_async_blocking(
|
||||
self._persist_new_document(
|
||||
file_path=validated_path, content=content
|
||||
)
|
||||
)
|
||||
if isinstance(persist_result, str):
|
||||
return persist_result
|
||||
if isinstance(persist_result, dict) and persist_result.get("id"):
|
||||
dispatch_custom_event("document_created", persist_result)
|
||||
|
||||
if res.files_update is not None:
|
||||
return Command(
|
||||
update={
|
||||
"files": res.files_update,
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Updated file {res.path}",
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
return f"Updated file {res.path}"
|
||||
|
||||
async def async_write_file(
|
||||
file_path: Annotated[
|
||||
str,
|
||||
"Absolute path where the file should be created. Must be absolute, not relative.",
|
||||
],
|
||||
content: Annotated[
|
||||
str,
|
||||
"The text content to write to the file. This parameter is required.",
|
||||
],
|
||||
runtime: ToolRuntime[None, FilesystemState],
|
||||
) -> Command | str:
|
||||
resolved_backend = self._get_backend(runtime)
|
||||
try:
|
||||
validated_path = validate_path(file_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
res: WriteResult = await resolved_backend.awrite(validated_path, content)
|
||||
if res.error:
|
||||
return res.error
|
||||
|
||||
if not self._is_kb_document(validated_path):
|
||||
persist_result = await self._persist_new_document(
|
||||
file_path=validated_path,
|
||||
content=content,
|
||||
)
|
||||
if isinstance(persist_result, str):
|
||||
return persist_result
|
||||
if isinstance(persist_result, dict) and persist_result.get("id"):
|
||||
dispatch_custom_event("document_created", persist_result)
|
||||
|
||||
if res.files_update is not None:
|
||||
return Command(
|
||||
update={
|
||||
"files": res.files_update,
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Updated file {res.path}",
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
return f"Updated file {res.path}"
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="write_file",
|
||||
description=tool_description,
|
||||
func=sync_write_file,
|
||||
coroutine=async_write_file,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_kb_document(path: str) -> bool:
|
||||
"""Return True for paths under /documents/ (KB-sourced, XML-wrapped)."""
|
||||
return path.startswith("/documents/")
|
||||
|
||||
def _create_edit_file_tool(self) -> BaseTool:
|
||||
"""Create edit_file with DB persistence (skipped for KB documents)."""
|
||||
tool_description = (
|
||||
self._custom_tool_descriptions.get("edit_file")
|
||||
or SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
def sync_edit_file(
|
||||
file_path: Annotated[
|
||||
str,
|
||||
"Absolute path to the file to edit. Must be absolute, not relative.",
|
||||
],
|
||||
old_string: Annotated[
|
||||
str,
|
||||
"The exact text to find and replace. Must be unique in the file unless replace_all is True.",
|
||||
],
|
||||
new_string: Annotated[
|
||||
str,
|
||||
"The text to replace old_string with. Must be different from old_string.",
|
||||
],
|
||||
runtime: ToolRuntime[None, FilesystemState],
|
||||
*,
|
||||
replace_all: Annotated[
|
||||
bool,
|
||||
"If True, replace all occurrences of old_string. If False (default), old_string must be unique.",
|
||||
] = False,
|
||||
) -> Command | str:
|
||||
resolved_backend = self._get_backend(runtime)
|
||||
try:
|
||||
validated_path = validate_path(file_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
res: EditResult = resolved_backend.edit(
|
||||
validated_path,
|
||||
old_string,
|
||||
new_string,
|
||||
replace_all=replace_all,
|
||||
)
|
||||
if res.error:
|
||||
return res.error
|
||||
|
||||
if not self._is_kb_document(validated_path):
|
||||
read_result = resolved_backend.read(
|
||||
validated_path, offset=0, limit=200000
|
||||
)
|
||||
if read_result.error or read_result.file_data is None:
|
||||
return f"Error: could not reload edited file '{validated_path}' for persistence."
|
||||
updated_content = read_result.file_data["content"]
|
||||
persist_result = self._run_async_blocking(
|
||||
self._persist_edited_document(
|
||||
file_path=validated_path,
|
||||
updated_content=updated_content,
|
||||
)
|
||||
)
|
||||
if isinstance(persist_result, str):
|
||||
return persist_result
|
||||
|
||||
if res.files_update is not None:
|
||||
return Command(
|
||||
update={
|
||||
"files": res.files_update,
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'",
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'"
|
||||
|
||||
async def async_edit_file(
|
||||
file_path: Annotated[
|
||||
str,
|
||||
"Absolute path to the file to edit. Must be absolute, not relative.",
|
||||
],
|
||||
old_string: Annotated[
|
||||
str,
|
||||
"The exact text to find and replace. Must be unique in the file unless replace_all is True.",
|
||||
],
|
||||
new_string: Annotated[
|
||||
str,
|
||||
"The text to replace old_string with. Must be different from old_string.",
|
||||
],
|
||||
runtime: ToolRuntime[None, FilesystemState],
|
||||
*,
|
||||
replace_all: Annotated[
|
||||
bool,
|
||||
"If True, replace all occurrences of old_string. If False (default), old_string must be unique.",
|
||||
] = False,
|
||||
) -> Command | str:
|
||||
resolved_backend = self._get_backend(runtime)
|
||||
try:
|
||||
validated_path = validate_path(file_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
res: EditResult = await resolved_backend.aedit(
|
||||
validated_path,
|
||||
old_string,
|
||||
new_string,
|
||||
replace_all=replace_all,
|
||||
)
|
||||
if res.error:
|
||||
return res.error
|
||||
|
||||
if not self._is_kb_document(validated_path):
|
||||
read_result = await resolved_backend.aread(
|
||||
validated_path, offset=0, limit=200000
|
||||
)
|
||||
if read_result.error or read_result.file_data is None:
|
||||
return f"Error: could not reload edited file '{validated_path}' for persistence."
|
||||
updated_content = read_result.file_data["content"]
|
||||
persist_error = await self._persist_edited_document(
|
||||
file_path=validated_path,
|
||||
updated_content=updated_content,
|
||||
)
|
||||
if persist_error:
|
||||
return persist_error
|
||||
|
||||
if res.files_update is not None:
|
||||
return Command(
|
||||
update={
|
||||
"files": res.files_update,
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'",
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'"
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="edit_file",
|
||||
description=tool_description,
|
||||
func=sync_edit_file,
|
||||
coroutine=async_edit_file,
|
||||
)
|
||||
|
|
@ -0,0 +1,414 @@
|
|||
"""Knowledge-base pre-search middleware for the SurfSense new chat agent.
|
||||
|
||||
This middleware runs before the main agent loop and seeds a virtual filesystem
|
||||
(`files` state) with relevant documents retrieved via hybrid search. On each
|
||||
turn the filesystem is *expanded* — new results merge with documents loaded
|
||||
during prior turns — and a synthetic ``ls`` result is injected into the message
|
||||
history so the LLM is immediately aware of the current filesystem structure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Document, Folder, shielded_async_session
|
||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.utils.document_converters import embed_texts
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||
"""Extract plain text from a message content."""
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
parts.append(str(item.get("text", "")))
|
||||
return "\n".join(p for p in parts if p)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
|
||||
"""Convert arbitrary text into a filesystem-safe filename."""
|
||||
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
||||
name = re.sub(r"\s+", " ", name)
|
||||
if not name:
|
||||
name = fallback
|
||||
if len(name) > 180:
|
||||
name = name[:180].rstrip()
|
||||
if not name.lower().endswith(".xml"):
|
||||
name = f"{name}.xml"
|
||||
return name
|
||||
|
||||
|
||||
def _build_document_xml(
|
||||
document: dict[str, Any],
|
||||
matched_chunk_ids: set[int] | None = None,
|
||||
) -> str:
|
||||
"""Build citation-friendly XML with a ``<chunk_index>`` for smart seeking.
|
||||
|
||||
The ``<chunk_index>`` at the top of each document lists every chunk with its
|
||||
line range inside ``<document_content>`` and flags chunks that directly
|
||||
matched the search query (``matched="true"``). This lets the LLM jump
|
||||
straight to the most relevant section via ``read_file(offset=…, limit=…)``
|
||||
instead of reading sequentially from the start.
|
||||
"""
|
||||
matched = matched_chunk_ids or set()
|
||||
|
||||
doc_meta = document.get("document") or {}
|
||||
metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {}
|
||||
document_id = doc_meta.get("id", document.get("document_id", "unknown"))
|
||||
document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN"))
|
||||
title = doc_meta.get("title") or metadata.get("title") or "Untitled Document"
|
||||
url = (
|
||||
metadata.get("url") or metadata.get("source") or metadata.get("page_url") or ""
|
||||
)
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||
|
||||
# --- 1. Metadata header (fixed structure) ---
|
||||
metadata_lines: list[str] = [
|
||||
"<document>",
|
||||
"<document_metadata>",
|
||||
f" <document_id>{document_id}</document_id>",
|
||||
f" <document_type>{document_type}</document_type>",
|
||||
f" <title><![CDATA[{title}]]></title>",
|
||||
f" <url><![CDATA[{url}]]></url>",
|
||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
||||
"</document_metadata>",
|
||||
"",
|
||||
]
|
||||
|
||||
# --- 2. Pre-build chunk XML strings to compute line counts ---
|
||||
chunks = document.get("chunks") or []
|
||||
chunk_entries: list[tuple[int | None, str]] = [] # (chunk_id, xml_string)
|
||||
if isinstance(chunks, list):
|
||||
for chunk in chunks:
|
||||
if not isinstance(chunk, dict):
|
||||
continue
|
||||
chunk_id = chunk.get("chunk_id") or chunk.get("id")
|
||||
chunk_content = str(chunk.get("content", "")).strip()
|
||||
if not chunk_content:
|
||||
continue
|
||||
if chunk_id is None:
|
||||
xml = f" <chunk><![CDATA[{chunk_content}]]></chunk>"
|
||||
else:
|
||||
xml = f" <chunk id='{chunk_id}'><![CDATA[{chunk_content}]]></chunk>"
|
||||
chunk_entries.append((chunk_id, xml))
|
||||
|
||||
# --- 3. Compute line numbers for every chunk ---
|
||||
# Layout (1-indexed lines for read_file):
|
||||
# metadata_lines -> len(metadata_lines) lines
|
||||
# <chunk_index> -> 1 line
|
||||
# index entries -> len(chunk_entries) lines
|
||||
# </chunk_index> -> 1 line
|
||||
# (empty line) -> 1 line
|
||||
# <document_content> -> 1 line
|
||||
# chunk xml lines…
|
||||
# </document_content> -> 1 line
|
||||
# </document> -> 1 line
|
||||
index_overhead = (
|
||||
1 + len(chunk_entries) + 1 + 1 + 1
|
||||
) # tags + empty + <document_content>
|
||||
first_chunk_line = len(metadata_lines) + index_overhead + 1 # 1-indexed
|
||||
|
||||
current_line = first_chunk_line
|
||||
index_entry_lines: list[str] = []
|
||||
for cid, xml_str in chunk_entries:
|
||||
num_lines = xml_str.count("\n") + 1
|
||||
end_line = current_line + num_lines - 1
|
||||
matched_attr = ' matched="true"' if cid is not None and cid in matched else ""
|
||||
if cid is not None:
|
||||
index_entry_lines.append(
|
||||
f' <entry chunk_id="{cid}" lines="{current_line}-{end_line}"{matched_attr}/>'
|
||||
)
|
||||
else:
|
||||
index_entry_lines.append(
|
||||
f' <entry lines="{current_line}-{end_line}"{matched_attr}/>'
|
||||
)
|
||||
current_line = end_line + 1
|
||||
|
||||
# --- 4. Assemble final XML ---
|
||||
lines = metadata_lines.copy()
|
||||
lines.append("<chunk_index>")
|
||||
lines.extend(index_entry_lines)
|
||||
lines.append("</chunk_index>")
|
||||
lines.append("")
|
||||
lines.append("<document_content>")
|
||||
for _, xml_str in chunk_entries:
|
||||
lines.append(xml_str)
|
||||
lines.extend(["</document_content>", "</document>"])
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def _get_folder_paths(
|
||||
session: AsyncSession, search_space_id: int
|
||||
) -> dict[int, str]:
|
||||
"""Return a map of folder_id -> virtual folder path under /documents."""
|
||||
result = await session.execute(
|
||||
select(Folder.id, Folder.name, Folder.parent_id).where(
|
||||
Folder.search_space_id == search_space_id
|
||||
)
|
||||
)
|
||||
rows = result.all()
|
||||
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
|
||||
|
||||
cache: dict[int, str] = {}
|
||||
|
||||
def resolve_path(folder_id: int) -> str:
|
||||
if folder_id in cache:
|
||||
return cache[folder_id]
|
||||
parts: list[str] = []
|
||||
cursor: int | None = folder_id
|
||||
visited: set[int] = set()
|
||||
while cursor is not None and cursor in by_id and cursor not in visited:
|
||||
visited.add(cursor)
|
||||
entry = by_id[cursor]
|
||||
parts.append(
|
||||
_safe_filename(str(entry["name"]), fallback="folder").removesuffix(
|
||||
".xml"
|
||||
)
|
||||
)
|
||||
cursor = entry["parent_id"]
|
||||
parts.reverse()
|
||||
path = "/documents/" + "/".join(parts) if parts else "/documents"
|
||||
cache[folder_id] = path
|
||||
return path
|
||||
|
||||
for folder_id in by_id:
|
||||
resolve_path(folder_id)
|
||||
return cache
|
||||
|
||||
|
||||
def _build_synthetic_ls(
|
||||
existing_files: dict[str, Any] | None,
|
||||
new_files: dict[str, Any],
|
||||
) -> tuple[AIMessage, ToolMessage]:
|
||||
"""Build a synthetic ls("/documents") tool-call + result for the LLM context.
|
||||
|
||||
Paths are listed with *new* (rank-ordered) files first, then existing files
|
||||
that were already in state from prior turns.
|
||||
"""
|
||||
merged: dict[str, Any] = {**(existing_files or {}), **new_files}
|
||||
doc_paths = [
|
||||
p for p, v in merged.items() if p.startswith("/documents/") and v is not None
|
||||
]
|
||||
|
||||
new_set = set(new_files)
|
||||
new_paths = [p for p in doc_paths if p in new_set]
|
||||
old_paths = [p for p in doc_paths if p not in new_set]
|
||||
ordered = new_paths + old_paths
|
||||
|
||||
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
|
||||
ai_msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}],
|
||||
)
|
||||
tool_msg = ToolMessage(
|
||||
content=str(ordered) if ordered else "No documents found.",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
return ai_msg, tool_msg
|
||||
|
||||
|
||||
def _resolve_search_types(
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
) -> list[str] | None:
|
||||
"""Build a flat list of document-type strings for the chunk retriever.
|
||||
|
||||
Includes legacy equivalents from ``NATIVE_TO_LEGACY_DOCTYPE`` so that
|
||||
old documents indexed under Composio names are still found.
|
||||
|
||||
Returns ``None`` when no filtering is desired (search all types).
|
||||
"""
|
||||
types: set[str] = set()
|
||||
if available_document_types:
|
||||
types.update(available_document_types)
|
||||
if available_connectors:
|
||||
types.update(available_connectors)
|
||||
if not types:
|
||||
return None
|
||||
|
||||
expanded: set[str] = set(types)
|
||||
for t in types:
|
||||
legacy = NATIVE_TO_LEGACY_DOCTYPE.get(t)
|
||||
if legacy:
|
||||
expanded.add(legacy)
|
||||
return list(expanded) if expanded else None
|
||||
|
||||
|
||||
async def search_knowledge_base(
|
||||
*,
|
||||
query: str,
|
||||
search_space_id: int,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run a single unified hybrid search against the knowledge base.
|
||||
|
||||
Uses one ``ChucksHybridSearchRetriever`` call across all document types
|
||||
instead of fanning out per-connector. This reduces the number of DB
|
||||
queries from ~10 to 2 (one RRF query + one chunk fetch).
|
||||
"""
|
||||
if not query:
|
||||
return []
|
||||
|
||||
[embedding] = embed_texts([query])
|
||||
|
||||
doc_types = _resolve_search_types(available_connectors, available_document_types)
|
||||
retriever_top_k = min(top_k * 3, 30)
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
retriever = ChucksHybridSearchRetriever(session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text=query,
|
||||
top_k=retriever_top_k,
|
||||
search_space_id=search_space_id,
|
||||
document_type=doc_types,
|
||||
query_embedding=embedding.tolist(),
|
||||
)
|
||||
|
||||
return results[:top_k]
|
||||
|
||||
|
||||
async def build_scoped_filesystem(
|
||||
*,
|
||||
documents: Sequence[dict[str, Any]],
|
||||
search_space_id: int,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Build a StateBackend-compatible files dict from search results."""
|
||||
async with shielded_async_session() as session:
|
||||
folder_paths = await _get_folder_paths(session, search_space_id)
|
||||
doc_ids = [
|
||||
(doc.get("document") or {}).get("id")
|
||||
for doc in documents
|
||||
if isinstance(doc, dict)
|
||||
]
|
||||
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
|
||||
folder_by_doc_id: dict[int, int | None] = {}
|
||||
if doc_ids:
|
||||
doc_rows = await session.execute(
|
||||
select(Document.id, Document.folder_id).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.id.in_(doc_ids),
|
||||
)
|
||||
)
|
||||
folder_by_doc_id = {
|
||||
row.id: row.folder_id for row in doc_rows.all() if row.id is not None
|
||||
}
|
||||
|
||||
files: dict[str, dict[str, str]] = {}
|
||||
for document in documents:
|
||||
doc_meta = document.get("document") or {}
|
||||
title = str(doc_meta.get("title") or "untitled")
|
||||
doc_id = doc_meta.get("id")
|
||||
folder_id = folder_by_doc_id.get(doc_id) if isinstance(doc_id, int) else None
|
||||
base_folder = folder_paths.get(folder_id, "/documents")
|
||||
file_name = _safe_filename(title)
|
||||
path = f"{base_folder}/{file_name}"
|
||||
matched_ids = set(document.get("matched_chunk_ids") or [])
|
||||
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
|
||||
files[path] = {
|
||||
"content": xml_content.split("\n"),
|
||||
"encoding": "utf-8",
|
||||
"created_at": "",
|
||||
"modified_at": "",
|
||||
}
|
||||
return files
|
||||
|
||||
|
||||
class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Pre-agent middleware that always searches the KB and seeds a scoped filesystem."""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
search_space_id: int,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
) -> None:
|
||||
self.search_space_id = search_space_id
|
||||
self.available_connectors = available_connectors
|
||||
self.available_document_types = available_document_types
|
||||
self.top_k = top_k
|
||||
|
||||
def before_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_running():
|
||||
return None
|
||||
except RuntimeError:
|
||||
pass
|
||||
return asyncio.run(self.abefore_agent(state, runtime))
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
last_message = messages[-1]
|
||||
if not isinstance(last_message, HumanMessage):
|
||||
return None
|
||||
|
||||
user_text = _extract_text_from_message(last_message).strip()
|
||||
if not user_text:
|
||||
return None
|
||||
|
||||
t0 = _perf_log and asyncio.get_event_loop().time()
|
||||
existing_files = state.get("files")
|
||||
|
||||
search_results = await search_knowledge_base(
|
||||
query=user_text,
|
||||
search_space_id=self.search_space_id,
|
||||
available_connectors=self.available_connectors,
|
||||
available_document_types=self.available_document_types,
|
||||
top_k=self.top_k,
|
||||
)
|
||||
new_files = await build_scoped_filesystem(
|
||||
documents=search_results,
|
||||
search_space_id=self.search_space_id,
|
||||
)
|
||||
|
||||
ai_msg, tool_msg = _build_synthetic_ls(existing_files, new_files)
|
||||
|
||||
if t0 is not None:
|
||||
_perf_log.info(
|
||||
"[kb_fs_middleware] completed in %.3fs query=%r new_files=%d total=%d",
|
||||
asyncio.get_event_loop().time() - t0,
|
||||
user_text[:80],
|
||||
len(new_files),
|
||||
len(new_files) + len(existing_files or {}),
|
||||
)
|
||||
return {"files": new_files, "messages": [ai_msg, tool_msg]}
|
||||
Loading…
Add table
Add a link
Reference in a new issue