feat: made agent file sytem optimized

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-03-28 16:39:46 -07:00
parent ee0b59c0fa
commit 2cc2d339e6
67 changed files with 8011 additions and 5591 deletions

View file

@ -0,0 +1,17 @@
"""Middleware components for the SurfSense new chat agent."""
from app.agents.new_chat.middleware.dedup_tool_calls import (
DedupHITLToolCallsMiddleware,
)
from app.agents.new_chat.middleware.filesystem import (
SurfSenseFilesystemMiddleware,
)
from app.agents.new_chat.middleware.knowledge_search import (
KnowledgeBaseSearchMiddleware,
)
__all__ = [
"DedupHITLToolCallsMiddleware",
"KnowledgeBaseSearchMiddleware",
"SurfSenseFilesystemMiddleware",
]

View file

@ -0,0 +1,694 @@
"""Custom filesystem middleware for the SurfSense agent.
This middleware customizes prompts and persists write/edit operations for
`/documents/*` files into SurfSense's `Document`/`Chunk` tables.
"""
from __future__ import annotations
import asyncio
import re
from datetime import UTC, datetime
from typing import Annotated, Any
from deepagents import FilesystemMiddleware
from deepagents.backends.protocol import EditResult, WriteResult
from deepagents.backends.utils import validate_path
from deepagents.middleware.filesystem import FilesystemState
from fractional_indexing import generate_key_between
from langchain.tools import ToolRuntime
from langchain_core.callbacks import dispatch_custom_event
from langchain_core.messages import ToolMessage
from langchain_core.tools import BaseTool, StructuredTool
from langgraph.types import Command
from sqlalchemy import delete, select
from app.db import Chunk, Document, DocumentType, Folder, shielded_async_session
from app.indexing_pipeline.document_chunker import chunk_text
from app.utils.document_converters import (
embed_texts,
generate_content_hash,
generate_unique_identifier_hash,
)
# =============================================================================
# System Prompt (injected into every model call by wrap_model_call)
# =============================================================================
SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = """## Following Conventions
- Read files before editing understand existing content before making changes.
- Mimic existing style, naming conventions, and patterns.
## Filesystem Tools `ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`, `save_document`
All file paths must start with a `/`.
- ls: list files and directories at a given path.
- read_file: read a file from the filesystem.
- write_file: create a temporary file in the session (not persisted).
- edit_file: edit a file in the session (not persisted for /documents/ files).
- glob: find files matching a pattern (e.g., "**/*.xml").
- grep: search for text within files.
- save_document: **permanently** save a new document to the user's knowledge
base. Use only when the user explicitly asks to save/create a document.
## Reading Documents Efficiently
Documents are formatted as XML. Each document contains:
- `<document_metadata>` title, type, URL, etc.
- `<chunk_index>` a table of every chunk with its **line range** and a
`matched="true"` flag for chunks that matched the search query.
- `<document_content>` the actual chunks in original document order.
**Workflow**: when reading a large document, read the first ~20 lines to see
the `<chunk_index>`, identify chunks marked `matched="true"`, then use
`read_file(path, offset=<start_line>, limit=<lines>)` to jump directly to
those sections instead of reading the entire file sequentially.
Use `<chunk id='...'>` values as citation IDs in your answers.
"""
# =============================================================================
# Per-Tool Descriptions (shown to the LLM as the tool's docstring)
# =============================================================================
SURFSENSE_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path.
"""
SURFSENSE_READ_FILE_TOOL_DESCRIPTION = """Reads a file from the filesystem.
Usage:
- By default, reads up to 100 lines from the beginning.
- Use `offset` and `limit` for pagination when files are large.
- Results include line numbers.
- Documents contain a `<chunk_index>` near the top listing every chunk with
its line range and a `matched="true"` flag for search-relevant chunks.
Read the index first, then jump to matched chunks with
`read_file(path, offset=<start_line>, limit=<num_lines>)`.
- Use chunk IDs (`<chunk id='...'>`) as citations in answers.
"""
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new file to the in-memory filesystem (session-only).
Use this to create scratch/working files during the conversation. Files created
here are ephemeral and will not be saved to the user's knowledge base.
To permanently save a document to the user's knowledge base, use the
`save_document` tool instead.
"""
SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files.
IMPORTANT:
- Read the file before editing.
- Preserve exact indentation and formatting.
- Edits to documents under `/documents/` are session-only (not persisted to the
database) because those files use an XML citation wrapper around the original
content.
"""
SURFSENSE_GLOB_TOOL_DESCRIPTION = """Find files matching a glob pattern.
Supports standard glob patterns: `*`, `**`, `?`.
Returns absolute file paths.
"""
SURFSENSE_GREP_TOOL_DESCRIPTION = """Search for a literal text pattern across files.
Use this to locate relevant document files/chunks before reading full files.
"""
SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION = """Permanently saves a document to the user's knowledge base.
This is an expensive operation it creates a new Document record in the
database, chunks the content, and generates embeddings for search.
Use ONLY when the user explicitly asks to save/create/store a document.
Do NOT use this for scratch work; use `write_file` for temporary files.
Args:
title: The document title (e.g., "Meeting Notes 2025-06-01").
content: The plain-text or markdown content to save. Do NOT include XML
citation wrappers pass only the actual document text.
folder_path: Optional folder path under /documents/ (e.g., "Work/Notes").
Folders are created automatically if they don't exist.
"""
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
"""SurfSense-specific filesystem middleware with DB persistence for docs."""
def __init__(
self,
*,
search_space_id: int | None = None,
created_by_id: str | None = None,
tool_token_limit_before_evict: int | None = 20000,
) -> None:
self._search_space_id = search_space_id
self._created_by_id = created_by_id
super().__init__(
system_prompt=SURFSENSE_FILESYSTEM_SYSTEM_PROMPT,
custom_tool_descriptions={
"ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION,
"read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION,
"write_file": SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION,
"edit_file": SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION,
"glob": SURFSENSE_GLOB_TOOL_DESCRIPTION,
"grep": SURFSENSE_GREP_TOOL_DESCRIPTION,
},
tool_token_limit_before_evict=tool_token_limit_before_evict,
)
# Remove the execute tool (no sandbox backend)
self.tools = [t for t in self.tools if t.name != "execute"]
self.tools.append(self._create_save_document_tool())
@staticmethod
def _run_async_blocking(coro: Any) -> Any:
"""Run async coroutine from sync code path when no event loop is running."""
try:
loop = asyncio.get_running_loop()
if loop.is_running():
return "Error: sync filesystem persistence not supported inside an active event loop."
except RuntimeError:
pass
return asyncio.run(coro)
@staticmethod
def _parse_virtual_path(file_path: str) -> tuple[list[str], str]:
"""Parse /documents/... path into folder parts and a document title."""
if not file_path.startswith("/documents/"):
return [], ""
rel = file_path[len("/documents/") :].strip("/")
if not rel:
return [], ""
parts = [part for part in rel.split("/") if part]
file_name = parts[-1]
title = file_name[:-4] if file_name.lower().endswith(".xml") else file_name
return parts[:-1], title
async def _ensure_folder_hierarchy(
self,
*,
folder_parts: list[str],
search_space_id: int,
) -> int | None:
"""Ensure folder hierarchy exists and return leaf folder ID."""
if not folder_parts:
return None
async with shielded_async_session() as session:
parent_id: int | None = None
for name in folder_parts:
result = await session.execute(
select(Folder).where(
Folder.search_space_id == search_space_id,
Folder.parent_id == parent_id
if parent_id is not None
else Folder.parent_id.is_(None),
Folder.name == name,
)
)
folder = result.scalar_one_or_none()
if folder is None:
sibling_result = await session.execute(
select(Folder.position)
.where(
Folder.search_space_id == search_space_id,
Folder.parent_id == parent_id
if parent_id is not None
else Folder.parent_id.is_(None),
)
.order_by(Folder.position.desc())
.limit(1)
)
last_position = sibling_result.scalar_one_or_none()
folder = Folder(
name=name,
position=generate_key_between(last_position, None),
parent_id=parent_id,
search_space_id=search_space_id,
created_by_id=self._created_by_id,
updated_at=datetime.now(UTC),
)
session.add(folder)
await session.flush()
parent_id = folder.id
await session.commit()
return parent_id
async def _persist_new_document(
self, *, file_path: str, content: str
) -> dict[str, Any] | str:
"""Persist a new NOTE document from a newly written file.
Returns a dict with document metadata on success, or an error string.
"""
if self._search_space_id is None:
return {}
folder_parts, title = self._parse_virtual_path(file_path)
if not title:
return "Error: write_file for document persistence requires path under /documents/<name>.xml"
folder_id = await self._ensure_folder_hierarchy(
folder_parts=folder_parts,
search_space_id=self._search_space_id,
)
async with shielded_async_session() as session:
content_hash = generate_content_hash(content, self._search_space_id)
existing = await session.execute(
select(Document.id).where(Document.content_hash == content_hash)
)
if existing.scalar_one_or_none() is not None:
return "Error: A document with identical content already exists."
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
file_path,
self._search_space_id,
)
doc = Document(
title=title,
document_type=DocumentType.NOTE,
document_metadata={"virtual_path": file_path},
content=content,
content_hash=content_hash,
unique_identifier_hash=unique_identifier_hash,
source_markdown=content,
search_space_id=self._search_space_id,
folder_id=folder_id,
created_by_id=self._created_by_id,
updated_at=datetime.now(UTC),
)
session.add(doc)
await session.flush()
summary_embedding = embed_texts([content])[0]
doc.embedding = summary_embedding
chunk_texts = chunk_text(content)
if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts)
chunks = [
Chunk(document_id=doc.id, content=text, embedding=embedding)
for text, embedding in zip(
chunk_texts, chunk_embeddings, strict=True
)
]
session.add_all(chunks)
await session.commit()
return {
"id": doc.id,
"title": title,
"documentType": DocumentType.NOTE.value,
"searchSpaceId": self._search_space_id,
"folderId": folder_id,
"createdById": str(self._created_by_id)
if self._created_by_id
else None,
}
async def _persist_edited_document(
self, *, file_path: str, updated_content: str
) -> str | None:
"""Persist edits for an existing NOTE document and recreate chunks."""
if self._search_space_id is None:
return None
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
file_path,
self._search_space_id,
)
doc_id_from_xml: int | None = None
match = re.search(r"<document_id>\s*(\d+)\s*</document_id>", updated_content)
if match:
doc_id_from_xml = int(match.group(1))
async with shielded_async_session() as session:
doc_result = await session.execute(
select(Document).where(
Document.search_space_id == self._search_space_id,
Document.unique_identifier_hash == unique_identifier_hash,
)
)
document = doc_result.scalar_one_or_none()
if document is None and doc_id_from_xml is not None:
by_id_result = await session.execute(
select(Document).where(
Document.search_space_id == self._search_space_id,
Document.id == doc_id_from_xml,
)
)
document = by_id_result.scalar_one_or_none()
if document is None:
return "Error: Could not map edited file to an existing document."
document.content = updated_content
document.source_markdown = updated_content
document.content_hash = generate_content_hash(
updated_content, self._search_space_id
)
document.updated_at = datetime.now(UTC)
if not document.document_metadata:
document.document_metadata = {}
document.document_metadata["virtual_path"] = file_path
summary_embedding = embed_texts([updated_content])[0]
document.embedding = summary_embedding
await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
chunk_texts = chunk_text(updated_content)
if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts)
session.add_all(
[
Chunk(
document_id=document.id, content=text, embedding=embedding
)
for text, embedding in zip(
chunk_texts, chunk_embeddings, strict=True
)
]
)
await session.commit()
return None
def _create_save_document_tool(self) -> BaseTool:
"""Create save_document tool that persists a new document to the KB."""
def sync_save_document(
title: Annotated[str, "Title for the new document."],
content: Annotated[
str,
"Plain-text or markdown content to save. Do NOT include XML wrappers.",
],
runtime: ToolRuntime[None, FilesystemState],
folder_path: Annotated[
str,
"Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.",
] = "",
) -> Command | str:
if not content.strip():
return "Error: content cannot be empty."
file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled"
if not file_name.lower().endswith(".xml"):
file_name = f"{file_name}.xml"
folder = folder_path.strip().strip("/") if folder_path else ""
virtual_path = (
f"/documents/{folder}/{file_name}"
if folder
else f"/documents/{file_name}"
)
persist_result = self._run_async_blocking(
self._persist_new_document(file_path=virtual_path, content=content)
)
if isinstance(persist_result, str):
return persist_result
if isinstance(persist_result, dict) and persist_result.get("id"):
dispatch_custom_event("document_created", persist_result)
return f"Document '{title}' saved to knowledge base (path: {virtual_path})."
async def async_save_document(
title: Annotated[str, "Title for the new document."],
content: Annotated[
str,
"Plain-text or markdown content to save. Do NOT include XML wrappers.",
],
runtime: ToolRuntime[None, FilesystemState],
folder_path: Annotated[
str,
"Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.",
] = "",
) -> Command | str:
if not content.strip():
return "Error: content cannot be empty."
file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled"
if not file_name.lower().endswith(".xml"):
file_name = f"{file_name}.xml"
folder = folder_path.strip().strip("/") if folder_path else ""
virtual_path = (
f"/documents/{folder}/{file_name}"
if folder
else f"/documents/{file_name}"
)
persist_result = await self._persist_new_document(
file_path=virtual_path, content=content
)
if isinstance(persist_result, str):
return persist_result
if isinstance(persist_result, dict) and persist_result.get("id"):
dispatch_custom_event("document_created", persist_result)
return f"Document '{title}' saved to knowledge base (path: {virtual_path})."
return StructuredTool.from_function(
name="save_document",
description=SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION,
func=sync_save_document,
coroutine=async_save_document,
)
def _create_write_file_tool(self) -> BaseTool:
"""Create write_file — ephemeral for /documents/*, persisted otherwise."""
tool_description = (
self._custom_tool_descriptions.get("write_file")
or SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION
)
def sync_write_file(
file_path: Annotated[
str,
"Absolute path where the file should be created. Must be absolute, not relative.",
],
content: Annotated[
str,
"The text content to write to the file. This parameter is required.",
],
runtime: ToolRuntime[None, FilesystemState],
) -> Command | str:
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
res: WriteResult = resolved_backend.write(validated_path, content)
if res.error:
return res.error
if not self._is_kb_document(validated_path):
persist_result = self._run_async_blocking(
self._persist_new_document(
file_path=validated_path, content=content
)
)
if isinstance(persist_result, str):
return persist_result
if isinstance(persist_result, dict) and persist_result.get("id"):
dispatch_custom_event("document_created", persist_result)
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=f"Updated file {res.path}",
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Updated file {res.path}"
async def async_write_file(
file_path: Annotated[
str,
"Absolute path where the file should be created. Must be absolute, not relative.",
],
content: Annotated[
str,
"The text content to write to the file. This parameter is required.",
],
runtime: ToolRuntime[None, FilesystemState],
) -> Command | str:
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
res: WriteResult = await resolved_backend.awrite(validated_path, content)
if res.error:
return res.error
if not self._is_kb_document(validated_path):
persist_result = await self._persist_new_document(
file_path=validated_path,
content=content,
)
if isinstance(persist_result, str):
return persist_result
if isinstance(persist_result, dict) and persist_result.get("id"):
dispatch_custom_event("document_created", persist_result)
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=f"Updated file {res.path}",
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Updated file {res.path}"
return StructuredTool.from_function(
name="write_file",
description=tool_description,
func=sync_write_file,
coroutine=async_write_file,
)
@staticmethod
def _is_kb_document(path: str) -> bool:
"""Return True for paths under /documents/ (KB-sourced, XML-wrapped)."""
return path.startswith("/documents/")
def _create_edit_file_tool(self) -> BaseTool:
"""Create edit_file with DB persistence (skipped for KB documents)."""
tool_description = (
self._custom_tool_descriptions.get("edit_file")
or SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION
)
def sync_edit_file(
file_path: Annotated[
str,
"Absolute path to the file to edit. Must be absolute, not relative.",
],
old_string: Annotated[
str,
"The exact text to find and replace. Must be unique in the file unless replace_all is True.",
],
new_string: Annotated[
str,
"The text to replace old_string with. Must be different from old_string.",
],
runtime: ToolRuntime[None, FilesystemState],
*,
replace_all: Annotated[
bool,
"If True, replace all occurrences of old_string. If False (default), old_string must be unique.",
] = False,
) -> Command | str:
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
res: EditResult = resolved_backend.edit(
validated_path,
old_string,
new_string,
replace_all=replace_all,
)
if res.error:
return res.error
if not self._is_kb_document(validated_path):
read_result = resolved_backend.read(
validated_path, offset=0, limit=200000
)
if read_result.error or read_result.file_data is None:
return f"Error: could not reload edited file '{validated_path}' for persistence."
updated_content = read_result.file_data["content"]
persist_result = self._run_async_blocking(
self._persist_edited_document(
file_path=validated_path,
updated_content=updated_content,
)
)
if isinstance(persist_result, str):
return persist_result
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'",
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'"
async def async_edit_file(
file_path: Annotated[
str,
"Absolute path to the file to edit. Must be absolute, not relative.",
],
old_string: Annotated[
str,
"The exact text to find and replace. Must be unique in the file unless replace_all is True.",
],
new_string: Annotated[
str,
"The text to replace old_string with. Must be different from old_string.",
],
runtime: ToolRuntime[None, FilesystemState],
*,
replace_all: Annotated[
bool,
"If True, replace all occurrences of old_string. If False (default), old_string must be unique.",
] = False,
) -> Command | str:
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
res: EditResult = await resolved_backend.aedit(
validated_path,
old_string,
new_string,
replace_all=replace_all,
)
if res.error:
return res.error
if not self._is_kb_document(validated_path):
read_result = await resolved_backend.aread(
validated_path, offset=0, limit=200000
)
if read_result.error or read_result.file_data is None:
return f"Error: could not reload edited file '{validated_path}' for persistence."
updated_content = read_result.file_data["content"]
persist_error = await self._persist_edited_document(
file_path=validated_path,
updated_content=updated_content,
)
if persist_error:
return persist_error
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'",
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'"
return StructuredTool.from_function(
name="edit_file",
description=tool_description,
func=sync_edit_file,
coroutine=async_edit_file,
)

View file

@ -0,0 +1,414 @@
"""Knowledge-base pre-search middleware for the SurfSense new chat agent.
This middleware runs before the main agent loop and seeds a virtual filesystem
(`files` state) with relevant documents retrieved via hybrid search. On each
turn the filesystem is *expanded* new results merge with documents loaded
during prior turns and a synthetic ``ls`` result is injected into the message
history so the LLM is immediately aware of the current filesystem structure.
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import uuid
from collections.abc import Sequence
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.runtime import Runtime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Document, Folder, shielded_async_session
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.utils.document_converters import embed_texts
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
def _extract_text_from_message(message: BaseMessage) -> str:
"""Extract plain text from a message content."""
content = getattr(message, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
return "\n".join(p for p in parts if p)
return str(content)
def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
"""Convert arbitrary text into a filesystem-safe filename."""
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
name = re.sub(r"\s+", " ", name)
if not name:
name = fallback
if len(name) > 180:
name = name[:180].rstrip()
if not name.lower().endswith(".xml"):
name = f"{name}.xml"
return name
def _build_document_xml(
document: dict[str, Any],
matched_chunk_ids: set[int] | None = None,
) -> str:
"""Build citation-friendly XML with a ``<chunk_index>`` for smart seeking.
The ``<chunk_index>`` at the top of each document lists every chunk with its
line range inside ``<document_content>`` and flags chunks that directly
matched the search query (``matched="true"``). This lets the LLM jump
straight to the most relevant section via ``read_file(offset=, limit=)``
instead of reading sequentially from the start.
"""
matched = matched_chunk_ids or set()
doc_meta = document.get("document") or {}
metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {}
document_id = doc_meta.get("id", document.get("document_id", "unknown"))
document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN"))
title = doc_meta.get("title") or metadata.get("title") or "Untitled Document"
url = (
metadata.get("url") or metadata.get("source") or metadata.get("page_url") or ""
)
metadata_json = json.dumps(metadata, ensure_ascii=False)
# --- 1. Metadata header (fixed structure) ---
metadata_lines: list[str] = [
"<document>",
"<document_metadata>",
f" <document_id>{document_id}</document_id>",
f" <document_type>{document_type}</document_type>",
f" <title><![CDATA[{title}]]></title>",
f" <url><![CDATA[{url}]]></url>",
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
"</document_metadata>",
"",
]
# --- 2. Pre-build chunk XML strings to compute line counts ---
chunks = document.get("chunks") or []
chunk_entries: list[tuple[int | None, str]] = [] # (chunk_id, xml_string)
if isinstance(chunks, list):
for chunk in chunks:
if not isinstance(chunk, dict):
continue
chunk_id = chunk.get("chunk_id") or chunk.get("id")
chunk_content = str(chunk.get("content", "")).strip()
if not chunk_content:
continue
if chunk_id is None:
xml = f" <chunk><![CDATA[{chunk_content}]]></chunk>"
else:
xml = f" <chunk id='{chunk_id}'><![CDATA[{chunk_content}]]></chunk>"
chunk_entries.append((chunk_id, xml))
# --- 3. Compute line numbers for every chunk ---
# Layout (1-indexed lines for read_file):
# metadata_lines -> len(metadata_lines) lines
# <chunk_index> -> 1 line
# index entries -> len(chunk_entries) lines
# </chunk_index> -> 1 line
# (empty line) -> 1 line
# <document_content> -> 1 line
# chunk xml lines…
# </document_content> -> 1 line
# </document> -> 1 line
index_overhead = (
1 + len(chunk_entries) + 1 + 1 + 1
) # tags + empty + <document_content>
first_chunk_line = len(metadata_lines) + index_overhead + 1 # 1-indexed
current_line = first_chunk_line
index_entry_lines: list[str] = []
for cid, xml_str in chunk_entries:
num_lines = xml_str.count("\n") + 1
end_line = current_line + num_lines - 1
matched_attr = ' matched="true"' if cid is not None and cid in matched else ""
if cid is not None:
index_entry_lines.append(
f' <entry chunk_id="{cid}" lines="{current_line}-{end_line}"{matched_attr}/>'
)
else:
index_entry_lines.append(
f' <entry lines="{current_line}-{end_line}"{matched_attr}/>'
)
current_line = end_line + 1
# --- 4. Assemble final XML ---
lines = metadata_lines.copy()
lines.append("<chunk_index>")
lines.extend(index_entry_lines)
lines.append("</chunk_index>")
lines.append("")
lines.append("<document_content>")
for _, xml_str in chunk_entries:
lines.append(xml_str)
lines.extend(["</document_content>", "</document>"])
return "\n".join(lines)
async def _get_folder_paths(
session: AsyncSession, search_space_id: int
) -> dict[int, str]:
"""Return a map of folder_id -> virtual folder path under /documents."""
result = await session.execute(
select(Folder.id, Folder.name, Folder.parent_id).where(
Folder.search_space_id == search_space_id
)
)
rows = result.all()
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
cache: dict[int, str] = {}
def resolve_path(folder_id: int) -> str:
if folder_id in cache:
return cache[folder_id]
parts: list[str] = []
cursor: int | None = folder_id
visited: set[int] = set()
while cursor is not None and cursor in by_id and cursor not in visited:
visited.add(cursor)
entry = by_id[cursor]
parts.append(
_safe_filename(str(entry["name"]), fallback="folder").removesuffix(
".xml"
)
)
cursor = entry["parent_id"]
parts.reverse()
path = "/documents/" + "/".join(parts) if parts else "/documents"
cache[folder_id] = path
return path
for folder_id in by_id:
resolve_path(folder_id)
return cache
def _build_synthetic_ls(
existing_files: dict[str, Any] | None,
new_files: dict[str, Any],
) -> tuple[AIMessage, ToolMessage]:
"""Build a synthetic ls("/documents") tool-call + result for the LLM context.
Paths are listed with *new* (rank-ordered) files first, then existing files
that were already in state from prior turns.
"""
merged: dict[str, Any] = {**(existing_files or {}), **new_files}
doc_paths = [
p for p, v in merged.items() if p.startswith("/documents/") and v is not None
]
new_set = set(new_files)
new_paths = [p for p in doc_paths if p in new_set]
old_paths = [p for p in doc_paths if p not in new_set]
ordered = new_paths + old_paths
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
ai_msg = AIMessage(
content="",
tool_calls=[{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}],
)
tool_msg = ToolMessage(
content=str(ordered) if ordered else "No documents found.",
tool_call_id=tool_call_id,
)
return ai_msg, tool_msg
def _resolve_search_types(
available_connectors: list[str] | None,
available_document_types: list[str] | None,
) -> list[str] | None:
"""Build a flat list of document-type strings for the chunk retriever.
Includes legacy equivalents from ``NATIVE_TO_LEGACY_DOCTYPE`` so that
old documents indexed under Composio names are still found.
Returns ``None`` when no filtering is desired (search all types).
"""
types: set[str] = set()
if available_document_types:
types.update(available_document_types)
if available_connectors:
types.update(available_connectors)
if not types:
return None
expanded: set[str] = set(types)
for t in types:
legacy = NATIVE_TO_LEGACY_DOCTYPE.get(t)
if legacy:
expanded.add(legacy)
return list(expanded) if expanded else None
async def search_knowledge_base(
*,
query: str,
search_space_id: int,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
top_k: int = 10,
) -> list[dict[str, Any]]:
"""Run a single unified hybrid search against the knowledge base.
Uses one ``ChucksHybridSearchRetriever`` call across all document types
instead of fanning out per-connector. This reduces the number of DB
queries from ~10 to 2 (one RRF query + one chunk fetch).
"""
if not query:
return []
[embedding] = embed_texts([query])
doc_types = _resolve_search_types(available_connectors, available_document_types)
retriever_top_k = min(top_k * 3, 30)
async with shielded_async_session() as session:
retriever = ChucksHybridSearchRetriever(session)
results = await retriever.hybrid_search(
query_text=query,
top_k=retriever_top_k,
search_space_id=search_space_id,
document_type=doc_types,
query_embedding=embedding.tolist(),
)
return results[:top_k]
async def build_scoped_filesystem(
*,
documents: Sequence[dict[str, Any]],
search_space_id: int,
) -> dict[str, dict[str, str]]:
"""Build a StateBackend-compatible files dict from search results."""
async with shielded_async_session() as session:
folder_paths = await _get_folder_paths(session, search_space_id)
doc_ids = [
(doc.get("document") or {}).get("id")
for doc in documents
if isinstance(doc, dict)
]
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
folder_by_doc_id: dict[int, int | None] = {}
if doc_ids:
doc_rows = await session.execute(
select(Document.id, Document.folder_id).where(
Document.search_space_id == search_space_id,
Document.id.in_(doc_ids),
)
)
folder_by_doc_id = {
row.id: row.folder_id for row in doc_rows.all() if row.id is not None
}
files: dict[str, dict[str, str]] = {}
for document in documents:
doc_meta = document.get("document") or {}
title = str(doc_meta.get("title") or "untitled")
doc_id = doc_meta.get("id")
folder_id = folder_by_doc_id.get(doc_id) if isinstance(doc_id, int) else None
base_folder = folder_paths.get(folder_id, "/documents")
file_name = _safe_filename(title)
path = f"{base_folder}/{file_name}"
matched_ids = set(document.get("matched_chunk_ids") or [])
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
files[path] = {
"content": xml_content.split("\n"),
"encoding": "utf-8",
"created_at": "",
"modified_at": "",
}
return files
class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Pre-agent middleware that always searches the KB and seeds a scoped filesystem."""
tools = ()
def __init__(
self,
*,
search_space_id: int,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
top_k: int = 10,
) -> None:
self.search_space_id = search_space_id
self.available_connectors = available_connectors
self.available_document_types = available_document_types
self.top_k = top_k
def before_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
try:
loop = asyncio.get_running_loop()
if loop.is_running():
return None
except RuntimeError:
pass
return asyncio.run(self.abefore_agent(state, runtime))
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
messages = state.get("messages") or []
if not messages:
return None
last_message = messages[-1]
if not isinstance(last_message, HumanMessage):
return None
user_text = _extract_text_from_message(last_message).strip()
if not user_text:
return None
t0 = _perf_log and asyncio.get_event_loop().time()
existing_files = state.get("files")
search_results = await search_knowledge_base(
query=user_text,
search_space_id=self.search_space_id,
available_connectors=self.available_connectors,
available_document_types=self.available_document_types,
top_k=self.top_k,
)
new_files = await build_scoped_filesystem(
documents=search_results,
search_space_id=self.search_space_id,
)
ai_msg, tool_msg = _build_synthetic_ls(existing_files, new_files)
if t0 is not None:
_perf_log.info(
"[kb_fs_middleware] completed in %.3fs query=%r new_files=%d total=%d",
asyncio.get_event_loop().time() - t0,
user_text[:80],
len(new_files),
len(new_files) + len(existing_files or {}),
)
return {"files": new_files, "messages": [ai_msg, tool_msg]}