feat: enhance knowledge base search and document retrieval

- Introduced a mechanism to identify degenerate queries that lack meaningful search signals, improving search accuracy.
- Implemented a fallback method for browsing recent documents when queries are degenerate, ensuring relevant results are returned.
- Added limits on the number of chunks fetched per document to optimize performance and prevent excessive data loading.
- Updated the ConnectorService to allow for reusable query embeddings, enhancing efficiency in search operations.
- Enhanced LLM router service to support context window fallbacks, improving robustness during context window limitations.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-02-28 19:40:24 -08:00
parent b08e8da40c
commit 40a091f8cc
7 changed files with 476 additions and 100 deletions

View file

@ -10,6 +10,7 @@ This module provides:
import asyncio import asyncio
import json import json
import re
import time import time
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
@ -22,6 +23,149 @@ from app.db import async_session_maker
from app.services.connector_service import ConnectorService from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger from app.utils.perf import get_perf_logger
# Connectors that call external live-search APIs (no local DB / embedding needed).
# These are never filtered by available_document_types.
_LIVE_SEARCH_CONNECTORS: set[str] = {
"TAVILY_API",
"SEARXNG_API",
"LINKUP_API",
"BAIDU_SEARCH_API",
}
# Patterns that indicate the query has no meaningful search signal.
# plainto_tsquery('english', '*') produces an empty tsquery and an embedding
# of '*' is random noise, so both keyword and semantic search degrade to
# arbitrary ordering — large documents (many chunks) dominate by chance.
_DEGENERATE_QUERY_RE = re.compile(
r"^[\s*?_.#@!\-/\\]+$" # only wildcards, punctuation, whitespace
)
# Max chunks per document when doing a recency-based browse instead of
# a real search. We want breadth (many docs) over depth (many chunks).
_BROWSE_MAX_CHUNKS_PER_DOC = 5
def _is_degenerate_query(query: str) -> bool:
"""Return True when the query carries no meaningful search signal.
Catches wildcard patterns (``*``, ``**``), empty / whitespace-only
strings, and single-character non-word tokens. These queries cause
both keyword search (empty tsquery) and semantic search (meaningless
embedding) to return effectively random results.
"""
stripped = query.strip()
if not stripped:
return True
return bool(_DEGENERATE_QUERY_RE.match(stripped))
async def _browse_recent_documents(
search_space_id: int,
document_type: str | None,
top_k: int,
start_date: datetime | None,
end_date: datetime | None,
) -> list[dict[str, Any]]:
"""Return the most-recent documents (recency-ordered, no search ranking).
Used as a fallback when the search query is degenerate (e.g. ``*``) and
semantic / keyword search would produce arbitrary results. Returns
document-grouped dicts in the same shape as ``_combined_rrf_search``
so the rest of the pipeline works unchanged.
"""
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from app.db import Chunk, Document, DocumentType
perf = get_perf_logger()
t0 = time.perf_counter()
base_conditions = [Document.search_space_id == search_space_id]
if document_type is not None:
if isinstance(document_type, str):
try:
doc_type_enum = DocumentType[document_type]
base_conditions.append(Document.document_type == doc_type_enum)
except KeyError:
return []
else:
base_conditions.append(Document.document_type == document_type)
if start_date is not None:
base_conditions.append(Document.updated_at >= start_date)
if end_date is not None:
base_conditions.append(Document.updated_at <= end_date)
async with async_session_maker() as session:
doc_query = (
select(Document)
.options(joinedload(Document.search_space))
.where(*base_conditions)
.order_by(Document.updated_at.desc())
.limit(top_k)
)
result = await session.execute(doc_query)
documents = result.scalars().unique().all()
if not documents:
return []
doc_ids = [d.id for d in documents]
chunk_query = (
select(Chunk)
.where(Chunk.document_id.in_(doc_ids))
.order_by(Chunk.document_id, Chunk.id)
)
chunk_result = await session.execute(chunk_query)
raw_chunks = chunk_result.scalars().all()
doc_chunk_counts: dict[int, int] = {}
doc_chunks: dict[int, list[dict]] = {d.id: [] for d in documents}
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if count < _BROWSE_MAX_CHUNKS_PER_DOC:
doc_chunks[did].append({"chunk_id": chunk.id, "content": chunk.content})
doc_chunk_counts[did] = count + 1
results: list[dict[str, Any]] = []
for doc in documents:
chunks_list = doc_chunks.get(doc.id, [])
results.append(
{
"document_id": doc.id,
"content": "\n\n".join(
c["content"] for c in chunks_list if c.get("content")
),
"score": 0.0,
"chunks": chunks_list,
"document": {
"id": doc.id,
"title": doc.title,
"document_type": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
"metadata": doc.document_metadata or {},
},
"source": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
}
)
perf.info(
"[kb_browse] recency browse in %.3fs docs=%d space=%d type=%s",
time.perf_counter() - t0,
len(results),
search_space_id,
document_type,
)
return results
# ============================================================================= # =============================================================================
# Connector Constants and Normalization # Connector Constants and Normalization
# ============================================================================= # =============================================================================
@ -184,9 +328,23 @@ _CHARS_PER_TOKEN = 4
# Hard-floor / ceiling so the budget is always sensible regardless of what # Hard-floor / ceiling so the budget is always sensible regardless of what
# the model reports. # the model reports.
_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens _MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens
_MAX_TOOL_OUTPUT_CHARS = 400_000 # ~100K tokens _MAX_TOOL_OUTPUT_CHARS = 200_000 # ~50K tokens
_MAX_CHUNK_CHARS = 8_000 _MAX_CHUNK_CHARS = 8_000
# Rank-adaptive per-document budget allocation.
# Top-ranked (most relevant) documents get a larger share of the budget so
# we pack as much high-quality context as possible.
#
# fraction(rank) = _TOP_DOC_BUDGET_FRACTION / (1 + rank * _RANK_DECAY)
#
# Examples (128K budget, 8K chunk cap):
# rank 0 → 40% → 6 chunks | rank 3 → 19% → 3 chunks
# rank 1 → 30% → 4 chunks | rank 10 → 10% → 3 chunks (floor)
# rank 2 → 24% → 3 chunks |
_TOP_DOC_BUDGET_FRACTION = 0.40
_RANK_DECAY = 0.35
_MIN_CHUNKS_PER_DOC = 3
def _compute_tool_output_budget(max_input_tokens: int | None) -> int: def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
"""Derive a character budget from the model's context window. """Derive a character budget from the model's context window.
@ -208,18 +366,24 @@ def format_documents_for_context(
*, *,
max_chars: int = _MAX_TOOL_OUTPUT_CHARS, max_chars: int = _MAX_TOOL_OUTPUT_CHARS,
max_chunk_chars: int = _MAX_CHUNK_CHARS, max_chunk_chars: int = _MAX_CHUNK_CHARS,
max_chunks_per_doc: int = 0,
) -> str: ) -> str:
""" """
Format retrieved documents into a readable context string for the LLM. Format retrieved documents into a readable context string for the LLM.
Documents are added in order (highest relevance first) until the character Documents are added in order (highest relevance first) until the character
budget is reached. Individual chunks are capped at ``max_chunk_chars`` so budget is reached. Individual chunks are capped at ``max_chunk_chars`` and
a single oversized chunk cannot monopolize the output. each document is limited to a dynamically computed chunk cap so a single
large document cannot monopolize the output while still maximising the use
of available context space.
Args: Args:
documents: List of document dictionaries from connector search documents: List of document dictionaries from connector search
max_chars: Approximate character budget for the entire output. max_chars: Approximate character budget for the entire output.
max_chunk_chars: Per-chunk character cap (content is tail-truncated). max_chunk_chars: Per-chunk character cap (content is tail-truncated).
max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means
auto-compute per document using a rank-adaptive formula so
higher-ranked documents receive more chunks.
Returns: Returns:
Formatted string with document contents and metadata Formatted string with document contents and metadata
@ -342,7 +506,23 @@ def format_documents_for_context(
"<document_content>", "<document_content>",
] ]
for ch in g["chunks"]: # Rank-adaptive per-document chunk cap: top results get more chunks.
if max_chunks_per_doc > 0:
chunks_allowed = max_chunks_per_doc
else:
doc_fraction = _TOP_DOC_BUDGET_FRACTION / (1 + doc_idx * _RANK_DECAY)
max_doc_chars = int(max_chars * doc_fraction)
xml_overhead = 500
chunks_allowed = max(
(max_doc_chars - xml_overhead) // max(max_chunk_chars, 1),
_MIN_CHUNKS_PER_DOC,
)
chunks = g["chunks"]
if len(chunks) > chunks_allowed:
chunks = chunks[:chunks_allowed]
for ch in chunks:
ch_content = ch["content"] ch_content = ch["content"]
if max_chunk_chars and len(ch_content) > max_chunk_chars: if max_chunk_chars and len(ch_content) > max_chunk_chars:
ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)" ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)"
@ -359,9 +539,11 @@ def format_documents_for_context(
doc_xml = "\n".join(doc_lines) doc_xml = "\n".join(doc_lines)
doc_len = len(doc_xml) doc_len = len(doc_xml)
# Always include at least the first document; afterwards enforce budget. if total_chars + doc_len > max_chars:
if doc_idx > 0 and total_chars + doc_len > max_chars:
remaining = total_docs - doc_idx remaining = total_docs - doc_idx
if doc_idx == 0:
parts.append(doc_xml)
total_chars += doc_len
parts.append( parts.append(
f"<!-- Output truncated: {remaining} more document(s) omitted " f"<!-- Output truncated: {remaining} more document(s) omitted "
f"(budget {max_chars} chars). Refine your query or reduce top_k " f"(budget {max_chars} chars). Refine your query or reduce top_k "
@ -372,7 +554,15 @@ def format_documents_for_context(
parts.append(doc_xml) parts.append(doc_xml)
total_chars += doc_len total_chars += doc_len
return "\n".join(parts).strip() result = "\n".join(parts).strip()
# Hard safety net: if the result is still over budget (e.g. a single massive
# first document), forcibly truncate with a closing comment.
if len(result) > max_chars:
truncation_msg = "\n<!-- ...output forcibly truncated to fit context window -->"
result = result[: max_chars - len(truncation_msg)] + truncation_msg
return result
# ============================================================================= # =============================================================================
@ -390,6 +580,7 @@ async def search_knowledge_base_async(
start_date: datetime | None = None, start_date: datetime | None = None,
end_date: datetime | None = None, end_date: datetime | None = None,
available_connectors: list[str] | None = None, available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
max_input_tokens: int | None = None, max_input_tokens: int | None = None,
) -> str: ) -> str:
""" """
@ -408,6 +599,9 @@ async def search_knowledge_base_async(
end_date: Optional end datetime (UTC) for filtering documents end_date: Optional end datetime (UTC) for filtering documents
available_connectors: Optional list of connectors actually available in the search space. available_connectors: Optional list of connectors actually available in the search space.
If provided, only these connectors will be searched. If provided, only these connectors will be searched.
available_document_types: Optional list of document types that actually have indexed
data. When provided, local connectors whose document type is
absent are skipped entirely (no embedding / DB round-trip).
max_input_tokens: Model context window size (tokens). Used to dynamically max_input_tokens: Model context window size (tokens). Used to dynamically
size the output so it fits within the model's limits. size the output so it fits within the model's limits.
@ -428,6 +622,23 @@ async def search_knowledge_base_async(
) )
connectors = _normalize_connectors(connectors_to_search, available_connectors) connectors = _normalize_connectors(connectors_to_search, available_connectors)
# --- Optimization 1: skip local connectors that have zero indexed documents ---
if available_document_types:
doc_types_set = set(available_document_types)
before_count = len(connectors)
connectors = [
c for c in connectors if c in _LIVE_SEARCH_CONNECTORS or c in doc_types_set
]
skipped = before_count - len(connectors)
if skipped:
perf.info(
"[kb_search] skipped %d empty connectors (had %d, now %d)",
skipped,
before_count,
len(connectors),
)
perf.info( perf.info(
"[kb_search] searching %d connectors: %s (space=%d, top_k=%d)", "[kb_search] searching %d connectors: %s (space=%d, top_k=%d)",
len(connectors), len(connectors),
@ -436,59 +647,84 @@ async def search_knowledge_base_async(
top_k, top_k,
) )
connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = { # --- Fast-path: degenerate queries (*, **, empty, etc.) ---
"YOUTUBE_VIDEO": ("search_youtube", True, True, {}), # Semantic embedding of '*' is noise and plainto_tsquery('english', '*')
"EXTENSION": ("search_extension", True, True, {}), # yields an empty tsquery, so both retrieval signals are useless.
"CRAWLED_URL": ("search_crawled_urls", True, True, {}), # Fall back to a recency-ordered browse that returns diverse results.
"FILE": ("search_files", True, True, {}), if _is_degenerate_query(query):
"SLACK_CONNECTOR": ("search_slack", True, True, {}), perf.info(
"TEAMS_CONNECTOR": ("search_teams", True, True, {}), "[kb_search] degenerate query %r detected - falling back to recency browse",
"NOTION_CONNECTOR": ("search_notion", True, True, {}), query,
"GITHUB_CONNECTOR": ("search_github", True, True, {}), )
"LINEAR_CONNECTOR": ("search_linear", True, True, {}), local_connectors = [c for c in connectors if c not in _LIVE_SEARCH_CONNECTORS]
if not local_connectors:
local_connectors = [None] # type: ignore[list-item]
browse_results = await asyncio.gather(
*[
_browse_recent_documents(
search_space_id=search_space_id,
document_type=c,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
for c in local_connectors
]
)
for docs in browse_results:
all_documents.extend(docs)
# Skip dedup + formatting below (browse already returns unique docs)
# but still cap output budget.
output_budget = _compute_tool_output_budget(max_input_tokens)
result = format_documents_for_context(
all_documents,
max_chars=output_budget,
max_chunks_per_doc=_BROWSE_MAX_CHUNKS_PER_DOC,
)
perf.info(
"[kb_search] TOTAL (browse) in %.3fs total_docs=%d output_chars=%d "
"budget=%d space=%d",
time.perf_counter() - t0,
len(all_documents),
len(result),
output_budget,
search_space_id,
)
return result
# Specs for live-search connectors (external APIs, no local DB/embedding).
live_connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
"TAVILY_API": ("search_tavily", False, True, {}), "TAVILY_API": ("search_tavily", False, True, {}),
"SEARXNG_API": ("search_searxng", False, True, {}), "SEARXNG_API": ("search_searxng", False, True, {}),
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}), "LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}), "BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
"DISCORD_CONNECTOR": ("search_discord", True, True, {}),
"JIRA_CONNECTOR": ("search_jira", True, True, {}),
"GOOGLE_CALENDAR_CONNECTOR": ("search_google_calendar", True, True, {}),
"AIRTABLE_CONNECTOR": ("search_airtable", True, True, {}),
"GOOGLE_GMAIL_CONNECTOR": ("search_google_gmail", True, True, {}),
"GOOGLE_DRIVE_FILE": ("search_google_drive", True, True, {}),
"CONFLUENCE_CONNECTOR": ("search_confluence", True, True, {}),
"CLICKUP_CONNECTOR": ("search_clickup", True, True, {}),
"LUMA_CONNECTOR": ("search_luma", True, True, {}),
"ELASTICSEARCH_CONNECTOR": ("search_elasticsearch", True, True, {}),
"NOTE": ("search_notes", True, True, {}),
"BOOKSTACK_CONNECTOR": ("search_bookstack", True, True, {}),
"CIRCLEBACK": ("search_circleback", True, True, {}),
"OBSIDIAN_CONNECTOR": ("search_obsidian", True, True, {}),
# Composio connectors
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": (
"search_composio_google_drive",
True,
True,
{},
),
"COMPOSIO_GMAIL_CONNECTOR": ("search_composio_gmail", True, True, {}),
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": (
"search_composio_google_calendar",
True,
True,
{},
),
} }
# Keep a conservative cap to avoid overloading DB/external services. # --- Optimization 2: compute the query embedding once, share across all local searches ---
precomputed_embedding: list[float] | None = None
has_local_connectors = any(c not in _LIVE_SEARCH_CONNECTORS for c in connectors)
if has_local_connectors:
from app.config import config as app_config
t_embed = time.perf_counter()
precomputed_embedding = app_config.embedding_model_instance.embed(query)
perf.info(
"[kb_search] shared embedding computed in %.3fs",
time.perf_counter() - t_embed,
)
max_parallel_searches = 4 max_parallel_searches = 4
semaphore = asyncio.Semaphore(max_parallel_searches) semaphore = asyncio.Semaphore(max_parallel_searches)
async def _search_one_connector(connector: str) -> list[dict[str, Any]]: async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
spec = connector_specs.get(connector) is_live = connector in _LIVE_SEARCH_CONNECTORS
if is_live:
spec = live_connector_specs.get(connector)
if spec is None: if spec is None:
return [] return []
method_name, includes_date_range, includes_top_k, extra_kwargs = spec method_name, includes_date_range, includes_top_k, extra_kwargs = spec
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"user_query": query, "user_query": query,
@ -502,15 +738,10 @@ async def search_knowledge_base_async(
kwargs["end_date"] = resolved_end_date kwargs["end_date"] = resolved_end_date
try: try:
# Use isolated session per connector. Shared AsyncSession cannot safely
# run concurrent DB operations.
t_conn = time.perf_counter() t_conn = time.perf_counter()
async with semaphore, async_session_maker() as isolated_session: async with semaphore, async_session_maker() as isolated_session:
isolated_connector_service = ConnectorService( svc = ConnectorService(isolated_session, search_space_id)
isolated_session, search_space_id _, chunks = await getattr(svc, method_name)(**kwargs)
)
connector_method = getattr(isolated_connector_service, method_name)
_, chunks = await connector_method(**kwargs)
perf.info( perf.info(
"[kb_search] connector=%s results=%d in %.3fs", "[kb_search] connector=%s results=%d in %.3fs",
connector, connector,
@ -519,12 +750,32 @@ async def search_knowledge_base_async(
) )
return chunks return chunks
except Exception as e: except Exception as e:
perf.warning( perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
"[kb_search] connector=%s FAILED in %.3fs: %s", return []
connector,
time.perf_counter() - t_conn, # --- Optimization 3: call _combined_rrf_search directly with shared embedding ---
e, try:
t_conn = time.perf_counter()
async with semaphore, async_session_maker() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id)
chunks = await svc._combined_rrf_search(
query_text=query,
search_space_id=search_space_id,
document_type=connector,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
query_embedding=precomputed_embedding,
) )
perf.info(
"[kb_search] connector=%s results=%d in %.3fs",
connector,
len(chunks),
time.perf_counter() - t_conn,
)
return chunks
except Exception as e:
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
return [] return []
t_gather = time.perf_counter() t_gather = time.perf_counter()
@ -582,12 +833,24 @@ async def search_knowledge_base_async(
output_budget = _compute_tool_output_budget(max_input_tokens) output_budget = _compute_tool_output_budget(max_input_tokens)
result = format_documents_for_context(deduplicated, max_chars=output_budget) result = format_documents_for_context(deduplicated, max_chars=output_budget)
if len(result) > output_budget:
perf.warning(
"[kb_search] output STILL exceeds budget after format (%d > %d), "
"hard truncation should have fired",
len(result),
output_budget,
)
perf.info( perf.info(
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d space=%d", "[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
"budget=%d max_input_tokens=%s space=%d",
time.perf_counter() - t0, time.perf_counter() - t0,
len(all_documents), len(all_documents),
len(deduplicated), len(deduplicated),
len(result), len(result),
output_budget,
max_input_tokens,
search_space_id, search_space_id,
) )
return result return result
@ -628,11 +891,15 @@ class SearchKnowledgeBaseInput(BaseModel):
"""Input schema for the search_knowledge_base tool.""" """Input schema for the search_knowledge_base tool."""
query: str = Field( query: str = Field(
description="The search query - be specific and include key terms" description=(
"The search query - use specific natural language terms. "
"NEVER use wildcards like '*' or '**'; instead describe what you want "
"(e.g. 'recent meeting notes' or 'project architecture overview')."
),
) )
top_k: int = Field( top_k: int = Field(
default=10, default=10,
description="Number of results to retrieve (default: 10)", description="Number of results to retrieve (default: 10). Keep ≤20 for focused searches.",
) )
start_date: str | None = Field( start_date: str | None = Field(
default=None, default=None,
@ -695,6 +962,10 @@ Focus searches on these types for best results."""
Use this tool to find documents, notes, files, web pages, and other content that may help answer the user's question. Use this tool to find documents, notes, files, web pages, and other content that may help answer the user's question.
IMPORTANT: IMPORTANT:
- Always craft specific, descriptive search queries using natural language keywords.
Good: "quarterly sales report Q3", "Python API authentication design".
Bad: "*", "**", "everything", single characters. Wildcard/empty queries yield poor results.
- Prefer multiple focused searches over a single broad one with high top_k.
- If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below. - If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below.
- If `connectors_to_search` is omitted/empty, the system will search broadly. - If `connectors_to_search` is omitted/empty, the system will search broadly.
- Only connectors that are enabled/configured for this search space are available.{doc_types_info} - Only connectors that are enabled/configured for this search space are available.{doc_types_info}
@ -710,6 +981,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
# Capture for closure # Capture for closure
_available_connectors = available_connectors _available_connectors = available_connectors
_available_document_types = available_document_types
async def _search_knowledge_base_impl( async def _search_knowledge_base_impl(
query: str, query: str,
@ -739,6 +1011,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
start_date=parsed_start, start_date=parsed_start,
end_date=parsed_end, end_date=parsed_end,
available_connectors=_available_connectors, available_connectors=_available_connectors,
available_document_types=_available_document_types,
max_input_tokens=max_input_tokens, max_input_tokens=max_input_tokens,
) )

View file

@ -145,10 +145,12 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
thread_id=deps["thread_id"], thread_id=deps["thread_id"],
connector_service=deps.get("connector_service"), connector_service=deps.get("connector_service"),
available_connectors=deps.get("available_connectors"), available_connectors=deps.get("available_connectors"),
available_document_types=deps.get("available_document_types"),
), ),
requires=["search_space_id", "thread_id"], requires=["search_space_id", "thread_id"],
# connector_service and available_connectors are optional — # connector_service, available_connectors, and available_document_types
# when missing, source_strategy="kb_search" degrades gracefully to "provided" # are optional — when missing, source_strategy="kb_search" degrades
# gracefully to "provided"
), ),
# Link preview tool - fetches Open Graph metadata for URLs # Link preview tool - fetches Open Graph metadata for URLs
ToolDefinition( ToolDefinition(

View file

@ -559,6 +559,7 @@ def create_generate_report_tool(
thread_id: int | None = None, thread_id: int | None = None,
connector_service: ConnectorService | None = None, connector_service: ConnectorService | None = None,
available_connectors: list[str] | None = None, available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
): ):
""" """
Factory function to create the generate_report tool with injected dependencies. Factory function to create the generate_report tool with injected dependencies.
@ -838,6 +839,7 @@ def create_generate_report_tool(
connector_service=kb_connector_svc, connector_service=kb_connector_svc,
top_k=10, top_k=10,
available_connectors=available_connectors, available_connectors=available_connectors,
available_document_types=available_document_types,
) )
kb_results = await asyncio.gather( kb_results = await asyncio.gather(

View file

@ -3,6 +3,8 @@ from datetime import datetime
from app.utils.perf import get_perf_logger from app.utils.perf import get_perf_logger
_MAX_FETCH_CHUNKS_PER_DOC = 30
class ChucksHybridSearchRetriever: class ChucksHybridSearchRetriever:
def __init__(self, db_session): def __init__(self, db_session):
@ -346,8 +348,9 @@ class ChucksHybridSearchRetriever:
if not doc_ids: if not doc_ids:
return [] return []
# Fetch ALL chunks for selected documents in a single query so the final prompt can cite # Fetch chunks for selected documents. We cap per document to avoid
# any chunk from those documents. # loading hundreds of chunks for a single large file while still
# ensuring the chunks that matched the RRF query are always included.
chunk_query = ( chunk_query = (
select(Chunk) select(Chunk)
.options(joinedload(Chunk.document)) .options(joinedload(Chunk.document))
@ -357,7 +360,20 @@ class ChucksHybridSearchRetriever:
.order_by(Chunk.document_id, Chunk.id) .order_by(Chunk.document_id, Chunk.id)
) )
chunks_result = await self.db_session.execute(chunk_query) chunks_result = await self.db_session.execute(chunk_query)
all_chunks = chunks_result.scalars().all() raw_chunks = chunks_result.scalars().all()
matched_chunk_ids: set[int] = {
item["chunk_id"] for item in serialized_chunk_results
}
doc_chunk_counts: dict[int, int] = {}
all_chunks: list = []
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if chunk.id in matched_chunk_ids or count < _MAX_FETCH_CHUNKS_PER_DOC:
all_chunks.append(chunk)
doc_chunk_counts[did] = count + 1
# Assemble final doc-grouped results in the same order as doc_ids # Assemble final doc-grouped results in the same order as doc_ids
doc_map: dict[int, dict] = { doc_map: dict[int, dict] = {

View file

@ -3,6 +3,8 @@ from datetime import datetime
from app.utils.perf import get_perf_logger from app.utils.perf import get_perf_logger
_MAX_FETCH_CHUNKS_PER_DOC = 30
class DocumentHybridSearchRetriever: class DocumentHybridSearchRetriever:
def __init__(self, db_session): def __init__(self, db_session):
@ -279,7 +281,8 @@ class DocumentHybridSearchRetriever:
# Collect document IDs for chunk fetching # Collect document IDs for chunk fetching
doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores] doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores]
# Fetch ALL chunks for these documents in a single query # Fetch chunks for these documents, capped per document to avoid
# loading hundreds of chunks for a single large file.
chunks_query = ( chunks_query = (
select(Chunk) select(Chunk)
.options(joinedload(Chunk.document)) .options(joinedload(Chunk.document))
@ -287,7 +290,16 @@ class DocumentHybridSearchRetriever:
.order_by(Chunk.document_id, Chunk.id) .order_by(Chunk.document_id, Chunk.id)
) )
chunks_result = await self.db_session.execute(chunks_query) chunks_result = await self.db_session.execute(chunks_query)
chunks = chunks_result.scalars().all() raw_chunks = chunks_result.scalars().all()
doc_chunk_counts: dict[int, int] = {}
chunks: list = []
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if count < _MAX_FETCH_CHUNKS_PER_DOC:
chunks.append(chunk)
doc_chunk_counts[did] = count + 1
# Assemble doc-grouped results # Assemble doc-grouped results
doc_map: dict[int, dict] = { doc_map: dict[int, dict] = {

View file

@ -224,6 +224,7 @@ class ConnectorService:
top_k: int = 20, top_k: int = 20,
start_date: datetime | None = None, start_date: datetime | None = None,
end_date: datetime | None = None, end_date: datetime | None = None,
query_embedding: list[float] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Perform combined search using both chunk-based and document-based hybrid search, Perform combined search using both chunk-based and document-based hybrid search,
@ -260,7 +261,8 @@ class ConnectorService:
# Get more results from each retriever for better fusion # Get more results from each retriever for better fusion
retriever_top_k = top_k * 2 retriever_top_k = top_k * 2
# Pre-compute the embedding once so both retrievers reuse it. # Reuse caller-provided embedding or compute once for both retrievers.
if query_embedding is None:
t_embed = time.perf_counter() t_embed = time.perf_counter()
query_embedding = config.embedding_model_instance.embed(query_text) query_embedding = config.embedding_model_instance.embed(query_text)
perf.info( perf.info(

View file

@ -159,26 +159,95 @@ class LLMRouterService:
# Merge with provided settings # Merge with provided settings
final_settings = {**default_settings, **instance._router_settings} final_settings = {**default_settings, **instance._router_settings}
# Build a "auto-large" fallback group with deployments whose context
# window exceeds the smallest deployment. This lets the router
# automatically fall back to a bigger-context model when gpt-4o (128K)
# hits ContextWindowExceededError.
full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
try: try:
instance._router = Router( router_kwargs: dict[str, Any] = {
model_list=model_list, "model_list": full_model_list,
routing_strategy=final_settings.get( "routing_strategy": final_settings.get(
"routing_strategy", "usage-based-routing" "routing_strategy", "usage-based-routing"
), ),
num_retries=final_settings.get("num_retries", 3), "num_retries": final_settings.get("num_retries", 3),
allowed_fails=final_settings.get("allowed_fails", 3), "allowed_fails": final_settings.get("allowed_fails", 3),
cooldown_time=final_settings.get("cooldown_time", 60), "cooldown_time": final_settings.get("cooldown_time", 60),
set_verbose=False, # Disable verbose logging in production "set_verbose": False,
) }
if ctx_fallbacks:
router_kwargs["context_window_fallbacks"] = ctx_fallbacks
instance._router = Router(**router_kwargs)
instance._initialized = True instance._initialized = True
logger.info( logger.info(
f"LLM Router initialized with {len(model_list)} deployments, " "LLM Router initialized with %d deployments, "
f"strategy: {final_settings.get('routing_strategy')}" "strategy: %s, context_window_fallbacks: %s",
len(model_list),
final_settings.get("routing_strategy"),
ctx_fallbacks or "none",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize LLM Router: {e}") logger.error(f"Failed to initialize LLM Router: {e}")
instance._router = None instance._router = None
@classmethod
def _build_context_fallback_groups(
cls, model_list: list[dict]
) -> tuple[list[dict], list[dict[str, list[str]]] | None]:
"""Create an ``auto-large`` model group for context-window fallbacks.
Uses ``litellm.get_model_info`` to discover the context window of each
deployment. Deployments whose ``max_input_tokens`` exceeds the smallest
window are duplicated into an ``auto-large`` group. The returned
fallback config tells the Router: on ``ContextWindowExceededError`` for
``auto``, retry with ``auto-large``.
Returns:
(full_model_list, context_window_fallbacks) ``full_model_list``
contains the original entries plus any ``auto-large`` duplicates.
``context_window_fallbacks`` is ``None`` when every deployment has
the same context size (no useful fallback).
"""
from litellm import get_model_info
ctx_map: dict[str, int] = {}
for dep in model_list:
params = dep.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
try:
info = get_model_info(base_model)
ctx = info.get("max_input_tokens")
if isinstance(ctx, int) and ctx > 0:
ctx_map[base_model] = ctx
except Exception:
continue
if not ctx_map:
return model_list, None
min_ctx = min(ctx_map.values())
large_deployments: list[dict] = []
for dep in model_list:
params = dep.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
if ctx_map.get(base_model, 0) > min_ctx:
dup = {**dep, "model_name": "auto-large"}
large_deployments.append(dup)
if not large_deployments:
return model_list, None
logger.info(
"Context-window fallback: %d large-context deployments "
"(min_ctx=%d) added to 'auto-large' group",
len(large_deployments),
min_ctx,
)
return model_list + large_deployments, [{"auto": ["auto-large"]}]
@classmethod @classmethod
def _config_to_deployment(cls, config: dict) -> dict | None: def _config_to_deployment(cls, config: dict) -> dict | None:
""" """