Merge pull request #846 from MODSetter/dev

feat: perf optimizations
This commit is contained in:
Rohan Verma 2026-02-27 17:24:20 -08:00 committed by GitHub
commit 4105bd0d7a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 424 additions and 67 deletions

View file

@ -430,7 +430,10 @@ async def search_knowledge_base_async(
connectors = _normalize_connectors(connectors_to_search, available_connectors) connectors = _normalize_connectors(connectors_to_search, available_connectors)
perf.info( perf.info(
"[kb_search] searching %d connectors: %s (space=%d, top_k=%d)", "[kb_search] searching %d connectors: %s (space=%d, top_k=%d)",
len(connectors), connectors[:5], search_space_id, top_k, len(connectors),
connectors[:5],
search_space_id,
top_k,
) )
connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = { connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
@ -510,13 +513,17 @@ async def search_knowledge_base_async(
_, chunks = await connector_method(**kwargs) _, chunks = await connector_method(**kwargs)
perf.info( perf.info(
"[kb_search] connector=%s results=%d in %.3fs", "[kb_search] connector=%s results=%d in %.3fs",
connector, len(chunks), time.perf_counter() - t_conn, connector,
len(chunks),
time.perf_counter() - t_conn,
) )
return chunks return chunks
except Exception as e: except Exception as e:
perf.warning( perf.warning(
"[kb_search] connector=%s FAILED in %.3fs: %s", "[kb_search] connector=%s FAILED in %.3fs: %s",
connector, time.perf_counter() - t_conn, e, connector,
time.perf_counter() - t_conn,
e,
) )
return [] return []
@ -525,7 +532,8 @@ async def search_knowledge_base_async(
*[_search_one_connector(connector) for connector in connectors] *[_search_one_connector(connector) for connector in connectors]
) )
perf.info( perf.info(
"[kb_search] all connectors gathered in %.3fs", time.perf_counter() - t_gather, "[kb_search] all connectors gathered in %.3fs",
time.perf_counter() - t_gather,
) )
for chunks in connector_results: for chunks in connector_results:
all_documents.extend(chunks) all_documents.extend(chunks)
@ -576,7 +584,11 @@ async def search_knowledge_base_async(
result = format_documents_for_context(deduplicated, max_chars=output_budget) result = format_documents_for_context(deduplicated, max_chars=output_budget)
perf.info( perf.info(
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d space=%d", "[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d space=%d",
time.perf_counter() - t0, len(all_documents), len(deduplicated), len(result), search_space_id, time.perf_counter() - t0,
len(all_documents),
len(deduplicated),
len(result),
search_space_id,
) )
return result return result

View file

@ -147,7 +147,9 @@ class IndexingPipelineService:
await self.session.commit() await self.session.commit()
perf.info( perf.info(
"[indexing] prepare_for_indexing in %.3fs input=%d output=%d", "[indexing] prepare_for_indexing in %.3fs input=%d output=%d",
time.perf_counter() - t0, len(connector_docs), len(documents), time.perf_counter() - t0,
len(connector_docs),
len(documents),
) )
return documents return documents
except IntegrityError: except IntegrityError:
@ -185,7 +187,8 @@ class IndexingPipelineService:
) )
perf.info( perf.info(
"[indexing] summarize_document doc=%d in %.3fs", "[indexing] summarize_document doc=%d in %.3fs",
document.id, time.perf_counter() - t_step, document.id,
time.perf_counter() - t_step,
) )
elif connector_doc.should_summarize and connector_doc.fallback_summary: elif connector_doc.should_summarize and connector_doc.fallback_summary:
content = connector_doc.fallback_summary content = connector_doc.fallback_summary
@ -196,7 +199,8 @@ class IndexingPipelineService:
embedding = embed_text(content) embedding = embed_text(content)
perf.debug( perf.debug(
"[indexing] embed_text (summary) doc=%d in %.3fs", "[indexing] embed_text (summary) doc=%d in %.3fs",
document.id, time.perf_counter() - t_step, document.id,
time.perf_counter() - t_step,
) )
await self.session.execute( await self.session.execute(
@ -213,7 +217,9 @@ class IndexingPipelineService:
] ]
perf.info( perf.info(
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs", "[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
document.id, len(chunks), time.perf_counter() - t_step, document.id,
len(chunks),
time.perf_counter() - t_step,
) )
document.content = content document.content = content
@ -224,7 +230,9 @@ class IndexingPipelineService:
await self.session.commit() await self.session.commit()
perf.info( perf.info(
"[indexing] index TOTAL doc=%d chunks=%d in %.3fs", "[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
document.id, len(chunks), time.perf_counter() - t_index, document.id,
len(chunks),
time.perf_counter() - t_index,
) )
log_index_success(ctx, chunk_count=len(chunks)) log_index_success(ctx, chunk_count=len(chunks))

View file

@ -76,7 +76,10 @@ class ChucksHybridSearchRetriever:
chunks = result.scalars().all() chunks = result.scalars().all()
perf.info( perf.info(
"[chunk_search] vector_search DB query in %.3fs results=%d (total %.3fs) space=%d", "[chunk_search] vector_search DB query in %.3fs results=%d (total %.3fs) space=%d",
time.perf_counter() - t_db, len(chunks), time.perf_counter() - t0, search_space_id, time.perf_counter() - t_db,
len(chunks),
time.perf_counter() - t0,
search_space_id,
) )
return chunks return chunks
@ -139,7 +142,9 @@ class ChucksHybridSearchRetriever:
chunks = result.scalars().all() chunks = result.scalars().all()
perf.info( perf.info(
"[chunk_search] full_text_search in %.3fs results=%d space=%d", "[chunk_search] full_text_search in %.3fs results=%d space=%d",
time.perf_counter() - t0, len(chunks), search_space_id, time.perf_counter() - t0,
len(chunks),
search_space_id,
) )
return chunks return chunks
@ -152,6 +157,7 @@ class ChucksHybridSearchRetriever:
document_type: str | None = None, document_type: str | None = None,
start_date: datetime | None = None, start_date: datetime | None = None,
end_date: datetime | None = None, end_date: datetime | None = None,
query_embedding: list | None = None,
) -> list: ) -> list:
""" """
Hybrid search that returns **documents** (not individual chunks). Hybrid search that returns **documents** (not individual chunks).
@ -166,6 +172,7 @@ class ChucksHybridSearchRetriever:
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL") document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
start_date: Optional start date for filtering documents by updated_at start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at end_date: Optional end date for filtering documents by updated_at
query_embedding: Pre-computed embedding vector. If None, will be computed here.
Returns: Returns:
List of dictionaries containing document data and relevance scores. Each dict contains: List of dictionaries containing document data and relevance scores. Each dict contains:
@ -183,14 +190,14 @@ class ChucksHybridSearchRetriever:
perf = get_perf_logger() perf = get_perf_logger()
t0 = time.perf_counter() t0 = time.perf_counter()
# Get embedding for the query if query_embedding is None:
embedding_model = config.embedding_model_instance embedding_model = config.embedding_model_instance
t_embed = time.perf_counter() t_embed = time.perf_counter()
query_embedding = embedding_model.embed(query_text) query_embedding = embedding_model.embed(query_text)
perf.debug( perf.debug(
"[chunk_search] hybrid_search embedding in %.3fs", "[chunk_search] hybrid_search embedding in %.3fs",
time.perf_counter() - t_embed, time.perf_counter() - t_embed,
) )
# RRF constants # RRF constants
k = 60 k = 60
@ -291,7 +298,10 @@ class ChucksHybridSearchRetriever:
chunks_with_scores = result.all() chunks_with_scores = result.all()
perf.info( perf.info(
"[chunk_search] hybrid_search RRF query in %.3fs results=%d space=%d type=%s", "[chunk_search] hybrid_search RRF query in %.3fs results=%d space=%d type=%s",
time.perf_counter() - t_rrf, len(chunks_with_scores), search_space_id, document_type, time.perf_counter() - t_rrf,
len(chunks_with_scores),
search_space_id,
document_type,
) )
# If no results were found, return an empty list # If no results were found, return an empty list
@ -392,6 +402,9 @@ class ChucksHybridSearchRetriever:
perf.info( perf.info(
"[chunk_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s", "[chunk_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
time.perf_counter() - t0, len(final_docs), search_space_id, document_type, time.perf_counter() - t0,
len(final_docs),
search_space_id,
document_type,
) )
return final_docs return final_docs

View file

@ -71,7 +71,9 @@ class DocumentHybridSearchRetriever:
documents = result.scalars().all() documents = result.scalars().all()
perf.info( perf.info(
"[doc_search] vector_search in %.3fs results=%d space=%d", "[doc_search] vector_search in %.3fs results=%d space=%d",
time.perf_counter() - t0, len(documents), search_space_id, time.perf_counter() - t0,
len(documents),
search_space_id,
) )
return documents return documents
@ -133,7 +135,9 @@ class DocumentHybridSearchRetriever:
documents = result.scalars().all() documents = result.scalars().all()
perf.info( perf.info(
"[doc_search] full_text_search in %.3fs results=%d space=%d", "[doc_search] full_text_search in %.3fs results=%d space=%d",
time.perf_counter() - t0, len(documents), search_space_id, time.perf_counter() - t0,
len(documents),
search_space_id,
) )
return documents return documents
@ -146,6 +150,7 @@ class DocumentHybridSearchRetriever:
document_type: str | None = None, document_type: str | None = None,
start_date: datetime | None = None, start_date: datetime | None = None,
end_date: datetime | None = None, end_date: datetime | None = None,
query_embedding: list | None = None,
) -> list: ) -> list:
""" """
Hybrid search that returns **documents** (not individual chunks). Hybrid search that returns **documents** (not individual chunks).
@ -160,7 +165,7 @@ class DocumentHybridSearchRetriever:
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL") document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
start_date: Optional start date for filtering documents by updated_at start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at end_date: Optional end date for filtering documents by updated_at
query_embedding: Pre-computed embedding vector. If None, will be computed here.
""" """
from sqlalchemy import func, select, text from sqlalchemy import func, select, text
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
@ -171,9 +176,9 @@ class DocumentHybridSearchRetriever:
perf = get_perf_logger() perf = get_perf_logger()
t0 = time.perf_counter() t0 = time.perf_counter()
# Get embedding for the query if query_embedding is None:
embedding_model = config.embedding_model_instance embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text) query_embedding = embedding_model.embed(query_text)
# RRF constants # RRF constants
k = 60 k = 60
@ -325,6 +330,9 @@ class DocumentHybridSearchRetriever:
perf.info( perf.info(
"[doc_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s", "[doc_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
time.perf_counter() - t0, len(final_docs), search_space_id, document_type, time.perf_counter() - t0,
len(final_docs),
search_space_id,
document_type,
) )
return final_docs return final_docs

View file

@ -7,6 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session from app.db import User, get_async_session
from app.schemas.chat_comments import ( from app.schemas.chat_comments import (
CommentBatchRequest,
CommentBatchResponse,
CommentCreateRequest, CommentCreateRequest,
CommentListResponse, CommentListResponse,
CommentReplyResponse, CommentReplyResponse,
@ -19,6 +21,7 @@ from app.services.chat_comments_service import (
create_reply, create_reply,
delete_comment, delete_comment,
get_comments_for_message, get_comments_for_message,
get_comments_for_messages_batch,
get_user_mentions, get_user_mentions,
update_comment, update_comment,
) )
@ -27,6 +30,16 @@ from app.users import current_active_user
router = APIRouter() router = APIRouter()
@router.post("/messages/comments/batch", response_model=CommentBatchResponse)
async def batch_list_comments(
request: CommentBatchRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Batch-fetch comments for multiple messages in one request."""
return await get_comments_for_messages_batch(session, request.message_ids, user)
@router.get("/messages/{message_id}/comments", response_model=CommentListResponse) @router.get("/messages/{message_id}/comments", response_model=CommentListResponse)
async def list_comments( async def list_comments(
message_id: int, message_id: int,

View file

@ -87,6 +87,18 @@ class CommentListResponse(BaseModel):
total_count: int total_count: int
class CommentBatchRequest(BaseModel):
"""Request for batch-fetching comments for multiple messages."""
message_ids: list[int] = Field(..., min_length=1, max_length=200)
class CommentBatchResponse(BaseModel):
"""Batch response keyed by message_id."""
comments_by_message: dict[int, CommentListResponse]
# ============================================================================= # =============================================================================
# Mention Schemas # Mention Schemas
# ============================================================================= # =============================================================================

View file

@ -22,6 +22,7 @@ from app.db import (
) )
from app.schemas.chat_comments import ( from app.schemas.chat_comments import (
AuthorResponse, AuthorResponse,
CommentBatchResponse,
CommentListResponse, CommentListResponse,
CommentReplyResponse, CommentReplyResponse,
CommentResponse, CommentResponse,
@ -264,6 +265,146 @@ async def get_comments_for_message(
) )
async def get_comments_for_messages_batch(
session: AsyncSession,
message_ids: list[int],
user: User,
) -> CommentBatchResponse:
"""
Batch-fetch comments for multiple messages in a single DB round-trip.
Validates that all messages exist and belong to search spaces the user
can read comments in, then loads all comments with eager-loaded authors
and replies.
"""
if not message_ids:
return CommentBatchResponse(comments_by_message={})
unique_ids = list(set(message_ids))
result = await session.execute(
select(NewChatMessage)
.options(selectinload(NewChatMessage.thread))
.filter(NewChatMessage.id.in_(unique_ids))
)
messages = result.scalars().all()
msg_map = {m.id: m for m in messages}
search_space_ids = {m.thread.search_space_id for m in messages}
permissions_cache: dict[int, set] = {}
for ss_id in search_space_ids:
await check_permission(
session,
user,
ss_id,
Permission.COMMENTS_READ.value,
"You don't have permission to read comments in this search space",
)
permissions_cache[ss_id] = await get_user_permissions(session, user.id, ss_id)
result = await session.execute(
select(ChatComment)
.options(
selectinload(ChatComment.author),
selectinload(ChatComment.replies).selectinload(ChatComment.author),
)
.filter(
ChatComment.message_id.in_(unique_ids),
ChatComment.parent_id.is_(None),
)
.order_by(ChatComment.created_at)
)
top_level_comments = result.scalars().all()
all_mentioned_uuids: set[UUID] = set()
for comment in top_level_comments:
all_mentioned_uuids.update(parse_mentions(comment.content))
for reply in comment.replies:
all_mentioned_uuids.update(parse_mentions(reply.content))
user_names = await get_user_names_for_mentions(session, all_mentioned_uuids)
comments_by_msg: dict[int, list[ChatComment]] = {mid: [] for mid in unique_ids}
for comment in top_level_comments:
comments_by_msg.setdefault(comment.message_id, []).append(comment)
comments_by_message: dict[int, CommentListResponse] = {}
for mid in unique_ids:
msg = msg_map.get(mid)
if msg is None:
comments_by_message[mid] = CommentListResponse(comments=[], total_count=0)
continue
ss_id = msg.thread.search_space_id
user_perms = permissions_cache.get(ss_id, set())
can_delete_any = has_permission(user_perms, Permission.COMMENTS_DELETE.value)
comment_responses = []
for comment in comments_by_msg.get(mid, []):
author = None
if comment.author:
author = AuthorResponse(
id=comment.author.id,
display_name=comment.author.display_name,
avatar_url=comment.author.avatar_url,
email=comment.author.email,
)
replies = []
for reply in sorted(comment.replies, key=lambda r: r.created_at):
reply_author = None
if reply.author:
reply_author = AuthorResponse(
id=reply.author.id,
display_name=reply.author.display_name,
avatar_url=reply.author.avatar_url,
email=reply.author.email,
)
is_reply_author = (
reply.author_id == user.id if reply.author_id else False
)
replies.append(
CommentReplyResponse(
id=reply.id,
content=reply.content,
content_rendered=render_mentions(reply.content, user_names),
author=reply_author,
created_at=reply.created_at,
updated_at=reply.updated_at,
is_edited=reply.updated_at > reply.created_at,
can_edit=is_reply_author,
can_delete=is_reply_author or can_delete_any,
)
)
is_comment_author = (
comment.author_id == user.id if comment.author_id else False
)
comment_responses.append(
CommentResponse(
id=comment.id,
message_id=comment.message_id,
content=comment.content,
content_rendered=render_mentions(comment.content, user_names),
author=author,
created_at=comment.created_at,
updated_at=comment.updated_at,
is_edited=comment.updated_at > comment.created_at,
can_edit=is_comment_author,
can_delete=is_comment_author or can_delete_any,
reply_count=len(replies),
replies=replies,
)
)
comments_by_message[mid] = CommentListResponse(
comments=comment_responses,
total_count=len(comment_responses),
)
return CommentBatchResponse(comments_by_message=comments_by_message)
async def create_comment( async def create_comment(
session: AsyncSession, session: AsyncSession,
message_id: int, message_id: int,

View file

@ -16,6 +16,7 @@ from app.db import (
Document, Document,
SearchSourceConnector, SearchSourceConnector,
SearchSourceConnectorType, SearchSourceConnectorType,
async_session_maker,
) )
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriever.documents_hybrid_search import DocumentHybridSearchRetriever from app.retriever.documents_hybrid_search import DocumentHybridSearchRetriever
@ -248,6 +249,8 @@ class ConnectorService:
Returns: Returns:
List of combined and deduplicated document results List of combined and deduplicated document results
""" """
from app.config import config
perf = get_perf_logger() perf = get_perf_logger()
t0 = time.perf_counter() t0 = time.perf_counter()
@ -257,39 +260,48 @@ class ConnectorService:
# Get more results from each retriever for better fusion # Get more results from each retriever for better fusion
retriever_top_k = top_k * 2 retriever_top_k = top_k * 2
# IMPORTANT: # Pre-compute the embedding once so both retrievers reuse it.
# These retrievers share the same AsyncSession. AsyncSession does not permit t_embed = time.perf_counter()
# concurrent awaits that require DB IO on the same session/connection. query_embedding = config.embedding_model_instance.embed(query_text)
# Running these in parallel can raise:
# "This session is provisioning a new connection; concurrent operations are not permitted"
#
# So we run them sequentially.
t_chunk = time.perf_counter()
chunk_results = await self.chunk_retriever.hybrid_search(
query_text=query_text,
top_k=retriever_top_k,
search_space_id=search_space_id,
document_type=document_type,
start_date=start_date,
end_date=end_date,
)
perf.info( perf.info(
"[connector_svc] _combined_rrf chunk_retriever in %.3fs results=%d type=%s", "[connector_svc] _combined_rrf embedding in %.3fs type=%s",
time.perf_counter() - t_chunk, len(chunk_results), document_type, time.perf_counter() - t_embed,
document_type,
) )
t_doc = time.perf_counter() search_kwargs = {
doc_results = await self.document_retriever.hybrid_search( "query_text": query_text,
query_text=query_text, "top_k": retriever_top_k,
top_k=retriever_top_k, "search_space_id": search_space_id,
search_space_id=search_space_id, "document_type": document_type,
document_type=document_type, "start_date": start_date,
start_date=start_date, "end_date": end_date,
end_date=end_date, "query_embedding": query_embedding,
}
# Run chunk and document retrievers in parallel using separate DB sessions
# so they don't contend on a shared AsyncSession connection.
async def _run_chunk_search() -> list[dict[str, Any]]:
async with async_session_maker() as session:
retriever = ChucksHybridSearchRetriever(session)
return await retriever.hybrid_search(**search_kwargs)
async def _run_doc_search() -> list[dict[str, Any]]:
async with async_session_maker() as session:
retriever = DocumentHybridSearchRetriever(session)
return await retriever.hybrid_search(**search_kwargs)
t_parallel = time.perf_counter()
chunk_results, doc_results = await asyncio.gather(
_run_chunk_search(), _run_doc_search()
) )
perf.info( perf.info(
"[connector_svc] _combined_rrf doc_retriever in %.3fs results=%d type=%s", "[connector_svc] _combined_rrf parallel retrievers in %.3fs "
time.perf_counter() - t_doc, len(doc_results), document_type, "chunk_results=%d doc_results=%d type=%s",
time.perf_counter() - t_parallel,
len(chunk_results),
len(doc_results),
document_type,
) )
# Helper to extract document_id from our doc-grouped result # Helper to extract document_id from our doc-grouped result
@ -353,7 +365,10 @@ class ConnectorService:
perf.info( perf.info(
"[connector_svc] _combined_rrf_search TOTAL in %.3fs results=%d type=%s space=%d", "[connector_svc] _combined_rrf_search TOTAL in %.3fs results=%d type=%s space=%d",
time.perf_counter() - t0, len(combined_results), document_type, search_space_id, time.perf_counter() - t0,
len(combined_results),
document_type,
search_space_id,
) )
return combined_results return combined_results

View file

@ -437,14 +437,16 @@ class ChatLiteLLMRouter(BaseChatModel):
except ContextWindowExceededError as e: except ContextWindowExceededError as e:
perf.warning( perf.warning(
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs", "[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count, time.perf_counter() - t0, msg_count,
time.perf_counter() - t0,
) )
raise ContextOverflowError(str(e)) from e raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e: except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e): if _is_context_overflow_error(e):
perf.warning( perf.warning(
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs", "[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count, time.perf_counter() - t0, msg_count,
time.perf_counter() - t0,
) )
raise ContextOverflowError(str(e)) from e raise ContextOverflowError(str(e)) from e
raise raise
@ -500,14 +502,16 @@ class ChatLiteLLMRouter(BaseChatModel):
except ContextWindowExceededError as e: except ContextWindowExceededError as e:
perf.warning( perf.warning(
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs", "[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count, time.perf_counter() - t0, msg_count,
time.perf_counter() - t0,
) )
raise ContextOverflowError(str(e)) from e raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e: except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e): if _is_context_overflow_error(e):
perf.warning( perf.warning(
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs", "[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count, time.perf_counter() - t0, msg_count,
time.perf_counter() - t0,
) )
raise ContextOverflowError(str(e)) from e raise ContextOverflowError(str(e)) from e
raise raise
@ -608,14 +612,16 @@ class ChatLiteLLMRouter(BaseChatModel):
except ContextWindowExceededError as e: except ContextWindowExceededError as e:
perf.warning( perf.warning(
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs", "[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count, time.perf_counter() - t0, msg_count,
time.perf_counter() - t0,
) )
raise ContextOverflowError(str(e)) from e raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e: except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e): if _is_context_overflow_error(e):
perf.warning( perf.warning(
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs", "[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count, time.perf_counter() - t0, msg_count,
time.perf_counter() - t0,
) )
raise ContextOverflowError(str(e)) from e raise ContextOverflowError(str(e)) from e
raise raise
@ -623,7 +629,8 @@ class ChatLiteLLMRouter(BaseChatModel):
t_first_chunk = time.perf_counter() t_first_chunk = time.perf_counter()
perf.info( perf.info(
"[llm_router] _astream connection established msgs=%d in %.3fs", "[llm_router] _astream connection established msgs=%d in %.3fs",
msg_count, t_first_chunk - t0, msg_count,
t_first_chunk - t0,
) )
chunk_count = 0 chunk_count = 0
@ -645,7 +652,8 @@ class ChatLiteLLMRouter(BaseChatModel):
perf.info( perf.info(
"[llm_router] _astream completed chunks=%d total=%.3fs", "[llm_router] _astream completed chunks=%d total=%.3fs",
chunk_count, time.perf_counter() - t0, chunk_count,
time.perf_counter() - t0,
) )
def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]: def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]:

View file

@ -65,6 +65,7 @@ import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Spinner } from "@/components/ui/spinner"; import { Spinner } from "@/components/ui/spinner";
import type { Document } from "@/contracts/types/document.types"; import type { Document } from "@/contracts/types/document.types";
import { useBatchCommentsPreload } from "@/hooks/use-comments";
import { useCommentsElectric } from "@/hooks/use-comments-electric"; import { useCommentsElectric } from "@/hooks/use-comments-electric";
import { documentsApiService } from "@/lib/apis/documents-api.service"; import { documentsApiService } from "@/lib/apis/documents-api.service";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
@ -309,6 +310,22 @@ const Composer: FC = () => {
// Sync comments for the entire thread via Electric SQL (one subscription per thread) // Sync comments for the entire thread via Electric SQL (one subscription per thread)
useCommentsElectric(threadId); useCommentsElectric(threadId);
// Batch-prefetch comments for all assistant messages so individual useComments
// hooks never fire their own network requests (eliminates N+1 API calls).
// Return a primitive string from the selector so useSyncExternalStore can
// compare snapshots by value and avoid infinite re-render loops.
const assistantIdsKey = useAssistantState(({ thread }) =>
thread.messages
.filter((m) => m.role === "assistant" && m.id?.startsWith("msg-"))
.map((m) => m.id!.replace("msg-", ""))
.join(",")
);
const assistantDbMessageIds = useMemo(
() => (assistantIdsKey ? assistantIdsKey.split(",").map(Number) : []),
[assistantIdsKey]
);
useBatchCommentsPreload(assistantDbMessageIds);
// Auto-focus editor on new chat page after mount // Auto-focus editor on new chat page after mount
useEffect(() => { useEffect(() => {
if (isThreadEmpty && !hasAutoFocusedRef.current && editorRef.current) { if (isThreadEmpty && !hasAutoFocusedRef.current && editorRef.current) {

View file

@ -82,6 +82,22 @@ export const getCommentsResponse = z.object({
total_count: z.number(), total_count: z.number(),
}); });
/**
* Batch-fetch comments for multiple messages
*/
export const getBatchCommentsRequest = z.object({
message_ids: z.array(z.number()).min(1).max(200),
});
export const commentListResponse = z.object({
comments: z.array(comment),
total_count: z.number(),
});
export const getBatchCommentsResponse = z.object({
comments_by_message: z.record(z.string(), commentListResponse),
});
/** /**
* Create comment * Create comment
*/ */
@ -145,6 +161,8 @@ export type MentionComment = z.infer<typeof mentionComment>;
export type Mention = z.infer<typeof mention>; export type Mention = z.infer<typeof mention>;
export type GetCommentsRequest = z.infer<typeof getCommentsRequest>; export type GetCommentsRequest = z.infer<typeof getCommentsRequest>;
export type GetCommentsResponse = z.infer<typeof getCommentsResponse>; export type GetCommentsResponse = z.infer<typeof getCommentsResponse>;
export type GetBatchCommentsRequest = z.infer<typeof getBatchCommentsRequest>;
export type GetBatchCommentsResponse = z.infer<typeof getBatchCommentsResponse>;
export type CreateCommentRequest = z.infer<typeof createCommentRequest>; export type CreateCommentRequest = z.infer<typeof createCommentRequest>;
export type CreateCommentResponse = z.infer<typeof createCommentResponse>; export type CreateCommentResponse = z.infer<typeof createCommentResponse>;
export type CreateReplyRequest = z.infer<typeof createReplyRequest>; export type CreateReplyRequest = z.infer<typeof createReplyRequest>;

View file

@ -1,4 +1,5 @@
import { useQuery } from "@tanstack/react-query"; import { useQuery, useQueryClient } from "@tanstack/react-query";
import { useEffect, useRef } from "react";
import { chatCommentsApiService } from "@/lib/apis/chat-comments-api.service"; import { chatCommentsApiService } from "@/lib/apis/chat-comments-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys"; import { cacheKeys } from "@/lib/query-client/cache-keys";
@ -7,12 +8,84 @@ interface UseCommentsOptions {
enabled?: boolean; enabled?: boolean;
} }
// ---------------------------------------------------------------------------
// Module-level coordination: when a batch request is in-flight, individual
// useComments queryFns piggy-back on it instead of making their own requests.
// ---------------------------------------------------------------------------
let _batchInflight: Promise<void> | null = null;
let _batchTargetIds = new Set<number>();
export function useComments({ messageId, enabled = true }: UseCommentsOptions) { export function useComments({ messageId, enabled = true }: UseCommentsOptions) {
const queryClient = useQueryClient();
return useQuery({ return useQuery({
queryKey: cacheKeys.comments.byMessage(messageId), queryKey: cacheKeys.comments.byMessage(messageId),
queryFn: async () => { queryFn: async () => {
// Yield one macro-task so the batch prefetch useEffect (which sets
// _batchInflight) has a chance to fire before we decide to fetch.
await new Promise<void>((r) => setTimeout(r, 0));
if (_batchInflight && _batchTargetIds.has(messageId)) {
await _batchInflight;
const cached = queryClient.getQueryData(cacheKeys.comments.byMessage(messageId));
if (cached) return cached;
}
return chatCommentsApiService.getComments({ message_id: messageId }); return chatCommentsApiService.getComments({ message_id: messageId });
}, },
enabled: enabled && !!messageId, enabled: enabled && !!messageId,
staleTime: 30_000,
}); });
} }
/**
* Batch-fetch comments for all given message IDs in a single request, then
* seed the per-message React Query cache so individual useComments hooks
* resolve from cache instead of firing their own requests.
*/
export function useBatchCommentsPreload(messageIds: number[]) {
const queryClient = useQueryClient();
const prevKeyRef = useRef<string>("");
useEffect(() => {
if (!messageIds.length) return;
const key = messageIds
.slice()
.sort((a, b) => a - b)
.join(",");
if (key === prevKeyRef.current) return;
prevKeyRef.current = key;
_batchTargetIds = new Set(messageIds);
let cancelled = false;
const promise = chatCommentsApiService
.getBatchComments({ message_ids: messageIds })
.then((data) => {
if (cancelled) return;
for (const [msgIdStr, commentList] of Object.entries(data.comments_by_message)) {
queryClient.setQueryData(cacheKeys.comments.byMessage(Number(msgIdStr)), commentList);
}
})
.catch(() => {
// Batch failed; individual queryFns will fall through to their own fetch
})
.finally(() => {
if (_batchInflight === promise) {
_batchInflight = null;
_batchTargetIds = new Set();
}
});
_batchInflight = promise;
return () => {
cancelled = true;
if (_batchInflight === promise) {
_batchInflight = null;
_batchTargetIds = new Set();
}
};
}, [messageIds, queryClient]);
}

View file

@ -8,8 +8,11 @@ import {
type DeleteCommentRequest, type DeleteCommentRequest,
deleteCommentRequest, deleteCommentRequest,
deleteCommentResponse, deleteCommentResponse,
type GetBatchCommentsRequest,
type GetCommentsRequest, type GetCommentsRequest,
type GetMentionsRequest, type GetMentionsRequest,
getBatchCommentsRequest,
getBatchCommentsResponse,
getCommentsRequest, getCommentsRequest,
getCommentsResponse, getCommentsResponse,
getMentionsRequest, getMentionsRequest,
@ -22,6 +25,22 @@ import { ValidationError } from "@/lib/error";
import { baseApiService } from "./base-api.service"; import { baseApiService } from "./base-api.service";
class ChatCommentsApiService { class ChatCommentsApiService {
/**
* Batch-fetch comments for multiple messages in one request
*/
getBatchComments = async (request: GetBatchCommentsRequest) => {
const parsed = getBatchCommentsRequest.safeParse(request);
if (!parsed.success) {
const errorMessage = parsed.error.issues.map((issue) => issue.message).join(", ");
throw new ValidationError(`Invalid request: ${errorMessage}`);
}
return baseApiService.post("/api/v1/messages/comments/batch", getBatchCommentsResponse, {
body: { message_ids: parsed.data.message_ids },
});
};
/** /**
* Get comments for a message * Get comments for a message
*/ */