diff --git a/surfsense_backend/alembic/versions/124_add_ai_file_sort_enabled.py b/surfsense_backend/alembic/versions/124_add_ai_file_sort_enabled.py new file mode 100644 index 000000000..b77eb9337 --- /dev/null +++ b/surfsense_backend/alembic/versions/124_add_ai_file_sort_enabled.py @@ -0,0 +1,44 @@ +"""124_add_ai_file_sort_enabled + +Revision ID: 124 +Revises: 123 +Create Date: 2026-04-14 + +Adds ai_file_sort_enabled boolean column to searchspaces. +Defaults to False so AI file sorting is opt-in per search space. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "124" +down_revision: str | None = "123" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + conn = op.get_bind() + existing_columns = [ + col["name"] for col in sa.inspect(conn).get_columns("searchspaces") + ] + + if "ai_file_sort_enabled" not in existing_columns: + op.add_column( + "searchspaces", + sa.Column( + "ai_file_sort_enabled", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + ) + + +def downgrade() -> None: + op.drop_column("searchspaces", "ai_file_sort_enabled") diff --git a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py index bc6f7fd9e..61494ff1a 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py +++ b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py @@ -93,7 +93,8 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] @staticmethod def _dedup( - state: AgentState, dedup_keys: dict[str, str] # type: ignore[type-arg] + state: AgentState, + dedup_keys: dict[str, str], # type: ignore[type-arg] ) -> dict[str, Any] | None: messages = state.get("messages") if not messages: diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index af5a6925b..bcd544d61 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -9,6 +9,7 @@ from __future__ import annotations import asyncio import logging import re +import secrets from datetime import UTC, datetime from typing import Annotated, Any @@ -27,6 +28,7 @@ from sqlalchemy import delete, select from app.agents.new_chat.sandbox import ( _evict_sandbox_cache, + delete_sandbox, get_or_create_sandbox, is_sandbox_enabled, ) @@ -552,7 +554,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): @staticmethod def _wrap_as_python(code: str) -> str: """Wrap Python code in a shell invocation for the sandbox.""" - return f"python3 << 'PYEOF'\n{code}\nPYEOF" + sentinel = f"_PYEOF_{secrets.token_hex(8)}" + return f"python3 << '{sentinel}'\n{code}\n{sentinel}" async def _execute_in_sandbox( self, @@ -572,7 +575,10 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): self._thread_id, first_err, ) - _evict_sandbox_cache(self._thread_id) + try: + await delete_sandbox(self._thread_id) + except Exception: + _evict_sandbox_cache(self._thread_id) try: return await self._try_sandbox_execute(command, runtime, timeout) except Exception: @@ -587,7 +593,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): runtime: ToolRuntime[None, FilesystemState], timeout: int | None, ) -> str: - sandbox, is_new = await get_or_create_sandbox(self._thread_id) + sandbox, _is_new = await get_or_create_sandbox(self._thread_id) + # NOTE: sync_files_to_sandbox is intentionally disabled. + # The virtual FS contains XML-wrapped KB documents whose paths + # would double-nest under SANDBOX_DOCUMENTS_ROOT (e.g. + # /home/daytona/documents/documents/Report.xml) and uploading + # all KB docs on the first execute_code call adds significant + # latency. Re-enable once path mapping is fixed and upload is + # limited to user-created scratch files. # files = runtime.state.get("files") or {} # await sync_files_to_sandbox(self._thread_id, files, sandbox, is_new) result = await sandbox.aexecute(command, timeout=timeout) diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py index 06ed4ad80..0460da11d 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py @@ -58,6 +58,14 @@ class KBSearchPlan(BaseModel): default=None, description="Optional ISO end date or datetime for KB search filtering.", ) + is_recency_query: bool = Field( + default=False, + description=( + "True when the user's intent is primarily about recency or temporal " + "ordering (e.g. 'latest', 'newest', 'most recent', 'last uploaded') " + "rather than topical relevance." + ), + ) def _extract_text_from_message(message: BaseMessage) -> str: @@ -245,7 +253,7 @@ def _build_kb_planner_prompt( return ( "You optimize internal knowledge-base search inputs for document retrieval.\n" "Return JSON only with this exact shape:\n" - '{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null"}\n\n' + '{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null","is_recency_query":bool}\n\n' "Rules:\n" "- Preserve the user's intent.\n" "- Rewrite the query to improve retrieval using concrete entities, acronyms, projects, tools, people, and document-specific terms when helpful.\n" @@ -253,6 +261,11 @@ def _build_kb_planner_prompt( "- Only use date filters when the latest user request or recent dialogue clearly implies a time range.\n" "- If you use date filters, prefer returning both bounds.\n" "- If no date filter is useful, return null for both dates.\n" + '- Set "is_recency_query" to true ONLY when the user\'s primary intent is about ' + "recency or temporal ordering rather than topical relevance. Examples: " + '"latest file", "newest upload", "most recent document", "what did I save last", ' + '"show me files from today", "last thing I added". ' + "When true, results will be sorted by date instead of relevance.\n" "- Do not include markdown, prose, or explanations.\n\n" f"Today's UTC date: {today}\n\n" f"Recent conversation:\n{recent_conversation or '(none)'}\n\n" @@ -506,6 +519,135 @@ def _resolve_search_types( return list(expanded) if expanded else None +_RECENCY_MAX_CHUNKS_PER_DOC = 5 + + +async def browse_recent_documents( + *, + search_space_id: int, + document_type: list[str] | None = None, + top_k: int = 10, + start_date: datetime | None = None, + end_date: datetime | None = None, +) -> list[dict[str, Any]]: + """Return documents ordered by recency (newest first), no relevance ranking. + + Used when the user's intent is temporal ("latest file", "most recent upload") + and hybrid search would produce poor results because the query has no + meaningful topical signal. + """ + from sqlalchemy import func, select + + from app.db import DocumentType + + async with shielded_async_session() as session: + base_conditions = [ + Document.search_space_id == search_space_id, + func.coalesce(Document.status["state"].astext, "ready") != "deleting", + ] + + if document_type is not None: + import contextlib + + doc_type_enums = [] + for dt in document_type: + if isinstance(dt, str): + with contextlib.suppress(KeyError): + doc_type_enums.append(DocumentType[dt]) + else: + doc_type_enums.append(dt) + if doc_type_enums: + if len(doc_type_enums) == 1: + base_conditions.append(Document.document_type == doc_type_enums[0]) + else: + base_conditions.append(Document.document_type.in_(doc_type_enums)) + + if start_date is not None: + base_conditions.append(Document.updated_at >= start_date) + if end_date is not None: + base_conditions.append(Document.updated_at <= end_date) + + doc_query = ( + select(Document) + .where(*base_conditions) + .order_by(Document.updated_at.desc()) + .limit(top_k) + ) + result = await session.execute(doc_query) + documents = result.scalars().unique().all() + + if not documents: + return [] + + doc_ids = [d.id for d in documents] + + numbered = ( + select( + Chunk.id.label("chunk_id"), + Chunk.document_id, + Chunk.content, + func.row_number() + .over(partition_by=Chunk.document_id, order_by=Chunk.id) + .label("rn"), + ) + .where(Chunk.document_id.in_(doc_ids)) + .subquery("numbered") + ) + + chunk_query = ( + select(numbered.c.chunk_id, numbered.c.document_id, numbered.c.content) + .where(numbered.c.rn <= _RECENCY_MAX_CHUNKS_PER_DOC) + .order_by(numbered.c.document_id, numbered.c.chunk_id) + ) + chunk_result = await session.execute(chunk_query) + fetched_chunks = chunk_result.all() + + doc_chunks: dict[int, list[dict[str, Any]]] = {d.id: [] for d in documents} + for row in fetched_chunks: + if row.document_id in doc_chunks: + doc_chunks[row.document_id].append( + {"chunk_id": row.chunk_id, "content": row.content} + ) + + results: list[dict[str, Any]] = [] + for doc in documents: + chunks_list = doc_chunks.get(doc.id, []) + metadata = doc.document_metadata or {} + results.append( + { + "document_id": doc.id, + "content": "\n\n".join( + c["content"] for c in chunks_list if c.get("content") + ), + "score": 0.0, + "chunks": chunks_list, + "matched_chunk_ids": [], + "document": { + "id": doc.id, + "title": doc.title, + "document_type": ( + doc.document_type.value + if getattr(doc, "document_type", None) + else None + ), + "metadata": metadata, + }, + "source": ( + doc.document_type.value + if getattr(doc, "document_type", None) + else None + ), + } + ) + + logger.info( + "browse_recent_documents: %d docs returned for space=%d", + len(results), + search_space_id, + ) + return results + + async def search_knowledge_base( *, query: str, @@ -704,10 +846,13 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] *, messages: Sequence[BaseMessage], user_text: str, - ) -> tuple[str, datetime | None, datetime | None]: - """Rewrite the KB query and infer optional date filters with the LLM.""" + ) -> tuple[str, datetime | None, datetime | None, bool]: + """Rewrite the KB query and infer optional date filters with the LLM. + + Returns (optimized_query, start_date, end_date, is_recency_query). + """ if self.llm is None: - return user_text, None, None + return user_text, None, None, False recent_conversation = _render_recent_conversation( messages, @@ -734,15 +879,18 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] plan.start_date, plan.end_date, ) + is_recency = plan.is_recency_query _perf_log.info( - "[kb_fs_middleware] planner in %.3fs query=%r optimized=%r start=%s end=%s", + "[kb_fs_middleware] planner in %.3fs query=%r optimized=%r " + "start=%s end=%s recency=%s", loop.time() - t0, user_text[:80], optimized_query[:120], start_date.isoformat() if start_date else None, end_date.isoformat() if end_date else None, + is_recency, ) - return optimized_query, start_date, end_date + return optimized_query, start_date, end_date, is_recency except (json.JSONDecodeError, ValidationError, ValueError) as exc: logger.warning( "KB planner returned invalid output, using raw query: %s", exc @@ -750,7 +898,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] except Exception as exc: # pragma: no cover - defensive fallback logger.warning("KB planner failed, using raw query: %s", exc) - return user_text, None, None + return user_text, None, None, False def before_agent( # type: ignore[override] self, @@ -789,7 +937,12 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] t0 = _perf_log and asyncio.get_event_loop().time() existing_files = state.get("files") - planned_query, start_date, end_date = await self._plan_search_inputs( + ( + planned_query, + start_date, + end_date, + is_recency, + ) = await self._plan_search_inputs( messages=messages, user_text=user_text, ) @@ -805,16 +958,28 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] # messages within the same agent instance. self.mentioned_document_ids = [] - # --- 2. Run KB hybrid search --- - search_results = await search_knowledge_base( - query=planned_query, - search_space_id=self.search_space_id, - available_connectors=self.available_connectors, - available_document_types=self.available_document_types, - top_k=self.top_k, - start_date=start_date, - end_date=end_date, - ) + # --- 2. Run KB search (recency browse or hybrid) --- + if is_recency: + doc_types = _resolve_search_types( + self.available_connectors, self.available_document_types + ) + search_results = await browse_recent_documents( + search_space_id=self.search_space_id, + document_type=doc_types, + top_k=self.top_k, + start_date=start_date, + end_date=end_date, + ) + else: + search_results = await search_knowledge_base( + query=planned_query, + search_space_id=self.search_space_id, + available_connectors=self.available_connectors, + available_document_types=self.available_document_types, + top_k=self.top_k, + start_date=start_date, + end_date=end_date, + ) # --- 3. Merge: mentioned first, then search (dedup by doc id) --- seen_doc_ids: set[int] = set() diff --git a/surfsense_backend/app/agents/new_chat/sandbox.py b/surfsense_backend/app/agents/new_chat/sandbox.py index 614a1b1b9..efac7aae8 100644 --- a/surfsense_backend/app/agents/new_chat/sandbox.py +++ b/surfsense_backend/app/agents/new_chat/sandbox.py @@ -16,6 +16,7 @@ import contextlib import logging import os import shutil +import threading from pathlib import Path from daytona import ( @@ -55,9 +56,16 @@ class _TimeoutAwareSandbox(DaytonaSandbox): ) -> ExecuteResponse: # type: ignore[override] return await asyncio.to_thread(self.execute, command, timeout=timeout) + def download_file(self, path: str) -> bytes: + """Download a file from the sandbox filesystem.""" + return self._sandbox.fs.download_file(path) + _daytona_client: Daytona | None = None +_client_lock = threading.Lock() _sandbox_cache: dict[str, _TimeoutAwareSandbox] = {} +_sandbox_locks: dict[str, asyncio.Lock] = {} +_sandbox_locks_mu = asyncio.Lock() _seeded_files: dict[str, dict[str, str]] = {} _SANDBOX_CACHE_MAX_SIZE = 20 THREAD_LABEL_KEY = "surfsense_thread" @@ -70,14 +78,15 @@ def is_sandbox_enabled() -> bool: def _get_client() -> Daytona: global _daytona_client - if _daytona_client is None: - config = DaytonaConfig( - api_key=os.environ.get("DAYTONA_API_KEY", ""), - api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"), - target=os.environ.get("DAYTONA_TARGET", "us"), - ) - _daytona_client = Daytona(config) - return _daytona_client + with _client_lock: + if _daytona_client is None: + config = DaytonaConfig( + api_key=os.environ.get("DAYTONA_API_KEY", ""), + api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"), + target=os.environ.get("DAYTONA_TARGET", "us"), + ) + _daytona_client = Daytona(config) + return _daytona_client def _sandbox_create_params( @@ -129,14 +138,16 @@ def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]: try: client.delete(sandbox) except Exception: - logger.debug("Could not delete broken sandbox %s", sandbox.id, exc_info=True) + logger.debug( + "Could not delete broken sandbox %s", sandbox.id, exc_info=True + ) sandbox = client.create(_sandbox_create_params(labels)) is_new = True logger.info("Created replacement sandbox: %s", sandbox.id) elif sandbox.state != SandboxState.STARTED: sandbox.wait_for_sandbox_start(timeout=60) - except Exception: + except DaytonaError: logger.info("No existing sandbox for thread %s — creating one", thread_id) sandbox = client.create(_sandbox_create_params(labels)) is_new = True @@ -145,6 +156,16 @@ def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]: return _TimeoutAwareSandbox(sandbox=sandbox), is_new +async def _get_thread_lock(key: str) -> asyncio.Lock: + """Return a per-thread asyncio lock, creating one if needed.""" + async with _sandbox_locks_mu: + lock = _sandbox_locks.get(key) + if lock is None: + lock = asyncio.Lock() + _sandbox_locks[key] = lock + return lock + + async def get_or_create_sandbox( thread_id: int | str, ) -> tuple[_TimeoutAwareSandbox, bool]: @@ -152,25 +173,52 @@ async def get_or_create_sandbox( Uses an in-process cache keyed by thread_id so subsequent messages in the same conversation reuse the sandbox object without an API call. + A per-thread async lock prevents duplicate sandbox creation from + concurrent requests. Returns: Tuple of (sandbox, is_new). *is_new* is True when a fresh sandbox was created, signalling that file tracking should be reset. """ key = str(thread_id) - cached = _sandbox_cache.get(key) - if cached is not None: - logger.info("Reusing cached sandbox for thread %s", key) - return cached, False - sandbox, is_new = await asyncio.to_thread(_find_or_create, key) - _sandbox_cache[key] = sandbox + lock = await _get_thread_lock(key) - if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE: - oldest_key = next(iter(_sandbox_cache)) - _sandbox_cache.pop(oldest_key, None) - logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key) + async with lock: + cached = _sandbox_cache.get(key) + if cached is not None: + logger.info("Reusing cached sandbox for thread %s", key) + return cached, False + sandbox, is_new = await asyncio.to_thread(_find_or_create, key) + _sandbox_cache[key] = sandbox - return sandbox, is_new + if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE: + oldest_key = next(iter(_sandbox_cache)) + if oldest_key != key: + evicted = _sandbox_cache.pop(oldest_key, None) + _seeded_files.pop(oldest_key, None) + logger.debug("Evicted sandbox cache entry: %s", oldest_key) + if evicted is not None: + _schedule_sandbox_delete(evicted) + + return sandbox, is_new + + +def _schedule_sandbox_delete(sandbox: _TimeoutAwareSandbox) -> None: + """Best-effort background deletion of an evicted sandbox.""" + + def _delete() -> None: + try: + client = _get_client() + client.delete(sandbox._sandbox) + logger.info("Deleted evicted sandbox: %s", sandbox._sandbox.id) + except Exception: + logger.debug("Could not delete evicted sandbox", exc_info=True) + + try: + loop = asyncio.get_running_loop() + loop.run_in_executor(None, _delete) + except RuntimeError: + pass async def sync_files_to_sandbox( diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py index b76f4d757..095413bdb 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py @@ -2,10 +2,10 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector from app.services.confluence import ConfluenceToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py index 070efaf57..7c03c2760 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py @@ -2,10 +2,10 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector from app.services.confluence import ConfluenceToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py index c80df9710..791d0d8c5 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py @@ -2,10 +2,10 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector from app.services.confluence import ConfluenceToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py index 6e2578334..22d8a8a27 100644 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py @@ -5,10 +5,10 @@ from pathlib import Path from typing import Any, Literal from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.dropbox.client import DropboxClient from app.db import SearchSourceConnector, SearchSourceConnectorType diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py index 620b39aa2..12559b57a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py @@ -2,11 +2,11 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy import String, and_, cast, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.dropbox.client import DropboxClient from app.db import ( Document, diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py index 974f9b4af..0bd044695 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py @@ -6,9 +6,9 @@ from email.mime.text import MIMEText from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py index a1c713f0a..c3f0999f4 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py @@ -6,9 +6,9 @@ from email.mime.text import MIMEText from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py index cab97ee8a..1f1f6227a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py @@ -4,9 +4,9 @@ from datetime import datetime from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py index 1d53ac9ce..91178cd21 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py @@ -6,9 +6,9 @@ from email.mime.text import MIMEText from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py index 45ff6dfb9..259f52bba 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py @@ -150,7 +150,9 @@ def create_update_calendar_event_tool( final_new_end_datetime = result.params.get( "new_end_datetime", new_end_datetime ) - final_new_description = result.params.get("new_description", new_description) + final_new_description = result.params.get( + "new_description", new_description + ) final_new_location = result.params.get("new_location", new_location) final_new_attendees = result.params.get("new_attendees", new_attendees) diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py index a1ac90dc7..64ace547c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/hitl.py +++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py @@ -58,7 +58,9 @@ def _parse_decision(approval: Any) -> tuple[str, dict[str, Any]]: raise ValueError("No approval decision received") decision = decisions[0] - decision_type: str = decision.get("type") or decision.get("decision_type") or "approve" + decision_type: str = ( + decision.get("type") or decision.get("decision_type") or "approve" + ) edited_params: dict[str, Any] = {} edited_action = decision.get("edited_action") diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py index 0b3332694..8b40dde65 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py @@ -3,10 +3,10 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector from app.services.jira import JiraToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py index 52d4556a5..6466c80ea 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py @@ -3,10 +3,10 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector from app.services.jira import JiraToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py index 9c676fea3..f6e586a2e 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py @@ -3,10 +3,10 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector from app.services.jira import JiraToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py index d8005bd5c..ff254e133 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py @@ -2,9 +2,9 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector from app.services.linear import LinearToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py index d8bc88d82..29ef0cdf2 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py @@ -2,9 +2,9 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector from app.services.linear import LinearToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py index 7f6d952e5..f35d0dddd 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py @@ -2,9 +2,9 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector from app.services.linear import LinearKBSyncService, LinearToolMetadataService @@ -157,9 +157,13 @@ def create_update_linear_issue_tool( final_issue_id = result.params.get("issue_id", issue_id) final_document_id = result.params.get("document_id", document_id) final_new_title = result.params.get("new_title", new_title) - final_new_description = result.params.get("new_description", new_description) + final_new_description = result.params.get( + "new_description", new_description + ) final_new_state_id = result.params.get("new_state_id", new_state_id) - final_new_assignee_id = result.params.get("new_assignee_id", new_assignee_id) + final_new_assignee_id = result.params.get( + "new_assignee_id", new_assignee_id + ) final_new_priority = result.params.get("new_priority", new_priority) final_new_label_ids: list[str] | None = result.params.get( "new_label_ids", new_label_ids diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py index 396f3fe0d..6efffe960 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py @@ -2,9 +2,9 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.services.notion import NotionToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py index 92e395624..07f7583d2 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py @@ -2,9 +2,9 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.services.notion.tool_metadata_service import NotionToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py index ee7b8f256..85c08177c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py @@ -2,9 +2,9 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.services.notion import NotionToolMetadataService diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py index 5050c7885..21272e01d 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py @@ -5,10 +5,10 @@ from pathlib import Path from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.onedrive.client import OneDriveClient from app.db import SearchSourceConnector, SearchSourceConnectorType diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py index 6997e1d52..a7f13b5df 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py @@ -2,11 +2,11 @@ import logging from typing import Any from langchain_core.tools import tool -from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy import String, and_, cast, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.onedrive.client import OneDriveClient from app.db import ( Document, diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index b9fbe8845..61bdd65cb 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -1383,6 +1383,10 @@ class SearchSpace(BaseModel, TimestampMixin): Integer, nullable=True, default=0 ) # For vision/screenshot analysis, defaults to Auto mode + ai_file_sort_enabled = Column( + Boolean, nullable=False, default=False, server_default="false" + ) + user_id = Column( UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False ) diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index 22c552e5c..e6b2458f3 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -422,6 +422,8 @@ class IndexingPipelineService: ) log_index_success(ctx, chunk_count=len(chunks)) + await self._enqueue_ai_sort_if_enabled(document) + except RETRYABLE_LLM_ERRORS as e: log_retryable_llm_error(ctx, e) await rollback_and_persist_failure( @@ -457,6 +459,29 @@ class IndexingPipelineService: return document + async def _enqueue_ai_sort_if_enabled(self, document: Document) -> None: + """Fire-and-forget: enqueue incremental AI sort if the search space has it enabled.""" + try: + from app.db import SearchSpace + + result = await self.session.execute( + select(SearchSpace.ai_file_sort_enabled).where( + SearchSpace.id == document.search_space_id + ) + ) + enabled = result.scalar() + if not enabled: + return + + from app.tasks.celery_tasks.document_tasks import ai_sort_document_task + + user_id = str(document.created_by_id) if document.created_by_id else "" + ai_sort_document_task.delay(document.search_space_id, user_id, document.id) + except Exception: + logging.getLogger(__name__).warning( + "Failed to enqueue AI sort for document %s", document.id, exc_info=True + ) + async def index_batch_parallel( self, connector_docs: list[ConnectorDocument], diff --git a/surfsense_backend/app/routes/export_routes.py b/surfsense_backend/app/routes/export_routes.py index 641c7fedb..4f2b545a3 100644 --- a/surfsense_backend/app/routes/export_routes.py +++ b/surfsense_backend/app/routes/export_routes.py @@ -20,7 +20,9 @@ router = APIRouter() @router.get("/search-spaces/{search_space_id}/export") async def export_knowledge_base( search_space_id: int, - folder_id: int | None = Query(None, description="Export only this folder's subtree"), + folder_id: int | None = Query( + None, description="Export only this folder's subtree" + ), session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): diff --git a/surfsense_backend/app/routes/sandbox_routes.py b/surfsense_backend/app/routes/sandbox_routes.py index 2c12c3a1e..f656e8d76 100644 --- a/surfsense_backend/app/routes/sandbox_routes.py +++ b/surfsense_backend/app/routes/sandbox_routes.py @@ -86,9 +86,8 @@ async def download_sandbox_file( # Fall back to live sandbox download try: - sandbox = await get_or_create_sandbox(thread_id) - raw_sandbox = sandbox._sandbox - content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path) + sandbox, _ = await get_or_create_sandbox(thread_id) + content: bytes = await asyncio.to_thread(sandbox.download_file, path) except Exception as exc: logger.warning("Sandbox file download failed for %s: %s", path, exc) raise HTTPException( diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 7fa33ba1c..828137518 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -216,6 +216,7 @@ async def read_search_spaces( user_id=space.user_id, citations_enabled=space.citations_enabled, qna_custom_instructions=space.qna_custom_instructions, + ai_file_sort_enabled=space.ai_file_sort_enabled, member_count=member_count, is_owner=is_owner, ) @@ -384,6 +385,42 @@ async def edit_team_memory( return db_search_space +@router.post("/searchspaces/{search_space_id}/ai-sort") +async def trigger_ai_sort( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Trigger a full AI file sort for all documents in the search space.""" + try: + await check_permission( + session, + user, + search_space_id, + Permission.SETTINGS_UPDATE.value, + "You don't have permission to trigger AI sort on this search space", + ) + + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + db_search_space = result.scalars().first() + if not db_search_space: + raise HTTPException(status_code=404, detail="Search space not found") + + from app.tasks.celery_tasks.document_tasks import ai_sort_search_space_task + + ai_sort_search_space_task.delay(search_space_id, str(user.id)) + return {"message": "AI sort started"} + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to trigger AI sort: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to trigger AI sort: {e!s}" + ) from e + + @router.delete("/searchspaces/{search_space_id}", response_model=dict) async def delete_search_space( search_space_id: int, diff --git a/surfsense_backend/app/schemas/search_space.py b/surfsense_backend/app/schemas/search_space.py index e3b50be65..77e34ea4b 100644 --- a/surfsense_backend/app/schemas/search_space.py +++ b/surfsense_backend/app/schemas/search_space.py @@ -22,6 +22,7 @@ class SearchSpaceUpdate(BaseModel): citations_enabled: bool | None = None qna_custom_instructions: str | None = None shared_memory_md: str | None = None + ai_file_sort_enabled: bool | None = None class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel): @@ -31,6 +32,7 @@ class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel): citations_enabled: bool qna_custom_instructions: str | None = None shared_memory_md: str | None = None + ai_file_sort_enabled: bool = False model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/services/ai_file_sort_service.py b/surfsense_backend/app/services/ai_file_sort_service.py new file mode 100644 index 000000000..2f04131a6 --- /dev/null +++ b/surfsense_backend/app/services/ai_file_sort_service.py @@ -0,0 +1,329 @@ +"""AI File Sort Service: builds connector-type/date/category/subcategory folder paths.""" + +from __future__ import annotations + +import json +import logging +import re +from datetime import UTC, datetime + +from langchain_core.messages import HumanMessage +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm import selectinload + +from app.db import ( + Chunk, + Document, + DocumentType, + SearchSourceConnector, + SearchSourceConnectorType, +) +from app.services.folder_service import ensure_folder_hierarchy_with_depth_validation + +logger = logging.getLogger(__name__) + +_DOCTYPE_TO_CONNECTOR_LABEL: dict[str, str] = { + DocumentType.EXTENSION: "Browser Extension", + DocumentType.CRAWLED_URL: "Web Crawl", + DocumentType.FILE: "File Upload", + DocumentType.SLACK_CONNECTOR: "Slack", + DocumentType.TEAMS_CONNECTOR: "Teams", + DocumentType.ONEDRIVE_FILE: "OneDrive", + DocumentType.NOTION_CONNECTOR: "Notion", + DocumentType.YOUTUBE_VIDEO: "YouTube", + DocumentType.GITHUB_CONNECTOR: "GitHub", + DocumentType.LINEAR_CONNECTOR: "Linear", + DocumentType.DISCORD_CONNECTOR: "Discord", + DocumentType.JIRA_CONNECTOR: "Jira", + DocumentType.CONFLUENCE_CONNECTOR: "Confluence", + DocumentType.CLICKUP_CONNECTOR: "ClickUp", + DocumentType.GOOGLE_CALENDAR_CONNECTOR: "Google Calendar", + DocumentType.GOOGLE_GMAIL_CONNECTOR: "Gmail", + DocumentType.GOOGLE_DRIVE_FILE: "Google Drive", + DocumentType.AIRTABLE_CONNECTOR: "Airtable", + DocumentType.LUMA_CONNECTOR: "Luma", + DocumentType.ELASTICSEARCH_CONNECTOR: "Elasticsearch", + DocumentType.BOOKSTACK_CONNECTOR: "BookStack", + DocumentType.CIRCLEBACK: "Circleback", + DocumentType.OBSIDIAN_CONNECTOR: "Obsidian", + DocumentType.NOTE: "Notes", + DocumentType.DROPBOX_FILE: "Dropbox", + DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: "Google Drive (Composio)", + DocumentType.COMPOSIO_GMAIL_CONNECTOR: "Gmail (Composio)", + DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: "Google Calendar (Composio)", + DocumentType.LOCAL_FOLDER_FILE: "Local Folder", +} + +_CONNECTOR_TYPE_LABEL: dict[str, str] = { + SearchSourceConnectorType.SERPER_API: "Serper Search", + SearchSourceConnectorType.TAVILY_API: "Tavily Search", + SearchSourceConnectorType.SEARXNG_API: "SearXNG Search", + SearchSourceConnectorType.LINKUP_API: "Linkup Search", + SearchSourceConnectorType.BAIDU_SEARCH_API: "Baidu Search", + SearchSourceConnectorType.SLACK_CONNECTOR: "Slack", + SearchSourceConnectorType.TEAMS_CONNECTOR: "Teams", + SearchSourceConnectorType.ONEDRIVE_CONNECTOR: "OneDrive", + SearchSourceConnectorType.NOTION_CONNECTOR: "Notion", + SearchSourceConnectorType.GITHUB_CONNECTOR: "GitHub", + SearchSourceConnectorType.LINEAR_CONNECTOR: "Linear", + SearchSourceConnectorType.DISCORD_CONNECTOR: "Discord", + SearchSourceConnectorType.JIRA_CONNECTOR: "Jira", + SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "Confluence", + SearchSourceConnectorType.CLICKUP_CONNECTOR: "ClickUp", + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "Google Calendar", + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: "Gmail", + SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: "Google Drive", + SearchSourceConnectorType.AIRTABLE_CONNECTOR: "Airtable", + SearchSourceConnectorType.LUMA_CONNECTOR: "Luma", + SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: "Elasticsearch", + SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: "Web Crawl", + SearchSourceConnectorType.BOOKSTACK_CONNECTOR: "BookStack", + SearchSourceConnectorType.CIRCLEBACK_CONNECTOR: "Circleback", + SearchSourceConnectorType.OBSIDIAN_CONNECTOR: "Obsidian", + SearchSourceConnectorType.MCP_CONNECTOR: "MCP", + SearchSourceConnectorType.DROPBOX_CONNECTOR: "Dropbox", + SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: "Google Drive (Composio)", + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: "Gmail (Composio)", + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: "Google Calendar (Composio)", +} + +_MAX_CONTENT_CHARS = 4000 +_MAX_CHUNKS_FOR_CONTEXT = 5 + +_CATEGORY_PROMPT = ( + "Based on the document information below, classify it into a broad category " + "and a more specific subcategory.\n\n" + "Rules:\n" + "- category: 1-2 word broad theme (e.g. Science, Finance, Engineering, Communication, Media)\n" + "- subcategory: 1-2 word specific topic within the category " + "(e.g. Physics, Tax Reports, Backend, Team Updates)\n" + "- Use nouns only. Do not include generic terms like 'General' or 'Miscellaneous'.\n\n" + "Title: {title}\n\n" + "Content: {summary}\n\n" + 'Respond with ONLY a JSON object: {{"category": "...", "subcategory": "..."}}' +) + +_SAFE_NAME_RE = re.compile(r"[^a-zA-Z0-9 _\-()]") +_FALLBACK_CATEGORY = "Uncategorized" +_FALLBACK_SUBCATEGORY = "General" + + +def resolve_root_folder_label( + document: Document, connector: SearchSourceConnector | None +) -> str: + if connector is not None: + return _CONNECTOR_TYPE_LABEL.get( + connector.connector_type, str(connector.connector_type) + ) + return _DOCTYPE_TO_CONNECTOR_LABEL.get( + document.document_type, str(document.document_type) + ) + + +def resolve_date_folder(document: Document) -> str: + ts = document.updated_at or document.created_at + if ts is None: + ts = datetime.now(UTC) + return ts.strftime("%Y-%m-%d") + + +def sanitize_category_folder_name( + value: str | None, fallback: str = _FALLBACK_CATEGORY +) -> str: + if not value or not value.strip(): + return fallback + cleaned = _SAFE_NAME_RE.sub("", value.strip()) + cleaned = " ".join(cleaned.split()) + if not cleaned: + return fallback + return cleaned[:50] + + +async def _resolve_document_text( + session: AsyncSession, + document: Document, +) -> str: + """Build the best available text representation for taxonomy generation. + + Prefers ``document.content``; falls back to joining the first few chunks + when content is empty or too short to be useful. + """ + text = (document.content or "").strip() + if len(text) >= 100: + return text[:_MAX_CONTENT_CHARS] + + stmt = ( + select(Chunk.content) + .where(Chunk.document_id == document.id) + .order_by(Chunk.id) + .limit(_MAX_CHUNKS_FOR_CONTEXT) + ) + result = await session.execute(stmt) + chunk_texts = [row[0] for row in result.all() if row[0]] + if chunk_texts: + combined = "\n\n".join(chunk_texts) + return combined[:_MAX_CONTENT_CHARS] + + return text[:_MAX_CONTENT_CHARS] + + +def _get_cached_taxonomy(document: Document) -> tuple[str, str] | None: + """Return (category, subcategory) from document metadata cache, or None.""" + meta = document.document_metadata + if not isinstance(meta, dict): + return None + cat = meta.get("ai_sort_category") + subcat = meta.get("ai_sort_subcategory") + if cat and subcat and isinstance(cat, str) and isinstance(subcat, str): + return cat, subcat + return None + + +def _set_cached_taxonomy(document: Document, category: str, subcategory: str) -> None: + """Persist the AI taxonomy on document metadata for deterministic re-sorts.""" + meta = dict(document.document_metadata or {}) + meta["ai_sort_category"] = category + meta["ai_sort_subcategory"] = subcategory + document.document_metadata = meta + + +async def generate_ai_taxonomy( + title: str, + summary_or_content: str, + llm, +) -> tuple[str, str]: + """Return (category, subcategory) using a single structured LLM call.""" + text = (summary_or_content or "").strip() + if not text: + return _FALLBACK_CATEGORY, _FALLBACK_SUBCATEGORY + + if len(text) > _MAX_CONTENT_CHARS: + text = text[:_MAX_CONTENT_CHARS] + + prompt = _CATEGORY_PROMPT.format(title=title or "Untitled", summary=text) + try: + result = await llm.ainvoke([HumanMessage(content=prompt)]) + raw = result.content.strip() + if raw.startswith("```"): + raw = raw.split("\n", 1)[-1].rsplit("```", 1)[0].strip() + parsed = json.loads(raw) + category = sanitize_category_folder_name( + parsed.get("category"), _FALLBACK_CATEGORY + ) + subcategory = sanitize_category_folder_name( + parsed.get("subcategory"), _FALLBACK_SUBCATEGORY + ) + return category, subcategory + except Exception: + logger.warning("AI taxonomy generation failed, using fallback", exc_info=True) + return _FALLBACK_CATEGORY, _FALLBACK_SUBCATEGORY + + +def _build_path_segments( + root_label: str, + date_label: str, + category: str, + subcategory: str, +) -> list[dict]: + return [ + {"name": root_label, "metadata": {"ai_sort": True, "ai_sort_level": 1}}, + {"name": date_label, "metadata": {"ai_sort": True, "ai_sort_level": 2}}, + {"name": category, "metadata": {"ai_sort": True, "ai_sort_level": 3}}, + {"name": subcategory, "metadata": {"ai_sort": True, "ai_sort_level": 4}}, + ] + + +async def _resolve_taxonomy( + session: AsyncSession, + document: Document, + llm, +) -> tuple[str, str]: + """Return (category, subcategory), reusing cached values when available.""" + cached = _get_cached_taxonomy(document) + if cached is not None: + return cached + + content_text = await _resolve_document_text(session, document) + category, subcategory = await generate_ai_taxonomy( + document.title, content_text, llm + ) + _set_cached_taxonomy(document, category, subcategory) + return category, subcategory + + +async def ai_sort_document( + session: AsyncSession, + document: Document, + llm, +) -> Document: + """Sort a single document into the 4-level AI folder hierarchy.""" + connector: SearchSourceConnector | None = None + if document.connector_id is not None: + connector = await session.get(SearchSourceConnector, document.connector_id) + + root_label = resolve_root_folder_label(document, connector) + date_label = resolve_date_folder(document) + + category, subcategory = await _resolve_taxonomy(session, document, llm) + + segments = _build_path_segments(root_label, date_label, category, subcategory) + + leaf_folder = await ensure_folder_hierarchy_with_depth_validation( + session, + document.search_space_id, + segments, + ) + + document.folder_id = leaf_folder.id + await session.flush() + return document + + +async def ai_sort_all_documents( + session: AsyncSession, + search_space_id: int, + llm, +) -> tuple[int, int]: + """Sort all documents in a search space. Returns (sorted_count, failed_count).""" + stmt = ( + select(Document) + .where(Document.search_space_id == search_space_id) + .options(selectinload(Document.connector)) + ) + result = await session.execute(stmt) + documents = list(result.scalars().all()) + + sorted_count = 0 + failed_count = 0 + + for doc in documents: + try: + connector = doc.connector + root_label = resolve_root_folder_label(doc, connector) + date_label = resolve_date_folder(doc) + + category, subcategory = await _resolve_taxonomy(session, doc, llm) + segments = _build_path_segments( + root_label, date_label, category, subcategory + ) + + leaf_folder = await ensure_folder_hierarchy_with_depth_validation( + session, + search_space_id, + segments, + ) + doc.folder_id = leaf_folder.id + sorted_count += 1 + except Exception: + logger.error("Failed to AI-sort document %s", doc.id, exc_info=True) + failed_count += 1 + + await session.commit() + logger.info( + "AI sort complete for search_space=%d: sorted=%d, failed=%d", + search_space_id, + sorted_count, + failed_count, + ) + return sorted_count, failed_count diff --git a/surfsense_backend/app/services/folder_service.py b/surfsense_backend/app/services/folder_service.py index fc1fb8a75..f5b608600 100644 --- a/surfsense_backend/app/services/folder_service.py +++ b/surfsense_backend/app/services/folder_service.py @@ -142,6 +142,58 @@ async def generate_folder_position( return generate_key_between(last_position, None) +async def ensure_folder_hierarchy_with_depth_validation( + session: AsyncSession, + search_space_id: int, + path_segments: list[dict], +) -> Folder: + """Create or return a nested folder chain, validating depth at each step. + + Each item in ``path_segments`` is a dict with: + - ``name`` (str): folder display name + - ``metadata`` (dict | None): optional ``folder_metadata`` JSONB payload + + Returns the deepest (leaf) Folder in the chain. + """ + parent_id: int | None = None + current_folder: Folder | None = None + + for segment in path_segments: + name = segment["name"] + metadata = segment.get("metadata") + + stmt = select(Folder).where( + Folder.search_space_id == search_space_id, + Folder.name == name, + Folder.parent_id == parent_id + if parent_id is not None + else Folder.parent_id.is_(None), + ) + result = await session.execute(stmt) + folder = result.scalar_one_or_none() + + if folder is None: + await validate_folder_depth(session, parent_id, subtree_depth=0) + position = await generate_folder_position( + session, search_space_id, parent_id + ) + folder = Folder( + name=name, + search_space_id=search_space_id, + parent_id=parent_id, + position=position, + folder_metadata=metadata, + ) + session.add(folder) + await session.flush() + + current_folder = folder + parent_id = folder.id + + assert current_folder is not None, "path_segments must not be empty" + return current_folder + + async def get_folder_subtree_ids(session: AsyncSession, folder_id: int) -> list[int]: """Return all folder IDs in the subtree rooted at folder_id (inclusive).""" result = await session.execute( diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index fc946b4bc..719cfb940 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -4,6 +4,7 @@ import asyncio import contextlib import logging import os +import time from uuid import UUID from app.celery_app import celery_app @@ -1551,3 +1552,121 @@ async def _index_uploaded_folder_files_async( heartbeat_task.cancel() if notification_id is not None: _stop_heartbeat(notification_id) + + +# ===== AI File Sort tasks ===== + +AI_SORT_LOCK_TTL_SECONDS = 600 # 10 minutes +_ai_sort_redis = None + + +def _get_ai_sort_redis(): + import redis + + global _ai_sort_redis + if _ai_sort_redis is None: + _ai_sort_redis = redis.from_url(config.REDIS_APP_URL, decode_responses=True) + return _ai_sort_redis + + +def _ai_sort_lock_key(search_space_id: int) -> str: + return f"ai_sort:search_space:{search_space_id}:lock" + + +@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1) +def ai_sort_search_space_task(self, search_space_id: int, user_id: str): + """Full AI sort for all documents in a search space.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id)) + finally: + loop.close() + + +async def _ai_sort_search_space_async(search_space_id: int, user_id: str): + r = _get_ai_sort_redis() + lock_key = _ai_sort_lock_key(search_space_id) + + if not r.set(lock_key, "running", nx=True, ex=AI_SORT_LOCK_TTL_SECONDS): + logger.info( + "AI sort already running for search_space=%d, skipping", + search_space_id, + ) + return + + t_start = time.perf_counter() + try: + from app.services.ai_file_sort_service import ai_sort_all_documents + from app.services.llm_service import get_document_summary_llm + + async with get_celery_session_maker()() as session: + llm = await get_document_summary_llm( + session, search_space_id, disable_streaming=True + ) + if llm is None: + logger.warning( + "No LLM configured for search_space=%d, skipping AI sort", + search_space_id, + ) + return + + sorted_count, failed_count = await ai_sort_all_documents( + session, search_space_id, llm + ) + elapsed = time.perf_counter() - t_start + logger.info( + "AI sort search_space=%d done in %.1fs: sorted=%d failed=%d", + search_space_id, + elapsed, + sorted_count, + failed_count, + ) + finally: + r.delete(lock_key) + + +@celery_app.task( + name="ai_sort_document", bind=True, max_retries=2, default_retry_delay=10 +) +def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int): + """Incremental AI sort for a single document after indexing.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + _ai_sort_document_async(search_space_id, user_id, document_id) + ) + finally: + loop.close() + + +async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int): + from app.db import Document + from app.services.ai_file_sort_service import ai_sort_document + from app.services.llm_service import get_document_summary_llm + + async with get_celery_session_maker()() as session: + document = await session.get(Document, document_id) + if document is None: + logger.warning("Document %d not found, skipping AI sort", document_id) + return + + llm = await get_document_summary_llm( + session, search_space_id, disable_streaming=True + ) + if llm is None: + logger.warning( + "No LLM for search_space=%d, skipping AI sort of doc=%d", + search_space_id, + document_id, + ) + return + + await ai_sort_document(session, document, llm) + await session.commit() + logger.info( + "AI sorted document=%d into search_space=%d", + document_id, + search_space_id, + ) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 364a14bad..4530f5046 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -61,6 +61,7 @@ from app.services.new_streaming_service import VercelStreamingService from app.utils.content_utils import bootstrap_history_from_db from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap +_background_tasks: set[asyncio.Task] = set() _perf_log = get_perf_logger() @@ -142,7 +143,7 @@ class StreamResult: accumulated_text: str = "" is_interrupted: bool = False interrupt_value: dict[str, Any] | None = None - sandbox_files: list[str] = field(default_factory=list) # unused, kept for compat + sandbox_files: list[str] = field(default_factory=list) agent_called_update_memory: bool = False @@ -440,7 +441,7 @@ async def _stream_agent_events( status="in_progress", items=last_active_step_items, ) - elif tool_name == "execute": + elif tool_name in ("execute", "execute_code"): cmd = ( tool_input.get("command", "") if isinstance(tool_input, dict) @@ -738,7 +739,7 @@ async def _stream_agent_events( status="completed", items=completed_items, ) - elif tool_name == "execute": + elif tool_name in ("execute", "execute_code"): raw_text = ( tool_output.get("result", "") if isinstance(tool_output, dict) @@ -985,7 +986,7 @@ async def _stream_agent_events( if isinstance(tool_output, dict) else {"result": tool_output}, ) - elif tool_name == "execute": + elif tool_name in ("execute", "execute_code"): raw_text = ( tool_output.get("result", "") if isinstance(tool_output, dict) @@ -1598,7 +1599,7 @@ async def stream_new_chat( # Shared threads write to team memory; private threads write to user memory. if not stream_result.agent_called_update_memory: if visibility == ChatVisibility.SEARCH_SPACE: - asyncio.create_task( + task = asyncio.create_task( extract_and_save_team_memory( user_message=user_query, search_space_id=search_space_id, @@ -1606,14 +1607,18 @@ async def stream_new_chat( author_display_name=current_user_display_name, ) ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) elif user_id: - asyncio.create_task( + task = asyncio.create_task( extract_and_save_memory( user_message=user_query, user_id=user_id, llm=llm, ) ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) # Finish the step and message yield streaming_service.format_finish_step() @@ -1663,6 +1668,21 @@ async def stream_new_chat( with contextlib.suppress(Exception): await session.close() + # Persist any sandbox-produced files to local storage so they + # remain downloadable after the Daytona sandbox auto-deletes. + if stream_result and stream_result.sandbox_files: + with contextlib.suppress(Exception): + from app.agents.new_chat.sandbox import ( + is_sandbox_enabled, + persist_and_delete_sandbox, + ) + + if is_sandbox_enabled(): + with anyio.CancelScope(shield=True): + await persist_and_delete_sandbox( + chat_id, stream_result.sandbox_files + ) + # Break circular refs held by the agent graph, tools, and LLM # wrappers so the GC can reclaim them in a single pass. agent = llm = connector_service = None diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index d8f95da63..21cdbd29f 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -961,6 +961,7 @@ async def index_google_drive_files( vision_llm = None if connector_enable_vision_llm: from app.services.llm_service import get_vision_llm + vision_llm = await get_vision_llm(session, search_space_id) drive_client = GoogleDriveClient( session, connector_id, credentials=pre_built_credentials @@ -1168,6 +1169,7 @@ async def index_google_drive_single_file( vision_llm = None if connector_enable_vision_llm: from app.services.llm_service import get_vision_llm + vision_llm = await get_vision_llm(session, search_space_id) drive_client = GoogleDriveClient( session, connector_id, credentials=pre_built_credentials @@ -1306,6 +1308,7 @@ async def index_google_drive_selected_files( vision_llm = None if connector_enable_vision_llm: from app.services.llm_service import get_vision_llm + vision_llm = await get_vision_llm(session, search_space_id) drive_client = GoogleDriveClient( session, connector_id, credentials=pre_built_credentials diff --git a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py index 2d5f9648d..b6797f77a 100644 --- a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py @@ -1360,7 +1360,9 @@ async def index_uploaded_files( try: content, content_hash = await _compute_file_content_hash( - temp_path, filename, search_space_id, + temp_path, + filename, + search_space_id, vision_llm=vision_llm_instance, ) except Exception as e: diff --git a/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py index aa654a9a9..2def799f3 100644 --- a/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py @@ -656,6 +656,7 @@ async def index_onedrive_files( vision_llm = None if connector_enable_vision_llm: from app.services.llm_service import get_vision_llm + vision_llm = await get_vision_llm(session, search_space_id) onedrive_client = OneDriveClient(session, connector_id) diff --git a/surfsense_backend/scripts/create_sandbox_snapshot.py b/surfsense_backend/scripts/create_sandbox_snapshot.py index 97ed6dfe8..82d452174 100644 --- a/surfsense_backend/scripts/create_sandbox_snapshot.py +++ b/surfsense_backend/scripts/create_sandbox_snapshot.py @@ -21,7 +21,11 @@ from pathlib import Path from dotenv import load_dotenv _here = Path(__file__).parent -for candidate in [_here / "../surfsense_backend/.env", _here / ".env", _here / "../.env"]: +for candidate in [ + _here / "../surfsense_backend/.env", + _here / ".env", + _here / "../.env", +]: if candidate.exists(): load_dotenv(candidate) break @@ -57,7 +61,10 @@ def main() -> None: api_key = os.environ.get("DAYTONA_API_KEY") if not api_key: print("ERROR: DAYTONA_API_KEY is not set.", file=sys.stderr) - print("Add it to surfsense_backend/.env or export it in your shell.", file=sys.stderr) + print( + "Add it to surfsense_backend/.env or export it in your shell.", + file=sys.stderr, + ) sys.exit(1) daytona = Daytona() @@ -67,7 +74,7 @@ def main() -> None: print(f"Deleting existing snapshot '{SNAPSHOT_NAME}' …") daytona.snapshot.delete(existing) print(f"Deleted '{SNAPSHOT_NAME}'. Waiting for removal to propagate …") - for attempt in range(30): + for _attempt in range(30): time.sleep(2) try: daytona.snapshot.get(SNAPSHOT_NAME) @@ -75,7 +82,9 @@ def main() -> None: print(f"Confirmed '{SNAPSHOT_NAME}' is gone.\n") break else: - print(f"WARNING: '{SNAPSHOT_NAME}' may still exist after 60s. Proceeding anyway.\n") + print( + f"WARNING: '{SNAPSHOT_NAME}' may still exist after 60s. Proceeding anyway.\n" + ) except Exception: pass diff --git a/surfsense_backend/tests/unit/etl_pipeline/test_etl_pipeline_service.py b/surfsense_backend/tests/unit/etl_pipeline/test_etl_pipeline_service.py index 1a94d4263..bb8d7ef83 100644 --- a/surfsense_backend/tests/unit/etl_pipeline/test_etl_pipeline_service.py +++ b/surfsense_backend/tests/unit/etl_pipeline/test_etl_pipeline_service.py @@ -431,7 +431,9 @@ async def test_llamacloud_heif_accepted_only_with_azure_di(tmp_path, mocker): mocker.patch("app.config.config.AZURE_DI_ENDPOINT", None, create=True) mocker.patch("app.config.config.AZURE_DI_KEY", None, create=True) - with pytest.raises(EtlUnsupportedFileError, match="document parser does not support this format"): + with pytest.raises( + EtlUnsupportedFileError, match="document parser does not support this format" + ): await EtlPipelineService().extract( EtlRequest(file_path=str(heif_file), filename="photo.heif") ) diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py index a8cf5c93b..1aaf5d127 100644 --- a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py @@ -6,6 +6,7 @@ import pytest from langchain_core.messages import AIMessage, HumanMessage from app.agents.new_chat.middleware.knowledge_search import ( + KBSearchPlan, KnowledgeBaseSearchMiddleware, _build_document_xml, _normalize_optional_date_range, @@ -366,3 +367,146 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: assert captured["query"] == "deel founders guide summary" assert captured["start_date"] is None assert captured["end_date"] is None + + async def test_middleware_routes_to_recency_browse_when_flagged( + self, + monkeypatch, + ): + """When the planner sets is_recency_query=true, browse_recent_documents + is called instead of search_knowledge_base.""" + browse_captured: dict = {} + search_called = False + + async def fake_browse_recent_documents(**kwargs): + browse_captured.update(kwargs) + return [] + + async def fake_search_knowledge_base(**kwargs): + nonlocal search_called + search_called = True + return [] + + async def fake_build_scoped_filesystem(**kwargs): + return {}, {} + + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.browse_recent_documents", + fake_browse_recent_documents, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + fake_search_knowledge_base, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", + fake_build_scoped_filesystem, + ) + + llm = FakeLLM( + json.dumps( + { + "optimized_query": "latest uploaded file", + "start_date": None, + "end_date": None, + "is_recency_query": True, + } + ) + ) + middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42) + + result = await middleware.abefore_agent( + {"messages": [HumanMessage(content="what's my latest file?")]}, + runtime=None, + ) + + assert result is not None + assert browse_captured["search_space_id"] == 42 + assert not search_called + + async def test_middleware_uses_hybrid_search_when_not_recency( + self, + monkeypatch, + ): + """When is_recency_query is false (default), hybrid search is used.""" + search_captured: dict = {} + browse_called = False + + async def fake_browse_recent_documents(**kwargs): + nonlocal browse_called + browse_called = True + return [] + + async def fake_search_knowledge_base(**kwargs): + search_captured.update(kwargs) + return [] + + async def fake_build_scoped_filesystem(**kwargs): + return {}, {} + + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.browse_recent_documents", + fake_browse_recent_documents, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + fake_search_knowledge_base, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", + fake_build_scoped_filesystem, + ) + + llm = FakeLLM( + json.dumps( + { + "optimized_query": "quarterly revenue report analysis", + "start_date": None, + "end_date": None, + "is_recency_query": False, + } + ) + ) + middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42) + + await middleware.abefore_agent( + {"messages": [HumanMessage(content="find the quarterly revenue report")]}, + runtime=None, + ) + + assert search_captured["query"] == "quarterly revenue report analysis" + assert not browse_called + + +# ── KBSearchPlan schema ──────────────────────────────────────────────── + + +class TestKBSearchPlanSchema: + def test_is_recency_query_defaults_to_false(self): + plan = KBSearchPlan(optimized_query="test query") + assert plan.is_recency_query is False + + def test_is_recency_query_parses_true(self): + plan = _parse_kb_search_plan_response( + json.dumps( + { + "optimized_query": "latest uploaded file", + "start_date": None, + "end_date": None, + "is_recency_query": True, + } + ) + ) + assert plan.is_recency_query is True + assert plan.optimized_query == "latest uploaded file" + + def test_missing_is_recency_query_defaults_to_false(self): + plan = _parse_kb_search_plan_response( + json.dumps( + { + "optimized_query": "meeting notes", + "start_date": None, + "end_date": None, + } + ) + ) + assert plan.is_recency_query is False diff --git a/surfsense_backend/tests/unit/services/test_ai_file_sort_service.py b/surfsense_backend/tests/unit/services/test_ai_file_sort_service.py new file mode 100644 index 000000000..860c2ffa2 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_ai_file_sort_service.py @@ -0,0 +1,275 @@ +"""Unit tests for AI file sort service: folder label resolution, date extraction, category sanitization.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest + +pytestmark = pytest.mark.unit + + +# ── resolve_root_folder_label ── + + +def _make_document(document_type: str, connector_id=None): + doc = MagicMock() + doc.document_type = document_type + doc.connector_id = connector_id + return doc + + +def _make_connector(connector_type: str): + conn = MagicMock() + conn.connector_type = connector_type + return conn + + +def test_root_label_uses_connector_type_when_available(): + from app.services.ai_file_sort_service import resolve_root_folder_label + + doc = _make_document("FILE", connector_id=1) + conn = _make_connector("GOOGLE_DRIVE_CONNECTOR") + assert resolve_root_folder_label(doc, conn) == "Google Drive" + + +def test_root_label_falls_back_to_document_type(): + from app.services.ai_file_sort_service import resolve_root_folder_label + + doc = _make_document("SLACK_CONNECTOR") + assert resolve_root_folder_label(doc, None) == "Slack" + + +def test_root_label_unknown_doctype_returns_raw_value(): + from app.services.ai_file_sort_service import resolve_root_folder_label + + doc = _make_document("UNKNOWN_TYPE") + assert resolve_root_folder_label(doc, None) == "UNKNOWN_TYPE" + + +# ── resolve_date_folder ── + + +def test_date_folder_from_updated_at(): + from app.services.ai_file_sort_service import resolve_date_folder + + doc = MagicMock() + doc.updated_at = datetime(2025, 3, 15, 10, 30, 0, tzinfo=UTC) + doc.created_at = datetime(2025, 1, 1, 0, 0, 0, tzinfo=UTC) + assert resolve_date_folder(doc) == "2025-03-15" + + +def test_date_folder_falls_back_to_created_at(): + from app.services.ai_file_sort_service import resolve_date_folder + + doc = MagicMock() + doc.updated_at = None + doc.created_at = datetime(2024, 12, 25, 23, 59, 0, tzinfo=UTC) + assert resolve_date_folder(doc) == "2024-12-25" + + +def test_date_folder_both_none_uses_today(): + from app.services.ai_file_sort_service import resolve_date_folder + + doc = MagicMock() + doc.updated_at = None + doc.created_at = None + result = resolve_date_folder(doc) + today = datetime.now(UTC).strftime("%Y-%m-%d") + assert result == today + + +# ── sanitize_category_folder_name ── + + +def test_sanitize_normal_value(): + from app.services.ai_file_sort_service import sanitize_category_folder_name + + assert sanitize_category_folder_name("Machine Learning") == "Machine Learning" + + +def test_sanitize_strips_special_chars(): + from app.services.ai_file_sort_service import sanitize_category_folder_name + + assert sanitize_category_folder_name("Tax/Reports!") == "TaxReports" + + +def test_sanitize_empty_returns_fallback(): + from app.services.ai_file_sort_service import sanitize_category_folder_name + + assert sanitize_category_folder_name("") == "Uncategorized" + assert sanitize_category_folder_name(None) == "Uncategorized" + + +def test_sanitize_truncates_long_names(): + from app.services.ai_file_sort_service import sanitize_category_folder_name + + long_name = "A" * 100 + result = sanitize_category_folder_name(long_name) + assert len(result) <= 50 + + +# ── generate_ai_taxonomy ── + + +@pytest.mark.asyncio +async def test_generate_ai_taxonomy_parses_json(): + from app.services.ai_file_sort_service import generate_ai_taxonomy + + mock_llm = AsyncMock() + mock_result = MagicMock() + mock_result.content = '{"category": "Science", "subcategory": "Physics"}' + mock_llm.ainvoke.return_value = mock_result + + cat, sub = await generate_ai_taxonomy( + "Physics Paper", "Some science document about physics", mock_llm + ) + assert cat == "Science" + assert sub == "Physics" + + +@pytest.mark.asyncio +async def test_generate_ai_taxonomy_handles_markdown_code_block(): + from app.services.ai_file_sort_service import generate_ai_taxonomy + + mock_llm = AsyncMock() + mock_result = MagicMock() + mock_result.content = ( + '```json\n{"category": "Finance", "subcategory": "Tax Reports"}\n```' + ) + mock_llm.ainvoke.return_value = mock_result + + cat, sub = await generate_ai_taxonomy("Tax Doc", "A tax report document", mock_llm) + assert cat == "Finance" + assert sub == "Tax Reports" + + +@pytest.mark.asyncio +async def test_generate_ai_taxonomy_includes_title_in_prompt(): + from app.services.ai_file_sort_service import generate_ai_taxonomy + + mock_llm = AsyncMock() + mock_result = MagicMock() + mock_result.content = '{"category": "Engineering", "subcategory": "Backend"}' + mock_llm.ainvoke.return_value = mock_result + + await generate_ai_taxonomy("API Design Guide", "content about REST APIs", mock_llm) + + prompt_text = mock_llm.ainvoke.call_args[0][0][0].content + assert "API Design Guide" in prompt_text + assert "content about REST APIs" in prompt_text + + +@pytest.mark.asyncio +async def test_generate_ai_taxonomy_fallback_on_error(): + from app.services.ai_file_sort_service import generate_ai_taxonomy + + mock_llm = AsyncMock() + mock_llm.ainvoke.side_effect = RuntimeError("LLM down") + + cat, sub = await generate_ai_taxonomy("Title", "some content", mock_llm) + assert cat == "Uncategorized" + assert sub == "General" + + +@pytest.mark.asyncio +async def test_generate_ai_taxonomy_fallback_on_empty_content(): + from app.services.ai_file_sort_service import generate_ai_taxonomy + + mock_llm = AsyncMock() + cat, sub = await generate_ai_taxonomy("Title", "", mock_llm) + assert cat == "Uncategorized" + assert sub == "General" + mock_llm.ainvoke.assert_not_called() + + +@pytest.mark.asyncio +async def test_generate_ai_taxonomy_fallback_on_invalid_json(): + from app.services.ai_file_sort_service import generate_ai_taxonomy + + mock_llm = AsyncMock() + mock_result = MagicMock() + mock_result.content = "not valid json at all" + mock_llm.ainvoke.return_value = mock_result + + cat, sub = await generate_ai_taxonomy("Title", "some content", mock_llm) + assert cat == "Uncategorized" + assert sub == "General" + + +# ── taxonomy caching ── + + +def test_get_cached_taxonomy_returns_none_when_no_metadata(): + from app.services.ai_file_sort_service import _get_cached_taxonomy + + doc = MagicMock() + doc.document_metadata = None + assert _get_cached_taxonomy(doc) is None + + +def test_get_cached_taxonomy_returns_none_when_keys_missing(): + from app.services.ai_file_sort_service import _get_cached_taxonomy + + doc = MagicMock() + doc.document_metadata = {"some_other_key": "value"} + assert _get_cached_taxonomy(doc) is None + + +def test_get_cached_taxonomy_returns_cached_values(): + from app.services.ai_file_sort_service import _get_cached_taxonomy + + doc = MagicMock() + doc.document_metadata = { + "ai_sort_category": "Finance", + "ai_sort_subcategory": "Tax Reports", + } + assert _get_cached_taxonomy(doc) == ("Finance", "Tax Reports") + + +def test_set_cached_taxonomy_persists_on_metadata(): + from app.services.ai_file_sort_service import _set_cached_taxonomy + + doc = MagicMock() + doc.document_metadata = {"existing_key": "keep_me"} + _set_cached_taxonomy(doc, "Science", "Physics") + assert doc.document_metadata["ai_sort_category"] == "Science" + assert doc.document_metadata["ai_sort_subcategory"] == "Physics" + assert doc.document_metadata["existing_key"] == "keep_me" + + +def test_set_cached_taxonomy_creates_metadata_when_none(): + from app.services.ai_file_sort_service import _set_cached_taxonomy + + doc = MagicMock() + doc.document_metadata = None + _set_cached_taxonomy(doc, "Engineering", "Backend") + assert doc.document_metadata == { + "ai_sort_category": "Engineering", + "ai_sort_subcategory": "Backend", + } + + +# ── _build_path_segments ── + + +def test_build_path_segments_structure(): + from app.services.ai_file_sort_service import _build_path_segments + + segments = _build_path_segments("Google Drive", "2025-03-15", "Science", "Physics") + assert len(segments) == 4 + assert segments[0] == { + "name": "Google Drive", + "metadata": {"ai_sort": True, "ai_sort_level": 1}, + } + assert segments[1] == { + "name": "2025-03-15", + "metadata": {"ai_sort": True, "ai_sort_level": 2}, + } + assert segments[2] == { + "name": "Science", + "metadata": {"ai_sort": True, "ai_sort_level": 3}, + } + assert segments[3] == { + "name": "Physics", + "metadata": {"ai_sort": True, "ai_sort_level": 4}, + } diff --git a/surfsense_backend/tests/unit/services/test_ai_sort_task_dedupe.py b/surfsense_backend/tests/unit/services/test_ai_sort_task_dedupe.py new file mode 100644 index 000000000..fd9018514 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_ai_sort_task_dedupe.py @@ -0,0 +1,43 @@ +"""Unit tests for AI sort task Redis deduplication lock.""" + +from unittest.mock import MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +def test_lock_key_format(): + from app.tasks.celery_tasks.document_tasks import _ai_sort_lock_key + + key = _ai_sort_lock_key(42) + assert key == "ai_sort:search_space:42:lock" + + +def test_lock_prevents_duplicate_run(): + """When the Redis lock already exists, the task should skip execution.""" + + mock_redis = MagicMock() + mock_redis.set.return_value = False # Lock already held + + with ( + patch( + "app.tasks.celery_tasks.document_tasks._get_ai_sort_redis", + return_value=mock_redis, + ), + patch( + "app.tasks.celery_tasks.document_tasks.get_celery_session_maker" + ) as mock_session_maker, + ): + import asyncio + + from app.tasks.celery_tasks.document_tasks import _ai_sort_search_space_async + + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(_ai_sort_search_space_async(1, "user-123")) + finally: + loop.close() + + # Session maker should never be called since lock was not acquired + mock_session_maker.assert_not_called() diff --git a/surfsense_backend/tests/unit/services/test_folder_hierarchy.py b/surfsense_backend/tests/unit/services/test_folder_hierarchy.py new file mode 100644 index 000000000..9077f6b0e --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_folder_hierarchy.py @@ -0,0 +1,87 @@ +"""Unit tests for ensure_folder_hierarchy_with_depth_validation.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_creates_missing_folders_in_chain(): + """Should create all folders when none exist.""" + from app.services.folder_service import ( + ensure_folder_hierarchy_with_depth_validation, + ) + + session = AsyncMock() + # All lookups return None (no existing folders) + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + session.execute.return_value = mock_result + + folder_instances = [] + + def track_add(obj): + folder_instances.append(obj) + + session.add = track_add + + with ( + patch( + "app.services.folder_service.validate_folder_depth", new_callable=AsyncMock + ), + patch( + "app.services.folder_service.generate_folder_position", + new_callable=AsyncMock, + return_value="a0", + ), + ): + # Mock flush to assign IDs + call_count = 0 + + async def mock_flush(): + nonlocal call_count + call_count += 1 + if folder_instances: + folder_instances[-1].id = call_count + + session.flush = mock_flush + + segments = [ + {"name": "Slack", "metadata": {"ai_sort": True, "ai_sort_level": 1}}, + {"name": "2025-03-15", "metadata": {"ai_sort": True, "ai_sort_level": 2}}, + ] + + result = await ensure_folder_hierarchy_with_depth_validation( + session, 1, segments + ) + + assert len(folder_instances) == 2 + assert folder_instances[0].name == "Slack" + assert folder_instances[1].name == "2025-03-15" + assert result is folder_instances[-1] + + +@pytest.mark.asyncio +async def test_reuses_existing_folder(): + """When a folder already exists, it should be reused, not created.""" + from app.services.folder_service import ( + ensure_folder_hierarchy_with_depth_validation, + ) + + session = AsyncMock() + + existing_folder = MagicMock() + existing_folder.id = 42 + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = existing_folder + session.execute.return_value = mock_result + + segments = [{"name": "Existing", "metadata": None}] + + result = await ensure_folder_hierarchy_with_depth_validation(session, 1, segments) + + assert result is existing_folder + session.add.assert_not_called() diff --git a/surfsense_web/app/dashboard/error.tsx b/surfsense_web/app/dashboard/error.tsx new file mode 100644 index 000000000..4d872a69f --- /dev/null +++ b/surfsense_web/app/dashboard/error.tsx @@ -0,0 +1,44 @@ +"use client"; + +import Link from "next/link"; +import { useEffect } from "react"; + +export default function DashboardError({ + error, + reset, +}: { + error: globalThis.Error & { digest?: string }; + reset: () => void; +}) { + useEffect(() => { + import("posthog-js") + .then(({ default: posthog }) => { + posthog.captureException(error); + }) + .catch(() => {}); + }, [error]); + + return ( +
+

Something went wrong

+

+ An error occurred in this section. Your dashboard is still available. +

+
+ + + Go to dashboard home + +
+
+ ); +} diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 4f3dd7c00..8d3e90b7d 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -53,6 +53,7 @@ import { useElectronAPI } from "@/hooks/use-platform"; import { useTokenUsage } from "@/components/assistant-ui/token-usage-context"; import { getProviderIcon } from "@/lib/provider-icons"; import { cn } from "@/lib/utils"; +import { openSafeNavigationHref, resolveSafeNavigationHref } from "@/components/tool-ui/shared/media"; // Captured once at module load — survives client-side navigations that strip the query param. const IS_QUICK_ASSIST_WINDOW = @@ -482,6 +483,7 @@ const AssistantMessageInner: FC = () => { generate_image: GenerateImageToolUI, update_memory: UpdateMemoryToolUI, execute: SandboxExecuteToolUI, + execute_code: SandboxExecuteToolUI, create_notion_page: CreateNotionPageToolUI, update_notion_page: UpdateNotionPageToolUI, delete_notion_page: DeleteNotionPageToolUI, diff --git a/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx b/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx index c08871445..eb161f485 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx @@ -2,7 +2,7 @@ import { Search, Unplug } from "lucide-react"; import type { FC } from "react"; -import { getDocumentTypeLabel } from "@/components/documents/DocumentTypeIcon"; +import { getDocumentTypeLabel } from "@/lib/documents/document-type-labels"; import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import { TabsContent } from "@/components/ui/tabs"; diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 59797fc72..dcb1e3e9e 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -60,7 +60,7 @@ import { } from "@/components/assistant-ui/inline-mention-editor"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { UserMessage } from "@/components/assistant-ui/user-message"; -import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/components/layout/ui/sidebar/SidebarSlideOutPanel"; +import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events"; import { DocumentMentionPicker, type DocumentMentionPickerRef, diff --git a/surfsense_web/components/chat-comments/comment-item/comment-item.tsx b/surfsense_web/components/chat-comments/comment-item/comment-item.tsx index eb374ba49..359439d07 100644 --- a/surfsense_web/components/chat-comments/comment-item/comment-item.tsx +++ b/surfsense_web/components/chat-comments/comment-item/comment-item.tsx @@ -9,6 +9,7 @@ import { Button } from "@/components/ui/button"; import { cn } from "@/lib/utils"; import { CommentComposer } from "../comment-composer/comment-composer"; import { CommentActions } from "./comment-actions"; +import { convertRenderedToDisplay } from "@/lib/comments/utils"; import type { CommentItemProps } from "./types"; function getInitials(name: string | null, email: string): string { @@ -69,10 +70,6 @@ function formatTimestamp(dateString: string): string { ); } -export function convertRenderedToDisplay(contentRendered: string): string { - // Convert @{DisplayName} format to @DisplayName for editing - return contentRendered.replace(/@\{([^}]+)\}/g, "@$1"); -} function renderMentions(content: string): React.ReactNode { // Match @{DisplayName} format from backend diff --git a/surfsense_web/components/documents/DocumentTypeIcon.tsx b/surfsense_web/components/documents/DocumentTypeIcon.tsx index 5c03d96fa..4c6507081 100644 --- a/surfsense_web/components/documents/DocumentTypeIcon.tsx +++ b/surfsense_web/components/documents/DocumentTypeIcon.tsx @@ -4,52 +4,12 @@ import type React from "react"; import { useEffect, useRef, useState } from "react"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; +import { getDocumentTypeLabel } from "@/lib/documents/document-type-labels"; export function getDocumentTypeIcon(type: string, className?: string): React.ReactNode { return getConnectorIcon(type, className); } -export function getDocumentTypeLabel(type: string): string { - const labelMap: Record = { - EXTENSION: "Extension", - CRAWLED_URL: "Web Page", - FILE: "File", - SLACK_CONNECTOR: "Slack", - TEAMS_CONNECTOR: "Microsoft Teams", - ONEDRIVE_FILE: "OneDrive", - DROPBOX_FILE: "Dropbox", - NOTION_CONNECTOR: "Notion", - YOUTUBE_VIDEO: "YouTube Video", - GITHUB_CONNECTOR: "GitHub", - LINEAR_CONNECTOR: "Linear", - DISCORD_CONNECTOR: "Discord", - JIRA_CONNECTOR: "Jira", - CONFLUENCE_CONNECTOR: "Confluence", - CLICKUP_CONNECTOR: "ClickUp", - GOOGLE_CALENDAR_CONNECTOR: "Google Calendar", - GOOGLE_GMAIL_CONNECTOR: "Gmail", - GOOGLE_DRIVE_FILE: "Google Drive", - AIRTABLE_CONNECTOR: "Airtable", - LUMA_CONNECTOR: "Luma", - ELASTICSEARCH_CONNECTOR: "Elasticsearch", - BOOKSTACK_CONNECTOR: "BookStack", - CIRCLEBACK: "Circleback", - OBSIDIAN_CONNECTOR: "Obsidian", - LOCAL_FOLDER_FILE: "Local Folder", - SURFSENSE_DOCS: "SurfSense Docs", - NOTE: "Note", - COMPOSIO_GOOGLE_DRIVE_CONNECTOR: "Composio Google Drive", - COMPOSIO_GMAIL_CONNECTOR: "Composio Gmail", - COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: "Composio Google Calendar", - }; - return ( - labelMap[type] || - type - .split("_") - .map((word) => word.charAt(0) + word.slice(1).toLowerCase()) - .join(" ") - ); -} export function DocumentTypeChip({ type, className }: { type: string; className?: string }) { const icon = getDocumentTypeIcon(type, "h-4 w-4"); diff --git a/surfsense_web/components/documents/DocumentsFilters.tsx b/surfsense_web/components/documents/DocumentsFilters.tsx index d43f3680b..2b7cf0f10 100644 --- a/surfsense_web/components/documents/DocumentsFilters.tsx +++ b/surfsense_web/components/documents/DocumentsFilters.tsx @@ -1,6 +1,8 @@ "use client"; +import { IconBinaryTree, IconBinaryTreeFilled } from "@tabler/icons-react"; import { FolderPlus, ListFilter, Search, Upload, X } from "lucide-react"; +import { AnimatePresence, motion } from "motion/react"; import { useTranslations } from "next-intl"; import React, { useCallback, useMemo, useRef, useState } from "react"; import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup"; @@ -10,8 +12,10 @@ import { Input } from "@/components/ui/input"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { cn } from "@/lib/utils"; import type { DocumentTypeEnum } from "@/contracts/types/document.types"; -import { getDocumentTypeIcon, getDocumentTypeLabel } from "./DocumentTypeIcon"; +import { getDocumentTypeIcon } from "./DocumentTypeIcon"; +import { getDocumentTypeLabel } from "@/lib/documents/document-type-labels"; export function DocumentsFilters({ typeCounts: typeCountsRecord, @@ -20,6 +24,9 @@ export function DocumentsFilters({ onToggleType, activeTypes, onCreateFolder, + aiSortEnabled = false, + aiSortBusy = false, + onToggleAiSort, }: { typeCounts: Partial>; onSearch: (v: string) => void; @@ -27,6 +34,9 @@ export function DocumentsFilters({ onToggleType: (type: DocumentTypeEnum, checked: boolean) => void; activeTypes: DocumentTypeEnum[]; onCreateFolder?: () => void; + aiSortEnabled?: boolean; + aiSortBusy?: boolean; + onToggleAiSort?: () => void; }) { const t = useTranslations("documents"); const id = React.useId(); @@ -64,7 +74,7 @@ export function DocumentsFilters({ return (
- {/* Filter + New Folder Toggle Group */} + {/* New Folder + Filter Toggle Group */} {onCreateFolder && ( @@ -172,6 +182,70 @@ export function DocumentsFilters({ + {/* AI Sort Toggle */} + {onToggleAiSort && ( + + + + + + {aiSortBusy + ? "AI sort in progress..." + : aiSortEnabled + ? "AI sort active — click to disable" + : "Enable AI sort"} + + + )} + {/* Search Input */}
diff --git a/surfsense_web/components/documents/FolderTreeView.tsx b/surfsense_web/components/documents/FolderTreeView.tsx index 4988e87e7..ca833949c 100644 --- a/surfsense_web/components/documents/FolderTreeView.tsx +++ b/surfsense_web/components/documents/FolderTreeView.tsx @@ -90,6 +90,8 @@ export function FolderTreeView({ const [openContextMenuId, setOpenContextMenuId] = useState(null); + const [manuallyCollapsedAiIds, setManuallyCollapsedAiIds] = useState>(new Set()); + // Single subscription for rename state — derived boolean passed to each FolderNode const [renamingFolderId, setRenamingFolderId] = useAtom(renamingFolderIdAtom); const handleStartRename = useCallback( @@ -98,6 +100,38 @@ export function FolderTreeView({ ); const handleCancelRename = useCallback(() => setRenamingFolderId(null), [setRenamingFolderId]); + const aiSortFolderLevels = useMemo(() => { + const map = new Map(); + for (const f of folders) { + if (f.metadata?.ai_sort === true && typeof f.metadata?.ai_sort_level === "number") { + map.set(f.id, f.metadata.ai_sort_level as number); + } + } + return map; + }, [folders]); + + const handleToggleExpand = useCallback( + (folderId: number) => { + const aiLevel = aiSortFolderLevels.get(folderId); + if (aiLevel !== undefined && aiLevel < 4) { + // AI-auto-expanded folder: only toggle the manual-collapse set. + // Calling onToggleExpand would add it to expandedIds and fight auto-expand. + setManuallyCollapsedAiIds((prev) => { + const next = new Set(prev); + if (next.has(folderId)) { + next.delete(folderId); + } else { + next.add(folderId); + } + return next; + }); + return; + } + onToggleExpand(folderId); + }, + [onToggleExpand, aiSortFolderLevels] + ); + const effectiveActiveTypes = useMemo(() => { if ( activeTypes.includes("FILE" as DocumentTypeEnum) && @@ -212,9 +246,16 @@ export function FolderTreeView({ function renderLevel(parentId: number | null, depth: number): React.ReactNode[] { const key = parentId ?? "root"; - const childFolders = (foldersByParent[key] ?? []) - .slice() - .sort((a, b) => a.position.localeCompare(b.position)); + const childFolders = (foldersByParent[key] ?? []).slice().sort((a, b) => { + const aIsDate = + a.metadata?.ai_sort === true && a.metadata?.ai_sort_level === 2; + const bIsDate = + b.metadata?.ai_sort === true && b.metadata?.ai_sort_level === 2; + if (aIsDate && bIsDate) { + return b.name.localeCompare(a.name); + } + return a.position.localeCompare(b.position); + }); const visibleFolders = hasDescendantMatch ? childFolders.filter((f) => hasDescendantMatch[f.id]) : childFolders; @@ -226,6 +267,32 @@ export function FolderTreeView({ const nodes: React.ReactNode[] = []; + if (parentId === null) { + const processingDocs = childDocs.filter((d) => { + const state = d.status?.state; + return state === "pending" || state === "processing"; + }); + for (const d of processingDocs) { + nodes.push( + setOpenContextMenuId(open ? `doc-${d.id}` : null)} + /> + ); + } + } + for (let i = 0; i < visibleFolders.length; i++) { const f = visibleFolders[i]; const siblingPositions = { @@ -233,8 +300,15 @@ export function FolderTreeView({ after: i < visibleFolders.length - 1 ? visibleFolders[i + 1].position : null, }; - const isAutoExpanded = !!searchQuery && !!hasDescendantMatch?.[f.id]; - const isExpanded = expandedIds.has(f.id) || isAutoExpanded; + const isSearchAutoExpanded = !!searchQuery && !!hasDescendantMatch?.[f.id]; + const isAiAutoExpandCandidate = + f.metadata?.ai_sort === true && + typeof f.metadata?.ai_sort_level === "number" && + (f.metadata.ai_sort_level as number) < 4; + const isManuallyCollapsed = manuallyCollapsedAiIds.has(f.id); + const isExpanded = isManuallyCollapsed + ? isSearchAutoExpanded + : expandedIds.has(f.id) || isSearchAutoExpanded || isAiAutoExpandCandidate; nodes.push( { + const state = d.status?.state; + return state !== "pending" && state !== "processing"; + }) + : childDocs; + + for (const d of remainingDocs) { nodes.push( , , , etc. tags correctly. +// remarkMdx treats { } as JSX expression delimiters and does NOT support +// HTML comments (). Arbitrary markdown from document conversions +// (e.g. PDF-to-markdown via Azure/DocIntel) can contain constructs that +// break the MDX parser. This module sanitises them before deserialization. // --------------------------------------------------------------------------- const FENCED_OR_INLINE_CODE = /(```[\s\S]*?```|`[^`\n]+`)/g; -export function escapeMdxExpressions(md: string): string { +// Strip HTML comments that MDX cannot parse. +// PDF converters emit , , etc. +// MDX uses JSX-style comments and chokes on HTML comments, causing the +// parser to stop at the first occurrence. +// - becomes a thematic break (---) +// - All other HTML comments are removed +function stripHtmlComments(md: string): string { + return md + .replace(//gi, "\n---\n") + .replace(//g, ""); +} + +// Convert
...
blocks to plain text blockquotes. +//
with arbitrary text content is not valid JSX, causing the MDX +// parser to fail. +function convertFigureBlocks(md: string): string { + return md.replace(/]*>([\s\S]*?)<\/figure>/gi, (_match, inner: string) => { + const trimmed = (inner as string).trim(); + if (!trimmed) return ""; + const quoted = trimmed + .split("\n") + .map((line) => `> ${line}`) + .join("\n"); + return `\n${quoted}\n`; + }); +} + +// Escape unescaped { and } outside of fenced/inline code so remarkMdx +// treats them as literal characters rather than JSX expression delimiters. +function escapeCurlyBraces(md: string): string { const parts = md.split(FENCED_OR_INLINE_CODE); return parts .map((part, i) => { - // Odd indices are code blocks / inline code – leave untouched if (i % 2 === 1) return part; - // Escape { and } that are NOT already escaped (no preceding \) return part.replace(/(?(null); const isElectron = typeof window !== "undefined" && !!window.electronAPI; + // AI File Sort state + const { data: searchSpaces, refetch: refetchSearchSpaces } = useAtomValue(searchSpacesAtom); + const activeSearchSpace = useMemo( + () => searchSpaces?.find((s) => s.id === searchSpaceId), + [searchSpaces, searchSpaceId] + ); + const aiSortEnabled = activeSearchSpace?.ai_file_sort_enabled ?? false; + const [aiSortBusy, setAiSortBusy] = useState(false); + const [aiSortConfirmOpen, setAiSortConfirmOpen] = useState(false); + + const handleToggleAiSort = useCallback(() => { + if (aiSortEnabled) { + // Disable: just update the setting, no confirmation needed + setAiSortBusy(true); + searchSpacesApiService + .updateSearchSpace({ id: searchSpaceId, data: { ai_file_sort_enabled: false } }) + .then(() => { + refetchSearchSpaces(); + toast.success("AI file sorting disabled"); + }) + .catch(() => toast.error("Failed to disable AI file sorting")) + .finally(() => setAiSortBusy(false)); + } else { + setAiSortConfirmOpen(true); + } + }, [aiSortEnabled, searchSpaceId, refetchSearchSpaces]); + + const handleConfirmEnableAiSort = useCallback(() => { + setAiSortConfirmOpen(false); + setAiSortBusy(true); + searchSpacesApiService + .updateSearchSpace({ id: searchSpaceId, data: { ai_file_sort_enabled: true } }) + .then(() => searchSpacesApiService.triggerAiSort(searchSpaceId)) + .then(() => { + refetchSearchSpaces(); + toast.success("AI file sorting enabled — organizing your documents in the background"); + }) + .catch(() => toast.error("Failed to enable AI file sorting")) + .finally(() => setAiSortBusy(false)); + }, [searchSpaceId, refetchSearchSpaces]); + const handleWatchLocalFolder = useCallback(async () => { const api = window.electronAPI; if (!api?.selectFolder) return; @@ -905,6 +948,9 @@ export function DocumentsSidebar({ onToggleType={onToggleType} activeTypes={activeTypes} onCreateFolder={() => handleCreateFolder(null)} + aiSortEnabled={aiSortEnabled} + aiSortBusy={aiSortBusy} + onToggleAiSort={handleToggleAiSort} />
@@ -1066,6 +1112,25 @@ export function DocumentsSidebar({ + + + + + Enable AI File Sorting? + + All documents in this search space will be organized into folders by + connector type, date, and AI-generated categories. New documents will + also be sorted automatically. You can disable this at any time. + + + + Cancel + + Enable + + + + ); diff --git a/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx index 371f3dc6d..5d8b530c0 100644 --- a/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx @@ -22,8 +22,8 @@ import { useParams, useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import { useCallback, useDeferredValue, useEffect, useMemo, useRef, useState } from "react"; import { setTargetCommentIdAtom } from "@/atoms/chat/current-thread.atom"; -import { convertRenderedToDisplay } from "@/components/chat-comments/comment-item/comment-item"; -import { getDocumentTypeLabel } from "@/components/documents/DocumentTypeIcon"; +import { convertRenderedToDisplay } from "@/lib/comments/utils"; +import { getDocumentTypeLabel } from "@/lib/documents/document-type-labels"; import { Tabs, TabsList, TabsTrigger } from "@/components/ui/animated-tabs"; import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; import { Button } from "@/components/ui/button"; diff --git a/surfsense_web/components/layout/ui/sidebar/SidebarSlideOutPanel.tsx b/surfsense_web/components/layout/ui/sidebar/SidebarSlideOutPanel.tsx index 5195082cd..661f76ed3 100644 --- a/surfsense_web/components/layout/ui/sidebar/SidebarSlideOutPanel.tsx +++ b/surfsense_web/components/layout/ui/sidebar/SidebarSlideOutPanel.tsx @@ -3,8 +3,8 @@ import { AnimatePresence, motion } from "motion/react"; import { useCallback, useEffect } from "react"; import { useMediaQuery } from "@/hooks/use-media-query"; +import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events"; -export const SLIDEOUT_PANEL_OPENED_EVENT = "slideout-panel-opened"; interface SidebarSlideOutPanelProps { open: boolean; diff --git a/surfsense_web/components/new-chat/chat-header.tsx b/surfsense_web/components/new-chat/chat-header.tsx index 0c5253c6c..4716418ee 100644 --- a/surfsense_web/components/new-chat/chat-header.tsx +++ b/surfsense_web/components/new-chat/chat-header.tsx @@ -44,21 +44,28 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) { const [isVisionGlobal, setIsVisionGlobal] = useState(false); const [visionDialogMode, setVisionDialogMode] = useState<"create" | "edit" | "view">("view"); + // Default provider for create dialogs + const [defaultLLMProvider, setDefaultLLMProvider] = useState(); + const [defaultImageProvider, setDefaultImageProvider] = useState(); + const [defaultVisionProvider, setDefaultVisionProvider] = useState(); + // LLM handlers const handleEditLLMConfig = useCallback( (config: NewLLMConfigPublic | GlobalNewLLMConfig, global: boolean) => { setSelectedConfig(config); setIsGlobal(global); setDialogMode(global ? "view" : "edit"); + setDefaultLLMProvider(undefined); setDialogOpen(true); }, [] ); - const handleAddNewLLM = useCallback(() => { + const handleAddNewLLM = useCallback((provider?: string) => { setSelectedConfig(null); setIsGlobal(false); setDialogMode("create"); + setDefaultLLMProvider(provider); setDialogOpen(true); }, []); @@ -68,10 +75,11 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) { }, []); // Image model handlers - const handleAddImageModel = useCallback(() => { + const handleAddImageModel = useCallback((provider?: string) => { setSelectedImageConfig(null); setIsImageGlobal(false); setImageDialogMode("create"); + setDefaultImageProvider(provider); setImageDialogOpen(true); }, []); @@ -80,6 +88,7 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) { setSelectedImageConfig(config); setIsImageGlobal(global); setImageDialogMode(global ? "view" : "edit"); + setDefaultImageProvider(undefined); setImageDialogOpen(true); }, [] @@ -91,10 +100,11 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) { }, []); // Vision model handlers - const handleAddVisionModel = useCallback(() => { + const handleAddVisionModel = useCallback((provider?: string) => { setSelectedVisionConfig(null); setIsVisionGlobal(false); setVisionDialogMode("create"); + setDefaultVisionProvider(provider); setVisionDialogOpen(true); }, []); @@ -103,6 +113,7 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) { setSelectedVisionConfig(config); setIsVisionGlobal(global); setVisionDialogMode(global ? "view" : "edit"); + setDefaultVisionProvider(undefined); setVisionDialogOpen(true); }, [] @@ -131,6 +142,7 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) { isGlobal={isGlobal} searchSpaceId={searchSpaceId} mode={dialogMode} + defaultProvider={defaultLLMProvider} />
); diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 00e37491d..26937e18b 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -1,8 +1,20 @@ "use client"; +import type React from "react"; import { useAtomValue } from "jotai"; -import { Bot, Check, ChevronDown, Edit3, ImageIcon, Plus, ScanEye, Search, Zap } from "lucide-react"; -import { type UIEvent, useCallback, useMemo, useState } from "react"; +import { + Bot, + Check, + ChevronDown, + Edit3, + Eye, + ImageIcon, + Layers, + Plus, + Search, + Zap, +} from "lucide-react"; +import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { globalImageGenConfigsAtom, @@ -22,17 +34,16 @@ import { import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { - Command, - CommandEmpty, - CommandGroup, - CommandInput, - CommandItem, - CommandList, - CommandSeparator, -} from "@/components/ui/command"; + Drawer, + DrawerContent, + DrawerHandle, + DrawerHeader, + DrawerTitle, + DrawerTrigger, +} from "@/components/ui/drawer"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Spinner } from "@/components/ui/spinner"; -import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import type { GlobalImageGenConfig, GlobalNewLLMConfig, @@ -41,16 +52,209 @@ import type { NewLLMConfigPublic, VisionLLMConfig, } from "@/contracts/types/new-llm-config.types"; +import { useIsMobile } from "@/hooks/use-mobile"; import { getProviderIcon } from "@/lib/provider-icons"; import { cn } from "@/lib/utils"; +// ─── Helpers ──────────────────────────────────────────────────────── + +const PROVIDER_NAMES: Record = { + OPENAI: "OpenAI", + ANTHROPIC: "Anthropic", + GOOGLE: "Google", + AZURE: "Azure", + AZURE_OPENAI: "Azure OpenAI", + AWS_BEDROCK: "AWS Bedrock", + BEDROCK: "Bedrock", + DEEPSEEK: "DeepSeek", + MISTRAL: "Mistral", + COHERE: "Cohere", + GROQ: "Groq", + OLLAMA: "Ollama", + TOGETHER_AI: "Together AI", + FIREWORKS_AI: "Fireworks AI", + REPLICATE: "Replicate", + HUGGINGFACE: "HuggingFace", + PERPLEXITY: "Perplexity", + XAI: "xAI", + OPENROUTER: "OpenRouter", + CEREBRAS: "Cerebras", + SAMBANOVA: "SambaNova", + VERTEX_AI: "Vertex AI", + MINIMAX: "MiniMax", + MOONSHOT: "Moonshot", + ZHIPU: "Zhipu", + DEEPINFRA: "DeepInfra", + CLOUDFLARE: "Cloudflare", + DATABRICKS: "Databricks", + NSCALE: "NScale", + RECRAFT: "Recraft", + XINFERENCE: "XInference", + CUSTOM: "Custom", + AI21: "AI21", + ALIBABA_QWEN: "Qwen", + ANYSCALE: "Anyscale", + COMETAPI: "CometAPI", +}; + +// Provider keys valid per model type, matching backend enums +// (LiteLLMProvider, ImageGenProvider, VisionProvider in db.py) +const LLM_PROVIDER_KEYS: string[] = [ + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "BEDROCK", + "VERTEX_AI", + "GROQ", + "DEEPSEEK", + "XAI", + "MISTRAL", + "COHERE", + "OPENROUTER", + "TOGETHER_AI", + "FIREWORKS_AI", + "REPLICATE", + "PERPLEXITY", + "OLLAMA", + "CEREBRAS", + "SAMBANOVA", + "DEEPINFRA", + "AI21", + "ALIBABA_QWEN", + "MOONSHOT", + "ZHIPU", + "MINIMAX", + "HUGGINGFACE", + "CLOUDFLARE", + "DATABRICKS", + "ANYSCALE", + "COMETAPI", + "GITHUB_MODELS", + "CUSTOM", +]; + +const IMAGE_PROVIDER_KEYS: string[] = [ + "OPENAI", + "AZURE_OPENAI", + "GOOGLE", + "VERTEX_AI", + "BEDROCK", + "RECRAFT", + "OPENROUTER", + "XINFERENCE", + "NSCALE", +]; + +const VISION_PROVIDER_KEYS: string[] = [ + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "VERTEX_AI", + "BEDROCK", + "XAI", + "OPENROUTER", + "OLLAMA", + "GROQ", + "TOGETHER_AI", + "FIREWORKS_AI", + "DEEPSEEK", + "MISTRAL", + "CUSTOM", +]; + +const PROVIDER_KEYS_BY_TAB: Record = { + llm: LLM_PROVIDER_KEYS, + image: IMAGE_PROVIDER_KEYS, + vision: VISION_PROVIDER_KEYS, +}; + +function formatProviderName(provider: string): string { + const key = provider.toUpperCase(); + return ( + PROVIDER_NAMES[key] ?? + provider.charAt(0).toUpperCase() + + provider.slice(1).toLowerCase().replace(/_/g, " ") + ); +} + +function normalizeText(input: string): string { + return input + .normalize("NFD") + .replace(/\p{Diacritic}/gu, "") + .toLowerCase() + .replace(/[^a-z0-9]+/g, " ") + .trim(); +} + +interface ConfigBase { + id: number; + name: string; + model_name: string; + provider: string; +} + +function filterAndScore( + configs: T[], + selectedProvider: string, + searchQuery: string, +): T[] { + let result = configs; + + if (selectedProvider !== "all") { + result = result.filter( + (c) => c.provider.toUpperCase() === selectedProvider, + ); + } + + if (!searchQuery.trim()) return result; + + const normalized = normalizeText(searchQuery); + const tokens = normalized.split(/\s+/).filter(Boolean); + + const scored = result.map((c) => { + const aggregate = normalizeText( + [c.name, c.model_name, c.provider].join(" "), + ); + let score = 0; + if (aggregate.includes(normalized)) score += 5; + for (const token of tokens) { + if (aggregate.includes(token)) score += 1; + } + return { config: c, score }; + }); + + return scored + .filter((s) => s.score > 0) + .sort((a, b) => b.score - a.score) + .map((s) => s.config); +} + +interface DisplayItem { + config: ConfigBase & Record; + isGlobal: boolean; + isAutoMode: boolean; +} + +// ─── Component ────────────────────────────────────────────────────── + interface ModelSelectorProps { - onEditLLM: (config: NewLLMConfigPublic | GlobalNewLLMConfig, isGlobal: boolean) => void; - onAddNewLLM: () => void; - onEditImage?: (config: ImageGenerationConfig | GlobalImageGenConfig, isGlobal: boolean) => void; - onAddNewImage?: () => void; - onEditVision?: (config: VisionLLMConfig | GlobalVisionLLMConfig, isGlobal: boolean) => void; - onAddNewVision?: () => void; + onEditLLM: ( + config: NewLLMConfigPublic | GlobalNewLLMConfig, + isGlobal: boolean, + ) => void; + onAddNewLLM: (provider?: string) => void; + onEditImage?: ( + config: ImageGenerationConfig | GlobalImageGenConfig, + isGlobal: boolean, + ) => void; + onAddNewImage?: (provider?: string) => void; + onEditVision?: ( + config: VisionLLMConfig | GlobalVisionLLMConfig, + isGlobal: boolean, + ) => void; + onAddNewVision?: (provider?: string) => void; className?: string; } @@ -64,40 +268,69 @@ export function ModelSelector({ className, }: ModelSelectorProps) { const [open, setOpen] = useState(false); - const [activeTab, setActiveTab] = useState<"llm" | "image" | "vision">("llm"); - const [llmSearchQuery, setLlmSearchQuery] = useState(""); - const [imageSearchQuery, setImageSearchQuery] = useState(""); - const [visionSearchQuery, setVisionSearchQuery] = useState(""); - const [llmScrollPos, setLlmScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const [imageScrollPos, setImageScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const [visionScrollPos, setVisionScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const handleListScroll = useCallback( - (setter: typeof setLlmScrollPos) => (e: UIEvent) => { - const el = e.currentTarget; - const atTop = el.scrollTop <= 2; - const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; - setter(atTop ? "top" : atBottom ? "bottom" : "middle"); - }, - [] + const [activeTab, setActiveTab] = useState<"llm" | "image" | "vision">( + "llm", ); + const [searchQuery, setSearchQuery] = useState(""); + const [selectedProvider, setSelectedProvider] = useState("all"); + const [focusedIndex, setFocusedIndex] = useState(-1); + const [showScrollIndicator, setShowScrollIndicator] = useState(true); + const providerSidebarRef = useRef(null); + const modelListRef = useRef(null); + const searchInputRef = useRef(null); + const isMobile = useIsMobile(); - // LLM data - const { data: llmUserConfigs, isLoading: llmUserLoading } = useAtomValue(newLLMConfigsAtom); + // Reset search + provider when tab changes + useEffect(() => { + setSelectedProvider("all"); + setSearchQuery(""); + setFocusedIndex(-1); + }, [activeTab]); + + // Reset on open + useEffect(() => { + if (open) { + setSearchQuery(""); + setSelectedProvider("all"); + } + }, [open]); + + // Cmd/Ctrl+M shortcut + useEffect(() => { + const handler = (e: KeyboardEvent) => { + if ((e.metaKey || e.ctrlKey) && e.key === "m") { + e.preventDefault(); + setOpen((prev) => !prev); + } + }; + document.addEventListener("keydown", handler); + return () => document.removeEventListener("keydown", handler); + }, []); + + // Focus search input on open + useEffect(() => { + if (open && !isMobile) { + requestAnimationFrame(() => searchInputRef.current?.focus()); + } + }, [open, isMobile, activeTab]); + + // ─── Data ─── + const { data: llmUserConfigs, isLoading: llmUserLoading } = + useAtomValue(newLLMConfigsAtom); const { data: llmGlobalConfigs, isLoading: llmGlobalLoading } = useAtomValue(globalNewLLMConfigsAtom); - const { data: preferences, isLoading: prefsLoading } = useAtomValue(llmPreferencesAtom); + const { data: preferences, isLoading: prefsLoading } = + useAtomValue(llmPreferencesAtom); const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom); - const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); - - // Image data + const { mutateAsync: updatePreferences } = useAtomValue( + updateLLMPreferencesMutationAtom, + ); const { data: imageGlobalConfigs, isLoading: imageGlobalLoading } = useAtomValue(globalImageGenConfigsAtom); - const { data: imageUserConfigs, isLoading: imageUserLoading } = useAtomValue(imageGenConfigsAtom); - - // Vision data - const { data: visionGlobalConfigs, isLoading: visionGlobalLoading } = useAtomValue( - globalVisionLLMConfigsAtom - ); + const { data: imageUserConfigs, isLoading: imageUserLoading } = + useAtomValue(imageGenConfigsAtom); + const { data: visionGlobalConfigs, isLoading: visionGlobalLoading } = + useAtomValue(globalVisionLLMConfigsAtom); const { data: visionUserConfigs, isLoading: visionUserLoading } = useAtomValue(visionLLMConfigsAtom); @@ -110,133 +343,220 @@ export function ModelSelector({ visionGlobalLoading || visionUserLoading; - // ─── LLM current config ─── + // ─── Current selected configs ─── const currentLLMConfig = useMemo(() => { if (!preferences) return null; - const agentLlmId = preferences.agent_llm_id; - if (agentLlmId === null || agentLlmId === undefined) return null; - if (agentLlmId <= 0) { - return llmGlobalConfigs?.find((c) => c.id === agentLlmId) ?? null; - } - return llmUserConfigs?.find((c) => c.id === agentLlmId) ?? null; + const id = preferences.agent_llm_id; + if (id === null || id === undefined) return null; + if (id <= 0) return llmGlobalConfigs?.find((c) => c.id === id) ?? null; + return llmUserConfigs?.find((c) => c.id === id) ?? null; }, [preferences, llmGlobalConfigs, llmUserConfigs]); const isLLMAutoMode = - currentLLMConfig && "is_auto_mode" in currentLLMConfig && currentLLMConfig.is_auto_mode; + currentLLMConfig && + "is_auto_mode" in currentLLMConfig && + currentLLMConfig.is_auto_mode; - // ─── Image current config ─── const currentImageConfig = useMemo(() => { if (!preferences) return null; const id = preferences.image_generation_config_id; if (id === null || id === undefined) return null; - const globalMatch = imageGlobalConfigs?.find((c) => c.id === id); - if (globalMatch) return globalMatch; - return imageUserConfigs?.find((c) => c.id === id) ?? null; + return ( + imageGlobalConfigs?.find((c) => c.id === id) ?? + imageUserConfigs?.find((c) => c.id === id) ?? + null + ); }, [preferences, imageGlobalConfigs, imageUserConfigs]); const isImageAutoMode = - currentImageConfig && "is_auto_mode" in currentImageConfig && currentImageConfig.is_auto_mode; + currentImageConfig && + "is_auto_mode" in currentImageConfig && + currentImageConfig.is_auto_mode; - // ─── Vision current config ─── const currentVisionConfig = useMemo(() => { if (!preferences) return null; const id = preferences.vision_llm_config_id; if (id === null || id === undefined) return null; - const globalMatch = visionGlobalConfigs?.find((c) => c.id === id); - if (globalMatch) return globalMatch; - return visionUserConfigs?.find((c) => c.id === id) ?? null; + return ( + visionGlobalConfigs?.find((c) => c.id === id) ?? + visionUserConfigs?.find((c) => c.id === id) ?? + null + ); }, [preferences, visionGlobalConfigs, visionUserConfigs]); - const isVisionAutoMode = useMemo(() => { - return ( - currentVisionConfig && - "is_auto_mode" in currentVisionConfig && - currentVisionConfig.is_auto_mode + const isVisionAutoMode = + currentVisionConfig && + "is_auto_mode" in currentVisionConfig && + currentVisionConfig.is_auto_mode; + + // ─── Filtered configs (separate global / user for section headers) ─── + const filteredLLMGlobal = useMemo( + () => + filterAndScore(llmGlobalConfigs ?? [], selectedProvider, searchQuery), + [llmGlobalConfigs, selectedProvider, searchQuery], + ); + const filteredLLMUser = useMemo( + () => + filterAndScore(llmUserConfigs ?? [], selectedProvider, searchQuery), + [llmUserConfigs, selectedProvider, searchQuery], + ); + const filteredImageGlobal = useMemo( + () => + filterAndScore( + imageGlobalConfigs ?? [], + selectedProvider, + searchQuery, + ), + [imageGlobalConfigs, selectedProvider, searchQuery], + ); + const filteredImageUser = useMemo( + () => + filterAndScore( + imageUserConfigs ?? [], + selectedProvider, + searchQuery, + ), + [imageUserConfigs, selectedProvider, searchQuery], + ); + const filteredVisionGlobal = useMemo( + () => + filterAndScore( + visionGlobalConfigs ?? [], + selectedProvider, + searchQuery, + ), + [visionGlobalConfigs, selectedProvider, searchQuery], + ); + const filteredVisionUser = useMemo( + () => + filterAndScore( + visionUserConfigs ?? [], + selectedProvider, + searchQuery, + ), + [visionUserConfigs, selectedProvider, searchQuery], + ); + + // Combined display list for keyboard navigation + const currentDisplayItems: DisplayItem[] = useMemo(() => { + const toItems = ( + configs: ConfigBase[], + isGlobal: boolean, + ): DisplayItem[] => + configs.map((c) => ({ + config: c as ConfigBase & Record, + isGlobal, + isAutoMode: + isGlobal && + "is_auto_mode" in c && + !!(c as Record).is_auto_mode, + })); + + switch (activeTab) { + case "llm": + return [ + ...toItems(filteredLLMGlobal, true), + ...toItems(filteredLLMUser, false), + ]; + case "image": + return [ + ...toItems(filteredImageGlobal, true), + ...toItems(filteredImageUser, false), + ]; + case "vision": + return [ + ...toItems(filteredVisionGlobal, true), + ...toItems(filteredVisionUser, false), + ]; + } + }, [ + activeTab, + filteredLLMGlobal, + filteredLLMUser, + filteredImageGlobal, + filteredImageUser, + filteredVisionGlobal, + filteredVisionUser, + ]); + + // ─── Provider sidebar data ─── + // Collect which providers actually have configured models for the active tab + const configuredProviderSet = useMemo(() => { + const configs = + activeTab === "llm" + ? [ + ...(llmGlobalConfigs ?? []), + ...(llmUserConfigs ?? []), + ] + : activeTab === "image" + ? [ + ...(imageGlobalConfigs ?? []), + ...(imageUserConfigs ?? []), + ] + : [ + ...(visionGlobalConfigs ?? []), + ...(visionUserConfigs ?? []), + ]; + const set = new Set(); + for (const c of configs) { + if (c.provider) set.add(c.provider.toUpperCase()); + } + return set; + }, [ + activeTab, + llmGlobalConfigs, + llmUserConfigs, + imageGlobalConfigs, + imageUserConfigs, + visionGlobalConfigs, + visionUserConfigs, + ]); + + // Show only providers valid for the active tab; configured ones first + const activeProviders = useMemo(() => { + const tabKeys = PROVIDER_KEYS_BY_TAB[activeTab] ?? LLM_PROVIDER_KEYS; + const configured = tabKeys.filter((p) => + configuredProviderSet.has(p), ); - }, [currentVisionConfig]); - - // ─── LLM filtering ─── - const filteredLLMGlobal = useMemo(() => { - if (!llmGlobalConfigs) return []; - if (!llmSearchQuery) return llmGlobalConfigs; - const q = llmSearchQuery.toLowerCase(); - return llmGlobalConfigs.filter( - (c) => - c.name.toLowerCase().includes(q) || - c.model_name.toLowerCase().includes(q) || - c.provider.toLowerCase().includes(q) + const unconfigured = tabKeys.filter( + (p) => !configuredProviderSet.has(p), ); - }, [llmGlobalConfigs, llmSearchQuery]); + return ["all", ...configured, ...unconfigured]; + }, [activeTab, configuredProviderSet]); - const filteredLLMUser = useMemo(() => { - if (!llmUserConfigs) return []; - if (!llmSearchQuery) return llmUserConfigs; - const q = llmSearchQuery.toLowerCase(); - return llmUserConfigs.filter( - (c) => - c.name.toLowerCase().includes(q) || - c.model_name.toLowerCase().includes(q) || - c.provider.toLowerCase().includes(q) - ); - }, [llmUserConfigs, llmSearchQuery]); + const providerModelCounts = useMemo(() => { + const allConfigs = + activeTab === "llm" + ? [ + ...(llmGlobalConfigs ?? []), + ...(llmUserConfigs ?? []), + ] + : activeTab === "image" + ? [ + ...(imageGlobalConfigs ?? []), + ...(imageUserConfigs ?? []), + ] + : [ + ...(visionGlobalConfigs ?? []), + ...(visionUserConfigs ?? []), + ]; + const counts: Record = { all: allConfigs.length }; + for (const c of allConfigs) { + const p = c.provider.toUpperCase(); + counts[p] = (counts[p] || 0) + 1; + } + return counts; + }, [ + activeTab, + llmGlobalConfigs, + llmUserConfigs, + imageGlobalConfigs, + imageUserConfigs, + visionGlobalConfigs, + visionUserConfigs, + ]); - const totalLLMModels = (llmGlobalConfigs?.length ?? 0) + (llmUserConfigs?.length ?? 0); - - // ─── Image filtering ─── - const filteredImageGlobal = useMemo(() => { - if (!imageGlobalConfigs) return []; - if (!imageSearchQuery) return imageGlobalConfigs; - const q = imageSearchQuery.toLowerCase(); - return imageGlobalConfigs.filter( - (c) => - c.name.toLowerCase().includes(q) || - c.model_name.toLowerCase().includes(q) || - c.provider.toLowerCase().includes(q) - ); - }, [imageGlobalConfigs, imageSearchQuery]); - - const filteredImageUser = useMemo(() => { - if (!imageUserConfigs) return []; - if (!imageSearchQuery) return imageUserConfigs; - const q = imageSearchQuery.toLowerCase(); - return imageUserConfigs.filter( - (c) => - c.name.toLowerCase().includes(q) || - c.model_name.toLowerCase().includes(q) || - c.provider.toLowerCase().includes(q) - ); - }, [imageUserConfigs, imageSearchQuery]); - - const totalImageModels = (imageGlobalConfigs?.length ?? 0) + (imageUserConfigs?.length ?? 0); - - // ─── Vision filtering ─── - const filteredVisionGlobal = useMemo(() => { - if (!visionGlobalConfigs) return []; - if (!visionSearchQuery) return visionGlobalConfigs; - const q = visionSearchQuery.toLowerCase(); - return visionGlobalConfigs.filter( - (c) => - c.name.toLowerCase().includes(q) || - c.model_name.toLowerCase().includes(q) || - c.provider.toLowerCase().includes(q) - ); - }, [visionGlobalConfigs, visionSearchQuery]); - - const filteredVisionUser = useMemo(() => { - if (!visionUserConfigs) return []; - if (!visionSearchQuery) return visionUserConfigs; - const q = visionSearchQuery.toLowerCase(); - return visionUserConfigs.filter( - (c) => - c.name.toLowerCase().includes(q) || - c.model_name.toLowerCase().includes(q) || - c.provider.toLowerCase().includes(q) - ); - }, [visionUserConfigs, visionSearchQuery]); - - const totalVisionModels = (visionGlobalConfigs?.length ?? 0) + (visionUserConfigs?.length ?? 0); - - // ─── Handlers ─── + // ─── Selection handlers ─── const handleSelectLLM = useCallback( async (config: NewLLMConfigPublic | GlobalNewLLMConfig) => { if (currentLLMConfig?.id === config.id) { @@ -254,21 +574,11 @@ export function ModelSelector({ }); toast.success(`Switched to ${config.name}`); setOpen(false); - } catch (error) { - console.error("Failed to switch model:", error); + } catch { toast.error("Failed to switch model"); } }, - [currentLLMConfig, searchSpaceId, updatePreferences] - ); - - const handleEditLLMConfig = useCallback( - (e: React.MouseEvent, config: NewLLMConfigPublic | GlobalNewLLMConfig, isGlobal: boolean) => { - e.stopPropagation(); - onEditLLM(config, isGlobal); - setOpen(false); - }, - [onEditLLM] + [currentLLMConfig, searchSpaceId, updatePreferences], ); const handleSelectImage = useCallback( @@ -292,7 +602,7 @@ export function ModelSelector({ toast.error("Failed to switch image model"); } }, - [currentImageConfig, searchSpaceId, updatePreferences] + [currentImageConfig, searchSpaceId, updatePreferences], ); const handleSelectVision = useCallback( @@ -316,667 +626,687 @@ export function ModelSelector({ toast.error("Failed to switch vision model"); } }, - [currentVisionConfig, searchSpaceId, updatePreferences] + [currentVisionConfig, searchSpaceId, updatePreferences], ); - return ( - - - + + + {isAll + ? "All Models" + : formatProviderName( + provider, + )} + {isConfigured + ? ` (${count})` + : " — not configured"} + + + + ); + })} +
+ {!isMobile && showScrollIndicator && ( +
+ +
+ )} +
+ ); + }; + + // ─── Render: Model card ─── + const getSelectedId = () => { + switch (activeTab) { + case "llm": + return currentLLMConfig?.id; + case "image": + return currentImageConfig?.id; + case "vision": + return currentVisionConfig?.id; + } + }; + + const renderModelCard = (item: DisplayItem, index: number) => { + const { config, isAutoMode } = item; + const isSelected = getSelectedId() === config.id; + const isFocused = focusedIndex === index; + const hasCitations = + "citations_enabled" in config && !!config.citations_enabled; + + return ( +
handleSelectItem(item)} + onMouseEnter={() => setFocusedIndex(index)} + className={cn( + "group flex items-start gap-2.5 px-2.5 py-2 rounded-lg cursor-pointer", + "transition-all duration-150 mx-1", + "hover:bg-accent/40 active:scale-[0.99]", + isSelected && "bg-primary/6 dark:bg-primary/8", + isFocused && "bg-accent/50 ring-1 ring-primary/20", + )} + > + {/* Provider icon */} +
+ {getProviderIcon(config.provider as string, { + isAutoMode, + className: "size-5", + })} +
+ + {/* Model info */} +
+
+ + {config.name} + + {isAutoMode && ( + + Recommended + + )} +
+
+ + {isAutoMode + ? "Auto Mode" + : (config.model_name as string)} + + {!isAutoMode && hasCitations && ( + + Citations + + )} +
+
+ + {/* Actions */} +
+ {!isAutoMode && ( + + )} + {isSelected && ( + + )} +
+
+ ); + }; + + // ─── Render: Full content ─── + const renderContent = () => { + const globalItems = currentDisplayItems.filter((i) => i.isGlobal); + const userItems = currentDisplayItems.filter((i) => !i.isGlobal); + const globalStartIdx = 0; + const userStartIdx = globalItems.length; + + const addHandler = + activeTab === "llm" + ? onAddNewLLM + : activeTab === "image" + ? onAddNewImage + : onAddNewVision; + const addLabel = + activeTab === "llm" + ? "Add Model" + : activeTab === "image" + ? "Add Image Model" + : "Add Vision Model"; + + return ( +
+ {/* Tab header */} +
+
+ {( + [ + { + value: "llm" as const, + icon: Zap, + label: "LLM", + }, + { + value: "image" as const, + icon: ImageIcon, + label: "Image", + }, + { + value: "vision" as const, + icon: Eye, + label: "Vision", + }, + ] as const + ).map(({ value, icon: Icon, label }) => ( + + ))} +
+
+ + {/* Two-pane layout */} +
+ {/* Provider sidebar */} + {renderProviderSidebar()} + + {/* Main content */} +
+ {/* Search */} +
+ + + setSearchQuery(e.target.value) + } + onKeyDown={handleKeyDown} + autoFocus={!isMobile} + role="combobox" + aria-expanded={true} + aria-controls="model-selector-list" + className={cn( + "w-full pl-8 pr-3 py-1.5 text-xs rounded-lg", + "bg-secondary/30 border border-border/40", + "focus:outline-none focus:ring-2 focus:ring-primary/20 focus:border-primary/40", + "placeholder:text-muted-foreground/50", + "transition-[box-shadow,border-color] duration-200", + )} + /> +
+ + {/* Provider header when filtered */} + {selectedProvider !== "all" && ( +
+ {getProviderIcon(selectedProvider, { + className: "size-4", + })} + + {formatProviderName(selectedProvider)} + + + {configuredProviderSet.has( + selectedProvider, + ) + ? `${providerModelCounts[selectedProvider] || 0} models` + : "Not configured"} + +
+ )} + + {/* Model list */} +
+ {currentDisplayItems.length === 0 ? ( +
+ {selectedProvider !== "all" && + !configuredProviderSet.has( + selectedProvider, + ) ? ( + <> +
+ {getProviderIcon( + selectedProvider, + { + className: + "size-10", + }, + )} +
+

+ No{" "} + {formatProviderName( + selectedProvider, + )}{" "} + models configured +

+

+ Add a model with this + provider to get started +

+ {addHandler && ( + + )} + + ) : ( + <> + +

+ No models found +

+

+ Try a different search + term +

+ + )} +
+ ) : ( + <> + {globalItems.length > 0 && ( + <> +
+ Global Models +
+ {globalItems.map((item, i) => + renderModelCard( + item, + globalStartIdx + i, + ), + )} + + )} + {globalItems.length > 0 && + userItems.length > 0 && ( +
+ )} + {userItems.length > 0 && ( + <> +
+ Your Configurations +
+ {userItems.map((item, i) => + renderModelCard( + item, + userStartIdx + i, + ), + )} + + )} + + )} +
+ + {/* Add model button */} + {addHandler && ( +
+ +
+ )} +
+
+
+ ); + }; + + // ─── Trigger button ─── + const triggerButton = ( + - +
+ {/* Image */} + {currentImageConfig ? ( + <> + {getProviderIcon(currentImageConfig.provider, { + isAutoMode: isImageAutoMode ?? false, + })} + + {currentImageConfig.name} + + + ) : ( + + )} +
+ {/* Vision */} + {currentVisionConfig ? ( + <> + {getProviderIcon(currentVisionConfig.provider, { + isAutoMode: isVisionAutoMode ?? false, + })} + + {currentVisionConfig.name} + + + ) : ( + + )} + + )} + + + ); + // ─── Shell: Drawer on mobile, Popover on desktop ─── + if (isMobile) { + return ( + + {triggerButton} + + + + Select Model + +
+ {renderContent()} +
+
+
+ ); + } + + return ( + + {triggerButton} e.preventDefault()} > - setActiveTab(v as "llm" | "image" | "vision")} - className="w-full" - > -
- - - - LLM - - - - Image - - - - Vision - - -
- - {/* ─── LLM Tab ─── */} - - - {totalLLMModels > 3 && ( -
- -
- )} - - - -
- {llmGlobalConfigs?.length || llmUserConfigs?.length ? ( - <> - -

No models found

-

Try a different search term

- - ) : ( -

No models found

- )} -
-
- - {/* Global LLM Configs */} - {filteredLLMGlobal.length > 0 && ( - -
- Global Models -
- {filteredLLMGlobal.map((config) => { - const isSelected = currentLLMConfig?.id === config.id; - const isAutoMode = "is_auto_mode" in config && config.is_auto_mode; - return ( - handleSelectLLM(config)} - className={cn( - "mx-2 rounded-lg mb-1 cursor-pointer group transition-all", - "hover:bg-accent/50 dark:hover:bg-white/[0.06]", - isSelected && "bg-accent/80 dark:bg-white/[0.06]", - isAutoMode && "" - )} - > -
-
-
- {getProviderIcon(config.provider, { isAutoMode })} -
-
-
- {config.name} - {isAutoMode && ( - - Recommended - - )} - {isSelected && ( - - )} -
-
- - {isAutoMode ? "Auto Mode" : config.model_name} - - {!isAutoMode && config.citations_enabled && ( - - Citations - - )} -
-
-
- {!isAutoMode && ( - - )} -
-
- ); - })} -
- )} - - {filteredLLMGlobal.length > 0 && filteredLLMUser.length > 0 && ( - - )} - - {/* User LLM Configs */} - {filteredLLMUser.length > 0 && ( - -
- Your Configurations -
- {filteredLLMUser.map((config) => { - const isSelected = currentLLMConfig?.id === config.id; - return ( - handleSelectLLM(config)} - className={cn( - "mx-2 rounded-lg mb-1 cursor-pointer group transition-all", - "hover:bg-accent/50 dark:hover:bg-white/[0.06]", - isSelected && "bg-accent/80 dark:bg-white/[0.06]" - )} - > -
-
-
{getProviderIcon(config.provider)}
-
-
- {config.name} - {isSelected && ( - - )} -
-
- - {config.model_name} - - {config.citations_enabled && ( - - Citations - - )} -
-
-
- -
-
- ); - })} -
- )} - - {/* Add New LLM Config */} -
- -
-
-
-
- - {/* ─── Image Tab ─── */} - - - {totalImageModels > 3 && ( -
- -
- )} - - -
- {imageGlobalConfigs?.length || imageUserConfigs?.length ? ( - <> - -

No image models found

-

Try a different search term

- - ) : ( -

No image models found

- )} -
-
- - {/* Global Image Configs */} - {filteredImageGlobal.length > 0 && ( - -
- Global Image Models -
- {filteredImageGlobal.map((config) => { - const isSelected = currentImageConfig?.id === config.id; - const isAuto = "is_auto_mode" in config && config.is_auto_mode; - return ( - handleSelectImage(config.id)} - className={cn( - "mx-2 rounded-lg mb-1 cursor-pointer group transition-all hover:bg-accent/50 dark:hover:bg-white/[0.06]", - isSelected && "bg-accent/80 dark:bg-white/[0.06]", - isAuto && "" - )} - > -
-
- {getProviderIcon(config.provider, { isAutoMode: isAuto })} -
-
-
- {config.name} - {isAuto && ( - - Recommended - - )} - {isSelected && } -
- - {isAuto ? "Auto Mode" : config.model_name} - -
- {onEditImage && !isAuto && ( - - )} -
-
- ); - })} -
- )} - - {/* User Image Configs */} - {filteredImageUser.length > 0 && ( - <> - {filteredImageGlobal.length > 0 && ( - - )} - -
- Your Image Models -
- {filteredImageUser.map((config) => { - const isSelected = currentImageConfig?.id === config.id; - return ( - handleSelectImage(config.id)} - className={cn( - "mx-2 rounded-lg mb-1 cursor-pointer group transition-all hover:bg-accent/50 dark:hover:bg-white/[0.06]", - isSelected && "bg-accent/80 dark:bg-white/[0.06]" - )} - > -
-
{getProviderIcon(config.provider)}
-
-
- {config.name} - {isSelected && ( - - )} -
- - {config.model_name} - -
- {onEditImage && ( - - )} -
-
- ); - })} -
- - )} - - {/* Add New Image Config */} - {onAddNewImage && ( -
- -
- )} -
-
-
- - {/* ─── Vision Tab ─── */} - - - {totalVisionModels > 3 && ( -
- -
- )} - - -
- {visionGlobalConfigs?.length || visionUserConfigs?.length ? ( - <> - -

No vision models found

-

Try a different search term

- - ) : ( -

No vision models found

- )} -
-
- - {filteredVisionGlobal.length > 0 && ( - -
- Global Vision Models -
- {filteredVisionGlobal.map((config) => { - const isSelected = currentVisionConfig?.id === config.id; - const isAuto = "is_auto_mode" in config && config.is_auto_mode; - return ( - handleSelectVision(config.id)} - className={cn( - "mx-2 rounded-lg mb-1 cursor-pointer group transition-all hover:bg-accent/50 dark:hover:bg-white/[0.06]", - isSelected && "bg-accent/80 dark:bg-white/[0.06]" - )} - > -
-
- {getProviderIcon(config.provider, { isAutoMode: isAuto })} -
-
-
- {config.name} - {isAuto && ( - - Recommended - - )} - {isSelected && } -
- - {isAuto ? "Auto Mode" : config.model_name} - -
- {onEditVision && !isAuto && ( - - )} -
-
- ); - })} -
- )} - - {filteredVisionUser.length > 0 && ( - <> - {filteredVisionGlobal.length > 0 && ( - - )} - -
- Your Vision Models -
- {filteredVisionUser.map((config) => { - const isSelected = currentVisionConfig?.id === config.id; - return ( - handleSelectVision(config.id)} - className={cn( - "mx-2 rounded-lg mb-1 cursor-pointer group transition-all hover:bg-accent/50 dark:hover:bg-white/[0.06]", - isSelected && "bg-accent/80 dark:bg-white/[0.06]" - )} - > -
-
{getProviderIcon(config.provider)}
-
-
- {config.name} - {isSelected && ( - - )} -
- - {config.model_name} - -
- {onEditVision && ( - - )} -
-
- ); - })} -
- - )} - - {onAddNewVision && ( -
- -
- )} -
-
-
-
+ {renderContent()}
); diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index 43177e383..9d75e9a1d 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -1,4 +1,9 @@ +"use client"; +import React, { useRef, useEffect, useState } from "react"; +import { AnimatePresence, motion } from "motion/react"; +import { IconPlus } from "@tabler/icons-react"; import { Pricing } from "@/components/pricing"; +import { cn } from "@/lib/utils"; const demoPlans = [ { @@ -59,13 +64,280 @@ const demoPlans = [ }, ]; +interface FAQItem { + question: string; + answer: string; +} + +interface FAQSection { + title: string; + items: FAQItem[]; +} + +const faqData: FAQSection[] = [ + { + title: "Pages & Billing", + items: [ + { + question: "What exactly is a \"page\" in SurfSense?", + answer: + "A page is a simple billing unit that measures how much content you add to your knowledge base. For PDFs, one page equals one real PDF page. For other document types like Word, PowerPoint, and Excel files, pages are automatically estimated based on the file. Every file uses at least 1 page.", + }, + { + question: "How does the Pay As You Go plan work?", + answer: + "There's no monthly subscription. When you need more pages, simply purchase 1,000-page packs at $1 each. Purchased pages are added to your account immediately so you can keep indexing right away. You only pay when you actually need more.", + }, + { + question: "What happens if I run out of pages?", + answer: + "SurfSense checks your remaining pages before processing each file. If you don't have enough, the upload is paused and you'll be notified. You can purchase additional page packs at any time to continue. For cloud connector syncs, a small overage may be allowed so your sync doesn't partially fail.", + }, + { + question: "If I delete a document, do I get my pages back?", + answer: + "No. Deleting a document removes it from your knowledge base, but the pages it used are not refunded. Pages track your total usage over time, not how much is currently stored. So be mindful of what you index. Once pages are spent, they're spent even if you later remove the document.", + }, + ], + }, + { + title: "File Types & Connectors", + items: [ + { + question: "Which file types count toward my page limit?", + answer: + "Page limits only apply to document files that need processing, including PDFs, Word documents (DOC, DOCX, ODT, RTF), presentations (PPT, PPTX, ODP), spreadsheets (XLS, XLSX, ODS), ebooks (EPUB), and images (JPG, PNG, TIFF, WebP, BMP). Plain text files, code files, Markdown, CSV, TSV, HTML, audio, and video files do not consume any pages.", + }, + { + question: "How are pages consumed?", + answer: + "Pages are deducted whenever a document file is successfully indexed into your knowledge base, whether through direct uploads or file-based connector syncs (Google Drive, OneDrive, Dropbox, Local Folder). SurfSense checks your remaining pages before processing and only charges you after the file is indexed. Duplicate documents are automatically detected and won't cost you extra pages.", + }, + { + question: "Do connectors like Slack, Notion, or Gmail use pages?", + answer: + "No. Connectors that work with structured text data like Slack, Discord, Notion, Confluence, Jira, Linear, ClickUp, GitHub, Gmail, Google Calendar, Microsoft Teams, Airtable, Elasticsearch, Web Crawler, BookStack, Obsidian, and Luma do not use pages at all. Page limits only apply to file-based connectors that need document processing, such as Google Drive, OneDrive, Dropbox, and Local Folder syncs.", + }, + ], + }, + { + title: "Self-Hosting", + items: [ + { + question: "Can I self-host SurfSense with unlimited pages?", + answer: + "Yes! When self-hosting, you have full control over your page limits. The default self-hosted setup gives you effectively unlimited pages, so you can index as much data as your infrastructure supports.", + }, + ], + }, +]; + +const GridLineHorizontal = ({ + className, + offset, +}: { + className?: string; + offset?: string; +}) => { + return ( +
+ ); +}; + +const GridLineVertical = ({ + className, + offset, +}: { + className?: string; + offset?: string; +}) => { + return ( +
+ ); +}; + +function PricingFAQ() { + const [activeId, setActiveId] = useState(null); + const containerRef = useRef(null); + + useEffect(() => { + function handleClickOutside(event: MouseEvent) { + if ( + containerRef.current && + !containerRef.current.contains(event.target as Node) + ) { + setActiveId(null); + } + } + + document.addEventListener("mousedown", handleClickOutside); + return () => document.removeEventListener("mousedown", handleClickOutside); + }, []); + + const toggleQuestion = (id: string) => { + setActiveId(activeId === id ? null : id); + }; + + return ( +
+
+

+ Frequently Asked Questions +

+

+ Everything you need to know about SurfSense pages and billing. + Can't find what you need? Reach out at{" "} + + rohan@surfsense.com + +

+
+ +
+ {faqData.map((section) => ( +
+

+ {section.title} +

+
+ {section.items.map((item, index) => { + const id = `${section.title}-${index}`; + const isActive = activeId === id; + + return ( +
+ {isActive && ( +
+ + + + +
+ )} + + + {isActive && ( + +

+ {item.answer} +

+
+ )} +
+
+ ); + })} +
+
+ ))} +
+
+ ); +} + function PricingBasic() { return ( - + <> + + + ); } diff --git a/surfsense_web/components/shared/image-config-dialog.tsx b/surfsense_web/components/shared/image-config-dialog.tsx index 2ae53ccca..1b94f0a35 100644 --- a/surfsense_web/components/shared/image-config-dialog.tsx +++ b/surfsense_web/components/shared/image-config-dialog.tsx @@ -48,6 +48,7 @@ interface ImageConfigDialogProps { isGlobal: boolean; searchSpaceId: number; mode: "create" | "edit" | "view"; + defaultProvider?: string; } const INITIAL_FORM = { @@ -67,6 +68,7 @@ export function ImageConfigDialog({ isGlobal, searchSpaceId, mode, + defaultProvider, }: ImageConfigDialogProps) { const [isSubmitting, setIsSubmitting] = useState(false); const [formData, setFormData] = useState(INITIAL_FORM); @@ -87,11 +89,11 @@ export function ImageConfigDialog({ api_version: config.api_version || "", }); } else if (mode === "create") { - setFormData(INITIAL_FORM); + setFormData({ ...INITIAL_FORM, provider: defaultProvider ?? "" }); } setScrollPos("top"); } - }, [open, mode, config, isGlobal]); + }, [open, mode, config, isGlobal, defaultProvider]); const { mutateAsync: createConfig } = useAtomValue(createImageGenConfigMutationAtom); const { mutateAsync: updateConfig } = useAtomValue(updateImageGenConfigMutationAtom); diff --git a/surfsense_web/components/shared/model-config-dialog.tsx b/surfsense_web/components/shared/model-config-dialog.tsx index 4d2373b49..1a3c8e4a0 100644 --- a/surfsense_web/components/shared/model-config-dialog.tsx +++ b/surfsense_web/components/shared/model-config-dialog.tsx @@ -28,6 +28,7 @@ interface ModelConfigDialogProps { isGlobal: boolean; searchSpaceId: number; mode: "create" | "edit" | "view"; + defaultProvider?: string; } export function ModelConfigDialog({ @@ -37,6 +38,7 @@ export function ModelConfigDialog({ isGlobal, searchSpaceId, mode, + defaultProvider, }: ModelConfigDialogProps) { const [isSubmitting, setIsSubmitting] = useState(false); const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); @@ -194,10 +196,12 @@ export function ModelConfigDialog({ {mode === "create" ? ( ) : isGlobal && config ? (
diff --git a/surfsense_web/components/shared/vision-config-dialog.tsx b/surfsense_web/components/shared/vision-config-dialog.tsx index 7332d3dcd..7f96b0594 100644 --- a/surfsense_web/components/shared/vision-config-dialog.tsx +++ b/surfsense_web/components/shared/vision-config-dialog.tsx @@ -49,6 +49,7 @@ interface VisionConfigDialogProps { isGlobal: boolean; searchSpaceId: number; mode: "create" | "edit" | "view"; + defaultProvider?: string; } const INITIAL_FORM = { @@ -68,6 +69,7 @@ export function VisionConfigDialog({ isGlobal, searchSpaceId, mode, + defaultProvider, }: VisionConfigDialogProps) { const [isSubmitting, setIsSubmitting] = useState(false); const [formData, setFormData] = useState(INITIAL_FORM); @@ -87,11 +89,11 @@ export function VisionConfigDialog({ api_version: (config as VisionLLMConfig).api_version || "", }); } else if (mode === "create") { - setFormData(INITIAL_FORM); + setFormData({ ...INITIAL_FORM, provider: defaultProvider ?? "" }); } setScrollPos("top"); } - }, [open, mode, config, isGlobal]); + }, [open, mode, config, isGlobal, defaultProvider]); const { mutateAsync: createConfig } = useAtomValue(createVisionLLMConfigMutationAtom); const { mutateAsync: updateConfig } = useAtomValue(updateVisionLLMConfigMutationAtom); diff --git a/surfsense_web/components/tool-ui/image/index.tsx b/surfsense_web/components/tool-ui/image/index.tsx index 9c39f4928..0536ede66 100644 --- a/surfsense_web/components/tool-ui/image/index.tsx +++ b/surfsense_web/components/tool-ui/image/index.tsx @@ -288,7 +288,7 @@ export function Image({ alt={alt} width={0} height={0} - sizes="100vw" + sizes={`(max-width: ${maxWidth}) 100vw, ${maxWidth}`} loading="eager" className={cn( "w-full h-auto transition-transform duration-300", @@ -307,7 +307,7 @@ export function Image({ src={src} alt={alt} fill - sizes="(max-width: 512px) 100vw, 512px" + sizes={`(max-width: ${maxWidth}) 100vw, ${maxWidth}`} className={cn( "transition-transform duration-300", fit === "cover" ? "object-cover" : "object-contain", diff --git a/surfsense_web/contracts/types/search-space.types.ts b/surfsense_web/contracts/types/search-space.types.ts index 7b4fefb62..7449f82b1 100644 --- a/surfsense_web/contracts/types/search-space.types.ts +++ b/surfsense_web/contracts/types/search-space.types.ts @@ -10,6 +10,7 @@ export const searchSpace = z.object({ citations_enabled: z.boolean(), qna_custom_instructions: z.string().nullable(), shared_memory_md: z.string().nullable().optional(), + ai_file_sort_enabled: z.boolean().optional().default(false), member_count: z.number(), is_owner: z.boolean(), }); @@ -56,6 +57,7 @@ export const updateSearchSpaceRequest = z.object({ citations_enabled: true, qna_custom_instructions: true, shared_memory_md: true, + ai_file_sort_enabled: true, }) .partial(), }); diff --git a/surfsense_web/lib/apis/search-spaces-api.service.ts b/surfsense_web/lib/apis/search-spaces-api.service.ts index 3e2006e46..e593245f8 100644 --- a/surfsense_web/lib/apis/search-spaces-api.service.ts +++ b/surfsense_web/lib/apis/search-spaces-api.service.ts @@ -1,3 +1,4 @@ +import { z } from "zod"; import { type CreateSearchSpaceRequest, createSearchSpaceRequest, @@ -117,6 +118,17 @@ class SearchSpacesApiService { return baseApiService.delete(`/api/v1/searchspaces/${request.id}`, deleteSearchSpaceResponse); }; + /** + * Trigger AI file sorting for all documents in a search space + */ + triggerAiSort = async (searchSpaceId: number) => { + return baseApiService.post( + `/api/v1/searchspaces/${searchSpaceId}/ai-sort`, + z.object({ message: z.string() }), + {} + ); + }; + /** * Leave a search space (remove own membership) * This is used by non-owners to leave a shared search space diff --git a/surfsense_web/lib/comments/utils.ts b/surfsense_web/lib/comments/utils.ts new file mode 100644 index 000000000..acbcee506 --- /dev/null +++ b/surfsense_web/lib/comments/utils.ts @@ -0,0 +1,4 @@ +export function convertRenderedToDisplay(contentRendered: string): string { + // Convert @{DisplayName} format to @DisplayName for editing + return contentRendered.replace(/@\{([^}]+)\}/g, "@$1"); +} diff --git a/surfsense_web/lib/documents/document-type-labels.ts b/surfsense_web/lib/documents/document-type-labels.ts new file mode 100644 index 000000000..844961886 --- /dev/null +++ b/surfsense_web/lib/documents/document-type-labels.ts @@ -0,0 +1,41 @@ +export function getDocumentTypeLabel(type: string): string { + const labelMap: Record = { + EXTENSION: "Extension", + CRAWLED_URL: "Web Page", + FILE: "File", + SLACK_CONNECTOR: "Slack", + TEAMS_CONNECTOR: "Microsoft Teams", + ONEDRIVE_FILE: "OneDrive", + DROPBOX_FILE: "Dropbox", + NOTION_CONNECTOR: "Notion", + YOUTUBE_VIDEO: "YouTube Video", + GITHUB_CONNECTOR: "GitHub", + LINEAR_CONNECTOR: "Linear", + DISCORD_CONNECTOR: "Discord", + JIRA_CONNECTOR: "Jira", + CONFLUENCE_CONNECTOR: "Confluence", + CLICKUP_CONNECTOR: "ClickUp", + GOOGLE_CALENDAR_CONNECTOR: "Google Calendar", + GOOGLE_GMAIL_CONNECTOR: "Gmail", + GOOGLE_DRIVE_FILE: "Google Drive", + AIRTABLE_CONNECTOR: "Airtable", + LUMA_CONNECTOR: "Luma", + ELASTICSEARCH_CONNECTOR: "Elasticsearch", + BOOKSTACK_CONNECTOR: "BookStack", + CIRCLEBACK: "Circleback", + OBSIDIAN_CONNECTOR: "Obsidian", + LOCAL_FOLDER_FILE: "Local Folder", + SURFSENSE_DOCS: "SurfSense Docs", + NOTE: "Note", + COMPOSIO_GOOGLE_DRIVE_CONNECTOR: "Composio Google Drive", + COMPOSIO_GMAIL_CONNECTOR: "Composio Gmail", + COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: "Composio Google Calendar", + }; + return ( + labelMap[type] || + type + .split("_") + .map((word) => word.charAt(0) + word.slice(1).toLowerCase()) + .join(" ") + ); +} diff --git a/surfsense_web/lib/layout-events.ts b/surfsense_web/lib/layout-events.ts new file mode 100644 index 000000000..45c52f7a4 --- /dev/null +++ b/surfsense_web/lib/layout-events.ts @@ -0,0 +1 @@ +export const SLIDEOUT_PANEL_OPENED_EVENT = "slideout-panel-opened";