diff --git a/surfsense_backend/alembic/versions/124_add_ai_file_sort_enabled.py b/surfsense_backend/alembic/versions/124_add_ai_file_sort_enabled.py new file mode 100644 index 000000000..b77eb9337 --- /dev/null +++ b/surfsense_backend/alembic/versions/124_add_ai_file_sort_enabled.py @@ -0,0 +1,44 @@ +"""124_add_ai_file_sort_enabled + +Revision ID: 124 +Revises: 123 +Create Date: 2026-04-14 + +Adds ai_file_sort_enabled boolean column to searchspaces. +Defaults to False so AI file sorting is opt-in per search space. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "124" +down_revision: str | None = "123" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + conn = op.get_bind() + existing_columns = [ + col["name"] for col in sa.inspect(conn).get_columns("searchspaces") + ] + + if "ai_file_sort_enabled" not in existing_columns: + op.add_column( + "searchspaces", + sa.Column( + "ai_file_sort_enabled", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + ) + + +def downgrade() -> None: + op.drop_column("searchspaces", "ai_file_sort_enabled") diff --git a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py index bc6f7fd9e..61494ff1a 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py +++ b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py @@ -93,7 +93,8 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] @staticmethod def _dedup( - state: AgentState, dedup_keys: dict[str, str] # type: ignore[type-arg] + state: AgentState, + dedup_keys: dict[str, str], # type: ignore[type-arg] ) -> dict[str, Any] | None: messages = state.get("messages") if not messages: diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index 314810eff..bcd544d61 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -593,7 +593,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): runtime: ToolRuntime[None, FilesystemState], timeout: int | None, ) -> str: - sandbox, is_new = await get_or_create_sandbox(self._thread_id) + sandbox, _is_new = await get_or_create_sandbox(self._thread_id) # NOTE: sync_files_to_sandbox is intentionally disabled. # The virtual FS contains XML-wrapped KB documents whose paths # would double-nest under SANDBOX_DOCUMENTS_ROOT (e.g. diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py index 06ed4ad80..0460da11d 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py @@ -58,6 +58,14 @@ class KBSearchPlan(BaseModel): default=None, description="Optional ISO end date or datetime for KB search filtering.", ) + is_recency_query: bool = Field( + default=False, + description=( + "True when the user's intent is primarily about recency or temporal " + "ordering (e.g. 'latest', 'newest', 'most recent', 'last uploaded') " + "rather than topical relevance." + ), + ) def _extract_text_from_message(message: BaseMessage) -> str: @@ -245,7 +253,7 @@ def _build_kb_planner_prompt( return ( "You optimize internal knowledge-base search inputs for document retrieval.\n" "Return JSON only with this exact shape:\n" - '{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null"}\n\n' + '{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null","is_recency_query":bool}\n\n' "Rules:\n" "- Preserve the user's intent.\n" "- Rewrite the query to improve retrieval using concrete entities, acronyms, projects, tools, people, and document-specific terms when helpful.\n" @@ -253,6 +261,11 @@ def _build_kb_planner_prompt( "- Only use date filters when the latest user request or recent dialogue clearly implies a time range.\n" "- If you use date filters, prefer returning both bounds.\n" "- If no date filter is useful, return null for both dates.\n" + '- Set "is_recency_query" to true ONLY when the user\'s primary intent is about ' + "recency or temporal ordering rather than topical relevance. Examples: " + '"latest file", "newest upload", "most recent document", "what did I save last", ' + '"show me files from today", "last thing I added". ' + "When true, results will be sorted by date instead of relevance.\n" "- Do not include markdown, prose, or explanations.\n\n" f"Today's UTC date: {today}\n\n" f"Recent conversation:\n{recent_conversation or '(none)'}\n\n" @@ -506,6 +519,135 @@ def _resolve_search_types( return list(expanded) if expanded else None +_RECENCY_MAX_CHUNKS_PER_DOC = 5 + + +async def browse_recent_documents( + *, + search_space_id: int, + document_type: list[str] | None = None, + top_k: int = 10, + start_date: datetime | None = None, + end_date: datetime | None = None, +) -> list[dict[str, Any]]: + """Return documents ordered by recency (newest first), no relevance ranking. + + Used when the user's intent is temporal ("latest file", "most recent upload") + and hybrid search would produce poor results because the query has no + meaningful topical signal. + """ + from sqlalchemy import func, select + + from app.db import DocumentType + + async with shielded_async_session() as session: + base_conditions = [ + Document.search_space_id == search_space_id, + func.coalesce(Document.status["state"].astext, "ready") != "deleting", + ] + + if document_type is not None: + import contextlib + + doc_type_enums = [] + for dt in document_type: + if isinstance(dt, str): + with contextlib.suppress(KeyError): + doc_type_enums.append(DocumentType[dt]) + else: + doc_type_enums.append(dt) + if doc_type_enums: + if len(doc_type_enums) == 1: + base_conditions.append(Document.document_type == doc_type_enums[0]) + else: + base_conditions.append(Document.document_type.in_(doc_type_enums)) + + if start_date is not None: + base_conditions.append(Document.updated_at >= start_date) + if end_date is not None: + base_conditions.append(Document.updated_at <= end_date) + + doc_query = ( + select(Document) + .where(*base_conditions) + .order_by(Document.updated_at.desc()) + .limit(top_k) + ) + result = await session.execute(doc_query) + documents = result.scalars().unique().all() + + if not documents: + return [] + + doc_ids = [d.id for d in documents] + + numbered = ( + select( + Chunk.id.label("chunk_id"), + Chunk.document_id, + Chunk.content, + func.row_number() + .over(partition_by=Chunk.document_id, order_by=Chunk.id) + .label("rn"), + ) + .where(Chunk.document_id.in_(doc_ids)) + .subquery("numbered") + ) + + chunk_query = ( + select(numbered.c.chunk_id, numbered.c.document_id, numbered.c.content) + .where(numbered.c.rn <= _RECENCY_MAX_CHUNKS_PER_DOC) + .order_by(numbered.c.document_id, numbered.c.chunk_id) + ) + chunk_result = await session.execute(chunk_query) + fetched_chunks = chunk_result.all() + + doc_chunks: dict[int, list[dict[str, Any]]] = {d.id: [] for d in documents} + for row in fetched_chunks: + if row.document_id in doc_chunks: + doc_chunks[row.document_id].append( + {"chunk_id": row.chunk_id, "content": row.content} + ) + + results: list[dict[str, Any]] = [] + for doc in documents: + chunks_list = doc_chunks.get(doc.id, []) + metadata = doc.document_metadata or {} + results.append( + { + "document_id": doc.id, + "content": "\n\n".join( + c["content"] for c in chunks_list if c.get("content") + ), + "score": 0.0, + "chunks": chunks_list, + "matched_chunk_ids": [], + "document": { + "id": doc.id, + "title": doc.title, + "document_type": ( + doc.document_type.value + if getattr(doc, "document_type", None) + else None + ), + "metadata": metadata, + }, + "source": ( + doc.document_type.value + if getattr(doc, "document_type", None) + else None + ), + } + ) + + logger.info( + "browse_recent_documents: %d docs returned for space=%d", + len(results), + search_space_id, + ) + return results + + async def search_knowledge_base( *, query: str, @@ -704,10 +846,13 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] *, messages: Sequence[BaseMessage], user_text: str, - ) -> tuple[str, datetime | None, datetime | None]: - """Rewrite the KB query and infer optional date filters with the LLM.""" + ) -> tuple[str, datetime | None, datetime | None, bool]: + """Rewrite the KB query and infer optional date filters with the LLM. + + Returns (optimized_query, start_date, end_date, is_recency_query). + """ if self.llm is None: - return user_text, None, None + return user_text, None, None, False recent_conversation = _render_recent_conversation( messages, @@ -734,15 +879,18 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] plan.start_date, plan.end_date, ) + is_recency = plan.is_recency_query _perf_log.info( - "[kb_fs_middleware] planner in %.3fs query=%r optimized=%r start=%s end=%s", + "[kb_fs_middleware] planner in %.3fs query=%r optimized=%r " + "start=%s end=%s recency=%s", loop.time() - t0, user_text[:80], optimized_query[:120], start_date.isoformat() if start_date else None, end_date.isoformat() if end_date else None, + is_recency, ) - return optimized_query, start_date, end_date + return optimized_query, start_date, end_date, is_recency except (json.JSONDecodeError, ValidationError, ValueError) as exc: logger.warning( "KB planner returned invalid output, using raw query: %s", exc @@ -750,7 +898,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] except Exception as exc: # pragma: no cover - defensive fallback logger.warning("KB planner failed, using raw query: %s", exc) - return user_text, None, None + return user_text, None, None, False def before_agent( # type: ignore[override] self, @@ -789,7 +937,12 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] t0 = _perf_log and asyncio.get_event_loop().time() existing_files = state.get("files") - planned_query, start_date, end_date = await self._plan_search_inputs( + ( + planned_query, + start_date, + end_date, + is_recency, + ) = await self._plan_search_inputs( messages=messages, user_text=user_text, ) @@ -805,16 +958,28 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] # messages within the same agent instance. self.mentioned_document_ids = [] - # --- 2. Run KB hybrid search --- - search_results = await search_knowledge_base( - query=planned_query, - search_space_id=self.search_space_id, - available_connectors=self.available_connectors, - available_document_types=self.available_document_types, - top_k=self.top_k, - start_date=start_date, - end_date=end_date, - ) + # --- 2. Run KB search (recency browse or hybrid) --- + if is_recency: + doc_types = _resolve_search_types( + self.available_connectors, self.available_document_types + ) + search_results = await browse_recent_documents( + search_space_id=self.search_space_id, + document_type=doc_types, + top_k=self.top_k, + start_date=start_date, + end_date=end_date, + ) + else: + search_results = await search_knowledge_base( + query=planned_query, + search_space_id=self.search_space_id, + available_connectors=self.available_connectors, + available_document_types=self.available_document_types, + top_k=self.top_k, + start_date=start_date, + end_date=end_date, + ) # --- 3. Merge: mentioned first, then search (dedup by doc id) --- seen_doc_ids: set[int] = set() diff --git a/surfsense_backend/app/agents/new_chat/sandbox.py b/surfsense_backend/app/agents/new_chat/sandbox.py index 04af466ca..efac7aae8 100644 --- a/surfsense_backend/app/agents/new_chat/sandbox.py +++ b/surfsense_backend/app/agents/new_chat/sandbox.py @@ -138,7 +138,9 @@ def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]: try: client.delete(sandbox) except Exception: - logger.debug("Could not delete broken sandbox %s", sandbox.id, exc_info=True) + logger.debug( + "Could not delete broken sandbox %s", sandbox.id, exc_info=True + ) sandbox = client.create(_sandbox_create_params(labels)) is_new = True logger.info("Created replacement sandbox: %s", sandbox.id) @@ -203,6 +205,7 @@ async def get_or_create_sandbox( def _schedule_sandbox_delete(sandbox: _TimeoutAwareSandbox) -> None: """Best-effort background deletion of an evicted sandbox.""" + def _delete() -> None: try: client = _get_client() diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py index b76f4d757..095413bdb 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py @@ -2,10 +2,10 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector from app.services.confluence import ConfluenceToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py index 070efaf57..7c03c2760 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py @@ -2,10 +2,10 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector from app.services.confluence import ConfluenceToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py index c80df9710..791d0d8c5 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py @@ -2,10 +2,10 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector from app.services.confluence import ConfluenceToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py index 6e2578334..22d8a8a27 100644 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py @@ -5,10 +5,10 @@ from pathlib import Path from typing import Any, Literal from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.dropbox.client import DropboxClient from app.db import SearchSourceConnector, SearchSourceConnectorType diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py index 620b39aa2..12559b57a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py @@ -2,11 +2,11 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy import String, and_, cast, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.dropbox.client import DropboxClient from app.db import ( Document, diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py index 974f9b4af..0bd044695 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py @@ -6,9 +6,9 @@ from email.mime.text import MIMEText from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py index a1c713f0a..c3f0999f4 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py @@ -6,9 +6,9 @@ from email.mime.text import MIMEText from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py index cab97ee8a..1f1f6227a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py @@ -4,9 +4,9 @@ from datetime import datetime from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py index 1d53ac9ce..91178cd21 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py @@ -6,9 +6,9 @@ from email.mime.text import MIMEText from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py index 45ff6dfb9..259f52bba 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py @@ -150,7 +150,9 @@ def create_update_calendar_event_tool( final_new_end_datetime = result.params.get( "new_end_datetime", new_end_datetime ) - final_new_description = result.params.get("new_description", new_description) + final_new_description = result.params.get( + "new_description", new_description + ) final_new_location = result.params.get("new_location", new_location) final_new_attendees = result.params.get("new_attendees", new_attendees) diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py index a1ac90dc7..64ace547c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/hitl.py +++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py @@ -58,7 +58,9 @@ def _parse_decision(approval: Any) -> tuple[str, dict[str, Any]]: raise ValueError("No approval decision received") decision = decisions[0] - decision_type: str = decision.get("type") or decision.get("decision_type") or "approve" + decision_type: str = ( + decision.get("type") or decision.get("decision_type") or "approve" + ) edited_params: dict[str, Any] = {} edited_action = decision.get("edited_action") diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py index 0b3332694..8b40dde65 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py @@ -3,10 +3,10 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector from app.services.jira import JiraToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py index 52d4556a5..6466c80ea 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py @@ -3,10 +3,10 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector from app.services.jira import JiraToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py index 9c676fea3..f6e586a2e 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py @@ -3,10 +3,10 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector from app.services.jira import JiraToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py index d8005bd5c..ff254e133 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py @@ -2,9 +2,9 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector from app.services.linear import LinearToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py index d8bc88d82..29ef0cdf2 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py @@ -2,9 +2,9 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector from app.services.linear import LinearToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py index 7f6d952e5..f35d0dddd 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py @@ -2,9 +2,9 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector from app.services.linear import LinearKBSyncService, LinearToolMetadataService @@ -157,9 +157,13 @@ def create_update_linear_issue_tool( final_issue_id = result.params.get("issue_id", issue_id) final_document_id = result.params.get("document_id", document_id) final_new_title = result.params.get("new_title", new_title) - final_new_description = result.params.get("new_description", new_description) + final_new_description = result.params.get( + "new_description", new_description + ) final_new_state_id = result.params.get("new_state_id", new_state_id) - final_new_assignee_id = result.params.get("new_assignee_id", new_assignee_id) + final_new_assignee_id = result.params.get( + "new_assignee_id", new_assignee_id + ) final_new_priority = result.params.get("new_priority", new_priority) final_new_label_ids: list[str] | None = result.params.get( "new_label_ids", new_label_ids diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py index 396f3fe0d..6efffe960 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py @@ -2,9 +2,9 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.services.notion import NotionToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py index 92e395624..07f7583d2 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py @@ -2,9 +2,9 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.services.notion.tool_metadata_service import NotionToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py index ee7b8f256..85c08177c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py @@ -2,9 +2,9 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.services.notion import NotionToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py index 5050c7885..21272e01d 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py @@ -5,10 +5,10 @@ from pathlib import Path from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.onedrive.client import OneDriveClient from app.db import SearchSourceConnector, SearchSourceConnectorType diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py index 6997e1d52..a7f13b5df 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py @@ -2,11 +2,11 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy import String, and_, cast, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.onedrive.client import OneDriveClient from app.db import ( Document, diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index e69d28ac2..82d77f847 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -1321,6 +1321,10 @@ class SearchSpace(BaseModel, TimestampMixin): Integer, nullable=True, default=0 ) # For vision/screenshot analysis, defaults to Auto mode + ai_file_sort_enabled = Column( + Boolean, nullable=False, default=False, server_default="false" + ) + user_id = Column( UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False ) diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index 22c552e5c..e6b2458f3 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -422,6 +422,8 @@ class IndexingPipelineService: ) log_index_success(ctx, chunk_count=len(chunks)) + await self._enqueue_ai_sort_if_enabled(document) + except RETRYABLE_LLM_ERRORS as e: log_retryable_llm_error(ctx, e) await rollback_and_persist_failure( @@ -457,6 +459,29 @@ class IndexingPipelineService: return document + async def _enqueue_ai_sort_if_enabled(self, document: Document) -> None: + """Fire-and-forget: enqueue incremental AI sort if the search space has it enabled.""" + try: + from app.db import SearchSpace + + result = await self.session.execute( + select(SearchSpace.ai_file_sort_enabled).where( + SearchSpace.id == document.search_space_id + ) + ) + enabled = result.scalar() + if not enabled: + return + + from app.tasks.celery_tasks.document_tasks import ai_sort_document_task + + user_id = str(document.created_by_id) if document.created_by_id else "" + ai_sort_document_task.delay(document.search_space_id, user_id, document.id) + except Exception: + logging.getLogger(__name__).warning( + "Failed to enqueue AI sort for document %s", document.id, exc_info=True + ) + async def index_batch_parallel( self, connector_docs: list[ConnectorDocument], diff --git a/surfsense_backend/app/routes/export_routes.py b/surfsense_backend/app/routes/export_routes.py index 641c7fedb..4f2b545a3 100644 --- a/surfsense_backend/app/routes/export_routes.py +++ b/surfsense_backend/app/routes/export_routes.py @@ -20,7 +20,9 @@ router = APIRouter() @router.get("/search-spaces/{search_space_id}/export") async def export_knowledge_base( search_space_id: int, - folder_id: int | None = Query(None, description="Export only this folder's subtree"), + folder_id: int | None = Query( + None, description="Export only this folder's subtree" + ), session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 7fa33ba1c..828137518 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -216,6 +216,7 @@ async def read_search_spaces( user_id=space.user_id, citations_enabled=space.citations_enabled, qna_custom_instructions=space.qna_custom_instructions, + ai_file_sort_enabled=space.ai_file_sort_enabled, member_count=member_count, is_owner=is_owner, ) @@ -384,6 +385,42 @@ async def edit_team_memory( return db_search_space +@router.post("/searchspaces/{search_space_id}/ai-sort") +async def trigger_ai_sort( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Trigger a full AI file sort for all documents in the search space.""" + try: + await check_permission( + session, + user, + search_space_id, + Permission.SETTINGS_UPDATE.value, + "You don't have permission to trigger AI sort on this search space", + ) + + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + db_search_space = result.scalars().first() + if not db_search_space: + raise HTTPException(status_code=404, detail="Search space not found") + + from app.tasks.celery_tasks.document_tasks import ai_sort_search_space_task + + ai_sort_search_space_task.delay(search_space_id, str(user.id)) + return {"message": "AI sort started"} + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to trigger AI sort: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to trigger AI sort: {e!s}" + ) from e + + @router.delete("/searchspaces/{search_space_id}", response_model=dict) async def delete_search_space( search_space_id: int, diff --git a/surfsense_backend/app/schemas/search_space.py b/surfsense_backend/app/schemas/search_space.py index e3b50be65..77e34ea4b 100644 --- a/surfsense_backend/app/schemas/search_space.py +++ b/surfsense_backend/app/schemas/search_space.py @@ -22,6 +22,7 @@ class SearchSpaceUpdate(BaseModel): citations_enabled: bool | None = None qna_custom_instructions: str | None = None shared_memory_md: str | None = None + ai_file_sort_enabled: bool | None = None class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel): @@ -31,6 +32,7 @@ class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel): citations_enabled: bool qna_custom_instructions: str | None = None shared_memory_md: str | None = None + ai_file_sort_enabled: bool = False model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/services/ai_file_sort_service.py b/surfsense_backend/app/services/ai_file_sort_service.py new file mode 100644 index 000000000..2f04131a6 --- /dev/null +++ b/surfsense_backend/app/services/ai_file_sort_service.py @@ -0,0 +1,329 @@ +"""AI File Sort Service: builds connector-type/date/category/subcategory folder paths.""" + +from __future__ import annotations + +import json +import logging +import re +from datetime import UTC, datetime + +from langchain_core.messages import HumanMessage +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm import selectinload + +from app.db import ( + Chunk, + Document, + DocumentType, + SearchSourceConnector, + SearchSourceConnectorType, +) +from app.services.folder_service import ensure_folder_hierarchy_with_depth_validation + +logger = logging.getLogger(__name__) + +_DOCTYPE_TO_CONNECTOR_LABEL: dict[str, str] = { + DocumentType.EXTENSION: "Browser Extension", + DocumentType.CRAWLED_URL: "Web Crawl", + DocumentType.FILE: "File Upload", + DocumentType.SLACK_CONNECTOR: "Slack", + DocumentType.TEAMS_CONNECTOR: "Teams", + DocumentType.ONEDRIVE_FILE: "OneDrive", + DocumentType.NOTION_CONNECTOR: "Notion", + DocumentType.YOUTUBE_VIDEO: "YouTube", + DocumentType.GITHUB_CONNECTOR: "GitHub", + DocumentType.LINEAR_CONNECTOR: "Linear", + DocumentType.DISCORD_CONNECTOR: "Discord", + DocumentType.JIRA_CONNECTOR: "Jira", + DocumentType.CONFLUENCE_CONNECTOR: "Confluence", + DocumentType.CLICKUP_CONNECTOR: "ClickUp", + DocumentType.GOOGLE_CALENDAR_CONNECTOR: "Google Calendar", + DocumentType.GOOGLE_GMAIL_CONNECTOR: "Gmail", + DocumentType.GOOGLE_DRIVE_FILE: "Google Drive", + DocumentType.AIRTABLE_CONNECTOR: "Airtable", + DocumentType.LUMA_CONNECTOR: "Luma", + DocumentType.ELASTICSEARCH_CONNECTOR: "Elasticsearch", + DocumentType.BOOKSTACK_CONNECTOR: "BookStack", + DocumentType.CIRCLEBACK: "Circleback", + DocumentType.OBSIDIAN_CONNECTOR: "Obsidian", + DocumentType.NOTE: "Notes", + DocumentType.DROPBOX_FILE: "Dropbox", + DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: "Google Drive (Composio)", + DocumentType.COMPOSIO_GMAIL_CONNECTOR: "Gmail (Composio)", + DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: "Google Calendar (Composio)", + DocumentType.LOCAL_FOLDER_FILE: "Local Folder", +} + +_CONNECTOR_TYPE_LABEL: dict[str, str] = { + SearchSourceConnectorType.SERPER_API: "Serper Search", + SearchSourceConnectorType.TAVILY_API: "Tavily Search", + SearchSourceConnectorType.SEARXNG_API: "SearXNG Search", + SearchSourceConnectorType.LINKUP_API: "Linkup Search", + SearchSourceConnectorType.BAIDU_SEARCH_API: "Baidu Search", + SearchSourceConnectorType.SLACK_CONNECTOR: "Slack", + SearchSourceConnectorType.TEAMS_CONNECTOR: "Teams", + SearchSourceConnectorType.ONEDRIVE_CONNECTOR: "OneDrive", + SearchSourceConnectorType.NOTION_CONNECTOR: "Notion", + SearchSourceConnectorType.GITHUB_CONNECTOR: "GitHub", + SearchSourceConnectorType.LINEAR_CONNECTOR: "Linear", + SearchSourceConnectorType.DISCORD_CONNECTOR: "Discord", + SearchSourceConnectorType.JIRA_CONNECTOR: "Jira", + SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "Confluence", + SearchSourceConnectorType.CLICKUP_CONNECTOR: "ClickUp", + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "Google Calendar", + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: "Gmail", + SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: "Google Drive", + SearchSourceConnectorType.AIRTABLE_CONNECTOR: "Airtable", + SearchSourceConnectorType.LUMA_CONNECTOR: "Luma", + SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: "Elasticsearch", + SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: "Web Crawl", + SearchSourceConnectorType.BOOKSTACK_CONNECTOR: "BookStack", + SearchSourceConnectorType.CIRCLEBACK_CONNECTOR: "Circleback", + SearchSourceConnectorType.OBSIDIAN_CONNECTOR: "Obsidian", + SearchSourceConnectorType.MCP_CONNECTOR: "MCP", + SearchSourceConnectorType.DROPBOX_CONNECTOR: "Dropbox", + SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: "Google Drive (Composio)", + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: "Gmail (Composio)", + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: "Google Calendar (Composio)", +} + +_MAX_CONTENT_CHARS = 4000 +_MAX_CHUNKS_FOR_CONTEXT = 5 + +_CATEGORY_PROMPT = ( + "Based on the document information below, classify it into a broad category " + "and a more specific subcategory.\n\n" + "Rules:\n" + "- category: 1-2 word broad theme (e.g. Science, Finance, Engineering, Communication, Media)\n" + "- subcategory: 1-2 word specific topic within the category " + "(e.g. Physics, Tax Reports, Backend, Team Updates)\n" + "- Use nouns only. Do not include generic terms like 'General' or 'Miscellaneous'.\n\n" + "Title: {title}\n\n" + "Content: {summary}\n\n" + 'Respond with ONLY a JSON object: {{"category": "...", "subcategory": "..."}}' +) + +_SAFE_NAME_RE = re.compile(r"[^a-zA-Z0-9 _\-()]") +_FALLBACK_CATEGORY = "Uncategorized" +_FALLBACK_SUBCATEGORY = "General" + + +def resolve_root_folder_label( + document: Document, connector: SearchSourceConnector | None +) -> str: + if connector is not None: + return _CONNECTOR_TYPE_LABEL.get( + connector.connector_type, str(connector.connector_type) + ) + return _DOCTYPE_TO_CONNECTOR_LABEL.get( + document.document_type, str(document.document_type) + ) + + +def resolve_date_folder(document: Document) -> str: + ts = document.updated_at or document.created_at + if ts is None: + ts = datetime.now(UTC) + return ts.strftime("%Y-%m-%d") + + +def sanitize_category_folder_name( + value: str | None, fallback: str = _FALLBACK_CATEGORY +) -> str: + if not value or not value.strip(): + return fallback + cleaned = _SAFE_NAME_RE.sub("", value.strip()) + cleaned = " ".join(cleaned.split()) + if not cleaned: + return fallback + return cleaned[:50] + + +async def _resolve_document_text( + session: AsyncSession, + document: Document, +) -> str: + """Build the best available text representation for taxonomy generation. + + Prefers ``document.content``; falls back to joining the first few chunks + when content is empty or too short to be useful. + """ + text = (document.content or "").strip() + if len(text) >= 100: + return text[:_MAX_CONTENT_CHARS] + + stmt = ( + select(Chunk.content) + .where(Chunk.document_id == document.id) + .order_by(Chunk.id) + .limit(_MAX_CHUNKS_FOR_CONTEXT) + ) + result = await session.execute(stmt) + chunk_texts = [row[0] for row in result.all() if row[0]] + if chunk_texts: + combined = "\n\n".join(chunk_texts) + return combined[:_MAX_CONTENT_CHARS] + + return text[:_MAX_CONTENT_CHARS] + + +def _get_cached_taxonomy(document: Document) -> tuple[str, str] | None: + """Return (category, subcategory) from document metadata cache, or None.""" + meta = document.document_metadata + if not isinstance(meta, dict): + return None + cat = meta.get("ai_sort_category") + subcat = meta.get("ai_sort_subcategory") + if cat and subcat and isinstance(cat, str) and isinstance(subcat, str): + return cat, subcat + return None + + +def _set_cached_taxonomy(document: Document, category: str, subcategory: str) -> None: + """Persist the AI taxonomy on document metadata for deterministic re-sorts.""" + meta = dict(document.document_metadata or {}) + meta["ai_sort_category"] = category + meta["ai_sort_subcategory"] = subcategory + document.document_metadata = meta + + +async def generate_ai_taxonomy( + title: str, + summary_or_content: str, + llm, +) -> tuple[str, str]: + """Return (category, subcategory) using a single structured LLM call.""" + text = (summary_or_content or "").strip() + if not text: + return _FALLBACK_CATEGORY, _FALLBACK_SUBCATEGORY + + if len(text) > _MAX_CONTENT_CHARS: + text = text[:_MAX_CONTENT_CHARS] + + prompt = _CATEGORY_PROMPT.format(title=title or "Untitled", summary=text) + try: + result = await llm.ainvoke([HumanMessage(content=prompt)]) + raw = result.content.strip() + if raw.startswith("```"): + raw = raw.split("\n", 1)[-1].rsplit("```", 1)[0].strip() + parsed = json.loads(raw) + category = sanitize_category_folder_name( + parsed.get("category"), _FALLBACK_CATEGORY + ) + subcategory = sanitize_category_folder_name( + parsed.get("subcategory"), _FALLBACK_SUBCATEGORY + ) + return category, subcategory + except Exception: + logger.warning("AI taxonomy generation failed, using fallback", exc_info=True) + return _FALLBACK_CATEGORY, _FALLBACK_SUBCATEGORY + + +def _build_path_segments( + root_label: str, + date_label: str, + category: str, + subcategory: str, +) -> list[dict]: + return [ + {"name": root_label, "metadata": {"ai_sort": True, "ai_sort_level": 1}}, + {"name": date_label, "metadata": {"ai_sort": True, "ai_sort_level": 2}}, + {"name": category, "metadata": {"ai_sort": True, "ai_sort_level": 3}}, + {"name": subcategory, "metadata": {"ai_sort": True, "ai_sort_level": 4}}, + ] + + +async def _resolve_taxonomy( + session: AsyncSession, + document: Document, + llm, +) -> tuple[str, str]: + """Return (category, subcategory), reusing cached values when available.""" + cached = _get_cached_taxonomy(document) + if cached is not None: + return cached + + content_text = await _resolve_document_text(session, document) + category, subcategory = await generate_ai_taxonomy( + document.title, content_text, llm + ) + _set_cached_taxonomy(document, category, subcategory) + return category, subcategory + + +async def ai_sort_document( + session: AsyncSession, + document: Document, + llm, +) -> Document: + """Sort a single document into the 4-level AI folder hierarchy.""" + connector: SearchSourceConnector | None = None + if document.connector_id is not None: + connector = await session.get(SearchSourceConnector, document.connector_id) + + root_label = resolve_root_folder_label(document, connector) + date_label = resolve_date_folder(document) + + category, subcategory = await _resolve_taxonomy(session, document, llm) + + segments = _build_path_segments(root_label, date_label, category, subcategory) + + leaf_folder = await ensure_folder_hierarchy_with_depth_validation( + session, + document.search_space_id, + segments, + ) + + document.folder_id = leaf_folder.id + await session.flush() + return document + + +async def ai_sort_all_documents( + session: AsyncSession, + search_space_id: int, + llm, +) -> tuple[int, int]: + """Sort all documents in a search space. Returns (sorted_count, failed_count).""" + stmt = ( + select(Document) + .where(Document.search_space_id == search_space_id) + .options(selectinload(Document.connector)) + ) + result = await session.execute(stmt) + documents = list(result.scalars().all()) + + sorted_count = 0 + failed_count = 0 + + for doc in documents: + try: + connector = doc.connector + root_label = resolve_root_folder_label(doc, connector) + date_label = resolve_date_folder(doc) + + category, subcategory = await _resolve_taxonomy(session, doc, llm) + segments = _build_path_segments( + root_label, date_label, category, subcategory + ) + + leaf_folder = await ensure_folder_hierarchy_with_depth_validation( + session, + search_space_id, + segments, + ) + doc.folder_id = leaf_folder.id + sorted_count += 1 + except Exception: + logger.error("Failed to AI-sort document %s", doc.id, exc_info=True) + failed_count += 1 + + await session.commit() + logger.info( + "AI sort complete for search_space=%d: sorted=%d, failed=%d", + search_space_id, + sorted_count, + failed_count, + ) + return sorted_count, failed_count diff --git a/surfsense_backend/app/services/folder_service.py b/surfsense_backend/app/services/folder_service.py index fc1fb8a75..f5b608600 100644 --- a/surfsense_backend/app/services/folder_service.py +++ b/surfsense_backend/app/services/folder_service.py @@ -142,6 +142,58 @@ async def generate_folder_position( return generate_key_between(last_position, None) +async def ensure_folder_hierarchy_with_depth_validation( + session: AsyncSession, + search_space_id: int, + path_segments: list[dict], +) -> Folder: + """Create or return a nested folder chain, validating depth at each step. + + Each item in ``path_segments`` is a dict with: + - ``name`` (str): folder display name + - ``metadata`` (dict | None): optional ``folder_metadata`` JSONB payload + + Returns the deepest (leaf) Folder in the chain. + """ + parent_id: int | None = None + current_folder: Folder | None = None + + for segment in path_segments: + name = segment["name"] + metadata = segment.get("metadata") + + stmt = select(Folder).where( + Folder.search_space_id == search_space_id, + Folder.name == name, + Folder.parent_id == parent_id + if parent_id is not None + else Folder.parent_id.is_(None), + ) + result = await session.execute(stmt) + folder = result.scalar_one_or_none() + + if folder is None: + await validate_folder_depth(session, parent_id, subtree_depth=0) + position = await generate_folder_position( + session, search_space_id, parent_id + ) + folder = Folder( + name=name, + search_space_id=search_space_id, + parent_id=parent_id, + position=position, + folder_metadata=metadata, + ) + session.add(folder) + await session.flush() + + current_folder = folder + parent_id = folder.id + + assert current_folder is not None, "path_segments must not be empty" + return current_folder + + async def get_folder_subtree_ids(session: AsyncSession, folder_id: int) -> list[int]: """Return all folder IDs in the subtree rooted at folder_id (inclusive).""" result = await session.execute( diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index fc946b4bc..719cfb940 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -4,6 +4,7 @@ import asyncio import contextlib import logging import os +import time from uuid import UUID from app.celery_app import celery_app @@ -1551,3 +1552,121 @@ async def _index_uploaded_folder_files_async( heartbeat_task.cancel() if notification_id is not None: _stop_heartbeat(notification_id) + + +# ===== AI File Sort tasks ===== + +AI_SORT_LOCK_TTL_SECONDS = 600 # 10 minutes +_ai_sort_redis = None + + +def _get_ai_sort_redis(): + import redis + + global _ai_sort_redis + if _ai_sort_redis is None: + _ai_sort_redis = redis.from_url(config.REDIS_APP_URL, decode_responses=True) + return _ai_sort_redis + + +def _ai_sort_lock_key(search_space_id: int) -> str: + return f"ai_sort:search_space:{search_space_id}:lock" + + +@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1) +def ai_sort_search_space_task(self, search_space_id: int, user_id: str): + """Full AI sort for all documents in a search space.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id)) + finally: + loop.close() + + +async def _ai_sort_search_space_async(search_space_id: int, user_id: str): + r = _get_ai_sort_redis() + lock_key = _ai_sort_lock_key(search_space_id) + + if not r.set(lock_key, "running", nx=True, ex=AI_SORT_LOCK_TTL_SECONDS): + logger.info( + "AI sort already running for search_space=%d, skipping", + search_space_id, + ) + return + + t_start = time.perf_counter() + try: + from app.services.ai_file_sort_service import ai_sort_all_documents + from app.services.llm_service import get_document_summary_llm + + async with get_celery_session_maker()() as session: + llm = await get_document_summary_llm( + session, search_space_id, disable_streaming=True + ) + if llm is None: + logger.warning( + "No LLM configured for search_space=%d, skipping AI sort", + search_space_id, + ) + return + + sorted_count, failed_count = await ai_sort_all_documents( + session, search_space_id, llm + ) + elapsed = time.perf_counter() - t_start + logger.info( + "AI sort search_space=%d done in %.1fs: sorted=%d failed=%d", + search_space_id, + elapsed, + sorted_count, + failed_count, + ) + finally: + r.delete(lock_key) + + +@celery_app.task( + name="ai_sort_document", bind=True, max_retries=2, default_retry_delay=10 +) +def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int): + """Incremental AI sort for a single document after indexing.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + _ai_sort_document_async(search_space_id, user_id, document_id) + ) + finally: + loop.close() + + +async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int): + from app.db import Document + from app.services.ai_file_sort_service import ai_sort_document + from app.services.llm_service import get_document_summary_llm + + async with get_celery_session_maker()() as session: + document = await session.get(Document, document_id) + if document is None: + logger.warning("Document %d not found, skipping AI sort", document_id) + return + + llm = await get_document_summary_llm( + session, search_space_id, disable_streaming=True + ) + if llm is None: + logger.warning( + "No LLM for search_space=%d, skipping AI sort of doc=%d", + search_space_id, + document_id, + ) + return + + await ai_sort_document(session, document, llm) + await session.commit() + logger.info( + "AI sorted document=%d into search_space=%d", + document_id, + search_space_id, + ) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 336ede751..47a270568 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -61,6 +61,7 @@ from app.services.new_streaming_service import VercelStreamingService from app.utils.content_utils import bootstrap_history_from_db from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap +_background_tasks: set[asyncio.Task] = set() _perf_log = get_perf_logger() @@ -1552,7 +1553,7 @@ async def stream_new_chat( # Shared threads write to team memory; private threads write to user memory. if not stream_result.agent_called_update_memory: if visibility == ChatVisibility.SEARCH_SPACE: - asyncio.create_task( + task = asyncio.create_task( extract_and_save_team_memory( user_message=user_query, search_space_id=search_space_id, @@ -1560,14 +1561,18 @@ async def stream_new_chat( author_display_name=current_user_display_name, ) ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) elif user_id: - asyncio.create_task( + task = asyncio.create_task( extract_and_save_memory( user_message=user_query, user_id=user_id, llm=llm, ) ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) # Finish the step and message yield streaming_service.format_finish_step() diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index d8f95da63..21cdbd29f 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -961,6 +961,7 @@ async def index_google_drive_files( vision_llm = None if connector_enable_vision_llm: from app.services.llm_service import get_vision_llm + vision_llm = await get_vision_llm(session, search_space_id) drive_client = GoogleDriveClient( session, connector_id, credentials=pre_built_credentials @@ -1168,6 +1169,7 @@ async def index_google_drive_single_file( vision_llm = None if connector_enable_vision_llm: from app.services.llm_service import get_vision_llm + vision_llm = await get_vision_llm(session, search_space_id) drive_client = GoogleDriveClient( session, connector_id, credentials=pre_built_credentials @@ -1306,6 +1308,7 @@ async def index_google_drive_selected_files( vision_llm = None if connector_enable_vision_llm: from app.services.llm_service import get_vision_llm + vision_llm = await get_vision_llm(session, search_space_id) drive_client = GoogleDriveClient( session, connector_id, credentials=pre_built_credentials diff --git a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py index 2d5f9648d..b6797f77a 100644 --- a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py @@ -1360,7 +1360,9 @@ async def index_uploaded_files( try: content, content_hash = await _compute_file_content_hash( - temp_path, filename, search_space_id, + temp_path, + filename, + search_space_id, vision_llm=vision_llm_instance, ) except Exception as e: diff --git a/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py index aa654a9a9..2def799f3 100644 --- a/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py @@ -656,6 +656,7 @@ async def index_onedrive_files( vision_llm = None if connector_enable_vision_llm: from app.services.llm_service import get_vision_llm + vision_llm = await get_vision_llm(session, search_space_id) onedrive_client = OneDriveClient(session, connector_id) diff --git a/surfsense_backend/scripts/create_sandbox_snapshot.py b/surfsense_backend/scripts/create_sandbox_snapshot.py index 97ed6dfe8..82d452174 100644 --- a/surfsense_backend/scripts/create_sandbox_snapshot.py +++ b/surfsense_backend/scripts/create_sandbox_snapshot.py @@ -21,7 +21,11 @@ from pathlib import Path from dotenv import load_dotenv _here = Path(__file__).parent -for candidate in [_here / "../surfsense_backend/.env", _here / ".env", _here / "../.env"]: +for candidate in [ + _here / "../surfsense_backend/.env", + _here / ".env", + _here / "../.env", +]: if candidate.exists(): load_dotenv(candidate) break @@ -57,7 +61,10 @@ def main() -> None: api_key = os.environ.get("DAYTONA_API_KEY") if not api_key: print("ERROR: DAYTONA_API_KEY is not set.", file=sys.stderr) - print("Add it to surfsense_backend/.env or export it in your shell.", file=sys.stderr) + print( + "Add it to surfsense_backend/.env or export it in your shell.", + file=sys.stderr, + ) sys.exit(1) daytona = Daytona() @@ -67,7 +74,7 @@ def main() -> None: print(f"Deleting existing snapshot '{SNAPSHOT_NAME}' …") daytona.snapshot.delete(existing) print(f"Deleted '{SNAPSHOT_NAME}'. Waiting for removal to propagate …") - for attempt in range(30): + for _attempt in range(30): time.sleep(2) try: daytona.snapshot.get(SNAPSHOT_NAME) @@ -75,7 +82,9 @@ def main() -> None: print(f"Confirmed '{SNAPSHOT_NAME}' is gone.\n") break else: - print(f"WARNING: '{SNAPSHOT_NAME}' may still exist after 60s. Proceeding anyway.\n") + print( + f"WARNING: '{SNAPSHOT_NAME}' may still exist after 60s. Proceeding anyway.\n" + ) except Exception: pass diff --git a/surfsense_backend/tests/unit/etl_pipeline/test_etl_pipeline_service.py b/surfsense_backend/tests/unit/etl_pipeline/test_etl_pipeline_service.py index 1a94d4263..bb8d7ef83 100644 --- a/surfsense_backend/tests/unit/etl_pipeline/test_etl_pipeline_service.py +++ b/surfsense_backend/tests/unit/etl_pipeline/test_etl_pipeline_service.py @@ -431,7 +431,9 @@ async def test_llamacloud_heif_accepted_only_with_azure_di(tmp_path, mocker): mocker.patch("app.config.config.AZURE_DI_ENDPOINT", None, create=True) mocker.patch("app.config.config.AZURE_DI_KEY", None, create=True) - with pytest.raises(EtlUnsupportedFileError, match="document parser does not support this format"): + with pytest.raises( + EtlUnsupportedFileError, match="document parser does not support this format" + ): await EtlPipelineService().extract( EtlRequest(file_path=str(heif_file), filename="photo.heif") ) diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py index a8cf5c93b..1aaf5d127 100644 --- a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py @@ -6,6 +6,7 @@ import pytest from langchain_core.messages import AIMessage, HumanMessage from app.agents.new_chat.middleware.knowledge_search import ( + KBSearchPlan, KnowledgeBaseSearchMiddleware, _build_document_xml, _normalize_optional_date_range, @@ -366,3 +367,146 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: assert captured["query"] == "deel founders guide summary" assert captured["start_date"] is None assert captured["end_date"] is None + + async def test_middleware_routes_to_recency_browse_when_flagged( + self, + monkeypatch, + ): + """When the planner sets is_recency_query=true, browse_recent_documents + is called instead of search_knowledge_base.""" + browse_captured: dict = {} + search_called = False + + async def fake_browse_recent_documents(**kwargs): + browse_captured.update(kwargs) + return [] + + async def fake_search_knowledge_base(**kwargs): + nonlocal search_called + search_called = True + return [] + + async def fake_build_scoped_filesystem(**kwargs): + return {}, {} + + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.browse_recent_documents", + fake_browse_recent_documents, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + fake_search_knowledge_base, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", + fake_build_scoped_filesystem, + ) + + llm = FakeLLM( + json.dumps( + { + "optimized_query": "latest uploaded file", + "start_date": None, + "end_date": None, + "is_recency_query": True, + } + ) + ) + middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42) + + result = await middleware.abefore_agent( + {"messages": [HumanMessage(content="what's my latest file?")]}, + runtime=None, + ) + + assert result is not None + assert browse_captured["search_space_id"] == 42 + assert not search_called + + async def test_middleware_uses_hybrid_search_when_not_recency( + self, + monkeypatch, + ): + """When is_recency_query is false (default), hybrid search is used.""" + search_captured: dict = {} + browse_called = False + + async def fake_browse_recent_documents(**kwargs): + nonlocal browse_called + browse_called = True + return [] + + async def fake_search_knowledge_base(**kwargs): + search_captured.update(kwargs) + return [] + + async def fake_build_scoped_filesystem(**kwargs): + return {}, {} + + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.browse_recent_documents", + fake_browse_recent_documents, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + fake_search_knowledge_base, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", + fake_build_scoped_filesystem, + ) + + llm = FakeLLM( + json.dumps( + { + "optimized_query": "quarterly revenue report analysis", + "start_date": None, + "end_date": None, + "is_recency_query": False, + } + ) + ) + middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42) + + await middleware.abefore_agent( + {"messages": [HumanMessage(content="find the quarterly revenue report")]}, + runtime=None, + ) + + assert search_captured["query"] == "quarterly revenue report analysis" + assert not browse_called + + +# ── KBSearchPlan schema ──────────────────────────────────────────────── + + +class TestKBSearchPlanSchema: + def test_is_recency_query_defaults_to_false(self): + plan = KBSearchPlan(optimized_query="test query") + assert plan.is_recency_query is False + + def test_is_recency_query_parses_true(self): + plan = _parse_kb_search_plan_response( + json.dumps( + { + "optimized_query": "latest uploaded file", + "start_date": None, + "end_date": None, + "is_recency_query": True, + } + ) + ) + assert plan.is_recency_query is True + assert plan.optimized_query == "latest uploaded file" + + def test_missing_is_recency_query_defaults_to_false(self): + plan = _parse_kb_search_plan_response( + json.dumps( + { + "optimized_query": "meeting notes", + "start_date": None, + "end_date": None, + } + ) + ) + assert plan.is_recency_query is False diff --git a/surfsense_backend/tests/unit/services/test_ai_file_sort_service.py b/surfsense_backend/tests/unit/services/test_ai_file_sort_service.py new file mode 100644 index 000000000..860c2ffa2 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_ai_file_sort_service.py @@ -0,0 +1,275 @@ +"""Unit tests for AI file sort service: folder label resolution, date extraction, category sanitization.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest + +pytestmark = pytest.mark.unit + + +# ── resolve_root_folder_label ── + + +def _make_document(document_type: str, connector_id=None): + doc = MagicMock() + doc.document_type = document_type + doc.connector_id = connector_id + return doc + + +def _make_connector(connector_type: str): + conn = MagicMock() + conn.connector_type = connector_type + return conn + + +def test_root_label_uses_connector_type_when_available(): + from app.services.ai_file_sort_service import resolve_root_folder_label + + doc = _make_document("FILE", connector_id=1) + conn = _make_connector("GOOGLE_DRIVE_CONNECTOR") + assert resolve_root_folder_label(doc, conn) == "Google Drive" + + +def test_root_label_falls_back_to_document_type(): + from app.services.ai_file_sort_service import resolve_root_folder_label + + doc = _make_document("SLACK_CONNECTOR") + assert resolve_root_folder_label(doc, None) == "Slack" + + +def test_root_label_unknown_doctype_returns_raw_value(): + from app.services.ai_file_sort_service import resolve_root_folder_label + + doc = _make_document("UNKNOWN_TYPE") + assert resolve_root_folder_label(doc, None) == "UNKNOWN_TYPE" + + +# ── resolve_date_folder ── + + +def test_date_folder_from_updated_at(): + from app.services.ai_file_sort_service import resolve_date_folder + + doc = MagicMock() + doc.updated_at = datetime(2025, 3, 15, 10, 30, 0, tzinfo=UTC) + doc.created_at = datetime(2025, 1, 1, 0, 0, 0, tzinfo=UTC) + assert resolve_date_folder(doc) == "2025-03-15" + + +def test_date_folder_falls_back_to_created_at(): + from app.services.ai_file_sort_service import resolve_date_folder + + doc = MagicMock() + doc.updated_at = None + doc.created_at = datetime(2024, 12, 25, 23, 59, 0, tzinfo=UTC) + assert resolve_date_folder(doc) == "2024-12-25" + + +def test_date_folder_both_none_uses_today(): + from app.services.ai_file_sort_service import resolve_date_folder + + doc = MagicMock() + doc.updated_at = None + doc.created_at = None + result = resolve_date_folder(doc) + today = datetime.now(UTC).strftime("%Y-%m-%d") + assert result == today + + +# ── sanitize_category_folder_name ── + + +def test_sanitize_normal_value(): + from app.services.ai_file_sort_service import sanitize_category_folder_name + + assert sanitize_category_folder_name("Machine Learning") == "Machine Learning" + + +def test_sanitize_strips_special_chars(): + from app.services.ai_file_sort_service import sanitize_category_folder_name + + assert sanitize_category_folder_name("Tax/Reports!") == "TaxReports" + + +def test_sanitize_empty_returns_fallback(): + from app.services.ai_file_sort_service import sanitize_category_folder_name + + assert sanitize_category_folder_name("") == "Uncategorized" + assert sanitize_category_folder_name(None) == "Uncategorized" + + +def test_sanitize_truncates_long_names(): + from app.services.ai_file_sort_service import sanitize_category_folder_name + + long_name = "A" * 100 + result = sanitize_category_folder_name(long_name) + assert len(result) <= 50 + + +# ── generate_ai_taxonomy ── + + +@pytest.mark.asyncio +async def test_generate_ai_taxonomy_parses_json(): + from app.services.ai_file_sort_service import generate_ai_taxonomy + + mock_llm = AsyncMock() + mock_result = MagicMock() + mock_result.content = '{"category": "Science", "subcategory": "Physics"}' + mock_llm.ainvoke.return_value = mock_result + + cat, sub = await generate_ai_taxonomy( + "Physics Paper", "Some science document about physics", mock_llm + ) + assert cat == "Science" + assert sub == "Physics" + + +@pytest.mark.asyncio +async def test_generate_ai_taxonomy_handles_markdown_code_block(): + from app.services.ai_file_sort_service import generate_ai_taxonomy + + mock_llm = AsyncMock() + mock_result = MagicMock() + mock_result.content = ( + '```json\n{"category": "Finance", "subcategory": "Tax Reports"}\n```' + ) + mock_llm.ainvoke.return_value = mock_result + + cat, sub = await generate_ai_taxonomy("Tax Doc", "A tax report document", mock_llm) + assert cat == "Finance" + assert sub == "Tax Reports" + + +@pytest.mark.asyncio +async def test_generate_ai_taxonomy_includes_title_in_prompt(): + from app.services.ai_file_sort_service import generate_ai_taxonomy + + mock_llm = AsyncMock() + mock_result = MagicMock() + mock_result.content = '{"category": "Engineering", "subcategory": "Backend"}' + mock_llm.ainvoke.return_value = mock_result + + await generate_ai_taxonomy("API Design Guide", "content about REST APIs", mock_llm) + + prompt_text = mock_llm.ainvoke.call_args[0][0][0].content + assert "API Design Guide" in prompt_text + assert "content about REST APIs" in prompt_text + + +@pytest.mark.asyncio +async def test_generate_ai_taxonomy_fallback_on_error(): + from app.services.ai_file_sort_service import generate_ai_taxonomy + + mock_llm = AsyncMock() + mock_llm.ainvoke.side_effect = RuntimeError("LLM down") + + cat, sub = await generate_ai_taxonomy("Title", "some content", mock_llm) + assert cat == "Uncategorized" + assert sub == "General" + + +@pytest.mark.asyncio +async def test_generate_ai_taxonomy_fallback_on_empty_content(): + from app.services.ai_file_sort_service import generate_ai_taxonomy + + mock_llm = AsyncMock() + cat, sub = await generate_ai_taxonomy("Title", "", mock_llm) + assert cat == "Uncategorized" + assert sub == "General" + mock_llm.ainvoke.assert_not_called() + + +@pytest.mark.asyncio +async def test_generate_ai_taxonomy_fallback_on_invalid_json(): + from app.services.ai_file_sort_service import generate_ai_taxonomy + + mock_llm = AsyncMock() + mock_result = MagicMock() + mock_result.content = "not valid json at all" + mock_llm.ainvoke.return_value = mock_result + + cat, sub = await generate_ai_taxonomy("Title", "some content", mock_llm) + assert cat == "Uncategorized" + assert sub == "General" + + +# ── taxonomy caching ── + + +def test_get_cached_taxonomy_returns_none_when_no_metadata(): + from app.services.ai_file_sort_service import _get_cached_taxonomy + + doc = MagicMock() + doc.document_metadata = None + assert _get_cached_taxonomy(doc) is None + + +def test_get_cached_taxonomy_returns_none_when_keys_missing(): + from app.services.ai_file_sort_service import _get_cached_taxonomy + + doc = MagicMock() + doc.document_metadata = {"some_other_key": "value"} + assert _get_cached_taxonomy(doc) is None + + +def test_get_cached_taxonomy_returns_cached_values(): + from app.services.ai_file_sort_service import _get_cached_taxonomy + + doc = MagicMock() + doc.document_metadata = { + "ai_sort_category": "Finance", + "ai_sort_subcategory": "Tax Reports", + } + assert _get_cached_taxonomy(doc) == ("Finance", "Tax Reports") + + +def test_set_cached_taxonomy_persists_on_metadata(): + from app.services.ai_file_sort_service import _set_cached_taxonomy + + doc = MagicMock() + doc.document_metadata = {"existing_key": "keep_me"} + _set_cached_taxonomy(doc, "Science", "Physics") + assert doc.document_metadata["ai_sort_category"] == "Science" + assert doc.document_metadata["ai_sort_subcategory"] == "Physics" + assert doc.document_metadata["existing_key"] == "keep_me" + + +def test_set_cached_taxonomy_creates_metadata_when_none(): + from app.services.ai_file_sort_service import _set_cached_taxonomy + + doc = MagicMock() + doc.document_metadata = None + _set_cached_taxonomy(doc, "Engineering", "Backend") + assert doc.document_metadata == { + "ai_sort_category": "Engineering", + "ai_sort_subcategory": "Backend", + } + + +# ── _build_path_segments ── + + +def test_build_path_segments_structure(): + from app.services.ai_file_sort_service import _build_path_segments + + segments = _build_path_segments("Google Drive", "2025-03-15", "Science", "Physics") + assert len(segments) == 4 + assert segments[0] == { + "name": "Google Drive", + "metadata": {"ai_sort": True, "ai_sort_level": 1}, + } + assert segments[1] == { + "name": "2025-03-15", + "metadata": {"ai_sort": True, "ai_sort_level": 2}, + } + assert segments[2] == { + "name": "Science", + "metadata": {"ai_sort": True, "ai_sort_level": 3}, + } + assert segments[3] == { + "name": "Physics", + "metadata": {"ai_sort": True, "ai_sort_level": 4}, + } diff --git a/surfsense_backend/tests/unit/services/test_ai_sort_task_dedupe.py b/surfsense_backend/tests/unit/services/test_ai_sort_task_dedupe.py new file mode 100644 index 000000000..fd9018514 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_ai_sort_task_dedupe.py @@ -0,0 +1,43 @@ +"""Unit tests for AI sort task Redis deduplication lock.""" + +from unittest.mock import MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +def test_lock_key_format(): + from app.tasks.celery_tasks.document_tasks import _ai_sort_lock_key + + key = _ai_sort_lock_key(42) + assert key == "ai_sort:search_space:42:lock" + + +def test_lock_prevents_duplicate_run(): + """When the Redis lock already exists, the task should skip execution.""" + + mock_redis = MagicMock() + mock_redis.set.return_value = False # Lock already held + + with ( + patch( + "app.tasks.celery_tasks.document_tasks._get_ai_sort_redis", + return_value=mock_redis, + ), + patch( + "app.tasks.celery_tasks.document_tasks.get_celery_session_maker" + ) as mock_session_maker, + ): + import asyncio + + from app.tasks.celery_tasks.document_tasks import _ai_sort_search_space_async + + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(_ai_sort_search_space_async(1, "user-123")) + finally: + loop.close() + + # Session maker should never be called since lock was not acquired + mock_session_maker.assert_not_called() diff --git a/surfsense_backend/tests/unit/services/test_folder_hierarchy.py b/surfsense_backend/tests/unit/services/test_folder_hierarchy.py new file mode 100644 index 000000000..9077f6b0e --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_folder_hierarchy.py @@ -0,0 +1,87 @@ +"""Unit tests for ensure_folder_hierarchy_with_depth_validation.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_creates_missing_folders_in_chain(): + """Should create all folders when none exist.""" + from app.services.folder_service import ( + ensure_folder_hierarchy_with_depth_validation, + ) + + session = AsyncMock() + # All lookups return None (no existing folders) + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + session.execute.return_value = mock_result + + folder_instances = [] + + def track_add(obj): + folder_instances.append(obj) + + session.add = track_add + + with ( + patch( + "app.services.folder_service.validate_folder_depth", new_callable=AsyncMock + ), + patch( + "app.services.folder_service.generate_folder_position", + new_callable=AsyncMock, + return_value="a0", + ), + ): + # Mock flush to assign IDs + call_count = 0 + + async def mock_flush(): + nonlocal call_count + call_count += 1 + if folder_instances: + folder_instances[-1].id = call_count + + session.flush = mock_flush + + segments = [ + {"name": "Slack", "metadata": {"ai_sort": True, "ai_sort_level": 1}}, + {"name": "2025-03-15", "metadata": {"ai_sort": True, "ai_sort_level": 2}}, + ] + + result = await ensure_folder_hierarchy_with_depth_validation( + session, 1, segments + ) + + assert len(folder_instances) == 2 + assert folder_instances[0].name == "Slack" + assert folder_instances[1].name == "2025-03-15" + assert result is folder_instances[-1] + + +@pytest.mark.asyncio +async def test_reuses_existing_folder(): + """When a folder already exists, it should be reused, not created.""" + from app.services.folder_service import ( + ensure_folder_hierarchy_with_depth_validation, + ) + + session = AsyncMock() + + existing_folder = MagicMock() + existing_folder.id = 42 + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = existing_folder + session.execute.return_value = mock_result + + segments = [{"name": "Existing", "metadata": None}] + + result = await ensure_folder_hierarchy_with_depth_validation(session, 1, segments) + + assert result is existing_folder + session.add.assert_not_called() diff --git a/surfsense_web/components/documents/DocumentsFilters.tsx b/surfsense_web/components/documents/DocumentsFilters.tsx index d43f3680b..08af96dde 100644 --- a/surfsense_web/components/documents/DocumentsFilters.tsx +++ b/surfsense_web/components/documents/DocumentsFilters.tsx @@ -1,6 +1,8 @@ "use client"; +import { IconBinaryTree, IconBinaryTreeFilled } from "@tabler/icons-react"; import { FolderPlus, ListFilter, Search, Upload, X } from "lucide-react"; +import { AnimatePresence, motion } from "motion/react"; import { useTranslations } from "next-intl"; import React, { useCallback, useMemo, useRef, useState } from "react"; import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup"; @@ -10,6 +12,7 @@ import { Input } from "@/components/ui/input"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { cn } from "@/lib/utils"; import type { DocumentTypeEnum } from "@/contracts/types/document.types"; import { getDocumentTypeIcon, getDocumentTypeLabel } from "./DocumentTypeIcon"; @@ -20,6 +23,9 @@ export function DocumentsFilters({ onToggleType, activeTypes, onCreateFolder, + aiSortEnabled = false, + aiSortBusy = false, + onToggleAiSort, }: { typeCounts: Partial>; onSearch: (v: string) => void; @@ -27,6 +33,9 @@ export function DocumentsFilters({ onToggleType: (type: DocumentTypeEnum, checked: boolean) => void; activeTypes: DocumentTypeEnum[]; onCreateFolder?: () => void; + aiSortEnabled?: boolean; + aiSortBusy?: boolean; + onToggleAiSort?: () => void; }) { const t = useTranslations("documents"); const id = React.useId(); @@ -64,7 +73,7 @@ export function DocumentsFilters({ return (
- {/* Filter + New Folder Toggle Group */} + {/* New Folder + Filter Toggle Group */} {onCreateFolder && ( @@ -172,6 +181,70 @@ export function DocumentsFilters({ + {/* AI Sort Toggle */} + {onToggleAiSort && ( + + + + + + {aiSortBusy + ? "AI sort in progress..." + : aiSortEnabled + ? "AI sort active — click to disable" + : "Enable AI sort"} + + + )} + {/* Search Input */}
diff --git a/surfsense_web/components/documents/FolderTreeView.tsx b/surfsense_web/components/documents/FolderTreeView.tsx index 4988e87e7..ca833949c 100644 --- a/surfsense_web/components/documents/FolderTreeView.tsx +++ b/surfsense_web/components/documents/FolderTreeView.tsx @@ -90,6 +90,8 @@ export function FolderTreeView({ const [openContextMenuId, setOpenContextMenuId] = useState(null); + const [manuallyCollapsedAiIds, setManuallyCollapsedAiIds] = useState>(new Set()); + // Single subscription for rename state — derived boolean passed to each FolderNode const [renamingFolderId, setRenamingFolderId] = useAtom(renamingFolderIdAtom); const handleStartRename = useCallback( @@ -98,6 +100,38 @@ export function FolderTreeView({ ); const handleCancelRename = useCallback(() => setRenamingFolderId(null), [setRenamingFolderId]); + const aiSortFolderLevels = useMemo(() => { + const map = new Map(); + for (const f of folders) { + if (f.metadata?.ai_sort === true && typeof f.metadata?.ai_sort_level === "number") { + map.set(f.id, f.metadata.ai_sort_level as number); + } + } + return map; + }, [folders]); + + const handleToggleExpand = useCallback( + (folderId: number) => { + const aiLevel = aiSortFolderLevels.get(folderId); + if (aiLevel !== undefined && aiLevel < 4) { + // AI-auto-expanded folder: only toggle the manual-collapse set. + // Calling onToggleExpand would add it to expandedIds and fight auto-expand. + setManuallyCollapsedAiIds((prev) => { + const next = new Set(prev); + if (next.has(folderId)) { + next.delete(folderId); + } else { + next.add(folderId); + } + return next; + }); + return; + } + onToggleExpand(folderId); + }, + [onToggleExpand, aiSortFolderLevels] + ); + const effectiveActiveTypes = useMemo(() => { if ( activeTypes.includes("FILE" as DocumentTypeEnum) && @@ -212,9 +246,16 @@ export function FolderTreeView({ function renderLevel(parentId: number | null, depth: number): React.ReactNode[] { const key = parentId ?? "root"; - const childFolders = (foldersByParent[key] ?? []) - .slice() - .sort((a, b) => a.position.localeCompare(b.position)); + const childFolders = (foldersByParent[key] ?? []).slice().sort((a, b) => { + const aIsDate = + a.metadata?.ai_sort === true && a.metadata?.ai_sort_level === 2; + const bIsDate = + b.metadata?.ai_sort === true && b.metadata?.ai_sort_level === 2; + if (aIsDate && bIsDate) { + return b.name.localeCompare(a.name); + } + return a.position.localeCompare(b.position); + }); const visibleFolders = hasDescendantMatch ? childFolders.filter((f) => hasDescendantMatch[f.id]) : childFolders; @@ -226,6 +267,32 @@ export function FolderTreeView({ const nodes: React.ReactNode[] = []; + if (parentId === null) { + const processingDocs = childDocs.filter((d) => { + const state = d.status?.state; + return state === "pending" || state === "processing"; + }); + for (const d of processingDocs) { + nodes.push( + setOpenContextMenuId(open ? `doc-${d.id}` : null)} + /> + ); + } + } + for (let i = 0; i < visibleFolders.length; i++) { const f = visibleFolders[i]; const siblingPositions = { @@ -233,8 +300,15 @@ export function FolderTreeView({ after: i < visibleFolders.length - 1 ? visibleFolders[i + 1].position : null, }; - const isAutoExpanded = !!searchQuery && !!hasDescendantMatch?.[f.id]; - const isExpanded = expandedIds.has(f.id) || isAutoExpanded; + const isSearchAutoExpanded = !!searchQuery && !!hasDescendantMatch?.[f.id]; + const isAiAutoExpandCandidate = + f.metadata?.ai_sort === true && + typeof f.metadata?.ai_sort_level === "number" && + (f.metadata.ai_sort_level as number) < 4; + const isManuallyCollapsed = manuallyCollapsedAiIds.has(f.id); + const isExpanded = isManuallyCollapsed + ? isSearchAutoExpanded + : expandedIds.has(f.id) || isSearchAutoExpanded || isAiAutoExpandCandidate; nodes.push( { + const state = d.status?.state; + return state !== "pending" && state !== "processing"; + }) + : childDocs; + + for (const d of remainingDocs) { nodes.push( , , , etc. tags correctly. +// remarkMdx treats { } as JSX expression delimiters and does NOT support +// HTML comments (). Arbitrary markdown from document conversions +// (e.g. PDF-to-markdown via Azure/DocIntel) can contain constructs that +// break the MDX parser. This module sanitises them before deserialization. // --------------------------------------------------------------------------- const FENCED_OR_INLINE_CODE = /(```[\s\S]*?```|`[^`\n]+`)/g; -export function escapeMdxExpressions(md: string): string { +// Strip HTML comments that MDX cannot parse. +// PDF converters emit , , etc. +// MDX uses JSX-style comments and chokes on HTML comments, causing the +// parser to stop at the first occurrence. +// - becomes a thematic break (---) +// - All other HTML comments are removed +function stripHtmlComments(md: string): string { + return md + .replace(//gi, "\n---\n") + .replace(//g, ""); +} + +// Convert
...
blocks to plain text blockquotes. +//
with arbitrary text content is not valid JSX, causing the MDX +// parser to fail. +function convertFigureBlocks(md: string): string { + return md.replace(/]*>([\s\S]*?)<\/figure>/gi, (_match, inner: string) => { + const trimmed = (inner as string).trim(); + if (!trimmed) return ""; + const quoted = trimmed + .split("\n") + .map((line) => `> ${line}`) + .join("\n"); + return `\n${quoted}\n`; + }); +} + +// Escape unescaped { and } outside of fenced/inline code so remarkMdx +// treats them as literal characters rather than JSX expression delimiters. +function escapeCurlyBraces(md: string): string { const parts = md.split(FENCED_OR_INLINE_CODE); return parts .map((part, i) => { - // Odd indices are code blocks / inline code – leave untouched if (i % 2 === 1) return part; - // Escape { and } that are NOT already escaped (no preceding \) return part.replace(/(?(null); const isElectron = typeof window !== "undefined" && !!window.electronAPI; + // AI File Sort state + const { data: searchSpaces, refetch: refetchSearchSpaces } = useAtomValue(searchSpacesAtom); + const activeSearchSpace = useMemo( + () => searchSpaces?.find((s) => s.id === searchSpaceId), + [searchSpaces, searchSpaceId] + ); + const aiSortEnabled = activeSearchSpace?.ai_file_sort_enabled ?? false; + const [aiSortBusy, setAiSortBusy] = useState(false); + const [aiSortConfirmOpen, setAiSortConfirmOpen] = useState(false); + + const handleToggleAiSort = useCallback(() => { + if (aiSortEnabled) { + // Disable: just update the setting, no confirmation needed + setAiSortBusy(true); + searchSpacesApiService + .updateSearchSpace({ id: searchSpaceId, data: { ai_file_sort_enabled: false } }) + .then(() => { + refetchSearchSpaces(); + toast.success("AI file sorting disabled"); + }) + .catch(() => toast.error("Failed to disable AI file sorting")) + .finally(() => setAiSortBusy(false)); + } else { + setAiSortConfirmOpen(true); + } + }, [aiSortEnabled, searchSpaceId, refetchSearchSpaces]); + + const handleConfirmEnableAiSort = useCallback(() => { + setAiSortConfirmOpen(false); + setAiSortBusy(true); + searchSpacesApiService + .updateSearchSpace({ id: searchSpaceId, data: { ai_file_sort_enabled: true } }) + .then(() => searchSpacesApiService.triggerAiSort(searchSpaceId)) + .then(() => { + refetchSearchSpaces(); + toast.success("AI file sorting enabled — organizing your documents in the background"); + }) + .catch(() => toast.error("Failed to enable AI file sorting")) + .finally(() => setAiSortBusy(false)); + }, [searchSpaceId, refetchSearchSpaces]); + const handleWatchLocalFolder = useCallback(async () => { const api = window.electronAPI; if (!api?.selectFolder) return; @@ -905,6 +948,9 @@ export function DocumentsSidebar({ onToggleType={onToggleType} activeTypes={activeTypes} onCreateFolder={() => handleCreateFolder(null)} + aiSortEnabled={aiSortEnabled} + aiSortBusy={aiSortBusy} + onToggleAiSort={handleToggleAiSort} />
@@ -1066,6 +1112,25 @@ export function DocumentsSidebar({ + + + + + Enable AI File Sorting? + + All documents in this search space will be organized into folders by + connector type, date, and AI-generated categories. New documents will + also be sorted automatically. You can disable this at any time. + + + + Cancel + + Enable + + + + ); diff --git a/surfsense_web/contracts/types/search-space.types.ts b/surfsense_web/contracts/types/search-space.types.ts index 7b4fefb62..7449f82b1 100644 --- a/surfsense_web/contracts/types/search-space.types.ts +++ b/surfsense_web/contracts/types/search-space.types.ts @@ -10,6 +10,7 @@ export const searchSpace = z.object({ citations_enabled: z.boolean(), qna_custom_instructions: z.string().nullable(), shared_memory_md: z.string().nullable().optional(), + ai_file_sort_enabled: z.boolean().optional().default(false), member_count: z.number(), is_owner: z.boolean(), }); @@ -56,6 +57,7 @@ export const updateSearchSpaceRequest = z.object({ citations_enabled: true, qna_custom_instructions: true, shared_memory_md: true, + ai_file_sort_enabled: true, }) .partial(), }); diff --git a/surfsense_web/lib/apis/search-spaces-api.service.ts b/surfsense_web/lib/apis/search-spaces-api.service.ts index 3e2006e46..e593245f8 100644 --- a/surfsense_web/lib/apis/search-spaces-api.service.ts +++ b/surfsense_web/lib/apis/search-spaces-api.service.ts @@ -1,3 +1,4 @@ +import { z } from "zod"; import { type CreateSearchSpaceRequest, createSearchSpaceRequest, @@ -117,6 +118,17 @@ class SearchSpacesApiService { return baseApiService.delete(`/api/v1/searchspaces/${request.id}`, deleteSearchSpaceResponse); }; + /** + * Trigger AI file sorting for all documents in a search space + */ + triggerAiSort = async (searchSpaceId: number) => { + return baseApiService.post( + `/api/v1/searchspaces/${searchSpaceId}/ai-sort`, + z.object({ message: z.string() }), + {} + ); + }; + /** * Leave a search space (remove own membership) * This is used by non-owners to leave a shared search space