mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): move middleware package to app/agents/shared (slice 5c)
Relocate the entire new_chat/middleware/ package to the shared kernel as one cohesive unit (it is live shared infrastructure: the multi-agent stack wraps nearly every middleware via multi_agent_chat/middleware/main_agent/*, and anonymous_agent consumes it too). Flip 69 live importers across both the package-path and submodule-path forms. Shims left for the frozen single-agent stack: a package __init__ re-export plus submodule shims for permission, skills_backends, and scoped_model_fallback (the three imported via submodule path by chat_deepagent/subagents). Cycle break: importing shared.middleware previously reached back into new_chat.tools at module load, which dragged in new_chat.__init__ -> chat_deepagent -> the middleware shim -> half-initialized shared.middleware. Made action_log's ToolDefinition import TYPE_CHECKING-only and tool_call_repair's INVALID_TOOL_NAME import function-local. These tools-package back-edges fully resolve in slice 6. Asset note: skills_backends._default_builtin_root now walks to app/agents/new_chat/skills/builtin (the skills/ tree migrates in slice 7).
This commit is contained in:
parent
6f488d9564
commit
227983a104
98 changed files with 1131 additions and 999 deletions
332
surfsense_backend/app/agents/shared/middleware/knowledge_tree.py
Normal file
332
surfsense_backend/app/agents/shared/middleware/knowledge_tree.py
Normal file
|
|
@ -0,0 +1,332 @@
|
|||
"""Workspace-tree middleware for the SurfSense agent.
|
||||
|
||||
Renders the full ``Folder``+``Document`` tree under ``/documents/`` once per
|
||||
turn (cloud only), caches it by ``(search_space_id, tree_version)``, and
|
||||
injects the result as a ``<workspace_tree>`` system message immediately
|
||||
before the latest human turn.
|
||||
|
||||
The render is bounded by two truncation layers:
|
||||
|
||||
1. **Entry cap** — at most ``MAX_TREE_ENTRIES`` lines. The remainder is
|
||||
replaced with a "use ls" hint.
|
||||
2. **Token cap** — at most ``MAX_TREE_TOKENS`` tokens (using the LLM's
|
||||
token-count profile when available). If the entry-truncated tree still
|
||||
exceeds the token cap we fall back to a root-only summary.
|
||||
|
||||
Anonymous mode renders only ``state['kb_anon_doc']`` (no DB calls).
|
||||
|
||||
This middleware also performs a one-time initialization of ``state['cwd']``
|
||||
to ``"/documents"`` so subsequent middlewares and tools always see a valid
|
||||
cwd in cloud mode.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.filesystem_state import SurfSenseFilesystemState
|
||||
from app.agents.shared.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
PathIndex,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.db import Document, shielded_async_session
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
try:
|
||||
from litellm import token_counter
|
||||
except Exception: # pragma: no cover - optional dep
|
||||
token_counter = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_TREE_ENTRIES = 500
|
||||
MAX_TREE_TOKENS = 4000
|
||||
|
||||
|
||||
def _approx_tokens(text: str) -> int:
|
||||
"""Cheap fallback token estimate (1 token ~= 4 chars)."""
|
||||
return max(1, (len(text) + 3) // 4)
|
||||
|
||||
|
||||
def _count_tokens(text: str, *, llm: BaseChatModel | None) -> int:
|
||||
if llm is None:
|
||||
return _approx_tokens(text)
|
||||
count_fn = getattr(llm, "_count_tokens", None)
|
||||
if callable(count_fn):
|
||||
try:
|
||||
return int(count_fn([{"role": "user", "content": text}]))
|
||||
except Exception:
|
||||
pass
|
||||
profile = getattr(llm, "profile", None)
|
||||
model_names: list[str] = []
|
||||
if isinstance(profile, dict):
|
||||
tcms = profile.get("token_count_models")
|
||||
if isinstance(tcms, list):
|
||||
model_names.extend(name for name in tcms if isinstance(name, str) and name)
|
||||
tcm = profile.get("token_count_model")
|
||||
if isinstance(tcm, str) and tcm and tcm not in model_names:
|
||||
model_names.append(tcm)
|
||||
model_name = model_names[0] if model_names else getattr(llm, "model", None)
|
||||
if not isinstance(model_name, str) or not model_name or token_counter is None:
|
||||
return _approx_tokens(text)
|
||||
try:
|
||||
return int(
|
||||
token_counter(
|
||||
messages=[{"role": "user", "content": text}],
|
||||
model=model_name,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
return _approx_tokens(text)
|
||||
|
||||
|
||||
class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Inject the workspace folder/document tree into the agent's context."""
|
||||
|
||||
tools = ()
|
||||
state_schema = SurfSenseFilesystemState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
search_space_id: int,
|
||||
filesystem_mode: FilesystemMode,
|
||||
llm: BaseChatModel | None = None,
|
||||
max_entries: int = MAX_TREE_ENTRIES,
|
||||
max_tokens: int = MAX_TREE_TOKENS,
|
||||
inject_system_message: bool = True, # For backwards compatibility
|
||||
) -> None:
|
||||
self.search_space_id = search_space_id
|
||||
self.filesystem_mode = filesystem_mode
|
||||
self.llm = llm
|
||||
self.max_entries = max_entries
|
||||
self.max_tokens = max_tokens
|
||||
self.inject_system_message = inject_system_message
|
||||
self._cache: dict[tuple[int, int, bool], str] = {}
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
|
||||
start = time.perf_counter()
|
||||
update: dict[str, Any] = {}
|
||||
if not state.get("cwd"):
|
||||
update["cwd"] = DOCUMENTS_ROOT
|
||||
|
||||
anon_doc = state.get("kb_anon_doc")
|
||||
if anon_doc:
|
||||
tree_msg = self._render_anon_tree(anon_doc)
|
||||
cache_outcome = "anon"
|
||||
else:
|
||||
version = int(state.get("tree_version") or 0)
|
||||
cache_key = (self.search_space_id, version, False)
|
||||
cache_outcome = "hit" if cache_key in self._cache else "miss"
|
||||
tree_msg = await self._render_kb_tree(state)
|
||||
|
||||
update["workspace_tree_text"] = tree_msg
|
||||
|
||||
if self.inject_system_message:
|
||||
messages = list(state.get("messages") or [])
|
||||
insert_at = max(len(messages) - 1, 0)
|
||||
messages.insert(insert_at, SystemMessage(content=tree_msg))
|
||||
update["messages"] = messages
|
||||
|
||||
_perf_log.info(
|
||||
"[knowledge_tree] cache=%s chars=%d elapsed=%.3fs space=%d",
|
||||
cache_outcome,
|
||||
len(tree_msg),
|
||||
time.perf_counter() - start,
|
||||
self.search_space_id,
|
||||
)
|
||||
return update
|
||||
|
||||
def before_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_running():
|
||||
return None
|
||||
except RuntimeError:
|
||||
pass
|
||||
return asyncio.run(self.abefore_agent(state, runtime))
|
||||
|
||||
# ------------------------------------------------------------------ render
|
||||
|
||||
def _render_anon_tree(self, anon_doc: dict[str, Any]) -> str:
|
||||
path = str(anon_doc.get("path") or "")
|
||||
title = str(anon_doc.get("title") or "uploaded_document")
|
||||
return (
|
||||
"<workspace_tree>\n"
|
||||
"Anonymous session — only one read-only document is available.\n"
|
||||
f"{DOCUMENTS_ROOT}/\n"
|
||||
f" {path} — {title}\n"
|
||||
"</workspace_tree>"
|
||||
)
|
||||
|
||||
async def _render_kb_tree(self, state: AgentState) -> str:
|
||||
version = int(state.get("tree_version") or 0)
|
||||
cache_key = (self.search_space_id, version, False)
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
index = await build_path_index(session, self.search_space_id)
|
||||
doc_rows = await session.execute(
|
||||
select(Document.id, Document.title, Document.folder_id).where(
|
||||
Document.search_space_id == self.search_space_id
|
||||
)
|
||||
)
|
||||
docs = list(doc_rows.all())
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning("knowledge_tree: DB error %s", exc)
|
||||
return "<workspace_tree>\n(unavailable)\n</workspace_tree>"
|
||||
|
||||
rendered = self._format_tree(index, docs)
|
||||
self._cache[cache_key] = rendered
|
||||
return rendered
|
||||
|
||||
def _format_tree(self, index: PathIndex, docs: list[Any]) -> str:
|
||||
folder_paths = sorted(set(index.folder_paths.values()))
|
||||
doc_paths = sorted(
|
||||
doc_to_virtual_path(
|
||||
doc_id=row.id,
|
||||
title=str(row.title or "untitled"),
|
||||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
for row in docs
|
||||
)
|
||||
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
|
||||
|
||||
# Pre-compute which folders have at least one descendant (folder or doc).
|
||||
# A folder is "empty" iff no path in `all_paths` is strictly under it.
|
||||
# Used to emit an explicit "(empty)" marker so the LLM doesn't have to
|
||||
# infer emptiness from indentation alone.
|
||||
non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths)
|
||||
|
||||
lines: list[str] = []
|
||||
for path in all_paths:
|
||||
depth = (
|
||||
0
|
||||
if path == DOCUMENTS_ROOT
|
||||
else len([p for p in path[len(DOCUMENTS_ROOT) :].split("/") if p])
|
||||
)
|
||||
indent = " " * depth
|
||||
is_dir = path == DOCUMENTS_ROOT or path in folder_paths
|
||||
display = (
|
||||
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
|
||||
)
|
||||
if is_dir:
|
||||
if path != DOCUMENTS_ROOT and path not in non_empty_folders:
|
||||
lines.append(f"{indent}{display}/ (empty)")
|
||||
else:
|
||||
lines.append(f"{indent}{display}/")
|
||||
else:
|
||||
lines.append(f"{indent}{display}")
|
||||
if len(lines) >= self.max_entries:
|
||||
remaining = len(all_paths) - len(lines)
|
||||
if remaining > 0:
|
||||
lines.append(
|
||||
f"... {remaining} more entries — use "
|
||||
"ls('/documents/<folder>', offset, limit) to expand"
|
||||
)
|
||||
break
|
||||
|
||||
body = "\n".join(lines)
|
||||
rendered = f"<workspace_tree>\n{body}\n</workspace_tree>"
|
||||
|
||||
token_count = _count_tokens(rendered, llm=self.llm)
|
||||
if token_count <= self.max_tokens:
|
||||
return rendered
|
||||
|
||||
return self._format_root_summary(folder_paths, doc_paths)
|
||||
|
||||
@staticmethod
|
||||
def _compute_non_empty_folders(
|
||||
folder_paths: list[str], doc_paths: list[str]
|
||||
) -> set[str]:
|
||||
"""Return the set of folder paths that contain at least one descendant.
|
||||
|
||||
A folder is "non-empty" if any document path or any other folder path
|
||||
is strictly under it. Documents propagate emptiness up to every
|
||||
ancestor folder, while a sub-folder only marks its direct ancestors
|
||||
non-empty (so a chain of empty folders all read ``(empty)``).
|
||||
"""
|
||||
non_empty: set[str] = set()
|
||||
folder_set = set(folder_paths)
|
||||
|
||||
for doc_path in doc_paths:
|
||||
parent = doc_path.rsplit("/", 1)[0]
|
||||
while parent and parent != DOCUMENTS_ROOT:
|
||||
if parent in folder_set:
|
||||
non_empty.add(parent)
|
||||
parent = parent.rsplit("/", 1)[0]
|
||||
|
||||
for child in folder_paths:
|
||||
parent = child.rsplit("/", 1)[0]
|
||||
while parent and parent != DOCUMENTS_ROOT and parent in folder_set:
|
||||
non_empty.add(parent)
|
||||
parent = parent.rsplit("/", 1)[0]
|
||||
|
||||
return non_empty
|
||||
|
||||
def _format_root_summary(
|
||||
self, folder_paths: list[str], doc_paths: list[str]
|
||||
) -> str:
|
||||
top_level: dict[str, int] = {}
|
||||
loose_docs = 0
|
||||
for path in doc_paths:
|
||||
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||
if "/" in rel:
|
||||
top = rel.split("/", 1)[0]
|
||||
top_level[top] = top_level.get(top, 0) + 1
|
||||
else:
|
||||
loose_docs += 1
|
||||
for path in folder_paths:
|
||||
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||
if not rel:
|
||||
continue
|
||||
top = rel.split("/", 1)[0]
|
||||
top_level.setdefault(top, 0)
|
||||
|
||||
lines = [DOCUMENTS_ROOT + "/"]
|
||||
for name in sorted(top_level):
|
||||
count = top_level[name]
|
||||
lines.append(f" {name}/ ({count} document{'s' if count != 1 else ''})")
|
||||
if loose_docs:
|
||||
lines.append(
|
||||
f" ({loose_docs} loose document{'s' if loose_docs != 1 else ''})"
|
||||
)
|
||||
lines.append(
|
||||
"Tree is large; use list_tree('/documents/<folder>') to drill in "
|
||||
"or ls('/documents/<folder>', offset, limit) for paginated listings."
|
||||
)
|
||||
return "<workspace_tree>\n" + "\n".join(lines) + "\n</workspace_tree>"
|
||||
|
||||
|
||||
__all__ = ["KnowledgeTreeMiddleware"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue