From 40a091f8cc2dafac4c2e5ceea4ff2f321251d9d0 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Sat, 28 Feb 2026 19:40:24 -0800 Subject: [PATCH] feat: enhance knowledge base search and document retrieval - Introduced a mechanism to identify degenerate queries that lack meaningful search signals, improving search accuracy. - Implemented a fallback method for browsing recent documents when queries are degenerate, ensuring relevant results are returned. - Added limits on the number of chunks fetched per document to optimize performance and prevent excessive data loading. - Updated the ConnectorService to allow for reusable query embeddings, enhancing efficiency in search operations. - Enhanced LLM router service to support context window fallbacks, improving robustness during context window limitations. --- .../agents/new_chat/tools/knowledge_base.py | 423 ++++++++++++++---- .../app/agents/new_chat/tools/registry.py | 6 +- .../app/agents/new_chat/tools/report.py | 2 + .../app/retriever/chunks_hybrid_search.py | 22 +- .../app/retriever/documents_hybrid_search.py | 16 +- .../app/services/connector_service.py | 18 +- .../app/services/llm_router_service.py | 89 +++- 7 files changed, 476 insertions(+), 100 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py index 9394d68b4..16cad80e5 100644 --- a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py @@ -10,6 +10,7 @@ This module provides: import asyncio import json +import re import time from datetime import datetime from typing import Any @@ -22,6 +23,149 @@ from app.db import async_session_maker from app.services.connector_service import ConnectorService from app.utils.perf import get_perf_logger +# Connectors that call external live-search APIs (no local DB / embedding needed). +# These are never filtered by available_document_types. +_LIVE_SEARCH_CONNECTORS: set[str] = { + "TAVILY_API", + "SEARXNG_API", + "LINKUP_API", + "BAIDU_SEARCH_API", +} + +# Patterns that indicate the query has no meaningful search signal. +# plainto_tsquery('english', '*') produces an empty tsquery and an embedding +# of '*' is random noise, so both keyword and semantic search degrade to +# arbitrary ordering — large documents (many chunks) dominate by chance. +_DEGENERATE_QUERY_RE = re.compile( + r"^[\s*?_.#@!\-/\\]+$" # only wildcards, punctuation, whitespace +) + +# Max chunks per document when doing a recency-based browse instead of +# a real search. We want breadth (many docs) over depth (many chunks). +_BROWSE_MAX_CHUNKS_PER_DOC = 5 + + +def _is_degenerate_query(query: str) -> bool: + """Return True when the query carries no meaningful search signal. + + Catches wildcard patterns (``*``, ``**``), empty / whitespace-only + strings, and single-character non-word tokens. These queries cause + both keyword search (empty tsquery) and semantic search (meaningless + embedding) to return effectively random results. + """ + stripped = query.strip() + if not stripped: + return True + return bool(_DEGENERATE_QUERY_RE.match(stripped)) + + +async def _browse_recent_documents( + search_space_id: int, + document_type: str | None, + top_k: int, + start_date: datetime | None, + end_date: datetime | None, +) -> list[dict[str, Any]]: + """Return the most-recent documents (recency-ordered, no search ranking). + + Used as a fallback when the search query is degenerate (e.g. ``*``) and + semantic / keyword search would produce arbitrary results. Returns + document-grouped dicts in the same shape as ``_combined_rrf_search`` + so the rest of the pipeline works unchanged. + """ + from sqlalchemy import select + from sqlalchemy.orm import joinedload + + from app.db import Chunk, Document, DocumentType + + perf = get_perf_logger() + t0 = time.perf_counter() + + base_conditions = [Document.search_space_id == search_space_id] + + if document_type is not None: + if isinstance(document_type, str): + try: + doc_type_enum = DocumentType[document_type] + base_conditions.append(Document.document_type == doc_type_enum) + except KeyError: + return [] + else: + base_conditions.append(Document.document_type == document_type) + + if start_date is not None: + base_conditions.append(Document.updated_at >= start_date) + if end_date is not None: + base_conditions.append(Document.updated_at <= end_date) + + async with async_session_maker() as session: + doc_query = ( + select(Document) + .options(joinedload(Document.search_space)) + .where(*base_conditions) + .order_by(Document.updated_at.desc()) + .limit(top_k) + ) + result = await session.execute(doc_query) + documents = result.scalars().unique().all() + + if not documents: + return [] + + doc_ids = [d.id for d in documents] + + chunk_query = ( + select(Chunk) + .where(Chunk.document_id.in_(doc_ids)) + .order_by(Chunk.document_id, Chunk.id) + ) + chunk_result = await session.execute(chunk_query) + raw_chunks = chunk_result.scalars().all() + + doc_chunk_counts: dict[int, int] = {} + doc_chunks: dict[int, list[dict]] = {d.id: [] for d in documents} + for chunk in raw_chunks: + did = chunk.document_id + count = doc_chunk_counts.get(did, 0) + if count < _BROWSE_MAX_CHUNKS_PER_DOC: + doc_chunks[did].append({"chunk_id": chunk.id, "content": chunk.content}) + doc_chunk_counts[did] = count + 1 + + results: list[dict[str, Any]] = [] + for doc in documents: + chunks_list = doc_chunks.get(doc.id, []) + results.append( + { + "document_id": doc.id, + "content": "\n\n".join( + c["content"] for c in chunks_list if c.get("content") + ), + "score": 0.0, + "chunks": chunks_list, + "document": { + "id": doc.id, + "title": doc.title, + "document_type": doc.document_type.value + if getattr(doc, "document_type", None) + else None, + "metadata": doc.document_metadata or {}, + }, + "source": doc.document_type.value + if getattr(doc, "document_type", None) + else None, + } + ) + + perf.info( + "[kb_browse] recency browse in %.3fs docs=%d space=%d type=%s", + time.perf_counter() - t0, + len(results), + search_space_id, + document_type, + ) + return results + + # ============================================================================= # Connector Constants and Normalization # ============================================================================= @@ -184,9 +328,23 @@ _CHARS_PER_TOKEN = 4 # Hard-floor / ceiling so the budget is always sensible regardless of what # the model reports. _MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens -_MAX_TOOL_OUTPUT_CHARS = 400_000 # ~100K tokens +_MAX_TOOL_OUTPUT_CHARS = 200_000 # ~50K tokens _MAX_CHUNK_CHARS = 8_000 +# Rank-adaptive per-document budget allocation. +# Top-ranked (most relevant) documents get a larger share of the budget so +# we pack as much high-quality context as possible. +# +# fraction(rank) = _TOP_DOC_BUDGET_FRACTION / (1 + rank * _RANK_DECAY) +# +# Examples (128K budget, 8K chunk cap): +# rank 0 → 40% → 6 chunks | rank 3 → 19% → 3 chunks +# rank 1 → 30% → 4 chunks | rank 10 → 10% → 3 chunks (floor) +# rank 2 → 24% → 3 chunks | +_TOP_DOC_BUDGET_FRACTION = 0.40 +_RANK_DECAY = 0.35 +_MIN_CHUNKS_PER_DOC = 3 + def _compute_tool_output_budget(max_input_tokens: int | None) -> int: """Derive a character budget from the model's context window. @@ -208,18 +366,24 @@ def format_documents_for_context( *, max_chars: int = _MAX_TOOL_OUTPUT_CHARS, max_chunk_chars: int = _MAX_CHUNK_CHARS, + max_chunks_per_doc: int = 0, ) -> str: """ Format retrieved documents into a readable context string for the LLM. Documents are added in order (highest relevance first) until the character - budget is reached. Individual chunks are capped at ``max_chunk_chars`` so - a single oversized chunk cannot monopolize the output. + budget is reached. Individual chunks are capped at ``max_chunk_chars`` and + each document is limited to a dynamically computed chunk cap so a single + large document cannot monopolize the output while still maximising the use + of available context space. Args: documents: List of document dictionaries from connector search max_chars: Approximate character budget for the entire output. max_chunk_chars: Per-chunk character cap (content is tail-truncated). + max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means + auto-compute per document using a rank-adaptive formula so + higher-ranked documents receive more chunks. Returns: Formatted string with document contents and metadata @@ -342,7 +506,23 @@ def format_documents_for_context( "", ] - for ch in g["chunks"]: + # Rank-adaptive per-document chunk cap: top results get more chunks. + if max_chunks_per_doc > 0: + chunks_allowed = max_chunks_per_doc + else: + doc_fraction = _TOP_DOC_BUDGET_FRACTION / (1 + doc_idx * _RANK_DECAY) + max_doc_chars = int(max_chars * doc_fraction) + xml_overhead = 500 + chunks_allowed = max( + (max_doc_chars - xml_overhead) // max(max_chunk_chars, 1), + _MIN_CHUNKS_PER_DOC, + ) + + chunks = g["chunks"] + if len(chunks) > chunks_allowed: + chunks = chunks[:chunks_allowed] + + for ch in chunks: ch_content = ch["content"] if max_chunk_chars and len(ch_content) > max_chunk_chars: ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)" @@ -359,9 +539,11 @@ def format_documents_for_context( doc_xml = "\n".join(doc_lines) doc_len = len(doc_xml) - # Always include at least the first document; afterwards enforce budget. - if doc_idx > 0 and total_chars + doc_len > max_chars: + if total_chars + doc_len > max_chars: remaining = total_docs - doc_idx + if doc_idx == 0: + parts.append(doc_xml) + total_chars += doc_len parts.append( f"" + result = result[: max_chars - len(truncation_msg)] + truncation_msg + + return result # ============================================================================= @@ -390,6 +580,7 @@ async def search_knowledge_base_async( start_date: datetime | None = None, end_date: datetime | None = None, available_connectors: list[str] | None = None, + available_document_types: list[str] | None = None, max_input_tokens: int | None = None, ) -> str: """ @@ -408,6 +599,9 @@ async def search_knowledge_base_async( end_date: Optional end datetime (UTC) for filtering documents available_connectors: Optional list of connectors actually available in the search space. If provided, only these connectors will be searched. + available_document_types: Optional list of document types that actually have indexed + data. When provided, local connectors whose document type is + absent are skipped entirely (no embedding / DB round-trip). max_input_tokens: Model context window size (tokens). Used to dynamically size the output so it fits within the model's limits. @@ -428,6 +622,23 @@ async def search_knowledge_base_async( ) connectors = _normalize_connectors(connectors_to_search, available_connectors) + + # --- Optimization 1: skip local connectors that have zero indexed documents --- + if available_document_types: + doc_types_set = set(available_document_types) + before_count = len(connectors) + connectors = [ + c for c in connectors if c in _LIVE_SEARCH_CONNECTORS or c in doc_types_set + ] + skipped = before_count - len(connectors) + if skipped: + perf.info( + "[kb_search] skipped %d empty connectors (had %d, now %d)", + skipped, + before_count, + len(connectors), + ) + perf.info( "[kb_search] searching %d connectors: %s (space=%d, top_k=%d)", len(connectors), @@ -436,81 +647,126 @@ async def search_knowledge_base_async( top_k, ) - connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = { - "YOUTUBE_VIDEO": ("search_youtube", True, True, {}), - "EXTENSION": ("search_extension", True, True, {}), - "CRAWLED_URL": ("search_crawled_urls", True, True, {}), - "FILE": ("search_files", True, True, {}), - "SLACK_CONNECTOR": ("search_slack", True, True, {}), - "TEAMS_CONNECTOR": ("search_teams", True, True, {}), - "NOTION_CONNECTOR": ("search_notion", True, True, {}), - "GITHUB_CONNECTOR": ("search_github", True, True, {}), - "LINEAR_CONNECTOR": ("search_linear", True, True, {}), + # --- Fast-path: degenerate queries (*, **, empty, etc.) --- + # Semantic embedding of '*' is noise and plainto_tsquery('english', '*') + # yields an empty tsquery, so both retrieval signals are useless. + # Fall back to a recency-ordered browse that returns diverse results. + if _is_degenerate_query(query): + perf.info( + "[kb_search] degenerate query %r detected - falling back to recency browse", + query, + ) + local_connectors = [c for c in connectors if c not in _LIVE_SEARCH_CONNECTORS] + if not local_connectors: + local_connectors = [None] # type: ignore[list-item] + + browse_results = await asyncio.gather( + *[ + _browse_recent_documents( + search_space_id=search_space_id, + document_type=c, + top_k=top_k, + start_date=resolved_start_date, + end_date=resolved_end_date, + ) + for c in local_connectors + ] + ) + for docs in browse_results: + all_documents.extend(docs) + + # Skip dedup + formatting below (browse already returns unique docs) + # but still cap output budget. + output_budget = _compute_tool_output_budget(max_input_tokens) + result = format_documents_for_context( + all_documents, + max_chars=output_budget, + max_chunks_per_doc=_BROWSE_MAX_CHUNKS_PER_DOC, + ) + perf.info( + "[kb_search] TOTAL (browse) in %.3fs total_docs=%d output_chars=%d " + "budget=%d space=%d", + time.perf_counter() - t0, + len(all_documents), + len(result), + output_budget, + search_space_id, + ) + return result + + # Specs for live-search connectors (external APIs, no local DB/embedding). + live_connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = { "TAVILY_API": ("search_tavily", False, True, {}), "SEARXNG_API": ("search_searxng", False, True, {}), "LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}), "BAIDU_SEARCH_API": ("search_baidu", False, True, {}), - "DISCORD_CONNECTOR": ("search_discord", True, True, {}), - "JIRA_CONNECTOR": ("search_jira", True, True, {}), - "GOOGLE_CALENDAR_CONNECTOR": ("search_google_calendar", True, True, {}), - "AIRTABLE_CONNECTOR": ("search_airtable", True, True, {}), - "GOOGLE_GMAIL_CONNECTOR": ("search_google_gmail", True, True, {}), - "GOOGLE_DRIVE_FILE": ("search_google_drive", True, True, {}), - "CONFLUENCE_CONNECTOR": ("search_confluence", True, True, {}), - "CLICKUP_CONNECTOR": ("search_clickup", True, True, {}), - "LUMA_CONNECTOR": ("search_luma", True, True, {}), - "ELASTICSEARCH_CONNECTOR": ("search_elasticsearch", True, True, {}), - "NOTE": ("search_notes", True, True, {}), - "BOOKSTACK_CONNECTOR": ("search_bookstack", True, True, {}), - "CIRCLEBACK": ("search_circleback", True, True, {}), - "OBSIDIAN_CONNECTOR": ("search_obsidian", True, True, {}), - # Composio connectors - "COMPOSIO_GOOGLE_DRIVE_CONNECTOR": ( - "search_composio_google_drive", - True, - True, - {}, - ), - "COMPOSIO_GMAIL_CONNECTOR": ("search_composio_gmail", True, True, {}), - "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": ( - "search_composio_google_calendar", - True, - True, - {}, - ), } - # Keep a conservative cap to avoid overloading DB/external services. + # --- Optimization 2: compute the query embedding once, share across all local searches --- + precomputed_embedding: list[float] | None = None + has_local_connectors = any(c not in _LIVE_SEARCH_CONNECTORS for c in connectors) + if has_local_connectors: + from app.config import config as app_config + + t_embed = time.perf_counter() + precomputed_embedding = app_config.embedding_model_instance.embed(query) + perf.info( + "[kb_search] shared embedding computed in %.3fs", + time.perf_counter() - t_embed, + ) + max_parallel_searches = 4 semaphore = asyncio.Semaphore(max_parallel_searches) async def _search_one_connector(connector: str) -> list[dict[str, Any]]: - spec = connector_specs.get(connector) - if spec is None: - return [] + is_live = connector in _LIVE_SEARCH_CONNECTORS - method_name, includes_date_range, includes_top_k, extra_kwargs = spec - kwargs: dict[str, Any] = { - "user_query": query, - "search_space_id": search_space_id, - **extra_kwargs, - } - if includes_top_k: - kwargs["top_k"] = top_k - if includes_date_range: - kwargs["start_date"] = resolved_start_date - kwargs["end_date"] = resolved_end_date + if is_live: + spec = live_connector_specs.get(connector) + if spec is None: + return [] + method_name, includes_date_range, includes_top_k, extra_kwargs = spec + kwargs: dict[str, Any] = { + "user_query": query, + "search_space_id": search_space_id, + **extra_kwargs, + } + if includes_top_k: + kwargs["top_k"] = top_k + if includes_date_range: + kwargs["start_date"] = resolved_start_date + kwargs["end_date"] = resolved_end_date + try: + t_conn = time.perf_counter() + async with semaphore, async_session_maker() as isolated_session: + svc = ConnectorService(isolated_session, search_space_id) + _, chunks = await getattr(svc, method_name)(**kwargs) + perf.info( + "[kb_search] connector=%s results=%d in %.3fs", + connector, + len(chunks), + time.perf_counter() - t_conn, + ) + return chunks + except Exception as e: + perf.warning("[kb_search] connector=%s FAILED: %s", connector, e) + return [] + + # --- Optimization 3: call _combined_rrf_search directly with shared embedding --- try: - # Use isolated session per connector. Shared AsyncSession cannot safely - # run concurrent DB operations. t_conn = time.perf_counter() async with semaphore, async_session_maker() as isolated_session: - isolated_connector_service = ConnectorService( - isolated_session, search_space_id + svc = ConnectorService(isolated_session, search_space_id) + chunks = await svc._combined_rrf_search( + query_text=query, + search_space_id=search_space_id, + document_type=connector, + top_k=top_k, + start_date=resolved_start_date, + end_date=resolved_end_date, + query_embedding=precomputed_embedding, ) - connector_method = getattr(isolated_connector_service, method_name) - _, chunks = await connector_method(**kwargs) perf.info( "[kb_search] connector=%s results=%d in %.3fs", connector, @@ -519,12 +775,7 @@ async def search_knowledge_base_async( ) return chunks except Exception as e: - perf.warning( - "[kb_search] connector=%s FAILED in %.3fs: %s", - connector, - time.perf_counter() - t_conn, - e, - ) + perf.warning("[kb_search] connector=%s FAILED: %s", connector, e) return [] t_gather = time.perf_counter() @@ -582,12 +833,24 @@ async def search_knowledge_base_async( output_budget = _compute_tool_output_budget(max_input_tokens) result = format_documents_for_context(deduplicated, max_chars=output_budget) + + if len(result) > output_budget: + perf.warning( + "[kb_search] output STILL exceeds budget after format (%d > %d), " + "hard truncation should have fired", + len(result), + output_budget, + ) + perf.info( - "[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d space=%d", + "[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d " + "budget=%d max_input_tokens=%s space=%d", time.perf_counter() - t0, len(all_documents), len(deduplicated), len(result), + output_budget, + max_input_tokens, search_space_id, ) return result @@ -628,11 +891,15 @@ class SearchKnowledgeBaseInput(BaseModel): """Input schema for the search_knowledge_base tool.""" query: str = Field( - description="The search query - be specific and include key terms" + description=( + "The search query - use specific natural language terms. " + "NEVER use wildcards like '*' or '**'; instead describe what you want " + "(e.g. 'recent meeting notes' or 'project architecture overview')." + ), ) top_k: int = Field( default=10, - description="Number of results to retrieve (default: 10)", + description="Number of results to retrieve (default: 10). Keep ≤20 for focused searches.", ) start_date: str | None = Field( default=None, @@ -695,6 +962,10 @@ Focus searches on these types for best results.""" Use this tool to find documents, notes, files, web pages, and other content that may help answer the user's question. IMPORTANT: +- Always craft specific, descriptive search queries using natural language keywords. + Good: "quarterly sales report Q3", "Python API authentication design". + Bad: "*", "**", "everything", single characters. Wildcard/empty queries yield poor results. +- Prefer multiple focused searches over a single broad one with high top_k. - If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below. - If `connectors_to_search` is omitted/empty, the system will search broadly. - Only connectors that are enabled/configured for this search space are available.{doc_types_info} @@ -710,6 +981,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type # Capture for closure _available_connectors = available_connectors + _available_document_types = available_document_types async def _search_knowledge_base_impl( query: str, @@ -739,6 +1011,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type start_date=parsed_start, end_date=parsed_end, available_connectors=_available_connectors, + available_document_types=_available_document_types, max_input_tokens=max_input_tokens, ) diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index f36f0de13..99cb09b38 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -145,10 +145,12 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ thread_id=deps["thread_id"], connector_service=deps.get("connector_service"), available_connectors=deps.get("available_connectors"), + available_document_types=deps.get("available_document_types"), ), requires=["search_space_id", "thread_id"], - # connector_service and available_connectors are optional — - # when missing, source_strategy="kb_search" degrades gracefully to "provided" + # connector_service, available_connectors, and available_document_types + # are optional — when missing, source_strategy="kb_search" degrades + # gracefully to "provided" ), # Link preview tool - fetches Open Graph metadata for URLs ToolDefinition( diff --git a/surfsense_backend/app/agents/new_chat/tools/report.py b/surfsense_backend/app/agents/new_chat/tools/report.py index 0896fea4b..5212c2c3b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/report.py +++ b/surfsense_backend/app/agents/new_chat/tools/report.py @@ -559,6 +559,7 @@ def create_generate_report_tool( thread_id: int | None = None, connector_service: ConnectorService | None = None, available_connectors: list[str] | None = None, + available_document_types: list[str] | None = None, ): """ Factory function to create the generate_report tool with injected dependencies. @@ -838,6 +839,7 @@ def create_generate_report_tool( connector_service=kb_connector_svc, top_k=10, available_connectors=available_connectors, + available_document_types=available_document_types, ) kb_results = await asyncio.gather( diff --git a/surfsense_backend/app/retriever/chunks_hybrid_search.py b/surfsense_backend/app/retriever/chunks_hybrid_search.py index 38ecba96c..4787e8147 100644 --- a/surfsense_backend/app/retriever/chunks_hybrid_search.py +++ b/surfsense_backend/app/retriever/chunks_hybrid_search.py @@ -3,6 +3,8 @@ from datetime import datetime from app.utils.perf import get_perf_logger +_MAX_FETCH_CHUNKS_PER_DOC = 30 + class ChucksHybridSearchRetriever: def __init__(self, db_session): @@ -346,8 +348,9 @@ class ChucksHybridSearchRetriever: if not doc_ids: return [] - # Fetch ALL chunks for selected documents in a single query so the final prompt can cite - # any chunk from those documents. + # Fetch chunks for selected documents. We cap per document to avoid + # loading hundreds of chunks for a single large file while still + # ensuring the chunks that matched the RRF query are always included. chunk_query = ( select(Chunk) .options(joinedload(Chunk.document)) @@ -357,7 +360,20 @@ class ChucksHybridSearchRetriever: .order_by(Chunk.document_id, Chunk.id) ) chunks_result = await self.db_session.execute(chunk_query) - all_chunks = chunks_result.scalars().all() + raw_chunks = chunks_result.scalars().all() + + matched_chunk_ids: set[int] = { + item["chunk_id"] for item in serialized_chunk_results + } + + doc_chunk_counts: dict[int, int] = {} + all_chunks: list = [] + for chunk in raw_chunks: + did = chunk.document_id + count = doc_chunk_counts.get(did, 0) + if chunk.id in matched_chunk_ids or count < _MAX_FETCH_CHUNKS_PER_DOC: + all_chunks.append(chunk) + doc_chunk_counts[did] = count + 1 # Assemble final doc-grouped results in the same order as doc_ids doc_map: dict[int, dict] = { diff --git a/surfsense_backend/app/retriever/documents_hybrid_search.py b/surfsense_backend/app/retriever/documents_hybrid_search.py index f4daf8e26..69e97384f 100644 --- a/surfsense_backend/app/retriever/documents_hybrid_search.py +++ b/surfsense_backend/app/retriever/documents_hybrid_search.py @@ -3,6 +3,8 @@ from datetime import datetime from app.utils.perf import get_perf_logger +_MAX_FETCH_CHUNKS_PER_DOC = 30 + class DocumentHybridSearchRetriever: def __init__(self, db_session): @@ -279,7 +281,8 @@ class DocumentHybridSearchRetriever: # Collect document IDs for chunk fetching doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores] - # Fetch ALL chunks for these documents in a single query + # Fetch chunks for these documents, capped per document to avoid + # loading hundreds of chunks for a single large file. chunks_query = ( select(Chunk) .options(joinedload(Chunk.document)) @@ -287,7 +290,16 @@ class DocumentHybridSearchRetriever: .order_by(Chunk.document_id, Chunk.id) ) chunks_result = await self.db_session.execute(chunks_query) - chunks = chunks_result.scalars().all() + raw_chunks = chunks_result.scalars().all() + + doc_chunk_counts: dict[int, int] = {} + chunks: list = [] + for chunk in raw_chunks: + did = chunk.document_id + count = doc_chunk_counts.get(did, 0) + if count < _MAX_FETCH_CHUNKS_PER_DOC: + chunks.append(chunk) + doc_chunk_counts[did] = count + 1 # Assemble doc-grouped results doc_map: dict[int, dict] = { diff --git a/surfsense_backend/app/services/connector_service.py b/surfsense_backend/app/services/connector_service.py index 157e0bab5..0aa48eccd 100644 --- a/surfsense_backend/app/services/connector_service.py +++ b/surfsense_backend/app/services/connector_service.py @@ -224,6 +224,7 @@ class ConnectorService: top_k: int = 20, start_date: datetime | None = None, end_date: datetime | None = None, + query_embedding: list[float] | None = None, ) -> list[dict[str, Any]]: """ Perform combined search using both chunk-based and document-based hybrid search, @@ -260,14 +261,15 @@ class ConnectorService: # Get more results from each retriever for better fusion retriever_top_k = top_k * 2 - # Pre-compute the embedding once so both retrievers reuse it. - t_embed = time.perf_counter() - query_embedding = config.embedding_model_instance.embed(query_text) - perf.info( - "[connector_svc] _combined_rrf embedding in %.3fs type=%s", - time.perf_counter() - t_embed, - document_type, - ) + # Reuse caller-provided embedding or compute once for both retrievers. + if query_embedding is None: + t_embed = time.perf_counter() + query_embedding = config.embedding_model_instance.embed(query_text) + perf.info( + "[connector_svc] _combined_rrf embedding in %.3fs type=%s", + time.perf_counter() - t_embed, + document_type, + ) search_kwargs = { "query_text": query_text, diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 2465834f4..e8c0d2d47 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -159,26 +159,95 @@ class LLMRouterService: # Merge with provided settings final_settings = {**default_settings, **instance._router_settings} + # Build a "auto-large" fallback group with deployments whose context + # window exceeds the smallest deployment. This lets the router + # automatically fall back to a bigger-context model when gpt-4o (128K) + # hits ContextWindowExceededError. + full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list) + try: - instance._router = Router( - model_list=model_list, - routing_strategy=final_settings.get( + router_kwargs: dict[str, Any] = { + "model_list": full_model_list, + "routing_strategy": final_settings.get( "routing_strategy", "usage-based-routing" ), - num_retries=final_settings.get("num_retries", 3), - allowed_fails=final_settings.get("allowed_fails", 3), - cooldown_time=final_settings.get("cooldown_time", 60), - set_verbose=False, # Disable verbose logging in production - ) + "num_retries": final_settings.get("num_retries", 3), + "allowed_fails": final_settings.get("allowed_fails", 3), + "cooldown_time": final_settings.get("cooldown_time", 60), + "set_verbose": False, + } + if ctx_fallbacks: + router_kwargs["context_window_fallbacks"] = ctx_fallbacks + + instance._router = Router(**router_kwargs) instance._initialized = True logger.info( - f"LLM Router initialized with {len(model_list)} deployments, " - f"strategy: {final_settings.get('routing_strategy')}" + "LLM Router initialized with %d deployments, " + "strategy: %s, context_window_fallbacks: %s", + len(model_list), + final_settings.get("routing_strategy"), + ctx_fallbacks or "none", ) except Exception as e: logger.error(f"Failed to initialize LLM Router: {e}") instance._router = None + @classmethod + def _build_context_fallback_groups( + cls, model_list: list[dict] + ) -> tuple[list[dict], list[dict[str, list[str]]] | None]: + """Create an ``auto-large`` model group for context-window fallbacks. + + Uses ``litellm.get_model_info`` to discover the context window of each + deployment. Deployments whose ``max_input_tokens`` exceeds the smallest + window are duplicated into an ``auto-large`` group. The returned + fallback config tells the Router: on ``ContextWindowExceededError`` for + ``auto``, retry with ``auto-large``. + + Returns: + (full_model_list, context_window_fallbacks) — ``full_model_list`` + contains the original entries plus any ``auto-large`` duplicates. + ``context_window_fallbacks`` is ``None`` when every deployment has + the same context size (no useful fallback). + """ + from litellm import get_model_info + + ctx_map: dict[str, int] = {} + for dep in model_list: + params = dep.get("litellm_params", {}) + base_model = params.get("base_model") or params.get("model", "") + try: + info = get_model_info(base_model) + ctx = info.get("max_input_tokens") + if isinstance(ctx, int) and ctx > 0: + ctx_map[base_model] = ctx + except Exception: + continue + + if not ctx_map: + return model_list, None + + min_ctx = min(ctx_map.values()) + + large_deployments: list[dict] = [] + for dep in model_list: + params = dep.get("litellm_params", {}) + base_model = params.get("base_model") or params.get("model", "") + if ctx_map.get(base_model, 0) > min_ctx: + dup = {**dep, "model_name": "auto-large"} + large_deployments.append(dup) + + if not large_deployments: + return model_list, None + + logger.info( + "Context-window fallback: %d large-context deployments " + "(min_ctx=%d) added to 'auto-large' group", + len(large_deployments), + min_ctx, + ) + return model_list + large_deployments, [{"auto": ["auto-large"]}] + @classmethod def _config_to_deployment(cls, config: dict) -> dict | None: """