diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/search_knowledge_base.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/search_knowledge_base.py index 9236e9121..ad47816f9 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/search_knowledge_base.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/search_knowledge_base.py @@ -33,6 +33,7 @@ from app.agents.chat.runtime.path_resolver import ( ) from app.db import Document, shielded_async_session from app.utils.perf import get_perf_logger +from app.utils.text_spans import char_span_to_line_range _perf_log = get_perf_logger() @@ -56,12 +57,16 @@ _TOOL_DESCRIPTION = ( ) -async def _resolve_virtual_paths( +async def _resolve_doc_context( results: list[dict[str, Any]], *, search_space_id: int, -) -> dict[int, str]: - """Resolve ``Document.id`` -> canonical virtual path for the search hits.""" +) -> tuple[dict[int, str], dict[int, str]]: + """Resolve ``Document.id`` -> (canonical virtual path, source_markdown). + + ``source_markdown`` is the canonical body the chunk spans index into; the + renderer uses it to turn a chunk's char span into a line range. + """ doc_ids = [ doc_id for doc_id in ( @@ -72,17 +77,24 @@ async def _resolve_virtual_paths( if isinstance(doc_id, int) ] if not doc_ids: - return {} + return {}, {} async with shielded_async_session() as session: index: PathIndex = await build_path_index(session, search_space_id) - folder_rows = await session.execute( - select(Document.id, Document.folder_id).where( + rows = await session.execute( + select( + Document.id, Document.folder_id, Document.source_markdown + ).where( Document.search_space_id == search_space_id, Document.id.in_(doc_ids), ) ) - folder_by_doc_id = {row.id: row.folder_id for row in folder_rows.all()} + folder_by_doc_id: dict[int, int | None] = {} + bodies: dict[int, str] = {} + for row in rows.all(): + folder_by_doc_id[row.id] = row.folder_id + if row.source_markdown: + bodies[row.id] = row.source_markdown paths: dict[int, str] = {} for doc in results: @@ -97,13 +109,69 @@ async def _resolve_virtual_paths( folder_id=folder_id if isinstance(folder_id, int) else None, index=index, ) - return paths + return paths, bodies + + +def _line_label(chunk: dict[str, Any], body: str | None) -> str: + """``[lines X-Y]`` for a span-bearing chunk, or '' when spans are absent.""" + start = chunk.get("start_char") + end = chunk.get("end_char") + if not body or not isinstance(start, int) or not isinstance(end, int): + return "" + start_line, end_line = char_span_to_line_range(body, start, end) + if start_line == end_line: + return f"[line {start_line}]" + return f"[lines {start_line}-{end_line}]" + + +def _render_passage(chunk: dict[str, Any], body: str | None) -> str | None: + """Render one matched chunk as an indented, line-annotated passage.""" + content = (chunk.get("content") or "").strip() + if not content: + return None + snippet = content[:_PER_DOC_SNIPPET_CHARS].strip() + if len(content) > _PER_DOC_SNIPPET_CHARS: + snippet += " ..." + indented = snippet.replace("\n", "\n ") + label = _line_label(chunk, body) + head = f"\n {label}" if label else "" + return f"{head}\n {indented}" + + +def _matched_passages(doc: dict[str, Any], body: str | None) -> str: + """Render the RRF-matched chunks; '' when none can be rendered.""" + by_id = { + c.get("chunk_id"): c + for c in (doc.get("chunks") or []) + if isinstance(c, dict) + } + rendered: list[str] = [] + for chunk_id in doc.get("matched_chunk_ids") or []: + chunk = by_id.get(chunk_id) + if chunk is None: + continue + passage = _render_passage(chunk, body) + if passage: + rendered.append(passage) + return "".join(rendered) + + +def _fallback_snippet(doc: dict[str, Any]) -> str: + """Top-of-document preview, used only when no matched chunk is available.""" + content = (doc.get("content") or "").strip() + if not content: + return "\n (no preview available; read the document for details)" + snippet = content[:_PER_DOC_SNIPPET_CHARS].strip() + if len(content) > _PER_DOC_SNIPPET_CHARS: + snippet += " ..." + return "\n " + snippet.replace("\n", "\n ") def _format_hits( results: list[dict[str, Any]], *, paths: dict[int, str], + bodies: dict[int, str], query: str, ) -> str: """Render search hits as a compact, model-readable block.""" @@ -124,21 +192,14 @@ def _format_hits( score = doc.get("score") score_str = f"{score:.3f}" if isinstance(score, int | float) else "n/a" path = paths.get(doc_id) if isinstance(doc_id, int) else None + body = bodies.get(doc_id) if isinstance(doc_id, int) else None header = f"\n{rank}. {title} (type={doc_type}, score={score_str})" + ( f"\n path: {path}" if path else "" ) - content = (doc.get("content") or "").strip() - if content: - snippet = content[:_PER_DOC_SNIPPET_CHARS].strip() - if len(content) > _PER_DOC_SNIPPET_CHARS: - snippet += " ..." - body = "\n " + snippet.replace("\n", "\n ") - else: - body = "\n (no preview available; read the document for details)" - - entry = header + body + passages = _matched_passages(doc, body) + entry = header + (passages or _fallback_snippet(doc)) if total + len(entry) > _MAX_TOTAL_CHARS: lines.append("\n") break @@ -204,8 +265,10 @@ def create_search_knowledge_base_tool( top_k=clamped_top_k, ) - paths = await _resolve_virtual_paths(results, search_space_id=_space_id) - rendered = _format_hits(results, paths=paths, query=cleaned_query) + paths, bodies = await _resolve_doc_context(results, search_space_id=_space_id) + rendered = _format_hits( + results, paths=paths, bodies=bodies, query=cleaned_query + ) matched = _matched_chunk_ids(results) _perf_log.info(