diff --git a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py index 19f21bbc6..9394d68b4 100644 --- a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py @@ -430,7 +430,10 @@ async def search_knowledge_base_async( connectors = _normalize_connectors(connectors_to_search, available_connectors) perf.info( "[kb_search] searching %d connectors: %s (space=%d, top_k=%d)", - len(connectors), connectors[:5], search_space_id, top_k, + len(connectors), + connectors[:5], + search_space_id, + top_k, ) connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = { @@ -510,13 +513,17 @@ async def search_knowledge_base_async( _, chunks = await connector_method(**kwargs) perf.info( "[kb_search] connector=%s results=%d in %.3fs", - connector, len(chunks), time.perf_counter() - t_conn, + connector, + len(chunks), + time.perf_counter() - t_conn, ) return chunks except Exception as e: perf.warning( "[kb_search] connector=%s FAILED in %.3fs: %s", - connector, time.perf_counter() - t_conn, e, + connector, + time.perf_counter() - t_conn, + e, ) return [] @@ -525,7 +532,8 @@ async def search_knowledge_base_async( *[_search_one_connector(connector) for connector in connectors] ) perf.info( - "[kb_search] all connectors gathered in %.3fs", time.perf_counter() - t_gather, + "[kb_search] all connectors gathered in %.3fs", + time.perf_counter() - t_gather, ) for chunks in connector_results: all_documents.extend(chunks) @@ -576,7 +584,11 @@ async def search_knowledge_base_async( result = format_documents_for_context(deduplicated, max_chars=output_budget) perf.info( "[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d space=%d", - time.perf_counter() - t0, len(all_documents), len(deduplicated), len(result), search_space_id, + time.perf_counter() - t0, + len(all_documents), + len(deduplicated), + len(result), + search_space_id, ) return result diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index e6d7977cc..9460f900c 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -147,7 +147,9 @@ class IndexingPipelineService: await self.session.commit() perf.info( "[indexing] prepare_for_indexing in %.3fs input=%d output=%d", - time.perf_counter() - t0, len(connector_docs), len(documents), + time.perf_counter() - t0, + len(connector_docs), + len(documents), ) return documents except IntegrityError: @@ -185,7 +187,8 @@ class IndexingPipelineService: ) perf.info( "[indexing] summarize_document doc=%d in %.3fs", - document.id, time.perf_counter() - t_step, + document.id, + time.perf_counter() - t_step, ) elif connector_doc.should_summarize and connector_doc.fallback_summary: content = connector_doc.fallback_summary @@ -196,7 +199,8 @@ class IndexingPipelineService: embedding = embed_text(content) perf.debug( "[indexing] embed_text (summary) doc=%d in %.3fs", - document.id, time.perf_counter() - t_step, + document.id, + time.perf_counter() - t_step, ) await self.session.execute( @@ -213,7 +217,9 @@ class IndexingPipelineService: ] perf.info( "[indexing] chunk+embed doc=%d chunks=%d in %.3fs", - document.id, len(chunks), time.perf_counter() - t_step, + document.id, + len(chunks), + time.perf_counter() - t_step, ) document.content = content @@ -224,7 +230,9 @@ class IndexingPipelineService: await self.session.commit() perf.info( "[indexing] index TOTAL doc=%d chunks=%d in %.3fs", - document.id, len(chunks), time.perf_counter() - t_index, + document.id, + len(chunks), + time.perf_counter() - t_index, ) log_index_success(ctx, chunk_count=len(chunks)) diff --git a/surfsense_backend/app/retriever/chunks_hybrid_search.py b/surfsense_backend/app/retriever/chunks_hybrid_search.py index ed3f63acc..38ecba96c 100644 --- a/surfsense_backend/app/retriever/chunks_hybrid_search.py +++ b/surfsense_backend/app/retriever/chunks_hybrid_search.py @@ -76,7 +76,10 @@ class ChucksHybridSearchRetriever: chunks = result.scalars().all() perf.info( "[chunk_search] vector_search DB query in %.3fs results=%d (total %.3fs) space=%d", - time.perf_counter() - t_db, len(chunks), time.perf_counter() - t0, search_space_id, + time.perf_counter() - t_db, + len(chunks), + time.perf_counter() - t0, + search_space_id, ) return chunks @@ -139,7 +142,9 @@ class ChucksHybridSearchRetriever: chunks = result.scalars().all() perf.info( "[chunk_search] full_text_search in %.3fs results=%d space=%d", - time.perf_counter() - t0, len(chunks), search_space_id, + time.perf_counter() - t0, + len(chunks), + search_space_id, ) return chunks @@ -152,6 +157,7 @@ class ChucksHybridSearchRetriever: document_type: str | None = None, start_date: datetime | None = None, end_date: datetime | None = None, + query_embedding: list | None = None, ) -> list: """ Hybrid search that returns **documents** (not individual chunks). @@ -166,6 +172,7 @@ class ChucksHybridSearchRetriever: document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL") start_date: Optional start date for filtering documents by updated_at end_date: Optional end date for filtering documents by updated_at + query_embedding: Pre-computed embedding vector. If None, will be computed here. Returns: List of dictionaries containing document data and relevance scores. Each dict contains: @@ -183,14 +190,14 @@ class ChucksHybridSearchRetriever: perf = get_perf_logger() t0 = time.perf_counter() - # Get embedding for the query - embedding_model = config.embedding_model_instance - t_embed = time.perf_counter() - query_embedding = embedding_model.embed(query_text) - perf.debug( - "[chunk_search] hybrid_search embedding in %.3fs", - time.perf_counter() - t_embed, - ) + if query_embedding is None: + embedding_model = config.embedding_model_instance + t_embed = time.perf_counter() + query_embedding = embedding_model.embed(query_text) + perf.debug( + "[chunk_search] hybrid_search embedding in %.3fs", + time.perf_counter() - t_embed, + ) # RRF constants k = 60 @@ -291,7 +298,10 @@ class ChucksHybridSearchRetriever: chunks_with_scores = result.all() perf.info( "[chunk_search] hybrid_search RRF query in %.3fs results=%d space=%d type=%s", - time.perf_counter() - t_rrf, len(chunks_with_scores), search_space_id, document_type, + time.perf_counter() - t_rrf, + len(chunks_with_scores), + search_space_id, + document_type, ) # If no results were found, return an empty list @@ -392,6 +402,9 @@ class ChucksHybridSearchRetriever: perf.info( "[chunk_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s", - time.perf_counter() - t0, len(final_docs), search_space_id, document_type, + time.perf_counter() - t0, + len(final_docs), + search_space_id, + document_type, ) return final_docs diff --git a/surfsense_backend/app/retriever/documents_hybrid_search.py b/surfsense_backend/app/retriever/documents_hybrid_search.py index 608e1c2e6..f4daf8e26 100644 --- a/surfsense_backend/app/retriever/documents_hybrid_search.py +++ b/surfsense_backend/app/retriever/documents_hybrid_search.py @@ -71,7 +71,9 @@ class DocumentHybridSearchRetriever: documents = result.scalars().all() perf.info( "[doc_search] vector_search in %.3fs results=%d space=%d", - time.perf_counter() - t0, len(documents), search_space_id, + time.perf_counter() - t0, + len(documents), + search_space_id, ) return documents @@ -133,7 +135,9 @@ class DocumentHybridSearchRetriever: documents = result.scalars().all() perf.info( "[doc_search] full_text_search in %.3fs results=%d space=%d", - time.perf_counter() - t0, len(documents), search_space_id, + time.perf_counter() - t0, + len(documents), + search_space_id, ) return documents @@ -146,6 +150,7 @@ class DocumentHybridSearchRetriever: document_type: str | None = None, start_date: datetime | None = None, end_date: datetime | None = None, + query_embedding: list | None = None, ) -> list: """ Hybrid search that returns **documents** (not individual chunks). @@ -160,7 +165,7 @@ class DocumentHybridSearchRetriever: document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL") start_date: Optional start date for filtering documents by updated_at end_date: Optional end date for filtering documents by updated_at - + query_embedding: Pre-computed embedding vector. If None, will be computed here. """ from sqlalchemy import func, select, text from sqlalchemy.orm import joinedload @@ -171,9 +176,9 @@ class DocumentHybridSearchRetriever: perf = get_perf_logger() t0 = time.perf_counter() - # Get embedding for the query - embedding_model = config.embedding_model_instance - query_embedding = embedding_model.embed(query_text) + if query_embedding is None: + embedding_model = config.embedding_model_instance + query_embedding = embedding_model.embed(query_text) # RRF constants k = 60 @@ -325,6 +330,9 @@ class DocumentHybridSearchRetriever: perf.info( "[doc_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s", - time.perf_counter() - t0, len(final_docs), search_space_id, document_type, + time.perf_counter() - t0, + len(final_docs), + search_space_id, + document_type, ) return final_docs diff --git a/surfsense_backend/app/routes/chat_comments_routes.py b/surfsense_backend/app/routes/chat_comments_routes.py index 1c21c0f4a..f5a8fd0af 100644 --- a/surfsense_backend/app/routes/chat_comments_routes.py +++ b/surfsense_backend/app/routes/chat_comments_routes.py @@ -7,6 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.db import User, get_async_session from app.schemas.chat_comments import ( + CommentBatchRequest, + CommentBatchResponse, CommentCreateRequest, CommentListResponse, CommentReplyResponse, @@ -19,6 +21,7 @@ from app.services.chat_comments_service import ( create_reply, delete_comment, get_comments_for_message, + get_comments_for_messages_batch, get_user_mentions, update_comment, ) @@ -27,6 +30,16 @@ from app.users import current_active_user router = APIRouter() +@router.post("/messages/comments/batch", response_model=CommentBatchResponse) +async def batch_list_comments( + request: CommentBatchRequest, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Batch-fetch comments for multiple messages in one request.""" + return await get_comments_for_messages_batch(session, request.message_ids, user) + + @router.get("/messages/{message_id}/comments", response_model=CommentListResponse) async def list_comments( message_id: int, diff --git a/surfsense_backend/app/schemas/chat_comments.py b/surfsense_backend/app/schemas/chat_comments.py index b87ee58a4..984e8b812 100644 --- a/surfsense_backend/app/schemas/chat_comments.py +++ b/surfsense_backend/app/schemas/chat_comments.py @@ -87,6 +87,18 @@ class CommentListResponse(BaseModel): total_count: int +class CommentBatchRequest(BaseModel): + """Request for batch-fetching comments for multiple messages.""" + + message_ids: list[int] = Field(..., min_length=1, max_length=200) + + +class CommentBatchResponse(BaseModel): + """Batch response keyed by message_id.""" + + comments_by_message: dict[int, CommentListResponse] + + # ============================================================================= # Mention Schemas # ============================================================================= diff --git a/surfsense_backend/app/services/chat_comments_service.py b/surfsense_backend/app/services/chat_comments_service.py index c9ca920f6..c2bb65aee 100644 --- a/surfsense_backend/app/services/chat_comments_service.py +++ b/surfsense_backend/app/services/chat_comments_service.py @@ -22,6 +22,7 @@ from app.db import ( ) from app.schemas.chat_comments import ( AuthorResponse, + CommentBatchResponse, CommentListResponse, CommentReplyResponse, CommentResponse, @@ -264,6 +265,146 @@ async def get_comments_for_message( ) +async def get_comments_for_messages_batch( + session: AsyncSession, + message_ids: list[int], + user: User, +) -> CommentBatchResponse: + """ + Batch-fetch comments for multiple messages in a single DB round-trip. + + Validates that all messages exist and belong to search spaces the user + can read comments in, then loads all comments with eager-loaded authors + and replies. + """ + if not message_ids: + return CommentBatchResponse(comments_by_message={}) + + unique_ids = list(set(message_ids)) + + result = await session.execute( + select(NewChatMessage) + .options(selectinload(NewChatMessage.thread)) + .filter(NewChatMessage.id.in_(unique_ids)) + ) + messages = result.scalars().all() + msg_map = {m.id: m for m in messages} + + search_space_ids = {m.thread.search_space_id for m in messages} + permissions_cache: dict[int, set] = {} + for ss_id in search_space_ids: + await check_permission( + session, + user, + ss_id, + Permission.COMMENTS_READ.value, + "You don't have permission to read comments in this search space", + ) + permissions_cache[ss_id] = await get_user_permissions(session, user.id, ss_id) + + result = await session.execute( + select(ChatComment) + .options( + selectinload(ChatComment.author), + selectinload(ChatComment.replies).selectinload(ChatComment.author), + ) + .filter( + ChatComment.message_id.in_(unique_ids), + ChatComment.parent_id.is_(None), + ) + .order_by(ChatComment.created_at) + ) + top_level_comments = result.scalars().all() + + all_mentioned_uuids: set[UUID] = set() + for comment in top_level_comments: + all_mentioned_uuids.update(parse_mentions(comment.content)) + for reply in comment.replies: + all_mentioned_uuids.update(parse_mentions(reply.content)) + + user_names = await get_user_names_for_mentions(session, all_mentioned_uuids) + + comments_by_msg: dict[int, list[ChatComment]] = {mid: [] for mid in unique_ids} + for comment in top_level_comments: + comments_by_msg.setdefault(comment.message_id, []).append(comment) + + comments_by_message: dict[int, CommentListResponse] = {} + for mid in unique_ids: + msg = msg_map.get(mid) + if msg is None: + comments_by_message[mid] = CommentListResponse(comments=[], total_count=0) + continue + + ss_id = msg.thread.search_space_id + user_perms = permissions_cache.get(ss_id, set()) + can_delete_any = has_permission(user_perms, Permission.COMMENTS_DELETE.value) + + comment_responses = [] + for comment in comments_by_msg.get(mid, []): + author = None + if comment.author: + author = AuthorResponse( + id=comment.author.id, + display_name=comment.author.display_name, + avatar_url=comment.author.avatar_url, + email=comment.author.email, + ) + + replies = [] + for reply in sorted(comment.replies, key=lambda r: r.created_at): + reply_author = None + if reply.author: + reply_author = AuthorResponse( + id=reply.author.id, + display_name=reply.author.display_name, + avatar_url=reply.author.avatar_url, + email=reply.author.email, + ) + is_reply_author = ( + reply.author_id == user.id if reply.author_id else False + ) + replies.append( + CommentReplyResponse( + id=reply.id, + content=reply.content, + content_rendered=render_mentions(reply.content, user_names), + author=reply_author, + created_at=reply.created_at, + updated_at=reply.updated_at, + is_edited=reply.updated_at > reply.created_at, + can_edit=is_reply_author, + can_delete=is_reply_author or can_delete_any, + ) + ) + + is_comment_author = ( + comment.author_id == user.id if comment.author_id else False + ) + comment_responses.append( + CommentResponse( + id=comment.id, + message_id=comment.message_id, + content=comment.content, + content_rendered=render_mentions(comment.content, user_names), + author=author, + created_at=comment.created_at, + updated_at=comment.updated_at, + is_edited=comment.updated_at > comment.created_at, + can_edit=is_comment_author, + can_delete=is_comment_author or can_delete_any, + reply_count=len(replies), + replies=replies, + ) + ) + + comments_by_message[mid] = CommentListResponse( + comments=comment_responses, + total_count=len(comment_responses), + ) + + return CommentBatchResponse(comments_by_message=comments_by_message) + + async def create_comment( session: AsyncSession, message_id: int, diff --git a/surfsense_backend/app/services/connector_service.py b/surfsense_backend/app/services/connector_service.py index fa91de391..157e0bab5 100644 --- a/surfsense_backend/app/services/connector_service.py +++ b/surfsense_backend/app/services/connector_service.py @@ -16,6 +16,7 @@ from app.db import ( Document, SearchSourceConnector, SearchSourceConnectorType, + async_session_maker, ) from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever from app.retriever.documents_hybrid_search import DocumentHybridSearchRetriever @@ -248,6 +249,8 @@ class ConnectorService: Returns: List of combined and deduplicated document results """ + from app.config import config + perf = get_perf_logger() t0 = time.perf_counter() @@ -257,39 +260,48 @@ class ConnectorService: # Get more results from each retriever for better fusion retriever_top_k = top_k * 2 - # IMPORTANT: - # These retrievers share the same AsyncSession. AsyncSession does not permit - # concurrent awaits that require DB IO on the same session/connection. - # Running these in parallel can raise: - # "This session is provisioning a new connection; concurrent operations are not permitted" - # - # So we run them sequentially. - t_chunk = time.perf_counter() - chunk_results = await self.chunk_retriever.hybrid_search( - query_text=query_text, - top_k=retriever_top_k, - search_space_id=search_space_id, - document_type=document_type, - start_date=start_date, - end_date=end_date, - ) + # Pre-compute the embedding once so both retrievers reuse it. + t_embed = time.perf_counter() + query_embedding = config.embedding_model_instance.embed(query_text) perf.info( - "[connector_svc] _combined_rrf chunk_retriever in %.3fs results=%d type=%s", - time.perf_counter() - t_chunk, len(chunk_results), document_type, + "[connector_svc] _combined_rrf embedding in %.3fs type=%s", + time.perf_counter() - t_embed, + document_type, ) - t_doc = time.perf_counter() - doc_results = await self.document_retriever.hybrid_search( - query_text=query_text, - top_k=retriever_top_k, - search_space_id=search_space_id, - document_type=document_type, - start_date=start_date, - end_date=end_date, + search_kwargs = { + "query_text": query_text, + "top_k": retriever_top_k, + "search_space_id": search_space_id, + "document_type": document_type, + "start_date": start_date, + "end_date": end_date, + "query_embedding": query_embedding, + } + + # Run chunk and document retrievers in parallel using separate DB sessions + # so they don't contend on a shared AsyncSession connection. + async def _run_chunk_search() -> list[dict[str, Any]]: + async with async_session_maker() as session: + retriever = ChucksHybridSearchRetriever(session) + return await retriever.hybrid_search(**search_kwargs) + + async def _run_doc_search() -> list[dict[str, Any]]: + async with async_session_maker() as session: + retriever = DocumentHybridSearchRetriever(session) + return await retriever.hybrid_search(**search_kwargs) + + t_parallel = time.perf_counter() + chunk_results, doc_results = await asyncio.gather( + _run_chunk_search(), _run_doc_search() ) perf.info( - "[connector_svc] _combined_rrf doc_retriever in %.3fs results=%d type=%s", - time.perf_counter() - t_doc, len(doc_results), document_type, + "[connector_svc] _combined_rrf parallel retrievers in %.3fs " + "chunk_results=%d doc_results=%d type=%s", + time.perf_counter() - t_parallel, + len(chunk_results), + len(doc_results), + document_type, ) # Helper to extract document_id from our doc-grouped result @@ -353,7 +365,10 @@ class ConnectorService: perf.info( "[connector_svc] _combined_rrf_search TOTAL in %.3fs results=%d type=%s space=%d", - time.perf_counter() - t0, len(combined_results), document_type, search_space_id, + time.perf_counter() - t0, + len(combined_results), + document_type, + search_space_id, ) return combined_results diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index e9b84c5cd..7839e4014 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -437,14 +437,16 @@ class ChatLiteLLMRouter(BaseChatModel): except ContextWindowExceededError as e: perf.warning( "[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs", - msg_count, time.perf_counter() - t0, + msg_count, + time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e except LiteLLMBadRequestError as e: if _is_context_overflow_error(e): perf.warning( "[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs", - msg_count, time.perf_counter() - t0, + msg_count, + time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e raise @@ -500,14 +502,16 @@ class ChatLiteLLMRouter(BaseChatModel): except ContextWindowExceededError as e: perf.warning( "[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs", - msg_count, time.perf_counter() - t0, + msg_count, + time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e except LiteLLMBadRequestError as e: if _is_context_overflow_error(e): perf.warning( "[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs", - msg_count, time.perf_counter() - t0, + msg_count, + time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e raise @@ -608,14 +612,16 @@ class ChatLiteLLMRouter(BaseChatModel): except ContextWindowExceededError as e: perf.warning( "[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs", - msg_count, time.perf_counter() - t0, + msg_count, + time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e except LiteLLMBadRequestError as e: if _is_context_overflow_error(e): perf.warning( "[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs", - msg_count, time.perf_counter() - t0, + msg_count, + time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e raise @@ -623,7 +629,8 @@ class ChatLiteLLMRouter(BaseChatModel): t_first_chunk = time.perf_counter() perf.info( "[llm_router] _astream connection established msgs=%d in %.3fs", - msg_count, t_first_chunk - t0, + msg_count, + t_first_chunk - t0, ) chunk_count = 0 @@ -645,7 +652,8 @@ class ChatLiteLLMRouter(BaseChatModel): perf.info( "[llm_router] _astream completed chunks=%d total=%.3fs", - chunk_count, time.perf_counter() - t0, + chunk_count, + time.perf_counter() - t0, ) def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]: diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index cd0b4971c..98fa5b436 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -65,6 +65,7 @@ import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking"; import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import type { Document } from "@/contracts/types/document.types"; +import { useBatchCommentsPreload } from "@/hooks/use-comments"; import { useCommentsElectric } from "@/hooks/use-comments-electric"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { cn } from "@/lib/utils"; @@ -309,6 +310,22 @@ const Composer: FC = () => { // Sync comments for the entire thread via Electric SQL (one subscription per thread) useCommentsElectric(threadId); + // Batch-prefetch comments for all assistant messages so individual useComments + // hooks never fire their own network requests (eliminates N+1 API calls). + // Return a primitive string from the selector so useSyncExternalStore can + // compare snapshots by value and avoid infinite re-render loops. + const assistantIdsKey = useAssistantState(({ thread }) => + thread.messages + .filter((m) => m.role === "assistant" && m.id?.startsWith("msg-")) + .map((m) => m.id!.replace("msg-", "")) + .join(",") + ); + const assistantDbMessageIds = useMemo( + () => (assistantIdsKey ? assistantIdsKey.split(",").map(Number) : []), + [assistantIdsKey] + ); + useBatchCommentsPreload(assistantDbMessageIds); + // Auto-focus editor on new chat page after mount useEffect(() => { if (isThreadEmpty && !hasAutoFocusedRef.current && editorRef.current) { diff --git a/surfsense_web/contracts/types/chat-comments.types.ts b/surfsense_web/contracts/types/chat-comments.types.ts index 46e064a4e..cdeca0a44 100644 --- a/surfsense_web/contracts/types/chat-comments.types.ts +++ b/surfsense_web/contracts/types/chat-comments.types.ts @@ -82,6 +82,22 @@ export const getCommentsResponse = z.object({ total_count: z.number(), }); +/** + * Batch-fetch comments for multiple messages + */ +export const getBatchCommentsRequest = z.object({ + message_ids: z.array(z.number()).min(1).max(200), +}); + +export const commentListResponse = z.object({ + comments: z.array(comment), + total_count: z.number(), +}); + +export const getBatchCommentsResponse = z.object({ + comments_by_message: z.record(z.string(), commentListResponse), +}); + /** * Create comment */ @@ -145,6 +161,8 @@ export type MentionComment = z.infer; export type Mention = z.infer; export type GetCommentsRequest = z.infer; export type GetCommentsResponse = z.infer; +export type GetBatchCommentsRequest = z.infer; +export type GetBatchCommentsResponse = z.infer; export type CreateCommentRequest = z.infer; export type CreateCommentResponse = z.infer; export type CreateReplyRequest = z.infer; diff --git a/surfsense_web/hooks/use-comments.ts b/surfsense_web/hooks/use-comments.ts index 4f027d67c..562f7ae02 100644 --- a/surfsense_web/hooks/use-comments.ts +++ b/surfsense_web/hooks/use-comments.ts @@ -1,4 +1,5 @@ -import { useQuery } from "@tanstack/react-query"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { useEffect, useRef } from "react"; import { chatCommentsApiService } from "@/lib/apis/chat-comments-api.service"; import { cacheKeys } from "@/lib/query-client/cache-keys"; @@ -7,12 +8,84 @@ interface UseCommentsOptions { enabled?: boolean; } +// --------------------------------------------------------------------------- +// Module-level coordination: when a batch request is in-flight, individual +// useComments queryFns piggy-back on it instead of making their own requests. +// --------------------------------------------------------------------------- +let _batchInflight: Promise | null = null; +let _batchTargetIds = new Set(); + export function useComments({ messageId, enabled = true }: UseCommentsOptions) { + const queryClient = useQueryClient(); + return useQuery({ queryKey: cacheKeys.comments.byMessage(messageId), queryFn: async () => { + // Yield one macro-task so the batch prefetch useEffect (which sets + // _batchInflight) has a chance to fire before we decide to fetch. + await new Promise((r) => setTimeout(r, 0)); + + if (_batchInflight && _batchTargetIds.has(messageId)) { + await _batchInflight; + const cached = queryClient.getQueryData(cacheKeys.comments.byMessage(messageId)); + if (cached) return cached; + } + return chatCommentsApiService.getComments({ message_id: messageId }); }, enabled: enabled && !!messageId, + staleTime: 30_000, }); } + +/** + * Batch-fetch comments for all given message IDs in a single request, then + * seed the per-message React Query cache so individual useComments hooks + * resolve from cache instead of firing their own requests. + */ +export function useBatchCommentsPreload(messageIds: number[]) { + const queryClient = useQueryClient(); + const prevKeyRef = useRef(""); + + useEffect(() => { + if (!messageIds.length) return; + + const key = messageIds + .slice() + .sort((a, b) => a - b) + .join(","); + if (key === prevKeyRef.current) return; + prevKeyRef.current = key; + + _batchTargetIds = new Set(messageIds); + let cancelled = false; + + const promise = chatCommentsApiService + .getBatchComments({ message_ids: messageIds }) + .then((data) => { + if (cancelled) return; + for (const [msgIdStr, commentList] of Object.entries(data.comments_by_message)) { + queryClient.setQueryData(cacheKeys.comments.byMessage(Number(msgIdStr)), commentList); + } + }) + .catch(() => { + // Batch failed; individual queryFns will fall through to their own fetch + }) + .finally(() => { + if (_batchInflight === promise) { + _batchInflight = null; + _batchTargetIds = new Set(); + } + }); + + _batchInflight = promise; + + return () => { + cancelled = true; + if (_batchInflight === promise) { + _batchInflight = null; + _batchTargetIds = new Set(); + } + }; + }, [messageIds, queryClient]); +} diff --git a/surfsense_web/lib/apis/chat-comments-api.service.ts b/surfsense_web/lib/apis/chat-comments-api.service.ts index 952de7a25..f1ec7a5d9 100644 --- a/surfsense_web/lib/apis/chat-comments-api.service.ts +++ b/surfsense_web/lib/apis/chat-comments-api.service.ts @@ -8,8 +8,11 @@ import { type DeleteCommentRequest, deleteCommentRequest, deleteCommentResponse, + type GetBatchCommentsRequest, type GetCommentsRequest, type GetMentionsRequest, + getBatchCommentsRequest, + getBatchCommentsResponse, getCommentsRequest, getCommentsResponse, getMentionsRequest, @@ -22,6 +25,22 @@ import { ValidationError } from "@/lib/error"; import { baseApiService } from "./base-api.service"; class ChatCommentsApiService { + /** + * Batch-fetch comments for multiple messages in one request + */ + getBatchComments = async (request: GetBatchCommentsRequest) => { + const parsed = getBatchCommentsRequest.safeParse(request); + + if (!parsed.success) { + const errorMessage = parsed.error.issues.map((issue) => issue.message).join(", "); + throw new ValidationError(`Invalid request: ${errorMessage}`); + } + + return baseApiService.post("/api/v1/messages/comments/batch", getBatchCommentsResponse, { + body: { message_ids: parsed.data.message_ids }, + }); + }; + /** * Get comments for a message */