SurfSense/surfsense_backend/app/agents/shared/middleware/knowledge_search.py

1058 lines
39 KiB
Python
Raw Normal View History

"""Hybrid-search priority middleware for the SurfSense new chat agent.
This middleware runs ``before_agent`` on every turn and writes:
* ``state["kb_priority"]`` the top-K most relevant documents for the
current user message, used to render a ``<priority_documents>`` system
message immediately before the user turn.
* ``state["kb_matched_chunk_ids"]`` internal hand-off mapping
(``Document.id`` matched chunk IDs) consumed by
:class:`KBPostgresBackend._load_file_data` when the agent first reads each
document, so the XML wrapper can flag matched sections in
``<chunk_index>``.
The previous "scoped filesystem" behaviour (synthetic ``ls`` + state
``files`` seeding) is intentionally removed: documents are now lazy-loaded
from Postgres on demand, with the full workspace tree rendered separately
by :class:`KnowledgeTreeMiddleware`.
In anonymous mode the middleware skips hybrid search entirely and emits a
single-entry priority list pointing at the Redis-loaded document
(``state["kb_anon_doc"]``).
2026-03-28 16:39:46 -07:00
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
from collections.abc import Sequence
from datetime import UTC, datetime
2026-03-28 16:39:46 -07:00
from typing import Any
2026-04-28 09:22:19 -07:00
from langchain.agents import create_agent
2026-03-28 16:39:46 -07:00
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
2026-04-28 09:22:19 -07:00
from langchain_core.runnables import Runnable
2026-03-28 16:39:46 -07:00
from langgraph.runtime import Runtime
from litellm import token_counter
from pydantic import BaseModel, Field, ValidationError
2026-03-28 16:39:46 -07:00
from sqlalchemy import select
from app.agents.multi_agent_chat.shared.date_filters import (
parse_date_or_datetime,
resolve_date_range,
)
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
SurfSenseFilesystemState,
)
from app.agents.shared.feature_flags import get_flags
from app.agents.shared.filesystem_selection import FilesystemMode
from app.agents.shared.path_resolver import (
PathIndex,
build_path_index,
doc_to_virtual_path,
)
from app.db import (
NATIVE_TO_LEGACY_DOCTYPE,
Chunk,
Document,
Folder,
shielded_async_session,
)
2026-03-28 16:39:46 -07:00
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.utils.document_converters import embed_texts
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
class KBSearchPlan(BaseModel):
"""Structured internal plan for KB retrieval."""
optimized_query: str = Field(
min_length=1,
description="Optimized retrieval query preserving the user's intent.",
)
start_date: str | None = Field(
default=None,
description="Optional ISO start date or datetime for KB search filtering.",
)
end_date: str | None = Field(
default=None,
description="Optional ISO end date or datetime for KB search filtering.",
)
2026-04-14 01:43:30 -07:00
is_recency_query: bool = Field(
default=False,
description=(
"True when the user's intent is primarily about recency or temporal "
"ordering (e.g. 'latest', 'newest', 'most recent', 'last uploaded') "
"rather than topical relevance."
),
)
2026-03-28 16:39:46 -07:00
def _extract_text_from_message(message: BaseMessage) -> str:
content = getattr(message, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
return "\n".join(p for p in parts if p)
return str(content)
def _render_recent_conversation(
messages: Sequence[BaseMessage],
*,
llm: BaseChatModel | None = None,
user_text: str = "",
max_messages: int = 6,
) -> str:
"""Render recent dialogue for internal planning under a token budget.
Filters to ``HumanMessage`` and ``AIMessage`` (without tool_calls) so that
injected ``SystemMessage`` artefacts (priority list, workspace tree,
file-write contract) don't pollute the planner prompt.
"""
rendered: list[tuple[str, str]] = []
for message in messages:
role: str | None = None
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
if getattr(message, "tool_calls", None):
continue
role = "assistant"
else:
continue
text = _extract_text_from_message(message).strip()
if not text:
continue
text = re.sub(r"\s+", " ", text)
rendered.append((role, text))
if not rendered:
return ""
if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip():
rendered = rendered[:-1]
if not rendered:
return ""
def _legacy_render() -> str:
legacy_lines: list[str] = []
for role, text in rendered[-max_messages:]:
clipped = text[:400].rstrip() + "..." if len(text) > 400 else text
legacy_lines.append(f"{role}: {clipped}")
return "\n".join(legacy_lines)
def _count_prompt_tokens(conversation_text: str) -> int | None:
prompt = _build_kb_planner_prompt(
recent_conversation=conversation_text or "(none)",
user_text=user_text,
)
message_payload = [{"role": "user", "content": prompt}]
count_fn = getattr(llm, "_count_tokens", None) if llm is not None else None
if callable(count_fn):
try:
return count_fn(message_payload)
except Exception:
pass
profile = getattr(llm, "profile", None) if llm is not None else None
model_names: list[str] = []
if isinstance(profile, dict):
tcms = profile.get("token_count_models")
if isinstance(tcms, list):
model_names.extend(
name for name in tcms if isinstance(name, str) and name
)
tcm = profile.get("token_count_model")
if isinstance(tcm, str) and tcm and tcm not in model_names:
model_names.append(tcm)
model_name = model_names[0] if model_names else getattr(llm, "model", None)
if not isinstance(model_name, str) or not model_name:
return None
try:
return token_counter(messages=message_payload, model=model_name)
except Exception:
return None
get_max_input_tokens = getattr(llm, "_get_max_input_tokens", None) if llm else None
if callable(get_max_input_tokens):
try:
max_input_tokens = int(get_max_input_tokens())
except Exception:
max_input_tokens = None
else:
profile = getattr(llm, "profile", None) if llm is not None else None
max_input_tokens = (
profile.get("max_input_tokens")
if isinstance(profile, dict)
and isinstance(profile.get("max_input_tokens"), int)
else None
)
if not isinstance(max_input_tokens, int) or max_input_tokens <= 0:
return _legacy_render()
output_reserve = min(max(int(max_input_tokens * 0.02), 256), 1024)
budget = max_input_tokens - output_reserve
if budget <= 0:
return _legacy_render()
selected_lines: list[str] = []
for role, text in reversed(rendered):
candidate_line = f"{role}: {text}"
candidate_lines = [candidate_line, *selected_lines]
candidate_conversation = "\n".join(candidate_lines)
token_count = _count_prompt_tokens(candidate_conversation)
if token_count is None:
return _legacy_render()
if token_count <= budget:
selected_lines = candidate_lines
continue
lo, hi = 1, len(text)
best_line: str | None = None
while lo <= hi:
mid = (lo + hi) // 2
clipped_text = text[:mid].rstrip() + "..."
clipped_line = f"{role}: {clipped_text}"
clipped_conversation = "\n".join([clipped_line, *selected_lines])
clipped_tokens = _count_prompt_tokens(clipped_conversation)
if clipped_tokens is None:
break
if clipped_tokens <= budget:
best_line = clipped_line
lo = mid + 1
else:
hi = mid - 1
if best_line is not None:
selected_lines = [best_line, *selected_lines]
break
if not selected_lines:
return _legacy_render()
return "\n".join(selected_lines)
def _build_kb_planner_prompt(
*,
recent_conversation: str,
user_text: str,
) -> str:
today = datetime.now(UTC).date().isoformat()
return (
"You optimize internal knowledge-base search inputs for document retrieval.\n"
"Return JSON only with this exact shape:\n"
2026-04-14 01:43:30 -07:00
'{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null","is_recency_query":bool}\n\n'
"Rules:\n"
"- Preserve the user's intent.\n"
"- Rewrite the query to improve retrieval using concrete entities, acronyms, projects, tools, people, and document-specific terms when helpful.\n"
"- Keep the query concise and retrieval-focused.\n"
"- Only use date filters when the latest user request or recent dialogue clearly implies a time range.\n"
"- If you use date filters, prefer returning both bounds.\n"
"- If no date filter is useful, return null for both dates.\n"
2026-04-14 01:43:30 -07:00
'- Set "is_recency_query" to true ONLY when the user\'s primary intent is about '
"recency or temporal ordering rather than topical relevance. Examples: "
'"latest file", "newest upload", "most recent document", "what did I save last", '
'"show me files from today", "last thing I added". '
"When true, results will be sorted by date instead of relevance.\n"
"- Do not include markdown, prose, or explanations.\n\n"
f"Today's UTC date: {today}\n\n"
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
f"Latest user message:\n{user_text}"
)
def _extract_json_payload(text: str) -> str:
stripped = text.strip()
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
if fenced:
return fenced.group(1)
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start : end + 1]
return stripped
def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan:
payload = json.loads(_extract_json_payload(response_text))
return KBSearchPlan.model_validate(payload)
def _normalize_optional_date_range(
start_date: str | None,
end_date: str | None,
) -> tuple[datetime | None, datetime | None]:
parsed_start = parse_date_or_datetime(start_date) if start_date else None
parsed_end = parse_date_or_datetime(end_date) if end_date else None
if parsed_start is None and parsed_end is None:
return None, None
return resolve_date_range(parsed_start, parsed_end)
2026-03-28 16:39:46 -07:00
def _resolve_search_types(
available_connectors: list[str] | None,
available_document_types: list[str] | None,
) -> list[str] | None:
types: set[str] = set()
if available_document_types:
types.update(available_document_types)
if available_connectors:
types.update(available_connectors)
if not types:
return None
expanded: set[str] = set(types)
for t in types:
legacy = NATIVE_TO_LEGACY_DOCTYPE.get(t)
if legacy:
expanded.add(legacy)
return list(expanded) if expanded else None
2026-04-14 01:43:30 -07:00
_RECENCY_MAX_CHUNKS_PER_DOC = 5
async def browse_recent_documents(
*,
search_space_id: int,
document_type: list[str] | None = None,
top_k: int = 10,
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> list[dict[str, Any]]:
"""Return documents ordered by recency (newest first), no relevance ranking."""
from sqlalchemy import func
2026-04-14 01:43:30 -07:00
from app.db import DocumentType
async with shielded_async_session() as session:
base_conditions = [
Document.search_space_id == search_space_id,
func.coalesce(Document.status["state"].astext, "ready") != "deleting",
]
if document_type is not None:
import contextlib
doc_type_enums = []
for dt in document_type:
if isinstance(dt, str):
with contextlib.suppress(KeyError):
doc_type_enums.append(DocumentType[dt])
else:
doc_type_enums.append(dt)
if doc_type_enums:
if len(doc_type_enums) == 1:
base_conditions.append(Document.document_type == doc_type_enums[0])
else:
base_conditions.append(Document.document_type.in_(doc_type_enums))
if start_date is not None:
base_conditions.append(Document.updated_at >= start_date)
if end_date is not None:
base_conditions.append(Document.updated_at <= end_date)
doc_query = (
select(Document)
.where(*base_conditions)
.order_by(Document.updated_at.desc())
.limit(top_k)
)
result = await session.execute(doc_query)
documents = result.scalars().unique().all()
if not documents:
return []
doc_ids = [d.id for d in documents]
numbered = (
select(
Chunk.id.label("chunk_id"),
Chunk.document_id,
Chunk.content,
func.row_number()
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
.label("rn"),
)
.where(Chunk.document_id.in_(doc_ids))
.subquery("numbered")
)
chunk_query = (
select(numbered.c.chunk_id, numbered.c.document_id, numbered.c.content)
.where(numbered.c.rn <= _RECENCY_MAX_CHUNKS_PER_DOC)
.order_by(numbered.c.document_id, numbered.c.chunk_id)
)
chunk_result = await session.execute(chunk_query)
fetched_chunks = chunk_result.all()
doc_chunks: dict[int, list[dict[str, Any]]] = {d.id: [] for d in documents}
for row in fetched_chunks:
if row.document_id in doc_chunks:
doc_chunks[row.document_id].append(
{"chunk_id": row.chunk_id, "content": row.content}
)
results: list[dict[str, Any]] = []
for doc in documents:
chunks_list = doc_chunks.get(doc.id, [])
metadata = doc.document_metadata or {}
results.append(
{
"document_id": doc.id,
"content": "\n\n".join(
c["content"] for c in chunks_list if c.get("content")
),
"score": 0.0,
"chunks": chunks_list,
"matched_chunk_ids": [],
"document": {
"id": doc.id,
"title": doc.title,
"document_type": (
doc.document_type.value
if getattr(doc, "document_type", None)
else None
),
"metadata": metadata,
"folder_id": getattr(doc, "folder_id", None),
2026-04-14 01:43:30 -07:00
},
"source": (
doc.document_type.value
if getattr(doc, "document_type", None)
else None
),
}
)
return results
2026-03-28 16:39:46 -07:00
async def search_knowledge_base(
*,
query: str,
search_space_id: int,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
top_k: int = 10,
start_date: datetime | None = None,
end_date: datetime | None = None,
2026-03-28 16:39:46 -07:00
) -> list[dict[str, Any]]:
"""Run a single unified hybrid search against the knowledge base."""
2026-03-28 16:39:46 -07:00
if not query:
return []
[embedding] = await asyncio.to_thread(embed_texts, [query])
2026-03-28 16:39:46 -07:00
doc_types = _resolve_search_types(available_connectors, available_document_types)
retriever_top_k = min(top_k * 3, 30)
async with shielded_async_session() as session:
retriever = ChucksHybridSearchRetriever(session)
results = await retriever.hybrid_search(
query_text=query,
top_k=retriever_top_k,
search_space_id=search_space_id,
document_type=doc_types,
start_date=start_date,
end_date=end_date,
2026-03-28 16:39:46 -07:00
query_embedding=embedding.tolist(),
)
return results[:top_k]
async def fetch_mentioned_documents(
*,
document_ids: list[int],
search_space_id: int,
) -> list[dict[str, Any]]:
"""Fetch explicitly mentioned documents."""
if not document_ids:
return []
async with shielded_async_session() as session:
doc_result = await session.execute(
select(Document).where(
Document.id.in_(document_ids),
Document.search_space_id == search_space_id,
)
)
docs = {doc.id: doc for doc in doc_result.scalars().all()}
if not docs:
return []
chunk_result = await session.execute(
select(Chunk.id, Chunk.content, Chunk.document_id)
.where(Chunk.document_id.in_(list(docs.keys())))
.order_by(Chunk.document_id, Chunk.id)
)
chunks_by_doc: dict[int, list[dict[str, Any]]] = {doc_id: [] for doc_id in docs}
for row in chunk_result.all():
if row.document_id in chunks_by_doc:
chunks_by_doc[row.document_id].append(
{"chunk_id": row.id, "content": row.content}
)
results: list[dict[str, Any]] = []
for doc_id in document_ids:
doc = docs.get(doc_id)
if doc is None:
continue
metadata = doc.document_metadata or {}
results.append(
{
"document_id": doc.id,
"content": "",
"score": 1.0,
"chunks": chunks_by_doc.get(doc.id, []),
"matched_chunk_ids": [],
"document": {
"id": doc.id,
"title": doc.title,
"document_type": (
doc.document_type.value
if getattr(doc, "document_type", None)
else None
),
"metadata": metadata,
"folder_id": getattr(doc, "folder_id", None),
},
"source": (
doc.document_type.value
if getattr(doc, "document_type", None)
else None
),
"_user_mentioned": True,
}
)
return results
def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage:
"""Render the priority list as a single ``<priority_documents>`` system message."""
if not priority:
body = "(no priority documents for this turn)"
else:
lines: list[str] = []
for entry in priority:
score = entry.get("score")
mentioned = entry.get("mentioned")
2026-04-28 21:37:51 -07:00
score_str = f"{score:.3f}" if isinstance(score, int | float) else "n/a"
mark = " [USER-MENTIONED]" if mentioned else ""
lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}")
body = "\n".join(lines)
return SystemMessage(
content=(
"<priority_documents>\n"
"These documents are most relevant to the latest user message; "
"read them first. Matched sections are flagged inside each "
"document's <chunk_index>.\n"
f"{body}\n"
"</priority_documents>"
)
)
class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Compute hybrid-search priority hints for the current turn."""
2026-03-28 16:39:46 -07:00
tools = ()
state_schema = SurfSenseFilesystemState
2026-03-28 16:39:46 -07:00
def __init__(
self,
*,
llm: BaseChatModel | None = None,
planner_llm: BaseChatModel | None = None,
2026-03-28 16:39:46 -07:00
search_space_id: int,
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
2026-03-28 16:39:46 -07:00
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
top_k: int = 10,
mentioned_document_ids: list[int] | None = None,
2026-05-15 17:33:44 -07:00
inject_system_message: bool = True, # For backwards compatibility
2026-03-28 16:39:46 -07:00
) -> None:
self.llm = llm
# The planner LLM handles short, structured internal tasks (query
# rewriting, date extraction, recency classification). When an
# operator marks a global config ``is_planner: true`` we route
# those calls to a cheap/fast model (e.g. gpt-4o-mini, Haiku, Azure
# gpt-5.x-nano) instead of the user's chat LLM — those classification
# tasks don't need frontier-tier capability. Falls back to the chat
# LLM when no planner config is wired up so deployments without one
# keep working unchanged.
self.planner_llm = planner_llm or llm
2026-03-28 16:39:46 -07:00
self.search_space_id = search_space_id
self.filesystem_mode = filesystem_mode
2026-03-28 16:39:46 -07:00
self.available_connectors = available_connectors
self.available_document_types = available_document_types
self.top_k = top_k
self.mentioned_document_ids = mentioned_document_ids or []
self.inject_system_message = inject_system_message
2026-04-28 23:52:37 -07:00
# Build the kb-planner private Runnable ONCE here so we don't pay
# the ``create_agent`` compile cost (50-200ms) on every turn.
# Disabled by default behind ``enable_kb_planner_runnable``; when
# off the planner falls back to the legacy ``planner_llm.ainvoke``
2026-04-28 23:52:37 -07:00
# path.
2026-04-28 09:22:19 -07:00
self._planner: Runnable | None = None
self._planner_compile_failed = False
def _build_kb_planner_runnable(self) -> Runnable | None:
"""Compile the kb-planner private :class:`Runnable` once.
Returns ``None`` when the feature flag is disabled, when the LLM is
unavailable, or when ``create_agent`` raises (we fall back to the
legacy ``planner_llm.ainvoke`` path in that case). Compilation happens
2026-04-28 09:22:19 -07:00
lazily on first call, then memoized via ``self._planner``.
The compiled agent is constructed without tools the planner's
2026-04-28 23:52:37 -07:00
contract is "answer with structured JSON" but it inherits the
:class:`RetryAfterMiddleware` so transient rate-limit errors
from the planner LLM call don't fail the whole turn.
2026-04-28 09:22:19 -07:00
"""
if self._planner is not None or self._planner_compile_failed:
return self._planner
if self.planner_llm is None:
2026-04-28 09:22:19 -07:00
return None
flags = get_flags()
2026-04-28 21:37:51 -07:00
if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
2026-04-28 09:22:19 -07:00
return None
from app.agents.shared.middleware.retry_after import RetryAfterMiddleware
2026-04-28 09:22:19 -07:00
try:
self._planner = create_agent(
self.planner_llm,
2026-04-28 09:22:19 -07:00
tools=[],
middleware=[RetryAfterMiddleware(max_retries=2)],
)
except Exception as exc: # pragma: no cover - defensive
logger.warning(
"kb-planner Runnable compile failed; falling back to planner_llm.ainvoke: %s",
2026-04-28 09:22:19 -07:00
exc,
)
self._planner_compile_failed = True
self._planner = None
return self._planner
2026-03-28 16:39:46 -07:00
async def _plan_search_inputs(
self,
*,
messages: Sequence[BaseMessage],
user_text: str,
2026-04-14 01:43:30 -07:00
) -> tuple[str, datetime | None, datetime | None, bool]:
if self.planner_llm is None:
2026-04-14 01:43:30 -07:00
return user_text, None, None, False
recent_conversation = _render_recent_conversation(
messages,
llm=self.planner_llm,
user_text=user_text,
)
prompt = _build_kb_planner_prompt(
recent_conversation=recent_conversation,
user_text=user_text,
)
loop = asyncio.get_running_loop()
t0 = loop.time()
2026-04-28 23:52:37 -07:00
# Prefer the compiled-once planner Runnable when enabled; otherwise
# fall back to ``planner_llm.ainvoke``. The ``surfsense:internal``
# tag is preserved on both paths so ``_stream_agent_events`` still
2026-04-28 09:22:19 -07:00
# suppresses the planner's intermediate events from the UI.
planner = self._build_kb_planner_runnable()
try:
2026-04-28 09:22:19 -07:00
if planner is not None:
planner_state = await planner.ainvoke(
{"messages": [HumanMessage(content=prompt)]},
config={"tags": ["surfsense:internal"]},
)
response_messages = (
planner_state.get("messages", [])
if isinstance(planner_state, dict)
else []
)
response = (
response_messages[-1]
if response_messages
else AIMessage(content="")
)
else:
response = await self.planner_llm.ainvoke(
2026-04-28 09:22:19 -07:00
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
plan = _parse_kb_search_plan_response(_extract_text_from_message(response))
optimized_query = (
re.sub(r"\s+", " ", plan.optimized_query).strip() or user_text
)
start_date, end_date = _normalize_optional_date_range(
plan.start_date,
plan.end_date,
)
2026-04-14 01:43:30 -07:00
is_recency = plan.is_recency_query
_perf_log.info(
"[kb_priority] planner in %.3fs query=%r optimized=%r "
2026-04-14 01:43:30 -07:00
"start=%s end=%s recency=%s",
loop.time() - t0,
user_text[:80],
optimized_query[:120],
start_date.isoformat() if start_date else None,
end_date.isoformat() if end_date else None,
2026-04-14 01:43:30 -07:00
is_recency,
)
2026-04-14 01:43:30 -07:00
return optimized_query, start_date, end_date, is_recency
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
logger.warning(
"KB planner returned invalid output, using raw query: %s", exc
)
except Exception as exc: # pragma: no cover - defensive fallback
logger.warning("KB planner failed, using raw query: %s", exc)
2026-04-14 01:43:30 -07:00
return user_text, None, None, False
2026-03-28 16:39:46 -07:00
def before_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
try:
loop = asyncio.get_running_loop()
if loop.is_running():
return None
except RuntimeError:
pass
return asyncio.run(self.abefore_agent(state, runtime))
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
if self.filesystem_mode != FilesystemMode.CLOUD:
return None
2026-03-28 16:39:46 -07:00
messages = state.get("messages") or []
if not messages:
return None
last_human: HumanMessage | None = None
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
last_human = msg
break
if last_human is None:
2026-03-28 16:39:46 -07:00
return None
user_text = _extract_text_from_message(last_human).strip()
2026-03-28 16:39:46 -07:00
if not user_text:
return None
anon_doc = state.get("kb_anon_doc")
if anon_doc:
return self._anon_priority(state, anon_doc)
return await self._authenticated_priority(state, messages, user_text, runtime)
def _anon_priority(
self,
state: AgentState,
anon_doc: dict[str, Any],
) -> dict[str, Any]:
path = str(anon_doc.get("path") or "")
title = str(anon_doc.get("title") or "uploaded_document")
priority = [
{
"path": path,
"score": 1.0,
"document_id": None,
"title": title,
"mentioned": True,
}
]
update: dict[str, Any] = {
"kb_priority": priority,
"kb_matched_chunk_ids": {},
}
if self.inject_system_message:
new_messages = list(state.get("messages") or [])
insert_at = max(len(new_messages) - 1, 0)
new_messages.insert(insert_at, _render_priority_message(priority))
update["messages"] = new_messages
return update
async def _authenticated_priority(
self,
state: AgentState,
messages: Sequence[BaseMessage],
user_text: str,
runtime: Runtime[Any] | None = None,
) -> dict[str, Any]:
t0 = asyncio.get_event_loop().time()
2026-04-14 01:43:30 -07:00
(
planned_query,
start_date,
end_date,
is_recency,
) = await self._plan_search_inputs(
messages=messages,
user_text=user_text,
)
2026-03-28 16:39:46 -07:00
# Per-turn ``mentioned_document_ids`` flow:
# 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
# streaming task supplies a fresh :class:`SurfSenseContextSchema`
# on every ``astream_events`` call, so this list is naturally
# scoped to the current turn. Allows cross-turn graph reuse via
# ``agent_cache``.
# 2. Legacy fallback (cache disabled / context not propagated): the
# constructor-injected ``self.mentioned_document_ids`` list. We
# drain it after the first read so a cached graph (no Phase 1.5
# wiring) doesn't keep replaying the same mentions on every
# turn.
#
# CRITICAL: distinguish "context absent" (legacy caller, no field at
# all) from "context provided but empty" (turn with no mentions).
# ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
# Python, so a naive ``if ctx_mentions:`` would fall through to the
# legacy closure on every no-mention follow-up turn — replaying the
# mentions baked in by turn 1's cache-miss build. Always drain the
# closure once the runtime path has fired so a cached middleware
# instance can never resurrect stale state.
mention_ids: list[int] = []
ctx = getattr(runtime, "context", None) if runtime is not None else None
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
if ctx_mentions is not None:
# Runtime path is authoritative — even an empty list means
# "this turn has no mentions", NOT "look at the closure".
mention_ids = list(ctx_mentions)
if self.mentioned_document_ids:
self.mentioned_document_ids = []
elif self.mentioned_document_ids:
mention_ids = list(self.mentioned_document_ids)
self.mentioned_document_ids = []
# Folder mentions live alongside doc mentions on the runtime
# context. They never feed hybrid search (folders aren't
# embedded) — they're surfaced purely as ``[USER-MENTIONED]``
# priority entries so the agent walks the folder with ``ls`` /
# ``find_documents`` instead of ignoring it. Cloud filesystem
# mode only.
folder_mention_ids: list[int] = []
if (
ctx is not None
and getattr(self, "filesystem_mode", FilesystemMode.CLOUD)
== FilesystemMode.CLOUD
):
ctx_folders = getattr(ctx, "mentioned_folder_ids", None)
if ctx_folders:
folder_mention_ids = list(ctx_folders)
mentioned_results: list[dict[str, Any]] = []
if mention_ids:
mentioned_results = await fetch_mentioned_documents(
document_ids=mention_ids,
search_space_id=self.search_space_id,
)
2026-04-14 01:43:30 -07:00
if is_recency:
doc_types = _resolve_search_types(
self.available_connectors, self.available_document_types
)
search_results = await browse_recent_documents(
search_space_id=self.search_space_id,
document_type=doc_types,
top_k=self.top_k,
start_date=start_date,
end_date=end_date,
)
else:
search_results = await search_knowledge_base(
query=planned_query,
search_space_id=self.search_space_id,
available_connectors=self.available_connectors,
available_document_types=self.available_document_types,
top_k=self.top_k,
start_date=start_date,
end_date=end_date,
)
seen_doc_ids: set[int] = set()
merged: list[dict[str, Any]] = []
for doc in mentioned_results:
doc_id = (doc.get("document") or {}).get("id")
if isinstance(doc_id, int):
seen_doc_ids.add(doc_id)
merged.append(doc)
for doc in search_results:
doc_id = (doc.get("document") or {}).get("id")
if isinstance(doc_id, int) and doc_id in seen_doc_ids:
continue
merged.append(doc)
priority, matched_chunk_ids = await self._materialize_priority(merged)
2026-03-28 16:39:46 -07:00
if folder_mention_ids:
folder_entries = await self._materialize_folder_priority(folder_mention_ids)
priority = folder_entries + priority
_perf_log.info(
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d folders=%d",
asyncio.get_event_loop().time() - t0,
user_text[:80],
len(priority),
len(mentioned_results),
len(folder_mention_ids),
)
2026-03-28 16:39:46 -07:00
update: dict[str, Any] = {
"kb_priority": priority,
"kb_matched_chunk_ids": matched_chunk_ids,
}
if self.inject_system_message:
new_messages = list(messages)
insert_at = max(len(new_messages) - 1, 0)
new_messages.insert(insert_at, _render_priority_message(priority))
update["messages"] = new_messages
return update
async def _materialize_folder_priority(
self, folder_ids: list[int]
) -> list[dict[str, Any]]:
"""Resolve user-mentioned folder ids to ``<priority_documents>`` entries.
Each entry uses the canonical ``/documents/Folder/Sub/`` virtual
path (matching ``KnowledgeTreeMiddleware`` and the agent's
``ls`` adapter) and is flagged ``mentioned=True`` so the
rendered line carries ``[USER-MENTIONED]``. ``score`` is left
``None`` so the renderer prints ``n/a`` folders aren't
ranked, the agent decides which children to read.
"""
if not folder_ids:
return []
async with shielded_async_session() as session:
index: PathIndex = await build_path_index(session, self.search_space_id)
folder_rows = await session.execute(
select(Folder.id, Folder.name).where(
Folder.search_space_id == self.search_space_id,
Folder.id.in_(folder_ids),
)
)
folder_titles: dict[int, str] = {
row.id: row.name for row in folder_rows.all()
}
entries: list[dict[str, Any]] = []
seen: set[int] = set()
for folder_id in folder_ids:
if folder_id in seen:
continue
seen.add(folder_id)
base = index.folder_paths.get(folder_id)
if base is None:
logger.debug(
"kb_priority: dropping folder id=%s (missing from path index)",
folder_id,
)
continue
path = base if base.endswith("/") else f"{base}/"
entries.append(
{
"path": path,
"score": None,
"document_id": None,
"folder_id": folder_id,
"title": folder_titles.get(folder_id, ""),
"mentioned": True,
}
)
return entries
async def _materialize_priority(
self, merged: list[dict[str, Any]]
) -> tuple[list[dict[str, Any]], dict[int, list[int]]]:
"""Resolve canonical paths and matched chunk ids for the priority list."""
priority: list[dict[str, Any]] = []
matched_chunk_ids: dict[int, list[int]] = {}
if not merged:
return priority, matched_chunk_ids
async with shielded_async_session() as session:
index: PathIndex = await build_path_index(session, self.search_space_id)
doc_ids = [
(doc.get("document") or {}).get("id")
for doc in merged
if isinstance(doc, dict)
]
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
folder_by_doc_id: dict[int, int | None] = {}
if doc_ids:
folder_rows = await session.execute(
select(Document.id, Document.folder_id).where(
Document.search_space_id == self.search_space_id,
Document.id.in_(doc_ids),
)
)
folder_by_doc_id = {row.id: row.folder_id for row in folder_rows.all()}
for doc in merged:
doc_meta = doc.get("document") or {}
doc_id = doc_meta.get("id")
title = doc_meta.get("title") or "untitled"
folder_id = (
folder_by_doc_id.get(doc_id)
if isinstance(doc_id, int)
else doc_meta.get("folder_id")
)
path = doc_to_virtual_path(
doc_id=doc_id if isinstance(doc_id, int) else None,
title=str(title),
folder_id=folder_id if isinstance(folder_id, int) else None,
index=index,
)
priority.append(
{
"path": path,
"score": float(doc.get("score") or 0.0),
"document_id": doc_id if isinstance(doc_id, int) else None,
"title": str(title),
"mentioned": bool(doc.get("_user_mentioned")),
}
2026-03-28 16:39:46 -07:00
)
if isinstance(doc_id, int):
chunk_ids = doc.get("matched_chunk_ids") or []
if chunk_ids:
matched_chunk_ids[doc_id] = [
2026-04-28 21:37:51 -07:00
int(cid) for cid in chunk_ids if isinstance(cid, int | str)
]
return priority, matched_chunk_ids
__all__ = [
"KnowledgePriorityMiddleware",
"browse_recent_documents",
"fetch_mentioned_documents",
"search_knowledge_base",
]