mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
feat: enhance knowledge base search with date filtering
This commit is contained in:
parent
006dccbe4b
commit
ad0e77c3d6
7 changed files with 660 additions and 12 deletions
|
|
@ -447,6 +447,7 @@ async def create_surfsense_deep_agent(
|
||||||
deepagent_middleware = [
|
deepagent_middleware = [
|
||||||
TodoListMiddleware(),
|
TodoListMiddleware(),
|
||||||
KnowledgeBaseSearchMiddleware(
|
KnowledgeBaseSearchMiddleware(
|
||||||
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
available_connectors=available_connectors,
|
available_connectors=available_connectors,
|
||||||
available_document_types=available_document_types,
|
available_document_types=available_document_types,
|
||||||
|
|
|
||||||
|
|
@ -15,14 +15,19 @@ import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
from litellm import token_counter
|
||||||
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
|
||||||
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Document, Folder, shielded_async_session
|
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Document, Folder, shielded_async_session
|
||||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||||
from app.utils.document_converters import embed_texts
|
from app.utils.document_converters import embed_texts
|
||||||
|
|
@ -32,6 +37,23 @@ logger = logging.getLogger(__name__)
|
||||||
_perf_log = get_perf_logger()
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class KBSearchPlan(BaseModel):
|
||||||
|
"""Structured internal plan for KB retrieval."""
|
||||||
|
|
||||||
|
optimized_query: str = Field(
|
||||||
|
min_length=1,
|
||||||
|
description="Optimized retrieval query preserving the user's intent.",
|
||||||
|
)
|
||||||
|
start_date: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional ISO start date or datetime for KB search filtering.",
|
||||||
|
)
|
||||||
|
end_date: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional ISO end date or datetime for KB search filtering.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_from_message(message: BaseMessage) -> str:
|
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||||
"""Extract plain text from a message content."""
|
"""Extract plain text from a message content."""
|
||||||
content = getattr(message, "content", "")
|
content = getattr(message, "content", "")
|
||||||
|
|
@ -61,6 +83,212 @@ def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _render_recent_conversation(
|
||||||
|
messages: Sequence[BaseMessage],
|
||||||
|
*,
|
||||||
|
llm: BaseChatModel | None = None,
|
||||||
|
user_text: str = "",
|
||||||
|
max_messages: int = 6,
|
||||||
|
) -> str:
|
||||||
|
"""Render recent dialogue for internal planning under a token budget.
|
||||||
|
|
||||||
|
Prefers the latest messages and uses the project's existing model-aware
|
||||||
|
token budgeting hooks when available on the LLM (`_count_tokens`,
|
||||||
|
`_get_max_input_tokens`). Falls back to the prior fixed-message heuristic
|
||||||
|
if token counting is unavailable.
|
||||||
|
"""
|
||||||
|
rendered: list[tuple[str, str]] = []
|
||||||
|
for message in messages:
|
||||||
|
role: str | None = None
|
||||||
|
if isinstance(message, HumanMessage):
|
||||||
|
role = "user"
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
if getattr(message, "tool_calls", None):
|
||||||
|
continue
|
||||||
|
role = "assistant"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
text = _extract_text_from_message(message).strip()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
text = re.sub(r"\s+", " ", text)
|
||||||
|
rendered.append((role, text))
|
||||||
|
|
||||||
|
if not rendered:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Exclude the latest user message from "recent conversation" because it is
|
||||||
|
# already passed separately as "Latest user message" in the planner prompt.
|
||||||
|
if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip():
|
||||||
|
rendered = rendered[:-1]
|
||||||
|
|
||||||
|
if not rendered:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _legacy_render() -> str:
|
||||||
|
legacy_lines: list[str] = []
|
||||||
|
for role, text in rendered[-max_messages:]:
|
||||||
|
clipped = text[:400].rstrip() + "..." if len(text) > 400 else text
|
||||||
|
legacy_lines.append(f"{role}: {clipped}")
|
||||||
|
return "\n".join(legacy_lines)
|
||||||
|
|
||||||
|
def _count_prompt_tokens(conversation_text: str) -> int | None:
|
||||||
|
prompt = _build_kb_planner_prompt(
|
||||||
|
recent_conversation=conversation_text or "(none)",
|
||||||
|
user_text=user_text,
|
||||||
|
)
|
||||||
|
message_payload = [{"role": "user", "content": prompt}]
|
||||||
|
|
||||||
|
count_fn = getattr(llm, "_count_tokens", None) if llm is not None else None
|
||||||
|
if callable(count_fn):
|
||||||
|
try:
|
||||||
|
return count_fn(message_payload)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
profile = getattr(llm, "profile", None) if llm is not None else None
|
||||||
|
model_names: list[str] = []
|
||||||
|
if isinstance(profile, dict):
|
||||||
|
tcms = profile.get("token_count_models")
|
||||||
|
if isinstance(tcms, list):
|
||||||
|
model_names.extend(
|
||||||
|
name for name in tcms if isinstance(name, str) and name
|
||||||
|
)
|
||||||
|
tcm = profile.get("token_count_model")
|
||||||
|
if isinstance(tcm, str) and tcm and tcm not in model_names:
|
||||||
|
model_names.append(tcm)
|
||||||
|
model_name = model_names[0] if model_names else getattr(llm, "model", None)
|
||||||
|
if not isinstance(model_name, str) or not model_name:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return token_counter(messages=message_payload, model=model_name)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
get_max_input_tokens = getattr(llm, "_get_max_input_tokens", None) if llm else None
|
||||||
|
if callable(get_max_input_tokens):
|
||||||
|
try:
|
||||||
|
max_input_tokens = int(get_max_input_tokens())
|
||||||
|
except Exception:
|
||||||
|
max_input_tokens = None
|
||||||
|
else:
|
||||||
|
profile = getattr(llm, "profile", None) if llm is not None else None
|
||||||
|
max_input_tokens = (
|
||||||
|
profile.get("max_input_tokens")
|
||||||
|
if isinstance(profile, dict)
|
||||||
|
and isinstance(profile.get("max_input_tokens"), int)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(max_input_tokens, int) or max_input_tokens <= 0:
|
||||||
|
return _legacy_render()
|
||||||
|
|
||||||
|
output_reserve = min(max(int(max_input_tokens * 0.02), 256), 1024)
|
||||||
|
budget = max_input_tokens - output_reserve
|
||||||
|
if budget <= 0:
|
||||||
|
return _legacy_render()
|
||||||
|
|
||||||
|
selected_lines: list[str] = []
|
||||||
|
for role, text in reversed(rendered):
|
||||||
|
candidate_line = f"{role}: {text}"
|
||||||
|
candidate_lines = [candidate_line, *selected_lines]
|
||||||
|
candidate_conversation = "\n".join(candidate_lines)
|
||||||
|
token_count = _count_prompt_tokens(candidate_conversation)
|
||||||
|
if token_count is None:
|
||||||
|
return _legacy_render()
|
||||||
|
if token_count <= budget:
|
||||||
|
selected_lines = candidate_lines
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If the full message does not fit, keep as much of this most-recent
|
||||||
|
# older message as possible via binary search.
|
||||||
|
lo, hi = 1, len(text)
|
||||||
|
best_line: str | None = None
|
||||||
|
while lo <= hi:
|
||||||
|
mid = (lo + hi) // 2
|
||||||
|
clipped_text = text[:mid].rstrip() + "..."
|
||||||
|
clipped_line = f"{role}: {clipped_text}"
|
||||||
|
clipped_conversation = "\n".join([clipped_line, *selected_lines])
|
||||||
|
clipped_tokens = _count_prompt_tokens(clipped_conversation)
|
||||||
|
if clipped_tokens is None:
|
||||||
|
break
|
||||||
|
if clipped_tokens <= budget:
|
||||||
|
best_line = clipped_line
|
||||||
|
lo = mid + 1
|
||||||
|
else:
|
||||||
|
hi = mid - 1
|
||||||
|
|
||||||
|
if best_line is not None:
|
||||||
|
selected_lines = [best_line, *selected_lines]
|
||||||
|
break
|
||||||
|
|
||||||
|
if not selected_lines:
|
||||||
|
return _legacy_render()
|
||||||
|
|
||||||
|
return "\n".join(selected_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_kb_planner_prompt(
|
||||||
|
*,
|
||||||
|
recent_conversation: str,
|
||||||
|
user_text: str,
|
||||||
|
) -> str:
|
||||||
|
"""Build a compact internal prompt for KB query rewriting and date scoping."""
|
||||||
|
today = datetime.now(UTC).date().isoformat()
|
||||||
|
return (
|
||||||
|
"You optimize internal knowledge-base search inputs for document retrieval.\n"
|
||||||
|
"Return JSON only with this exact shape:\n"
|
||||||
|
'{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null"}\n\n'
|
||||||
|
"Rules:\n"
|
||||||
|
"- Preserve the user's intent.\n"
|
||||||
|
"- Rewrite the query to improve retrieval using concrete entities, acronyms, projects, tools, people, and document-specific terms when helpful.\n"
|
||||||
|
"- Keep the query concise and retrieval-focused.\n"
|
||||||
|
"- Only use date filters when the latest user request or recent dialogue clearly implies a time range.\n"
|
||||||
|
"- If you use date filters, prefer returning both bounds.\n"
|
||||||
|
"- If no date filter is useful, return null for both dates.\n"
|
||||||
|
"- Do not include markdown, prose, or explanations.\n\n"
|
||||||
|
f"Today's UTC date: {today}\n\n"
|
||||||
|
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
|
||||||
|
f"Latest user message:\n{user_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_payload(text: str) -> str:
|
||||||
|
"""Extract a JSON object from a raw LLM response."""
|
||||||
|
stripped = text.strip()
|
||||||
|
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
||||||
|
if fenced:
|
||||||
|
return fenced.group(1)
|
||||||
|
|
||||||
|
start = stripped.find("{")
|
||||||
|
end = stripped.rfind("}")
|
||||||
|
if start != -1 and end != -1 and end > start:
|
||||||
|
return stripped[start : end + 1]
|
||||||
|
return stripped
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan:
|
||||||
|
"""Parse and validate the planner's JSON response."""
|
||||||
|
payload = json.loads(_extract_json_payload(response_text))
|
||||||
|
return KBSearchPlan.model_validate(payload)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_optional_date_range(
|
||||||
|
start_date: str | None,
|
||||||
|
end_date: str | None,
|
||||||
|
) -> tuple[datetime | None, datetime | None]:
|
||||||
|
"""Normalize optional planner dates into a UTC datetime range."""
|
||||||
|
parsed_start = parse_date_or_datetime(start_date) if start_date else None
|
||||||
|
parsed_end = parse_date_or_datetime(end_date) if end_date else None
|
||||||
|
|
||||||
|
if parsed_start is None and parsed_end is None:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
resolved_start, resolved_end = resolve_date_range(parsed_start, parsed_end)
|
||||||
|
return resolved_start, resolved_end
|
||||||
|
|
||||||
|
|
||||||
def _build_document_xml(
|
def _build_document_xml(
|
||||||
document: dict[str, Any],
|
document: dict[str, Any],
|
||||||
matched_chunk_ids: set[int] | None = None,
|
matched_chunk_ids: set[int] | None = None,
|
||||||
|
|
@ -264,6 +492,8 @@ async def search_knowledge_base(
|
||||||
available_connectors: list[str] | None = None,
|
available_connectors: list[str] | None = None,
|
||||||
available_document_types: list[str] | None = None,
|
available_document_types: list[str] | None = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Run a single unified hybrid search against the knowledge base.
|
"""Run a single unified hybrid search against the knowledge base.
|
||||||
|
|
||||||
|
|
@ -286,6 +516,8 @@ async def search_knowledge_base(
|
||||||
top_k=retriever_top_k,
|
top_k=retriever_top_k,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
document_type=doc_types,
|
document_type=doc_types,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
query_embedding=embedding.tolist(),
|
query_embedding=embedding.tolist(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -346,16 +578,71 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
llm: BaseChatModel | None = None,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
available_connectors: list[str] | None = None,
|
available_connectors: list[str] | None = None,
|
||||||
available_document_types: list[str] | None = None,
|
available_document_types: list[str] | None = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.llm = llm
|
||||||
self.search_space_id = search_space_id
|
self.search_space_id = search_space_id
|
||||||
self.available_connectors = available_connectors
|
self.available_connectors = available_connectors
|
||||||
self.available_document_types = available_document_types
|
self.available_document_types = available_document_types
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
|
async def _plan_search_inputs(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
messages: Sequence[BaseMessage],
|
||||||
|
user_text: str,
|
||||||
|
) -> tuple[str, datetime | None, datetime | None]:
|
||||||
|
"""Rewrite the KB query and infer optional date filters with the LLM."""
|
||||||
|
if self.llm is None:
|
||||||
|
return user_text, None, None
|
||||||
|
|
||||||
|
recent_conversation = _render_recent_conversation(
|
||||||
|
messages,
|
||||||
|
llm=self.llm,
|
||||||
|
user_text=user_text,
|
||||||
|
)
|
||||||
|
prompt = _build_kb_planner_prompt(
|
||||||
|
recent_conversation=recent_conversation,
|
||||||
|
user_text=user_text,
|
||||||
|
)
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
t0 = loop.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.llm.ainvoke(
|
||||||
|
[HumanMessage(content=prompt)],
|
||||||
|
config={"tags": ["surfsense:internal"]},
|
||||||
|
)
|
||||||
|
plan = _parse_kb_search_plan_response(_extract_text_from_message(response))
|
||||||
|
optimized_query = (
|
||||||
|
re.sub(r"\s+", " ", plan.optimized_query).strip() or user_text
|
||||||
|
)
|
||||||
|
start_date, end_date = _normalize_optional_date_range(
|
||||||
|
plan.start_date,
|
||||||
|
plan.end_date,
|
||||||
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r start=%s end=%s",
|
||||||
|
loop.time() - t0,
|
||||||
|
user_text[:80],
|
||||||
|
optimized_query[:120],
|
||||||
|
start_date.isoformat() if start_date else None,
|
||||||
|
end_date.isoformat() if end_date else None,
|
||||||
|
)
|
||||||
|
return optimized_query, start_date, end_date
|
||||||
|
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
|
||||||
|
logger.warning(
|
||||||
|
"KB planner returned invalid output, using raw query: %s", exc
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive fallback
|
||||||
|
logger.warning("KB planner failed, using raw query: %s", exc)
|
||||||
|
|
||||||
|
return user_text, None, None
|
||||||
|
|
||||||
def before_agent( # type: ignore[override]
|
def before_agent( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
state: AgentState,
|
state: AgentState,
|
||||||
|
|
@ -388,13 +675,19 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
|
||||||
t0 = _perf_log and asyncio.get_event_loop().time()
|
t0 = _perf_log and asyncio.get_event_loop().time()
|
||||||
existing_files = state.get("files")
|
existing_files = state.get("files")
|
||||||
|
planned_query, start_date, end_date = await self._plan_search_inputs(
|
||||||
|
messages=messages,
|
||||||
|
user_text=user_text,
|
||||||
|
)
|
||||||
|
|
||||||
search_results = await search_knowledge_base(
|
search_results = await search_knowledge_base(
|
||||||
query=user_text,
|
query=planned_query,
|
||||||
search_space_id=self.search_space_id,
|
search_space_id=self.search_space_id,
|
||||||
available_connectors=self.available_connectors,
|
available_connectors=self.available_connectors,
|
||||||
available_document_types=self.available_document_types,
|
available_document_types=self.available_document_types,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
)
|
)
|
||||||
new_files = await build_scoped_filesystem(
|
new_files = await build_scoped_filesystem(
|
||||||
documents=search_results,
|
documents=search_results,
|
||||||
|
|
@ -405,9 +698,10 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
|
||||||
if t0 is not None:
|
if t0 is not None:
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[kb_fs_middleware] completed in %.3fs query=%r new_files=%d total=%d",
|
"[kb_fs_middleware] completed in %.3fs query=%r optimized=%r new_files=%d total=%d",
|
||||||
asyncio.get_event_loop().time() - t0,
|
asyncio.get_event_loop().time() - t0,
|
||||||
user_text[:80],
|
user_text[:80],
|
||||||
|
planned_query[:120],
|
||||||
len(new_files),
|
len(new_files),
|
||||||
len(new_files) + len(existing_files or {}),
|
len(new_files) + len(existing_files or {}),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -310,7 +310,7 @@ class GoogleGmailConnector:
|
||||||
Fetch recent messages from Gmail within specified date range.
|
Fetch recent messages from Gmail within specified date range.
|
||||||
Args:
|
Args:
|
||||||
max_results: Maximum number of messages to fetch (default: 50)
|
max_results: Maximum number of messages to fetch (default: 50)
|
||||||
start_date: Start date in YYYY-MM-DD format (default: 30 days ago)
|
start_date: Start date in YYYY-MM-DD format (default: 3 days ago)
|
||||||
end_date: End date in YYYY-MM-DD format (default: today)
|
end_date: End date in YYYY-MM-DD format (default: today)
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing (messages list with details, error message or None)
|
Tuple containing (messages list with details, error message or None)
|
||||||
|
|
@ -334,8 +334,8 @@ class GoogleGmailConnector:
|
||||||
start_query = start_dt.strftime("%Y/%m/%d")
|
start_query = start_dt.strftime("%Y/%m/%d")
|
||||||
query_parts.append(f"after:{start_query}")
|
query_parts.append(f"after:{start_query}")
|
||||||
else:
|
else:
|
||||||
# Default to 30 days ago
|
# Default to 3 days ago
|
||||||
cutoff_date = datetime.now() - timedelta(days=30)
|
cutoff_date = datetime.now() - timedelta(days=3)
|
||||||
date_query = cutoff_date.strftime("%Y/%m/%d")
|
date_query = cutoff_date.strftime("%Y/%m/%d")
|
||||||
query_parts.append(f"after:{date_query}")
|
query_parts.append(f"after:{date_query}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -152,7 +152,9 @@ class _FakeReconciliationStripeClient:
|
||||||
|
|
||||||
|
|
||||||
class TestStripeCheckoutSessionCreation:
|
class TestStripeCheckoutSessionCreation:
|
||||||
async def test_get_status_reflects_backend_toggle(self, client, headers, monkeypatch):
|
async def test_get_status_reflects_backend_toggle(
|
||||||
|
self, client, headers, monkeypatch
|
||||||
|
):
|
||||||
monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGE_BUYING_ENABLED", False)
|
monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGE_BUYING_ENABLED", False)
|
||||||
disabled_response = await client.get("/api/v1/stripe/status", headers=headers)
|
disabled_response = await client.get("/api/v1/stripe/status", headers=headers)
|
||||||
assert disabled_response.status_code == 200, disabled_response.text
|
assert disabled_response.status_code == 200, disabled_response.text
|
||||||
|
|
@ -237,7 +239,9 @@ class TestStripeCheckoutSessionCreation:
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 503, response.text
|
assert response.status_code == 503, response.text
|
||||||
assert response.json()["detail"] == "Page purchases are temporarily unavailable."
|
assert (
|
||||||
|
response.json()["detail"] == "Page purchases are temporarily unavailable."
|
||||||
|
)
|
||||||
|
|
||||||
purchase_count = await _fetchrow("SELECT COUNT(*) AS count FROM page_purchases")
|
purchase_count = await _fetchrow("SELECT COUNT(*) AS count FROM page_purchases")
|
||||||
assert purchase_count is not None
|
assert purchase_count is not None
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
@ -22,6 +22,7 @@ def _make_document(
|
||||||
content: str,
|
content: str,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
created_by_id: str,
|
created_by_id: str,
|
||||||
|
updated_at: datetime | None = None,
|
||||||
) -> Document:
|
) -> Document:
|
||||||
uid = uuid.uuid4().hex[:12]
|
uid = uuid.uuid4().hex[:12]
|
||||||
return Document(
|
return Document(
|
||||||
|
|
@ -34,7 +35,7 @@ def _make_document(
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
created_by_id=created_by_id,
|
created_by_id=created_by_id,
|
||||||
embedding=DUMMY_EMBEDDING,
|
embedding=DUMMY_EMBEDDING,
|
||||||
updated_at=datetime.now(UTC),
|
updated_at=updated_at or datetime.now(UTC),
|
||||||
status={"state": "ready"},
|
status={"state": "ready"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -104,3 +105,54 @@ async def seed_large_doc(
|
||||||
"search_space": db_search_space,
|
"search_space": db_search_space,
|
||||||
"user": db_user,
|
"user": db_user,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def seed_date_filtered_docs(
|
||||||
|
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||||
|
):
|
||||||
|
"""Insert matching docs with different timestamps for date-filter tests."""
|
||||||
|
user_id = str(db_user.id)
|
||||||
|
space_id = db_search_space.id
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
recent_doc = _make_document(
|
||||||
|
title="Recent OCV Notes",
|
||||||
|
document_type=DocumentType.FILE,
|
||||||
|
content="ocv meeting decisions and action items",
|
||||||
|
search_space_id=space_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
old_doc = _make_document(
|
||||||
|
title="Old OCV Notes",
|
||||||
|
document_type=DocumentType.FILE,
|
||||||
|
content="ocv meeting decisions and action items",
|
||||||
|
search_space_id=space_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
updated_at=now - timedelta(days=730),
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add_all([recent_doc, old_doc])
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
db_session.add_all(
|
||||||
|
[
|
||||||
|
_make_chunk(
|
||||||
|
content="ocv meeting decisions and action items recent",
|
||||||
|
document_id=recent_doc.id,
|
||||||
|
),
|
||||||
|
_make_chunk(
|
||||||
|
content="ocv meeting decisions and action items old",
|
||||||
|
document_id=old_doc.id,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"recent_doc": recent_doc,
|
||||||
|
"old_doc": old_doc,
|
||||||
|
"search_space": db_search_space,
|
||||||
|
"user": db_user,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
"""Integration smoke tests for KB search query/date scoping."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.knowledge_search import search_knowledge_base
|
||||||
|
|
||||||
|
from .conftest import DUMMY_EMBEDDING
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
async def test_search_knowledge_base_applies_date_filters(
|
||||||
|
db_session,
|
||||||
|
seed_date_filtered_docs,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
"""Date filters should remove older matching documents from scoped KB results."""
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def fake_shielded_async_session():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.agents.new_chat.middleware.knowledge_search.shielded_async_session",
|
||||||
|
fake_shielded_async_session,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.agents.new_chat.middleware.knowledge_search.embed_texts",
|
||||||
|
lambda texts: [np.array(DUMMY_EMBEDDING) for _ in texts],
|
||||||
|
)
|
||||||
|
|
||||||
|
space_id = seed_date_filtered_docs["search_space"].id
|
||||||
|
recent_cutoff = datetime.now(UTC) - timedelta(days=30)
|
||||||
|
|
||||||
|
unfiltered_results = await search_knowledge_base(
|
||||||
|
query="ocv meeting decisions",
|
||||||
|
search_space_id=space_id,
|
||||||
|
available_document_types=["FILE"],
|
||||||
|
top_k=10,
|
||||||
|
)
|
||||||
|
filtered_results = await search_knowledge_base(
|
||||||
|
query="ocv meeting decisions",
|
||||||
|
search_space_id=space_id,
|
||||||
|
available_document_types=["FILE"],
|
||||||
|
top_k=10,
|
||||||
|
start_date=recent_cutoff,
|
||||||
|
end_date=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
|
||||||
|
unfiltered_ids = {result["document"]["id"] for result in unfiltered_results}
|
||||||
|
filtered_ids = {result["document"]["id"] for result in filtered_results}
|
||||||
|
|
||||||
|
assert seed_date_filtered_docs["recent_doc"].id in unfiltered_ids
|
||||||
|
assert seed_date_filtered_docs["old_doc"].id in unfiltered_ids
|
||||||
|
assert seed_date_filtered_docs["recent_doc"].id in filtered_ids
|
||||||
|
assert seed_date_filtered_docs["old_doc"].id not in filtered_ids
|
||||||
|
|
@ -1,12 +1,16 @@
|
||||||
"""Unit tests for knowledge_search middleware helpers.
|
"""Unit tests for knowledge_search middleware helpers."""
|
||||||
|
|
||||||
These test pure functions that don't require a database.
|
import json
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
from app.agents.new_chat.middleware.knowledge_search import (
|
from app.agents.new_chat.middleware.knowledge_search import (
|
||||||
|
KnowledgeBaseSearchMiddleware,
|
||||||
_build_document_xml,
|
_build_document_xml,
|
||||||
|
_normalize_optional_date_range,
|
||||||
|
_parse_kb_search_plan_response,
|
||||||
|
_render_recent_conversation,
|
||||||
_resolve_search_types,
|
_resolve_search_types,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -131,3 +135,234 @@ class TestBuildDocumentXml:
|
||||||
line for line in lines if "<![CDATA[" in line and "<chunk" in line
|
line for line in lines if "<![CDATA[" in line and "<chunk" in line
|
||||||
]
|
]
|
||||||
assert len(chunk_lines) == 3
|
assert len(chunk_lines) == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ── planner parsing / date normalization ───────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestPlannerHelpers:
|
||||||
|
def test_parse_kb_search_plan_response_accepts_plain_json(self):
|
||||||
|
plan = _parse_kb_search_plan_response(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"optimized_query": "ocv meeting decisions summary",
|
||||||
|
"start_date": "2026-03-01",
|
||||||
|
"end_date": "2026-03-31",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert plan.optimized_query == "ocv meeting decisions summary"
|
||||||
|
assert plan.start_date == "2026-03-01"
|
||||||
|
assert plan.end_date == "2026-03-31"
|
||||||
|
|
||||||
|
def test_parse_kb_search_plan_response_accepts_fenced_json(self):
|
||||||
|
plan = _parse_kb_search_plan_response(
|
||||||
|
"""```json
|
||||||
|
{"optimized_query":"deel founders guide","start_date":null,"end_date":null}
|
||||||
|
```"""
|
||||||
|
)
|
||||||
|
assert plan.optimized_query == "deel founders guide"
|
||||||
|
assert plan.start_date is None
|
||||||
|
assert plan.end_date is None
|
||||||
|
|
||||||
|
def test_normalize_optional_date_range_returns_none_when_absent(self):
|
||||||
|
start_date, end_date = _normalize_optional_date_range(None, None)
|
||||||
|
assert start_date is None
|
||||||
|
assert end_date is None
|
||||||
|
|
||||||
|
def test_normalize_optional_date_range_resolves_single_bound(self):
|
||||||
|
start_date, end_date = _normalize_optional_date_range("2026-03-01", None)
|
||||||
|
assert start_date is not None
|
||||||
|
assert end_date is not None
|
||||||
|
assert start_date.date().isoformat() == "2026-03-01"
|
||||||
|
assert end_date >= start_date
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLLM:
|
||||||
|
def __init__(self, response_text: str):
|
||||||
|
self.response_text = response_text
|
||||||
|
self.calls: list[dict] = []
|
||||||
|
|
||||||
|
async def ainvoke(self, messages, config=None):
|
||||||
|
self.calls.append({"messages": messages, "config": config})
|
||||||
|
return AIMessage(content=self.response_text)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeBudgetLLM:
|
||||||
|
def __init__(self, *, max_input_tokens: int):
|
||||||
|
self._max_input_tokens_value = max_input_tokens
|
||||||
|
|
||||||
|
def _get_max_input_tokens(self) -> int:
|
||||||
|
return self._max_input_tokens_value
|
||||||
|
|
||||||
|
def _count_tokens(self, messages) -> int:
|
||||||
|
# Deterministic, simple proxy for tests: count characters as tokens.
|
||||||
|
return sum(len(msg.get("content", "")) for msg in messages)
|
||||||
|
|
||||||
|
|
||||||
|
class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
|
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="old user context " * 40),
|
||||||
|
AIMessage(content="old assistant answer " * 35),
|
||||||
|
HumanMessage(content="recent user context " * 20),
|
||||||
|
AIMessage(content="recent assistant answer " * 18),
|
||||||
|
HumanMessage(content="latest question"),
|
||||||
|
]
|
||||||
|
|
||||||
|
rendered = _render_recent_conversation(
|
||||||
|
messages,
|
||||||
|
llm=FakeBudgetLLM(max_input_tokens=900),
|
||||||
|
user_text="latest question",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "recent user context" in rendered
|
||||||
|
assert "recent assistant answer" in rendered
|
||||||
|
assert "latest question" not in rendered
|
||||||
|
assert rendered.index("recent user context") < rendered.index(
|
||||||
|
"recent assistant answer"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_render_recent_conversation_falls_back_to_legacy_without_budgeting(self):
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="message one"),
|
||||||
|
AIMessage(content="message two"),
|
||||||
|
HumanMessage(content="latest question"),
|
||||||
|
]
|
||||||
|
|
||||||
|
rendered = _render_recent_conversation(
|
||||||
|
messages,
|
||||||
|
llm=None,
|
||||||
|
user_text="latest question",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "user: message one" in rendered
|
||||||
|
assert "assistant: message two" in rendered
|
||||||
|
assert "latest question" not in rendered
|
||||||
|
|
||||||
|
async def test_middleware_uses_optimized_query_and_dates(self, monkeypatch):
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
async def fake_search_knowledge_base(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def fake_build_scoped_filesystem(**kwargs):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||||
|
fake_search_knowledge_base,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||||
|
fake_build_scoped_filesystem,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = FakeLLM(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"optimized_query": "ocv meeting decisions action items",
|
||||||
|
"start_date": "2026-03-01",
|
||||||
|
"end_date": "2026-03-31",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=37)
|
||||||
|
|
||||||
|
result = await middleware.abefore_agent(
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
HumanMessage(content="what happened in our OCV meeting last month?")
|
||||||
|
]
|
||||||
|
},
|
||||||
|
runtime=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert captured["query"] == "ocv meeting decisions action items"
|
||||||
|
assert captured["start_date"] is not None
|
||||||
|
assert captured["end_date"] is not None
|
||||||
|
assert captured["start_date"].date().isoformat() == "2026-03-01"
|
||||||
|
assert captured["end_date"].date().isoformat() == "2026-03-31"
|
||||||
|
assert llm.calls[0]["config"] == {"tags": ["surfsense:internal"]}
|
||||||
|
|
||||||
|
async def test_middleware_falls_back_when_planner_returns_invalid_json(
|
||||||
|
self,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
async def fake_search_knowledge_base(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def fake_build_scoped_filesystem(**kwargs):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||||
|
fake_search_knowledge_base,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||||
|
fake_build_scoped_filesystem,
|
||||||
|
)
|
||||||
|
|
||||||
|
middleware = KnowledgeBaseSearchMiddleware(
|
||||||
|
llm=FakeLLM("not json"),
|
||||||
|
search_space_id=37,
|
||||||
|
)
|
||||||
|
|
||||||
|
await middleware.abefore_agent(
|
||||||
|
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
|
||||||
|
runtime=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured["query"] == "summarize founders guide by deel"
|
||||||
|
assert captured["start_date"] is None
|
||||||
|
assert captured["end_date"] is None
|
||||||
|
|
||||||
|
async def test_middleware_passes_none_dates_when_planner_returns_nulls(
|
||||||
|
self,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
async def fake_search_knowledge_base(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def fake_build_scoped_filesystem(**kwargs):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||||
|
fake_search_knowledge_base,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||||
|
fake_build_scoped_filesystem,
|
||||||
|
)
|
||||||
|
|
||||||
|
middleware = KnowledgeBaseSearchMiddleware(
|
||||||
|
llm=FakeLLM(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"optimized_query": "deel founders guide summary",
|
||||||
|
"start_date": None,
|
||||||
|
"end_date": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
),
|
||||||
|
search_space_id=37,
|
||||||
|
)
|
||||||
|
|
||||||
|
await middleware.abefore_agent(
|
||||||
|
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
|
||||||
|
runtime=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured["query"] == "deel founders guide summary"
|
||||||
|
assert captured["start_date"] is None
|
||||||
|
assert captured["end_date"] is None
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue