Merge pull request #846 from MODSetter/dev

feat: perf optimizations
This commit is contained in:
Rohan Verma 2026-02-27 17:24:20 -08:00 committed by GitHub
commit 4105bd0d7a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 424 additions and 67 deletions

View file

@ -430,7 +430,10 @@ async def search_knowledge_base_async(
connectors = _normalize_connectors(connectors_to_search, available_connectors)
perf.info(
"[kb_search] searching %d connectors: %s (space=%d, top_k=%d)",
len(connectors), connectors[:5], search_space_id, top_k,
len(connectors),
connectors[:5],
search_space_id,
top_k,
)
connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
@ -510,13 +513,17 @@ async def search_knowledge_base_async(
_, chunks = await connector_method(**kwargs)
perf.info(
"[kb_search] connector=%s results=%d in %.3fs",
connector, len(chunks), time.perf_counter() - t_conn,
connector,
len(chunks),
time.perf_counter() - t_conn,
)
return chunks
except Exception as e:
perf.warning(
"[kb_search] connector=%s FAILED in %.3fs: %s",
connector, time.perf_counter() - t_conn, e,
connector,
time.perf_counter() - t_conn,
e,
)
return []
@ -525,7 +532,8 @@ async def search_knowledge_base_async(
*[_search_one_connector(connector) for connector in connectors]
)
perf.info(
"[kb_search] all connectors gathered in %.3fs", time.perf_counter() - t_gather,
"[kb_search] all connectors gathered in %.3fs",
time.perf_counter() - t_gather,
)
for chunks in connector_results:
all_documents.extend(chunks)
@ -576,7 +584,11 @@ async def search_knowledge_base_async(
result = format_documents_for_context(deduplicated, max_chars=output_budget)
perf.info(
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d space=%d",
time.perf_counter() - t0, len(all_documents), len(deduplicated), len(result), search_space_id,
time.perf_counter() - t0,
len(all_documents),
len(deduplicated),
len(result),
search_space_id,
)
return result

View file

@ -147,7 +147,9 @@ class IndexingPipelineService:
await self.session.commit()
perf.info(
"[indexing] prepare_for_indexing in %.3fs input=%d output=%d",
time.perf_counter() - t0, len(connector_docs), len(documents),
time.perf_counter() - t0,
len(connector_docs),
len(documents),
)
return documents
except IntegrityError:
@ -185,7 +187,8 @@ class IndexingPipelineService:
)
perf.info(
"[indexing] summarize_document doc=%d in %.3fs",
document.id, time.perf_counter() - t_step,
document.id,
time.perf_counter() - t_step,
)
elif connector_doc.should_summarize and connector_doc.fallback_summary:
content = connector_doc.fallback_summary
@ -196,7 +199,8 @@ class IndexingPipelineService:
embedding = embed_text(content)
perf.debug(
"[indexing] embed_text (summary) doc=%d in %.3fs",
document.id, time.perf_counter() - t_step,
document.id,
time.perf_counter() - t_step,
)
await self.session.execute(
@ -213,7 +217,9 @@ class IndexingPipelineService:
]
perf.info(
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
document.id, len(chunks), time.perf_counter() - t_step,
document.id,
len(chunks),
time.perf_counter() - t_step,
)
document.content = content
@ -224,7 +230,9 @@ class IndexingPipelineService:
await self.session.commit()
perf.info(
"[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
document.id, len(chunks), time.perf_counter() - t_index,
document.id,
len(chunks),
time.perf_counter() - t_index,
)
log_index_success(ctx, chunk_count=len(chunks))

View file

@ -76,7 +76,10 @@ class ChucksHybridSearchRetriever:
chunks = result.scalars().all()
perf.info(
"[chunk_search] vector_search DB query in %.3fs results=%d (total %.3fs) space=%d",
time.perf_counter() - t_db, len(chunks), time.perf_counter() - t0, search_space_id,
time.perf_counter() - t_db,
len(chunks),
time.perf_counter() - t0,
search_space_id,
)
return chunks
@ -139,7 +142,9 @@ class ChucksHybridSearchRetriever:
chunks = result.scalars().all()
perf.info(
"[chunk_search] full_text_search in %.3fs results=%d space=%d",
time.perf_counter() - t0, len(chunks), search_space_id,
time.perf_counter() - t0,
len(chunks),
search_space_id,
)
return chunks
@ -152,6 +157,7 @@ class ChucksHybridSearchRetriever:
document_type: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
query_embedding: list | None = None,
) -> list:
"""
Hybrid search that returns **documents** (not individual chunks).
@ -166,6 +172,7 @@ class ChucksHybridSearchRetriever:
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at
query_embedding: Pre-computed embedding vector. If None, will be computed here.
Returns:
List of dictionaries containing document data and relevance scores. Each dict contains:
@ -183,14 +190,14 @@ class ChucksHybridSearchRetriever:
perf = get_perf_logger()
t0 = time.perf_counter()
# Get embedding for the query
embedding_model = config.embedding_model_instance
t_embed = time.perf_counter()
query_embedding = embedding_model.embed(query_text)
perf.debug(
"[chunk_search] hybrid_search embedding in %.3fs",
time.perf_counter() - t_embed,
)
if query_embedding is None:
embedding_model = config.embedding_model_instance
t_embed = time.perf_counter()
query_embedding = embedding_model.embed(query_text)
perf.debug(
"[chunk_search] hybrid_search embedding in %.3fs",
time.perf_counter() - t_embed,
)
# RRF constants
k = 60
@ -291,7 +298,10 @@ class ChucksHybridSearchRetriever:
chunks_with_scores = result.all()
perf.info(
"[chunk_search] hybrid_search RRF query in %.3fs results=%d space=%d type=%s",
time.perf_counter() - t_rrf, len(chunks_with_scores), search_space_id, document_type,
time.perf_counter() - t_rrf,
len(chunks_with_scores),
search_space_id,
document_type,
)
# If no results were found, return an empty list
@ -392,6 +402,9 @@ class ChucksHybridSearchRetriever:
perf.info(
"[chunk_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
time.perf_counter() - t0, len(final_docs), search_space_id, document_type,
time.perf_counter() - t0,
len(final_docs),
search_space_id,
document_type,
)
return final_docs

View file

@ -71,7 +71,9 @@ class DocumentHybridSearchRetriever:
documents = result.scalars().all()
perf.info(
"[doc_search] vector_search in %.3fs results=%d space=%d",
time.perf_counter() - t0, len(documents), search_space_id,
time.perf_counter() - t0,
len(documents),
search_space_id,
)
return documents
@ -133,7 +135,9 @@ class DocumentHybridSearchRetriever:
documents = result.scalars().all()
perf.info(
"[doc_search] full_text_search in %.3fs results=%d space=%d",
time.perf_counter() - t0, len(documents), search_space_id,
time.perf_counter() - t0,
len(documents),
search_space_id,
)
return documents
@ -146,6 +150,7 @@ class DocumentHybridSearchRetriever:
document_type: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
query_embedding: list | None = None,
) -> list:
"""
Hybrid search that returns **documents** (not individual chunks).
@ -160,7 +165,7 @@ class DocumentHybridSearchRetriever:
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at
query_embedding: Pre-computed embedding vector. If None, will be computed here.
"""
from sqlalchemy import func, select, text
from sqlalchemy.orm import joinedload
@ -171,9 +176,9 @@ class DocumentHybridSearchRetriever:
perf = get_perf_logger()
t0 = time.perf_counter()
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
if query_embedding is None:
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
# RRF constants
k = 60
@ -325,6 +330,9 @@ class DocumentHybridSearchRetriever:
perf.info(
"[doc_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
time.perf_counter() - t0, len(final_docs), search_space_id, document_type,
time.perf_counter() - t0,
len(final_docs),
search_space_id,
document_type,
)
return final_docs

View file

@ -7,6 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session
from app.schemas.chat_comments import (
CommentBatchRequest,
CommentBatchResponse,
CommentCreateRequest,
CommentListResponse,
CommentReplyResponse,
@ -19,6 +21,7 @@ from app.services.chat_comments_service import (
create_reply,
delete_comment,
get_comments_for_message,
get_comments_for_messages_batch,
get_user_mentions,
update_comment,
)
@ -27,6 +30,16 @@ from app.users import current_active_user
router = APIRouter()
@router.post("/messages/comments/batch", response_model=CommentBatchResponse)
async def batch_list_comments(
request: CommentBatchRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Batch-fetch comments for multiple messages in one request."""
return await get_comments_for_messages_batch(session, request.message_ids, user)
@router.get("/messages/{message_id}/comments", response_model=CommentListResponse)
async def list_comments(
message_id: int,

View file

@ -87,6 +87,18 @@ class CommentListResponse(BaseModel):
total_count: int
class CommentBatchRequest(BaseModel):
"""Request for batch-fetching comments for multiple messages."""
message_ids: list[int] = Field(..., min_length=1, max_length=200)
class CommentBatchResponse(BaseModel):
"""Batch response keyed by message_id."""
comments_by_message: dict[int, CommentListResponse]
# =============================================================================
# Mention Schemas
# =============================================================================

View file

@ -22,6 +22,7 @@ from app.db import (
)
from app.schemas.chat_comments import (
AuthorResponse,
CommentBatchResponse,
CommentListResponse,
CommentReplyResponse,
CommentResponse,
@ -264,6 +265,146 @@ async def get_comments_for_message(
)
async def get_comments_for_messages_batch(
session: AsyncSession,
message_ids: list[int],
user: User,
) -> CommentBatchResponse:
"""
Batch-fetch comments for multiple messages in a single DB round-trip.
Validates that all messages exist and belong to search spaces the user
can read comments in, then loads all comments with eager-loaded authors
and replies.
"""
if not message_ids:
return CommentBatchResponse(comments_by_message={})
unique_ids = list(set(message_ids))
result = await session.execute(
select(NewChatMessage)
.options(selectinload(NewChatMessage.thread))
.filter(NewChatMessage.id.in_(unique_ids))
)
messages = result.scalars().all()
msg_map = {m.id: m for m in messages}
search_space_ids = {m.thread.search_space_id for m in messages}
permissions_cache: dict[int, set] = {}
for ss_id in search_space_ids:
await check_permission(
session,
user,
ss_id,
Permission.COMMENTS_READ.value,
"You don't have permission to read comments in this search space",
)
permissions_cache[ss_id] = await get_user_permissions(session, user.id, ss_id)
result = await session.execute(
select(ChatComment)
.options(
selectinload(ChatComment.author),
selectinload(ChatComment.replies).selectinload(ChatComment.author),
)
.filter(
ChatComment.message_id.in_(unique_ids),
ChatComment.parent_id.is_(None),
)
.order_by(ChatComment.created_at)
)
top_level_comments = result.scalars().all()
all_mentioned_uuids: set[UUID] = set()
for comment in top_level_comments:
all_mentioned_uuids.update(parse_mentions(comment.content))
for reply in comment.replies:
all_mentioned_uuids.update(parse_mentions(reply.content))
user_names = await get_user_names_for_mentions(session, all_mentioned_uuids)
comments_by_msg: dict[int, list[ChatComment]] = {mid: [] for mid in unique_ids}
for comment in top_level_comments:
comments_by_msg.setdefault(comment.message_id, []).append(comment)
comments_by_message: dict[int, CommentListResponse] = {}
for mid in unique_ids:
msg = msg_map.get(mid)
if msg is None:
comments_by_message[mid] = CommentListResponse(comments=[], total_count=0)
continue
ss_id = msg.thread.search_space_id
user_perms = permissions_cache.get(ss_id, set())
can_delete_any = has_permission(user_perms, Permission.COMMENTS_DELETE.value)
comment_responses = []
for comment in comments_by_msg.get(mid, []):
author = None
if comment.author:
author = AuthorResponse(
id=comment.author.id,
display_name=comment.author.display_name,
avatar_url=comment.author.avatar_url,
email=comment.author.email,
)
replies = []
for reply in sorted(comment.replies, key=lambda r: r.created_at):
reply_author = None
if reply.author:
reply_author = AuthorResponse(
id=reply.author.id,
display_name=reply.author.display_name,
avatar_url=reply.author.avatar_url,
email=reply.author.email,
)
is_reply_author = (
reply.author_id == user.id if reply.author_id else False
)
replies.append(
CommentReplyResponse(
id=reply.id,
content=reply.content,
content_rendered=render_mentions(reply.content, user_names),
author=reply_author,
created_at=reply.created_at,
updated_at=reply.updated_at,
is_edited=reply.updated_at > reply.created_at,
can_edit=is_reply_author,
can_delete=is_reply_author or can_delete_any,
)
)
is_comment_author = (
comment.author_id == user.id if comment.author_id else False
)
comment_responses.append(
CommentResponse(
id=comment.id,
message_id=comment.message_id,
content=comment.content,
content_rendered=render_mentions(comment.content, user_names),
author=author,
created_at=comment.created_at,
updated_at=comment.updated_at,
is_edited=comment.updated_at > comment.created_at,
can_edit=is_comment_author,
can_delete=is_comment_author or can_delete_any,
reply_count=len(replies),
replies=replies,
)
)
comments_by_message[mid] = CommentListResponse(
comments=comment_responses,
total_count=len(comment_responses),
)
return CommentBatchResponse(comments_by_message=comments_by_message)
async def create_comment(
session: AsyncSession,
message_id: int,

View file

@ -16,6 +16,7 @@ from app.db import (
Document,
SearchSourceConnector,
SearchSourceConnectorType,
async_session_maker,
)
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriever.documents_hybrid_search import DocumentHybridSearchRetriever
@ -248,6 +249,8 @@ class ConnectorService:
Returns:
List of combined and deduplicated document results
"""
from app.config import config
perf = get_perf_logger()
t0 = time.perf_counter()
@ -257,39 +260,48 @@ class ConnectorService:
# Get more results from each retriever for better fusion
retriever_top_k = top_k * 2
# IMPORTANT:
# These retrievers share the same AsyncSession. AsyncSession does not permit
# concurrent awaits that require DB IO on the same session/connection.
# Running these in parallel can raise:
# "This session is provisioning a new connection; concurrent operations are not permitted"
#
# So we run them sequentially.
t_chunk = time.perf_counter()
chunk_results = await self.chunk_retriever.hybrid_search(
query_text=query_text,
top_k=retriever_top_k,
search_space_id=search_space_id,
document_type=document_type,
start_date=start_date,
end_date=end_date,
)
# Pre-compute the embedding once so both retrievers reuse it.
t_embed = time.perf_counter()
query_embedding = config.embedding_model_instance.embed(query_text)
perf.info(
"[connector_svc] _combined_rrf chunk_retriever in %.3fs results=%d type=%s",
time.perf_counter() - t_chunk, len(chunk_results), document_type,
"[connector_svc] _combined_rrf embedding in %.3fs type=%s",
time.perf_counter() - t_embed,
document_type,
)
t_doc = time.perf_counter()
doc_results = await self.document_retriever.hybrid_search(
query_text=query_text,
top_k=retriever_top_k,
search_space_id=search_space_id,
document_type=document_type,
start_date=start_date,
end_date=end_date,
search_kwargs = {
"query_text": query_text,
"top_k": retriever_top_k,
"search_space_id": search_space_id,
"document_type": document_type,
"start_date": start_date,
"end_date": end_date,
"query_embedding": query_embedding,
}
# Run chunk and document retrievers in parallel using separate DB sessions
# so they don't contend on a shared AsyncSession connection.
async def _run_chunk_search() -> list[dict[str, Any]]:
async with async_session_maker() as session:
retriever = ChucksHybridSearchRetriever(session)
return await retriever.hybrid_search(**search_kwargs)
async def _run_doc_search() -> list[dict[str, Any]]:
async with async_session_maker() as session:
retriever = DocumentHybridSearchRetriever(session)
return await retriever.hybrid_search(**search_kwargs)
t_parallel = time.perf_counter()
chunk_results, doc_results = await asyncio.gather(
_run_chunk_search(), _run_doc_search()
)
perf.info(
"[connector_svc] _combined_rrf doc_retriever in %.3fs results=%d type=%s",
time.perf_counter() - t_doc, len(doc_results), document_type,
"[connector_svc] _combined_rrf parallel retrievers in %.3fs "
"chunk_results=%d doc_results=%d type=%s",
time.perf_counter() - t_parallel,
len(chunk_results),
len(doc_results),
document_type,
)
# Helper to extract document_id from our doc-grouped result
@ -353,7 +365,10 @@ class ConnectorService:
perf.info(
"[connector_svc] _combined_rrf_search TOTAL in %.3fs results=%d type=%s space=%d",
time.perf_counter() - t0, len(combined_results), document_type, search_space_id,
time.perf_counter() - t0,
len(combined_results),
document_type,
search_space_id,
)
return combined_results

View file

@ -437,14 +437,16 @@ class ChatLiteLLMRouter(BaseChatModel):
except ContextWindowExceededError as e:
perf.warning(
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count, time.perf_counter() - t0,
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
perf.warning(
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count, time.perf_counter() - t0,
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
raise
@ -500,14 +502,16 @@ class ChatLiteLLMRouter(BaseChatModel):
except ContextWindowExceededError as e:
perf.warning(
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count, time.perf_counter() - t0,
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
perf.warning(
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count, time.perf_counter() - t0,
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
raise
@ -608,14 +612,16 @@ class ChatLiteLLMRouter(BaseChatModel):
except ContextWindowExceededError as e:
perf.warning(
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count, time.perf_counter() - t0,
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
perf.warning(
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count, time.perf_counter() - t0,
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
raise
@ -623,7 +629,8 @@ class ChatLiteLLMRouter(BaseChatModel):
t_first_chunk = time.perf_counter()
perf.info(
"[llm_router] _astream connection established msgs=%d in %.3fs",
msg_count, t_first_chunk - t0,
msg_count,
t_first_chunk - t0,
)
chunk_count = 0
@ -645,7 +652,8 @@ class ChatLiteLLMRouter(BaseChatModel):
perf.info(
"[llm_router] _astream completed chunks=%d total=%.3fs",
chunk_count, time.perf_counter() - t0,
chunk_count,
time.perf_counter() - t0,
)
def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]:

View file

@ -65,6 +65,7 @@ import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking";
import { Button } from "@/components/ui/button";
import { Spinner } from "@/components/ui/spinner";
import type { Document } from "@/contracts/types/document.types";
import { useBatchCommentsPreload } from "@/hooks/use-comments";
import { useCommentsElectric } from "@/hooks/use-comments-electric";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { cn } from "@/lib/utils";
@ -309,6 +310,22 @@ const Composer: FC = () => {
// Sync comments for the entire thread via Electric SQL (one subscription per thread)
useCommentsElectric(threadId);
// Batch-prefetch comments for all assistant messages so individual useComments
// hooks never fire their own network requests (eliminates N+1 API calls).
// Return a primitive string from the selector so useSyncExternalStore can
// compare snapshots by value and avoid infinite re-render loops.
const assistantIdsKey = useAssistantState(({ thread }) =>
thread.messages
.filter((m) => m.role === "assistant" && m.id?.startsWith("msg-"))
.map((m) => m.id!.replace("msg-", ""))
.join(",")
);
const assistantDbMessageIds = useMemo(
() => (assistantIdsKey ? assistantIdsKey.split(",").map(Number) : []),
[assistantIdsKey]
);
useBatchCommentsPreload(assistantDbMessageIds);
// Auto-focus editor on new chat page after mount
useEffect(() => {
if (isThreadEmpty && !hasAutoFocusedRef.current && editorRef.current) {

View file

@ -82,6 +82,22 @@ export const getCommentsResponse = z.object({
total_count: z.number(),
});
/**
* Batch-fetch comments for multiple messages
*/
export const getBatchCommentsRequest = z.object({
message_ids: z.array(z.number()).min(1).max(200),
});
export const commentListResponse = z.object({
comments: z.array(comment),
total_count: z.number(),
});
export const getBatchCommentsResponse = z.object({
comments_by_message: z.record(z.string(), commentListResponse),
});
/**
* Create comment
*/
@ -145,6 +161,8 @@ export type MentionComment = z.infer<typeof mentionComment>;
export type Mention = z.infer<typeof mention>;
export type GetCommentsRequest = z.infer<typeof getCommentsRequest>;
export type GetCommentsResponse = z.infer<typeof getCommentsResponse>;
export type GetBatchCommentsRequest = z.infer<typeof getBatchCommentsRequest>;
export type GetBatchCommentsResponse = z.infer<typeof getBatchCommentsResponse>;
export type CreateCommentRequest = z.infer<typeof createCommentRequest>;
export type CreateCommentResponse = z.infer<typeof createCommentResponse>;
export type CreateReplyRequest = z.infer<typeof createReplyRequest>;

View file

@ -1,4 +1,5 @@
import { useQuery } from "@tanstack/react-query";
import { useQuery, useQueryClient } from "@tanstack/react-query";
import { useEffect, useRef } from "react";
import { chatCommentsApiService } from "@/lib/apis/chat-comments-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
@ -7,12 +8,84 @@ interface UseCommentsOptions {
enabled?: boolean;
}
// ---------------------------------------------------------------------------
// Module-level coordination: when a batch request is in-flight, individual
// useComments queryFns piggy-back on it instead of making their own requests.
// ---------------------------------------------------------------------------
let _batchInflight: Promise<void> | null = null;
let _batchTargetIds = new Set<number>();
export function useComments({ messageId, enabled = true }: UseCommentsOptions) {
const queryClient = useQueryClient();
return useQuery({
queryKey: cacheKeys.comments.byMessage(messageId),
queryFn: async () => {
// Yield one macro-task so the batch prefetch useEffect (which sets
// _batchInflight) has a chance to fire before we decide to fetch.
await new Promise<void>((r) => setTimeout(r, 0));
if (_batchInflight && _batchTargetIds.has(messageId)) {
await _batchInflight;
const cached = queryClient.getQueryData(cacheKeys.comments.byMessage(messageId));
if (cached) return cached;
}
return chatCommentsApiService.getComments({ message_id: messageId });
},
enabled: enabled && !!messageId,
staleTime: 30_000,
});
}
/**
* Batch-fetch comments for all given message IDs in a single request, then
* seed the per-message React Query cache so individual useComments hooks
* resolve from cache instead of firing their own requests.
*/
export function useBatchCommentsPreload(messageIds: number[]) {
const queryClient = useQueryClient();
const prevKeyRef = useRef<string>("");
useEffect(() => {
if (!messageIds.length) return;
const key = messageIds
.slice()
.sort((a, b) => a - b)
.join(",");
if (key === prevKeyRef.current) return;
prevKeyRef.current = key;
_batchTargetIds = new Set(messageIds);
let cancelled = false;
const promise = chatCommentsApiService
.getBatchComments({ message_ids: messageIds })
.then((data) => {
if (cancelled) return;
for (const [msgIdStr, commentList] of Object.entries(data.comments_by_message)) {
queryClient.setQueryData(cacheKeys.comments.byMessage(Number(msgIdStr)), commentList);
}
})
.catch(() => {
// Batch failed; individual queryFns will fall through to their own fetch
})
.finally(() => {
if (_batchInflight === promise) {
_batchInflight = null;
_batchTargetIds = new Set();
}
});
_batchInflight = promise;
return () => {
cancelled = true;
if (_batchInflight === promise) {
_batchInflight = null;
_batchTargetIds = new Set();
}
};
}, [messageIds, queryClient]);
}

View file

@ -8,8 +8,11 @@ import {
type DeleteCommentRequest,
deleteCommentRequest,
deleteCommentResponse,
type GetBatchCommentsRequest,
type GetCommentsRequest,
type GetMentionsRequest,
getBatchCommentsRequest,
getBatchCommentsResponse,
getCommentsRequest,
getCommentsResponse,
getMentionsRequest,
@ -22,6 +25,22 @@ import { ValidationError } from "@/lib/error";
import { baseApiService } from "./base-api.service";
class ChatCommentsApiService {
/**
* Batch-fetch comments for multiple messages in one request
*/
getBatchComments = async (request: GetBatchCommentsRequest) => {
const parsed = getBatchCommentsRequest.safeParse(request);
if (!parsed.success) {
const errorMessage = parsed.error.issues.map((issue) => issue.message).join(", ");
throw new ValidationError(`Invalid request: ${errorMessage}`);
}
return baseApiService.post("/api/v1/messages/comments/batch", getBatchCommentsResponse, {
body: { message_ids: parsed.data.message_ids },
});
};
/**
* Get comments for a message
*/