mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 05:12:38 +02:00
commit
4105bd0d7a
13 changed files with 424 additions and 67 deletions
|
|
@ -430,7 +430,10 @@ async def search_knowledge_base_async(
|
|||
connectors = _normalize_connectors(connectors_to_search, available_connectors)
|
||||
perf.info(
|
||||
"[kb_search] searching %d connectors: %s (space=%d, top_k=%d)",
|
||||
len(connectors), connectors[:5], search_space_id, top_k,
|
||||
len(connectors),
|
||||
connectors[:5],
|
||||
search_space_id,
|
||||
top_k,
|
||||
)
|
||||
|
||||
connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
|
||||
|
|
@ -510,13 +513,17 @@ async def search_knowledge_base_async(
|
|||
_, chunks = await connector_method(**kwargs)
|
||||
perf.info(
|
||||
"[kb_search] connector=%s results=%d in %.3fs",
|
||||
connector, len(chunks), time.perf_counter() - t_conn,
|
||||
connector,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_conn,
|
||||
)
|
||||
return chunks
|
||||
except Exception as e:
|
||||
perf.warning(
|
||||
"[kb_search] connector=%s FAILED in %.3fs: %s",
|
||||
connector, time.perf_counter() - t_conn, e,
|
||||
connector,
|
||||
time.perf_counter() - t_conn,
|
||||
e,
|
||||
)
|
||||
return []
|
||||
|
||||
|
|
@ -525,7 +532,8 @@ async def search_knowledge_base_async(
|
|||
*[_search_one_connector(connector) for connector in connectors]
|
||||
)
|
||||
perf.info(
|
||||
"[kb_search] all connectors gathered in %.3fs", time.perf_counter() - t_gather,
|
||||
"[kb_search] all connectors gathered in %.3fs",
|
||||
time.perf_counter() - t_gather,
|
||||
)
|
||||
for chunks in connector_results:
|
||||
all_documents.extend(chunks)
|
||||
|
|
@ -576,7 +584,11 @@ async def search_knowledge_base_async(
|
|||
result = format_documents_for_context(deduplicated, max_chars=output_budget)
|
||||
perf.info(
|
||||
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d space=%d",
|
||||
time.perf_counter() - t0, len(all_documents), len(deduplicated), len(result), search_space_id,
|
||||
time.perf_counter() - t0,
|
||||
len(all_documents),
|
||||
len(deduplicated),
|
||||
len(result),
|
||||
search_space_id,
|
||||
)
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -147,7 +147,9 @@ class IndexingPipelineService:
|
|||
await self.session.commit()
|
||||
perf.info(
|
||||
"[indexing] prepare_for_indexing in %.3fs input=%d output=%d",
|
||||
time.perf_counter() - t0, len(connector_docs), len(documents),
|
||||
time.perf_counter() - t0,
|
||||
len(connector_docs),
|
||||
len(documents),
|
||||
)
|
||||
return documents
|
||||
except IntegrityError:
|
||||
|
|
@ -185,7 +187,8 @@ class IndexingPipelineService:
|
|||
)
|
||||
perf.info(
|
||||
"[indexing] summarize_document doc=%d in %.3fs",
|
||||
document.id, time.perf_counter() - t_step,
|
||||
document.id,
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
elif connector_doc.should_summarize and connector_doc.fallback_summary:
|
||||
content = connector_doc.fallback_summary
|
||||
|
|
@ -196,7 +199,8 @@ class IndexingPipelineService:
|
|||
embedding = embed_text(content)
|
||||
perf.debug(
|
||||
"[indexing] embed_text (summary) doc=%d in %.3fs",
|
||||
document.id, time.perf_counter() - t_step,
|
||||
document.id,
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
|
||||
await self.session.execute(
|
||||
|
|
@ -213,7 +217,9 @@ class IndexingPipelineService:
|
|||
]
|
||||
perf.info(
|
||||
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
||||
document.id, len(chunks), time.perf_counter() - t_step,
|
||||
document.id,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
|
||||
document.content = content
|
||||
|
|
@ -224,7 +230,9 @@ class IndexingPipelineService:
|
|||
await self.session.commit()
|
||||
perf.info(
|
||||
"[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
|
||||
document.id, len(chunks), time.perf_counter() - t_index,
|
||||
document.id,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_index,
|
||||
)
|
||||
log_index_success(ctx, chunk_count=len(chunks))
|
||||
|
||||
|
|
|
|||
|
|
@ -76,7 +76,10 @@ class ChucksHybridSearchRetriever:
|
|||
chunks = result.scalars().all()
|
||||
perf.info(
|
||||
"[chunk_search] vector_search DB query in %.3fs results=%d (total %.3fs) space=%d",
|
||||
time.perf_counter() - t_db, len(chunks), time.perf_counter() - t0, search_space_id,
|
||||
time.perf_counter() - t_db,
|
||||
len(chunks),
|
||||
time.perf_counter() - t0,
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
|
@ -139,7 +142,9 @@ class ChucksHybridSearchRetriever:
|
|||
chunks = result.scalars().all()
|
||||
perf.info(
|
||||
"[chunk_search] full_text_search in %.3fs results=%d space=%d",
|
||||
time.perf_counter() - t0, len(chunks), search_space_id,
|
||||
time.perf_counter() - t0,
|
||||
len(chunks),
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
|
@ -152,6 +157,7 @@ class ChucksHybridSearchRetriever:
|
|||
document_type: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Hybrid search that returns **documents** (not individual chunks).
|
||||
|
|
@ -166,6 +172,7 @@ class ChucksHybridSearchRetriever:
|
|||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
query_embedding: Pre-computed embedding vector. If None, will be computed here.
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing document data and relevance scores. Each dict contains:
|
||||
|
|
@ -183,14 +190,14 @@ class ChucksHybridSearchRetriever:
|
|||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
t_embed = time.perf_counter()
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
perf.debug(
|
||||
"[chunk_search] hybrid_search embedding in %.3fs",
|
||||
time.perf_counter() - t_embed,
|
||||
)
|
||||
if query_embedding is None:
|
||||
embedding_model = config.embedding_model_instance
|
||||
t_embed = time.perf_counter()
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
perf.debug(
|
||||
"[chunk_search] hybrid_search embedding in %.3fs",
|
||||
time.perf_counter() - t_embed,
|
||||
)
|
||||
|
||||
# RRF constants
|
||||
k = 60
|
||||
|
|
@ -291,7 +298,10 @@ class ChucksHybridSearchRetriever:
|
|||
chunks_with_scores = result.all()
|
||||
perf.info(
|
||||
"[chunk_search] hybrid_search RRF query in %.3fs results=%d space=%d type=%s",
|
||||
time.perf_counter() - t_rrf, len(chunks_with_scores), search_space_id, document_type,
|
||||
time.perf_counter() - t_rrf,
|
||||
len(chunks_with_scores),
|
||||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
|
||||
# If no results were found, return an empty list
|
||||
|
|
@ -392,6 +402,9 @@ class ChucksHybridSearchRetriever:
|
|||
|
||||
perf.info(
|
||||
"[chunk_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
|
||||
time.perf_counter() - t0, len(final_docs), search_space_id, document_type,
|
||||
time.perf_counter() - t0,
|
||||
len(final_docs),
|
||||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
return final_docs
|
||||
|
|
|
|||
|
|
@ -71,7 +71,9 @@ class DocumentHybridSearchRetriever:
|
|||
documents = result.scalars().all()
|
||||
perf.info(
|
||||
"[doc_search] vector_search in %.3fs results=%d space=%d",
|
||||
time.perf_counter() - t0, len(documents), search_space_id,
|
||||
time.perf_counter() - t0,
|
||||
len(documents),
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
return documents
|
||||
|
|
@ -133,7 +135,9 @@ class DocumentHybridSearchRetriever:
|
|||
documents = result.scalars().all()
|
||||
perf.info(
|
||||
"[doc_search] full_text_search in %.3fs results=%d space=%d",
|
||||
time.perf_counter() - t0, len(documents), search_space_id,
|
||||
time.perf_counter() - t0,
|
||||
len(documents),
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
return documents
|
||||
|
|
@ -146,6 +150,7 @@ class DocumentHybridSearchRetriever:
|
|||
document_type: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Hybrid search that returns **documents** (not individual chunks).
|
||||
|
|
@ -160,7 +165,7 @@ class DocumentHybridSearchRetriever:
|
|||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
|
||||
query_embedding: Pre-computed embedding vector. If None, will be computed here.
|
||||
"""
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
|
@ -171,9 +176,9 @@ class DocumentHybridSearchRetriever:
|
|||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
if query_embedding is None:
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
# RRF constants
|
||||
k = 60
|
||||
|
|
@ -325,6 +330,9 @@ class DocumentHybridSearchRetriever:
|
|||
|
||||
perf.info(
|
||||
"[doc_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
|
||||
time.perf_counter() - t0, len(final_docs), search_space_id, document_type,
|
||||
time.perf_counter() - t0,
|
||||
len(final_docs),
|
||||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
return final_docs
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.db import User, get_async_session
|
||||
from app.schemas.chat_comments import (
|
||||
CommentBatchRequest,
|
||||
CommentBatchResponse,
|
||||
CommentCreateRequest,
|
||||
CommentListResponse,
|
||||
CommentReplyResponse,
|
||||
|
|
@ -19,6 +21,7 @@ from app.services.chat_comments_service import (
|
|||
create_reply,
|
||||
delete_comment,
|
||||
get_comments_for_message,
|
||||
get_comments_for_messages_batch,
|
||||
get_user_mentions,
|
||||
update_comment,
|
||||
)
|
||||
|
|
@ -27,6 +30,16 @@ from app.users import current_active_user
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/messages/comments/batch", response_model=CommentBatchResponse)
|
||||
async def batch_list_comments(
|
||||
request: CommentBatchRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Batch-fetch comments for multiple messages in one request."""
|
||||
return await get_comments_for_messages_batch(session, request.message_ids, user)
|
||||
|
||||
|
||||
@router.get("/messages/{message_id}/comments", response_model=CommentListResponse)
|
||||
async def list_comments(
|
||||
message_id: int,
|
||||
|
|
|
|||
|
|
@ -87,6 +87,18 @@ class CommentListResponse(BaseModel):
|
|||
total_count: int
|
||||
|
||||
|
||||
class CommentBatchRequest(BaseModel):
|
||||
"""Request for batch-fetching comments for multiple messages."""
|
||||
|
||||
message_ids: list[int] = Field(..., min_length=1, max_length=200)
|
||||
|
||||
|
||||
class CommentBatchResponse(BaseModel):
|
||||
"""Batch response keyed by message_id."""
|
||||
|
||||
comments_by_message: dict[int, CommentListResponse]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mention Schemas
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from app.db import (
|
|||
)
|
||||
from app.schemas.chat_comments import (
|
||||
AuthorResponse,
|
||||
CommentBatchResponse,
|
||||
CommentListResponse,
|
||||
CommentReplyResponse,
|
||||
CommentResponse,
|
||||
|
|
@ -264,6 +265,146 @@ async def get_comments_for_message(
|
|||
)
|
||||
|
||||
|
||||
async def get_comments_for_messages_batch(
|
||||
session: AsyncSession,
|
||||
message_ids: list[int],
|
||||
user: User,
|
||||
) -> CommentBatchResponse:
|
||||
"""
|
||||
Batch-fetch comments for multiple messages in a single DB round-trip.
|
||||
|
||||
Validates that all messages exist and belong to search spaces the user
|
||||
can read comments in, then loads all comments with eager-loaded authors
|
||||
and replies.
|
||||
"""
|
||||
if not message_ids:
|
||||
return CommentBatchResponse(comments_by_message={})
|
||||
|
||||
unique_ids = list(set(message_ids))
|
||||
|
||||
result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.options(selectinload(NewChatMessage.thread))
|
||||
.filter(NewChatMessage.id.in_(unique_ids))
|
||||
)
|
||||
messages = result.scalars().all()
|
||||
msg_map = {m.id: m for m in messages}
|
||||
|
||||
search_space_ids = {m.thread.search_space_id for m in messages}
|
||||
permissions_cache: dict[int, set] = {}
|
||||
for ss_id in search_space_ids:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
ss_id,
|
||||
Permission.COMMENTS_READ.value,
|
||||
"You don't have permission to read comments in this search space",
|
||||
)
|
||||
permissions_cache[ss_id] = await get_user_permissions(session, user.id, ss_id)
|
||||
|
||||
result = await session.execute(
|
||||
select(ChatComment)
|
||||
.options(
|
||||
selectinload(ChatComment.author),
|
||||
selectinload(ChatComment.replies).selectinload(ChatComment.author),
|
||||
)
|
||||
.filter(
|
||||
ChatComment.message_id.in_(unique_ids),
|
||||
ChatComment.parent_id.is_(None),
|
||||
)
|
||||
.order_by(ChatComment.created_at)
|
||||
)
|
||||
top_level_comments = result.scalars().all()
|
||||
|
||||
all_mentioned_uuids: set[UUID] = set()
|
||||
for comment in top_level_comments:
|
||||
all_mentioned_uuids.update(parse_mentions(comment.content))
|
||||
for reply in comment.replies:
|
||||
all_mentioned_uuids.update(parse_mentions(reply.content))
|
||||
|
||||
user_names = await get_user_names_for_mentions(session, all_mentioned_uuids)
|
||||
|
||||
comments_by_msg: dict[int, list[ChatComment]] = {mid: [] for mid in unique_ids}
|
||||
for comment in top_level_comments:
|
||||
comments_by_msg.setdefault(comment.message_id, []).append(comment)
|
||||
|
||||
comments_by_message: dict[int, CommentListResponse] = {}
|
||||
for mid in unique_ids:
|
||||
msg = msg_map.get(mid)
|
||||
if msg is None:
|
||||
comments_by_message[mid] = CommentListResponse(comments=[], total_count=0)
|
||||
continue
|
||||
|
||||
ss_id = msg.thread.search_space_id
|
||||
user_perms = permissions_cache.get(ss_id, set())
|
||||
can_delete_any = has_permission(user_perms, Permission.COMMENTS_DELETE.value)
|
||||
|
||||
comment_responses = []
|
||||
for comment in comments_by_msg.get(mid, []):
|
||||
author = None
|
||||
if comment.author:
|
||||
author = AuthorResponse(
|
||||
id=comment.author.id,
|
||||
display_name=comment.author.display_name,
|
||||
avatar_url=comment.author.avatar_url,
|
||||
email=comment.author.email,
|
||||
)
|
||||
|
||||
replies = []
|
||||
for reply in sorted(comment.replies, key=lambda r: r.created_at):
|
||||
reply_author = None
|
||||
if reply.author:
|
||||
reply_author = AuthorResponse(
|
||||
id=reply.author.id,
|
||||
display_name=reply.author.display_name,
|
||||
avatar_url=reply.author.avatar_url,
|
||||
email=reply.author.email,
|
||||
)
|
||||
is_reply_author = (
|
||||
reply.author_id == user.id if reply.author_id else False
|
||||
)
|
||||
replies.append(
|
||||
CommentReplyResponse(
|
||||
id=reply.id,
|
||||
content=reply.content,
|
||||
content_rendered=render_mentions(reply.content, user_names),
|
||||
author=reply_author,
|
||||
created_at=reply.created_at,
|
||||
updated_at=reply.updated_at,
|
||||
is_edited=reply.updated_at > reply.created_at,
|
||||
can_edit=is_reply_author,
|
||||
can_delete=is_reply_author or can_delete_any,
|
||||
)
|
||||
)
|
||||
|
||||
is_comment_author = (
|
||||
comment.author_id == user.id if comment.author_id else False
|
||||
)
|
||||
comment_responses.append(
|
||||
CommentResponse(
|
||||
id=comment.id,
|
||||
message_id=comment.message_id,
|
||||
content=comment.content,
|
||||
content_rendered=render_mentions(comment.content, user_names),
|
||||
author=author,
|
||||
created_at=comment.created_at,
|
||||
updated_at=comment.updated_at,
|
||||
is_edited=comment.updated_at > comment.created_at,
|
||||
can_edit=is_comment_author,
|
||||
can_delete=is_comment_author or can_delete_any,
|
||||
reply_count=len(replies),
|
||||
replies=replies,
|
||||
)
|
||||
)
|
||||
|
||||
comments_by_message[mid] = CommentListResponse(
|
||||
comments=comment_responses,
|
||||
total_count=len(comment_responses),
|
||||
)
|
||||
|
||||
return CommentBatchResponse(comments_by_message=comments_by_message)
|
||||
|
||||
|
||||
async def create_comment(
|
||||
session: AsyncSession,
|
||||
message_id: int,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from app.db import (
|
|||
Document,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
async_session_maker,
|
||||
)
|
||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.retriever.documents_hybrid_search import DocumentHybridSearchRetriever
|
||||
|
|
@ -248,6 +249,8 @@ class ConnectorService:
|
|||
Returns:
|
||||
List of combined and deduplicated document results
|
||||
"""
|
||||
from app.config import config
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
|
|
@ -257,39 +260,48 @@ class ConnectorService:
|
|||
# Get more results from each retriever for better fusion
|
||||
retriever_top_k = top_k * 2
|
||||
|
||||
# IMPORTANT:
|
||||
# These retrievers share the same AsyncSession. AsyncSession does not permit
|
||||
# concurrent awaits that require DB IO on the same session/connection.
|
||||
# Running these in parallel can raise:
|
||||
# "This session is provisioning a new connection; concurrent operations are not permitted"
|
||||
#
|
||||
# So we run them sequentially.
|
||||
t_chunk = time.perf_counter()
|
||||
chunk_results = await self.chunk_retriever.hybrid_search(
|
||||
query_text=query_text,
|
||||
top_k=retriever_top_k,
|
||||
search_space_id=search_space_id,
|
||||
document_type=document_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
# Pre-compute the embedding once so both retrievers reuse it.
|
||||
t_embed = time.perf_counter()
|
||||
query_embedding = config.embedding_model_instance.embed(query_text)
|
||||
perf.info(
|
||||
"[connector_svc] _combined_rrf chunk_retriever in %.3fs results=%d type=%s",
|
||||
time.perf_counter() - t_chunk, len(chunk_results), document_type,
|
||||
"[connector_svc] _combined_rrf embedding in %.3fs type=%s",
|
||||
time.perf_counter() - t_embed,
|
||||
document_type,
|
||||
)
|
||||
|
||||
t_doc = time.perf_counter()
|
||||
doc_results = await self.document_retriever.hybrid_search(
|
||||
query_text=query_text,
|
||||
top_k=retriever_top_k,
|
||||
search_space_id=search_space_id,
|
||||
document_type=document_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
search_kwargs = {
|
||||
"query_text": query_text,
|
||||
"top_k": retriever_top_k,
|
||||
"search_space_id": search_space_id,
|
||||
"document_type": document_type,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"query_embedding": query_embedding,
|
||||
}
|
||||
|
||||
# Run chunk and document retrievers in parallel using separate DB sessions
|
||||
# so they don't contend on a shared AsyncSession connection.
|
||||
async def _run_chunk_search() -> list[dict[str, Any]]:
|
||||
async with async_session_maker() as session:
|
||||
retriever = ChucksHybridSearchRetriever(session)
|
||||
return await retriever.hybrid_search(**search_kwargs)
|
||||
|
||||
async def _run_doc_search() -> list[dict[str, Any]]:
|
||||
async with async_session_maker() as session:
|
||||
retriever = DocumentHybridSearchRetriever(session)
|
||||
return await retriever.hybrid_search(**search_kwargs)
|
||||
|
||||
t_parallel = time.perf_counter()
|
||||
chunk_results, doc_results = await asyncio.gather(
|
||||
_run_chunk_search(), _run_doc_search()
|
||||
)
|
||||
perf.info(
|
||||
"[connector_svc] _combined_rrf doc_retriever in %.3fs results=%d type=%s",
|
||||
time.perf_counter() - t_doc, len(doc_results), document_type,
|
||||
"[connector_svc] _combined_rrf parallel retrievers in %.3fs "
|
||||
"chunk_results=%d doc_results=%d type=%s",
|
||||
time.perf_counter() - t_parallel,
|
||||
len(chunk_results),
|
||||
len(doc_results),
|
||||
document_type,
|
||||
)
|
||||
|
||||
# Helper to extract document_id from our doc-grouped result
|
||||
|
|
@ -353,7 +365,10 @@ class ConnectorService:
|
|||
|
||||
perf.info(
|
||||
"[connector_svc] _combined_rrf_search TOTAL in %.3fs results=%d type=%s space=%d",
|
||||
time.perf_counter() - t0, len(combined_results), document_type, search_space_id,
|
||||
time.perf_counter() - t0,
|
||||
len(combined_results),
|
||||
document_type,
|
||||
search_space_id,
|
||||
)
|
||||
return combined_results
|
||||
|
||||
|
|
|
|||
|
|
@ -437,14 +437,16 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
except ContextWindowExceededError as e:
|
||||
perf.warning(
|
||||
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count, time.perf_counter() - t0,
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
perf.warning(
|
||||
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count, time.perf_counter() - t0,
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
|
@ -500,14 +502,16 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
except ContextWindowExceededError as e:
|
||||
perf.warning(
|
||||
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count, time.perf_counter() - t0,
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
perf.warning(
|
||||
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count, time.perf_counter() - t0,
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
|
@ -608,14 +612,16 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
except ContextWindowExceededError as e:
|
||||
perf.warning(
|
||||
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count, time.perf_counter() - t0,
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
perf.warning(
|
||||
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count, time.perf_counter() - t0,
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
|
@ -623,7 +629,8 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
t_first_chunk = time.perf_counter()
|
||||
perf.info(
|
||||
"[llm_router] _astream connection established msgs=%d in %.3fs",
|
||||
msg_count, t_first_chunk - t0,
|
||||
msg_count,
|
||||
t_first_chunk - t0,
|
||||
)
|
||||
|
||||
chunk_count = 0
|
||||
|
|
@ -645,7 +652,8 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
|
||||
perf.info(
|
||||
"[llm_router] _astream completed chunks=%d total=%.3fs",
|
||||
chunk_count, time.perf_counter() - t0,
|
||||
chunk_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
|
||||
def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]:
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking";
|
|||
import { Button } from "@/components/ui/button";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import type { Document } from "@/contracts/types/document.types";
|
||||
import { useBatchCommentsPreload } from "@/hooks/use-comments";
|
||||
import { useCommentsElectric } from "@/hooks/use-comments-electric";
|
||||
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
|
@ -309,6 +310,22 @@ const Composer: FC = () => {
|
|||
// Sync comments for the entire thread via Electric SQL (one subscription per thread)
|
||||
useCommentsElectric(threadId);
|
||||
|
||||
// Batch-prefetch comments for all assistant messages so individual useComments
|
||||
// hooks never fire their own network requests (eliminates N+1 API calls).
|
||||
// Return a primitive string from the selector so useSyncExternalStore can
|
||||
// compare snapshots by value and avoid infinite re-render loops.
|
||||
const assistantIdsKey = useAssistantState(({ thread }) =>
|
||||
thread.messages
|
||||
.filter((m) => m.role === "assistant" && m.id?.startsWith("msg-"))
|
||||
.map((m) => m.id!.replace("msg-", ""))
|
||||
.join(",")
|
||||
);
|
||||
const assistantDbMessageIds = useMemo(
|
||||
() => (assistantIdsKey ? assistantIdsKey.split(",").map(Number) : []),
|
||||
[assistantIdsKey]
|
||||
);
|
||||
useBatchCommentsPreload(assistantDbMessageIds);
|
||||
|
||||
// Auto-focus editor on new chat page after mount
|
||||
useEffect(() => {
|
||||
if (isThreadEmpty && !hasAutoFocusedRef.current && editorRef.current) {
|
||||
|
|
|
|||
|
|
@ -82,6 +82,22 @@ export const getCommentsResponse = z.object({
|
|||
total_count: z.number(),
|
||||
});
|
||||
|
||||
/**
|
||||
* Batch-fetch comments for multiple messages
|
||||
*/
|
||||
export const getBatchCommentsRequest = z.object({
|
||||
message_ids: z.array(z.number()).min(1).max(200),
|
||||
});
|
||||
|
||||
export const commentListResponse = z.object({
|
||||
comments: z.array(comment),
|
||||
total_count: z.number(),
|
||||
});
|
||||
|
||||
export const getBatchCommentsResponse = z.object({
|
||||
comments_by_message: z.record(z.string(), commentListResponse),
|
||||
});
|
||||
|
||||
/**
|
||||
* Create comment
|
||||
*/
|
||||
|
|
@ -145,6 +161,8 @@ export type MentionComment = z.infer<typeof mentionComment>;
|
|||
export type Mention = z.infer<typeof mention>;
|
||||
export type GetCommentsRequest = z.infer<typeof getCommentsRequest>;
|
||||
export type GetCommentsResponse = z.infer<typeof getCommentsResponse>;
|
||||
export type GetBatchCommentsRequest = z.infer<typeof getBatchCommentsRequest>;
|
||||
export type GetBatchCommentsResponse = z.infer<typeof getBatchCommentsResponse>;
|
||||
export type CreateCommentRequest = z.infer<typeof createCommentRequest>;
|
||||
export type CreateCommentResponse = z.infer<typeof createCommentResponse>;
|
||||
export type CreateReplyRequest = z.infer<typeof createReplyRequest>;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import { useQuery } from "@tanstack/react-query";
|
||||
import { useQuery, useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { chatCommentsApiService } from "@/lib/apis/chat-comments-api.service";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
|
||||
|
|
@ -7,12 +8,84 @@ interface UseCommentsOptions {
|
|||
enabled?: boolean;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Module-level coordination: when a batch request is in-flight, individual
|
||||
// useComments queryFns piggy-back on it instead of making their own requests.
|
||||
// ---------------------------------------------------------------------------
|
||||
let _batchInflight: Promise<void> | null = null;
|
||||
let _batchTargetIds = new Set<number>();
|
||||
|
||||
export function useComments({ messageId, enabled = true }: UseCommentsOptions) {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
return useQuery({
|
||||
queryKey: cacheKeys.comments.byMessage(messageId),
|
||||
queryFn: async () => {
|
||||
// Yield one macro-task so the batch prefetch useEffect (which sets
|
||||
// _batchInflight) has a chance to fire before we decide to fetch.
|
||||
await new Promise<void>((r) => setTimeout(r, 0));
|
||||
|
||||
if (_batchInflight && _batchTargetIds.has(messageId)) {
|
||||
await _batchInflight;
|
||||
const cached = queryClient.getQueryData(cacheKeys.comments.byMessage(messageId));
|
||||
if (cached) return cached;
|
||||
}
|
||||
|
||||
return chatCommentsApiService.getComments({ message_id: messageId });
|
||||
},
|
||||
enabled: enabled && !!messageId,
|
||||
staleTime: 30_000,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Batch-fetch comments for all given message IDs in a single request, then
|
||||
* seed the per-message React Query cache so individual useComments hooks
|
||||
* resolve from cache instead of firing their own requests.
|
||||
*/
|
||||
export function useBatchCommentsPreload(messageIds: number[]) {
|
||||
const queryClient = useQueryClient();
|
||||
const prevKeyRef = useRef<string>("");
|
||||
|
||||
useEffect(() => {
|
||||
if (!messageIds.length) return;
|
||||
|
||||
const key = messageIds
|
||||
.slice()
|
||||
.sort((a, b) => a - b)
|
||||
.join(",");
|
||||
if (key === prevKeyRef.current) return;
|
||||
prevKeyRef.current = key;
|
||||
|
||||
_batchTargetIds = new Set(messageIds);
|
||||
let cancelled = false;
|
||||
|
||||
const promise = chatCommentsApiService
|
||||
.getBatchComments({ message_ids: messageIds })
|
||||
.then((data) => {
|
||||
if (cancelled) return;
|
||||
for (const [msgIdStr, commentList] of Object.entries(data.comments_by_message)) {
|
||||
queryClient.setQueryData(cacheKeys.comments.byMessage(Number(msgIdStr)), commentList);
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
// Batch failed; individual queryFns will fall through to their own fetch
|
||||
})
|
||||
.finally(() => {
|
||||
if (_batchInflight === promise) {
|
||||
_batchInflight = null;
|
||||
_batchTargetIds = new Set();
|
||||
}
|
||||
});
|
||||
|
||||
_batchInflight = promise;
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
if (_batchInflight === promise) {
|
||||
_batchInflight = null;
|
||||
_batchTargetIds = new Set();
|
||||
}
|
||||
};
|
||||
}, [messageIds, queryClient]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,8 +8,11 @@ import {
|
|||
type DeleteCommentRequest,
|
||||
deleteCommentRequest,
|
||||
deleteCommentResponse,
|
||||
type GetBatchCommentsRequest,
|
||||
type GetCommentsRequest,
|
||||
type GetMentionsRequest,
|
||||
getBatchCommentsRequest,
|
||||
getBatchCommentsResponse,
|
||||
getCommentsRequest,
|
||||
getCommentsResponse,
|
||||
getMentionsRequest,
|
||||
|
|
@ -22,6 +25,22 @@ import { ValidationError } from "@/lib/error";
|
|||
import { baseApiService } from "./base-api.service";
|
||||
|
||||
class ChatCommentsApiService {
|
||||
/**
|
||||
* Batch-fetch comments for multiple messages in one request
|
||||
*/
|
||||
getBatchComments = async (request: GetBatchCommentsRequest) => {
|
||||
const parsed = getBatchCommentsRequest.safeParse(request);
|
||||
|
||||
if (!parsed.success) {
|
||||
const errorMessage = parsed.error.issues.map((issue) => issue.message).join(", ");
|
||||
throw new ValidationError(`Invalid request: ${errorMessage}`);
|
||||
}
|
||||
|
||||
return baseApiService.post("/api/v1/messages/comments/batch", getBatchCommentsResponse, {
|
||||
body: { message_ids: parsed.data.message_ids },
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Get comments for a message
|
||||
*/
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue