mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 13:22:41 +02:00
feat: enhance knowledge base search and document retrieval
- Introduced a mechanism to identify degenerate queries that lack meaningful search signals, improving search accuracy. - Implemented a fallback method for browsing recent documents when queries are degenerate, ensuring relevant results are returned. - Added limits on the number of chunks fetched per document to optimize performance and prevent excessive data loading. - Updated the ConnectorService to allow for reusable query embeddings, enhancing efficiency in search operations. - Enhanced LLM router service to support context window fallbacks, improving robustness during context window limitations.
This commit is contained in:
parent
b08e8da40c
commit
40a091f8cc
7 changed files with 476 additions and 100 deletions
|
|
@ -10,6 +10,7 @@ This module provides:
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
|
@ -22,6 +23,149 @@ from app.db import async_session_maker
|
|||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
# Connectors that call external live-search APIs (no local DB / embedding needed).
|
||||
# These are never filtered by available_document_types.
|
||||
_LIVE_SEARCH_CONNECTORS: set[str] = {
|
||||
"TAVILY_API",
|
||||
"SEARXNG_API",
|
||||
"LINKUP_API",
|
||||
"BAIDU_SEARCH_API",
|
||||
}
|
||||
|
||||
# Patterns that indicate the query has no meaningful search signal.
|
||||
# plainto_tsquery('english', '*') produces an empty tsquery and an embedding
|
||||
# of '*' is random noise, so both keyword and semantic search degrade to
|
||||
# arbitrary ordering — large documents (many chunks) dominate by chance.
|
||||
_DEGENERATE_QUERY_RE = re.compile(
|
||||
r"^[\s*?_.#@!\-/\\]+$" # only wildcards, punctuation, whitespace
|
||||
)
|
||||
|
||||
# Max chunks per document when doing a recency-based browse instead of
|
||||
# a real search. We want breadth (many docs) over depth (many chunks).
|
||||
_BROWSE_MAX_CHUNKS_PER_DOC = 5
|
||||
|
||||
|
||||
def _is_degenerate_query(query: str) -> bool:
|
||||
"""Return True when the query carries no meaningful search signal.
|
||||
|
||||
Catches wildcard patterns (``*``, ``**``), empty / whitespace-only
|
||||
strings, and single-character non-word tokens. These queries cause
|
||||
both keyword search (empty tsquery) and semantic search (meaningless
|
||||
embedding) to return effectively random results.
|
||||
"""
|
||||
stripped = query.strip()
|
||||
if not stripped:
|
||||
return True
|
||||
return bool(_DEGENERATE_QUERY_RE.match(stripped))
|
||||
|
||||
|
||||
async def _browse_recent_documents(
|
||||
search_space_id: int,
|
||||
document_type: str | None,
|
||||
top_k: int,
|
||||
start_date: datetime | None,
|
||||
end_date: datetime | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return the most-recent documents (recency-ordered, no search ranking).
|
||||
|
||||
Used as a fallback when the search query is degenerate (e.g. ``*``) and
|
||||
semantic / keyword search would produce arbitrary results. Returns
|
||||
document-grouped dicts in the same shape as ``_combined_rrf_search``
|
||||
so the rest of the pipeline works unchanged.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.db import Chunk, Document, DocumentType
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
base_conditions = [Document.search_space_id == search_space_id]
|
||||
|
||||
if document_type is not None:
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
return []
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
|
||||
if start_date is not None:
|
||||
base_conditions.append(Document.updated_at >= start_date)
|
||||
if end_date is not None:
|
||||
base_conditions.append(Document.updated_at <= end_date)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
doc_query = (
|
||||
select(Document)
|
||||
.options(joinedload(Document.search_space))
|
||||
.where(*base_conditions)
|
||||
.order_by(Document.updated_at.desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
result = await session.execute(doc_query)
|
||||
documents = result.scalars().unique().all()
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
doc_ids = [d.id for d in documents]
|
||||
|
||||
chunk_query = (
|
||||
select(Chunk)
|
||||
.where(Chunk.document_id.in_(doc_ids))
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunk_result = await session.execute(chunk_query)
|
||||
raw_chunks = chunk_result.scalars().all()
|
||||
|
||||
doc_chunk_counts: dict[int, int] = {}
|
||||
doc_chunks: dict[int, list[dict]] = {d.id: [] for d in documents}
|
||||
for chunk in raw_chunks:
|
||||
did = chunk.document_id
|
||||
count = doc_chunk_counts.get(did, 0)
|
||||
if count < _BROWSE_MAX_CHUNKS_PER_DOC:
|
||||
doc_chunks[did].append({"chunk_id": chunk.id, "content": chunk.content})
|
||||
doc_chunk_counts[did] = count + 1
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for doc in documents:
|
||||
chunks_list = doc_chunks.get(doc.id, [])
|
||||
results.append(
|
||||
{
|
||||
"document_id": doc.id,
|
||||
"content": "\n\n".join(
|
||||
c["content"] for c in chunks_list if c.get("content")
|
||||
),
|
||||
"score": 0.0,
|
||||
"chunks": chunks_list,
|
||||
"document": {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"document_type": doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None,
|
||||
"metadata": doc.document_metadata or {},
|
||||
},
|
||||
"source": doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
perf.info(
|
||||
"[kb_browse] recency browse in %.3fs docs=%d space=%d type=%s",
|
||||
time.perf_counter() - t0,
|
||||
len(results),
|
||||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Connector Constants and Normalization
|
||||
# =============================================================================
|
||||
|
|
@ -184,9 +328,23 @@ _CHARS_PER_TOKEN = 4
|
|||
# Hard-floor / ceiling so the budget is always sensible regardless of what
|
||||
# the model reports.
|
||||
_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens
|
||||
_MAX_TOOL_OUTPUT_CHARS = 400_000 # ~100K tokens
|
||||
_MAX_TOOL_OUTPUT_CHARS = 200_000 # ~50K tokens
|
||||
_MAX_CHUNK_CHARS = 8_000
|
||||
|
||||
# Rank-adaptive per-document budget allocation.
|
||||
# Top-ranked (most relevant) documents get a larger share of the budget so
|
||||
# we pack as much high-quality context as possible.
|
||||
#
|
||||
# fraction(rank) = _TOP_DOC_BUDGET_FRACTION / (1 + rank * _RANK_DECAY)
|
||||
#
|
||||
# Examples (128K budget, 8K chunk cap):
|
||||
# rank 0 → 40% → 6 chunks | rank 3 → 19% → 3 chunks
|
||||
# rank 1 → 30% → 4 chunks | rank 10 → 10% → 3 chunks (floor)
|
||||
# rank 2 → 24% → 3 chunks |
|
||||
_TOP_DOC_BUDGET_FRACTION = 0.40
|
||||
_RANK_DECAY = 0.35
|
||||
_MIN_CHUNKS_PER_DOC = 3
|
||||
|
||||
|
||||
def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
|
||||
"""Derive a character budget from the model's context window.
|
||||
|
|
@ -208,18 +366,24 @@ def format_documents_for_context(
|
|||
*,
|
||||
max_chars: int = _MAX_TOOL_OUTPUT_CHARS,
|
||||
max_chunk_chars: int = _MAX_CHUNK_CHARS,
|
||||
max_chunks_per_doc: int = 0,
|
||||
) -> str:
|
||||
"""
|
||||
Format retrieved documents into a readable context string for the LLM.
|
||||
|
||||
Documents are added in order (highest relevance first) until the character
|
||||
budget is reached. Individual chunks are capped at ``max_chunk_chars`` so
|
||||
a single oversized chunk cannot monopolize the output.
|
||||
budget is reached. Individual chunks are capped at ``max_chunk_chars`` and
|
||||
each document is limited to a dynamically computed chunk cap so a single
|
||||
large document cannot monopolize the output while still maximising the use
|
||||
of available context space.
|
||||
|
||||
Args:
|
||||
documents: List of document dictionaries from connector search
|
||||
max_chars: Approximate character budget for the entire output.
|
||||
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
|
||||
max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means
|
||||
auto-compute per document using a rank-adaptive formula so
|
||||
higher-ranked documents receive more chunks.
|
||||
|
||||
Returns:
|
||||
Formatted string with document contents and metadata
|
||||
|
|
@ -342,7 +506,23 @@ def format_documents_for_context(
|
|||
"<document_content>",
|
||||
]
|
||||
|
||||
for ch in g["chunks"]:
|
||||
# Rank-adaptive per-document chunk cap: top results get more chunks.
|
||||
if max_chunks_per_doc > 0:
|
||||
chunks_allowed = max_chunks_per_doc
|
||||
else:
|
||||
doc_fraction = _TOP_DOC_BUDGET_FRACTION / (1 + doc_idx * _RANK_DECAY)
|
||||
max_doc_chars = int(max_chars * doc_fraction)
|
||||
xml_overhead = 500
|
||||
chunks_allowed = max(
|
||||
(max_doc_chars - xml_overhead) // max(max_chunk_chars, 1),
|
||||
_MIN_CHUNKS_PER_DOC,
|
||||
)
|
||||
|
||||
chunks = g["chunks"]
|
||||
if len(chunks) > chunks_allowed:
|
||||
chunks = chunks[:chunks_allowed]
|
||||
|
||||
for ch in chunks:
|
||||
ch_content = ch["content"]
|
||||
if max_chunk_chars and len(ch_content) > max_chunk_chars:
|
||||
ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)"
|
||||
|
|
@ -359,9 +539,11 @@ def format_documents_for_context(
|
|||
doc_xml = "\n".join(doc_lines)
|
||||
doc_len = len(doc_xml)
|
||||
|
||||
# Always include at least the first document; afterwards enforce budget.
|
||||
if doc_idx > 0 and total_chars + doc_len > max_chars:
|
||||
if total_chars + doc_len > max_chars:
|
||||
remaining = total_docs - doc_idx
|
||||
if doc_idx == 0:
|
||||
parts.append(doc_xml)
|
||||
total_chars += doc_len
|
||||
parts.append(
|
||||
f"<!-- Output truncated: {remaining} more document(s) omitted "
|
||||
f"(budget {max_chars} chars). Refine your query or reduce top_k "
|
||||
|
|
@ -372,7 +554,15 @@ def format_documents_for_context(
|
|||
parts.append(doc_xml)
|
||||
total_chars += doc_len
|
||||
|
||||
return "\n".join(parts).strip()
|
||||
result = "\n".join(parts).strip()
|
||||
|
||||
# Hard safety net: if the result is still over budget (e.g. a single massive
|
||||
# first document), forcibly truncate with a closing comment.
|
||||
if len(result) > max_chars:
|
||||
truncation_msg = "\n<!-- ...output forcibly truncated to fit context window -->"
|
||||
result = result[: max_chars - len(truncation_msg)] + truncation_msg
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -390,6 +580,7 @@ async def search_knowledge_base_async(
|
|||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -408,6 +599,9 @@ async def search_knowledge_base_async(
|
|||
end_date: Optional end datetime (UTC) for filtering documents
|
||||
available_connectors: Optional list of connectors actually available in the search space.
|
||||
If provided, only these connectors will be searched.
|
||||
available_document_types: Optional list of document types that actually have indexed
|
||||
data. When provided, local connectors whose document type is
|
||||
absent are skipped entirely (no embedding / DB round-trip).
|
||||
max_input_tokens: Model context window size (tokens). Used to dynamically
|
||||
size the output so it fits within the model's limits.
|
||||
|
||||
|
|
@ -428,6 +622,23 @@ async def search_knowledge_base_async(
|
|||
)
|
||||
|
||||
connectors = _normalize_connectors(connectors_to_search, available_connectors)
|
||||
|
||||
# --- Optimization 1: skip local connectors that have zero indexed documents ---
|
||||
if available_document_types:
|
||||
doc_types_set = set(available_document_types)
|
||||
before_count = len(connectors)
|
||||
connectors = [
|
||||
c for c in connectors if c in _LIVE_SEARCH_CONNECTORS or c in doc_types_set
|
||||
]
|
||||
skipped = before_count - len(connectors)
|
||||
if skipped:
|
||||
perf.info(
|
||||
"[kb_search] skipped %d empty connectors (had %d, now %d)",
|
||||
skipped,
|
||||
before_count,
|
||||
len(connectors),
|
||||
)
|
||||
|
||||
perf.info(
|
||||
"[kb_search] searching %d connectors: %s (space=%d, top_k=%d)",
|
||||
len(connectors),
|
||||
|
|
@ -436,81 +647,126 @@ async def search_knowledge_base_async(
|
|||
top_k,
|
||||
)
|
||||
|
||||
connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
|
||||
"YOUTUBE_VIDEO": ("search_youtube", True, True, {}),
|
||||
"EXTENSION": ("search_extension", True, True, {}),
|
||||
"CRAWLED_URL": ("search_crawled_urls", True, True, {}),
|
||||
"FILE": ("search_files", True, True, {}),
|
||||
"SLACK_CONNECTOR": ("search_slack", True, True, {}),
|
||||
"TEAMS_CONNECTOR": ("search_teams", True, True, {}),
|
||||
"NOTION_CONNECTOR": ("search_notion", True, True, {}),
|
||||
"GITHUB_CONNECTOR": ("search_github", True, True, {}),
|
||||
"LINEAR_CONNECTOR": ("search_linear", True, True, {}),
|
||||
# --- Fast-path: degenerate queries (*, **, empty, etc.) ---
|
||||
# Semantic embedding of '*' is noise and plainto_tsquery('english', '*')
|
||||
# yields an empty tsquery, so both retrieval signals are useless.
|
||||
# Fall back to a recency-ordered browse that returns diverse results.
|
||||
if _is_degenerate_query(query):
|
||||
perf.info(
|
||||
"[kb_search] degenerate query %r detected - falling back to recency browse",
|
||||
query,
|
||||
)
|
||||
local_connectors = [c for c in connectors if c not in _LIVE_SEARCH_CONNECTORS]
|
||||
if not local_connectors:
|
||||
local_connectors = [None] # type: ignore[list-item]
|
||||
|
||||
browse_results = await asyncio.gather(
|
||||
*[
|
||||
_browse_recent_documents(
|
||||
search_space_id=search_space_id,
|
||||
document_type=c,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
for c in local_connectors
|
||||
]
|
||||
)
|
||||
for docs in browse_results:
|
||||
all_documents.extend(docs)
|
||||
|
||||
# Skip dedup + formatting below (browse already returns unique docs)
|
||||
# but still cap output budget.
|
||||
output_budget = _compute_tool_output_budget(max_input_tokens)
|
||||
result = format_documents_for_context(
|
||||
all_documents,
|
||||
max_chars=output_budget,
|
||||
max_chunks_per_doc=_BROWSE_MAX_CHUNKS_PER_DOC,
|
||||
)
|
||||
perf.info(
|
||||
"[kb_search] TOTAL (browse) in %.3fs total_docs=%d output_chars=%d "
|
||||
"budget=%d space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(all_documents),
|
||||
len(result),
|
||||
output_budget,
|
||||
search_space_id,
|
||||
)
|
||||
return result
|
||||
|
||||
# Specs for live-search connectors (external APIs, no local DB/embedding).
|
||||
live_connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
|
||||
"TAVILY_API": ("search_tavily", False, True, {}),
|
||||
"SEARXNG_API": ("search_searxng", False, True, {}),
|
||||
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
|
||||
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
|
||||
"DISCORD_CONNECTOR": ("search_discord", True, True, {}),
|
||||
"JIRA_CONNECTOR": ("search_jira", True, True, {}),
|
||||
"GOOGLE_CALENDAR_CONNECTOR": ("search_google_calendar", True, True, {}),
|
||||
"AIRTABLE_CONNECTOR": ("search_airtable", True, True, {}),
|
||||
"GOOGLE_GMAIL_CONNECTOR": ("search_google_gmail", True, True, {}),
|
||||
"GOOGLE_DRIVE_FILE": ("search_google_drive", True, True, {}),
|
||||
"CONFLUENCE_CONNECTOR": ("search_confluence", True, True, {}),
|
||||
"CLICKUP_CONNECTOR": ("search_clickup", True, True, {}),
|
||||
"LUMA_CONNECTOR": ("search_luma", True, True, {}),
|
||||
"ELASTICSEARCH_CONNECTOR": ("search_elasticsearch", True, True, {}),
|
||||
"NOTE": ("search_notes", True, True, {}),
|
||||
"BOOKSTACK_CONNECTOR": ("search_bookstack", True, True, {}),
|
||||
"CIRCLEBACK": ("search_circleback", True, True, {}),
|
||||
"OBSIDIAN_CONNECTOR": ("search_obsidian", True, True, {}),
|
||||
# Composio connectors
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": (
|
||||
"search_composio_google_drive",
|
||||
True,
|
||||
True,
|
||||
{},
|
||||
),
|
||||
"COMPOSIO_GMAIL_CONNECTOR": ("search_composio_gmail", True, True, {}),
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": (
|
||||
"search_composio_google_calendar",
|
||||
True,
|
||||
True,
|
||||
{},
|
||||
),
|
||||
}
|
||||
|
||||
# Keep a conservative cap to avoid overloading DB/external services.
|
||||
# --- Optimization 2: compute the query embedding once, share across all local searches ---
|
||||
precomputed_embedding: list[float] | None = None
|
||||
has_local_connectors = any(c not in _LIVE_SEARCH_CONNECTORS for c in connectors)
|
||||
if has_local_connectors:
|
||||
from app.config import config as app_config
|
||||
|
||||
t_embed = time.perf_counter()
|
||||
precomputed_embedding = app_config.embedding_model_instance.embed(query)
|
||||
perf.info(
|
||||
"[kb_search] shared embedding computed in %.3fs",
|
||||
time.perf_counter() - t_embed,
|
||||
)
|
||||
|
||||
max_parallel_searches = 4
|
||||
semaphore = asyncio.Semaphore(max_parallel_searches)
|
||||
|
||||
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
|
||||
spec = connector_specs.get(connector)
|
||||
if spec is None:
|
||||
return []
|
||||
is_live = connector in _LIVE_SEARCH_CONNECTORS
|
||||
|
||||
method_name, includes_date_range, includes_top_k, extra_kwargs = spec
|
||||
kwargs: dict[str, Any] = {
|
||||
"user_query": query,
|
||||
"search_space_id": search_space_id,
|
||||
**extra_kwargs,
|
||||
}
|
||||
if includes_top_k:
|
||||
kwargs["top_k"] = top_k
|
||||
if includes_date_range:
|
||||
kwargs["start_date"] = resolved_start_date
|
||||
kwargs["end_date"] = resolved_end_date
|
||||
if is_live:
|
||||
spec = live_connector_specs.get(connector)
|
||||
if spec is None:
|
||||
return []
|
||||
method_name, includes_date_range, includes_top_k, extra_kwargs = spec
|
||||
kwargs: dict[str, Any] = {
|
||||
"user_query": query,
|
||||
"search_space_id": search_space_id,
|
||||
**extra_kwargs,
|
||||
}
|
||||
if includes_top_k:
|
||||
kwargs["top_k"] = top_k
|
||||
if includes_date_range:
|
||||
kwargs["start_date"] = resolved_start_date
|
||||
kwargs["end_date"] = resolved_end_date
|
||||
|
||||
try:
|
||||
t_conn = time.perf_counter()
|
||||
async with semaphore, async_session_maker() as isolated_session:
|
||||
svc = ConnectorService(isolated_session, search_space_id)
|
||||
_, chunks = await getattr(svc, method_name)(**kwargs)
|
||||
perf.info(
|
||||
"[kb_search] connector=%s results=%d in %.3fs",
|
||||
connector,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_conn,
|
||||
)
|
||||
return chunks
|
||||
except Exception as e:
|
||||
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
|
||||
return []
|
||||
|
||||
# --- Optimization 3: call _combined_rrf_search directly with shared embedding ---
|
||||
try:
|
||||
# Use isolated session per connector. Shared AsyncSession cannot safely
|
||||
# run concurrent DB operations.
|
||||
t_conn = time.perf_counter()
|
||||
async with semaphore, async_session_maker() as isolated_session:
|
||||
isolated_connector_service = ConnectorService(
|
||||
isolated_session, search_space_id
|
||||
svc = ConnectorService(isolated_session, search_space_id)
|
||||
chunks = await svc._combined_rrf_search(
|
||||
query_text=query,
|
||||
search_space_id=search_space_id,
|
||||
document_type=connector,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
query_embedding=precomputed_embedding,
|
||||
)
|
||||
connector_method = getattr(isolated_connector_service, method_name)
|
||||
_, chunks = await connector_method(**kwargs)
|
||||
perf.info(
|
||||
"[kb_search] connector=%s results=%d in %.3fs",
|
||||
connector,
|
||||
|
|
@ -519,12 +775,7 @@ async def search_knowledge_base_async(
|
|||
)
|
||||
return chunks
|
||||
except Exception as e:
|
||||
perf.warning(
|
||||
"[kb_search] connector=%s FAILED in %.3fs: %s",
|
||||
connector,
|
||||
time.perf_counter() - t_conn,
|
||||
e,
|
||||
)
|
||||
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
|
||||
return []
|
||||
|
||||
t_gather = time.perf_counter()
|
||||
|
|
@ -582,12 +833,24 @@ async def search_knowledge_base_async(
|
|||
|
||||
output_budget = _compute_tool_output_budget(max_input_tokens)
|
||||
result = format_documents_for_context(deduplicated, max_chars=output_budget)
|
||||
|
||||
if len(result) > output_budget:
|
||||
perf.warning(
|
||||
"[kb_search] output STILL exceeds budget after format (%d > %d), "
|
||||
"hard truncation should have fired",
|
||||
len(result),
|
||||
output_budget,
|
||||
)
|
||||
|
||||
perf.info(
|
||||
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d space=%d",
|
||||
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
|
||||
"budget=%d max_input_tokens=%s space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(all_documents),
|
||||
len(deduplicated),
|
||||
len(result),
|
||||
output_budget,
|
||||
max_input_tokens,
|
||||
search_space_id,
|
||||
)
|
||||
return result
|
||||
|
|
@ -628,11 +891,15 @@ class SearchKnowledgeBaseInput(BaseModel):
|
|||
"""Input schema for the search_knowledge_base tool."""
|
||||
|
||||
query: str = Field(
|
||||
description="The search query - be specific and include key terms"
|
||||
description=(
|
||||
"The search query - use specific natural language terms. "
|
||||
"NEVER use wildcards like '*' or '**'; instead describe what you want "
|
||||
"(e.g. 'recent meeting notes' or 'project architecture overview')."
|
||||
),
|
||||
)
|
||||
top_k: int = Field(
|
||||
default=10,
|
||||
description="Number of results to retrieve (default: 10)",
|
||||
description="Number of results to retrieve (default: 10). Keep ≤20 for focused searches.",
|
||||
)
|
||||
start_date: str | None = Field(
|
||||
default=None,
|
||||
|
|
@ -695,6 +962,10 @@ Focus searches on these types for best results."""
|
|||
Use this tool to find documents, notes, files, web pages, and other content that may help answer the user's question.
|
||||
|
||||
IMPORTANT:
|
||||
- Always craft specific, descriptive search queries using natural language keywords.
|
||||
Good: "quarterly sales report Q3", "Python API authentication design".
|
||||
Bad: "*", "**", "everything", single characters. Wildcard/empty queries yield poor results.
|
||||
- Prefer multiple focused searches over a single broad one with high top_k.
|
||||
- If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below.
|
||||
- If `connectors_to_search` is omitted/empty, the system will search broadly.
|
||||
- Only connectors that are enabled/configured for this search space are available.{doc_types_info}
|
||||
|
|
@ -710,6 +981,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
|
|||
|
||||
# Capture for closure
|
||||
_available_connectors = available_connectors
|
||||
_available_document_types = available_document_types
|
||||
|
||||
async def _search_knowledge_base_impl(
|
||||
query: str,
|
||||
|
|
@ -739,6 +1011,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
|
|||
start_date=parsed_start,
|
||||
end_date=parsed_end,
|
||||
available_connectors=_available_connectors,
|
||||
available_document_types=_available_document_types,
|
||||
max_input_tokens=max_input_tokens,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -145,10 +145,12 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
thread_id=deps["thread_id"],
|
||||
connector_service=deps.get("connector_service"),
|
||||
available_connectors=deps.get("available_connectors"),
|
||||
available_document_types=deps.get("available_document_types"),
|
||||
),
|
||||
requires=["search_space_id", "thread_id"],
|
||||
# connector_service and available_connectors are optional —
|
||||
# when missing, source_strategy="kb_search" degrades gracefully to "provided"
|
||||
# connector_service, available_connectors, and available_document_types
|
||||
# are optional — when missing, source_strategy="kb_search" degrades
|
||||
# gracefully to "provided"
|
||||
),
|
||||
# Link preview tool - fetches Open Graph metadata for URLs
|
||||
ToolDefinition(
|
||||
|
|
|
|||
|
|
@ -559,6 +559,7 @@ def create_generate_report_tool(
|
|||
thread_id: int | None = None,
|
||||
connector_service: ConnectorService | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_report tool with injected dependencies.
|
||||
|
|
@ -838,6 +839,7 @@ def create_generate_report_tool(
|
|||
connector_service=kb_connector_svc,
|
||||
top_k=10,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
)
|
||||
|
||||
kb_results = await asyncio.gather(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from datetime import datetime
|
|||
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_MAX_FETCH_CHUNKS_PER_DOC = 30
|
||||
|
||||
|
||||
class ChucksHybridSearchRetriever:
|
||||
def __init__(self, db_session):
|
||||
|
|
@ -346,8 +348,9 @@ class ChucksHybridSearchRetriever:
|
|||
if not doc_ids:
|
||||
return []
|
||||
|
||||
# Fetch ALL chunks for selected documents in a single query so the final prompt can cite
|
||||
# any chunk from those documents.
|
||||
# Fetch chunks for selected documents. We cap per document to avoid
|
||||
# loading hundreds of chunks for a single large file while still
|
||||
# ensuring the chunks that matched the RRF query are always included.
|
||||
chunk_query = (
|
||||
select(Chunk)
|
||||
.options(joinedload(Chunk.document))
|
||||
|
|
@ -357,7 +360,20 @@ class ChucksHybridSearchRetriever:
|
|||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunks_result = await self.db_session.execute(chunk_query)
|
||||
all_chunks = chunks_result.scalars().all()
|
||||
raw_chunks = chunks_result.scalars().all()
|
||||
|
||||
matched_chunk_ids: set[int] = {
|
||||
item["chunk_id"] for item in serialized_chunk_results
|
||||
}
|
||||
|
||||
doc_chunk_counts: dict[int, int] = {}
|
||||
all_chunks: list = []
|
||||
for chunk in raw_chunks:
|
||||
did = chunk.document_id
|
||||
count = doc_chunk_counts.get(did, 0)
|
||||
if chunk.id in matched_chunk_ids or count < _MAX_FETCH_CHUNKS_PER_DOC:
|
||||
all_chunks.append(chunk)
|
||||
doc_chunk_counts[did] = count + 1
|
||||
|
||||
# Assemble final doc-grouped results in the same order as doc_ids
|
||||
doc_map: dict[int, dict] = {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from datetime import datetime
|
|||
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_MAX_FETCH_CHUNKS_PER_DOC = 30
|
||||
|
||||
|
||||
class DocumentHybridSearchRetriever:
|
||||
def __init__(self, db_session):
|
||||
|
|
@ -279,7 +281,8 @@ class DocumentHybridSearchRetriever:
|
|||
# Collect document IDs for chunk fetching
|
||||
doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores]
|
||||
|
||||
# Fetch ALL chunks for these documents in a single query
|
||||
# Fetch chunks for these documents, capped per document to avoid
|
||||
# loading hundreds of chunks for a single large file.
|
||||
chunks_query = (
|
||||
select(Chunk)
|
||||
.options(joinedload(Chunk.document))
|
||||
|
|
@ -287,7 +290,16 @@ class DocumentHybridSearchRetriever:
|
|||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunks_result = await self.db_session.execute(chunks_query)
|
||||
chunks = chunks_result.scalars().all()
|
||||
raw_chunks = chunks_result.scalars().all()
|
||||
|
||||
doc_chunk_counts: dict[int, int] = {}
|
||||
chunks: list = []
|
||||
for chunk in raw_chunks:
|
||||
did = chunk.document_id
|
||||
count = doc_chunk_counts.get(did, 0)
|
||||
if count < _MAX_FETCH_CHUNKS_PER_DOC:
|
||||
chunks.append(chunk)
|
||||
doc_chunk_counts[did] = count + 1
|
||||
|
||||
# Assemble doc-grouped results
|
||||
doc_map: dict[int, dict] = {
|
||||
|
|
|
|||
|
|
@ -224,6 +224,7 @@ class ConnectorService:
|
|||
top_k: int = 20,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list[float] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Perform combined search using both chunk-based and document-based hybrid search,
|
||||
|
|
@ -260,14 +261,15 @@ class ConnectorService:
|
|||
# Get more results from each retriever for better fusion
|
||||
retriever_top_k = top_k * 2
|
||||
|
||||
# Pre-compute the embedding once so both retrievers reuse it.
|
||||
t_embed = time.perf_counter()
|
||||
query_embedding = config.embedding_model_instance.embed(query_text)
|
||||
perf.info(
|
||||
"[connector_svc] _combined_rrf embedding in %.3fs type=%s",
|
||||
time.perf_counter() - t_embed,
|
||||
document_type,
|
||||
)
|
||||
# Reuse caller-provided embedding or compute once for both retrievers.
|
||||
if query_embedding is None:
|
||||
t_embed = time.perf_counter()
|
||||
query_embedding = config.embedding_model_instance.embed(query_text)
|
||||
perf.info(
|
||||
"[connector_svc] _combined_rrf embedding in %.3fs type=%s",
|
||||
time.perf_counter() - t_embed,
|
||||
document_type,
|
||||
)
|
||||
|
||||
search_kwargs = {
|
||||
"query_text": query_text,
|
||||
|
|
|
|||
|
|
@ -159,26 +159,95 @@ class LLMRouterService:
|
|||
# Merge with provided settings
|
||||
final_settings = {**default_settings, **instance._router_settings}
|
||||
|
||||
# Build a "auto-large" fallback group with deployments whose context
|
||||
# window exceeds the smallest deployment. This lets the router
|
||||
# automatically fall back to a bigger-context model when gpt-4o (128K)
|
||||
# hits ContextWindowExceededError.
|
||||
full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
|
||||
|
||||
try:
|
||||
instance._router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy=final_settings.get(
|
||||
router_kwargs: dict[str, Any] = {
|
||||
"model_list": full_model_list,
|
||||
"routing_strategy": final_settings.get(
|
||||
"routing_strategy", "usage-based-routing"
|
||||
),
|
||||
num_retries=final_settings.get("num_retries", 3),
|
||||
allowed_fails=final_settings.get("allowed_fails", 3),
|
||||
cooldown_time=final_settings.get("cooldown_time", 60),
|
||||
set_verbose=False, # Disable verbose logging in production
|
||||
)
|
||||
"num_retries": final_settings.get("num_retries", 3),
|
||||
"allowed_fails": final_settings.get("allowed_fails", 3),
|
||||
"cooldown_time": final_settings.get("cooldown_time", 60),
|
||||
"set_verbose": False,
|
||||
}
|
||||
if ctx_fallbacks:
|
||||
router_kwargs["context_window_fallbacks"] = ctx_fallbacks
|
||||
|
||||
instance._router = Router(**router_kwargs)
|
||||
instance._initialized = True
|
||||
logger.info(
|
||||
f"LLM Router initialized with {len(model_list)} deployments, "
|
||||
f"strategy: {final_settings.get('routing_strategy')}"
|
||||
"LLM Router initialized with %d deployments, "
|
||||
"strategy: %s, context_window_fallbacks: %s",
|
||||
len(model_list),
|
||||
final_settings.get("routing_strategy"),
|
||||
ctx_fallbacks or "none",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def _build_context_fallback_groups(
|
||||
cls, model_list: list[dict]
|
||||
) -> tuple[list[dict], list[dict[str, list[str]]] | None]:
|
||||
"""Create an ``auto-large`` model group for context-window fallbacks.
|
||||
|
||||
Uses ``litellm.get_model_info`` to discover the context window of each
|
||||
deployment. Deployments whose ``max_input_tokens`` exceeds the smallest
|
||||
window are duplicated into an ``auto-large`` group. The returned
|
||||
fallback config tells the Router: on ``ContextWindowExceededError`` for
|
||||
``auto``, retry with ``auto-large``.
|
||||
|
||||
Returns:
|
||||
(full_model_list, context_window_fallbacks) — ``full_model_list``
|
||||
contains the original entries plus any ``auto-large`` duplicates.
|
||||
``context_window_fallbacks`` is ``None`` when every deployment has
|
||||
the same context size (no useful fallback).
|
||||
"""
|
||||
from litellm import get_model_info
|
||||
|
||||
ctx_map: dict[str, int] = {}
|
||||
for dep in model_list:
|
||||
params = dep.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
try:
|
||||
info = get_model_info(base_model)
|
||||
ctx = info.get("max_input_tokens")
|
||||
if isinstance(ctx, int) and ctx > 0:
|
||||
ctx_map[base_model] = ctx
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not ctx_map:
|
||||
return model_list, None
|
||||
|
||||
min_ctx = min(ctx_map.values())
|
||||
|
||||
large_deployments: list[dict] = []
|
||||
for dep in model_list:
|
||||
params = dep.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
if ctx_map.get(base_model, 0) > min_ctx:
|
||||
dup = {**dep, "model_name": "auto-large"}
|
||||
large_deployments.append(dup)
|
||||
|
||||
if not large_deployments:
|
||||
return model_list, None
|
||||
|
||||
logger.info(
|
||||
"Context-window fallback: %d large-context deployments "
|
||||
"(min_ctx=%d) added to 'auto-large' group",
|
||||
len(large_deployments),
|
||||
min_ctx,
|
||||
)
|
||||
return model_list + large_deployments, [{"auto": ["auto-large"]}]
|
||||
|
||||
@classmethod
|
||||
def _config_to_deployment(cls, config: dict) -> dict | None:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue