feat: enhance knowledge base search with date filtering

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-03-31 20:13:46 -07:00
parent 006dccbe4b
commit ad0e77c3d6
7 changed files with 660 additions and 12 deletions

View file

@ -447,6 +447,7 @@ async def create_surfsense_deep_agent(
deepagent_middleware = [ deepagent_middleware = [
TodoListMiddleware(), TodoListMiddleware(),
KnowledgeBaseSearchMiddleware( KnowledgeBaseSearchMiddleware(
llm=llm,
search_space_id=search_space_id, search_space_id=search_space_id,
available_connectors=available_connectors, available_connectors=available_connectors,
available_document_types=available_document_types, available_document_types=available_document_types,

View file

@ -15,14 +15,19 @@ import logging
import re import re
import uuid import uuid
from collections.abc import Sequence from collections.abc import Sequence
from datetime import UTC, datetime
from typing import Any from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from litellm import token_counter
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Document, Folder, shielded_async_session from app.db import NATIVE_TO_LEGACY_DOCTYPE, Document, Folder, shielded_async_session
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.utils.document_converters import embed_texts from app.utils.document_converters import embed_texts
@ -32,6 +37,23 @@ logger = logging.getLogger(__name__)
_perf_log = get_perf_logger() _perf_log = get_perf_logger()
class KBSearchPlan(BaseModel):
"""Structured internal plan for KB retrieval."""
optimized_query: str = Field(
min_length=1,
description="Optimized retrieval query preserving the user's intent.",
)
start_date: str | None = Field(
default=None,
description="Optional ISO start date or datetime for KB search filtering.",
)
end_date: str | None = Field(
default=None,
description="Optional ISO end date or datetime for KB search filtering.",
)
def _extract_text_from_message(message: BaseMessage) -> str: def _extract_text_from_message(message: BaseMessage) -> str:
"""Extract plain text from a message content.""" """Extract plain text from a message content."""
content = getattr(message, "content", "") content = getattr(message, "content", "")
@ -61,6 +83,212 @@ def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
return name return name
def _render_recent_conversation(
messages: Sequence[BaseMessage],
*,
llm: BaseChatModel | None = None,
user_text: str = "",
max_messages: int = 6,
) -> str:
"""Render recent dialogue for internal planning under a token budget.
Prefers the latest messages and uses the project's existing model-aware
token budgeting hooks when available on the LLM (`_count_tokens`,
`_get_max_input_tokens`). Falls back to the prior fixed-message heuristic
if token counting is unavailable.
"""
rendered: list[tuple[str, str]] = []
for message in messages:
role: str | None = None
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
if getattr(message, "tool_calls", None):
continue
role = "assistant"
else:
continue
text = _extract_text_from_message(message).strip()
if not text:
continue
text = re.sub(r"\s+", " ", text)
rendered.append((role, text))
if not rendered:
return ""
# Exclude the latest user message from "recent conversation" because it is
# already passed separately as "Latest user message" in the planner prompt.
if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip():
rendered = rendered[:-1]
if not rendered:
return ""
def _legacy_render() -> str:
legacy_lines: list[str] = []
for role, text in rendered[-max_messages:]:
clipped = text[:400].rstrip() + "..." if len(text) > 400 else text
legacy_lines.append(f"{role}: {clipped}")
return "\n".join(legacy_lines)
def _count_prompt_tokens(conversation_text: str) -> int | None:
prompt = _build_kb_planner_prompt(
recent_conversation=conversation_text or "(none)",
user_text=user_text,
)
message_payload = [{"role": "user", "content": prompt}]
count_fn = getattr(llm, "_count_tokens", None) if llm is not None else None
if callable(count_fn):
try:
return count_fn(message_payload)
except Exception:
pass
profile = getattr(llm, "profile", None) if llm is not None else None
model_names: list[str] = []
if isinstance(profile, dict):
tcms = profile.get("token_count_models")
if isinstance(tcms, list):
model_names.extend(
name for name in tcms if isinstance(name, str) and name
)
tcm = profile.get("token_count_model")
if isinstance(tcm, str) and tcm and tcm not in model_names:
model_names.append(tcm)
model_name = model_names[0] if model_names else getattr(llm, "model", None)
if not isinstance(model_name, str) or not model_name:
return None
try:
return token_counter(messages=message_payload, model=model_name)
except Exception:
return None
get_max_input_tokens = getattr(llm, "_get_max_input_tokens", None) if llm else None
if callable(get_max_input_tokens):
try:
max_input_tokens = int(get_max_input_tokens())
except Exception:
max_input_tokens = None
else:
profile = getattr(llm, "profile", None) if llm is not None else None
max_input_tokens = (
profile.get("max_input_tokens")
if isinstance(profile, dict)
and isinstance(profile.get("max_input_tokens"), int)
else None
)
if not isinstance(max_input_tokens, int) or max_input_tokens <= 0:
return _legacy_render()
output_reserve = min(max(int(max_input_tokens * 0.02), 256), 1024)
budget = max_input_tokens - output_reserve
if budget <= 0:
return _legacy_render()
selected_lines: list[str] = []
for role, text in reversed(rendered):
candidate_line = f"{role}: {text}"
candidate_lines = [candidate_line, *selected_lines]
candidate_conversation = "\n".join(candidate_lines)
token_count = _count_prompt_tokens(candidate_conversation)
if token_count is None:
return _legacy_render()
if token_count <= budget:
selected_lines = candidate_lines
continue
# If the full message does not fit, keep as much of this most-recent
# older message as possible via binary search.
lo, hi = 1, len(text)
best_line: str | None = None
while lo <= hi:
mid = (lo + hi) // 2
clipped_text = text[:mid].rstrip() + "..."
clipped_line = f"{role}: {clipped_text}"
clipped_conversation = "\n".join([clipped_line, *selected_lines])
clipped_tokens = _count_prompt_tokens(clipped_conversation)
if clipped_tokens is None:
break
if clipped_tokens <= budget:
best_line = clipped_line
lo = mid + 1
else:
hi = mid - 1
if best_line is not None:
selected_lines = [best_line, *selected_lines]
break
if not selected_lines:
return _legacy_render()
return "\n".join(selected_lines)
def _build_kb_planner_prompt(
*,
recent_conversation: str,
user_text: str,
) -> str:
"""Build a compact internal prompt for KB query rewriting and date scoping."""
today = datetime.now(UTC).date().isoformat()
return (
"You optimize internal knowledge-base search inputs for document retrieval.\n"
"Return JSON only with this exact shape:\n"
'{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null"}\n\n'
"Rules:\n"
"- Preserve the user's intent.\n"
"- Rewrite the query to improve retrieval using concrete entities, acronyms, projects, tools, people, and document-specific terms when helpful.\n"
"- Keep the query concise and retrieval-focused.\n"
"- Only use date filters when the latest user request or recent dialogue clearly implies a time range.\n"
"- If you use date filters, prefer returning both bounds.\n"
"- If no date filter is useful, return null for both dates.\n"
"- Do not include markdown, prose, or explanations.\n\n"
f"Today's UTC date: {today}\n\n"
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
f"Latest user message:\n{user_text}"
)
def _extract_json_payload(text: str) -> str:
"""Extract a JSON object from a raw LLM response."""
stripped = text.strip()
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
if fenced:
return fenced.group(1)
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start : end + 1]
return stripped
def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan:
"""Parse and validate the planner's JSON response."""
payload = json.loads(_extract_json_payload(response_text))
return KBSearchPlan.model_validate(payload)
def _normalize_optional_date_range(
start_date: str | None,
end_date: str | None,
) -> tuple[datetime | None, datetime | None]:
"""Normalize optional planner dates into a UTC datetime range."""
parsed_start = parse_date_or_datetime(start_date) if start_date else None
parsed_end = parse_date_or_datetime(end_date) if end_date else None
if parsed_start is None and parsed_end is None:
return None, None
resolved_start, resolved_end = resolve_date_range(parsed_start, parsed_end)
return resolved_start, resolved_end
def _build_document_xml( def _build_document_xml(
document: dict[str, Any], document: dict[str, Any],
matched_chunk_ids: set[int] | None = None, matched_chunk_ids: set[int] | None = None,
@ -264,6 +492,8 @@ async def search_knowledge_base(
available_connectors: list[str] | None = None, available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None, available_document_types: list[str] | None = None,
top_k: int = 10, top_k: int = 10,
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Run a single unified hybrid search against the knowledge base. """Run a single unified hybrid search against the knowledge base.
@ -286,6 +516,8 @@ async def search_knowledge_base(
top_k=retriever_top_k, top_k=retriever_top_k,
search_space_id=search_space_id, search_space_id=search_space_id,
document_type=doc_types, document_type=doc_types,
start_date=start_date,
end_date=end_date,
query_embedding=embedding.tolist(), query_embedding=embedding.tolist(),
) )
@ -346,16 +578,71 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
def __init__( def __init__(
self, self,
*, *,
llm: BaseChatModel | None = None,
search_space_id: int, search_space_id: int,
available_connectors: list[str] | None = None, available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None, available_document_types: list[str] | None = None,
top_k: int = 10, top_k: int = 10,
) -> None: ) -> None:
self.llm = llm
self.search_space_id = search_space_id self.search_space_id = search_space_id
self.available_connectors = available_connectors self.available_connectors = available_connectors
self.available_document_types = available_document_types self.available_document_types = available_document_types
self.top_k = top_k self.top_k = top_k
async def _plan_search_inputs(
self,
*,
messages: Sequence[BaseMessage],
user_text: str,
) -> tuple[str, datetime | None, datetime | None]:
"""Rewrite the KB query and infer optional date filters with the LLM."""
if self.llm is None:
return user_text, None, None
recent_conversation = _render_recent_conversation(
messages,
llm=self.llm,
user_text=user_text,
)
prompt = _build_kb_planner_prompt(
recent_conversation=recent_conversation,
user_text=user_text,
)
loop = asyncio.get_running_loop()
t0 = loop.time()
try:
response = await self.llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
plan = _parse_kb_search_plan_response(_extract_text_from_message(response))
optimized_query = (
re.sub(r"\s+", " ", plan.optimized_query).strip() or user_text
)
start_date, end_date = _normalize_optional_date_range(
plan.start_date,
plan.end_date,
)
_perf_log.info(
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r start=%s end=%s",
loop.time() - t0,
user_text[:80],
optimized_query[:120],
start_date.isoformat() if start_date else None,
end_date.isoformat() if end_date else None,
)
return optimized_query, start_date, end_date
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
logger.warning(
"KB planner returned invalid output, using raw query: %s", exc
)
except Exception as exc: # pragma: no cover - defensive fallback
logger.warning("KB planner failed, using raw query: %s", exc)
return user_text, None, None
def before_agent( # type: ignore[override] def before_agent( # type: ignore[override]
self, self,
state: AgentState, state: AgentState,
@ -388,13 +675,19 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
t0 = _perf_log and asyncio.get_event_loop().time() t0 = _perf_log and asyncio.get_event_loop().time()
existing_files = state.get("files") existing_files = state.get("files")
planned_query, start_date, end_date = await self._plan_search_inputs(
messages=messages,
user_text=user_text,
)
search_results = await search_knowledge_base( search_results = await search_knowledge_base(
query=user_text, query=planned_query,
search_space_id=self.search_space_id, search_space_id=self.search_space_id,
available_connectors=self.available_connectors, available_connectors=self.available_connectors,
available_document_types=self.available_document_types, available_document_types=self.available_document_types,
top_k=self.top_k, top_k=self.top_k,
start_date=start_date,
end_date=end_date,
) )
new_files = await build_scoped_filesystem( new_files = await build_scoped_filesystem(
documents=search_results, documents=search_results,
@ -405,9 +698,10 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
if t0 is not None: if t0 is not None:
_perf_log.info( _perf_log.info(
"[kb_fs_middleware] completed in %.3fs query=%r new_files=%d total=%d", "[kb_fs_middleware] completed in %.3fs query=%r optimized=%r new_files=%d total=%d",
asyncio.get_event_loop().time() - t0, asyncio.get_event_loop().time() - t0,
user_text[:80], user_text[:80],
planned_query[:120],
len(new_files), len(new_files),
len(new_files) + len(existing_files or {}), len(new_files) + len(existing_files or {}),
) )

View file

@ -310,7 +310,7 @@ class GoogleGmailConnector:
Fetch recent messages from Gmail within specified date range. Fetch recent messages from Gmail within specified date range.
Args: Args:
max_results: Maximum number of messages to fetch (default: 50) max_results: Maximum number of messages to fetch (default: 50)
start_date: Start date in YYYY-MM-DD format (default: 30 days ago) start_date: Start date in YYYY-MM-DD format (default: 3 days ago)
end_date: End date in YYYY-MM-DD format (default: today) end_date: End date in YYYY-MM-DD format (default: today)
Returns: Returns:
Tuple containing (messages list with details, error message or None) Tuple containing (messages list with details, error message or None)
@ -334,8 +334,8 @@ class GoogleGmailConnector:
start_query = start_dt.strftime("%Y/%m/%d") start_query = start_dt.strftime("%Y/%m/%d")
query_parts.append(f"after:{start_query}") query_parts.append(f"after:{start_query}")
else: else:
# Default to 30 days ago # Default to 3 days ago
cutoff_date = datetime.now() - timedelta(days=30) cutoff_date = datetime.now() - timedelta(days=3)
date_query = cutoff_date.strftime("%Y/%m/%d") date_query = cutoff_date.strftime("%Y/%m/%d")
query_parts.append(f"after:{date_query}") query_parts.append(f"after:{date_query}")

View file

@ -152,7 +152,9 @@ class _FakeReconciliationStripeClient:
class TestStripeCheckoutSessionCreation: class TestStripeCheckoutSessionCreation:
async def test_get_status_reflects_backend_toggle(self, client, headers, monkeypatch): async def test_get_status_reflects_backend_toggle(
self, client, headers, monkeypatch
):
monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGE_BUYING_ENABLED", False) monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGE_BUYING_ENABLED", False)
disabled_response = await client.get("/api/v1/stripe/status", headers=headers) disabled_response = await client.get("/api/v1/stripe/status", headers=headers)
assert disabled_response.status_code == 200, disabled_response.text assert disabled_response.status_code == 200, disabled_response.text
@ -237,7 +239,9 @@ class TestStripeCheckoutSessionCreation:
) )
assert response.status_code == 503, response.text assert response.status_code == 503, response.text
assert response.json()["detail"] == "Page purchases are temporarily unavailable." assert (
response.json()["detail"] == "Page purchases are temporarily unavailable."
)
purchase_count = await _fetchrow("SELECT COUNT(*) AS count FROM page_purchases") purchase_count = await _fetchrow("SELECT COUNT(*) AS count FROM page_purchases")
assert purchase_count is not None assert purchase_count is not None

View file

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import uuid import uuid
from datetime import UTC, datetime from datetime import UTC, datetime, timedelta
import pytest_asyncio import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -22,6 +22,7 @@ def _make_document(
content: str, content: str,
search_space_id: int, search_space_id: int,
created_by_id: str, created_by_id: str,
updated_at: datetime | None = None,
) -> Document: ) -> Document:
uid = uuid.uuid4().hex[:12] uid = uuid.uuid4().hex[:12]
return Document( return Document(
@ -34,7 +35,7 @@ def _make_document(
search_space_id=search_space_id, search_space_id=search_space_id,
created_by_id=created_by_id, created_by_id=created_by_id,
embedding=DUMMY_EMBEDDING, embedding=DUMMY_EMBEDDING,
updated_at=datetime.now(UTC), updated_at=updated_at or datetime.now(UTC),
status={"state": "ready"}, status={"state": "ready"},
) )
@ -104,3 +105,54 @@ async def seed_large_doc(
"search_space": db_search_space, "search_space": db_search_space,
"user": db_user, "user": db_user,
} }
@pytest_asyncio.fixture
async def seed_date_filtered_docs(
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
):
"""Insert matching docs with different timestamps for date-filter tests."""
user_id = str(db_user.id)
space_id = db_search_space.id
now = datetime.now(UTC)
recent_doc = _make_document(
title="Recent OCV Notes",
document_type=DocumentType.FILE,
content="ocv meeting decisions and action items",
search_space_id=space_id,
created_by_id=user_id,
updated_at=now,
)
old_doc = _make_document(
title="Old OCV Notes",
document_type=DocumentType.FILE,
content="ocv meeting decisions and action items",
search_space_id=space_id,
created_by_id=user_id,
updated_at=now - timedelta(days=730),
)
db_session.add_all([recent_doc, old_doc])
await db_session.flush()
db_session.add_all(
[
_make_chunk(
content="ocv meeting decisions and action items recent",
document_id=recent_doc.id,
),
_make_chunk(
content="ocv meeting decisions and action items old",
document_id=old_doc.id,
),
]
)
await db_session.flush()
return {
"recent_doc": recent_doc,
"old_doc": old_doc,
"search_space": db_search_space,
"user": db_user,
}

View file

@ -0,0 +1,62 @@
"""Integration smoke tests for KB search query/date scoping."""
from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
import numpy as np
import pytest
from app.agents.new_chat.middleware.knowledge_search import search_knowledge_base
from .conftest import DUMMY_EMBEDDING
pytestmark = pytest.mark.integration
async def test_search_knowledge_base_applies_date_filters(
db_session,
seed_date_filtered_docs,
monkeypatch,
):
"""Date filters should remove older matching documents from scoped KB results."""
@asynccontextmanager
async def fake_shielded_async_session():
yield db_session
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.shielded_async_session",
fake_shielded_async_session,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.embed_texts",
lambda texts: [np.array(DUMMY_EMBEDDING) for _ in texts],
)
space_id = seed_date_filtered_docs["search_space"].id
recent_cutoff = datetime.now(UTC) - timedelta(days=30)
unfiltered_results = await search_knowledge_base(
query="ocv meeting decisions",
search_space_id=space_id,
available_document_types=["FILE"],
top_k=10,
)
filtered_results = await search_knowledge_base(
query="ocv meeting decisions",
search_space_id=space_id,
available_document_types=["FILE"],
top_k=10,
start_date=recent_cutoff,
end_date=datetime.now(UTC),
)
unfiltered_ids = {result["document"]["id"] for result in unfiltered_results}
filtered_ids = {result["document"]["id"] for result in filtered_results}
assert seed_date_filtered_docs["recent_doc"].id in unfiltered_ids
assert seed_date_filtered_docs["old_doc"].id in unfiltered_ids
assert seed_date_filtered_docs["recent_doc"].id in filtered_ids
assert seed_date_filtered_docs["old_doc"].id not in filtered_ids

View file

@ -1,12 +1,16 @@
"""Unit tests for knowledge_search middleware helpers. """Unit tests for knowledge_search middleware helpers."""
These test pure functions that don't require a database. import json
"""
import pytest import pytest
from langchain_core.messages import AIMessage, HumanMessage
from app.agents.new_chat.middleware.knowledge_search import ( from app.agents.new_chat.middleware.knowledge_search import (
KnowledgeBaseSearchMiddleware,
_build_document_xml, _build_document_xml,
_normalize_optional_date_range,
_parse_kb_search_plan_response,
_render_recent_conversation,
_resolve_search_types, _resolve_search_types,
) )
@ -131,3 +135,234 @@ class TestBuildDocumentXml:
line for line in lines if "<![CDATA[" in line and "<chunk" in line line for line in lines if "<![CDATA[" in line and "<chunk" in line
] ]
assert len(chunk_lines) == 3 assert len(chunk_lines) == 3
# ── planner parsing / date normalization ───────────────────────────────
class TestPlannerHelpers:
def test_parse_kb_search_plan_response_accepts_plain_json(self):
plan = _parse_kb_search_plan_response(
json.dumps(
{
"optimized_query": "ocv meeting decisions summary",
"start_date": "2026-03-01",
"end_date": "2026-03-31",
}
)
)
assert plan.optimized_query == "ocv meeting decisions summary"
assert plan.start_date == "2026-03-01"
assert plan.end_date == "2026-03-31"
def test_parse_kb_search_plan_response_accepts_fenced_json(self):
plan = _parse_kb_search_plan_response(
"""```json
{"optimized_query":"deel founders guide","start_date":null,"end_date":null}
```"""
)
assert plan.optimized_query == "deel founders guide"
assert plan.start_date is None
assert plan.end_date is None
def test_normalize_optional_date_range_returns_none_when_absent(self):
start_date, end_date = _normalize_optional_date_range(None, None)
assert start_date is None
assert end_date is None
def test_normalize_optional_date_range_resolves_single_bound(self):
start_date, end_date = _normalize_optional_date_range("2026-03-01", None)
assert start_date is not None
assert end_date is not None
assert start_date.date().isoformat() == "2026-03-01"
assert end_date >= start_date
class FakeLLM:
def __init__(self, response_text: str):
self.response_text = response_text
self.calls: list[dict] = []
async def ainvoke(self, messages, config=None):
self.calls.append({"messages": messages, "config": config})
return AIMessage(content=self.response_text)
class FakeBudgetLLM:
def __init__(self, *, max_input_tokens: int):
self._max_input_tokens_value = max_input_tokens
def _get_max_input_tokens(self) -> int:
return self._max_input_tokens_value
def _count_tokens(self, messages) -> int:
# Deterministic, simple proxy for tests: count characters as tokens.
return sum(len(msg.get("content", "")) for msg in messages)
class TestKnowledgeBaseSearchMiddlewarePlanner:
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
messages = [
HumanMessage(content="old user context " * 40),
AIMessage(content="old assistant answer " * 35),
HumanMessage(content="recent user context " * 20),
AIMessage(content="recent assistant answer " * 18),
HumanMessage(content="latest question"),
]
rendered = _render_recent_conversation(
messages,
llm=FakeBudgetLLM(max_input_tokens=900),
user_text="latest question",
)
assert "recent user context" in rendered
assert "recent assistant answer" in rendered
assert "latest question" not in rendered
assert rendered.index("recent user context") < rendered.index(
"recent assistant answer"
)
def test_render_recent_conversation_falls_back_to_legacy_without_budgeting(self):
messages = [
HumanMessage(content="message one"),
AIMessage(content="message two"),
HumanMessage(content="latest question"),
]
rendered = _render_recent_conversation(
messages,
llm=None,
user_text="latest question",
)
assert "user: message one" in rendered
assert "assistant: message two" in rendered
assert "latest question" not in rendered
async def test_middleware_uses_optimized_query_and_dates(self, monkeypatch):
captured: dict = {}
async def fake_search_knowledge_base(**kwargs):
captured.update(kwargs)
return []
async def fake_build_scoped_filesystem(**kwargs):
return {}
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
fake_build_scoped_filesystem,
)
llm = FakeLLM(
json.dumps(
{
"optimized_query": "ocv meeting decisions action items",
"start_date": "2026-03-01",
"end_date": "2026-03-31",
}
)
)
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=37)
result = await middleware.abefore_agent(
{
"messages": [
HumanMessage(content="what happened in our OCV meeting last month?")
]
},
runtime=None,
)
assert result is not None
assert captured["query"] == "ocv meeting decisions action items"
assert captured["start_date"] is not None
assert captured["end_date"] is not None
assert captured["start_date"].date().isoformat() == "2026-03-01"
assert captured["end_date"].date().isoformat() == "2026-03-31"
assert llm.calls[0]["config"] == {"tags": ["surfsense:internal"]}
async def test_middleware_falls_back_when_planner_returns_invalid_json(
self,
monkeypatch,
):
captured: dict = {}
async def fake_search_knowledge_base(**kwargs):
captured.update(kwargs)
return []
async def fake_build_scoped_filesystem(**kwargs):
return {}
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
fake_build_scoped_filesystem,
)
middleware = KnowledgeBaseSearchMiddleware(
llm=FakeLLM("not json"),
search_space_id=37,
)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
runtime=None,
)
assert captured["query"] == "summarize founders guide by deel"
assert captured["start_date"] is None
assert captured["end_date"] is None
async def test_middleware_passes_none_dates_when_planner_returns_nulls(
self,
monkeypatch,
):
captured: dict = {}
async def fake_search_knowledge_base(**kwargs):
captured.update(kwargs)
return []
async def fake_build_scoped_filesystem(**kwargs):
return {}
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
fake_build_scoped_filesystem,
)
middleware = KnowledgeBaseSearchMiddleware(
llm=FakeLLM(
json.dumps(
{
"optimized_query": "deel founders guide summary",
"start_date": None,
"end_date": None,
}
)
),
search_space_id=37,
)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
runtime=None,
)
assert captured["query"] == "deel founders guide summary"
assert captured["start_date"] is None
assert captured["end_date"] is None