mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-27 09:46:25 +02:00
feat: added ai file sorting
This commit is contained in:
parent
fa0b47dfca
commit
4bee367d4a
51 changed files with 1703 additions and 72 deletions
|
|
@ -93,7 +93,8 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
@staticmethod
|
||||
def _dedup(
|
||||
state: AgentState, dedup_keys: dict[str, str] # type: ignore[type-arg]
|
||||
state: AgentState,
|
||||
dedup_keys: dict[str, str], # type: ignore[type-arg]
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages")
|
||||
if not messages:
|
||||
|
|
|
|||
|
|
@ -593,7 +593,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
runtime: ToolRuntime[None, FilesystemState],
|
||||
timeout: int | None,
|
||||
) -> str:
|
||||
sandbox, is_new = await get_or_create_sandbox(self._thread_id)
|
||||
sandbox, _is_new = await get_or_create_sandbox(self._thread_id)
|
||||
# NOTE: sync_files_to_sandbox is intentionally disabled.
|
||||
# The virtual FS contains XML-wrapped KB documents whose paths
|
||||
# would double-nest under SANDBOX_DOCUMENTS_ROOT (e.g.
|
||||
|
|
|
|||
|
|
@ -58,6 +58,14 @@ class KBSearchPlan(BaseModel):
|
|||
default=None,
|
||||
description="Optional ISO end date or datetime for KB search filtering.",
|
||||
)
|
||||
is_recency_query: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"True when the user's intent is primarily about recency or temporal "
|
||||
"ordering (e.g. 'latest', 'newest', 'most recent', 'last uploaded') "
|
||||
"rather than topical relevance."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||
|
|
@ -245,7 +253,7 @@ def _build_kb_planner_prompt(
|
|||
return (
|
||||
"You optimize internal knowledge-base search inputs for document retrieval.\n"
|
||||
"Return JSON only with this exact shape:\n"
|
||||
'{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null"}\n\n'
|
||||
'{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null","is_recency_query":bool}\n\n'
|
||||
"Rules:\n"
|
||||
"- Preserve the user's intent.\n"
|
||||
"- Rewrite the query to improve retrieval using concrete entities, acronyms, projects, tools, people, and document-specific terms when helpful.\n"
|
||||
|
|
@ -253,6 +261,11 @@ def _build_kb_planner_prompt(
|
|||
"- Only use date filters when the latest user request or recent dialogue clearly implies a time range.\n"
|
||||
"- If you use date filters, prefer returning both bounds.\n"
|
||||
"- If no date filter is useful, return null for both dates.\n"
|
||||
'- Set "is_recency_query" to true ONLY when the user\'s primary intent is about '
|
||||
"recency or temporal ordering rather than topical relevance. Examples: "
|
||||
'"latest file", "newest upload", "most recent document", "what did I save last", '
|
||||
'"show me files from today", "last thing I added". '
|
||||
"When true, results will be sorted by date instead of relevance.\n"
|
||||
"- Do not include markdown, prose, or explanations.\n\n"
|
||||
f"Today's UTC date: {today}\n\n"
|
||||
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
|
||||
|
|
@ -506,6 +519,135 @@ def _resolve_search_types(
|
|||
return list(expanded) if expanded else None
|
||||
|
||||
|
||||
_RECENCY_MAX_CHUNKS_PER_DOC = 5
|
||||
|
||||
|
||||
async def browse_recent_documents(
|
||||
*,
|
||||
search_space_id: int,
|
||||
document_type: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return documents ordered by recency (newest first), no relevance ranking.
|
||||
|
||||
Used when the user's intent is temporal ("latest file", "most recent upload")
|
||||
and hybrid search would produce poor results because the query has no
|
||||
meaningful topical signal.
|
||||
"""
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from app.db import DocumentType
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
base_conditions = [
|
||||
Document.search_space_id == search_space_id,
|
||||
func.coalesce(Document.status["state"].astext, "ready") != "deleting",
|
||||
]
|
||||
|
||||
if document_type is not None:
|
||||
import contextlib
|
||||
|
||||
doc_type_enums = []
|
||||
for dt in document_type:
|
||||
if isinstance(dt, str):
|
||||
with contextlib.suppress(KeyError):
|
||||
doc_type_enums.append(DocumentType[dt])
|
||||
else:
|
||||
doc_type_enums.append(dt)
|
||||
if doc_type_enums:
|
||||
if len(doc_type_enums) == 1:
|
||||
base_conditions.append(Document.document_type == doc_type_enums[0])
|
||||
else:
|
||||
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
||||
|
||||
if start_date is not None:
|
||||
base_conditions.append(Document.updated_at >= start_date)
|
||||
if end_date is not None:
|
||||
base_conditions.append(Document.updated_at <= end_date)
|
||||
|
||||
doc_query = (
|
||||
select(Document)
|
||||
.where(*base_conditions)
|
||||
.order_by(Document.updated_at.desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
result = await session.execute(doc_query)
|
||||
documents = result.scalars().unique().all()
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
doc_ids = [d.id for d in documents]
|
||||
|
||||
numbered = (
|
||||
select(
|
||||
Chunk.id.label("chunk_id"),
|
||||
Chunk.document_id,
|
||||
Chunk.content,
|
||||
func.row_number()
|
||||
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
|
||||
.label("rn"),
|
||||
)
|
||||
.where(Chunk.document_id.in_(doc_ids))
|
||||
.subquery("numbered")
|
||||
)
|
||||
|
||||
chunk_query = (
|
||||
select(numbered.c.chunk_id, numbered.c.document_id, numbered.c.content)
|
||||
.where(numbered.c.rn <= _RECENCY_MAX_CHUNKS_PER_DOC)
|
||||
.order_by(numbered.c.document_id, numbered.c.chunk_id)
|
||||
)
|
||||
chunk_result = await session.execute(chunk_query)
|
||||
fetched_chunks = chunk_result.all()
|
||||
|
||||
doc_chunks: dict[int, list[dict[str, Any]]] = {d.id: [] for d in documents}
|
||||
for row in fetched_chunks:
|
||||
if row.document_id in doc_chunks:
|
||||
doc_chunks[row.document_id].append(
|
||||
{"chunk_id": row.chunk_id, "content": row.content}
|
||||
)
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for doc in documents:
|
||||
chunks_list = doc_chunks.get(doc.id, [])
|
||||
metadata = doc.document_metadata or {}
|
||||
results.append(
|
||||
{
|
||||
"document_id": doc.id,
|
||||
"content": "\n\n".join(
|
||||
c["content"] for c in chunks_list if c.get("content")
|
||||
),
|
||||
"score": 0.0,
|
||||
"chunks": chunks_list,
|
||||
"matched_chunk_ids": [],
|
||||
"document": {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"document_type": (
|
||||
doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None
|
||||
),
|
||||
"metadata": metadata,
|
||||
},
|
||||
"source": (
|
||||
doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"browse_recent_documents: %d docs returned for space=%d",
|
||||
len(results),
|
||||
search_space_id,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def search_knowledge_base(
|
||||
*,
|
||||
query: str,
|
||||
|
|
@ -704,10 +846,13 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
*,
|
||||
messages: Sequence[BaseMessage],
|
||||
user_text: str,
|
||||
) -> tuple[str, datetime | None, datetime | None]:
|
||||
"""Rewrite the KB query and infer optional date filters with the LLM."""
|
||||
) -> tuple[str, datetime | None, datetime | None, bool]:
|
||||
"""Rewrite the KB query and infer optional date filters with the LLM.
|
||||
|
||||
Returns (optimized_query, start_date, end_date, is_recency_query).
|
||||
"""
|
||||
if self.llm is None:
|
||||
return user_text, None, None
|
||||
return user_text, None, None, False
|
||||
|
||||
recent_conversation = _render_recent_conversation(
|
||||
messages,
|
||||
|
|
@ -734,15 +879,18 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
plan.start_date,
|
||||
plan.end_date,
|
||||
)
|
||||
is_recency = plan.is_recency_query
|
||||
_perf_log.info(
|
||||
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r start=%s end=%s",
|
||||
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r "
|
||||
"start=%s end=%s recency=%s",
|
||||
loop.time() - t0,
|
||||
user_text[:80],
|
||||
optimized_query[:120],
|
||||
start_date.isoformat() if start_date else None,
|
||||
end_date.isoformat() if end_date else None,
|
||||
is_recency,
|
||||
)
|
||||
return optimized_query, start_date, end_date
|
||||
return optimized_query, start_date, end_date, is_recency
|
||||
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
|
||||
logger.warning(
|
||||
"KB planner returned invalid output, using raw query: %s", exc
|
||||
|
|
@ -750,7 +898,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
except Exception as exc: # pragma: no cover - defensive fallback
|
||||
logger.warning("KB planner failed, using raw query: %s", exc)
|
||||
|
||||
return user_text, None, None
|
||||
return user_text, None, None, False
|
||||
|
||||
def before_agent( # type: ignore[override]
|
||||
self,
|
||||
|
|
@ -789,7 +937,12 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
t0 = _perf_log and asyncio.get_event_loop().time()
|
||||
existing_files = state.get("files")
|
||||
planned_query, start_date, end_date = await self._plan_search_inputs(
|
||||
(
|
||||
planned_query,
|
||||
start_date,
|
||||
end_date,
|
||||
is_recency,
|
||||
) = await self._plan_search_inputs(
|
||||
messages=messages,
|
||||
user_text=user_text,
|
||||
)
|
||||
|
|
@ -805,16 +958,28 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
# messages within the same agent instance.
|
||||
self.mentioned_document_ids = []
|
||||
|
||||
# --- 2. Run KB hybrid search ---
|
||||
search_results = await search_knowledge_base(
|
||||
query=planned_query,
|
||||
search_space_id=self.search_space_id,
|
||||
available_connectors=self.available_connectors,
|
||||
available_document_types=self.available_document_types,
|
||||
top_k=self.top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
# --- 2. Run KB search (recency browse or hybrid) ---
|
||||
if is_recency:
|
||||
doc_types = _resolve_search_types(
|
||||
self.available_connectors, self.available_document_types
|
||||
)
|
||||
search_results = await browse_recent_documents(
|
||||
search_space_id=self.search_space_id,
|
||||
document_type=doc_types,
|
||||
top_k=self.top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
else:
|
||||
search_results = await search_knowledge_base(
|
||||
query=planned_query,
|
||||
search_space_id=self.search_space_id,
|
||||
available_connectors=self.available_connectors,
|
||||
available_document_types=self.available_document_types,
|
||||
top_k=self.top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# --- 3. Merge: mentioned first, then search (dedup by doc id) ---
|
||||
seen_doc_ids: set[int] = set()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue