mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
Relocate the entire new_chat/middleware/ package to the shared kernel as one cohesive unit (it is live shared infrastructure: the multi-agent stack wraps nearly every middleware via multi_agent_chat/middleware/main_agent/*, and anonymous_agent consumes it too). Flip 69 live importers across both the package-path and submodule-path forms. Shims left for the frozen single-agent stack: a package __init__ re-export plus submodule shims for permission, skills_backends, and scoped_model_fallback (the three imported via submodule path by chat_deepagent/subagents). Cycle break: importing shared.middleware previously reached back into new_chat.tools at module load, which dragged in new_chat.__init__ -> chat_deepagent -> the middleware shim -> half-initialized shared.middleware. Made action_log's ToolDefinition import TYPE_CHECKING-only and tool_call_repair's INVALID_TOOL_NAME import function-local. These tools-package back-edges fully resolve in slice 6. Asset note: skills_backends._default_builtin_root now walks to app/agents/new_chat/skills/builtin (the skills/ tree migrates in slice 7).
1057 lines
39 KiB
Python
1057 lines
39 KiB
Python
"""Hybrid-search priority middleware for the SurfSense new chat agent.
|
|
|
|
This middleware runs ``before_agent`` on every turn and writes:
|
|
|
|
* ``state["kb_priority"]`` — the top-K most relevant documents for the
|
|
current user message, used to render a ``<priority_documents>`` system
|
|
message immediately before the user turn.
|
|
* ``state["kb_matched_chunk_ids"]`` — internal hand-off mapping
|
|
(``Document.id`` → matched chunk IDs) consumed by
|
|
:class:`KBPostgresBackend._load_file_data` when the agent first reads each
|
|
document, so the XML wrapper can flag matched sections in
|
|
``<chunk_index>``.
|
|
|
|
The previous "scoped filesystem" behaviour (synthetic ``ls`` + state
|
|
``files`` seeding) is intentionally removed: documents are now lazy-loaded
|
|
from Postgres on demand, with the full workspace tree rendered separately
|
|
by :class:`KnowledgeTreeMiddleware`.
|
|
|
|
In anonymous mode the middleware skips hybrid search entirely and emits a
|
|
single-entry priority list pointing at the Redis-loaded document
|
|
(``state["kb_anon_doc"]``).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import re
|
|
from collections.abc import Sequence
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
from langchain.agents import create_agent
|
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
|
from langchain_core.language_models import BaseChatModel
|
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
|
from langchain_core.runnables import Runnable
|
|
from langgraph.runtime import Runtime
|
|
from litellm import token_counter
|
|
from pydantic import BaseModel, Field, ValidationError
|
|
from sqlalchemy import select
|
|
|
|
from app.agents.shared.feature_flags import get_flags
|
|
from app.agents.shared.filesystem_selection import FilesystemMode
|
|
from app.agents.shared.filesystem_state import SurfSenseFilesystemState
|
|
from app.agents.shared.path_resolver import (
|
|
PathIndex,
|
|
build_path_index,
|
|
doc_to_virtual_path,
|
|
)
|
|
from app.agents.shared.utils import parse_date_or_datetime, resolve_date_range
|
|
from app.db import (
|
|
NATIVE_TO_LEGACY_DOCTYPE,
|
|
Chunk,
|
|
Document,
|
|
Folder,
|
|
shielded_async_session,
|
|
)
|
|
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
|
from app.utils.document_converters import embed_texts
|
|
from app.utils.perf import get_perf_logger
|
|
|
|
logger = logging.getLogger(__name__)
|
|
_perf_log = get_perf_logger()
|
|
|
|
|
|
class KBSearchPlan(BaseModel):
|
|
"""Structured internal plan for KB retrieval."""
|
|
|
|
optimized_query: str = Field(
|
|
min_length=1,
|
|
description="Optimized retrieval query preserving the user's intent.",
|
|
)
|
|
start_date: str | None = Field(
|
|
default=None,
|
|
description="Optional ISO start date or datetime for KB search filtering.",
|
|
)
|
|
end_date: str | None = Field(
|
|
default=None,
|
|
description="Optional ISO end date or datetime for KB search filtering.",
|
|
)
|
|
is_recency_query: bool = Field(
|
|
default=False,
|
|
description=(
|
|
"True when the user's intent is primarily about recency or temporal "
|
|
"ordering (e.g. 'latest', 'newest', 'most recent', 'last uploaded') "
|
|
"rather than topical relevance."
|
|
),
|
|
)
|
|
|
|
|
|
def _extract_text_from_message(message: BaseMessage) -> str:
|
|
content = getattr(message, "content", "")
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
parts: list[str] = []
|
|
for item in content:
|
|
if isinstance(item, str):
|
|
parts.append(item)
|
|
elif isinstance(item, dict) and item.get("type") == "text":
|
|
parts.append(str(item.get("text", "")))
|
|
return "\n".join(p for p in parts if p)
|
|
return str(content)
|
|
|
|
|
|
def _render_recent_conversation(
|
|
messages: Sequence[BaseMessage],
|
|
*,
|
|
llm: BaseChatModel | None = None,
|
|
user_text: str = "",
|
|
max_messages: int = 6,
|
|
) -> str:
|
|
"""Render recent dialogue for internal planning under a token budget.
|
|
|
|
Filters to ``HumanMessage`` and ``AIMessage`` (without tool_calls) so that
|
|
injected ``SystemMessage`` artefacts (priority list, workspace tree,
|
|
file-write contract) don't pollute the planner prompt.
|
|
"""
|
|
rendered: list[tuple[str, str]] = []
|
|
for message in messages:
|
|
role: str | None = None
|
|
if isinstance(message, HumanMessage):
|
|
role = "user"
|
|
elif isinstance(message, AIMessage):
|
|
if getattr(message, "tool_calls", None):
|
|
continue
|
|
role = "assistant"
|
|
else:
|
|
continue
|
|
|
|
text = _extract_text_from_message(message).strip()
|
|
if not text:
|
|
continue
|
|
text = re.sub(r"\s+", " ", text)
|
|
rendered.append((role, text))
|
|
|
|
if not rendered:
|
|
return ""
|
|
|
|
if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip():
|
|
rendered = rendered[:-1]
|
|
|
|
if not rendered:
|
|
return ""
|
|
|
|
def _legacy_render() -> str:
|
|
legacy_lines: list[str] = []
|
|
for role, text in rendered[-max_messages:]:
|
|
clipped = text[:400].rstrip() + "..." if len(text) > 400 else text
|
|
legacy_lines.append(f"{role}: {clipped}")
|
|
return "\n".join(legacy_lines)
|
|
|
|
def _count_prompt_tokens(conversation_text: str) -> int | None:
|
|
prompt = _build_kb_planner_prompt(
|
|
recent_conversation=conversation_text or "(none)",
|
|
user_text=user_text,
|
|
)
|
|
message_payload = [{"role": "user", "content": prompt}]
|
|
|
|
count_fn = getattr(llm, "_count_tokens", None) if llm is not None else None
|
|
if callable(count_fn):
|
|
try:
|
|
return count_fn(message_payload)
|
|
except Exception:
|
|
pass
|
|
|
|
profile = getattr(llm, "profile", None) if llm is not None else None
|
|
model_names: list[str] = []
|
|
if isinstance(profile, dict):
|
|
tcms = profile.get("token_count_models")
|
|
if isinstance(tcms, list):
|
|
model_names.extend(
|
|
name for name in tcms if isinstance(name, str) and name
|
|
)
|
|
tcm = profile.get("token_count_model")
|
|
if isinstance(tcm, str) and tcm and tcm not in model_names:
|
|
model_names.append(tcm)
|
|
model_name = model_names[0] if model_names else getattr(llm, "model", None)
|
|
if not isinstance(model_name, str) or not model_name:
|
|
return None
|
|
try:
|
|
return token_counter(messages=message_payload, model=model_name)
|
|
except Exception:
|
|
return None
|
|
|
|
get_max_input_tokens = getattr(llm, "_get_max_input_tokens", None) if llm else None
|
|
if callable(get_max_input_tokens):
|
|
try:
|
|
max_input_tokens = int(get_max_input_tokens())
|
|
except Exception:
|
|
max_input_tokens = None
|
|
else:
|
|
profile = getattr(llm, "profile", None) if llm is not None else None
|
|
max_input_tokens = (
|
|
profile.get("max_input_tokens")
|
|
if isinstance(profile, dict)
|
|
and isinstance(profile.get("max_input_tokens"), int)
|
|
else None
|
|
)
|
|
|
|
if not isinstance(max_input_tokens, int) or max_input_tokens <= 0:
|
|
return _legacy_render()
|
|
|
|
output_reserve = min(max(int(max_input_tokens * 0.02), 256), 1024)
|
|
budget = max_input_tokens - output_reserve
|
|
if budget <= 0:
|
|
return _legacy_render()
|
|
|
|
selected_lines: list[str] = []
|
|
for role, text in reversed(rendered):
|
|
candidate_line = f"{role}: {text}"
|
|
candidate_lines = [candidate_line, *selected_lines]
|
|
candidate_conversation = "\n".join(candidate_lines)
|
|
token_count = _count_prompt_tokens(candidate_conversation)
|
|
if token_count is None:
|
|
return _legacy_render()
|
|
if token_count <= budget:
|
|
selected_lines = candidate_lines
|
|
continue
|
|
|
|
lo, hi = 1, len(text)
|
|
best_line: str | None = None
|
|
while lo <= hi:
|
|
mid = (lo + hi) // 2
|
|
clipped_text = text[:mid].rstrip() + "..."
|
|
clipped_line = f"{role}: {clipped_text}"
|
|
clipped_conversation = "\n".join([clipped_line, *selected_lines])
|
|
clipped_tokens = _count_prompt_tokens(clipped_conversation)
|
|
if clipped_tokens is None:
|
|
break
|
|
if clipped_tokens <= budget:
|
|
best_line = clipped_line
|
|
lo = mid + 1
|
|
else:
|
|
hi = mid - 1
|
|
|
|
if best_line is not None:
|
|
selected_lines = [best_line, *selected_lines]
|
|
break
|
|
|
|
if not selected_lines:
|
|
return _legacy_render()
|
|
|
|
return "\n".join(selected_lines)
|
|
|
|
|
|
def _build_kb_planner_prompt(
|
|
*,
|
|
recent_conversation: str,
|
|
user_text: str,
|
|
) -> str:
|
|
today = datetime.now(UTC).date().isoformat()
|
|
return (
|
|
"You optimize internal knowledge-base search inputs for document retrieval.\n"
|
|
"Return JSON only with this exact shape:\n"
|
|
'{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null","is_recency_query":bool}\n\n'
|
|
"Rules:\n"
|
|
"- Preserve the user's intent.\n"
|
|
"- Rewrite the query to improve retrieval using concrete entities, acronyms, projects, tools, people, and document-specific terms when helpful.\n"
|
|
"- Keep the query concise and retrieval-focused.\n"
|
|
"- Only use date filters when the latest user request or recent dialogue clearly implies a time range.\n"
|
|
"- If you use date filters, prefer returning both bounds.\n"
|
|
"- If no date filter is useful, return null for both dates.\n"
|
|
'- Set "is_recency_query" to true ONLY when the user\'s primary intent is about '
|
|
"recency or temporal ordering rather than topical relevance. Examples: "
|
|
'"latest file", "newest upload", "most recent document", "what did I save last", '
|
|
'"show me files from today", "last thing I added". '
|
|
"When true, results will be sorted by date instead of relevance.\n"
|
|
"- Do not include markdown, prose, or explanations.\n\n"
|
|
f"Today's UTC date: {today}\n\n"
|
|
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
|
|
f"Latest user message:\n{user_text}"
|
|
)
|
|
|
|
|
|
def _extract_json_payload(text: str) -> str:
|
|
stripped = text.strip()
|
|
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
|
if fenced:
|
|
return fenced.group(1)
|
|
start = stripped.find("{")
|
|
end = stripped.rfind("}")
|
|
if start != -1 and end != -1 and end > start:
|
|
return stripped[start : end + 1]
|
|
return stripped
|
|
|
|
|
|
def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan:
|
|
payload = json.loads(_extract_json_payload(response_text))
|
|
return KBSearchPlan.model_validate(payload)
|
|
|
|
|
|
def _normalize_optional_date_range(
|
|
start_date: str | None,
|
|
end_date: str | None,
|
|
) -> tuple[datetime | None, datetime | None]:
|
|
parsed_start = parse_date_or_datetime(start_date) if start_date else None
|
|
parsed_end = parse_date_or_datetime(end_date) if end_date else None
|
|
|
|
if parsed_start is None and parsed_end is None:
|
|
return None, None
|
|
|
|
return resolve_date_range(parsed_start, parsed_end)
|
|
|
|
|
|
def _resolve_search_types(
|
|
available_connectors: list[str] | None,
|
|
available_document_types: list[str] | None,
|
|
) -> list[str] | None:
|
|
types: set[str] = set()
|
|
if available_document_types:
|
|
types.update(available_document_types)
|
|
if available_connectors:
|
|
types.update(available_connectors)
|
|
if not types:
|
|
return None
|
|
|
|
expanded: set[str] = set(types)
|
|
for t in types:
|
|
legacy = NATIVE_TO_LEGACY_DOCTYPE.get(t)
|
|
if legacy:
|
|
expanded.add(legacy)
|
|
return list(expanded) if expanded else None
|
|
|
|
|
|
_RECENCY_MAX_CHUNKS_PER_DOC = 5
|
|
|
|
|
|
async def browse_recent_documents(
|
|
*,
|
|
search_space_id: int,
|
|
document_type: list[str] | None = None,
|
|
top_k: int = 10,
|
|
start_date: datetime | None = None,
|
|
end_date: datetime | None = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""Return documents ordered by recency (newest first), no relevance ranking."""
|
|
from sqlalchemy import func
|
|
|
|
from app.db import DocumentType
|
|
|
|
async with shielded_async_session() as session:
|
|
base_conditions = [
|
|
Document.search_space_id == search_space_id,
|
|
func.coalesce(Document.status["state"].astext, "ready") != "deleting",
|
|
]
|
|
|
|
if document_type is not None:
|
|
import contextlib
|
|
|
|
doc_type_enums = []
|
|
for dt in document_type:
|
|
if isinstance(dt, str):
|
|
with contextlib.suppress(KeyError):
|
|
doc_type_enums.append(DocumentType[dt])
|
|
else:
|
|
doc_type_enums.append(dt)
|
|
if doc_type_enums:
|
|
if len(doc_type_enums) == 1:
|
|
base_conditions.append(Document.document_type == doc_type_enums[0])
|
|
else:
|
|
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
|
|
|
if start_date is not None:
|
|
base_conditions.append(Document.updated_at >= start_date)
|
|
if end_date is not None:
|
|
base_conditions.append(Document.updated_at <= end_date)
|
|
|
|
doc_query = (
|
|
select(Document)
|
|
.where(*base_conditions)
|
|
.order_by(Document.updated_at.desc())
|
|
.limit(top_k)
|
|
)
|
|
result = await session.execute(doc_query)
|
|
documents = result.scalars().unique().all()
|
|
|
|
if not documents:
|
|
return []
|
|
|
|
doc_ids = [d.id for d in documents]
|
|
numbered = (
|
|
select(
|
|
Chunk.id.label("chunk_id"),
|
|
Chunk.document_id,
|
|
Chunk.content,
|
|
func.row_number()
|
|
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
|
|
.label("rn"),
|
|
)
|
|
.where(Chunk.document_id.in_(doc_ids))
|
|
.subquery("numbered")
|
|
)
|
|
|
|
chunk_query = (
|
|
select(numbered.c.chunk_id, numbered.c.document_id, numbered.c.content)
|
|
.where(numbered.c.rn <= _RECENCY_MAX_CHUNKS_PER_DOC)
|
|
.order_by(numbered.c.document_id, numbered.c.chunk_id)
|
|
)
|
|
chunk_result = await session.execute(chunk_query)
|
|
fetched_chunks = chunk_result.all()
|
|
|
|
doc_chunks: dict[int, list[dict[str, Any]]] = {d.id: [] for d in documents}
|
|
for row in fetched_chunks:
|
|
if row.document_id in doc_chunks:
|
|
doc_chunks[row.document_id].append(
|
|
{"chunk_id": row.chunk_id, "content": row.content}
|
|
)
|
|
|
|
results: list[dict[str, Any]] = []
|
|
for doc in documents:
|
|
chunks_list = doc_chunks.get(doc.id, [])
|
|
metadata = doc.document_metadata or {}
|
|
results.append(
|
|
{
|
|
"document_id": doc.id,
|
|
"content": "\n\n".join(
|
|
c["content"] for c in chunks_list if c.get("content")
|
|
),
|
|
"score": 0.0,
|
|
"chunks": chunks_list,
|
|
"matched_chunk_ids": [],
|
|
"document": {
|
|
"id": doc.id,
|
|
"title": doc.title,
|
|
"document_type": (
|
|
doc.document_type.value
|
|
if getattr(doc, "document_type", None)
|
|
else None
|
|
),
|
|
"metadata": metadata,
|
|
"folder_id": getattr(doc, "folder_id", None),
|
|
},
|
|
"source": (
|
|
doc.document_type.value
|
|
if getattr(doc, "document_type", None)
|
|
else None
|
|
),
|
|
}
|
|
)
|
|
return results
|
|
|
|
|
|
async def search_knowledge_base(
|
|
*,
|
|
query: str,
|
|
search_space_id: int,
|
|
available_connectors: list[str] | None = None,
|
|
available_document_types: list[str] | None = None,
|
|
top_k: int = 10,
|
|
start_date: datetime | None = None,
|
|
end_date: datetime | None = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""Run a single unified hybrid search against the knowledge base."""
|
|
if not query:
|
|
return []
|
|
|
|
[embedding] = await asyncio.to_thread(embed_texts, [query])
|
|
doc_types = _resolve_search_types(available_connectors, available_document_types)
|
|
retriever_top_k = min(top_k * 3, 30)
|
|
|
|
async with shielded_async_session() as session:
|
|
retriever = ChucksHybridSearchRetriever(session)
|
|
results = await retriever.hybrid_search(
|
|
query_text=query,
|
|
top_k=retriever_top_k,
|
|
search_space_id=search_space_id,
|
|
document_type=doc_types,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
query_embedding=embedding.tolist(),
|
|
)
|
|
|
|
return results[:top_k]
|
|
|
|
|
|
async def fetch_mentioned_documents(
|
|
*,
|
|
document_ids: list[int],
|
|
search_space_id: int,
|
|
) -> list[dict[str, Any]]:
|
|
"""Fetch explicitly mentioned documents."""
|
|
if not document_ids:
|
|
return []
|
|
|
|
async with shielded_async_session() as session:
|
|
doc_result = await session.execute(
|
|
select(Document).where(
|
|
Document.id.in_(document_ids),
|
|
Document.search_space_id == search_space_id,
|
|
)
|
|
)
|
|
docs = {doc.id: doc for doc in doc_result.scalars().all()}
|
|
|
|
if not docs:
|
|
return []
|
|
|
|
chunk_result = await session.execute(
|
|
select(Chunk.id, Chunk.content, Chunk.document_id)
|
|
.where(Chunk.document_id.in_(list(docs.keys())))
|
|
.order_by(Chunk.document_id, Chunk.id)
|
|
)
|
|
chunks_by_doc: dict[int, list[dict[str, Any]]] = {doc_id: [] for doc_id in docs}
|
|
for row in chunk_result.all():
|
|
if row.document_id in chunks_by_doc:
|
|
chunks_by_doc[row.document_id].append(
|
|
{"chunk_id": row.id, "content": row.content}
|
|
)
|
|
|
|
results: list[dict[str, Any]] = []
|
|
for doc_id in document_ids:
|
|
doc = docs.get(doc_id)
|
|
if doc is None:
|
|
continue
|
|
metadata = doc.document_metadata or {}
|
|
results.append(
|
|
{
|
|
"document_id": doc.id,
|
|
"content": "",
|
|
"score": 1.0,
|
|
"chunks": chunks_by_doc.get(doc.id, []),
|
|
"matched_chunk_ids": [],
|
|
"document": {
|
|
"id": doc.id,
|
|
"title": doc.title,
|
|
"document_type": (
|
|
doc.document_type.value
|
|
if getattr(doc, "document_type", None)
|
|
else None
|
|
),
|
|
"metadata": metadata,
|
|
"folder_id": getattr(doc, "folder_id", None),
|
|
},
|
|
"source": (
|
|
doc.document_type.value
|
|
if getattr(doc, "document_type", None)
|
|
else None
|
|
),
|
|
"_user_mentioned": True,
|
|
}
|
|
)
|
|
return results
|
|
|
|
|
|
def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage:
|
|
"""Render the priority list as a single ``<priority_documents>`` system message."""
|
|
if not priority:
|
|
body = "(no priority documents for this turn)"
|
|
else:
|
|
lines: list[str] = []
|
|
for entry in priority:
|
|
score = entry.get("score")
|
|
mentioned = entry.get("mentioned")
|
|
score_str = f"{score:.3f}" if isinstance(score, int | float) else "n/a"
|
|
mark = " [USER-MENTIONED]" if mentioned else ""
|
|
lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}")
|
|
body = "\n".join(lines)
|
|
return SystemMessage(
|
|
content=(
|
|
"<priority_documents>\n"
|
|
"These documents are most relevant to the latest user message; "
|
|
"read them first. Matched sections are flagged inside each "
|
|
"document's <chunk_index>.\n"
|
|
f"{body}\n"
|
|
"</priority_documents>"
|
|
)
|
|
)
|
|
|
|
|
|
class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|
"""Compute hybrid-search priority hints for the current turn."""
|
|
|
|
tools = ()
|
|
state_schema = SurfSenseFilesystemState
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
llm: BaseChatModel | None = None,
|
|
planner_llm: BaseChatModel | None = None,
|
|
search_space_id: int,
|
|
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
|
available_connectors: list[str] | None = None,
|
|
available_document_types: list[str] | None = None,
|
|
top_k: int = 10,
|
|
mentioned_document_ids: list[int] | None = None,
|
|
inject_system_message: bool = True, # For backwards compatibility
|
|
) -> None:
|
|
self.llm = llm
|
|
# The planner LLM handles short, structured internal tasks (query
|
|
# rewriting, date extraction, recency classification). When an
|
|
# operator marks a global config ``is_planner: true`` we route
|
|
# those calls to a cheap/fast model (e.g. gpt-4o-mini, Haiku, Azure
|
|
# gpt-5.x-nano) instead of the user's chat LLM — those classification
|
|
# tasks don't need frontier-tier capability. Falls back to the chat
|
|
# LLM when no planner config is wired up so deployments without one
|
|
# keep working unchanged.
|
|
self.planner_llm = planner_llm or llm
|
|
self.search_space_id = search_space_id
|
|
self.filesystem_mode = filesystem_mode
|
|
self.available_connectors = available_connectors
|
|
self.available_document_types = available_document_types
|
|
self.top_k = top_k
|
|
self.mentioned_document_ids = mentioned_document_ids or []
|
|
self.inject_system_message = inject_system_message
|
|
# Build the kb-planner private Runnable ONCE here so we don't pay
|
|
# the ``create_agent`` compile cost (50-200ms) on every turn.
|
|
# Disabled by default behind ``enable_kb_planner_runnable``; when
|
|
# off the planner falls back to the legacy ``planner_llm.ainvoke``
|
|
# path.
|
|
self._planner: Runnable | None = None
|
|
self._planner_compile_failed = False
|
|
|
|
def _build_kb_planner_runnable(self) -> Runnable | None:
|
|
"""Compile the kb-planner private :class:`Runnable` once.
|
|
|
|
Returns ``None`` when the feature flag is disabled, when the LLM is
|
|
unavailable, or when ``create_agent`` raises (we fall back to the
|
|
legacy ``planner_llm.ainvoke`` path in that case). Compilation happens
|
|
lazily on first call, then memoized via ``self._planner``.
|
|
|
|
The compiled agent is constructed without tools — the planner's
|
|
contract is "answer with structured JSON" — but it inherits the
|
|
:class:`RetryAfterMiddleware` so transient rate-limit errors
|
|
from the planner LLM call don't fail the whole turn.
|
|
"""
|
|
if self._planner is not None or self._planner_compile_failed:
|
|
return self._planner
|
|
if self.planner_llm is None:
|
|
return None
|
|
flags = get_flags()
|
|
if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
|
|
return None
|
|
|
|
from app.agents.shared.middleware.retry_after import RetryAfterMiddleware
|
|
|
|
try:
|
|
self._planner = create_agent(
|
|
self.planner_llm,
|
|
tools=[],
|
|
middleware=[RetryAfterMiddleware(max_retries=2)],
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive
|
|
logger.warning(
|
|
"kb-planner Runnable compile failed; falling back to planner_llm.ainvoke: %s",
|
|
exc,
|
|
)
|
|
self._planner_compile_failed = True
|
|
self._planner = None
|
|
return self._planner
|
|
|
|
async def _plan_search_inputs(
|
|
self,
|
|
*,
|
|
messages: Sequence[BaseMessage],
|
|
user_text: str,
|
|
) -> tuple[str, datetime | None, datetime | None, bool]:
|
|
if self.planner_llm is None:
|
|
return user_text, None, None, False
|
|
|
|
recent_conversation = _render_recent_conversation(
|
|
messages,
|
|
llm=self.planner_llm,
|
|
user_text=user_text,
|
|
)
|
|
prompt = _build_kb_planner_prompt(
|
|
recent_conversation=recent_conversation,
|
|
user_text=user_text,
|
|
)
|
|
loop = asyncio.get_running_loop()
|
|
t0 = loop.time()
|
|
|
|
# Prefer the compiled-once planner Runnable when enabled; otherwise
|
|
# fall back to ``planner_llm.ainvoke``. The ``surfsense:internal``
|
|
# tag is preserved on both paths so ``_stream_agent_events`` still
|
|
# suppresses the planner's intermediate events from the UI.
|
|
planner = self._build_kb_planner_runnable()
|
|
try:
|
|
if planner is not None:
|
|
planner_state = await planner.ainvoke(
|
|
{"messages": [HumanMessage(content=prompt)]},
|
|
config={"tags": ["surfsense:internal"]},
|
|
)
|
|
response_messages = (
|
|
planner_state.get("messages", [])
|
|
if isinstance(planner_state, dict)
|
|
else []
|
|
)
|
|
response = (
|
|
response_messages[-1]
|
|
if response_messages
|
|
else AIMessage(content="")
|
|
)
|
|
else:
|
|
response = await self.planner_llm.ainvoke(
|
|
[HumanMessage(content=prompt)],
|
|
config={"tags": ["surfsense:internal"]},
|
|
)
|
|
plan = _parse_kb_search_plan_response(_extract_text_from_message(response))
|
|
optimized_query = (
|
|
re.sub(r"\s+", " ", plan.optimized_query).strip() or user_text
|
|
)
|
|
start_date, end_date = _normalize_optional_date_range(
|
|
plan.start_date,
|
|
plan.end_date,
|
|
)
|
|
is_recency = plan.is_recency_query
|
|
_perf_log.info(
|
|
"[kb_priority] planner in %.3fs query=%r optimized=%r "
|
|
"start=%s end=%s recency=%s",
|
|
loop.time() - t0,
|
|
user_text[:80],
|
|
optimized_query[:120],
|
|
start_date.isoformat() if start_date else None,
|
|
end_date.isoformat() if end_date else None,
|
|
is_recency,
|
|
)
|
|
return optimized_query, start_date, end_date, is_recency
|
|
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
|
|
logger.warning(
|
|
"KB planner returned invalid output, using raw query: %s", exc
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive fallback
|
|
logger.warning("KB planner failed, using raw query: %s", exc)
|
|
|
|
return user_text, None, None, False
|
|
|
|
def before_agent( # type: ignore[override]
|
|
self,
|
|
state: AgentState,
|
|
runtime: Runtime[Any],
|
|
) -> dict[str, Any] | None:
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
if loop.is_running():
|
|
return None
|
|
except RuntimeError:
|
|
pass
|
|
return asyncio.run(self.abefore_agent(state, runtime))
|
|
|
|
async def abefore_agent( # type: ignore[override]
|
|
self,
|
|
state: AgentState,
|
|
runtime: Runtime[Any],
|
|
) -> dict[str, Any] | None:
|
|
if self.filesystem_mode != FilesystemMode.CLOUD:
|
|
return None
|
|
|
|
messages = state.get("messages") or []
|
|
if not messages:
|
|
return None
|
|
|
|
last_human: HumanMessage | None = None
|
|
for msg in reversed(messages):
|
|
if isinstance(msg, HumanMessage):
|
|
last_human = msg
|
|
break
|
|
if last_human is None:
|
|
return None
|
|
user_text = _extract_text_from_message(last_human).strip()
|
|
if not user_text:
|
|
return None
|
|
|
|
anon_doc = state.get("kb_anon_doc")
|
|
if anon_doc:
|
|
return self._anon_priority(state, anon_doc)
|
|
|
|
return await self._authenticated_priority(state, messages, user_text, runtime)
|
|
|
|
def _anon_priority(
|
|
self,
|
|
state: AgentState,
|
|
anon_doc: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
path = str(anon_doc.get("path") or "")
|
|
title = str(anon_doc.get("title") or "uploaded_document")
|
|
priority = [
|
|
{
|
|
"path": path,
|
|
"score": 1.0,
|
|
"document_id": None,
|
|
"title": title,
|
|
"mentioned": True,
|
|
}
|
|
]
|
|
update: dict[str, Any] = {
|
|
"kb_priority": priority,
|
|
"kb_matched_chunk_ids": {},
|
|
}
|
|
if self.inject_system_message:
|
|
new_messages = list(state.get("messages") or [])
|
|
insert_at = max(len(new_messages) - 1, 0)
|
|
new_messages.insert(insert_at, _render_priority_message(priority))
|
|
update["messages"] = new_messages
|
|
return update
|
|
|
|
async def _authenticated_priority(
|
|
self,
|
|
state: AgentState,
|
|
messages: Sequence[BaseMessage],
|
|
user_text: str,
|
|
runtime: Runtime[Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
t0 = asyncio.get_event_loop().time()
|
|
(
|
|
planned_query,
|
|
start_date,
|
|
end_date,
|
|
is_recency,
|
|
) = await self._plan_search_inputs(
|
|
messages=messages,
|
|
user_text=user_text,
|
|
)
|
|
|
|
# Per-turn ``mentioned_document_ids`` flow:
|
|
# 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
|
|
# streaming task supplies a fresh :class:`SurfSenseContextSchema`
|
|
# on every ``astream_events`` call, so this list is naturally
|
|
# scoped to the current turn. Allows cross-turn graph reuse via
|
|
# ``agent_cache``.
|
|
# 2. Legacy fallback (cache disabled / context not propagated): the
|
|
# constructor-injected ``self.mentioned_document_ids`` list. We
|
|
# drain it after the first read so a cached graph (no Phase 1.5
|
|
# wiring) doesn't keep replaying the same mentions on every
|
|
# turn.
|
|
#
|
|
# CRITICAL: distinguish "context absent" (legacy caller, no field at
|
|
# all) from "context provided but empty" (turn with no mentions).
|
|
# ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
|
|
# Python, so a naive ``if ctx_mentions:`` would fall through to the
|
|
# legacy closure on every no-mention follow-up turn — replaying the
|
|
# mentions baked in by turn 1's cache-miss build. Always drain the
|
|
# closure once the runtime path has fired so a cached middleware
|
|
# instance can never resurrect stale state.
|
|
mention_ids: list[int] = []
|
|
ctx = getattr(runtime, "context", None) if runtime is not None else None
|
|
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
|
|
if ctx_mentions is not None:
|
|
# Runtime path is authoritative — even an empty list means
|
|
# "this turn has no mentions", NOT "look at the closure".
|
|
mention_ids = list(ctx_mentions)
|
|
if self.mentioned_document_ids:
|
|
self.mentioned_document_ids = []
|
|
elif self.mentioned_document_ids:
|
|
mention_ids = list(self.mentioned_document_ids)
|
|
self.mentioned_document_ids = []
|
|
|
|
# Folder mentions live alongside doc mentions on the runtime
|
|
# context. They never feed hybrid search (folders aren't
|
|
# embedded) — they're surfaced purely as ``[USER-MENTIONED]``
|
|
# priority entries so the agent walks the folder with ``ls`` /
|
|
# ``find_documents`` instead of ignoring it. Cloud filesystem
|
|
# mode only.
|
|
folder_mention_ids: list[int] = []
|
|
if (
|
|
ctx is not None
|
|
and getattr(self, "filesystem_mode", FilesystemMode.CLOUD)
|
|
== FilesystemMode.CLOUD
|
|
):
|
|
ctx_folders = getattr(ctx, "mentioned_folder_ids", None)
|
|
if ctx_folders:
|
|
folder_mention_ids = list(ctx_folders)
|
|
|
|
mentioned_results: list[dict[str, Any]] = []
|
|
if mention_ids:
|
|
mentioned_results = await fetch_mentioned_documents(
|
|
document_ids=mention_ids,
|
|
search_space_id=self.search_space_id,
|
|
)
|
|
|
|
if is_recency:
|
|
doc_types = _resolve_search_types(
|
|
self.available_connectors, self.available_document_types
|
|
)
|
|
search_results = await browse_recent_documents(
|
|
search_space_id=self.search_space_id,
|
|
document_type=doc_types,
|
|
top_k=self.top_k,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
)
|
|
else:
|
|
search_results = await search_knowledge_base(
|
|
query=planned_query,
|
|
search_space_id=self.search_space_id,
|
|
available_connectors=self.available_connectors,
|
|
available_document_types=self.available_document_types,
|
|
top_k=self.top_k,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
)
|
|
|
|
seen_doc_ids: set[int] = set()
|
|
merged: list[dict[str, Any]] = []
|
|
for doc in mentioned_results:
|
|
doc_id = (doc.get("document") or {}).get("id")
|
|
if isinstance(doc_id, int):
|
|
seen_doc_ids.add(doc_id)
|
|
merged.append(doc)
|
|
for doc in search_results:
|
|
doc_id = (doc.get("document") or {}).get("id")
|
|
if isinstance(doc_id, int) and doc_id in seen_doc_ids:
|
|
continue
|
|
merged.append(doc)
|
|
|
|
priority, matched_chunk_ids = await self._materialize_priority(merged)
|
|
|
|
if folder_mention_ids:
|
|
folder_entries = await self._materialize_folder_priority(folder_mention_ids)
|
|
priority = folder_entries + priority
|
|
|
|
_perf_log.info(
|
|
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d folders=%d",
|
|
asyncio.get_event_loop().time() - t0,
|
|
user_text[:80],
|
|
len(priority),
|
|
len(mentioned_results),
|
|
len(folder_mention_ids),
|
|
)
|
|
|
|
update: dict[str, Any] = {
|
|
"kb_priority": priority,
|
|
"kb_matched_chunk_ids": matched_chunk_ids,
|
|
}
|
|
if self.inject_system_message:
|
|
new_messages = list(messages)
|
|
insert_at = max(len(new_messages) - 1, 0)
|
|
new_messages.insert(insert_at, _render_priority_message(priority))
|
|
update["messages"] = new_messages
|
|
return update
|
|
|
|
async def _materialize_folder_priority(
|
|
self, folder_ids: list[int]
|
|
) -> list[dict[str, Any]]:
|
|
"""Resolve user-mentioned folder ids to ``<priority_documents>`` entries.
|
|
|
|
Each entry uses the canonical ``/documents/Folder/Sub/`` virtual
|
|
path (matching ``KnowledgeTreeMiddleware`` and the agent's
|
|
``ls`` adapter) and is flagged ``mentioned=True`` so the
|
|
rendered line carries ``[USER-MENTIONED]``. ``score`` is left
|
|
``None`` so the renderer prints ``n/a`` — folders aren't
|
|
ranked, the agent decides which children to read.
|
|
"""
|
|
if not folder_ids:
|
|
return []
|
|
async with shielded_async_session() as session:
|
|
index: PathIndex = await build_path_index(session, self.search_space_id)
|
|
folder_rows = await session.execute(
|
|
select(Folder.id, Folder.name).where(
|
|
Folder.search_space_id == self.search_space_id,
|
|
Folder.id.in_(folder_ids),
|
|
)
|
|
)
|
|
folder_titles: dict[int, str] = {
|
|
row.id: row.name for row in folder_rows.all()
|
|
}
|
|
|
|
entries: list[dict[str, Any]] = []
|
|
seen: set[int] = set()
|
|
for folder_id in folder_ids:
|
|
if folder_id in seen:
|
|
continue
|
|
seen.add(folder_id)
|
|
base = index.folder_paths.get(folder_id)
|
|
if base is None:
|
|
logger.debug(
|
|
"kb_priority: dropping folder id=%s (missing from path index)",
|
|
folder_id,
|
|
)
|
|
continue
|
|
path = base if base.endswith("/") else f"{base}/"
|
|
entries.append(
|
|
{
|
|
"path": path,
|
|
"score": None,
|
|
"document_id": None,
|
|
"folder_id": folder_id,
|
|
"title": folder_titles.get(folder_id, ""),
|
|
"mentioned": True,
|
|
}
|
|
)
|
|
return entries
|
|
|
|
async def _materialize_priority(
|
|
self, merged: list[dict[str, Any]]
|
|
) -> tuple[list[dict[str, Any]], dict[int, list[int]]]:
|
|
"""Resolve canonical paths and matched chunk ids for the priority list."""
|
|
priority: list[dict[str, Any]] = []
|
|
matched_chunk_ids: dict[int, list[int]] = {}
|
|
|
|
if not merged:
|
|
return priority, matched_chunk_ids
|
|
|
|
async with shielded_async_session() as session:
|
|
index: PathIndex = await build_path_index(session, self.search_space_id)
|
|
doc_ids = [
|
|
(doc.get("document") or {}).get("id")
|
|
for doc in merged
|
|
if isinstance(doc, dict)
|
|
]
|
|
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
|
|
folder_by_doc_id: dict[int, int | None] = {}
|
|
if doc_ids:
|
|
folder_rows = await session.execute(
|
|
select(Document.id, Document.folder_id).where(
|
|
Document.search_space_id == self.search_space_id,
|
|
Document.id.in_(doc_ids),
|
|
)
|
|
)
|
|
folder_by_doc_id = {row.id: row.folder_id for row in folder_rows.all()}
|
|
|
|
for doc in merged:
|
|
doc_meta = doc.get("document") or {}
|
|
doc_id = doc_meta.get("id")
|
|
title = doc_meta.get("title") or "untitled"
|
|
folder_id = (
|
|
folder_by_doc_id.get(doc_id)
|
|
if isinstance(doc_id, int)
|
|
else doc_meta.get("folder_id")
|
|
)
|
|
path = doc_to_virtual_path(
|
|
doc_id=doc_id if isinstance(doc_id, int) else None,
|
|
title=str(title),
|
|
folder_id=folder_id if isinstance(folder_id, int) else None,
|
|
index=index,
|
|
)
|
|
priority.append(
|
|
{
|
|
"path": path,
|
|
"score": float(doc.get("score") or 0.0),
|
|
"document_id": doc_id if isinstance(doc_id, int) else None,
|
|
"title": str(title),
|
|
"mentioned": bool(doc.get("_user_mentioned")),
|
|
}
|
|
)
|
|
if isinstance(doc_id, int):
|
|
chunk_ids = doc.get("matched_chunk_ids") or []
|
|
if chunk_ids:
|
|
matched_chunk_ids[doc_id] = [
|
|
int(cid) for cid in chunk_ids if isinstance(cid, int | str)
|
|
]
|
|
return priority, matched_chunk_ids
|
|
|
|
|
|
# Backwards-compatible alias for any external imports.
|
|
KnowledgeBaseSearchMiddleware = KnowledgePriorityMiddleware
|
|
|
|
|
|
__all__ = [
|
|
"KnowledgeBaseSearchMiddleware",
|
|
"KnowledgePriorityMiddleware",
|
|
"browse_recent_documents",
|
|
"fetch_mentioned_documents",
|
|
"search_knowledge_base",
|
|
]
|