mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): colocate KB-search tool with its sole consumer; fix report ImportError
shared/tools/knowledge_base.py had exactly one production consumer: the report deliverable, which imported it via `from .knowledge_base import ...` -- a sibling path that did not exist, so the report KB-search path would raise ImportError at runtime. Move the module next to report.py (subagents/builtins/deliverables/tools/) which makes that relative import valid, and move its only dependency (shared/utils.py date helpers) to multi_agent_chat/shared/date_filters.py, shared between the KB tool and the knowledge_search middleware. Drop the now-unused knowledge-base re-exports from the shared/tools barrel and repoint the integration tests. import-all + error-contract stay green.
This commit is contained in:
parent
a7d7155039
commit
2db4ad479e
6 changed files with 12 additions and 17 deletions
|
|
@ -1,24 +1,14 @@
|
|||
"""Cross-agent shared tools and tool metadata.
|
||||
|
||||
Tool *implementations* live with the agents that own them (e.g. deliverable
|
||||
generators under ``subagents/builtins/deliverables/tools``). This package
|
||||
holds only the genuinely shared pieces: the display-metadata catalog and the
|
||||
knowledge-base helpers used across agents.
|
||||
generators and their knowledge-base search helper under
|
||||
``subagents/builtins/deliverables/tools``). This package holds only the
|
||||
genuinely shared piece: the display-metadata catalog.
|
||||
"""
|
||||
|
||||
from .catalog import TOOL_CATALOG, ToolMetadata
|
||||
from .knowledge_base import (
|
||||
CONNECTOR_DESCRIPTIONS,
|
||||
format_documents_for_context,
|
||||
search_knowledge_base_async,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Tool catalog (display metadata)
|
||||
"TOOL_CATALOG",
|
||||
"ToolMetadata",
|
||||
# Knowledge base utilities
|
||||
"CONNECTOR_DESCRIPTIONS",
|
||||
"format_documents_for_context",
|
||||
"search_knowledge_base_async",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,817 +0,0 @@
|
|||
"""
|
||||
Knowledge base search tool for the SurfSense agent.
|
||||
|
||||
This module provides:
|
||||
- Connector constants and normalization
|
||||
- Async knowledge base search across multiple connectors
|
||||
- Document formatting for LLM context
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import NATIVE_TO_LEGACY_DOCTYPE, shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
# Connectors that call external live-search APIs. These are handled by the
|
||||
# ``web_search`` tool and must be excluded from knowledge-base searches.
|
||||
_LIVE_SEARCH_CONNECTORS: set[str] = {
|
||||
"TAVILY_API",
|
||||
"LINKUP_API",
|
||||
"BAIDU_SEARCH_API",
|
||||
}
|
||||
|
||||
# Patterns that indicate the query has no meaningful search signal.
|
||||
# plainto_tsquery('english', '*') produces an empty tsquery and an embedding
|
||||
# of '*' is random noise, so both keyword and semantic search degrade to
|
||||
# arbitrary ordering — large documents (many chunks) dominate by chance.
|
||||
_DEGENERATE_QUERY_RE = re.compile(
|
||||
r"^[\s*?_.#@!\-/\\]+$" # only wildcards, punctuation, whitespace
|
||||
)
|
||||
|
||||
# Max chunks per document when doing a recency-based browse instead of
|
||||
# a real search. We want breadth (many docs) over depth (many chunks).
|
||||
_BROWSE_MAX_CHUNKS_PER_DOC = 5
|
||||
|
||||
|
||||
def _is_degenerate_query(query: str) -> bool:
|
||||
"""Return True when the query carries no meaningful search signal.
|
||||
|
||||
Catches wildcard patterns (``*``, ``**``), empty / whitespace-only
|
||||
strings, and single-character non-word tokens. These queries cause
|
||||
both keyword search (empty tsquery) and semantic search (meaningless
|
||||
embedding) to return effectively random results.
|
||||
"""
|
||||
stripped = query.strip()
|
||||
if not stripped:
|
||||
return True
|
||||
return bool(_DEGENERATE_QUERY_RE.match(stripped))
|
||||
|
||||
|
||||
async def _browse_recent_documents(
|
||||
search_space_id: int,
|
||||
document_type: str | list[str] | None,
|
||||
top_k: int,
|
||||
start_date: datetime | None,
|
||||
end_date: datetime | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return the most-recent documents (recency-ordered, no search ranking).
|
||||
|
||||
Used as a fallback when the search query is degenerate (e.g. ``*``) and
|
||||
semantic / keyword search would produce arbitrary results. Returns
|
||||
document-grouped dicts in the same shape as ``_combined_rrf_search``
|
||||
so the rest of the pipeline works unchanged.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.db import Chunk, Document, DocumentType
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
base_conditions = [Document.search_space_id == search_space_id]
|
||||
|
||||
if document_type is not None:
|
||||
type_list = (
|
||||
document_type if isinstance(document_type, list) else [document_type]
|
||||
)
|
||||
doc_type_enums = []
|
||||
for dt in type_list:
|
||||
if isinstance(dt, str):
|
||||
with contextlib.suppress(KeyError):
|
||||
doc_type_enums.append(DocumentType[dt])
|
||||
else:
|
||||
doc_type_enums.append(dt)
|
||||
if not doc_type_enums:
|
||||
return []
|
||||
if len(doc_type_enums) == 1:
|
||||
base_conditions.append(Document.document_type == doc_type_enums[0])
|
||||
else:
|
||||
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
||||
|
||||
if start_date is not None:
|
||||
base_conditions.append(Document.updated_at >= start_date)
|
||||
if end_date is not None:
|
||||
base_conditions.append(Document.updated_at <= end_date)
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
doc_query = (
|
||||
select(Document)
|
||||
.options(joinedload(Document.search_space))
|
||||
.where(*base_conditions)
|
||||
.order_by(Document.updated_at.desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
result = await session.execute(doc_query)
|
||||
documents = result.scalars().unique().all()
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
doc_ids = [d.id for d in documents]
|
||||
|
||||
chunk_query = (
|
||||
select(Chunk)
|
||||
.where(Chunk.document_id.in_(doc_ids))
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunk_result = await session.execute(chunk_query)
|
||||
raw_chunks = chunk_result.scalars().all()
|
||||
|
||||
doc_chunk_counts: dict[int, int] = {}
|
||||
doc_chunks: dict[int, list[dict]] = {d.id: [] for d in documents}
|
||||
for chunk in raw_chunks:
|
||||
did = chunk.document_id
|
||||
count = doc_chunk_counts.get(did, 0)
|
||||
if count < _BROWSE_MAX_CHUNKS_PER_DOC:
|
||||
doc_chunks[did].append({"chunk_id": chunk.id, "content": chunk.content})
|
||||
doc_chunk_counts[did] = count + 1
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for doc in documents:
|
||||
chunks_list = doc_chunks.get(doc.id, [])
|
||||
results.append(
|
||||
{
|
||||
"document_id": doc.id,
|
||||
"content": "\n\n".join(
|
||||
c["content"] for c in chunks_list if c.get("content")
|
||||
),
|
||||
"score": 0.0,
|
||||
"chunks": chunks_list,
|
||||
"document": {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"document_type": doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None,
|
||||
"metadata": doc.document_metadata or {},
|
||||
},
|
||||
"source": doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
perf.info(
|
||||
"[kb_browse] recency browse in %.3fs docs=%d space=%d type=%s",
|
||||
time.perf_counter() - t0,
|
||||
len(results),
|
||||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Connector Constants and Normalization
|
||||
# =============================================================================
|
||||
|
||||
# Canonical connector values used internally by ConnectorService
|
||||
# Includes all document types and search source connectors
|
||||
_ALL_CONNECTORS: list[str] = [
|
||||
"EXTENSION",
|
||||
"FILE",
|
||||
"SLACK_CONNECTOR",
|
||||
"TEAMS_CONNECTOR",
|
||||
"NOTION_CONNECTOR",
|
||||
"YOUTUBE_VIDEO",
|
||||
"GITHUB_CONNECTOR",
|
||||
"ELASTICSEARCH_CONNECTOR",
|
||||
"LINEAR_CONNECTOR",
|
||||
"JIRA_CONNECTOR",
|
||||
"CONFLUENCE_CONNECTOR",
|
||||
"CLICKUP_CONNECTOR",
|
||||
"GOOGLE_CALENDAR_CONNECTOR",
|
||||
"GOOGLE_GMAIL_CONNECTOR",
|
||||
"GOOGLE_DRIVE_FILE",
|
||||
"DISCORD_CONNECTOR",
|
||||
"AIRTABLE_CONNECTOR",
|
||||
"LUMA_CONNECTOR",
|
||||
"NOTE",
|
||||
"BOOKSTACK_CONNECTOR",
|
||||
"CRAWLED_URL",
|
||||
"CIRCLEBACK",
|
||||
"OBSIDIAN_CONNECTOR",
|
||||
"ONEDRIVE_FILE",
|
||||
"DROPBOX_FILE",
|
||||
]
|
||||
|
||||
# Human-readable descriptions for each connector type
|
||||
# Used for generating dynamic docstrings and informing the LLM
|
||||
CONNECTOR_DESCRIPTIONS: dict[str, str] = {
|
||||
"EXTENSION": "Web content saved via SurfSense browser extension (personal browsing history)",
|
||||
"FILE": "User-uploaded documents (PDFs, Word, etc.) (personal files)",
|
||||
"NOTE": "SurfSense Notes (notes created inside SurfSense)",
|
||||
"SLACK_CONNECTOR": "Slack conversations and shared content (personal workspace communications)",
|
||||
"TEAMS_CONNECTOR": "Microsoft Teams messages and conversations (personal Teams communications)",
|
||||
"NOTION_CONNECTOR": "Notion workspace pages and databases (personal knowledge management)",
|
||||
"YOUTUBE_VIDEO": "YouTube video transcripts and metadata (personally saved videos)",
|
||||
"GITHUB_CONNECTOR": "GitHub repository content and issues (personal repositories and interactions)",
|
||||
"ELASTICSEARCH_CONNECTOR": "Elasticsearch indexed documents and data (personal Elasticsearch instances)",
|
||||
"LINEAR_CONNECTOR": "Linear project issues and discussions (personal project management)",
|
||||
"JIRA_CONNECTOR": "Jira project issues, tickets, and comments (personal project tracking)",
|
||||
"CONFLUENCE_CONNECTOR": "Confluence pages and comments (personal project documentation)",
|
||||
"CLICKUP_CONNECTOR": "ClickUp tasks and project data (personal task management)",
|
||||
"GOOGLE_CALENDAR_CONNECTOR": "Google Calendar events, meetings, and schedules (personal calendar)",
|
||||
"GOOGLE_GMAIL_CONNECTOR": "Google Gmail emails and conversations (personal emails)",
|
||||
"GOOGLE_DRIVE_FILE": "Google Drive files and documents (personal cloud storage)",
|
||||
"DISCORD_CONNECTOR": "Discord server conversations and shared content (personal community)",
|
||||
"AIRTABLE_CONNECTOR": "Airtable records, tables, and database content (personal data)",
|
||||
"LUMA_CONNECTOR": "Luma events and meetings",
|
||||
"WEBCRAWLER_CONNECTOR": "Webpages indexed by SurfSense (personally selected websites)",
|
||||
"CRAWLED_URL": "Webpages indexed by SurfSense (personally selected websites)",
|
||||
"BOOKSTACK_CONNECTOR": "BookStack pages (personal documentation)",
|
||||
"CIRCLEBACK": "Circleback meeting notes, transcripts, and action items",
|
||||
"OBSIDIAN_CONNECTOR": "Obsidian vault notes and markdown files (personal notes)",
|
||||
"ONEDRIVE_FILE": "Microsoft OneDrive files and documents (personal cloud storage)",
|
||||
"DROPBOX_FILE": "Dropbox files and documents (cloud storage)",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_connectors(
|
||||
connectors_to_search: list[str] | None,
|
||||
available_connectors: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Normalize connectors provided by the model.
|
||||
|
||||
- Accepts user-facing enums like WEBCRAWLER_CONNECTOR and maps them to canonical
|
||||
ConnectorService types.
|
||||
- Drops unknown values.
|
||||
- If available_connectors is provided, only includes connectors from that list.
|
||||
- If connectors_to_search is None/empty, defaults to available_connectors or all.
|
||||
|
||||
Args:
|
||||
connectors_to_search: List of connectors requested by the model
|
||||
available_connectors: List of connectors actually available in the search space
|
||||
|
||||
Returns:
|
||||
List of normalized connector strings to search
|
||||
"""
|
||||
# Determine the set of valid connectors to consider
|
||||
valid_set = (
|
||||
set(available_connectors) if available_connectors else set(_ALL_CONNECTORS)
|
||||
)
|
||||
valid_set -= _LIVE_SEARCH_CONNECTORS
|
||||
|
||||
if not connectors_to_search:
|
||||
base = (
|
||||
list(available_connectors)
|
||||
if available_connectors
|
||||
else list(_ALL_CONNECTORS)
|
||||
)
|
||||
return [c for c in base if c not in _LIVE_SEARCH_CONNECTORS]
|
||||
|
||||
normalized: list[str] = []
|
||||
for raw in connectors_to_search:
|
||||
c = (raw or "").strip().upper()
|
||||
if not c:
|
||||
continue
|
||||
# Map user-facing aliases to canonical names
|
||||
if c == "WEBCRAWLER_CONNECTOR":
|
||||
c = "CRAWLED_URL"
|
||||
normalized.append(c)
|
||||
|
||||
# de-dupe while preserving order + filter to valid connectors
|
||||
seen: set[str] = set()
|
||||
out: list[str] = []
|
||||
for c in normalized:
|
||||
if c in seen:
|
||||
continue
|
||||
# Only include if it's a known connector AND available
|
||||
if c not in _ALL_CONNECTORS:
|
||||
continue
|
||||
if c not in valid_set:
|
||||
continue
|
||||
seen.add(c)
|
||||
out.append(c)
|
||||
|
||||
# Fallback to all available if nothing matched
|
||||
if not out:
|
||||
base = (
|
||||
list(available_connectors)
|
||||
if available_connectors
|
||||
else list(_ALL_CONNECTORS)
|
||||
)
|
||||
return [c for c in base if c not in _LIVE_SEARCH_CONNECTORS]
|
||||
return out
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Document Formatting
|
||||
# =============================================================================
|
||||
|
||||
|
||||
# Fraction of the model's context window (in characters) that a single tool
|
||||
# result is allowed to occupy. The remainder is reserved for system prompt,
|
||||
# conversation history, and model output. With ~4 chars/token this gives a
|
||||
# tool result ≈ 25 % of the context budget in tokens.
|
||||
_TOOL_OUTPUT_CONTEXT_FRACTION = 0.25
|
||||
_CHARS_PER_TOKEN = 4
|
||||
|
||||
# Hard-floor / ceiling so the budget is always sensible regardless of what
|
||||
# the model reports.
|
||||
_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens
|
||||
_MAX_TOOL_OUTPUT_CHARS = 200_000 # ~50K tokens
|
||||
_MAX_CHUNK_CHARS = 8_000
|
||||
|
||||
# Rank-adaptive per-document budget allocation.
|
||||
# Top-ranked (most relevant) documents get a larger share of the budget so
|
||||
# we pack as much high-quality context as possible.
|
||||
#
|
||||
# fraction(rank) = _TOP_DOC_BUDGET_FRACTION / (1 + rank * _RANK_DECAY)
|
||||
#
|
||||
# Examples (128K budget, 8K chunk cap):
|
||||
# rank 0 → 40% → 6 chunks | rank 3 → 19% → 3 chunks
|
||||
# rank 1 → 30% → 4 chunks | rank 10 → 10% → 3 chunks (floor)
|
||||
# rank 2 → 24% → 3 chunks |
|
||||
_TOP_DOC_BUDGET_FRACTION = 0.40
|
||||
_RANK_DECAY = 0.35
|
||||
_MIN_CHUNKS_PER_DOC = 3
|
||||
|
||||
|
||||
def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
|
||||
"""Derive a character budget from the model's context window.
|
||||
|
||||
Uses ``litellm.get_model_info`` via the value already resolved by
|
||||
``ChatLiteLLMRouter`` / ``ChatLiteLLM`` and passed through the dependency
|
||||
chain as ``max_input_tokens``. Falls back to a conservative default when
|
||||
the value is unavailable.
|
||||
"""
|
||||
if max_input_tokens is None or max_input_tokens <= 0:
|
||||
return _MIN_TOOL_OUTPUT_CHARS # conservative fallback
|
||||
|
||||
budget = int(max_input_tokens * _CHARS_PER_TOKEN * _TOOL_OUTPUT_CONTEXT_FRACTION)
|
||||
return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS))
|
||||
|
||||
|
||||
_INTERNAL_METADATA_KEYS: frozenset[str] = frozenset(
|
||||
{
|
||||
"message_id",
|
||||
"thread_id",
|
||||
"event_id",
|
||||
"calendar_id",
|
||||
"google_drive_file_id",
|
||||
"onedrive_file_id",
|
||||
"dropbox_file_id",
|
||||
"page_id",
|
||||
"issue_id",
|
||||
"connector_id",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def format_documents_for_context(
|
||||
documents: list[dict[str, Any]],
|
||||
*,
|
||||
max_chars: int = _MAX_TOOL_OUTPUT_CHARS,
|
||||
max_chunk_chars: int = _MAX_CHUNK_CHARS,
|
||||
max_chunks_per_doc: int = 0,
|
||||
) -> str:
|
||||
"""
|
||||
Format retrieved documents into a readable context string for the LLM.
|
||||
|
||||
Documents are added in order (highest relevance first) until the character
|
||||
budget is reached. Individual chunks are capped at ``max_chunk_chars`` and
|
||||
each document is limited to a dynamically computed chunk cap so a single
|
||||
large document cannot monopolize the output while still maximising the use
|
||||
of available context space.
|
||||
|
||||
Args:
|
||||
documents: List of document dictionaries from connector search
|
||||
max_chars: Approximate character budget for the entire output.
|
||||
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
|
||||
max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means
|
||||
auto-compute per document using a rank-adaptive formula so
|
||||
higher-ranked documents receive more chunks.
|
||||
|
||||
Returns:
|
||||
Formatted string with document contents and metadata
|
||||
"""
|
||||
if not documents:
|
||||
return ""
|
||||
|
||||
# Group chunks by document id (preferred) to produce the XML structure.
|
||||
#
|
||||
# IMPORTANT: ConnectorService returns **document-grouped** results of the form:
|
||||
# {
|
||||
# "document": {...},
|
||||
# "chunks": [{"chunk_id": 123, "content": "..."}, ...],
|
||||
# "source": "NOTION_CONNECTOR" | "FILE" | ...
|
||||
# }
|
||||
#
|
||||
# We must preserve chunk_id so citations like [citation:123] are possible.
|
||||
grouped: dict[str, dict[str, Any]] = {}
|
||||
|
||||
for doc in documents:
|
||||
document_info = (doc.get("document") or {}) if isinstance(doc, dict) else {}
|
||||
metadata = (
|
||||
(document_info.get("metadata") or {})
|
||||
if isinstance(document_info, dict)
|
||||
else {}
|
||||
)
|
||||
if not metadata and isinstance(doc, dict):
|
||||
# Some result shapes may place metadata at the top level.
|
||||
metadata = doc.get("metadata") or {}
|
||||
|
||||
source = (
|
||||
(doc.get("source") if isinstance(doc, dict) else None)
|
||||
or document_info.get("document_type")
|
||||
or metadata.get("document_type")
|
||||
or "UNKNOWN"
|
||||
)
|
||||
|
||||
# Document identity (prefer document_id; otherwise fall back to type+title+url)
|
||||
document_id_val = document_info.get("id")
|
||||
title = (
|
||||
document_info.get("title") or metadata.get("title") or "Untitled Document"
|
||||
)
|
||||
url = (
|
||||
metadata.get("url")
|
||||
or metadata.get("source")
|
||||
or metadata.get("page_url")
|
||||
or ""
|
||||
)
|
||||
|
||||
doc_key = (
|
||||
str(document_id_val)
|
||||
if document_id_val is not None
|
||||
else f"{source}::{title}::{url}"
|
||||
)
|
||||
|
||||
if doc_key not in grouped:
|
||||
grouped[doc_key] = {
|
||||
"document_id": document_id_val
|
||||
if document_id_val is not None
|
||||
else doc_key,
|
||||
"document_type": metadata.get("document_type") or source,
|
||||
"title": title,
|
||||
"url": url,
|
||||
"metadata": metadata,
|
||||
"chunks": [],
|
||||
}
|
||||
|
||||
# Prefer document-grouped chunks if available
|
||||
chunks_list = doc.get("chunks") if isinstance(doc, dict) else None
|
||||
if isinstance(chunks_list, list) and chunks_list:
|
||||
for ch in chunks_list:
|
||||
if not isinstance(ch, dict):
|
||||
continue
|
||||
chunk_id = ch.get("chunk_id") or ch.get("id")
|
||||
content = (ch.get("content") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
grouped[doc_key]["chunks"].append(
|
||||
{"chunk_id": chunk_id, "content": content}
|
||||
)
|
||||
continue
|
||||
|
||||
# Fallback: treat this as a flat chunk-like object
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
chunk_id = doc.get("chunk_id") or doc.get("id")
|
||||
content = (doc.get("content") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
grouped[doc_key]["chunks"].append({"chunk_id": chunk_id, "content": content})
|
||||
|
||||
# Live search connectors whose results should be cited by URL rather than
|
||||
# a numeric chunk_id (the numeric IDs are meaningless auto-incremented counters).
|
||||
live_search_connectors = {
|
||||
"TAVILY_API",
|
||||
"LINKUP_API",
|
||||
"BAIDU_SEARCH_API",
|
||||
}
|
||||
|
||||
# Render XML expected by citation instructions, respecting the char budget.
|
||||
parts: list[str] = []
|
||||
total_chars = 0
|
||||
total_docs = len(grouped)
|
||||
|
||||
for doc_idx, g in enumerate(grouped.values()):
|
||||
metadata_clean = {
|
||||
k: v for k, v in g["metadata"].items() if k not in _INTERNAL_METADATA_KEYS
|
||||
}
|
||||
metadata_json = json.dumps(metadata_clean, ensure_ascii=False)
|
||||
is_live_search = g["document_type"] in live_search_connectors
|
||||
|
||||
doc_lines: list[str] = [
|
||||
"<document>",
|
||||
"<document_metadata>",
|
||||
f" <document_id>{g['document_id']}</document_id>",
|
||||
f" <document_type>{g['document_type']}</document_type>",
|
||||
f" <title><![CDATA[{g['title']}]]></title>",
|
||||
f" <url><![CDATA[{g['url']}]]></url>",
|
||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
||||
"</document_metadata>",
|
||||
"",
|
||||
"<document_content>",
|
||||
]
|
||||
|
||||
# Rank-adaptive per-document chunk cap: top results get more chunks.
|
||||
if max_chunks_per_doc > 0:
|
||||
chunks_allowed = max_chunks_per_doc
|
||||
else:
|
||||
doc_fraction = _TOP_DOC_BUDGET_FRACTION / (1 + doc_idx * _RANK_DECAY)
|
||||
max_doc_chars = int(max_chars * doc_fraction)
|
||||
xml_overhead = 500
|
||||
chunks_allowed = max(
|
||||
(max_doc_chars - xml_overhead) // max(max_chunk_chars, 1),
|
||||
_MIN_CHUNKS_PER_DOC,
|
||||
)
|
||||
|
||||
chunks = g["chunks"]
|
||||
if len(chunks) > chunks_allowed:
|
||||
chunks = chunks[:chunks_allowed]
|
||||
|
||||
for ch in chunks:
|
||||
ch_content = ch["content"]
|
||||
if max_chunk_chars and len(ch_content) > max_chunk_chars:
|
||||
ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)"
|
||||
ch_id = g["url"] if (is_live_search and g["url"]) else ch["chunk_id"]
|
||||
if ch_id is None:
|
||||
doc_lines.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
|
||||
else:
|
||||
doc_lines.append(
|
||||
f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>"
|
||||
)
|
||||
|
||||
doc_lines.extend(["</document_content>", "</document>", ""])
|
||||
|
||||
doc_xml = "\n".join(doc_lines)
|
||||
doc_len = len(doc_xml)
|
||||
|
||||
if total_chars + doc_len > max_chars:
|
||||
remaining = total_docs - doc_idx
|
||||
if doc_idx == 0:
|
||||
parts.append(doc_xml)
|
||||
total_chars += doc_len
|
||||
parts.append(
|
||||
f"<!-- Output truncated: {remaining} more document(s) omitted "
|
||||
f"(budget {max_chars} chars). Refine your query or reduce top_k "
|
||||
f"to retrieve different results. -->"
|
||||
)
|
||||
break
|
||||
|
||||
parts.append(doc_xml)
|
||||
total_chars += doc_len
|
||||
|
||||
result = "\n".join(parts).strip()
|
||||
|
||||
# Hard safety net: if the result is still over budget (e.g. a single massive
|
||||
# first document), forcibly truncate with a closing comment.
|
||||
if len(result) > max_chars:
|
||||
truncation_msg = "\n<!-- ...output forcibly truncated to fit context window -->"
|
||||
result = result[: max_chars - len(truncation_msg)] + truncation_msg
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Knowledge Base Search
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def search_knowledge_base_async(
|
||||
query: str,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
connector_service: ConnectorService,
|
||||
connectors_to_search: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search the user's knowledge base for relevant documents.
|
||||
|
||||
This is the async implementation that searches across multiple connectors.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
search_space_id: The user's search space ID
|
||||
db_session: Database session
|
||||
connector_service: Initialized connector service
|
||||
connectors_to_search: Optional list of connector types to search. If omitted, searches all.
|
||||
top_k: Number of results per connector
|
||||
start_date: Optional start datetime (UTC) for filtering documents
|
||||
end_date: Optional end datetime (UTC) for filtering documents
|
||||
available_connectors: Optional list of connectors actually available in the search space.
|
||||
If provided, only these connectors will be searched.
|
||||
available_document_types: Optional list of document types that actually have indexed
|
||||
data. When provided, local connectors whose document type is
|
||||
absent are skipped entirely (no embedding / DB round-trip).
|
||||
max_input_tokens: Model context window size (tokens). Used to dynamically
|
||||
size the output so it fits within the model's limits.
|
||||
|
||||
Returns:
|
||||
Formatted string with search results
|
||||
"""
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
deduplicated = await search_knowledge_base_raw_async(
|
||||
query=query,
|
||||
search_space_id=search_space_id,
|
||||
db_session=db_session,
|
||||
connector_service=connector_service,
|
||||
connectors_to_search=connectors_to_search,
|
||||
top_k=top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
)
|
||||
|
||||
if not deduplicated:
|
||||
return "No documents found in the knowledge base. The search space has no indexed content yet."
|
||||
|
||||
# Use browse chunk cap for degenerate queries, otherwise adaptive chunking.
|
||||
max_chunks_per_doc = (
|
||||
_BROWSE_MAX_CHUNKS_PER_DOC if _is_degenerate_query(query) else 0
|
||||
)
|
||||
output_budget = _compute_tool_output_budget(max_input_tokens)
|
||||
result = format_documents_for_context(
|
||||
deduplicated,
|
||||
max_chars=output_budget,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
)
|
||||
|
||||
if len(result) > output_budget:
|
||||
perf.warning(
|
||||
"[kb_search] output STILL exceeds budget after format (%d > %d), "
|
||||
"hard truncation should have fired",
|
||||
len(result),
|
||||
output_budget,
|
||||
)
|
||||
|
||||
perf.info(
|
||||
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
|
||||
"budget=%d max_input_tokens=%s space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(deduplicated),
|
||||
len(deduplicated),
|
||||
len(result),
|
||||
output_budget,
|
||||
max_input_tokens,
|
||||
search_space_id,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def search_knowledge_base_raw_async(
|
||||
query: str,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
connector_service: ConnectorService,
|
||||
connectors_to_search: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
query_embedding: list[float] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search knowledge base and return raw document dicts (no XML formatting)."""
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
all_documents: list[dict[str, Any]] = []
|
||||
|
||||
# Preserve the public signature for compatibility even if values are unused.
|
||||
_ = (db_session, connector_service)
|
||||
|
||||
from app.agents.shared.utils import resolve_date_range
|
||||
|
||||
resolved_start_date, resolved_end_date = resolve_date_range(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
connectors = _normalize_connectors(connectors_to_search, available_connectors)
|
||||
|
||||
if available_document_types:
|
||||
doc_types_set = set(available_document_types)
|
||||
connectors = [
|
||||
c
|
||||
for c in connectors
|
||||
if c in doc_types_set
|
||||
or NATIVE_TO_LEGACY_DOCTYPE.get(c, "") in doc_types_set
|
||||
]
|
||||
|
||||
if not connectors:
|
||||
return []
|
||||
|
||||
if _is_degenerate_query(query):
|
||||
perf.info(
|
||||
"[kb_search_raw] degenerate query %r detected - recency browse",
|
||||
query,
|
||||
)
|
||||
browse_connectors = connectors if connectors else [None] # type: ignore[list-item]
|
||||
expanded_browse = []
|
||||
for connector in browse_connectors:
|
||||
if connector is not None and connector in NATIVE_TO_LEGACY_DOCTYPE:
|
||||
expanded_browse.append([connector, NATIVE_TO_LEGACY_DOCTYPE[connector]])
|
||||
else:
|
||||
expanded_browse.append(connector)
|
||||
browse_results = await asyncio.gather(
|
||||
*[
|
||||
_browse_recent_documents(
|
||||
search_space_id=search_space_id,
|
||||
document_type=connector,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
for connector in expanded_browse
|
||||
]
|
||||
)
|
||||
for docs in browse_results:
|
||||
all_documents.extend(docs)
|
||||
else:
|
||||
if query_embedding is None:
|
||||
from app.config import config as app_config
|
||||
|
||||
query_embedding = app_config.embedding_model_instance.embed(query)
|
||||
|
||||
max_parallel_searches = 4
|
||||
semaphore = asyncio.Semaphore(max_parallel_searches)
|
||||
|
||||
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
|
||||
try:
|
||||
async with semaphore, shielded_async_session() as isolated_session:
|
||||
svc = ConnectorService(isolated_session, search_space_id)
|
||||
return await svc._combined_rrf_search(
|
||||
query_text=query,
|
||||
search_space_id=search_space_id,
|
||||
document_type=connector,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
except Exception as exc:
|
||||
perf.warning("[kb_search_raw] connector=%s FAILED: %s", connector, exc)
|
||||
return []
|
||||
|
||||
connector_results = await asyncio.gather(
|
||||
*[_search_one_connector(connector) for connector in connectors]
|
||||
)
|
||||
for docs in connector_results:
|
||||
all_documents.extend(docs)
|
||||
|
||||
seen_doc_ids: set[Any] = set()
|
||||
seen_content_hashes: set[int] = set()
|
||||
deduplicated: list[dict[str, Any]] = []
|
||||
|
||||
def _content_fingerprint(document: dict[str, Any]) -> int | None:
|
||||
chunks = document.get("chunks")
|
||||
if isinstance(chunks, list):
|
||||
chunk_texts = []
|
||||
for chunk in chunks:
|
||||
if not isinstance(chunk, dict):
|
||||
continue
|
||||
chunk_content = (chunk.get("content") or "").strip()
|
||||
if chunk_content:
|
||||
chunk_texts.append(chunk_content)
|
||||
if chunk_texts:
|
||||
return hash("||".join(chunk_texts))
|
||||
flat_content = (document.get("content") or "").strip()
|
||||
if flat_content:
|
||||
return hash(flat_content)
|
||||
return None
|
||||
|
||||
for doc in all_documents:
|
||||
doc_id = (doc.get("document", {}) or {}).get("id")
|
||||
if doc_id is not None:
|
||||
if doc_id in seen_doc_ids:
|
||||
continue
|
||||
seen_doc_ids.add(doc_id)
|
||||
deduplicated.append(doc)
|
||||
continue
|
||||
content_hash = _content_fingerprint(doc)
|
||||
if content_hash is not None and content_hash in seen_content_hashes:
|
||||
continue
|
||||
if content_hash is not None:
|
||||
seen_content_hashes.add(content_hash)
|
||||
deduplicated.append(doc)
|
||||
|
||||
deduplicated.sort(key=lambda doc: doc.get("score", 0), reverse=True)
|
||||
perf.info(
|
||||
"[kb_search_raw] done in %.3fs total=%d deduped=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(all_documents),
|
||||
len(deduplicated),
|
||||
)
|
||||
return deduplicated
|
||||
Loading…
Add table
Add a link
Reference in a new issue