feat: enhance knowledge base search with date filtering

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-03-31 20:13:46 -07:00
parent 006dccbe4b
commit ad0e77c3d6
7 changed files with 660 additions and 12 deletions

View file

@ -15,14 +15,19 @@ import logging
import re
import uuid
from collections.abc import Sequence
from datetime import UTC, datetime
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.runtime import Runtime
from litellm import token_counter
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Document, Folder, shielded_async_session
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.utils.document_converters import embed_texts
@ -32,6 +37,23 @@ logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
class KBSearchPlan(BaseModel):
"""Structured internal plan for KB retrieval."""
optimized_query: str = Field(
min_length=1,
description="Optimized retrieval query preserving the user's intent.",
)
start_date: str | None = Field(
default=None,
description="Optional ISO start date or datetime for KB search filtering.",
)
end_date: str | None = Field(
default=None,
description="Optional ISO end date or datetime for KB search filtering.",
)
def _extract_text_from_message(message: BaseMessage) -> str:
"""Extract plain text from a message content."""
content = getattr(message, "content", "")
@ -61,6 +83,212 @@ def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
return name
def _render_recent_conversation(
messages: Sequence[BaseMessage],
*,
llm: BaseChatModel | None = None,
user_text: str = "",
max_messages: int = 6,
) -> str:
"""Render recent dialogue for internal planning under a token budget.
Prefers the latest messages and uses the project's existing model-aware
token budgeting hooks when available on the LLM (`_count_tokens`,
`_get_max_input_tokens`). Falls back to the prior fixed-message heuristic
if token counting is unavailable.
"""
rendered: list[tuple[str, str]] = []
for message in messages:
role: str | None = None
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
if getattr(message, "tool_calls", None):
continue
role = "assistant"
else:
continue
text = _extract_text_from_message(message).strip()
if not text:
continue
text = re.sub(r"\s+", " ", text)
rendered.append((role, text))
if not rendered:
return ""
# Exclude the latest user message from "recent conversation" because it is
# already passed separately as "Latest user message" in the planner prompt.
if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip():
rendered = rendered[:-1]
if not rendered:
return ""
def _legacy_render() -> str:
legacy_lines: list[str] = []
for role, text in rendered[-max_messages:]:
clipped = text[:400].rstrip() + "..." if len(text) > 400 else text
legacy_lines.append(f"{role}: {clipped}")
return "\n".join(legacy_lines)
def _count_prompt_tokens(conversation_text: str) -> int | None:
prompt = _build_kb_planner_prompt(
recent_conversation=conversation_text or "(none)",
user_text=user_text,
)
message_payload = [{"role": "user", "content": prompt}]
count_fn = getattr(llm, "_count_tokens", None) if llm is not None else None
if callable(count_fn):
try:
return count_fn(message_payload)
except Exception:
pass
profile = getattr(llm, "profile", None) if llm is not None else None
model_names: list[str] = []
if isinstance(profile, dict):
tcms = profile.get("token_count_models")
if isinstance(tcms, list):
model_names.extend(
name for name in tcms if isinstance(name, str) and name
)
tcm = profile.get("token_count_model")
if isinstance(tcm, str) and tcm and tcm not in model_names:
model_names.append(tcm)
model_name = model_names[0] if model_names else getattr(llm, "model", None)
if not isinstance(model_name, str) or not model_name:
return None
try:
return token_counter(messages=message_payload, model=model_name)
except Exception:
return None
get_max_input_tokens = getattr(llm, "_get_max_input_tokens", None) if llm else None
if callable(get_max_input_tokens):
try:
max_input_tokens = int(get_max_input_tokens())
except Exception:
max_input_tokens = None
else:
profile = getattr(llm, "profile", None) if llm is not None else None
max_input_tokens = (
profile.get("max_input_tokens")
if isinstance(profile, dict)
and isinstance(profile.get("max_input_tokens"), int)
else None
)
if not isinstance(max_input_tokens, int) or max_input_tokens <= 0:
return _legacy_render()
output_reserve = min(max(int(max_input_tokens * 0.02), 256), 1024)
budget = max_input_tokens - output_reserve
if budget <= 0:
return _legacy_render()
selected_lines: list[str] = []
for role, text in reversed(rendered):
candidate_line = f"{role}: {text}"
candidate_lines = [candidate_line, *selected_lines]
candidate_conversation = "\n".join(candidate_lines)
token_count = _count_prompt_tokens(candidate_conversation)
if token_count is None:
return _legacy_render()
if token_count <= budget:
selected_lines = candidate_lines
continue
# If the full message does not fit, keep as much of this most-recent
# older message as possible via binary search.
lo, hi = 1, len(text)
best_line: str | None = None
while lo <= hi:
mid = (lo + hi) // 2
clipped_text = text[:mid].rstrip() + "..."
clipped_line = f"{role}: {clipped_text}"
clipped_conversation = "\n".join([clipped_line, *selected_lines])
clipped_tokens = _count_prompt_tokens(clipped_conversation)
if clipped_tokens is None:
break
if clipped_tokens <= budget:
best_line = clipped_line
lo = mid + 1
else:
hi = mid - 1
if best_line is not None:
selected_lines = [best_line, *selected_lines]
break
if not selected_lines:
return _legacy_render()
return "\n".join(selected_lines)
def _build_kb_planner_prompt(
*,
recent_conversation: str,
user_text: str,
) -> str:
"""Build a compact internal prompt for KB query rewriting and date scoping."""
today = datetime.now(UTC).date().isoformat()
return (
"You optimize internal knowledge-base search inputs for document retrieval.\n"
"Return JSON only with this exact shape:\n"
'{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null"}\n\n'
"Rules:\n"
"- Preserve the user's intent.\n"
"- Rewrite the query to improve retrieval using concrete entities, acronyms, projects, tools, people, and document-specific terms when helpful.\n"
"- Keep the query concise and retrieval-focused.\n"
"- Only use date filters when the latest user request or recent dialogue clearly implies a time range.\n"
"- If you use date filters, prefer returning both bounds.\n"
"- If no date filter is useful, return null for both dates.\n"
"- Do not include markdown, prose, or explanations.\n\n"
f"Today's UTC date: {today}\n\n"
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
f"Latest user message:\n{user_text}"
)
def _extract_json_payload(text: str) -> str:
"""Extract a JSON object from a raw LLM response."""
stripped = text.strip()
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
if fenced:
return fenced.group(1)
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start : end + 1]
return stripped
def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan:
"""Parse and validate the planner's JSON response."""
payload = json.loads(_extract_json_payload(response_text))
return KBSearchPlan.model_validate(payload)
def _normalize_optional_date_range(
start_date: str | None,
end_date: str | None,
) -> tuple[datetime | None, datetime | None]:
"""Normalize optional planner dates into a UTC datetime range."""
parsed_start = parse_date_or_datetime(start_date) if start_date else None
parsed_end = parse_date_or_datetime(end_date) if end_date else None
if parsed_start is None and parsed_end is None:
return None, None
resolved_start, resolved_end = resolve_date_range(parsed_start, parsed_end)
return resolved_start, resolved_end
def _build_document_xml(
document: dict[str, Any],
matched_chunk_ids: set[int] | None = None,
@ -264,6 +492,8 @@ async def search_knowledge_base(
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
top_k: int = 10,
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> list[dict[str, Any]]:
"""Run a single unified hybrid search against the knowledge base.
@ -286,6 +516,8 @@ async def search_knowledge_base(
top_k=retriever_top_k,
search_space_id=search_space_id,
document_type=doc_types,
start_date=start_date,
end_date=end_date,
query_embedding=embedding.tolist(),
)
@ -346,16 +578,71 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
def __init__(
self,
*,
llm: BaseChatModel | None = None,
search_space_id: int,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
top_k: int = 10,
) -> None:
self.llm = llm
self.search_space_id = search_space_id
self.available_connectors = available_connectors
self.available_document_types = available_document_types
self.top_k = top_k
async def _plan_search_inputs(
self,
*,
messages: Sequence[BaseMessage],
user_text: str,
) -> tuple[str, datetime | None, datetime | None]:
"""Rewrite the KB query and infer optional date filters with the LLM."""
if self.llm is None:
return user_text, None, None
recent_conversation = _render_recent_conversation(
messages,
llm=self.llm,
user_text=user_text,
)
prompt = _build_kb_planner_prompt(
recent_conversation=recent_conversation,
user_text=user_text,
)
loop = asyncio.get_running_loop()
t0 = loop.time()
try:
response = await self.llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
plan = _parse_kb_search_plan_response(_extract_text_from_message(response))
optimized_query = (
re.sub(r"\s+", " ", plan.optimized_query).strip() or user_text
)
start_date, end_date = _normalize_optional_date_range(
plan.start_date,
plan.end_date,
)
_perf_log.info(
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r start=%s end=%s",
loop.time() - t0,
user_text[:80],
optimized_query[:120],
start_date.isoformat() if start_date else None,
end_date.isoformat() if end_date else None,
)
return optimized_query, start_date, end_date
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
logger.warning(
"KB planner returned invalid output, using raw query: %s", exc
)
except Exception as exc: # pragma: no cover - defensive fallback
logger.warning("KB planner failed, using raw query: %s", exc)
return user_text, None, None
def before_agent( # type: ignore[override]
self,
state: AgentState,
@ -388,13 +675,19 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
t0 = _perf_log and asyncio.get_event_loop().time()
existing_files = state.get("files")
planned_query, start_date, end_date = await self._plan_search_inputs(
messages=messages,
user_text=user_text,
)
search_results = await search_knowledge_base(
query=user_text,
query=planned_query,
search_space_id=self.search_space_id,
available_connectors=self.available_connectors,
available_document_types=self.available_document_types,
top_k=self.top_k,
start_date=start_date,
end_date=end_date,
)
new_files = await build_scoped_filesystem(
documents=search_results,
@ -405,9 +698,10 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
if t0 is not None:
_perf_log.info(
"[kb_fs_middleware] completed in %.3fs query=%r new_files=%d total=%d",
"[kb_fs_middleware] completed in %.3fs query=%r optimized=%r new_files=%d total=%d",
asyncio.get_event_loop().time() - t0,
user_text[:80],
planned_query[:120],
len(new_files),
len(new_files) + len(existing_files or {}),
)