This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-03-27 17:08:08 -07:00
commit 947def5c4a
65 changed files with 10278 additions and 7590 deletions

View file

@ -14,6 +14,20 @@ from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
def _is_date_only(value: str) -> bool:
"""Return True when *value* looks like a bare date (YYYY-MM-DD) with no time component."""
return len(value) <= 10 and "T" not in value
def _build_time_body(value: str, context: dict[str, Any] | Any) -> dict[str, str]:
"""Build a Google Calendar start/end body using ``date`` for all-day
events and ``dateTime`` for timed events."""
if _is_date_only(value):
return {"date": value}
tz = context.get("timezone", "UTC") if isinstance(context, dict) else "UTC"
return {"dateTime": value, "timeZone": tz}
def create_update_calendar_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
@ -255,25 +269,13 @@ def create_update_calendar_event_tool(
if final_new_summary is not None:
update_body["summary"] = final_new_summary
if final_new_start_datetime is not None:
tz = (
context.get("timezone", "UTC")
if isinstance(context, dict)
else "UTC"
update_body["start"] = _build_time_body(
final_new_start_datetime, context
)
update_body["start"] = {
"dateTime": final_new_start_datetime,
"timeZone": tz,
}
if final_new_end_datetime is not None:
tz = (
context.get("timezone", "UTC")
if isinstance(context, dict)
else "UTC"
update_body["end"] = _build_time_body(
final_new_end_datetime, context
)
update_body["end"] = {
"dateTime": final_new_end_datetime,
"timeZone": tz,
}
if final_new_description is not None:
update_body["description"] = final_new_description
if final_new_location is not None:

View file

@ -2,13 +2,14 @@
from .change_tracker import categorize_change, fetch_all_changes, get_start_page_token
from .client import GoogleDriveClient
from .content_extractor import download_and_process_file
from .content_extractor import download_and_extract_content, download_and_process_file
from .credentials import get_valid_credentials, validate_credentials
from .folder_manager import get_file_by_id, get_files_in_folder, list_folder_contents
__all__ = [
"GoogleDriveClient",
"categorize_change",
"download_and_extract_content",
"download_and_process_file",
"fetch_all_changes",
"get_file_by_id",

View file

@ -84,22 +84,50 @@ async def get_changes(
return [], None, f"Error getting changes: {e!s}"
async def _is_descendant_of(
client: GoogleDriveClient,
parent_ids: list[str],
target_folder_id: str,
max_depth: int = 20,
) -> bool:
"""Walk up the parent chain to check if any ancestor is *target_folder_id*."""
visited: set[str] = set()
to_check = list(parent_ids)
for _ in range(max_depth):
if not to_check:
return False
current = to_check.pop(0)
if current in visited:
continue
visited.add(current)
if current == target_folder_id:
return True
try:
service = await client.get_service()
meta = (
service.files()
.get(fileId=current, fields="parents", supportsAllDrives=True)
.execute()
)
grandparents = meta.get("parents", [])
to_check.extend(grandparents)
except Exception:
continue
return False
async def _filter_changes_by_folder(
client: GoogleDriveClient,
changes: list[dict[str, Any]],
folder_id: str,
) -> list[dict[str, Any]]:
"""
Filter changes to only include files within the specified folder.
Args:
client: GoogleDriveClient instance
changes: List of changes from API
folder_id: Folder ID to filter by
Returns:
Filtered list of changes
"""
"""Filter changes to only include files within the specified folder
(direct children or nested descendants)."""
filtered = []
for change in changes:
@ -108,14 +136,10 @@ async def _filter_changes_by_folder(
filtered.append(change)
continue
# Check if file is in the folder (or subfolder)
parents = file.get("parents", [])
if folder_id in parents:
filtered.append(change)
else:
# Check if any parent is a descendant of folder_id
# This is a simplified check - full implementation would traverse hierarchy
# For now, we'll include it and let indexer validate
elif await _is_descendant_of(client, parents, folder_id):
filtered.append(change)
return filtered

View file

@ -1,9 +1,15 @@
"""Google Drive API client."""
import asyncio
import io
import logging
import threading
import time
from typing import Any
import httplib2
from google.oauth2.credentials import Credentials
from google_auth_httplib2 import AuthorizedHttp
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseUpload
@ -12,6 +18,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
from .credentials import get_valid_credentials
from .file_types import GOOGLE_DOC, GOOGLE_SHEET
logger = logging.getLogger(__name__)
def _build_thread_http(credentials: Credentials) -> AuthorizedHttp:
"""Create a per-thread HTTP transport so concurrent downloads don't share
the same ``httplib2.Http`` (which is not thread-safe)."""
return AuthorizedHttp(credentials, http=httplib2.Http())
class GoogleDriveClient:
"""Client for Google Drive API operations."""
@ -34,7 +48,9 @@ class GoogleDriveClient:
self.session = session
self.connector_id = connector_id
self._credentials = credentials
self._resolved_credentials: Credentials | None = None
self.service = None
self._service_lock = asyncio.Lock()
async def get_service(self):
"""
@ -49,6 +65,10 @@ class GoogleDriveClient:
if self.service:
return self.service
async with self._service_lock:
if self.service:
return self.service
try:
if self._credentials:
credentials = self._credentials
@ -56,6 +76,7 @@ class GoogleDriveClient:
credentials = await get_valid_credentials(
self.session, self.connector_id
)
self._resolved_credentials = credentials
self.service = build("drive", "v3", credentials=credentials)
return self.service
except Exception as e:
@ -134,6 +155,33 @@ class GoogleDriveClient:
except Exception as e:
return None, f"Error getting file metadata: {e!s}"
@staticmethod
def _sync_download_file(
service, file_id: str, credentials: Credentials,
) -> tuple[bytes | None, str | None]:
"""Blocking download — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
t0 = time.monotonic()
logger.info(f"[download] START file={file_id} thread={thread}")
try:
from googleapiclient.http import MediaIoBaseDownload
http = _build_thread_http(credentials)
request = service.files().get_media(fileId=file_id)
request.http = http
fh = io.BytesIO()
downloader = MediaIoBaseDownload(fh, request)
done = False
while not done:
_, done = downloader.next_chunk()
return fh.getvalue(), None
except HttpError as e:
return None, f"HTTP error downloading file: {e.resp.status}"
except Exception as e:
return None, f"Error downloading file: {e!s}"
finally:
logger.info(f"[download] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
"""
Download binary file content.
@ -144,27 +192,76 @@ class GoogleDriveClient:
Returns:
Tuple of (file content bytes, error message)
"""
try:
service = await self.get_service()
request = service.files().get_media(fileId=file_id)
return await asyncio.to_thread(
self._sync_download_file, service, file_id, self._resolved_credentials,
)
import io
fh = io.BytesIO()
@staticmethod
def _sync_download_file_to_disk(
service, file_id: str, dest_path: str, chunksize: int,
credentials: Credentials,
) -> str | None:
"""Blocking download-to-disk — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
t0 = time.monotonic()
logger.info(f"[download-to-disk] START file={file_id} thread={thread}")
try:
from googleapiclient.http import MediaIoBaseDownload
downloader = MediaIoBaseDownload(fh, request)
http = _build_thread_http(credentials)
request = service.files().get_media(fileId=file_id)
request.http = http
with open(dest_path, "wb") as fh:
downloader = MediaIoBaseDownload(fh, request, chunksize=chunksize)
done = False
while not done:
_, done = downloader.next_chunk()
return fh.getvalue(), None
return None
except HttpError as e:
return None, f"HTTP error downloading file: {e.resp.status}"
return f"HTTP error downloading file: {e.resp.status}"
except Exception as e:
return None, f"Error downloading file: {e!s}"
return f"Error downloading file: {e!s}"
finally:
logger.info(f"[download-to-disk] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
async def download_file_to_disk(
self, file_id: str, dest_path: str, chunksize: int = 5 * 1024 * 1024,
) -> str | None:
"""Stream file directly to disk in chunks, avoiding full in-memory buffering.
Returns error message on failure, None on success.
"""
service = await self.get_service()
return await asyncio.to_thread(
self._sync_download_file_to_disk,
service, file_id, dest_path, chunksize, self._resolved_credentials,
)
@staticmethod
def _sync_export_google_file(
service, file_id: str, mime_type: str, credentials: Credentials,
) -> tuple[bytes | None, str | None]:
"""Blocking export — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
t0 = time.monotonic()
logger.info(f"[export] START file={file_id} thread={thread}")
try:
http = _build_thread_http(credentials)
content = (
service.files()
.export(fileId=file_id, mimeType=mime_type)
.execute(http=http)
)
if not isinstance(content, bytes):
content = content.encode("utf-8")
return content, None
except HttpError as e:
return None, f"HTTP error exporting file: {e.resp.status}"
except Exception as e:
return None, f"Error exporting file: {e!s}"
finally:
logger.info(f"[export] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
async def export_google_file(
self, file_id: str, mime_type: str
@ -179,24 +276,12 @@ class GoogleDriveClient:
Returns:
Tuple of (exported content as bytes, error message)
"""
try:
service = await self.get_service()
content = (
service.files().export(fileId=file_id, mimeType=mime_type).execute()
return await asyncio.to_thread(
self._sync_export_google_file, service, file_id, mime_type,
self._resolved_credentials,
)
# Content is already bytes from the API
# Keep as bytes to support both text and binary formats (like PDF)
if not isinstance(content, bytes):
content = content.encode("utf-8")
return content, None
except HttpError as e:
return None, f"HTTP error exporting file: {e.resp.status}"
except Exception as e:
return None, f"Error exporting file: {e!s}"
async def create_file(
self,
name: str,

View file

@ -1,8 +1,11 @@
"""Content extraction for Google Drive files."""
import asyncio
import logging
import os
import tempfile
import threading
import time
from pathlib import Path
from typing import Any
@ -12,11 +15,182 @@ from app.db import Log
from app.services.task_logging_service import TaskLoggingService
from .client import GoogleDriveClient
from .file_types import get_export_mime_type, is_google_workspace_file, should_skip_file
from .file_types import (
get_export_mime_type,
get_extension_from_mime,
is_google_workspace_file,
should_skip_file,
)
logger = logging.getLogger(__name__)
async def download_and_extract_content(
client: GoogleDriveClient,
file: dict[str, Any],
) -> tuple[str | None, dict[str, Any], str | None]:
"""Download a Google Drive file and extract its content as markdown.
ETL only -- no DB writes, no indexing, no summarization.
Returns:
(markdown_content, drive_metadata, error_message)
On success error_message is None.
"""
file_id = file.get("id")
file_name = file.get("name", "Unknown")
mime_type = file.get("mimeType", "")
if should_skip_file(mime_type):
return None, {}, f"Skipping {mime_type}"
logger.info(f"Downloading file for content extraction: {file_name} ({mime_type})")
drive_metadata: dict[str, Any] = {
"google_drive_file_id": file_id,
"google_drive_file_name": file_name,
"google_drive_mime_type": mime_type,
"source_connector": "google_drive",
}
if "modifiedTime" in file:
drive_metadata["modified_time"] = file["modifiedTime"]
if "createdTime" in file:
drive_metadata["created_time"] = file["createdTime"]
if "size" in file:
drive_metadata["file_size"] = file["size"]
if "webViewLink" in file:
drive_metadata["web_view_link"] = file["webViewLink"]
if "md5Checksum" in file:
drive_metadata["md5_checksum"] = file["md5Checksum"]
if is_google_workspace_file(mime_type):
export_ext = get_extension_from_mime(get_export_mime_type(mime_type) or "")
drive_metadata["exported_as"] = export_ext.lstrip(".") if export_ext else "pdf"
drive_metadata["original_workspace_type"] = mime_type.split(".")[-1]
temp_file_path = None
try:
if is_google_workspace_file(mime_type):
export_mime = get_export_mime_type(mime_type)
if not export_mime:
return None, drive_metadata, f"Cannot export Google Workspace type: {mime_type}"
content_bytes, error = await client.export_google_file(file_id, export_mime)
if error:
return None, drive_metadata, error
extension = get_extension_from_mime(export_mime) or ".pdf"
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
tmp.write(content_bytes)
temp_file_path = tmp.name
else:
extension = (
Path(file_name).suffix
or get_extension_from_mime(mime_type)
or ".bin"
)
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
temp_file_path = tmp.name
error = await client.download_file_to_disk(file_id, temp_file_path)
if error:
return None, drive_metadata, error
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
return markdown, drive_metadata, None
except Exception as e:
logger.warning(f"Failed to extract content from {file_name}: {e!s}")
return None, drive_metadata, str(e)
finally:
if temp_file_path and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except Exception:
pass
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
"""Parse a local file to markdown using the configured ETL service."""
lower = filename.lower()
if lower.endswith((".md", ".markdown", ".txt")):
with open(file_path, encoding="utf-8") as f:
return f.read()
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
from app.config import config as app_config
from litellm import atranscription
stt_service_type = (
"local"
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
else "external"
)
if stt_service_type == "local":
from app.services.stt_service import stt_service
t0 = time.monotonic()
logger.info(f"[local-stt] START file={filename} thread={threading.current_thread().name}")
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
logger.info(f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
text = result.get("text", "")
else:
with open(file_path, "rb") as audio_file:
kwargs: dict[str, Any] = {
"model": app_config.STT_SERVICE,
"file": audio_file,
"api_key": app_config.STT_SERVICE_API_KEY,
}
if app_config.STT_SERVICE_API_BASE:
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
resp = await atranscription(**kwargs)
text = resp.get("text", "")
if not text:
raise ValueError("Transcription returned empty text")
return f"# Transcription of {filename}\n\n{text}"
# Document files -- use configured ETL service
from app.config import config as app_config
if app_config.ETL_SERVICE == "UNSTRUCTURED":
from langchain_unstructured import UnstructuredLoader
from app.utils.document_converters import convert_document_to_markdown
loader = UnstructuredLoader(
file_path,
mode="elements",
post_processors=[],
languages=["eng"],
include_orig_elements=False,
include_metadata=False,
strategy="auto",
)
docs = await loader.aload()
return await convert_document_to_markdown(docs)
if app_config.ETL_SERVICE == "LLAMACLOUD":
from app.tasks.document_processors.file_processors import (
parse_with_llamacloud_retry,
)
result = await parse_with_llamacloud_retry(file_path=file_path, estimated_pages=50)
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
if not markdown_documents:
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
return markdown_documents[0].text
if app_config.ETL_SERVICE == "DOCLING":
from docling.document_converter import DocumentConverter
converter = DocumentConverter()
t0 = time.monotonic()
logger.info(f"[docling] START file={filename} thread={threading.current_thread().name}")
result = await asyncio.to_thread(converter.convert, file_path)
logger.info(f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
return result.document.export_to_markdown()
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
async def download_and_process_file(
client: GoogleDriveClient,
file: dict[str, Any],
@ -68,14 +242,17 @@ async def download_and_process_file(
if error:
return None, error
extension = ".pdf" if export_mime == "application/pdf" else ".txt"
extension = get_extension_from_mime(export_mime) or ".pdf"
else:
content_bytes, error = await client.download_file(file_id)
if error:
return None, error
# Preserve original file extension
extension = Path(file_name).suffix or ".bin"
extension = (
Path(file_name).suffix
or get_extension_from_mime(mime_type)
or ".bin"
)
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp_file:
tmp_file.write(content_bytes)
@ -113,7 +290,12 @@ async def download_and_process_file(
connector_info["metadata"]["md5_checksum"] = file["md5Checksum"]
if is_google_workspace_file(mime_type):
connector_info["metadata"]["exported_as"] = "pdf"
export_ext = get_extension_from_mime(
get_export_mime_type(mime_type) or ""
)
connector_info["metadata"]["exported_as"] = (
export_ext.lstrip(".") if export_ext else "pdf"
)
connector_info["metadata"]["original_workspace_type"] = mime_type.split(
"."
)[-1]

View file

@ -7,11 +7,34 @@ GOOGLE_FOLDER = "application/vnd.google-apps.folder"
GOOGLE_SHORTCUT = "application/vnd.google-apps.shortcut"
EXPORT_FORMATS = {
GOOGLE_DOC: "application/pdf",
GOOGLE_SHEET: "application/pdf",
GOOGLE_DOC: "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
GOOGLE_SHEET: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
GOOGLE_SLIDE: "application/pdf",
}
MIME_TO_EXTENSION: dict[str, str] = {
"application/pdf": ".pdf",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
"application/vnd.ms-excel": ".xls",
"application/msword": ".doc",
"application/vnd.ms-powerpoint": ".ppt",
"text/plain": ".txt",
"text/csv": ".csv",
"text/html": ".html",
"text/markdown": ".md",
"application/json": ".json",
"application/xml": ".xml",
"image/png": ".png",
"image/jpeg": ".jpg",
}
def get_extension_from_mime(mime_type: str) -> str | None:
"""Return a file extension (with leading dot) for a MIME type, or None."""
return MIME_TO_EXTENSION.get(mime_type)
def is_google_workspace_file(mime_type: str) -> bool:
"""Check if file is a Google Workspace file that needs export."""

View file

@ -3,10 +3,17 @@ import hashlib
from app.indexing_pipeline.connector_document import ConnectorDocument
def compute_identifier_hash(
document_type_value: str, unique_id: str, search_space_id: int
) -> str:
"""Return a stable SHA-256 hash from raw identity components."""
combined = f"{document_type_value}:{unique_id}:{search_space_id}"
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
"""Return a stable SHA-256 hash identifying a document by its source identity."""
combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}"
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
return compute_identifier_hash(doc.document_type.value, doc.unique_id, doc.search_space_id)
def compute_content_hash(doc: ConnectorDocument) -> str:

View file

@ -1,17 +1,21 @@
import asyncio
import contextlib
import logging
import time
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Chunk, Document, DocumentStatus
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Chunk, Document, DocumentStatus
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_chunker import chunk_text
from app.indexing_pipeline.document_embedder import embed_texts
from app.indexing_pipeline.document_hashing import (
compute_content_hash,
compute_identifier_hash,
compute_unique_identifier_hash,
)
from app.indexing_pipeline.document_persistence import (
@ -54,6 +58,62 @@ class IndexingPipelineService:
def __init__(self, session: AsyncSession) -> None:
self.session = session
async def migrate_legacy_docs(
self, connector_docs: list[ConnectorDocument]
) -> None:
"""Migrate legacy Composio documents to their native Google type.
For each ConnectorDocument whose document_type has a Composio equivalent
in NATIVE_TO_LEGACY_DOCTYPE, look up the old document by legacy hash and
update its unique_identifier_hash and document_type so that
prepare_for_indexing() can find it under the native hash.
"""
for doc in connector_docs:
legacy_type = NATIVE_TO_LEGACY_DOCTYPE.get(doc.document_type.value)
if not legacy_type:
continue
legacy_hash = compute_identifier_hash(
legacy_type, doc.unique_id, doc.search_space_id
)
result = await self.session.execute(
select(Document).filter(
Document.unique_identifier_hash == legacy_hash
)
)
existing = result.scalars().first()
if existing is None:
continue
native_hash = compute_identifier_hash(
doc.document_type.value, doc.unique_id, doc.search_space_id
)
existing.unique_identifier_hash = native_hash
existing.document_type = doc.document_type
await self.session.commit()
async def index_batch(
self, connector_docs: list[ConnectorDocument], llm
) -> list[Document]:
"""Convenience method: prepare_for_indexing then index each document.
Indexers that need heartbeat callbacks or custom per-document logic
should call prepare_for_indexing() + index() directly instead.
"""
doc_map = {
compute_unique_identifier_hash(cd): cd for cd in connector_docs
}
documents = await self.prepare_for_indexing(connector_docs)
results: list[Document] = []
for document in documents:
connector_doc = doc_map.get(document.unique_identifier_hash)
if connector_doc is None:
continue
result = await self.index(document, connector_doc, llm)
results.append(result)
return results
async def prepare_for_indexing(
self, connector_docs: list[ConnectorDocument]
) -> list[Document]:
@ -200,13 +260,14 @@ class IndexingPipelineService:
)
t_step = time.perf_counter()
chunk_texts = chunk_text(
chunk_texts = await asyncio.to_thread(
chunk_text,
connector_doc.source_markdown,
use_code_chunker=connector_doc.should_use_code_chunker,
)
texts_to_embed = [content, *chunk_texts]
embeddings = embed_texts(texts_to_embed)
embeddings = await asyncio.to_thread(embed_texts, texts_to_embed)
summary_embedding, *chunk_embeddings = embeddings
chunks = [
@ -268,3 +329,126 @@ class IndexingPipelineService:
await self.session.refresh(document)
return document
async def index_batch_parallel(
self,
connector_docs: list[ConnectorDocument],
get_llm: Callable[[AsyncSession], Awaitable],
*,
max_concurrency: int = 4,
on_heartbeat: Callable[[int], Awaitable[None]] | None = None,
heartbeat_interval: float = 30.0,
) -> tuple[list[Document], int, int]:
"""Index documents in parallel with bounded concurrency.
Phase 1 (serial): prepare_for_indexing using self.session.
Phase 2 (parallel): index each document in an isolated session,
bounded by a semaphore to avoid overwhelming APIs/DB.
"""
logger = logging.getLogger(__name__)
perf = get_perf_logger()
t_total = time.perf_counter()
doc_map = {
compute_unique_identifier_hash(cd): cd for cd in connector_docs
}
documents = await self.prepare_for_indexing(connector_docs)
if not documents:
return [], 0, 0
from app.tasks.celery_tasks import get_celery_session_maker
sem = asyncio.Semaphore(max_concurrency)
lock = asyncio.Lock()
indexed_count = 0
failed_count = 0
results: list[Document] = []
last_heartbeat = time.time()
async def _index_one(document: Document) -> Document | Exception:
nonlocal indexed_count, failed_count, last_heartbeat
connector_doc = doc_map.get(document.unique_identifier_hash)
if connector_doc is None:
logger.warning(
"No matching ConnectorDocument for document %s, skipping",
document.id,
)
async with lock:
failed_count += 1
return document
async with sem:
session_maker = get_celery_session_maker()
async with session_maker() as isolated_session:
try:
refetched = await isolated_session.get(
Document, document.id
)
if refetched is None:
async with lock:
failed_count += 1
return document
llm = await get_llm(isolated_session)
iso_pipeline = IndexingPipelineService(isolated_session)
result = await iso_pipeline.index(
refetched, connector_doc, llm
)
async with lock:
if DocumentStatus.is_state(
result.status, DocumentStatus.READY
):
indexed_count += 1
else:
failed_count += 1
if on_heartbeat:
now = time.time()
if now - last_heartbeat >= heartbeat_interval:
await on_heartbeat(indexed_count)
last_heartbeat = now
return result
except Exception as exc:
logger.error(
"Parallel index failed for doc %s: %s",
document.id,
exc,
exc_info=True,
)
async with lock:
failed_count += 1
return exc
tasks = [_index_one(doc) for doc in documents]
t_parallel = time.perf_counter()
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
perf.info(
"[indexing] index_batch_parallel gather docs=%d concurrency=%d "
"indexed=%d failed=%d in %.3fs",
len(documents),
max_concurrency,
indexed_count,
failed_count,
time.perf_counter() - t_parallel,
)
for outcome in outcomes:
if isinstance(outcome, Document):
results.append(outcome)
elif isinstance(outcome, Exception):
pass
perf.info(
"[indexing] index_batch_parallel TOTAL input=%d prepared=%d "
"indexed=%d failed=%d in %.3fs",
len(connector_docs),
len(documents),
indexed_count,
failed_count,
time.perf_counter() - t_total,
)
return results, indexed_count, failed_count

View file

@ -1,19 +1,43 @@
"""
Editor routes for document editing with markdown (Plate.js frontend).
Includes multi-format export (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text).
"""
import asyncio
import io
import logging
import os
import re
import tempfile
from datetime import UTC, datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException
import pypandoc
import typst
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.db import Document, DocumentType, Permission, User, get_async_session
from app.routes.reports_routes import (
ExportFormat,
_FILE_EXTENSIONS,
_MEDIA_TYPES,
_normalize_latex_delimiters,
_strip_wrapping_code_fences,
)
from app.templates.export_helpers import (
get_html_css_path,
get_reference_docx_path,
get_typst_template_path,
)
from app.users import current_active_user
from app.utils.rbac import check_permission
logger = logging.getLogger(__name__)
router = APIRouter()
@ -212,3 +236,153 @@ async def save_document(
"message": "Document saved and will be reindexed in the background",
"updated_at": document.updated_at.isoformat(),
}
@router.get(
"/search-spaces/{search_space_id}/documents/{document_id}/export"
)
async def export_document(
search_space_id: int,
document_id: int,
format: ExportFormat = Query(
ExportFormat.PDF,
description="Export format: pdf, docx, html, latex, epub, odt, or plain",
),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Export a document in the requested format (reuses the report export pipeline)."""
await check_permission(
session,
user,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
)
result = await session.execute(
select(Document)
.options(selectinload(Document.chunks))
.filter(
Document.id == document_id,
Document.search_space_id == search_space_id,
)
)
document = result.scalars().first()
if not document:
raise HTTPException(status_code=404, detail="Document not found")
# Resolve markdown content (same priority as editor-content endpoint)
markdown_content: str | None = document.source_markdown
if markdown_content is None and document.blocknote_document:
from app.utils.blocknote_to_markdown import blocknote_to_markdown
markdown_content = blocknote_to_markdown(document.blocknote_document)
if markdown_content is None:
chunks = sorted(document.chunks, key=lambda c: c.id)
if chunks:
markdown_content = "\n\n".join(chunk.content for chunk in chunks)
if not markdown_content or not markdown_content.strip():
raise HTTPException(
status_code=400, detail="Document has no content to export"
)
markdown_content = _strip_wrapping_code_fences(markdown_content)
markdown_content = _normalize_latex_delimiters(markdown_content)
doc_title = document.title or "Document"
formatted_date = (
document.created_at.strftime("%B %d, %Y") if document.created_at else ""
)
input_fmt = "gfm+tex_math_dollars"
meta_args = ["-M", f"title:{doc_title}", "-M", f"date:{formatted_date}"]
def _convert_and_read() -> bytes:
if format == ExportFormat.PDF:
typst_template = str(get_typst_template_path())
typst_markup: str = pypandoc.convert_text(
markdown_content,
"typst",
format=input_fmt,
extra_args=[
"--standalone",
f"--template={typst_template}",
"-V", "mainfont:Libertinus Serif",
"-V", "codefont:DejaVu Sans Mono",
*meta_args,
],
)
return typst.compile(typst_markup.encode("utf-8"))
if format == ExportFormat.DOCX:
return _pandoc_to_tempfile(
format.value,
["--standalone", f"--reference-doc={get_reference_docx_path()}", *meta_args],
)
if format == ExportFormat.HTML:
html_str: str = pypandoc.convert_text(
markdown_content,
"html5",
format=input_fmt,
extra_args=[
"--standalone", "--embed-resources",
f"--css={get_html_css_path()}",
"--syntax-highlighting=pygments",
*meta_args,
],
)
return html_str.encode("utf-8")
if format == ExportFormat.EPUB:
return _pandoc_to_tempfile("epub3", ["--standalone", *meta_args])
if format == ExportFormat.ODT:
return _pandoc_to_tempfile("odt", ["--standalone", *meta_args])
if format == ExportFormat.LATEX:
tex_str: str = pypandoc.convert_text(
markdown_content, "latex", format=input_fmt,
extra_args=["--standalone", *meta_args],
)
return tex_str.encode("utf-8")
plain_str: str = pypandoc.convert_text(
markdown_content, "plain", format=input_fmt,
extra_args=["--wrap=auto", "--columns=80"],
)
return plain_str.encode("utf-8")
def _pandoc_to_tempfile(output_format: str, extra_args: list[str]) -> bytes:
fd, tmp_path = tempfile.mkstemp(suffix=f".{output_format}")
os.close(fd)
try:
pypandoc.convert_text(
markdown_content, output_format, format=input_fmt,
extra_args=extra_args, outputfile=tmp_path,
)
with open(tmp_path, "rb") as f:
return f.read()
finally:
os.unlink(tmp_path)
try:
loop = asyncio.get_running_loop()
output = await loop.run_in_executor(None, _convert_and_read)
except Exception as e:
logger.exception("Document export failed")
raise HTTPException(status_code=500, detail=f"Export failed: {e!s}") from e
safe_title = (
"".join(c if c.isalnum() or c in " -_" else "_" for c in doc_title)
.strip()[:80]
or "document"
)
ext = _FILE_EXTENSIONS[format]
return StreamingResponse(
io.BytesIO(output),
media_type=_MEDIA_TYPES[format],
headers={"Content-Disposition": f'attachment; filename="{safe_title}.{ext}"'},
)

View file

@ -2329,7 +2329,7 @@ async def run_google_drive_indexing(
try:
from app.tasks.connector_indexers.google_drive_indexer import (
index_google_drive_files,
index_google_drive_single_file,
index_google_drive_selected_files,
)
# Parse the structured data
@ -2402,25 +2402,23 @@ async def run_google_drive_indexing(
exc_info=True,
)
# Index each individual file
for file in items.files:
# Index all selected files together via the parallel pipeline
if items.files:
try:
indexed_count, error_message = await index_google_drive_single_file(
file_tuples = [(f.id, f.name) for f in items.files]
indexed_count, _skipped, file_errors = await index_google_drive_selected_files(
session,
connector_id,
search_space_id,
user_id,
file_id=file.id,
file_name=file.name,
files=file_tuples,
)
if error_message:
errors.append(f"File '{file.name}': {error_message}")
else:
total_indexed += indexed_count
errors.extend(file_errors)
except Exception as e:
errors.append(f"File '{file.name}': {e!s}")
errors.append(f"File batch indexing: {e!s}")
logger.error(
f"Error indexing file {file.name} ({file.id}): {e}",
f"Error batch indexing files: {e}",
exc_info=True,
)

View file

@ -209,8 +209,8 @@ class GoogleCalendarKBSyncService:
)
calendar_id = (document.document_metadata or {}).get(
"calendar_id", "primary"
)
"calendar_id"
) or "primary"
live_event = await loop.run_in_executor(
None,
lambda: (

View file

@ -1,49 +1,74 @@
"""
Confluence connector indexer.
Provides real-time document status updates during indexing using a two-phase approach:
- Phase 1: Create all documents with PENDING status (visible in UI immediately)
- Phase 2: Process each document one by one (PENDING PROCESSING READY/FAILED)
"""
"""Confluence connector indexer using the unified parallel indexing pipeline."""
import contextlib
import time
from collections.abc import Awaitable, Callable
from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from .base import (
calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_connector_by_id,
get_current_timestamp,
logger,
safe_set_chunks,
update_connector_last_indexed,
)
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds
HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
page: dict,
full_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Confluence page dict to a ConnectorDocument."""
page_id = page.get("id", "")
page_title = page.get("title", "")
space_id = page.get("spaceId", "")
comment_count = len(page.get("comments", []))
metadata = {
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"comment_count": comment_count,
"connector_id": connector_id,
"document_type": "Confluence Page",
"connector_type": "Confluence",
}
fallback_summary = (
f"Confluence Page: {page_title}\n\nSpace ID: {space_id}\n\n{full_content}"
)
return ConnectorDocument(
title=page_title,
source_markdown=full_content,
unique_id=page_id,
document_type=DocumentType.CONFLUENCE_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_confluence_pages(
session: AsyncSession,
connector_id: int,
@ -53,26 +78,9 @@ async def index_confluence_pages(
end_date: str | None = None,
update_last_indexed: bool = True,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str | None]:
"""
Index Confluence pages and comments.
Args:
session: Database session
connector_id: ID of the Confluence connector
search_space_id: ID of the search space to store documents in
user_id: User ID
start_date: Start date for indexing (YYYY-MM-DD format)
end_date: End date for indexing (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Returns:
Tuple containing (number of documents indexed, error message or None)
"""
) -> tuple[int, int, str | None]:
"""Index Confluence pages and comments."""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="confluence_pages_indexing",
source="connector_indexing_task",
@ -86,7 +94,6 @@ async def index_confluence_pages(
)
try:
# Get the connector from the database
connector = await get_connector_by_id(
session, connector_id, SearchSourceConnectorType.CONFLUENCE_CONNECTOR
)
@ -98,9 +105,8 @@ async def index_confluence_pages(
"Connector not found",
{"error_type": "ConnectorNotFound"},
)
return 0, f"Connector with ID {connector_id} not found"
return 0, 0, f"Connector with ID {connector_id} not found"
# Initialize Confluence OAuth client
await task_logger.log_task_progress(
log_entry,
f"Initializing Confluence OAuth client for connector {connector_id}",
@ -114,7 +120,6 @@ async def index_confluence_pages(
)
)
# Calculate date range
start_date_str, end_date_str = calculate_date_range(
connector, start_date, end_date, default_days_back=365
)
@ -129,19 +134,14 @@ async def index_confluence_pages(
},
)
# Get pages within date range
try:
pages, error = await confluence_client.get_pages_by_date_range(
start_date=start_date_str, end_date=end_date_str, include_comments=True
)
if error:
# Don't treat "No pages found" as an error that should stop indexing
if "No pages found" in error:
logger.info(f"No Confluence pages found: {error}")
logger.info(
"No pages found is not a critical error, continuing with update"
)
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
@ -156,11 +156,10 @@ async def index_confluence_pages(
f"No Confluence pages found in date range {start_date_str} to {end_date_str}",
{"pages_found": 0},
)
# Close client before returning
if confluence_client:
with contextlib.suppress(Exception):
await confluence_client.close()
return 0, None
return 0, 0, None
else:
logger.error(f"Failed to get Confluence pages: {error}")
await task_logger.log_task_failure(
@ -169,36 +168,35 @@ async def index_confluence_pages(
"API Error",
{"error_type": "APIError"},
)
# Close client on error
if confluence_client:
with contextlib.suppress(Exception):
await confluence_client.close()
return 0, f"Failed to get Confluence pages: {error}"
return 0, 0, f"Failed to get Confluence pages: {error}"
logger.info(f"Retrieved {len(pages)} pages from Confluence API")
except Exception as e:
logger.error(f"Error fetching Confluence pages: {e!s}", exc_info=True)
# Close client on error
if confluence_client:
with contextlib.suppress(Exception):
await confluence_client.close()
return 0, f"Error fetching Confluence pages: {e!s}"
return 0, 0, f"Error fetching Confluence pages: {e!s}"
if not pages:
logger.info("No Confluence pages found for the specified date range")
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
)
await session.commit()
if confluence_client:
with contextlib.suppress(Exception):
await confluence_client.close()
return 0, 0, None
# =======================================================================
# PHASE 1: Analyze all pages, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
documents_indexed = 0
documents_skipped = 0
documents_failed = 0
duplicate_content_count = 0
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
pages_to_process = [] # List of dicts with document and page data
new_documents_created = False
connector_docs: list[ConnectorDocument] = []
for page in pages:
try:
@ -213,12 +211,10 @@ async def index_confluence_pages(
documents_skipped += 1
continue
# Extract page content
page_content = ""
if page.get("body") and page["body"].get("storage"):
page_content = page["body"]["storage"].get("value", "")
# Add comments to content
comments = page.get("comments", [])
comments_content = ""
if comments:
@ -235,61 +231,25 @@ async def index_confluence_pages(
comments_content += f"**Comment by {comment_author}** ({comment_date}):\n{comment_body}\n\n"
# Combine page content with comments
full_content = f"# {page_title}\n\n{page_content}{comments_content}"
if not full_content.strip():
if not page_content.strip() and not comments:
logger.warning(f"Skipping page with no content: {page_title}")
documents_skipped += 1
continue
# Generate unique identifier hash for this Confluence page
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.CONFLUENCE_CONNECTOR, page_id, search_space_id
doc = _build_connector_doc(
page,
full_content,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
)
# Generate content hash
content_hash = generate_content_hash(full_content, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
comment_count = len(comments)
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
pages_to_process.append(
{
"document": existing_document,
"is_new": False,
"full_content": full_content,
"page_content": page_content,
"content_hash": content_hash,
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"comment_count": comment_count,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
session, compute_content_hash(doc)
)
if duplicate_by_content:
@ -302,151 +262,29 @@ async def index_confluence_pages(
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=page_title,
document_type=DocumentType.CONFLUENCE_CONNECTOR,
document_metadata={
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"comment_count": comment_count,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
pages_to_process.append(
{
"document": document,
"is_new": True,
"full_content": full_content,
"page_content": page_content,
"content_hash": content_hash,
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"comment_count": comment_count,
}
)
connector_docs.append(doc)
except Exception as e:
logger.error(f"Error in Phase 1 for page: {e!s}", exc_info=True)
documents_failed += 1
logger.error(f"Error building ConnectorDocument for page: {e!s}", exc_info=True)
documents_skipped += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([p for p in pages_to_process if p['is_new']])} pending documents"
)
await session.commit()
pipeline = IndexingPipelineService(session)
await pipeline.migrate_legacy_docs(connector_docs)
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(pages_to_process)} documents")
async def _get_llm(s: AsyncSession):
return await get_user_long_context_llm(s, user_id, search_space_id)
for item in pages_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
if user_llm and connector.enable_summary:
document_metadata = {
"page_title": item["page_title"],
"page_id": item["page_id"],
"space_id": item["space_id"],
"comment_count": item["comment_count"],
"document_type": "Confluence Page",
"connector_type": "Confluence",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["full_content"], user_llm, document_metadata
)
else:
summary_content = f"Confluence Page: {item['page_title']}\n\nSpace ID: {item['space_id']}\n\n{item['full_content']}"
summary_embedding = embed_text(summary_content)
# Process chunks - using the full page content with comments
chunks = await create_document_chunks(item["full_content"])
# Update document to READY with actual content
document.title = item["page_title"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"page_id": item["page_id"],
"page_title": item["page_title"],
"space_id": item["space_id"],
"comment_count": item["comment_count"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Confluence pages processed so far"
)
await session.commit()
except Exception as e:
logger.error(
f"Error processing page {item.get('page_title', 'Unknown')}: {e!s}",
exc_info=True,
)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
documents_failed += 1
continue # Skip this page and continue with others
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# This ensures the UI shows "Last indexed" instead of "Never indexed"
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit to ensure all documents are persisted (safety net)
logger.info(
f"Final commit: Total {documents_indexed} Confluence pages processed"
)
@ -456,7 +294,6 @@ async def index_confluence_pages(
"Successfully committed all Confluence document changes to database"
)
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
@ -467,11 +304,9 @@ async def index_confluence_pages(
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else:
raise
# Build warning message if there were issues
warning_parts = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
@ -479,7 +314,6 @@ async def index_confluence_pages(
warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully completed Confluence indexing for connector {connector_id}",
@ -490,22 +324,19 @@ async def index_confluence_pages(
"duplicate_content_count": duplicate_content_count,
},
)
logger.info(
f"Confluence indexing completed: {documents_indexed} ready, "
f"{documents_skipped} skipped, {documents_failed} failed "
f"({duplicate_content_count} duplicate content)"
)
# Close the client connection
if confluence_client:
await confluence_client.close()
return documents_indexed, warning_message
return documents_indexed, documents_skipped, warning_message
except SQLAlchemyError as db_error:
await session.rollback()
# Close client if it exists
if confluence_client:
with contextlib.suppress(Exception):
await confluence_client.close()
@ -516,10 +347,9 @@ async def index_confluence_pages(
{"error_type": "SQLAlchemyError"},
)
logger.error(f"Database error: {db_error!s}", exc_info=True)
return 0, f"Database error: {db_error!s}"
return 0, 0, f"Database error: {db_error!s}"
except Exception as e:
await session.rollback()
# Close client if it exists
if confluence_client:
with contextlib.suppress(Exception):
await confluence_client.close()
@ -530,4 +360,4 @@ async def index_confluence_pages(
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index Confluence pages: {e!s}", exc_info=True)
return 0, f"Failed to index Confluence pages: {e!s}"
return 0, 0, f"Failed to index Confluence pages: {e!s}"

View file

@ -1,12 +1,10 @@
"""
Google Calendar connector indexer.
Implements 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
- Phase 2: Process each document: pending processing ready/failed
Uses the shared IndexingPipelineService for document deduplication,
summarization, chunking, and embedding.
"""
import time
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta
@ -15,29 +13,22 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.google_calendar_connector import GoogleCalendarConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from app.utils.google_credentials import (
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
build_composio_credentials,
)
from .base import (
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_connector_by_id,
get_current_timestamp,
logger,
parse_date_flexible,
safe_set_chunks,
update_connector_last_indexed,
)
@ -46,13 +37,60 @@ ACCEPTED_CALENDAR_CONNECTOR_TYPES = {
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
}
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds
HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
event: dict,
event_markdown: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Google Calendar API event dict to a ConnectorDocument."""
event_id = event.get("id", "")
event_summary = event.get("summary", "No Title")
calendar_id = event.get("calendarId", "")
start = event.get("start", {})
end = event.get("end", {})
start_time = start.get("dateTime") or start.get("date", "")
end_time = end.get("dateTime") or end.get("date", "")
location = event.get("location", "")
metadata = {
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"connector_id": connector_id,
"document_type": "Google Calendar Event",
"connector_type": "Google Calendar",
}
fallback_summary = (
f"Google Calendar Event: {event_summary}\n\n{event_markdown}"
)
return ConnectorDocument(
title=event_summary,
source_markdown=event_markdown,
unique_id=event_id,
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_google_calendar_events(
session: AsyncSession,
connector_id: int,
@ -82,7 +120,6 @@ async def index_google_calendar_events(
"""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="google_calendar_events_indexing",
source="connector_indexing_task",
@ -96,7 +133,7 @@ async def index_google_calendar_events(
)
try:
# Accept both native and Composio Calendar connectors
# ── Connector lookup ──────────────────────────────────────────
connector = None
for ct in ACCEPTED_CALENDAR_CONNECTOR_TYPES:
connector = await get_connector_by_id(session, connector_id, ct)
@ -112,7 +149,7 @@ async def index_google_calendar_events(
)
return 0, 0, f"Connector with ID {connector_id} not found"
# Build credentials based on connector type
# ── Credential building ───────────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
@ -184,6 +221,7 @@ async def index_google_calendar_events(
)
return 0, 0, "Google Calendar credentials not found in connector config"
# ── Calendar client init ──────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Initializing Google Calendar client for connector {connector_id}",
@ -203,36 +241,26 @@ async def index_google_calendar_events(
if end_date == "undefined" or end_date == "":
end_date = None
# Calculate date range
# For calendar connectors, allow future dates to index upcoming events
# ── Date range calculation ────────────────────────────────────
if start_date is None or end_date is None:
# Fall back to calculating dates based on last_indexed_at
# Default to today (users can manually select future dates if needed)
calculated_end_date = datetime.now()
# Use last_indexed_at as start date if available, otherwise use 30 days ago
if connector.last_indexed_at:
# Convert dates to be comparable (both timezone-naive)
last_indexed_naive = (
connector.last_indexed_at.replace(tzinfo=None)
if connector.last_indexed_at.tzinfo
else connector.last_indexed_at
)
# Allow future dates - use last_indexed_at as start date
calculated_start_date = last_indexed_naive
logger.info(
f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date"
)
else:
calculated_start_date = datetime.now() - timedelta(
days=365
) # Use 365 days as default for calendar events (matches frontend)
calculated_start_date = datetime.now() - timedelta(days=365)
logger.info(
f"No last_indexed_at found, using {calculated_start_date.strftime('%Y-%m-%d')} (365 days ago) as start date"
)
# Use calculated dates if not provided
start_date_str = (
start_date if start_date else calculated_start_date.strftime("%Y-%m-%d")
)
@ -240,19 +268,14 @@ async def index_google_calendar_events(
end_date if end_date else calculated_end_date.strftime("%Y-%m-%d")
)
else:
# Use provided dates (including future dates)
start_date_str = start_date
end_date_str = end_date
# FIX: Ensure end_date is at least 1 day after start_date to avoid
# "start_date must be strictly before end_date" errors when dates are the same
# (e.g., when last_indexed_at is today)
if start_date_str == end_date_str:
logger.info(
f"Start date ({start_date_str}) equals end date ({end_date_str}), "
"adjusting end date to next day to ensure valid date range"
)
# Parse end_date and add 1 day
try:
end_dt = parse_date_flexible(end_date_str)
except ValueError:
@ -264,6 +287,7 @@ async def index_google_calendar_events(
end_date_str = end_dt.strftime("%Y-%m-%d")
logger.info(f"Adjusted end date to {end_date_str}")
# ── Fetch events ──────────────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Fetching Google Calendar events from {start_date_str} to {end_date_str}",
@ -274,27 +298,19 @@ async def index_google_calendar_events(
},
)
# Get events within date range from primary calendar
try:
events, error = await calendar_client.get_all_primary_calendar_events(
start_date=start_date_str, end_date=end_date_str
)
if error:
# Don't treat "No events found" as an error that should stop indexing
if "No events found" in error:
logger.info(f"No Google Calendar events found: {error}")
logger.info(
"No events found is not a critical error, continuing with update"
)
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
)
await session.commit()
logger.info(
f"Updated last_indexed_at to {connector.last_indexed_at} despite no events found"
)
await task_logger.log_task_success(
log_entry,
@ -304,7 +320,6 @@ async def index_google_calendar_events(
return 0, 0, None
else:
logger.error(f"Failed to get Google Calendar events: {error}")
# Check if this is an authentication error that requires re-authentication
error_message = error
error_type = "APIError"
if (
@ -329,28 +344,15 @@ async def index_google_calendar_events(
logger.error(f"Error fetching Google Calendar events: {e!s}", exc_info=True)
return 0, 0, f"Error fetching Google Calendar events: {e!s}"
documents_indexed = 0
# ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
documents_failed = 0 # Track events that failed processing
duplicate_content_count = (
0 # Track events skipped due to duplicate content_hash
)
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
# =======================================================================
# PHASE 1: Analyze all events, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
events_to_process = [] # List of dicts with document and event data
new_documents_created = False
duplicate_content_count = 0
for event in events:
try:
event_id = event.get("id")
event_summary = event.get("summary", "No Title")
calendar_id = event.get("calendarId", "")
if not event_id:
logger.warning(f"Skipping event with missing ID: {event_summary}")
@ -363,246 +365,55 @@ async def index_google_calendar_events(
documents_skipped += 1
continue
start = event.get("start", {})
end = event.get("end", {})
start_time = start.get("dateTime") or start.get("date", "")
end_time = end.get("dateTime") or end.get("date", "")
location = event.get("location", "")
description = event.get("description", "")
# Generate unique identifier hash for this Google Calendar event
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.GOOGLE_CALENDAR_CONNECTOR, event_id, search_space_id
doc = _build_connector_doc(
event,
event_markdown,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
)
# Generate content hash
content_hash = generate_content_hash(event_markdown, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
# Fallback: legacy Composio hash
if not existing_document:
legacy_hash = generate_unique_identifier_hash(
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
event_id,
search_space_id,
)
existing_document = await check_document_by_unique_identifier(
session, legacy_hash
)
if existing_document:
existing_document.unique_identifier_hash = (
unique_identifier_hash
)
if (
existing_document.document_type
== DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
existing_document.document_type = (
DocumentType.GOOGLE_CALENDAR_CONNECTOR
)
logger.info(
f"Migrated legacy Composio Calendar document: {event_id}"
)
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
events_to_process.append(
{
"document": existing_document,
"is_new": False,
"event_markdown": event_markdown,
"content_hash": content_hash,
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"description": description,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
duplicate = await check_duplicate_document_by_hash(
session, compute_content_hash(doc)
)
if duplicate_by_content:
# A document with the same content already exists (likely from Composio connector)
if duplicate:
logger.info(
f"Event {event_summary} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping to avoid duplicate content."
f"Event {doc.title} already indexed by another connector "
f"(existing document ID: {duplicate.id}, "
f"type: {duplicate.document_type}). Skipping."
)
duplicate_content_count += 1
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=event_summary,
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
document_metadata={
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
events_to_process.append(
{
"document": document,
"is_new": True,
"event_markdown": event_markdown,
"content_hash": content_hash,
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"description": description,
}
)
connector_docs.append(doc)
except Exception as e:
logger.error(f"Error in Phase 1 for event: {e!s}", exc_info=True)
documents_failed += 1
logger.error(f"Error building ConnectorDocument for event: {e!s}", exc_info=True)
documents_skipped += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([e for e in events_to_process if e['is_new']])} pending documents"
)
await session.commit()
# ── Pipeline: migrate legacy docs + parallel index ─────────────
pipeline = IndexingPipelineService(session)
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(events_to_process)} documents")
await pipeline.migrate_legacy_docs(connector_docs)
for item in events_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
async def _get_llm(s):
return await get_user_long_context_llm(s, user_id, search_space_id)
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"event_id": item["event_id"],
"event_summary": item["event_summary"],
"calendar_id": item["calendar_id"],
"start_time": item["start_time"],
"end_time": item["end_time"],
"location": item["location"] or "No location",
"document_type": "Google Calendar Event",
"connector_type": "Google Calendar",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["event_markdown"], user_llm, document_metadata_for_summary
)
else:
summary_content = f"Google Calendar Event: {item['event_summary']}\n\n{item['event_markdown']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["event_markdown"])
# Update document to READY with actual content
document.title = item["event_summary"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"event_id": item["event_id"],
"event_summary": item["event_summary"],
"calendar_id": item["calendar_id"],
"start_time": item["start_time"],
"end_time": item["end_time"],
"location": item["location"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Google Calendar events processed so far"
)
await session.commit()
except Exception as e:
logger.error(f"Error processing Calendar event: {e!s}", exc_info=True)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
documents_failed += 1
continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# ── Finalize ──────────────────────────────────────────────────
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit for any remaining documents not yet committed in batches
logger.info(
f"Final commit: Total {documents_indexed} Google Calendar events processed"
)
@ -612,22 +423,18 @@ async def index_google_calendar_events(
"Successfully committed all Google Calendar document changes to database"
)
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
):
logger.warning(
f"Duplicate content_hash detected during final commit. "
f"This may occur if the same event was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else:
raise
# Build warning message if there were issues
warning_parts = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")

View file

@ -1,12 +1,10 @@
"""
Google Gmail connector indexer.
Implements 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
- Phase 2: Process each document: pending processing ready/failed
Uses the shared IndexingPipelineService for document deduplication,
summarization, chunking, and embedding.
"""
import time
from collections.abc import Awaitable, Callable
from datetime import datetime
@ -15,21 +13,12 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.google_gmail_connector import GoogleGmailConnector
from app.db import (
Document,
DocumentStatus,
DocumentType,
SearchSourceConnectorType,
)
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from app.utils.google_credentials import (
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
build_composio_credentials,
@ -37,12 +26,9 @@ from app.utils.google_credentials import (
from .base import (
calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_connector_by_id,
get_current_timestamp,
logger,
safe_set_chunks,
update_connector_last_indexed,
)
@ -51,13 +37,70 @@ ACCEPTED_GMAIL_CONNECTOR_TYPES = {
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
}
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds
HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
message: dict,
markdown_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Gmail API message dict to a ConnectorDocument."""
message_id = message.get("id", "")
thread_id = message.get("threadId", "")
payload = message.get("payload", {})
headers = payload.get("headers", [])
subject = "No Subject"
sender = "Unknown Sender"
date_str = "Unknown Date"
for header in headers:
name = header.get("name", "").lower()
value = header.get("value", "")
if name == "subject":
subject = value
elif name == "from":
sender = value
elif name == "date":
date_str = value
metadata = {
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date": date_str,
"connector_id": connector_id,
"document_type": "Gmail Message",
"connector_type": "Google Gmail",
}
fallback_summary = (
f"Google Gmail Message: {subject}\n\n"
f"From: {sender}\nDate: {date_str}\n\n"
f"{markdown_content}"
)
return ConnectorDocument(
title=subject,
source_markdown=markdown_content,
unique_id=message_id,
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_google_gmail_messages(
session: AsyncSession,
connector_id: int,
@ -80,7 +123,7 @@ async def index_google_gmail_messages(
start_date: Start date for filtering messages (YYYY-MM-DD format)
end_date: End date for filtering messages (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
max_messages: Maximum number of messages to fetch (default: 100)
max_messages: Maximum number of messages to fetch (default: 1000)
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Returns:
@ -88,7 +131,6 @@ async def index_google_gmail_messages(
"""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="google_gmail_messages_indexing",
source="connector_indexing_task",
@ -103,7 +145,7 @@ async def index_google_gmail_messages(
)
try:
# Accept both native and Composio Gmail connectors
# ── Connector lookup ──────────────────────────────────────────
connector = None
for ct in ACCEPTED_GMAIL_CONNECTOR_TYPES:
connector = await get_connector_by_id(session, connector_id, ct)
@ -117,7 +159,7 @@ async def index_google_gmail_messages(
)
return 0, 0, error_msg
# Build credentials based on connector type
# ── Credential building ───────────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
@ -189,6 +231,7 @@ async def index_google_gmail_messages(
)
return 0, 0, "Google gmail credentials not found in connector config"
# ── Gmail client init ─────────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Initializing Google gmail client for connector {connector_id}",
@ -199,14 +242,11 @@ async def index_google_gmail_messages(
credentials, session, user_id, connector_id
)
# Calculate date range using last_indexed_at if dates not provided
# This ensures Gmail uses the same date logic as other connectors
# (uses last_indexed_at → now, or 365 days back for first-time indexing)
calculated_start_date, calculated_end_date = calculate_date_range(
connector, start_date, end_date, default_days_back=365
)
# Fetch recent Google gmail messages
# ── Fetch messages ────────────────────────────────────────────
logger.info(
f"Fetching emails for connector {connector_id} "
f"from {calculated_start_date} to {calculated_end_date}"
@ -218,7 +258,6 @@ async def index_google_gmail_messages(
)
if error:
# Check if this is an authentication error that requires re-authentication
error_message = error
error_type = "APIError"
if (
@ -243,286 +282,74 @@ async def index_google_gmail_messages(
logger.info(f"Found {len(messages)} Google gmail messages to index")
documents_indexed = 0
# ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
documents_failed = 0 # Track messages that failed processing
duplicate_content_count = (
0 # Track messages skipped due to duplicate content_hash
)
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
# =======================================================================
# PHASE 1: Analyze all messages, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
messages_to_process = [] # List of dicts with document and message data
new_documents_created = False
duplicate_content_count = 0
for message in messages:
try:
# Extract message information
message_id = message.get("id", "")
thread_id = message.get("threadId", "")
# Extract headers for subject and sender
payload = message.get("payload", {})
headers = payload.get("headers", [])
subject = "No Subject"
sender = "Unknown Sender"
date_str = "Unknown Date"
for header in headers:
name = header.get("name", "").lower()
value = header.get("value", "")
if name == "subject":
subject = value
elif name == "from":
sender = value
elif name == "date":
date_str = value
if not message_id:
logger.warning(f"Skipping message with missing ID: {subject}")
logger.warning("Skipping message with missing ID")
documents_skipped += 1
continue
# Format message to markdown
markdown_content = gmail_connector.format_message_to_markdown(message)
if not markdown_content.strip():
logger.warning(f"Skipping message with no content: {subject}")
logger.warning(f"Skipping message with no content: {message_id}")
documents_skipped += 1
continue
# Generate unique identifier hash for this Gmail message
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.GOOGLE_GMAIL_CONNECTOR, message_id, search_space_id
doc = _build_connector_doc(
message,
markdown_content,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
)
# Generate content hash
content_hash = generate_content_hash(markdown_content, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
# Fallback: legacy Composio hash
if not existing_document:
legacy_hash = generate_unique_identifier_hash(
DocumentType.COMPOSIO_GMAIL_CONNECTOR,
message_id,
search_space_id,
)
existing_document = await check_document_by_unique_identifier(
session, legacy_hash
)
if existing_document:
existing_document.unique_identifier_hash = (
unique_identifier_hash
)
if (
existing_document.document_type
== DocumentType.COMPOSIO_GMAIL_CONNECTOR
):
existing_document.document_type = (
DocumentType.GOOGLE_GMAIL_CONNECTOR
)
logger.info(
f"Migrated legacy Composio Gmail document: {message_id}"
)
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
messages_to_process.append(
{
"document": existing_document,
"is_new": False,
"markdown_content": markdown_content,
"content_hash": content_hash,
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date_str": date_str,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
duplicate = await check_duplicate_document_by_hash(
session, compute_content_hash(doc)
)
if duplicate_by_content:
if duplicate:
logger.info(
f"Gmail message {subject} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping."
f"Gmail message {doc.title} already indexed by another connector "
f"(existing document ID: {duplicate.id}, "
f"type: {duplicate.document_type}). Skipping."
)
duplicate_content_count += 1
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=subject,
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
document_metadata={
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date": date_str,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
messages_to_process.append(
{
"document": document,
"is_new": True,
"markdown_content": markdown_content,
"content_hash": content_hash,
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date_str": date_str,
}
)
connector_docs.append(doc)
except Exception as e:
logger.error(f"Error in Phase 1 for message: {e!s}", exc_info=True)
documents_failed += 1
logger.error(f"Error building ConnectorDocument for message: {e!s}", exc_info=True)
documents_skipped += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([m for m in messages_to_process if m['is_new']])} pending documents"
)
await session.commit()
# ── Pipeline: migrate legacy docs + parallel index ─────────────
pipeline = IndexingPipelineService(session)
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
await pipeline.migrate_legacy_docs(connector_docs)
for item in messages_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
async def _get_llm(s):
return await get_user_long_context_llm(s, user_id, search_space_id)
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"message_id": item["message_id"],
"thread_id": item["thread_id"],
"subject": item["subject"],
"sender": item["sender"],
"date": item["date_str"],
"document_type": "Gmail Message",
"connector_type": "Google Gmail",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["markdown_content"],
user_llm,
document_metadata_for_summary,
)
else:
summary_content = f"Google Gmail Message: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])
# Update document to READY with actual content
document.title = item["subject"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"message_id": item["message_id"],
"thread_id": item["thread_id"],
"subject": item["subject"],
"sender": item["sender"],
"date": item["date_str"],
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Gmail messages processed so far"
)
await session.commit()
except Exception as e:
logger.error(f"Error processing Gmail message: {e!s}", exc_info=True)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
documents_failed += 1
continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# ── Finalize ──────────────────────────────────────────────────
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit for any remaining documents not yet committed in batches
logger.info(f"Final commit: Total {documents_indexed} Gmail messages processed")
try:
await session.commit()
@ -530,22 +357,18 @@ async def index_google_gmail_messages(
"Successfully committed all Google Gmail document changes to database"
)
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
):
logger.warning(
f"Duplicate content_hash detected during final commit. "
f"This may occur if the same message was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else:
raise
# Build warning message if there were issues
warning_parts = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
@ -555,7 +378,6 @@ async def index_google_gmail_messages(
total_processed = documents_indexed
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully completed Google Gmail indexing for connector {connector_id}",

View file

@ -1,49 +1,80 @@
"""
Jira connector indexer.
Provides real-time document status updates during indexing using a two-phase approach:
- Phase 1: Create all documents with PENDING status (visible in UI immediately)
- Phase 2: Process each document one by one (PENDING PROCESSING READY/FAILED)
"""
"""Jira connector indexer using the unified parallel indexing pipeline."""
import contextlib
import time
from collections.abc import Awaitable, Callable
from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.jira_history import JiraHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from .base import (
calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_connector_by_id,
get_current_timestamp,
logger,
safe_set_chunks,
update_connector_last_indexed,
)
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds - update notification every 30 seconds
HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
issue: dict,
formatted_issue: dict,
issue_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Jira issue dict to a ConnectorDocument."""
issue_id = issue.get("key", "")
issue_identifier = issue.get("key", "")
issue_title = issue.get("id", "")
state = formatted_issue.get("status", "Unknown")
priority = formatted_issue.get("priority", "Unknown")
comment_count = len(formatted_issue.get("comments", []))
metadata = {
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"priority": priority,
"comment_count": comment_count,
"connector_id": connector_id,
"document_type": "Jira Issue",
"connector_type": "Jira",
}
fallback_summary = (
f"Jira Issue {issue_identifier}: {issue_title}\n\n"
f"Status: {state}\n\n{issue_content}"
)
return ConnectorDocument(
title=f"{issue_identifier}: {issue_title}",
source_markdown=issue_content,
unique_id=issue_id,
document_type=DocumentType.JIRA_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_jira_issues(
session: AsyncSession,
connector_id: int,
@ -53,26 +84,9 @@ async def index_jira_issues(
end_date: str | None = None,
update_last_indexed: bool = True,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str | None]:
"""
Index Jira issues and comments.
Args:
session: Database session
connector_id: ID of the Jira connector
search_space_id: ID of the search space to store documents in
user_id: User ID
start_date: Start date for indexing (YYYY-MM-DD format)
end_date: End date for indexing (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Returns:
Tuple containing (number of documents indexed, error message or None)
"""
) -> tuple[int, int, str | None]:
"""Index Jira issues and comments."""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="jira_issues_indexing",
source="connector_indexing_task",
@ -86,7 +100,6 @@ async def index_jira_issues(
)
try:
# Get the connector from the database
connector = await get_connector_by_id(
session, connector_id, SearchSourceConnectorType.JIRA_CONNECTOR
)
@ -98,24 +111,15 @@ async def index_jira_issues(
"Connector not found",
{"error_type": "ConnectorNotFound"},
)
return 0, f"Connector with ID {connector_id} not found"
return 0, 0, f"Connector with ID {connector_id} not found"
# Initialize Jira client with internal refresh capability
# Token refresh will happen automatically when needed
await task_logger.log_task_progress(
log_entry,
f"Initializing Jira client for connector {connector_id}",
{"stage": "client_initialization"},
)
logger.info(f"Initializing Jira client for connector {connector_id}")
# Create connector with session and connector_id for internal refresh
# Token refresh will happen automatically when needed
jira_client = JiraHistoryConnector(session=session, connector_id=connector_id)
# Calculate date range
# Handle "undefined" strings from frontend
if start_date == "undefined" or start_date == "":
start_date = None
if end_date == "undefined" or end_date == "":
@ -135,19 +139,14 @@ async def index_jira_issues(
},
)
# Get issues within date range
try:
issues, error = await jira_client.get_issues_by_date_range(
start_date=start_date_str, end_date=end_date_str, include_comments=True
)
if error:
# Don't treat "No issues found" as an error that should stop indexing
if "No issues found" in error:
logger.info(f"No Jira issues found: {error}")
logger.info(
"No issues found is not a critical error, continuing with update"
)
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
@ -162,7 +161,8 @@ async def index_jira_issues(
f"No Jira issues found in date range {start_date_str} to {end_date_str}",
{"issues_found": 0},
)
return 0, None
await jira_client.close()
return 0, 0, None
else:
logger.error(f"Failed to get Jira issues: {error}")
await task_logger.log_task_failure(
@ -171,29 +171,30 @@ async def index_jira_issues(
"API Error",
{"error_type": "APIError"},
)
return 0, f"Failed to get Jira issues: {error}"
await jira_client.close()
return 0, 0, f"Failed to get Jira issues: {error}"
logger.info(f"Retrieved {len(issues)} issues from Jira API")
except Exception as e:
logger.error(f"Error fetching Jira issues: {e!s}", exc_info=True)
return 0, f"Error fetching Jira issues: {e!s}"
await jira_client.close()
return 0, 0, f"Error fetching Jira issues: {e!s}"
# =======================================================================
# PHASE 1: Analyze all issues, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
documents_indexed = 0
if not issues:
logger.info("No Jira issues found for the specified date range")
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
)
await session.commit()
await jira_client.close()
return 0, 0, None
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
documents_failed = 0
duplicate_content_count = 0
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
issues_to_process = [] # List of dicts with document and issue data
new_documents_created = False
for issue in issues:
try:
issue_id = issue.get("key")
@ -207,10 +208,7 @@ async def index_jira_issues(
documents_skipped += 1
continue
# Format the issue for better readability
formatted_issue = jira_client.format_issue(issue)
# Convert to markdown
issue_content = jira_client.format_issue_to_markdown(formatted_issue)
if not issue_content:
@ -220,53 +218,19 @@ async def index_jira_issues(
documents_skipped += 1
continue
# Generate unique identifier hash for this Jira issue
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.JIRA_CONNECTOR, issue_id, search_space_id
doc = _build_connector_doc(
issue,
formatted_issue,
issue_content,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
)
# Generate content hash
content_hash = generate_content_hash(issue_content, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
comment_count = len(formatted_issue.get("comments", []))
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
issues_to_process.append(
{
"document": existing_document,
"is_new": False,
"issue_content": issue_content,
"content_hash": content_hash,
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"formatted_issue": formatted_issue,
"comment_count": comment_count,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
session, compute_content_hash(doc)
)
if duplicate_by_content:
@ -279,160 +243,37 @@ async def index_jira_issues(
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=f"{issue_identifier}: {issue_title}",
document_type=DocumentType.JIRA_CONNECTOR,
document_metadata={
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": formatted_issue.get("status", "Unknown"),
"comment_count": comment_count,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
issues_to_process.append(
{
"document": document,
"is_new": True,
"issue_content": issue_content,
"content_hash": content_hash,
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"formatted_issue": formatted_issue,
"comment_count": comment_count,
}
)
except Exception as e:
logger.error(f"Error in Phase 1 for issue: {e!s}", exc_info=True)
documents_failed += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([i for i in issues_to_process if i['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(issues_to_process)} documents")
for item in issues_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm and connector.enable_summary:
document_metadata = {
"issue_key": item["issue_identifier"],
"issue_title": item["issue_title"],
"status": item["formatted_issue"].get("status", "Unknown"),
"priority": item["formatted_issue"].get("priority", "Unknown"),
"comment_count": item["comment_count"],
"document_type": "Jira Issue",
"connector_type": "Jira",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["issue_content"], user_llm, document_metadata
)
else:
summary_content = f"Jira Issue {item['issue_identifier']}: {item['issue_title']}\n\n{item['issue_content']}"
summary_embedding = embed_text(summary_content)
# Process chunks - using the full issue content with comments
chunks = await create_document_chunks(item["issue_content"])
# Update document to READY with actual content
document.title = f"{item['issue_identifier']}: {item['issue_title']}"
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"issue_id": item["issue_id"],
"issue_identifier": item["issue_identifier"],
"issue_title": item["issue_title"],
"state": item["formatted_issue"].get("status", "Unknown"),
"comment_count": item["comment_count"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Jira issues processed so far"
)
await session.commit()
connector_docs.append(doc)
except Exception as e:
logger.error(
f"Error processing issue {item.get('issue_identifier', 'Unknown')}: {e!s}",
f"Error building ConnectorDocument for issue {issue_identifier}: {e!s}",
exc_info=True,
)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
documents_failed += 1
continue # Skip this issue and continue with others
documents_skipped += 1
continue
pipeline = IndexingPipelineService(session)
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s: AsyncSession):
return await get_user_long_context_llm(s, user_id, search_space_id)
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# This ensures the UI shows "Last indexed" instead of "Never indexed"
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit to ensure all documents are persisted (safety net)
logger.info(f"Final commit: Total {documents_indexed} Jira issues processed")
try:
await session.commit()
logger.info("Successfully committed all JIRA document changes to database")
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
@ -447,7 +288,6 @@ async def index_jira_issues(
else:
raise
# Build warning message if there were issues
warning_parts = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
@ -455,7 +295,6 @@ async def index_jira_issues(
warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully completed JIRA indexing for connector {connector_id}",
@ -466,17 +305,13 @@ async def index_jira_issues(
"duplicate_content_count": duplicate_content_count,
},
)
logger.info(
f"JIRA indexing completed: {documents_indexed} ready, "
f"{documents_skipped} skipped, {documents_failed} failed "
f"({duplicate_content_count} duplicate content)"
)
# Clean up the connector
await jira_client.close()
return documents_indexed, warning_message
return documents_indexed, documents_skipped, warning_message
except SQLAlchemyError as db_error:
await session.rollback()
@ -487,11 +322,10 @@ async def index_jira_issues(
{"error_type": "SQLAlchemyError"},
)
logger.error(f"Database error: {db_error!s}", exc_info=True)
# Clean up the connector in case of error
if "jira_client" in locals():
with contextlib.suppress(Exception):
await jira_client.close()
return 0, f"Database error: {db_error!s}"
return 0, 0, f"Database error: {db_error!s}"
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(
@ -501,8 +335,7 @@ async def index_jira_issues(
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index JIRA issues: {e!s}", exc_info=True)
# Clean up the connector in case of error
if "jira_client" in locals():
with contextlib.suppress(Exception):
await jira_client.close()
return 0, f"Failed to index JIRA issues: {e!s}"
return 0, 0, f"Failed to index JIRA issues: {e!s}"

View file

@ -1,48 +1,84 @@
"""
Linear connector indexer.
Implements 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
- Phase 2: Process each document: pending processing ready/failed
Uses the shared IndexingPipelineService for document deduplication,
summarization, chunking, and embedding with bounded parallel indexing.
"""
import time
from collections.abc import Awaitable, Callable
from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.linear_connector import LinearConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from .base import (
calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_connector_by_id,
get_current_timestamp,
logger,
safe_set_chunks,
update_connector_last_indexed,
)
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds - update notification every 30 seconds
HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
issue: dict,
formatted_issue: dict,
issue_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Linear issue dict to a ConnectorDocument."""
issue_id = issue.get("id", "")
issue_identifier = issue.get("identifier", "")
issue_title = issue.get("title", "")
state = formatted_issue.get("state", "Unknown")
priority = formatted_issue.get("priority", "Unknown")
comment_count = len(formatted_issue.get("comments", []))
metadata = {
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"priority": priority,
"comment_count": comment_count,
"connector_id": connector_id,
"document_type": "Linear Issue",
"connector_type": "Linear",
}
fallback_summary = (
f"Linear Issue {issue_identifier}: {issue_title}\n\n"
f"Status: {state}\n\n{issue_content}"
)
return ConnectorDocument(
title=f"{issue_identifier}: {issue_title}",
source_markdown=issue_content,
unique_id=issue_id,
document_type=DocumentType.LINEAR_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_linear_issues(
session: AsyncSession,
connector_id: int,
@ -52,26 +88,15 @@ async def index_linear_issues(
end_date: str | None = None,
update_last_indexed: bool = True,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str | None]:
) -> tuple[int, int, str | None]:
"""
Index Linear issues and comments.
Args:
session: Database session
connector_id: ID of the Linear connector
search_space_id: ID of the search space to store documents in
user_id: ID of the user
start_date: Start date for indexing (YYYY-MM-DD format)
end_date: End date for indexing (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Returns:
Tuple containing (number of documents indexed, error message or None)
Tuple of (indexed_count, skipped_count, warning_or_error_message)
"""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="linear_issues_indexing",
source="connector_indexing_task",
@ -85,7 +110,7 @@ async def index_linear_issues(
)
try:
# Get the connector
# ── Connector lookup ──────────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Retrieving Linear connector {connector_id} from database",
@ -104,11 +129,11 @@ async def index_linear_issues(
{"error_type": "ConnectorNotFound"},
)
return (
0,
0,
f"Connector with ID {connector_id} not found or is not a Linear connector",
)
# Check if access_token exists (support both new OAuth format and old API key format)
if not connector.config.get("access_token") and not connector.config.get(
"LINEAR_API_KEY"
):
@ -118,26 +143,22 @@ async def index_linear_issues(
"Missing Linear access token",
{"error_type": "MissingToken"},
)
return 0, "Linear access token not found in connector config"
return 0, 0, "Linear access token not found in connector config"
# Initialize Linear client with internal refresh capability
# ── Client init ───────────────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Initializing Linear client for connector {connector_id}",
{"stage": "client_initialization"},
)
# Create connector with session and connector_id for internal refresh
# Token refresh will happen automatically when needed
linear_client = LinearConnector(session=session, connector_id=connector_id)
# Handle 'undefined' string from frontend (treat as None)
if start_date == "undefined" or start_date == "":
start_date = None
if end_date == "undefined" or end_date == "":
end_date = None
# Calculate date range
start_date_str, end_date_str = calculate_date_range(
connector, start_date, end_date, default_days_back=365
)
@ -154,37 +175,34 @@ async def index_linear_issues(
},
)
# Get issues within date range
# ── Fetch issues ──────────────────────────────────────────────
try:
issues, error = await linear_client.get_issues_by_date_range(
start_date=start_date_str, end_date=end_date_str, include_comments=True
start_date=start_date_str,
end_date=end_date_str,
include_comments=True,
)
if error:
# Don't treat "No issues found" as an error that should stop indexing
if "No issues found" in error:
logger.info(f"No Linear issues found: {error}")
logger.info(
"No issues found is not a critical error, continuing with update"
)
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
)
await session.commit()
logger.info(
f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found"
)
return 0, None
return 0, 0, None
else:
logger.error(f"Failed to get Linear issues: {error}")
return 0, f"Failed to get Linear issues: {error}"
return 0, 0, f"Failed to get Linear issues: {error}"
logger.info(f"Retrieved {len(issues)} issues from Linear API")
except Exception as e:
logger.error(f"Exception when calling Linear API: {e!s}", exc_info=True)
return 0, f"Failed to get Linear issues: {e!s}"
logger.error(
f"Exception when calling Linear API: {e!s}", exc_info=True
)
return 0, 0, f"Failed to get Linear issues: {e!s}"
if not issues:
logger.info("No Linear issues found for the specified date range")
@ -193,19 +211,12 @@ async def index_linear_issues(
session, connector, update_last_indexed
)
await session.commit()
logger.info(
f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found"
)
return 0, None # Return None instead of error message when no issues found
return 0, 0, None
# Track the number of documents indexed
documents_indexed = 0
# ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
documents_failed = 0 # Track issues that failed processing
skipped_issues = []
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
duplicate_content_count = 0
await task_logger.log_task_progress(
log_entry,
@ -213,13 +224,6 @@ async def index_linear_issues(
{"stage": "process_issues", "total_issues": len(issues)},
)
# =======================================================================
# PHASE 1: Analyze all issues, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
issues_to_process = [] # List of dicts with document and issue data
new_documents_created = False
for issue in issues:
try:
issue_id = issue.get("id", "")
@ -230,271 +234,102 @@ async def index_linear_issues(
logger.warning(
f"Skipping issue with missing ID or title: {issue_id or 'Unknown'}"
)
skipped_issues.append(
f"{issue_identifier or 'Unknown'} (missing data)"
)
documents_skipped += 1
continue
# Format the issue first to get well-structured data
formatted_issue = linear_client.format_issue(issue)
# Convert issue to markdown format
issue_content = linear_client.format_issue_to_markdown(formatted_issue)
issue_content = linear_client.format_issue_to_markdown(
formatted_issue
)
if not issue_content:
logger.warning(
f"Skipping issue with no content: {issue_identifier} - {issue_title}"
)
skipped_issues.append(f"{issue_identifier} (no content)")
documents_skipped += 1
continue
# Generate unique identifier hash for this Linear issue
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.LINEAR_CONNECTOR, issue_id, search_space_id
)
# Generate content hash
content_hash = generate_content_hash(issue_content, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
state = formatted_issue.get("state", "Unknown")
description = formatted_issue.get("description", "")
comment_count = len(formatted_issue.get("comments", []))
priority = formatted_issue.get("priority", "Unknown")
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
logger.info(
f"Document for Linear issue {issue_identifier} unchanged. Skipping."
)
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
issues_to_process.append(
{
"document": existing_document,
"is_new": False,
"issue_content": issue_content,
"content_hash": content_hash,
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"description": description,
"comment_count": comment_count,
"priority": priority,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
)
if duplicate_by_content:
logger.info(
f"Linear issue {issue_identifier} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping."
)
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=f"{issue_identifier}: {issue_title}",
document_type=DocumentType.LINEAR_CONNECTOR,
document_metadata={
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"comment_count": comment_count,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
doc = _build_connector_doc(
issue,
formatted_issue,
issue_content,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
issues_to_process.append(
{
"document": document,
"is_new": True,
"issue_content": issue_content,
"content_hash": content_hash,
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"description": description,
"comment_count": comment_count,
"priority": priority,
}
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
)
except Exception as e:
logger.error(f"Error in Phase 1 for issue: {e!s}", exc_info=True)
documents_failed += 1
with session.no_autoflush:
duplicate = await check_duplicate_document_by_hash(
session, compute_content_hash(doc)
)
if duplicate:
logger.info(
f"Linear issue {doc.title} already indexed by another connector "
f"(existing document ID: {duplicate.id}, "
f"type: {duplicate.document_type}). Skipping."
)
duplicate_content_count += 1
documents_skipped += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([i for i in issues_to_process if i['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(issues_to_process)} documents")
for item in issues_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"issue_id": item["issue_identifier"],
"issue_title": item["issue_title"],
"state": item["state"],
"priority": item["priority"],
"comment_count": item["comment_count"],
"document_type": "Linear Issue",
"connector_type": "Linear",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["issue_content"], user_llm, document_metadata_for_summary
)
else:
summary_content = f"Linear Issue {item['issue_identifier']}: {item['issue_title']}\n\nStatus: {item['state']}\n\n{item['issue_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["issue_content"])
# Update document to READY with actual content
document.title = f"{item['issue_identifier']}: {item['issue_title']}"
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"issue_id": item["issue_id"],
"issue_identifier": item["issue_identifier"],
"issue_title": item["issue_title"],
"state": item["state"],
"comment_count": item["comment_count"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Linear issues processed so far"
)
await session.commit()
connector_docs.append(doc)
except Exception as e:
logger.error(
f"Error processing issue {item.get('issue_identifier', 'Unknown')}: {e!s}",
f"Error building ConnectorDocument for issue: {e!s}",
exc_info=True,
)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
skipped_issues.append(
f"{item.get('issue_identifier', 'Unknown')} (processing error)"
)
documents_failed += 1
documents_skipped += 1
continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# ── Pipeline: migrate legacy docs + parallel index ────────────
pipeline = IndexingPipelineService(session)
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s):
return await get_user_long_context_llm(s, user_id, search_space_id)
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
# ── Finalize ──────────────────────────────────────────────────
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit for any remaining documents not yet committed in batches
logger.info(f"Final commit: Total {documents_indexed} Linear issues processed")
logger.info(
f"Final commit: Total {documents_indexed} Linear issues processed"
)
try:
await session.commit()
logger.info(
"Successfully committed all Linear document changes to database"
)
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
):
logger.warning(
f"Duplicate content_hash detected during final commit. "
f"This may occur if the same issue was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
else:
raise
# Build warning message if there were issues
warning_parts = []
warning_parts: list[str] = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
if documents_failed > 0:
warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully completed Linear indexing for connector {connector_id}",
@ -503,7 +338,7 @@ async def index_linear_issues(
"documents_indexed": documents_indexed,
"documents_skipped": documents_skipped,
"documents_failed": documents_failed,
"skipped_issues_count": len(skipped_issues),
"duplicate_content_count": duplicate_content_count,
},
)
@ -511,7 +346,7 @@ async def index_linear_issues(
f"Linear indexing completed: {documents_indexed} ready, "
f"{documents_skipped} skipped, {documents_failed} failed"
)
return documents_indexed, warning_message
return documents_indexed, documents_skipped, warning_message
except SQLAlchemyError as db_error:
await session.rollback()
@ -522,7 +357,7 @@ async def index_linear_issues(
{"error_type": "SQLAlchemyError"},
)
logger.error(f"Database error: {db_error!s}", exc_info=True)
return 0, f"Database error: {db_error!s}"
return 0, 0, f"Database error: {db_error!s}"
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(
@ -532,4 +367,4 @@ async def index_linear_issues(
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index Linear issues: {e!s}", exc_info=True)
return 0, f"Failed to index Linear issues: {e!s}"
return 0, 0, f"Failed to index Linear issues: {e!s}"

View file

@ -1,12 +1,10 @@
"""
Notion connector indexer.
Implements real-time document status updates using a two-phase approach:
- Phase 1: Create all documents with PENDING status (visible in UI immediately)
- Phase 2: Process each document one by one (pending processing ready/failed)
Uses the shared IndexingPipelineService for document deduplication,
summarization, chunking, and embedding with bounded parallel indexing.
"""
import time
from collections.abc import Awaitable, Callable
from datetime import datetime
@ -14,42 +12,64 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.notion_history import NotionHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from app.utils.notion_utils import process_blocks
from .base import (
build_document_metadata_string,
calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_connector_by_id,
get_current_timestamp,
logger,
safe_set_chunks,
update_connector_last_indexed,
)
# Type alias for retry callback
# Signature: async callback(retry_reason, attempt, max_attempts, wait_seconds) -> None
RetryCallbackType = Callable[[str, int, int, float], Awaitable[None]]
# Type alias for heartbeat callback
# Signature: async callback(indexed_count) -> None
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds - update notification every 30 seconds
HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
page: dict,
markdown_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Notion page dict to a ConnectorDocument."""
page_id = page.get("page_id", "")
page_title = page.get("title", f"Untitled page ({page_id})")
metadata = {
"page_title": page_title,
"page_id": page_id,
"connector_id": connector_id,
"document_type": "Notion Page",
"connector_type": "Notion",
}
fallback_summary = f"Notion Page: {page_title}\n\n{markdown_content}"
return ConnectorDocument(
title=page_title,
source_markdown=markdown_content,
unique_id=page_id,
document_type=DocumentType.NOTION_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_notion_pages(
session: AsyncSession,
connector_id: int,
@ -60,30 +80,15 @@ async def index_notion_pages(
update_last_indexed: bool = True,
on_retry_callback: RetryCallbackType | None = None,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str | None]:
) -> tuple[int, int, str | None]:
"""
Index Notion pages from all accessible pages.
Args:
session: Database session
connector_id: ID of the Notion connector
search_space_id: ID of the search space to store documents in
user_id: ID of the user
start_date: Start date for indexing (YYYY-MM-DD format)
end_date: End date for indexing (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
on_retry_callback: Optional callback for retry progress notifications.
Signature: async callback(retry_reason, attempt, max_attempts, wait_seconds)
retry_reason is one of: 'rate_limit', 'server_error', 'timeout'
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Called periodically with (indexed_count) to prevent task appearing stuck.
Returns:
Tuple containing (number of documents indexed, error message or None)
Tuple of (indexed_count, skipped_count, warning_or_error_message)
"""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="notion_pages_indexing",
source="connector_indexing_task",
@ -97,7 +102,7 @@ async def index_notion_pages(
)
try:
# Get the connector
# ── Connector lookup ──────────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Retrieving Notion connector {connector_id} from database",
@ -116,11 +121,11 @@ async def index_notion_pages(
{"error_type": "ConnectorNotFound"},
)
return (
0,
0,
f"Connector with ID {connector_id} not found or is not a Notion connector",
)
# Check if access_token exists (support both new OAuth format and old integration token format)
if not connector.config.get("access_token") and not connector.config.get(
"NOTION_INTEGRATION_TOKEN"
):
@ -130,9 +135,9 @@ async def index_notion_pages(
"Missing Notion access token",
{"error_type": "MissingToken"},
)
return 0, "Notion access token not found in connector config"
return 0, 0, "Notion access token not found in connector config"
# Initialize Notion client with internal refresh capability
# ── Client init ───────────────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Initializing Notion client for connector {connector_id}",
@ -141,18 +146,15 @@ async def index_notion_pages(
logger.info(f"Initializing Notion client for connector {connector_id}")
# Handle 'undefined' string from frontend (treat as None)
if start_date == "undefined" or start_date == "":
start_date = None
if end_date == "undefined" or end_date == "":
end_date = None
# Calculate date range using the shared utility function
start_date_str, end_date_str = calculate_date_range(
connector, start_date, end_date, default_days_back=365
)
# Convert YYYY-MM-DD to ISO format for Notion API
start_date_iso = datetime.strptime(start_date_str, "%Y-%m-%d").strftime(
"%Y-%m-%dT%H:%M:%SZ"
)
@ -160,13 +162,10 @@ async def index_notion_pages(
"%Y-%m-%dT%H:%M:%SZ"
)
# Create connector with session and connector_id for internal refresh
# Token refresh will happen automatically when needed
notion_client = NotionHistoryConnector(
session=session, connector_id=connector_id
)
# Set retry callback if provided (for user notifications during rate limits)
if on_retry_callback:
notion_client.set_retry_callback(on_retry_callback)
@ -182,21 +181,19 @@ async def index_notion_pages(
},
)
# Get all pages
# ── Fetch pages ───────────────────────────────────────────────
try:
pages = await notion_client.get_all_pages(
start_date=start_date_iso, end_date=end_date_iso
)
logger.info(f"Found {len(pages)} Notion pages")
# Get count of pages that had unsupported content skipped
pages_with_skipped_content = notion_client.get_skipped_content_count()
if pages_with_skipped_content > 0:
logger.info(
f"{pages_with_skipped_content} pages had Notion AI content skipped (not available via API)"
)
# Check if using legacy integration token and log warning
if notion_client.is_using_legacy_token():
logger.warning(
f"Connector {connector_id} is using legacy integration token. "
@ -204,8 +201,6 @@ async def index_notion_pages(
)
except Exception as e:
error_str = str(e)
# Check if this is an unsupported block type error (transcription, ai_block, etc.)
# These are known Notion API limitations and should be logged as warnings, not errors
unsupported_block_errors = [
"transcription is not supported",
"ai_block is not supported",
@ -216,7 +211,6 @@ async def index_notion_pages(
)
if is_unsupported_block_error:
# Log as warning since this is a known Notion API limitation
logger.warning(
f"Notion API limitation for connector {connector_id}: {error_str}. "
"This is a known issue with Notion AI blocks (transcription, ai_block) "
@ -229,7 +223,6 @@ async def index_notion_pages(
{"error_type": "UnsupportedBlockType", "is_known_limitation": True},
)
else:
# Log as error for other failures
logger.error(
f"Error fetching Notion pages for connector {connector_id}: {error_str}",
exc_info=True,
@ -242,7 +235,7 @@ async def index_notion_pages(
)
await notion_client.close()
return 0, f"Failed to get Notion pages: {e!s}"
return 0, 0, f"Failed to get Notion pages: {e!s}"
if not pages:
await task_logger.log_task_success(
@ -252,21 +245,17 @@ async def index_notion_pages(
{"pages_found": 0},
)
logger.info("No Notion pages found to index")
# CRITICAL: Update timestamp even when no pages found so Zero syncs
await update_connector_last_indexed(session, connector, update_last_indexed)
await update_connector_last_indexed(
session, connector, update_last_indexed
)
await session.commit()
await notion_client.close()
return 0, None # Success with 0 pages, not an error
return 0, 0, None
# Track the number of documents indexed
documents_indexed = 0
# ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
documents_failed = 0
duplicate_content_count = 0
skipped_pages = []
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
await task_logger.log_task_progress(
log_entry,
@ -274,13 +263,6 @@ async def index_notion_pages(
{"stage": "process_pages", "total_pages": len(pages)},
)
# =======================================================================
# PHASE 1: Analyze all pages, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
pages_to_process = [] # List of dicts with document and page data
new_documents_created = False
for page in pages:
try:
page_id = page.get("page_id")
@ -293,225 +275,71 @@ async def index_notion_pages(
if not page_content:
logger.info(f"No content found in page {page_title}. Skipping.")
skipped_pages.append(f"{page_title} (no content)")
documents_skipped += 1
continue
# Convert page content to markdown format
markdown_content = f"# Notion Page: {page_title}\n\n"
markdown_content += process_blocks(page_content)
# Format document metadata
metadata_sections = [
("METADATA", [f"PAGE_TITLE: {page_title}", f"PAGE_ID: {page_id}"]),
(
"CONTENT",
[
"FORMAT: markdown",
"TEXT_START",
markdown_content,
"TEXT_END",
],
),
]
# Build the document string
combined_document_string = build_document_metadata_string(
metadata_sections
if not markdown_content.strip():
logger.warning(
f"Skipping page with empty markdown: {page_title}"
)
# Generate unique identifier hash for this Notion page
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTION_CONNECTOR, page_id, search_space_id
)
# Generate content hash
content_hash = generate_content_hash(
combined_document_string, search_space_id
)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
pages_to_process.append(
{
"document": existing_document,
"is_new": False,
"markdown_content": markdown_content,
"content_hash": content_hash,
"page_id": page_id,
"page_title": page_title,
}
doc = _build_connector_doc(
page,
markdown_content,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
duplicate = await check_duplicate_document_by_hash(
session, compute_content_hash(doc)
)
if duplicate_by_content:
if duplicate:
logger.info(
f"Notion page {page_title} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping."
f"Notion page {doc.title} already indexed by another connector "
f"(existing document ID: {duplicate.id}, "
f"type: {duplicate.document_type}). Skipping."
)
duplicate_content_count += 1
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=page_title,
document_type=DocumentType.NOTION_CONNECTOR,
document_metadata={
"page_title": page_title,
"page_id": page_id,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
pages_to_process.append(
{
"document": document,
"is_new": True,
"markdown_content": markdown_content,
"content_hash": content_hash,
"page_id": page_id,
"page_title": page_title,
}
)
connector_docs.append(doc)
except Exception as e:
logger.error(f"Error in Phase 1 for page: {e!s}", exc_info=True)
documents_failed += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([p for p in pages_to_process if p['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(pages_to_process)} documents")
for item in pages_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"page_title": item["page_title"],
"page_id": item["page_id"],
"document_type": "Notion Page",
"connector_type": "Notion",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["markdown_content"],
user_llm,
document_metadata_for_summary,
)
else:
summary_content = f"Notion Page: {item['page_title']}\n\n{item['markdown_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])
# Update document to READY with actual content
document.title = item["page_title"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"page_title": item["page_title"],
"page_id": item["page_id"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Notion pages processed so far"
)
await session.commit()
except Exception as e:
logger.error(f"Error processing Notion page: {e!s}", exc_info=True)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
f"Error building ConnectorDocument for page: {e!s}",
exc_info=True,
)
skipped_pages.append(f"{item['page_title']} (processing error)")
documents_failed += 1
documents_skipped += 1
continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# ── Pipeline: migrate legacy docs + parallel index ────────────
pipeline = IndexingPipelineService(session)
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s):
return await get_user_long_context_llm(s, user_id, search_space_id)
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
# ── Finalize ──────────────────────────────────────────────────
await update_connector_last_indexed(session, connector, update_last_indexed)
total_processed = documents_indexed
# Final commit to ensure all documents are persisted (safety net)
logger.info(f"Final commit: Total {documents_indexed} documents processed")
try:
await session.commit()
@ -519,59 +347,53 @@ async def index_notion_pages(
"Successfully committed all Notion document changes to database"
)
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
):
logger.warning(
f"Duplicate content_hash detected during final commit. "
f"This may occur if the same page was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else:
raise
# Get final count of pages with skipped Notion AI content
# ── Build warning / notification message ──────────────────────
pages_with_skipped_ai_content = notion_client.get_skipped_content_count()
# Build warning message if there were issues
warning_parts = []
warning_parts: list[str] = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
if documents_failed > 0:
warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None
# Prepare result message with user-friendly notification about skipped content
result_message = None
if skipped_pages:
result_message = f"Processed {total_processed} pages. Skipped {len(skipped_pages)} pages: {', '.join(skipped_pages)}"
else:
result_message = f"Processed {total_processed} pages."
# Add user-friendly message about skipped Notion AI content
notification_parts: list[str] = []
if pages_with_skipped_ai_content > 0:
result_message += (
" Audio transcriptions and AI summaries from Notion aren't accessible "
"via their API - all other content was saved."
notification_parts.append(
"Some Notion AI content couldn't be synced (API limitation)"
)
if notion_client.is_using_legacy_token():
notification_parts.append(
"Using legacy token. Reconnect with OAuth for better reliability."
)
if warning_parts:
notification_parts.append(", ".join(warning_parts))
user_notification_message = (
" ".join(notification_parts) if notification_parts else None
)
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully completed Notion indexing for connector {connector_id}",
{
"pages_processed": total_processed,
"pages_processed": documents_indexed,
"documents_indexed": documents_indexed,
"documents_skipped": documents_skipped,
"documents_failed": documents_failed,
"duplicate_content_count": duplicate_content_count,
"skipped_pages_count": len(skipped_pages),
"pages_with_skipped_ai_content": pages_with_skipped_ai_content,
"result_message": result_message,
},
)
@ -581,35 +403,9 @@ async def index_notion_pages(
f"({duplicate_content_count} duplicate content)"
)
# Clean up the async client
await notion_client.close()
# Build user-friendly notification messages
# This will be shown in the notification to inform users
notification_parts = []
if pages_with_skipped_ai_content > 0:
notification_parts.append(
"Some Notion AI content couldn't be synced (API limitation)"
)
if notion_client.is_using_legacy_token():
notification_parts.append(
"Using legacy token. Reconnect with OAuth for better reliability."
)
# Include warning message if there were issues
if warning_message:
notification_parts.append(warning_message)
user_notification_message = (
" ".join(notification_parts) if notification_parts else None
)
return (
total_processed,
user_notification_message,
)
return documents_indexed, documents_skipped, user_notification_message
except SQLAlchemyError as db_error:
await session.rollback()
@ -622,10 +418,9 @@ async def index_notion_pages(
logger.error(
f"Database error during Notion indexing: {db_error!s}", exc_info=True
)
# Clean up the async client in case of error
if "notion_client" in locals():
await notion_client.close()
return 0, f"Database error: {db_error!s}"
return 0, 0, f"Database error: {db_error!s}"
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(
@ -635,7 +430,6 @@ async def index_notion_pages(
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index Notion pages: {e!s}", exc_info=True)
# Clean up the async client in case of error
if "notion_client" in locals():
await notion_client.close()
return 0, f"Failed to index Notion pages: {e!s}"
return 0, 0, f"Failed to index Notion pages: {e!s}"

View file

@ -1,5 +1,6 @@
import hashlib
import logging
import threading
import warnings
import numpy as np
@ -11,6 +12,12 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE
logger = logging.getLogger(__name__)
# HuggingFace fast tokenizers (Rust-backed) are not thread-safe — concurrent
# access from multiple threads causes "RuntimeError: Already borrowed".
# This reentrant lock serialises tokenizer + embedding model access so that
# asyncio.to_thread calls from index_batch_parallel don't collide.
_embedding_lock = threading.RLock()
def _get_embedding_max_tokens() -> int:
"""Get the max token limit for the configured embedding model.
@ -36,6 +43,7 @@ def truncate_for_embedding(text: str) -> str:
if len(text) // 3 <= max_tokens:
return text
with _embedding_lock:
tokenizer = config.embedding_model_instance.get_tokenizer()
tokens = tokenizer.encode(text)
if len(tokens) <= max_tokens:
@ -52,6 +60,7 @@ def embed_text(text: str) -> np.ndarray:
"""Truncate text to fit and embed it. Drop-in replacement for
``config.embedding_model_instance.embed(text)`` that never exceeds the
model's context window."""
with _embedding_lock:
return config.embedding_model_instance.embed(truncate_for_embedding(text))
@ -66,6 +75,7 @@ def embed_texts(texts: list[str]) -> list[np.ndarray]:
"""
if not texts:
return []
with _embedding_lock:
truncated = [truncate_for_embedding(t) for t in texts]
if config.is_local_embedding_model:
return [config.embedding_model_instance.embed(t) for t in truncated]

View file

@ -0,0 +1,111 @@
"""Integration tests: Calendar indexer builds ConnectorDocuments that flow through the pipeline."""
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _cal_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
return ConnectorDocument(
title=f"Event {unique_id}",
source_markdown=f"## Calendar Event\n\nDetails for {unique_id}",
unique_id=unique_id,
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=True,
fallback_summary=f"Calendar: Event {unique_id}",
metadata={
"event_id": unique_id,
"start_time": "2025-01-15T10:00:00",
"end_time": "2025-01-15T11:00:00",
"document_type": "Google Calendar Event",
},
)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_calendar_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A Calendar ConnectorDocument flows through prepare + index to a READY document."""
space_id = db_search_space.id
doc = _cal_doc(
unique_id="evt-1",
search_space_id=space_id,
connector_id=db_connector.id,
user_id=str(db_user.id),
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([doc])
assert len(prepared) == 1
await service.index(prepared[0], doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.search_space_id == space_id)
)
row = result.scalars().first()
assert row is not None
assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_calendar_legacy_doc_migrated(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A legacy Composio Calendar doc is migrated and reused."""
space_id = db_search_space.id
user_id = str(db_user.id)
evt_id = "evt-legacy-cal"
legacy_hash = compute_identifier_hash(
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR.value, evt_id, space_id
)
legacy_doc = Document(
title="Old Calendar Event",
document_type=DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
content="old summary",
content_hash=f"ch-{legacy_hash[:12]}",
unique_identifier_hash=legacy_hash,
source_markdown="## Old event",
search_space_id=space_id,
created_by_id=user_id,
embedding=[0.1] * _EMBEDDING_DIM,
status={"state": "ready"},
)
db_session.add(legacy_doc)
await db_session.flush()
original_id = legacy_doc.id
connector_doc = _cal_doc(
unique_id=evt_id,
search_space_id=space_id,
connector_id=db_connector.id,
user_id=user_id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(select(Document).filter(Document.id == original_id))
row = result.scalars().first()
assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR
native_hash = compute_identifier_hash(
DocumentType.GOOGLE_CALENDAR_CONNECTOR.value, evt_id, space_id
)
assert row.unique_identifier_hash == native_hash

View file

@ -0,0 +1,170 @@
"""Integration tests: Drive indexer builds ConnectorDocuments that flow through the pipeline."""
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _drive_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
return ConnectorDocument(
title=f"File {unique_id}.pdf",
source_markdown=f"## Document Content\n\nText from file {unique_id}",
unique_id=unique_id,
document_type=DocumentType.GOOGLE_DRIVE_FILE,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=True,
fallback_summary=f"File: {unique_id}.pdf",
metadata={
"google_drive_file_id": unique_id,
"google_drive_file_name": f"{unique_id}.pdf",
"document_type": "Google Drive File",
},
)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_drive_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A Drive ConnectorDocument flows through prepare + index to a READY document."""
space_id = db_search_space.id
doc = _drive_doc(
unique_id="file-abc",
search_space_id=space_id,
connector_id=db_connector.id,
user_id=str(db_user.id),
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([doc])
assert len(prepared) == 1
await service.index(prepared[0], doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.search_space_id == space_id)
)
row = result.scalars().first()
assert row is not None
assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_drive_legacy_doc_migrated(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A legacy Composio Drive doc is migrated and reused."""
space_id = db_search_space.id
user_id = str(db_user.id)
file_id = "file-legacy-drive"
legacy_hash = compute_identifier_hash(
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR.value, file_id, space_id
)
legacy_doc = Document(
title="Old Drive File",
document_type=DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
content="old file summary",
content_hash=f"ch-{legacy_hash[:12]}",
unique_identifier_hash=legacy_hash,
source_markdown="## Old file content",
search_space_id=space_id,
created_by_id=user_id,
embedding=[0.1] * _EMBEDDING_DIM,
status={"state": "ready"},
)
db_session.add(legacy_doc)
await db_session.flush()
original_id = legacy_doc.id
connector_doc = _drive_doc(
unique_id=file_id,
search_space_id=space_id,
connector_id=db_connector.id,
user_id=user_id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(select(Document).filter(Document.id == original_id))
row = result.scalars().first()
assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE
native_hash = compute_identifier_hash(
DocumentType.GOOGLE_DRIVE_FILE.value, file_id, space_id
)
assert row.unique_identifier_hash == native_hash
async def test_should_skip_file_skips_failed_document(
db_session, db_search_space, db_user,
):
"""A FAILED document with unchanged md5 must be skipped — user can manually retry via Quick Index."""
import importlib
import sys
import types
pkg = "app.tasks.connector_indexers"
stub = pkg not in sys.modules
if stub:
mod = types.ModuleType(pkg)
mod.__path__ = ["app/tasks/connector_indexers"]
mod.__package__ = pkg
sys.modules[pkg] = mod
try:
gdm = importlib.import_module(
"app.tasks.connector_indexers.google_drive_indexer"
)
_should_skip_file = gdm._should_skip_file
finally:
if stub:
sys.modules.pop(pkg, None)
space_id = db_search_space.id
file_id = "file-failed-drive"
md5 = "abc123deadbeef"
doc_hash = compute_identifier_hash(
DocumentType.GOOGLE_DRIVE_FILE.value, file_id, space_id
)
failed_doc = Document(
title="Failed File.pdf",
document_type=DocumentType.GOOGLE_DRIVE_FILE,
content="LLM rate limit exceeded",
content_hash=f"ch-{doc_hash[:12]}",
unique_identifier_hash=doc_hash,
source_markdown="## Real content",
search_space_id=space_id,
created_by_id=str(db_user.id),
embedding=[0.1] * _EMBEDDING_DIM,
status=DocumentStatus.failed("LLM rate limit exceeded"),
document_metadata={
"google_drive_file_id": file_id,
"google_drive_file_name": "Failed File.pdf",
"md5_checksum": md5,
},
)
db_session.add(failed_doc)
await db_session.flush()
incoming_file = {"id": file_id, "name": "Failed File.pdf", "mimeType": "application/pdf", "md5Checksum": md5}
should_skip, msg = await _should_skip_file(db_session, incoming_file, space_id)
assert should_skip, "FAILED documents must be skipped during automatic sync"
assert "failed" in msg.lower()

View file

@ -0,0 +1,116 @@
"""Integration tests: Gmail indexer builds ConnectorDocuments that flow through the pipeline."""
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import (
compute_identifier_hash,
compute_unique_identifier_hash,
)
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _gmail_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
"""Build a Gmail-style ConnectorDocument like the real indexer does."""
return ConnectorDocument(
title=f"Subject for {unique_id}",
source_markdown=f"## Email\n\nBody of {unique_id}",
unique_id=unique_id,
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=True,
fallback_summary=f"Gmail: Subject for {unique_id}",
metadata={
"message_id": unique_id,
"from": "sender@example.com",
"document_type": "Gmail Message",
},
)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_gmail_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A Gmail ConnectorDocument flows through prepare + index to a READY document."""
space_id = db_search_space.id
doc = _gmail_doc(
unique_id="msg-pipeline-1",
search_space_id=space_id,
connector_id=db_connector.id,
user_id=str(db_user.id),
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([doc])
assert len(prepared) == 1
await service.index(prepared[0], doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.search_space_id == space_id)
)
row = result.scalars().first()
assert row is not None
assert row.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
assert row.source_markdown == doc.source_markdown
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_gmail_legacy_doc_migrated_then_reused(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A legacy Composio Gmail doc is migrated then reused by the pipeline."""
space_id = db_search_space.id
user_id = str(db_user.id)
msg_id = "msg-legacy-gmail"
legacy_hash = compute_identifier_hash(
DocumentType.COMPOSIO_GMAIL_CONNECTOR.value, msg_id, space_id
)
legacy_doc = Document(
title="Old Gmail",
document_type=DocumentType.COMPOSIO_GMAIL_CONNECTOR,
content="old summary",
content_hash=f"ch-{legacy_hash[:12]}",
unique_identifier_hash=legacy_hash,
source_markdown="## Old content",
search_space_id=space_id,
created_by_id=user_id,
embedding=[0.1] * _EMBEDDING_DIM,
status={"state": "ready"},
)
db_session.add(legacy_doc)
await db_session.flush()
original_id = legacy_doc.id
connector_doc = _gmail_doc(
unique_id=msg_id,
search_space_id=space_id,
connector_id=db_connector.id,
user_id=user_id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
prepared = await service.prepare_for_indexing([connector_doc])
assert len(prepared) == 1
assert prepared[0].id == original_id
assert prepared[0].document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
native_hash = compute_identifier_hash(
DocumentType.GOOGLE_GMAIL_CONNECTOR.value, msg_id, space_id
)
assert prepared[0].unique_identifier_hash == native_hash

View file

@ -0,0 +1,55 @@
"""Integration tests for IndexingPipelineService.index_batch()."""
import pytest
from sqlalchemy import select
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.integration
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_index_batch_creates_ready_documents(
db_session, db_search_space, make_connector_document, mocker
):
"""index_batch prepares and indexes a batch, resulting in READY documents."""
space_id = db_search_space.id
docs = [
make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-batch-1",
search_space_id=space_id,
source_markdown="## Email 1\n\nBody",
),
make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-batch-2",
search_space_id=space_id,
source_markdown="## Email 2\n\nDifferent body",
),
]
service = IndexingPipelineService(session=db_session)
results = await service.index_batch(docs, llm=mocker.Mock())
assert len(results) == 2
result = await db_session.execute(
select(Document).filter(Document.search_space_id == space_id)
)
rows = result.scalars().all()
assert len(rows) == 2
for row in rows:
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
assert row.content is not None
assert row.embedding is not None
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_index_batch_empty_returns_empty(db_session, mocker):
"""index_batch with empty input returns an empty list."""
service = IndexingPipelineService(session=db_session)
results = await service.index_batch([], llm=mocker.Mock())
assert results == []

View file

@ -0,0 +1,92 @@
"""Integration tests for IndexingPipelineService.migrate_legacy_docs()."""
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Document, DocumentType
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
async def test_legacy_composio_gmail_doc_migrated_in_db(
db_session, db_search_space, db_user, make_connector_document
):
"""A Composio Gmail doc in the DB gets its hash and type updated to native."""
space_id = db_search_space.id
user_id = str(db_user.id)
unique_id = "msg-legacy-123"
legacy_hash = compute_identifier_hash(
DocumentType.COMPOSIO_GMAIL_CONNECTOR.value, unique_id, space_id
)
native_hash = compute_identifier_hash(
DocumentType.GOOGLE_GMAIL_CONNECTOR.value, unique_id, space_id
)
legacy_doc = Document(
title="Old Gmail",
document_type=DocumentType.COMPOSIO_GMAIL_CONNECTOR,
content="legacy content",
content_hash=f"ch-{legacy_hash[:12]}",
unique_identifier_hash=legacy_hash,
search_space_id=space_id,
created_by_id=user_id,
embedding=[0.1] * _EMBEDDING_DIM,
status={"state": "ready"},
)
db_session.add(legacy_doc)
await db_session.flush()
doc_id = legacy_doc.id
connector_doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id=unique_id,
search_space_id=space_id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(select(Document).filter(Document.id == doc_id))
reloaded = result.scalars().first()
assert reloaded.unique_identifier_hash == native_hash
assert reloaded.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
async def test_no_legacy_doc_is_noop(
db_session, db_search_space, make_connector_document
):
"""When no legacy document exists, migrate_legacy_docs does nothing."""
connector_doc = make_connector_document(
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
unique_id="evt-no-legacy",
search_space_id=db_search_space.id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
assert result.scalars().all() == []
async def test_non_google_type_is_skipped(
db_session, db_search_space, make_connector_document
):
"""migrate_legacy_docs skips ConnectorDocuments that are not Google types."""
connector_doc = make_connector_document(
document_type=DocumentType.CLICKUP_CONNECTOR,
unique_id="task-1",
search_space_id=db_search_space.id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])

View file

@ -0,0 +1,34 @@
"""Pre-register the connector_indexers package to bypass a circular import
in its ``__init__.py`` (airtable_indexer -> routes -> connector_indexers).
This lets tests import individual indexer modules (e.g.
``google_drive_indexer``) without triggering the full package init.
"""
import sys
import types
from pathlib import Path
_BACKEND = Path(__file__).resolve().parents[3]
def _stub_package(dotted: str, fs_dir: Path) -> None:
if dotted not in sys.modules:
mod = types.ModuleType(dotted)
mod.__path__ = [str(fs_dir)]
mod.__package__ = dotted
sys.modules[dotted] = mod
parts = dotted.split(".")
if len(parts) > 1:
parent_dotted = ".".join(parts[:-1])
parent = sys.modules.get(parent_dotted)
if parent is not None:
setattr(parent, parts[-1], sys.modules[dotted])
_stub_package("app.tasks", _BACKEND / "app" / "tasks")
_stub_package(
"app.tasks.connector_indexers",
_BACKEND / "app" / "tasks" / "connector_indexers",
)

View file

@ -0,0 +1,373 @@
"""Tests for Confluence indexer migrated to the unified parallel pipeline."""
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.tasks.connector_indexers.confluence_indexer as _mod
from app.db import DocumentType
from app.tasks.connector_indexers.confluence_indexer import (
_build_connector_doc,
index_confluence_pages,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_page(
page_id: str = "p1",
title: str = "Home",
space_id: str = "S1",
body: str = "<p>Hello</p>",
comments=None,
):
return {
"id": page_id,
"title": title,
"spaceId": space_id,
"body": {"storage": {"value": body}},
"comments": comments or [],
}
def _to_markdown(page: dict) -> str:
page_title = page.get("title", "")
page_content = page.get("body", {}).get("storage", {}).get("value", "")
comments = page.get("comments", [])
comments_content = ""
if comments:
comments_content = "\n\n## Comments\n\n"
for comment in comments:
comment_body = (
comment.get("body", {}).get("storage", {}).get("value", "")
)
comment_author = comment.get("version", {}).get("authorId", "Unknown")
comment_date = comment.get("version", {}).get("createdAt", "")
comments_content += (
f"**Comment by {comment_author}** ({comment_date}):\n"
f"{comment_body}\n\n"
)
return f"# {page_title}\n\n{page_content}{comments_content}"
# ---------------------------------------------------------------------------
# Slice 1: _build_connector_doc tracer bullet
# ---------------------------------------------------------------------------
async def test_build_connector_doc_produces_correct_fields():
page = _make_page(
page_id="abc-123",
title="Engineering Handbook",
space_id="ENG",
comments=[{"id": "c1"}],
)
markdown = _to_markdown(page)
doc = _build_connector_doc(
page,
markdown,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert doc.title == "Engineering Handbook"
assert doc.unique_id == "abc-123"
assert doc.document_type == DocumentType.CONFLUENCE_CONNECTOR
assert doc.source_markdown == markdown
assert doc.search_space_id == _SEARCH_SPACE_ID
assert doc.connector_id == _CONNECTOR_ID
assert doc.created_by_id == _USER_ID
assert doc.should_summarize is True
assert doc.metadata["page_id"] == "abc-123"
assert doc.metadata["page_title"] == "Engineering Handbook"
assert doc.metadata["space_id"] == "ENG"
assert doc.metadata["comment_count"] == 1
assert doc.metadata["connector_id"] == _CONNECTOR_ID
assert doc.metadata["document_type"] == "Confluence Page"
assert doc.metadata["connector_type"] == "Confluence"
assert doc.fallback_summary is not None
assert "Engineering Handbook" in doc.fallback_summary
assert markdown in doc.fallback_summary
async def test_build_connector_doc_summary_disabled():
doc = _build_connector_doc(
_make_page(),
_to_markdown(_make_page()),
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=False,
)
assert doc.should_summarize is False
# ---------------------------------------------------------------------------
# Shared fixtures for Slices 2-7
# ---------------------------------------------------------------------------
def _mock_connector(enable_summary: bool = True):
c = MagicMock()
c.config = {"access_token": "tok"}
c.enable_summary = enable_summary
c.last_indexed_at = None
return c
def _mock_confluence_client(pages=None, error=None):
client = MagicMock()
client.get_pages_by_date_range = AsyncMock(
return_value=(pages if pages is not None else [], error),
)
client.close = AsyncMock()
return client
@pytest.fixture
def confluence_mocks(monkeypatch):
mock_session = AsyncMock()
mock_session.no_autoflush = MagicMock()
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
)
confluence_client = _mock_confluence_client(pages=[_make_page()])
monkeypatch.setattr(
_mod, "ConfluenceHistoryConnector", MagicMock(return_value=confluence_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
)
return {
"session": mock_session,
"connector": mock_connector,
"confluence_client": confluence_client,
"task_logger": mock_task_logger,
"pipeline_mock": pipeline_mock,
"batch_mock": batch_mock,
}
async def _run_index(mocks, **overrides):
return await index_confluence_pages(
session=mocks["session"],
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
user_id=overrides.get("user_id", _USER_ID),
start_date=overrides.get("start_date", "2025-01-01"),
end_date=overrides.get("end_date", "2025-12-31"),
update_last_indexed=overrides.get("update_last_indexed", True),
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
)
# ---------------------------------------------------------------------------
# Slice 2: Full pipeline wiring
# ---------------------------------------------------------------------------
async def test_one_page_calls_pipeline_and_returns_indexed_count(confluence_mocks):
indexed, skipped, warning = await _run_index(confluence_mocks)
assert indexed == 1
assert skipped == 0
assert warning is None
confluence_mocks["batch_mock"].assert_called_once()
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].document_type == DocumentType.CONFLUENCE_CONNECTOR
async def test_pipeline_called_with_max_concurrency_3(confluence_mocks):
await _run_index(confluence_mocks)
call_kwargs = confluence_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("max_concurrency") == 3
async def test_migrate_legacy_docs_called_before_indexing(confluence_mocks):
await _run_index(confluence_mocks)
confluence_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 3: Page skipping (missing id/title/content)
# ---------------------------------------------------------------------------
async def test_pages_with_missing_id_are_skipped(confluence_mocks):
pages = [
_make_page(page_id="p1", title="Valid"),
_make_page(page_id="", title="Missing id"),
]
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
pages,
None,
)
_, skipped, _ = await _run_index(confluence_mocks)
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_pages_with_missing_title_are_skipped(confluence_mocks):
pages = [
_make_page(page_id="p1", title="Valid"),
_make_page(page_id="p2", title=""),
]
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
pages,
None,
)
_, skipped, _ = await _run_index(confluence_mocks)
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_pages_with_no_content_are_skipped(confluence_mocks):
pages = [
_make_page(page_id="p1", title="Valid", body="<p>ok</p>"),
_make_page(page_id="p2", title="Empty", body=""),
]
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
pages,
None,
)
_, skipped, _ = await _run_index(confluence_mocks)
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 4: Duplicate content skipping
# ---------------------------------------------------------------------------
async def test_duplicate_content_pages_are_skipped(confluence_mocks, monkeypatch):
pages = [
_make_page(page_id="p1", title="One"),
_make_page(page_id="p2", title="Two"),
]
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
pages,
None,
)
call_count = 0
async def _check_dup(session, content_hash):
nonlocal call_count
call_count += 1
if call_count == 2:
dup = MagicMock()
dup.id = 99
dup.document_type = "OTHER"
return dup
return None
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
_, skipped, _ = await _run_index(confluence_mocks)
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 5: Heartbeat callback forwarding
# ---------------------------------------------------------------------------
async def test_heartbeat_callback_forwarded_to_pipeline(confluence_mocks):
heartbeat_cb = AsyncMock()
await _run_index(confluence_mocks, on_heartbeat_callback=heartbeat_cb)
call_kwargs = confluence_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
# ---------------------------------------------------------------------------
# Slice 6: Empty pages and no-data success tuple
# ---------------------------------------------------------------------------
async def test_empty_pages_returns_zero_tuple(confluence_mocks):
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
[],
None,
)
indexed, skipped, warning = await _run_index(confluence_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
confluence_mocks["batch_mock"].assert_not_called()
async def test_no_pages_error_message_returns_success_tuple(confluence_mocks):
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
[],
"No pages found in date range",
)
indexed, skipped, warning = await _run_index(confluence_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
async def test_api_error_still_returns_3_tuple(confluence_mocks):
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
[],
"API exploded",
)
result = await _run_index(confluence_mocks)
assert len(result) == 3
assert result[0] == 0
assert result[1] == 0
assert "Failed to get Confluence pages" in result[2]
# ---------------------------------------------------------------------------
# Slice 7: Failed docs warning
# ---------------------------------------------------------------------------
async def test_failed_docs_warning_in_result(confluence_mocks):
confluence_mocks["batch_mock"].return_value = ([], 0, 2)
_, _, warning = await _run_index(confluence_mocks)
assert warning is not None
assert "2 failed" in warning

View file

@ -0,0 +1,671 @@
"""Tests for parallel download + indexing in the Google Drive indexer."""
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.tasks.connector_indexers.google_drive_indexer import (
_download_files_parallel,
_index_full_scan,
_index_selected_files,
_index_with_delta_sync,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_file_dict(file_id: str, name: str, mime: str = "text/plain") -> dict:
return {"id": file_id, "name": name, "mimeType": mime}
def _mock_extract_ok(file_id: str, file_name: str):
"""Return a successful (markdown, metadata, None) tuple."""
return (
f"# Content of {file_name}",
{"google_drive_file_id": file_id, "google_drive_file_name": file_name},
None,
)
@pytest.fixture
def mock_drive_client():
return MagicMock()
@pytest.fixture
def patch_extract(monkeypatch):
"""Provide a helper to set the download_and_extract_content mock."""
def _patch(side_effect=None, return_value=None):
mock = AsyncMock(side_effect=side_effect, return_value=return_value)
monkeypatch.setattr(
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
mock,
)
return mock
return _patch
async def test_single_file_returns_one_connector_document(
mock_drive_client, patch_extract,
):
"""Tracer bullet: downloading one file produces one ConnectorDocument."""
patch_extract(return_value=_mock_extract_ok("f1", "test.txt"))
docs, failed = await _download_files_parallel(
mock_drive_client,
[_make_file_dict("f1", "test.txt")],
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert len(docs) == 1
assert failed == 0
assert docs[0].title == "test.txt"
assert docs[0].unique_id == "f1"
async def test_multiple_files_all_produce_documents(
mock_drive_client, patch_extract,
):
"""All files are downloaded and converted to ConnectorDocuments."""
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
patch_extract(
side_effect=[_mock_extract_ok(f"f{i}", f"file{i}.txt") for i in range(3)]
)
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert len(docs) == 3
assert failed == 0
assert {d.unique_id for d in docs} == {"f0", "f1", "f2"}
async def test_one_download_exception_does_not_block_others(
mock_drive_client, patch_extract,
):
"""A RuntimeError in one download still lets the other files succeed."""
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
patch_extract(
side_effect=[
_mock_extract_ok("f0", "file0.txt"),
RuntimeError("network timeout"),
_mock_extract_ok("f2", "file2.txt"),
]
)
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert len(docs) == 2
assert failed == 1
assert {d.unique_id for d in docs} == {"f0", "f2"}
async def test_etl_error_counts_as_download_failure(
mock_drive_client, patch_extract,
):
"""download_and_extract_content returning an error is counted as failed."""
files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")]
patch_extract(
side_effect=[
_mock_extract_ok("f0", "good.txt"),
(None, {}, "ETL failed"),
]
)
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert len(docs) == 1
assert failed == 1
async def test_concurrency_bounded_by_semaphore(
mock_drive_client, monkeypatch,
):
"""Peak concurrent downloads never exceeds max_concurrency."""
lock = asyncio.Lock()
active = 0
peak = 0
async def _slow_extract(client, file):
nonlocal active, peak
async with lock:
active += 1
peak = max(peak, active)
await asyncio.sleep(0.05)
async with lock:
active -= 1
fid = file["id"]
return _mock_extract_ok(fid, file["name"])
monkeypatch.setattr(
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
_slow_extract,
)
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(6)]
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
max_concurrency=2,
)
assert len(docs) == 6
assert failed == 0
assert peak <= 2, f"Peak concurrency was {peak}, expected <= 2"
async def test_heartbeat_fires_during_parallel_downloads(
mock_drive_client, monkeypatch,
):
"""on_heartbeat is called at least once when downloads take time."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
monkeypatch.setattr(_mod, "HEARTBEAT_INTERVAL_SECONDS", 0)
async def _slow_extract(client, file):
await asyncio.sleep(0.05)
return _mock_extract_ok(file["id"], file["name"])
monkeypatch.setattr(
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
_slow_extract,
)
heartbeat_calls: list[int] = []
async def _on_heartbeat(count: int):
heartbeat_calls.append(count)
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
on_heartbeat=_on_heartbeat,
)
assert len(docs) == 3
assert failed == 0
assert len(heartbeat_calls) >= 1, "Heartbeat should have fired at least once"
# ---------------------------------------------------------------------------
# Slice 6, 6b, 6c -- _index_full_scan three-phase pipeline
# ---------------------------------------------------------------------------
def _folder_dict(file_id: str, name: str) -> dict:
return {"id": file_id, "name": name, "mimeType": "application/vnd.google-apps.folder"}
@pytest.fixture
def full_scan_mocks(mock_drive_client, monkeypatch):
"""Wire up all mocks needed to call _index_full_scan in isolation."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
mock_session = AsyncMock()
mock_connector = MagicMock()
mock_task_logger = MagicMock()
mock_task_logger.log_task_progress = AsyncMock()
mock_log_entry = MagicMock()
skip_results: dict[str, tuple[bool, str | None]] = {}
async def _fake_skip(session, file, search_space_id):
return skip_results.get(file["id"], (False, None))
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
download_mock = AsyncMock(return_value=([], 0))
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
batch_mock = AsyncMock(return_value=([], 0, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
)
monkeypatch.setattr(
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
)
return {
"drive_client": mock_drive_client,
"session": mock_session,
"connector": mock_connector,
"task_logger": mock_task_logger,
"log_entry": mock_log_entry,
"skip_results": skip_results,
"download_mock": download_mock,
"batch_mock": batch_mock,
"pipeline_mock": pipeline_mock,
}
async def _run_full_scan(mocks, *, max_files=500, include_subfolders=False):
return await _index_full_scan(
mocks["drive_client"],
mocks["session"],
mocks["connector"],
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
"folder-root",
"My Folder",
mocks["task_logger"],
mocks["log_entry"],
max_files,
include_subfolders=include_subfolders,
enable_summary=True,
)
async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
"""Full scan collects files serially, downloads and indexes in parallel,
and returns correct (indexed, skipped) with renames counted as indexed."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
page_files = [
_folder_dict("folder1", "SubFolder"),
_make_file_dict("skip1", "unchanged.txt"),
_make_file_dict("rename1", "renamed.txt"),
_make_file_dict("new1", "new1.txt"),
_make_file_dict("new2", "new2.txt"),
]
monkeypatch.setattr(
_mod, "get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged")
full_scan_mocks["skip_results"]["rename1"] = (True, "File renamed: 'old''renamed.txt'")
mock_docs = [MagicMock(), MagicMock()]
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
full_scan_mocks["batch_mock"].return_value = ([], 2, 0)
indexed, skipped = await _run_full_scan(full_scan_mocks)
assert indexed == 3 # 1 renamed + 2 from batch
assert skipped == 1 # 1 unchanged
full_scan_mocks["download_mock"].assert_called_once()
call_files = full_scan_mocks["download_mock"].call_args[0][1]
assert len(call_files) == 2
assert {f["id"] for f in call_files} == {"new1", "new2"}
async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
"""Only max_files non-folder files are processed; the rest are ignored."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)]
monkeypatch.setattr(
_mod, "get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
full_scan_mocks["download_mock"].return_value = ([], 0)
full_scan_mocks["batch_mock"].return_value = ([], 0, 0)
await _run_full_scan(full_scan_mocks, max_files=3)
download_call_files = full_scan_mocks["download_mock"].call_args[0][1]
assert len(download_call_files) == 3
async def test_full_scan_uses_max_concurrency_3_for_indexing(
full_scan_mocks, monkeypatch,
):
"""index_batch_parallel is called with max_concurrency=3."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
page_files = [_make_file_dict("f1", "file1.txt")]
monkeypatch.setattr(
_mod, "get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
mock_docs = [MagicMock()]
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
full_scan_mocks["batch_mock"].return_value = ([], 1, 0)
await _run_full_scan(full_scan_mocks)
call_kwargs = full_scan_mocks["batch_mock"].call_args
assert call_kwargs[1].get("max_concurrency") == 3 or (
len(call_kwargs[0]) > 2 and call_kwargs[0][2] == 3
)
# ---------------------------------------------------------------------------
# Slice 7 -- _index_with_delta_sync three-phase pipeline
# ---------------------------------------------------------------------------
async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
"""Removed/trashed changes call _remove_document; the rest go through
_download_files_parallel and index_batch_parallel."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
changes = [
{"fileId": "del1", "removed": True},
{"fileId": "del2", "file": {"id": "del2", "trashed": True}},
{"fileId": "trash1", "file": {"id": "trash1", "trashed": True}},
{"fileId": "mod1", "file": _make_file_dict("mod1", "modified1.txt")},
{"fileId": "mod2", "file": _make_file_dict("mod2", "modified2.txt")},
]
monkeypatch.setattr(
_mod, "fetch_all_changes",
AsyncMock(return_value=(changes, "new-token", None)),
)
change_types = {
"del1": "removed",
"del2": "removed",
"trash1": "trashed",
"mod1": "modified",
"mod2": "modified",
}
monkeypatch.setattr(
_mod, "categorize_change",
lambda change: change_types[change["fileId"]],
)
remove_calls: list[str] = []
async def _fake_remove(session, file_id, search_space_id):
remove_calls.append(file_id)
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
monkeypatch.setattr(
_mod, "_should_skip_file",
AsyncMock(return_value=(False, None)),
)
mock_docs = [MagicMock(), MagicMock()]
download_mock = AsyncMock(return_value=(mock_docs, 0))
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
batch_mock = AsyncMock(return_value=([], 2, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
)
monkeypatch.setattr(
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
)
mock_session = AsyncMock()
mock_task_logger = MagicMock()
mock_task_logger.log_task_progress = AsyncMock()
indexed, skipped = await _index_with_delta_sync(
MagicMock(),
mock_session,
MagicMock(),
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
"folder-root",
"start-token-abc",
mock_task_logger,
MagicMock(),
max_files=500,
enable_summary=True,
)
assert sorted(remove_calls) == ["del1", "del2", "trash1"]
download_mock.assert_called_once()
downloaded_files = download_mock.call_args[0][1]
assert len(downloaded_files) == 2
assert {f["id"] for f in downloaded_files} == {"mod1", "mod2"}
assert indexed == 2
assert skipped == 0
# ---------------------------------------------------------------------------
# _index_selected_files -- parallel indexing of user-selected files
# ---------------------------------------------------------------------------
@pytest.fixture
def selected_files_mocks(mock_drive_client, monkeypatch):
"""Wire up mocks for _index_selected_files tests."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
mock_session = AsyncMock()
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
async def _fake_get_file(client, file_id):
return get_file_results.get(file_id, (None, f"Not configured: {file_id}"))
monkeypatch.setattr(_mod, "get_file_by_id", _fake_get_file)
skip_results: dict[str, tuple[bool, str | None]] = {}
async def _fake_skip(session, file, search_space_id):
return skip_results.get(file["id"], (False, None))
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
download_and_index_mock = AsyncMock(return_value=(0, 0))
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
return {
"drive_client": mock_drive_client,
"session": mock_session,
"get_file_results": get_file_results,
"skip_results": skip_results,
"download_and_index_mock": download_and_index_mock,
}
async def _run_selected(mocks, file_ids):
return await _index_selected_files(
mocks["drive_client"],
mocks["session"],
file_ids,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
async def test_selected_files_single_file_indexed(selected_files_mocks):
"""Tracer bullet: one file fetched, not skipped, indexed via parallel pipeline."""
selected_files_mocks["get_file_results"]["f1"] = (
_make_file_dict("f1", "report.pdf"),
None,
)
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
indexed, skipped, errors = await _run_selected(
selected_files_mocks, [("f1", "report.pdf")],
)
assert indexed == 1
assert skipped == 0
assert errors == []
selected_files_mocks["download_and_index_mock"].assert_called_once()
async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
"""get_file_by_id failing for one file collects an error; others still indexed."""
selected_files_mocks["get_file_results"]["f1"] = (
_make_file_dict("f1", "first.txt"), None,
)
selected_files_mocks["get_file_results"]["f2"] = (None, "HTTP 404")
selected_files_mocks["get_file_results"]["f3"] = (
_make_file_dict("f3", "third.txt"), None,
)
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
indexed, skipped, errors = await _run_selected(
selected_files_mocks,
[("f1", "first.txt"), ("f2", "mid.txt"), ("f3", "third.txt")],
)
assert indexed == 2
assert skipped == 0
assert len(errors) == 1
assert "mid.txt" in errors[0]
assert "HTTP 404" in errors[0]
async def test_selected_files_skip_rename_counting(selected_files_mocks):
"""Unchanged files are skipped, renames counted as indexed,
and only new files are sent to _download_and_index."""
for fid, fname in [("s1", "unchanged.txt"), ("r1", "renamed.txt"),
("n1", "new1.txt"), ("n2", "new2.txt")]:
selected_files_mocks["get_file_results"][fid] = (
_make_file_dict(fid, fname), None,
)
selected_files_mocks["skip_results"]["s1"] = (True, "unchanged")
selected_files_mocks["skip_results"]["r1"] = (True, "File renamed: 'old' \u2192 'renamed.txt'")
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
indexed, skipped, errors = await _run_selected(
selected_files_mocks,
[("s1", "unchanged.txt"), ("r1", "renamed.txt"),
("n1", "new1.txt"), ("n2", "new2.txt")],
)
assert indexed == 3 # 1 renamed + 2 batch
assert skipped == 1 # 1 unchanged
assert errors == []
mock = selected_files_mocks["download_and_index_mock"]
mock.assert_called_once()
call_files = mock.call_args[1].get("files") if "files" in (mock.call_args[1] or {}) else mock.call_args[0][2]
assert len(call_files) == 2
assert {f["id"] for f in call_files} == {"n1", "n2"}
# ---------------------------------------------------------------------------
# asyncio.to_thread verification — prove blocking calls run in parallel
# ---------------------------------------------------------------------------
async def test_client_download_file_runs_in_thread_parallel():
"""Calling download_file concurrently via asyncio.gather should overlap
blocking work on separate threads, proving to_thread is effective.
Strategy: patch _sync_download_file with a blocking time.sleep(0.2).
Launch 3 concurrent calls. Serial would take >=0.6s; parallel < 0.4s.
"""
from app.connectors.google_drive.client import GoogleDriveClient
BLOCK_SECONDS = 0.2
NUM_CALLS = 3
def _blocking_download(service, file_id, credentials):
time.sleep(BLOCK_SECONDS)
return b"fake-content", None
client = GoogleDriveClient.__new__(GoogleDriveClient)
client.service = MagicMock()
client._resolved_credentials = MagicMock()
client._service_lock = asyncio.Lock()
with patch.object(
GoogleDriveClient, "_sync_download_file", staticmethod(_blocking_download),
):
start = time.monotonic()
results = await asyncio.gather(
*(client.download_file(f"file-{i}") for i in range(NUM_CALLS))
)
elapsed = time.monotonic() - start
for content, error in results:
assert content == b"fake-content"
assert error is None
serial_minimum = BLOCK_SECONDS * NUM_CALLS
assert elapsed < serial_minimum, (
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
f"downloads are not running in parallel"
)
async def test_client_export_google_file_runs_in_thread_parallel():
"""Same strategy for export_google_file — verify to_thread parallelism."""
from app.connectors.google_drive.client import GoogleDriveClient
BLOCK_SECONDS = 0.2
NUM_CALLS = 3
def _blocking_export(service, file_id, mime_type, credentials):
time.sleep(BLOCK_SECONDS)
return b"exported", None
client = GoogleDriveClient.__new__(GoogleDriveClient)
client.service = MagicMock()
client._resolved_credentials = MagicMock()
client._service_lock = asyncio.Lock()
with patch.object(
GoogleDriveClient, "_sync_export_google_file", staticmethod(_blocking_export),
):
start = time.monotonic()
results = await asyncio.gather(
*(client.export_google_file(f"file-{i}", "application/pdf")
for i in range(NUM_CALLS))
)
elapsed = time.monotonic() - start
for content, error in results:
assert content == b"exported"
assert error is None
serial_minimum = BLOCK_SECONDS * NUM_CALLS
assert elapsed < serial_minimum, (
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
f"exports are not running in parallel"
)

View file

@ -0,0 +1,372 @@
"""Tests for Jira indexer migrated to the unified parallel pipeline."""
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.tasks.connector_indexers.jira_indexer as _mod
from app.db import DocumentType
from app.tasks.connector_indexers.jira_indexer import (
_build_connector_doc,
index_jira_issues,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_issue(
issue_key: str = "ENG-1",
issue_id: str = "10001",
title: str = "Fix login",
):
return {"key": issue_key, "id": issue_id, "title": title}
def _make_formatted_issue(
issue_key: str = "ENG-1",
issue_id: str = "10001",
title: str = "Fix login",
status: str = "In Progress",
priority: str = "High",
comments=None,
):
return {
"key": issue_key,
"id": issue_id,
"title": title,
"status": status,
"priority": priority,
"comments": comments or [],
}
# ---------------------------------------------------------------------------
# Slice 1: _build_connector_doc tracer bullet
# ---------------------------------------------------------------------------
async def test_build_connector_doc_produces_correct_fields():
issue = _make_issue(issue_key="ENG-42", issue_id="4242", title="Fix auth bug")
formatted = _make_formatted_issue(
issue_key="ENG-42",
issue_id="4242",
title="Fix auth bug",
status="Done",
priority="Urgent",
comments=[{"id": "c1"}],
)
markdown = "# ENG-42: Fix auth bug\n\nBody"
doc = _build_connector_doc(
issue,
formatted,
markdown,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert doc.title == "ENG-42: 4242"
assert doc.unique_id == "ENG-42"
assert doc.document_type == DocumentType.JIRA_CONNECTOR
assert doc.source_markdown == markdown
assert doc.search_space_id == _SEARCH_SPACE_ID
assert doc.connector_id == _CONNECTOR_ID
assert doc.created_by_id == _USER_ID
assert doc.should_summarize is True
assert doc.metadata["issue_id"] == "ENG-42"
assert doc.metadata["issue_identifier"] == "ENG-42"
assert doc.metadata["issue_title"] == "4242"
assert doc.metadata["state"] == "Done"
assert doc.metadata["priority"] == "Urgent"
assert doc.metadata["comment_count"] == 1
assert doc.metadata["connector_id"] == _CONNECTOR_ID
assert doc.metadata["document_type"] == "Jira Issue"
assert doc.metadata["connector_type"] == "Jira"
assert doc.fallback_summary is not None
assert "ENG-42" in doc.fallback_summary
assert markdown in doc.fallback_summary
async def test_build_connector_doc_summary_disabled():
doc = _build_connector_doc(
_make_issue(),
_make_formatted_issue(),
"# content",
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=False,
)
assert doc.should_summarize is False
# ---------------------------------------------------------------------------
# Shared fixtures for Slices 2-7
# ---------------------------------------------------------------------------
def _mock_connector(enable_summary: bool = True):
c = MagicMock()
c.config = {"access_token": "tok"}
c.enable_summary = enable_summary
c.last_indexed_at = None
return c
def _mock_jira_client(issues=None, error=None):
client = MagicMock()
client.get_issues_by_date_range = AsyncMock(
return_value=(issues if issues is not None else [], error),
)
client.format_issue = MagicMock(
side_effect=lambda i: _make_formatted_issue(
issue_key=i.get("key", ""),
issue_id=i.get("id", ""),
title=i.get("title", ""),
)
)
client.format_issue_to_markdown = MagicMock(
side_effect=lambda fi: f"# {fi.get('key', '')}: {fi.get('id', '')}\n\nContent"
)
client.close = AsyncMock()
return client
@pytest.fixture
def jira_mocks(monkeypatch):
mock_session = AsyncMock()
mock_session.no_autoflush = MagicMock()
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
)
jira_client = _mock_jira_client(issues=[_make_issue()])
monkeypatch.setattr(
_mod, "JiraHistoryConnector", MagicMock(return_value=jira_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
)
return {
"session": mock_session,
"connector": mock_connector,
"jira_client": jira_client,
"task_logger": mock_task_logger,
"pipeline_mock": pipeline_mock,
"batch_mock": batch_mock,
}
async def _run_index(mocks, **overrides):
return await index_jira_issues(
session=mocks["session"],
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
user_id=overrides.get("user_id", _USER_ID),
start_date=overrides.get("start_date", "2025-01-01"),
end_date=overrides.get("end_date", "2025-12-31"),
update_last_indexed=overrides.get("update_last_indexed", True),
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
)
# ---------------------------------------------------------------------------
# Slice 2: Full pipeline wiring
# ---------------------------------------------------------------------------
async def test_one_issue_calls_pipeline_and_returns_indexed_count(jira_mocks):
indexed, skipped, warning = await _run_index(jira_mocks)
assert indexed == 1
assert skipped == 0
assert warning is None
jira_mocks["batch_mock"].assert_called_once()
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].document_type == DocumentType.JIRA_CONNECTOR
async def test_pipeline_called_with_max_concurrency_3(jira_mocks):
await _run_index(jira_mocks)
call_kwargs = jira_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("max_concurrency") == 3
async def test_migrate_legacy_docs_called_before_indexing(jira_mocks):
await _run_index(jira_mocks)
jira_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 3: Issue skipping (missing key/title/content)
# ---------------------------------------------------------------------------
async def test_issues_with_missing_key_are_skipped(jira_mocks):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
{"key": "", "id": "10002", "title": "No key"},
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_issues_with_missing_title_are_skipped(jira_mocks):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
{"key": "ENG-2", "id": "", "title": "Missing id used as title"},
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_issues_with_no_content_are_skipped(jira_mocks):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
_make_issue(issue_key="ENG-2", issue_id="10002"),
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
jira_mocks["jira_client"].format_issue_to_markdown.side_effect = [
"# ENG-1: 10001\n\nContent",
"",
]
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 4: Duplicate content skipping
# ---------------------------------------------------------------------------
async def test_duplicate_content_issues_are_skipped(jira_mocks, monkeypatch):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
_make_issue(issue_key="ENG-2", issue_id="10002"),
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
call_count = 0
async def _check_dup(session, content_hash):
nonlocal call_count
call_count += 1
if call_count == 2:
dup = MagicMock()
dup.id = 99
dup.document_type = "OTHER"
return dup
return None
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 5: Heartbeat callback forwarding
# ---------------------------------------------------------------------------
async def test_heartbeat_callback_forwarded_to_pipeline(jira_mocks):
heartbeat_cb = AsyncMock()
await _run_index(jira_mocks, on_heartbeat_callback=heartbeat_cb)
call_kwargs = jira_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
# ---------------------------------------------------------------------------
# Slice 6: Empty issues and no-data success tuple
# ---------------------------------------------------------------------------
async def test_empty_issues_returns_zero_tuple(jira_mocks):
jira_mocks["jira_client"].get_issues_by_date_range.return_value = ([], None)
indexed, skipped, warning = await _run_index(jira_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
jira_mocks["batch_mock"].assert_not_called()
async def test_no_issues_error_message_returns_success_tuple(jira_mocks):
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (
[],
"No issues found in date range",
)
indexed, skipped, warning = await _run_index(jira_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
async def test_api_error_still_returns_3_tuple(jira_mocks):
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (
[],
"API exploded",
)
result = await _run_index(jira_mocks)
assert len(result) == 3
assert result[0] == 0
assert result[1] == 0
assert "Failed to get Jira issues" in result[2]
# ---------------------------------------------------------------------------
# Slice 7: Failed docs warning
# ---------------------------------------------------------------------------
async def test_failed_docs_warning_in_result(jira_mocks):
jira_mocks["batch_mock"].return_value = ([], 0, 2)
_, _, warning = await _run_index(jira_mocks)
assert warning is not None
assert "2 failed" in warning

View file

@ -0,0 +1,355 @@
"""Tests for Linear indexer migrated to the unified parallel pipeline."""
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.tasks.connector_indexers.linear_indexer as _mod
from app.db import DocumentType
from app.tasks.connector_indexers.linear_indexer import (
_build_connector_doc,
index_linear_issues,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_issue(
issue_id: str = "issue-1",
identifier: str = "ENG-1",
title: str = "Fix bug",
):
return {"id": issue_id, "identifier": identifier, "title": title}
def _make_formatted_issue(
issue_id: str = "issue-1",
identifier: str = "ENG-1",
title: str = "Fix bug",
state: str = "In Progress",
priority: str = "High",
comments=None,
):
return {
"id": issue_id,
"identifier": identifier,
"title": title,
"state": state,
"priority": priority,
"description": "Some description",
"comments": comments or [],
}
# ---------------------------------------------------------------------------
# Slice 1: _build_connector_doc tracer bullet
# ---------------------------------------------------------------------------
async def test_build_connector_doc_produces_correct_fields():
"""Tracer bullet: a Linear issue produces a ConnectorDocument with correct fields."""
issue = _make_issue(issue_id="abc-123", identifier="ENG-42", title="Fix login bug")
formatted = _make_formatted_issue(
issue_id="abc-123",
identifier="ENG-42",
title="Fix login bug",
state="Done",
priority="Urgent",
comments=[{"id": "c1"}],
)
markdown = "# ENG-42: Fix login bug\n\nDescription here"
doc = _build_connector_doc(
issue,
formatted,
markdown,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert doc.title == "ENG-42: Fix login bug"
assert doc.unique_id == "abc-123"
assert doc.document_type == DocumentType.LINEAR_CONNECTOR
assert doc.source_markdown == markdown
assert doc.search_space_id == _SEARCH_SPACE_ID
assert doc.connector_id == _CONNECTOR_ID
assert doc.created_by_id == _USER_ID
assert doc.should_summarize is True
assert doc.metadata["issue_id"] == "abc-123"
assert doc.metadata["issue_identifier"] == "ENG-42"
assert doc.metadata["issue_title"] == "Fix login bug"
assert doc.metadata["state"] == "Done"
assert doc.metadata["priority"] == "Urgent"
assert doc.metadata["comment_count"] == 1
assert doc.metadata["connector_id"] == _CONNECTOR_ID
assert doc.metadata["document_type"] == "Linear Issue"
assert doc.metadata["connector_type"] == "Linear"
assert doc.fallback_summary is not None
assert "ENG-42" in doc.fallback_summary
assert markdown in doc.fallback_summary
async def test_build_connector_doc_summary_disabled():
"""When enable_summary is False, should_summarize is False."""
doc = _build_connector_doc(
_make_issue(),
_make_formatted_issue(),
"# content",
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=False,
)
assert doc.should_summarize is False
# ---------------------------------------------------------------------------
# Shared fixtures for Slices 2-6
# ---------------------------------------------------------------------------
def _mock_connector(enable_summary: bool = True):
c = MagicMock()
c.config = {"access_token": "tok"}
c.enable_summary = enable_summary
c.last_indexed_at = None
return c
def _mock_linear_client(issues=None, error=None):
client = MagicMock()
client.get_issues_by_date_range = AsyncMock(
return_value=(issues if issues is not None else [], error),
)
client.format_issue = MagicMock(side_effect=lambda i: _make_formatted_issue(
issue_id=i.get("id", ""),
identifier=i.get("identifier", ""),
title=i.get("title", ""),
))
client.format_issue_to_markdown = MagicMock(
side_effect=lambda fi: f"# {fi.get('identifier', '')}: {fi.get('title', '')}\n\nContent"
)
return client
@pytest.fixture
def linear_mocks(monkeypatch):
"""Wire up all external boundary mocks for index_linear_issues."""
mock_session = AsyncMock()
mock_session.no_autoflush = MagicMock()
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
)
linear_client = _mock_linear_client(issues=[_make_issue()])
monkeypatch.setattr(
_mod, "LinearConnector", MagicMock(return_value=linear_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
)
return {
"session": mock_session,
"connector": mock_connector,
"linear_client": linear_client,
"task_logger": mock_task_logger,
"pipeline_mock": pipeline_mock,
"batch_mock": batch_mock,
}
async def _run_index(mocks, **overrides):
return await index_linear_issues(
session=mocks["session"],
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
user_id=overrides.get("user_id", _USER_ID),
start_date=overrides.get("start_date", "2025-01-01"),
end_date=overrides.get("end_date", "2025-12-31"),
update_last_indexed=overrides.get("update_last_indexed", True),
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
)
# ---------------------------------------------------------------------------
# Slice 2: Full pipeline wiring
# ---------------------------------------------------------------------------
async def test_one_issue_calls_pipeline_and_returns_indexed_count(linear_mocks):
"""One valid issue is passed to the pipeline and the indexed count is returned."""
indexed, skipped, warning = await _run_index(linear_mocks)
assert indexed == 1
assert skipped == 0
assert warning is None
linear_mocks["batch_mock"].assert_called_once()
call_args = linear_mocks["batch_mock"].call_args
connector_docs = call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].document_type == DocumentType.LINEAR_CONNECTOR
async def test_pipeline_called_with_max_concurrency_3(linear_mocks):
"""index_batch_parallel is called with max_concurrency=3."""
await _run_index(linear_mocks)
call_kwargs = linear_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("max_concurrency") == 3
async def test_migrate_legacy_docs_called_before_indexing(linear_mocks):
"""migrate_legacy_docs is called on the pipeline before index_batch_parallel."""
await _run_index(linear_mocks)
linear_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 3: Issue skipping (missing ID / title)
# ---------------------------------------------------------------------------
async def test_issues_with_missing_id_are_skipped(linear_mocks):
"""Issues without id are skipped and not passed to the pipeline."""
issues = [
_make_issue(issue_id="valid-1", identifier="ENG-1", title="Valid"),
{"id": "", "identifier": "ENG-2", "title": "No ID"},
]
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].unique_id == "valid-1"
assert skipped == 1
async def test_issues_with_missing_title_are_skipped(linear_mocks):
"""Issues without title are skipped."""
issues = [
_make_issue(issue_id="valid-1", identifier="ENG-1", title="Valid"),
{"id": "id-2", "identifier": "ENG-2", "title": ""},
]
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 4: Duplicate content skipping
# ---------------------------------------------------------------------------
async def test_duplicate_content_issues_are_skipped(linear_mocks, monkeypatch):
"""Issues whose content hash matches an existing document are skipped."""
issues = [
_make_issue(issue_id="new-1", identifier="ENG-1", title="New"),
_make_issue(issue_id="dup-1", identifier="ENG-2", title="Dup"),
]
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
call_count = 0
async def _check_dup(session, content_hash):
nonlocal call_count
call_count += 1
if call_count == 2:
dup = MagicMock()
dup.id = 99
dup.document_type = "OTHER"
return dup
return None
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 5: Heartbeat callback forwarding
# ---------------------------------------------------------------------------
async def test_heartbeat_callback_forwarded_to_pipeline(linear_mocks):
"""on_heartbeat_callback is passed through to index_batch_parallel."""
heartbeat_cb = AsyncMock()
await _run_index(linear_mocks, on_heartbeat_callback=heartbeat_cb)
call_kwargs = linear_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
# ---------------------------------------------------------------------------
# Slice 6: Empty issues early return
# ---------------------------------------------------------------------------
async def test_empty_issues_returns_zero_tuple(linear_mocks):
"""When no issues are found, returns (0, 0, None) and pipeline is not called."""
linear_mocks["linear_client"].get_issues_by_date_range.return_value = ([], None)
indexed, skipped, warning = await _run_index(linear_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
linear_mocks["batch_mock"].assert_not_called()
async def test_failed_docs_warning_in_result(linear_mocks):
"""When documents fail indexing, the warning includes the count."""
linear_mocks["batch_mock"].return_value = ([], 0, 2)
_, _, warning = await _run_index(linear_mocks)
assert warning is not None
assert "2 failed" in warning

View file

@ -0,0 +1,345 @@
"""Tests for Notion indexer migrated to the unified parallel pipeline."""
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.tasks.connector_indexers.notion_indexer as _mod
from app.db import DocumentType
from app.tasks.connector_indexers.notion_indexer import (
_build_connector_doc,
index_notion_pages,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_page(page_id: str = "page-1", title: str = "Test Page", content=None):
if content is None:
content = [{"type": "paragraph", "content": "Hello world", "children": []}]
return {"page_id": page_id, "title": title, "content": content}
# ---------------------------------------------------------------------------
# Slice 1: _build_connector_doc tracer bullet
# ---------------------------------------------------------------------------
async def test_build_connector_doc_produces_correct_fields():
"""Tracer bullet: a single Notion page produces a ConnectorDocument with correct fields."""
page = _make_page(page_id="abc-123", title="My Notion Page")
markdown = "# My Notion Page\n\nHello world"
doc = _build_connector_doc(
page,
markdown,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert doc.title == "My Notion Page"
assert doc.unique_id == "abc-123"
assert doc.document_type == DocumentType.NOTION_CONNECTOR
assert doc.source_markdown == markdown
assert doc.search_space_id == _SEARCH_SPACE_ID
assert doc.connector_id == _CONNECTOR_ID
assert doc.created_by_id == _USER_ID
assert doc.should_summarize is True
assert doc.metadata["page_title"] == "My Notion Page"
assert doc.metadata["page_id"] == "abc-123"
assert doc.metadata["connector_id"] == _CONNECTOR_ID
assert doc.metadata["document_type"] == "Notion Page"
assert doc.metadata["connector_type"] == "Notion"
assert doc.fallback_summary is not None
assert "My Notion Page" in doc.fallback_summary
assert markdown in doc.fallback_summary
async def test_build_connector_doc_summary_disabled():
"""When enable_summary is False, should_summarize is False."""
doc = _build_connector_doc(
_make_page(),
"# content",
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=False,
)
assert doc.should_summarize is False
# ---------------------------------------------------------------------------
# Shared fixtures for Slices 2-7 (full index_notion_pages tests)
# ---------------------------------------------------------------------------
def _mock_connector(enable_summary: bool = True):
c = MagicMock()
c.config = {"access_token": "tok"}
c.enable_summary = enable_summary
c.last_indexed_at = None
return c
def _mock_notion_client(pages=None, skipped_count=0, legacy_token=False):
client = MagicMock()
client.get_all_pages = AsyncMock(return_value=pages if pages is not None else [])
client.get_skipped_content_count = MagicMock(return_value=skipped_count)
client.is_using_legacy_token = MagicMock(return_value=legacy_token)
client.close = AsyncMock()
client.set_retry_callback = MagicMock()
return client
@pytest.fixture
def notion_mocks(monkeypatch):
"""Wire up all external boundary mocks for index_notion_pages."""
mock_session = AsyncMock()
mock_session.no_autoflush = MagicMock()
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
)
notion_client = _mock_notion_client(pages=[_make_page()])
monkeypatch.setattr(
_mod, "NotionHistoryConnector", MagicMock(return_value=notion_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
monkeypatch.setattr(
_mod, "process_blocks", MagicMock(return_value="Converted markdown content"),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
)
return {
"session": mock_session,
"connector": mock_connector,
"notion_client": notion_client,
"task_logger": mock_task_logger,
"pipeline_mock": pipeline_mock,
"batch_mock": batch_mock,
}
async def _run_index(mocks, **overrides):
return await index_notion_pages(
session=mocks["session"],
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
user_id=overrides.get("user_id", _USER_ID),
start_date=overrides.get("start_date", "2025-01-01"),
end_date=overrides.get("end_date", "2025-12-31"),
update_last_indexed=overrides.get("update_last_indexed", True),
on_retry_callback=overrides.get("on_retry_callback"),
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
)
# ---------------------------------------------------------------------------
# Slice 2: Full pipeline wiring
# ---------------------------------------------------------------------------
async def test_one_page_calls_pipeline_and_returns_indexed_count(notion_mocks):
"""One valid page is passed to the pipeline and the indexed count is returned."""
indexed, skipped, warning = await _run_index(notion_mocks)
assert indexed == 1
assert skipped == 0
assert warning is None
notion_mocks["batch_mock"].assert_called_once()
call_args = notion_mocks["batch_mock"].call_args
connector_docs = call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].document_type == DocumentType.NOTION_CONNECTOR
async def test_pipeline_called_with_max_concurrency_3(notion_mocks):
"""index_batch_parallel is called with max_concurrency=3."""
await _run_index(notion_mocks)
call_kwargs = notion_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("max_concurrency") == 3
async def test_migrate_legacy_docs_called_before_indexing(notion_mocks):
"""migrate_legacy_docs is called on the pipeline before index_batch_parallel."""
await _run_index(notion_mocks)
notion_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 3: Page skipping (no content / missing ID)
# ---------------------------------------------------------------------------
async def test_pages_with_missing_id_are_skipped(notion_mocks, monkeypatch):
"""Pages without page_id are skipped and not passed to the pipeline."""
pages = [
_make_page(page_id="valid-1"),
{"title": "No ID page", "content": [{"type": "paragraph", "content": "text", "children": []}]},
]
notion_mocks["notion_client"].get_all_pages.return_value = pages
_, skipped, _ = await _run_index(notion_mocks)
connector_docs = notion_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].unique_id == "valid-1"
assert skipped == 1
async def test_pages_with_no_content_are_skipped(notion_mocks, monkeypatch):
"""Pages with empty content are skipped."""
pages = [
_make_page(page_id="valid-1"),
_make_page(page_id="empty-1", content=[]),
]
notion_mocks["notion_client"].get_all_pages.return_value = pages
_, skipped, _ = await _run_index(notion_mocks)
connector_docs = notion_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 4: Duplicate content skipping
# ---------------------------------------------------------------------------
async def test_duplicate_content_pages_are_skipped(notion_mocks, monkeypatch):
"""Pages whose content hash matches an existing document are skipped."""
pages = [
_make_page(page_id="new-1"),
_make_page(page_id="dup-1"),
]
notion_mocks["notion_client"].get_all_pages.return_value = pages
call_count = 0
async def _check_dup(session, content_hash):
nonlocal call_count
call_count += 1
if call_count == 2:
dup = MagicMock()
dup.id = 99
dup.document_type = "OTHER"
return dup
return None
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
_, skipped, _ = await _run_index(notion_mocks)
connector_docs = notion_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 5: Heartbeat callback forwarding
# ---------------------------------------------------------------------------
async def test_heartbeat_callback_forwarded_to_pipeline(notion_mocks):
"""on_heartbeat_callback is passed through to index_batch_parallel."""
heartbeat_cb = AsyncMock()
await _run_index(notion_mocks, on_heartbeat_callback=heartbeat_cb)
call_kwargs = notion_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
# ---------------------------------------------------------------------------
# Slice 6: Notion-specific warning messages
# ---------------------------------------------------------------------------
async def test_skipped_ai_content_warning_in_result(notion_mocks):
"""When Notion AI content was skipped, the warning message includes it."""
notion_mocks["notion_client"].get_skipped_content_count.return_value = 3
_, _, warning = await _run_index(notion_mocks)
assert warning is not None
assert "API limitation" in warning
async def test_legacy_token_warning_in_result(notion_mocks):
"""When using legacy token, the warning message includes a notice."""
notion_mocks["notion_client"].is_using_legacy_token.return_value = True
_, _, warning = await _run_index(notion_mocks)
assert warning is not None
assert "legacy token" in warning.lower()
async def test_failed_docs_warning_in_result(notion_mocks):
"""When documents fail indexing, the warning includes the count."""
notion_mocks["batch_mock"].return_value = ([], 0, 2)
_, _, warning = await _run_index(notion_mocks)
assert warning is not None
assert "2 failed" in warning
# ---------------------------------------------------------------------------
# Slice 7: Empty pages early return
# ---------------------------------------------------------------------------
async def test_empty_pages_returns_zero_tuple(notion_mocks):
"""When no pages are found, returns (0, 0, None) and updates last_indexed."""
notion_mocks["notion_client"].get_all_pages.return_value = []
indexed, skipped, warning = await _run_index(notion_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
notion_mocks["batch_mock"].assert_not_called()

View file

@ -3,6 +3,7 @@ import pytest
from app.db import DocumentType
from app.indexing_pipeline.document_hashing import (
compute_content_hash,
compute_identifier_hash,
compute_unique_identifier_hash,
)
@ -61,3 +62,23 @@ def test_different_content_produces_different_content_hash(make_connector_docume
doc_a = make_connector_document(source_markdown="Original content")
doc_b = make_connector_document(source_markdown="Updated content")
assert compute_content_hash(doc_a) != compute_content_hash(doc_b)
def test_compute_identifier_hash_matches_connector_doc_hash(make_connector_document):
"""Raw-args hash equals ConnectorDocument hash for equivalent inputs."""
doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-123",
search_space_id=5,
)
raw_hash = compute_identifier_hash("GOOGLE_GMAIL_CONNECTOR", "msg-123", 5)
assert raw_hash == compute_unique_identifier_hash(doc)
def test_compute_identifier_hash_differs_for_different_inputs():
"""Different arguments produce different hashes."""
h1 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-1", 1)
h2 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-2", 1)
h3 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-1", 2)
h4 = compute_identifier_hash("COMPOSIO_GOOGLE_DRIVE_CONNECTOR", "file-1", 1)
assert len({h1, h2, h3, h4}) == 4

View file

@ -0,0 +1,82 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.db import Document, DocumentType
from app.indexing_pipeline.document_hashing import compute_unique_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.unit
@pytest.fixture
def mock_session():
return AsyncMock()
@pytest.fixture
def pipeline(mock_session):
return IndexingPipelineService(mock_session)
async def test_calls_prepare_then_index_per_document(
pipeline, make_connector_document
):
"""index_batch calls prepare_for_indexing, then index() for each returned doc."""
doc1 = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-1",
search_space_id=1,
)
doc2 = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-2",
search_space_id=1,
)
orm1 = MagicMock(spec=Document)
orm1.unique_identifier_hash = compute_unique_identifier_hash(doc1)
orm2 = MagicMock(spec=Document)
orm2.unique_identifier_hash = compute_unique_identifier_hash(doc2)
mock_llm = MagicMock()
pipeline.prepare_for_indexing = AsyncMock(return_value=[orm1, orm2])
pipeline.index = AsyncMock(side_effect=lambda doc, cdoc, llm: doc)
results = await pipeline.index_batch([doc1, doc2], mock_llm)
pipeline.prepare_for_indexing.assert_awaited_once_with([doc1, doc2])
assert pipeline.index.await_count == 2
assert results == [orm1, orm2]
async def test_empty_input_returns_empty(pipeline):
"""Empty connector_docs list returns empty result."""
pipeline.prepare_for_indexing = AsyncMock(return_value=[])
results = await pipeline.index_batch([], MagicMock())
assert results == []
async def test_skips_document_without_matching_connector_doc(
pipeline, make_connector_document
):
"""If prepare returns a doc whose hash has no matching ConnectorDocument, it's skipped."""
doc1 = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-1",
search_space_id=1,
)
orphan_orm = MagicMock(spec=Document)
orphan_orm.unique_identifier_hash = "nonexistent-hash"
pipeline.prepare_for_indexing = AsyncMock(return_value=[orphan_orm])
pipeline.index = AsyncMock()
results = await pipeline.index_batch([doc1], MagicMock())
pipeline.index.assert_not_awaited()
assert results == []

View file

@ -0,0 +1,186 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.document_hashing import compute_unique_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.unit
@pytest.fixture
def mock_session():
session = AsyncMock()
session.refresh = AsyncMock()
return session
@pytest.fixture
def pipeline(mock_session):
return IndexingPipelineService(mock_session)
def _make_orm_doc(connector_doc, doc_id):
"""Create a MagicMock Document bound to a ConnectorDocument's hash."""
doc = MagicMock(spec=Document)
doc.id = doc_id
doc.unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
doc.status = DocumentStatus.pending()
return doc
async def test_index_calls_embed_and_chunk_via_to_thread(
pipeline, make_connector_document, monkeypatch
):
"""index() runs embed_texts and chunk_text via asyncio.to_thread, not blocking the loop."""
to_thread_calls = []
original_to_thread = asyncio.to_thread
async def tracking_to_thread(func, *args, **kwargs):
to_thread_calls.append(func.__name__)
return await original_to_thread(func, *args, **kwargs)
monkeypatch.setattr(asyncio, "to_thread", tracking_to_thread)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
AsyncMock(return_value="Summary."),
)
mock_chunk = MagicMock(return_value=["chunk1"])
mock_chunk.__name__ = "chunk_text"
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
mock_chunk,
)
mock_embed = MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts])
mock_embed.__name__ = "embed_texts"
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.embed_texts",
mock_embed,
)
connector_doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-1",
search_space_id=1,
)
document = MagicMock(spec=Document)
document.id = 1
document.status = DocumentStatus.pending()
await pipeline.index(document, connector_doc, llm=MagicMock())
assert "chunk_text" in to_thread_calls
assert "embed_texts" in to_thread_calls
def _mock_session_factory(orm_docs_by_id):
"""Replace get_celery_session_maker with a two-level callable.
get_celery_session_maker() -> session_maker
session_maker() -> async context manager yielding a mock session
"""
def _get_maker():
def _make_session():
session = MagicMock()
session.get = AsyncMock(
side_effect=lambda model, doc_id: orm_docs_by_id.get(doc_id)
)
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=session)
ctx.__aexit__ = AsyncMock(return_value=False)
return ctx
return _make_session
return _get_maker
async def test_batch_parallel_indexes_all_documents(
pipeline, make_connector_document, monkeypatch
):
"""index_batch_parallel indexes all documents and returns correct counts."""
docs = [
make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id=f"msg-{i}",
search_space_id=1,
)
for i in range(3)
]
orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)]
pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs)
orm_by_id = {d.id: d for d in orm_docs}
monkeypatch.setattr(
"app.tasks.celery_tasks.get_celery_session_maker",
_mock_session_factory(orm_by_id),
)
index_calls = []
async def fake_index(self, document, connector_doc, llm):
index_calls.append(document.id)
document.status = DocumentStatus.ready()
return document
monkeypatch.setattr(IndexingPipelineService, "index", fake_index)
async def mock_get_llm(session):
return MagicMock()
_, indexed, failed = await pipeline.index_batch_parallel(
docs, mock_get_llm, max_concurrency=2
)
assert indexed == 3
assert failed == 0
assert sorted(index_calls) == [1, 2, 3]
async def test_batch_parallel_one_failure_does_not_affect_others(
pipeline, make_connector_document, monkeypatch
):
"""One document failure doesn't prevent other documents from being indexed."""
docs = [
make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id=f"msg-{i}",
search_space_id=1,
)
for i in range(3)
]
orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)]
pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs)
orm_by_id = {d.id: d for d in orm_docs}
monkeypatch.setattr(
"app.tasks.celery_tasks.get_celery_session_maker",
_mock_session_factory(orm_by_id),
)
async def failing_index(self, document, connector_doc, llm):
if document.id == 2:
raise RuntimeError("LLM exploded")
document.status = DocumentStatus.ready()
return document
monkeypatch.setattr(IndexingPipelineService, "index", failing_index)
async def mock_get_llm(session):
return MagicMock()
_, indexed, failed = await pipeline.index_batch_parallel(
docs, mock_get_llm, max_concurrency=4
)
assert indexed == 2
assert failed == 1

View file

@ -0,0 +1,127 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.db import Document, DocumentType
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.unit
@pytest.fixture
def mock_session():
session = AsyncMock()
return session
@pytest.fixture
def pipeline(mock_session):
return IndexingPipelineService(mock_session)
def _make_execute_side_effect(doc_by_hash: dict):
"""Return a side_effect for session.execute that resolves documents by hash."""
async def _side_effect(stmt):
result = MagicMock()
for h, doc in doc_by_hash.items():
if h in str(stmt.compile(compile_kwargs={"literal_binds": True})):
result.scalars.return_value.first.return_value = doc
return result
result.scalars.return_value.first.return_value = None
return result
return _side_effect
async def test_updates_hash_and_type_for_legacy_document(
pipeline, mock_session, make_connector_document
):
"""Legacy Composio document gets unique_identifier_hash and document_type updated."""
doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-abc",
search_space_id=1,
)
legacy_hash = compute_identifier_hash("COMPOSIO_GMAIL_CONNECTOR", "msg-abc", 1)
native_hash = compute_identifier_hash("GOOGLE_GMAIL_CONNECTOR", "msg-abc", 1)
existing = MagicMock(spec=Document)
existing.unique_identifier_hash = legacy_hash
existing.document_type = DocumentType.COMPOSIO_GMAIL_CONNECTOR
result_mock = MagicMock()
result_mock.scalars.return_value.first.return_value = existing
mock_session.execute = AsyncMock(return_value=result_mock)
await pipeline.migrate_legacy_docs([doc])
assert existing.unique_identifier_hash == native_hash
assert existing.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
mock_session.commit.assert_awaited_once()
async def test_noop_when_no_legacy_document_exists(
pipeline, mock_session, make_connector_document
):
"""No updates when no legacy Composio document is found in DB."""
doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-xyz",
search_space_id=1,
)
result_mock = MagicMock()
result_mock.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=result_mock)
await pipeline.migrate_legacy_docs([doc])
mock_session.commit.assert_awaited_once()
async def test_skips_non_google_doc_types(
pipeline, mock_session, make_connector_document
):
"""Non-Google doc types have no legacy mapping and trigger no DB query."""
doc = make_connector_document(
document_type=DocumentType.SLACK_CONNECTOR,
unique_id="slack-123",
search_space_id=1,
)
await pipeline.migrate_legacy_docs([doc])
mock_session.execute.assert_not_awaited()
mock_session.commit.assert_awaited_once()
async def test_handles_all_three_google_types(
pipeline, mock_session, make_connector_document
):
"""Each native Google type correctly maps to its Composio legacy type."""
mappings = [
(DocumentType.GOOGLE_GMAIL_CONNECTOR, "COMPOSIO_GMAIL_CONNECTOR"),
(DocumentType.GOOGLE_CALENDAR_CONNECTOR, "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"),
(DocumentType.GOOGLE_DRIVE_FILE, "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"),
]
for native_type, expected_legacy in mappings:
doc = make_connector_document(
document_type=native_type,
unique_id="id-1",
search_space_id=1,
)
existing = MagicMock(spec=Document)
existing.document_type = DocumentType(expected_legacy)
result_mock = MagicMock()
result_mock.scalars.return_value.first.return_value = existing
mock_session.execute = AsyncMock(return_value=result_mock)
mock_session.commit = AsyncMock()
await pipeline.migrate_legacy_docs([doc])
assert existing.document_type == native_type

8862
surfsense_backend/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,14 @@
import { loader } from "fumadocs-core/source";
import type { Metadata } from "next";
import { changelog } from "@/.source/server";
import { formatDate } from "@/lib/utils";
import { getMDXComponents } from "@/mdx-components";
export const metadata: Metadata = {
title: "Changelog | SurfSense",
description: "See what's new in SurfSense.",
};
const source = loader({
baseUrl: "/changelog",
source: changelog.toFumadocsSource(),

View file

@ -1,6 +1,11 @@
import React from "react";
import type { Metadata } from "next";
import { ContactFormGridWithDetails } from "@/components/contact/contact-form";
export const metadata: Metadata = {
title: "Contact | SurfSense",
description: "Get in touch with the SurfSense team.",
};
const page = () => {
return (
<div>

View file

@ -1,6 +1,11 @@
import React from "react";
import type { Metadata } from "next";
import PricingBasic from "@/components/pricing/pricing-section";
export const metadata: Metadata = {
title: "Pricing | SurfSense",
description: "Explore SurfSense plans and pricing options.",
};
const page = () => {
return (
<div>

View file

@ -473,14 +473,14 @@ export function DocumentsTableShell({
}, [deletableSelectedIds, bulkDeleteDocuments, deleteDocument]);
const bulkDeleteBar = hasDeletableSelection ? (
<div className="flex items-center justify-center py-1.5 border-b border-border/50 bg-destructive/5 shrink-0 animate-in fade-in slide-in-from-top-1 duration-150">
<div className="absolute inset-x-0 top-0 z-10 flex items-center justify-center py-1 pointer-events-none animate-in fade-in duration-150">
<button
type="button"
onClick={() => setBulkDeleteConfirmOpen(true)}
className="flex items-center gap-1.5 px-3 py-1 rounded-md bg-destructive text-destructive-foreground shadow-sm text-xs font-medium hover:bg-destructive/90 transition-colors"
className="pointer-events-auto flex items-center gap-1.5 px-3 py-1 rounded-md bg-destructive text-destructive-foreground shadow-lg text-xs font-medium hover:bg-destructive/90 transition-colors"
>
<Trash2 size={12} />
Delete ({deletableSelectedIds.length} selected)
Delete {deletableSelectedIds.length} {deletableSelectedIds.length === 1 ? "item" : "items"}
</button>
</div>
) : null;
@ -526,7 +526,6 @@ export function DocumentsTableShell({
</TableRow>
</TableHeader>
</Table>
{bulkDeleteBar}
{loading ? (
<div className="flex-1 overflow-auto">
<Table className="table-fixed w-full">
@ -594,7 +593,8 @@ export function DocumentsTableShell({
)}
</div>
) : (
<div ref={desktopScrollRef} className="flex-1 overflow-auto">
<div ref={desktopScrollRef} className="flex-1 overflow-auto relative">
{bulkDeleteBar}
<Table className="table-fixed w-full">
<TableBody>
{sorted.map((doc) => {
@ -788,9 +788,6 @@ export function DocumentsTableShell({
)}
</div>
{/* Mobile bulk delete bar */}
<div className="md:hidden">{bulkDeleteBar}</div>
{/* Mobile Card View */}
{loading ? (
<div className="md:hidden divide-y divide-border/50 flex-1 overflow-auto">
@ -846,8 +843,9 @@ export function DocumentsTableShell({
) : (
<div
ref={mobileScrollRef}
className="md:hidden divide-y divide-border/50 flex-1 overflow-auto"
className="md:hidden divide-y divide-border/50 flex-1 overflow-auto relative"
>
{bulkDeleteBar}
{sorted.map((doc) => {
const isMentioned = mentionedDocIds?.has(doc.id) ?? false;
const statusState = doc.status?.state ?? "ready";

View file

@ -595,6 +595,7 @@ function CreateInviteDialog({
});
} catch (error) {
console.error("Failed to create invite:", error);
toast.error("Failed to create invite. Please try again.");
} finally {
setCreating(false);
}

View file

@ -29,9 +29,6 @@ export const createDocumentMutationAtom = atomWithMutation((get) => {
queryClient.invalidateQueries({
queryKey: cacheKeys.documents.globalQueryParams(documentsQueryParams),
});
queryClient.invalidateQueries({
queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined),
});
},
};
});
@ -75,9 +72,6 @@ export const updateDocumentMutationAtom = atomWithMutation((get) => {
queryClient.invalidateQueries({
queryKey: cacheKeys.documents.document(String(request.id)),
});
queryClient.invalidateQueries({
queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined),
});
},
};
});
@ -109,9 +103,6 @@ export const deleteDocumentMutationAtom = atomWithMutation((get) => {
queryClient.invalidateQueries({
queryKey: cacheKeys.documents.document(String(request.id)),
});
queryClient.invalidateQueries({
queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined),
});
},
};
});

View file

@ -1,38 +0,0 @@
import { atomWithQuery } from "jotai-tanstack-query";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import type { SearchDocumentsRequest } from "@/contracts/types/document.types";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import { globalDocumentsQueryParamsAtom } from "./ui.atoms";
export const documentsAtom = atomWithQuery((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
const queryParams = get(globalDocumentsQueryParamsAtom);
return {
queryKey: cacheKeys.documents.globalQueryParams(queryParams),
enabled: !!searchSpaceId,
queryFn: async () => {
return documentsApiService.getDocuments({
queryParams: queryParams,
});
},
};
});
export const documentTypeCountsAtom = atomWithQuery((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined),
enabled: !!searchSpaceId,
staleTime: 10 * 60 * 1000, // 10 minutes
queryFn: async () => {
return documentsApiService.getDocumentTypeCounts({
queryParams: {
search_space_id: searchSpaceId ?? undefined,
},
});
},
};
});

View file

@ -4,7 +4,7 @@ import { useAtomValue, useSetAtom } from "jotai";
import { AlertTriangle, Cable, Settings } from "lucide-react";
import { forwardRef, useEffect, useImperativeHandle, useMemo, useState } from "react";
import { createPortal } from "react-dom";
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms";
import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
import { statusInboxItemsAtom } from "@/atoms/inbox/status-inbox.atom";
import {
globalNewLLMConfigsAtom,
@ -72,9 +72,9 @@ export const ConnectorIndicator = forwardRef<ConnectorIndicatorHandle, Connector
const llmConfigLoading = preferencesLoading || globalConfigsLoading;
// Fetch document type counts via the lightweight /type-counts endpoint (cached 10 min)
const { data: documentTypeCounts, isFetching: documentTypesLoading } =
useAtomValue(documentTypeCountsAtom);
// Real-time document type counts via Zero (updates instantly as docs are indexed)
const documentTypeCounts = useZeroDocumentTypeCounts(searchSpaceId);
const documentTypesLoading = documentTypeCounts === undefined;
// Read status inbox items from shared atom (populated by LayoutDataProvider)
// instead of creating a duplicate useInbox("status") hook.

View file

@ -867,6 +867,9 @@ export const useConnectorDialog = () => {
setIsOpen(false);
setIsFromOAuth(false);
setIndexingConfig(null);
setIndexingConnector(null);
setIndexingConnectorConfig(null);
refreshConnectors();
queryClient.invalidateQueries({
@ -898,6 +901,9 @@ export const useConnectorDialog = () => {
const handleSkipIndexing = useCallback(() => {
setIsOpen(false);
setIsFromOAuth(false);
setIndexingConfig(null);
setIndexingConnector(null);
setIndexingConnectorConfig(null);
}, [setIsOpen]);
// Handle starting edit mode

View file

@ -1,6 +1,5 @@
"use client";
import { FolderPlus } from "lucide-react";
import { useCallback, useEffect, useRef, useState } from "react";
import { Button } from "@/components/ui/button";
import {
@ -52,22 +51,25 @@ export function CreateFolderDialog({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-sm">
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
<FolderPlus className="size-5 text-muted-foreground" />
<DialogContent className="select-none max-w-[90vw] sm:max-w-sm p-4 sm:p-5 data-[state=open]:animate-none data-[state=closed]:animate-none">
<DialogHeader className="space-y-2 pb-2">
<div className="flex items-center gap-2 sm:gap-3">
<div className="flex-1 min-w-0">
<DialogTitle className="text-base sm:text-lg">
{isSubfolder ? "New subfolder" : "New folder"}
</DialogTitle>
<DialogDescription>
<DialogDescription className="text-xs sm:text-sm mt-0.5">
{isSubfolder
? `Create a new folder inside "${parentFolderName}".`
: "Create a new folder at the root level."}
</DialogDescription>
</div>
</div>
</DialogHeader>
<form onSubmit={handleSubmit} className="flex flex-col gap-4">
<form onSubmit={handleSubmit} className="flex flex-col gap-3 sm:gap-4">
<div className="flex flex-col gap-2">
<Label htmlFor="folder-name">Folder name</Label>
<Label htmlFor="folder-name" className="text-sm">Folder name</Label>
<Input
ref={inputRef}
id="folder-name"
@ -76,14 +78,24 @@ export function CreateFolderDialog({
onChange={(e) => setName(e.target.value)}
maxLength={255}
autoComplete="off"
className="text-sm h-9 sm:h-10"
/>
</div>
<DialogFooter>
<Button type="button" variant="outline" onClick={() => onOpenChange(false)}>
<DialogFooter className="flex-row justify-end gap-2 pt-2 sm:pt-3">
<Button
type="button"
variant="secondary"
onClick={() => onOpenChange(false)}
className="h-8 sm:h-9 text-xs sm:text-sm"
>
Cancel
</Button>
<Button type="submit" disabled={!name.trim()}>
<Button
type="submit"
disabled={!name.trim()}
className="h-8 sm:h-9 text-xs sm:text-sm"
>
Create
</Button>
</DialogFooter>

View file

@ -1,25 +1,32 @@
"use client";
import { Eye, MoreHorizontal, Move, Pencil, Trash2 } from "lucide-react";
import React, { useCallback } from "react";
import { AlertCircle, Clock, Download, Eye, MoreHorizontal, Move, PenLine, Trash2 } from "lucide-react";
import React, { useCallback, useRef, useState } from "react";
import { useDrag } from "react-dnd";
import { getDocumentTypeIcon } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon";
import { ExportContextItems, ExportDropdownItems } from "@/components/shared/ExportMenuItems";
import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox";
import {
ContextMenu,
ContextMenuContent,
ContextMenuItem,
ContextMenuSeparator,
ContextMenuSub,
ContextMenuSubContent,
ContextMenuSubTrigger,
ContextMenuTrigger,
} from "@/components/ui/context-menu";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuSeparator,
DropdownMenuSub,
DropdownMenuSubContent,
DropdownMenuSubTrigger,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { Spinner } from "@/components/ui/spinner";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
import { cn } from "@/lib/utils";
import { DND_TYPES } from "./FolderNode";
@ -41,6 +48,9 @@ interface DocumentNodeProps {
onEdit: (doc: DocumentNodeDoc) => void;
onDelete: (doc: DocumentNodeDoc) => void;
onMove: (doc: DocumentNodeDoc) => void;
onExport?: (doc: DocumentNodeDoc, format: string) => void;
contextMenuOpen?: boolean;
onContextMenuOpenChange?: (open: boolean) => void;
}
export const DocumentNode = React.memo(function DocumentNode({
@ -52,6 +62,9 @@ export const DocumentNode = React.memo(function DocumentNode({
onEdit,
onDelete,
onMove,
onExport,
contextMenuOpen,
onContextMenuOpenChange,
}: DocumentNodeProps) {
const statusState = doc.status?.state ?? "ready";
const isSelectable = statusState !== "pending" && statusState !== "processing";
@ -74,48 +87,90 @@ export const DocumentNode = React.memo(function DocumentNode({
);
const isProcessing = statusState === "pending" || statusState === "processing";
const [dropdownOpen, setDropdownOpen] = useState(false);
const [exporting, setExporting] = useState<string | null>(null);
const rowRef = useRef<HTMLButtonElement>(null);
const handleExport = useCallback(
(format: string) => {
if (!onExport) return;
setExporting(format);
onExport(doc, format);
setTimeout(() => setExporting(null), 2000);
},
[doc, onExport]
);
const attachRef = useCallback(
(node: HTMLButtonElement | null) => {
(rowRef as React.MutableRefObject<HTMLButtonElement | null>).current = node;
drag(node);
},
[drag]
);
return (
<ContextMenu>
<ContextMenu onOpenChange={onContextMenuOpenChange}>
<ContextMenuTrigger asChild>
{/* biome-ignore lint/a11y/useSemanticElements: div required for drag ref */}
<div
ref={drag}
role="button"
tabIndex={0}
<button
type="button"
ref={attachRef}
className={cn(
"group flex h-8 items-center gap-1.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none",
"group flex h-8 w-full items-center gap-2.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none text-left",
isMentioned && "bg-accent/30",
isDragging && "opacity-40"
)}
style={{ paddingLeft: `${depth * 16 + 4}px` }}
onClick={handleCheckChange}
onKeyDown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
handleCheckChange();
}
}}
>
{isSelectable ? (
{(() => {
if (statusState === "pending") {
return (
<Tooltip>
<TooltipTrigger asChild>
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
<Clock className="h-3.5 w-3.5 text-muted-foreground/60" />
</span>
</TooltipTrigger>
<TooltipContent side="top">Pending - waiting to be synced</TooltipContent>
</Tooltip>
);
}
if (statusState === "processing") {
return (
<Tooltip>
<TooltipTrigger asChild>
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
<Spinner size="xs" className="text-primary" />
</span>
</TooltipTrigger>
<TooltipContent side="top">Syncing</TooltipContent>
</Tooltip>
);
}
if (statusState === "failed") {
return (
<Tooltip>
<TooltipTrigger asChild>
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
<AlertCircle className="h-3.5 w-3.5 text-destructive" />
</span>
</TooltipTrigger>
<TooltipContent side="top" className="max-w-xs">
{doc.status?.reason || "Processing failed"}
</TooltipContent>
</Tooltip>
);
}
return (
<Checkbox
checked={isMentioned}
onCheckedChange={handleCheckChange}
onClick={(e) => e.stopPropagation()}
className="h-3.5 w-3.5 shrink-0"
/>
) : (
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
<span
className={cn(
"h-2 w-2 rounded-full",
statusState === "processing" && "animate-pulse bg-amber-500",
statusState === "pending" && "bg-muted-foreground/40",
statusState === "failed" && "bg-destructive"
)}
/>
</span>
)}
);
})()}
<span className="flex-1 min-w-0 truncate">{doc.title}</span>
@ -126,25 +181,28 @@ export const DocumentNode = React.memo(function DocumentNode({
)}
</span>
<DropdownMenu>
<DropdownMenu open={dropdownOpen} onOpenChange={setDropdownOpen}>
<DropdownMenuTrigger asChild>
<Button
variant="ghost"
size="icon"
className="h-6 w-6 shrink-0 opacity-0 group-hover:opacity-100 transition-opacity"
className={cn(
"hidden sm:inline-flex h-6 w-6 shrink-0 hover:bg-transparent",
dropdownOpen ? "opacity-100 bg-accent hover:bg-accent" : "opacity-0 group-hover:opacity-100"
)}
onClick={(e) => e.stopPropagation()}
>
<MoreHorizontal className="h-3.5 w-3.5" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end" className="w-44">
<DropdownMenuContent align="end" className="w-40">
<DropdownMenuItem onClick={() => onPreview(doc)}>
<Eye className="mr-2 h-4 w-4" />
Open
</DropdownMenuItem>
{isEditable && (
<DropdownMenuItem onClick={() => onEdit(doc)}>
<Pencil className="mr-2 h-4 w-4" />
<PenLine className="mr-2 h-4 w-4" />
Edit
</DropdownMenuItem>
)}
@ -152,7 +210,17 @@ export const DocumentNode = React.memo(function DocumentNode({
<Move className="mr-2 h-4 w-4" />
Move to...
</DropdownMenuItem>
<DropdownMenuSeparator />
{onExport && (
<DropdownMenuSub>
<DropdownMenuSubTrigger>
<Download className="mr-2 h-4 w-4" />
Export
</DropdownMenuSubTrigger>
<DropdownMenuSubContent className="min-w-[180px]">
<ExportDropdownItems onExport={handleExport} exporting={exporting} />
</DropdownMenuSubContent>
</DropdownMenuSub>
)}
<DropdownMenuItem
className="text-destructive focus:text-destructive"
disabled={isProcessing}
@ -163,17 +231,18 @@ export const DocumentNode = React.memo(function DocumentNode({
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
</button>
</ContextMenuTrigger>
<ContextMenuContent className="w-44">
{contextMenuOpen && (
<ContextMenuContent className="w-40">
<ContextMenuItem onClick={() => onPreview(doc)}>
<Eye className="mr-2 h-4 w-4" />
Open
</ContextMenuItem>
{isEditable && (
<ContextMenuItem onClick={() => onEdit(doc)}>
<Pencil className="mr-2 h-4 w-4" />
<PenLine className="mr-2 h-4 w-4" />
Edit
</ContextMenuItem>
)}
@ -181,7 +250,17 @@ export const DocumentNode = React.memo(function DocumentNode({
<Move className="mr-2 h-4 w-4" />
Move to...
</ContextMenuItem>
<ContextMenuSeparator />
{onExport && (
<ContextMenuSub>
<ContextMenuSubTrigger>
<Download className="mr-2 h-4 w-4" />
Export
</ContextMenuSubTrigger>
<ContextMenuSubContent className="min-w-[180px]">
<ExportContextItems onExport={handleExport} exporting={exporting} />
</ContextMenuSubContent>
</ContextMenuSub>
)}
<ContextMenuItem
className="text-destructive focus:text-destructive"
disabled={isProcessing}
@ -191,6 +270,7 @@ export const DocumentNode = React.memo(function DocumentNode({
Delete
</ContextMenuItem>
</ContextMenuContent>
)}
</ContextMenu>
);
});

View file

@ -8,7 +8,7 @@ import {
FolderPlus,
MoreHorizontal,
Move,
Pencil,
PenLine,
Trash2,
} from "lucide-react";
import React, { useCallback, useEffect, useRef, useState } from "react";
@ -18,14 +18,12 @@ import {
ContextMenu,
ContextMenuContent,
ContextMenuItem,
ContextMenuSeparator,
ContextMenuTrigger,
} from "@/components/ui/context-menu";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { cn } from "@/lib/utils";
@ -66,6 +64,8 @@ interface FolderNodeProps {
onReorderFolder?: (folderId: number, beforePos: string | null, afterPos: string | null) => void;
siblingPositions?: { before: string | null; after: string | null };
disabledDropIds?: Set<number>;
contextMenuOpen?: boolean;
onContextMenuOpenChange?: (open: boolean) => void;
}
function getDropZone(
@ -99,6 +99,8 @@ export const FolderNode = React.memo(function FolderNode({
onReorderFolder,
siblingPositions,
disabledDropIds,
contextMenuOpen,
onContextMenuOpenChange,
}: FolderNodeProps) {
const [renameValue, setRenameValue] = useState(folder.name);
const inputRef = useRef<HTMLInputElement>(null);
@ -213,7 +215,7 @@ export const FolderNode = React.memo(function FolderNode({
const FolderIcon = isExpanded ? FolderOpen : Folder;
return (
<ContextMenu>
<ContextMenu onOpenChange={onContextMenuOpenChange}>
<ContextMenuTrigger asChild disabled={isRenaming}>
{/* biome-ignore lint/a11y/useSemanticElements: div required for drag/drop refs */}
<div
@ -261,7 +263,8 @@ export const FolderNode = React.memo(function FolderNode({
onBlur={handleRenameSubmit}
onKeyDown={handleRenameKeyDown}
onClick={(e) => e.stopPropagation()}
className="flex-1 min-w-0 rounded border border-primary bg-background px-1 py-0.5 text-sm outline-none"
placeholder="Enter folder name"
className="flex-1 min-w-0 bg-transparent px-1 py-0.5 text-sm outline-none caret-primary placeholder:text-muted-foreground/50"
/>
) : (
<span className="flex-1 min-w-0 truncate">{folder.name}</span>
@ -279,13 +282,13 @@ export const FolderNode = React.memo(function FolderNode({
<Button
variant="ghost"
size="icon"
className="h-6 w-6 shrink-0 opacity-0 group-hover:opacity-100 transition-opacity"
className="hidden sm:inline-flex h-6 w-6 shrink-0 opacity-0 group-hover:opacity-100 transition-opacity"
onClick={(e) => e.stopPropagation()}
>
<MoreHorizontal className="h-3.5 w-3.5" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end" className="w-48">
<DropdownMenuContent align="end" className="w-40">
<DropdownMenuItem
onClick={(e) => {
e.stopPropagation();
@ -301,7 +304,7 @@ export const FolderNode = React.memo(function FolderNode({
startRename();
}}
>
<Pencil className="mr-2 h-4 w-4" />
<PenLine className="mr-2 h-4 w-4" />
Rename
</DropdownMenuItem>
<DropdownMenuItem
@ -313,7 +316,6 @@ export const FolderNode = React.memo(function FolderNode({
<Move className="mr-2 h-4 w-4" />
Move to...
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuItem
className="text-destructive focus:text-destructive"
onClick={(e) => {
@ -330,21 +332,20 @@ export const FolderNode = React.memo(function FolderNode({
</div>
</ContextMenuTrigger>
{!isRenaming && (
<ContextMenuContent className="w-48">
{!isRenaming && contextMenuOpen && (
<ContextMenuContent className="w-40">
<ContextMenuItem onClick={() => onCreateSubfolder(folder.id)}>
<FolderPlus className="mr-2 h-4 w-4" />
New subfolder
</ContextMenuItem>
<ContextMenuItem onClick={() => startRename()}>
<Pencil className="mr-2 h-4 w-4" />
<PenLine className="mr-2 h-4 w-4" />
Rename
</ContextMenuItem>
<ContextMenuItem onClick={() => onMove(folder)}>
<Move className="mr-2 h-4 w-4" />
Move to...
</ContextMenuItem>
<ContextMenuSeparator />
<ContextMenuItem
className="text-destructive focus:text-destructive"
onClick={() => onDelete(folder)}

View file

@ -124,10 +124,18 @@ export function FolderPickerDialog({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-sm">
<DialogHeader>
<DialogTitle>{title}</DialogTitle>
{description && <DialogDescription>{description}</DialogDescription>}
<DialogContent className="select-none max-w-[90vw] sm:max-w-sm p-4 sm:p-5 data-[state=open]:animate-none data-[state=closed]:animate-none">
<DialogHeader className="space-y-2 pb-2">
<div className="flex items-center gap-2 sm:gap-3">
<div className="flex-1 min-w-0">
<DialogTitle className="text-base sm:text-lg">{title}</DialogTitle>
{description && (
<DialogDescription className="text-xs sm:text-sm mt-0.5">
{description}
</DialogDescription>
)}
</div>
</div>
</DialogHeader>
<div className="max-h-[300px] overflow-y-auto rounded-md border p-1">
@ -147,11 +155,17 @@ export function FolderPickerDialog({
{renderPickerLevel(null, 1)}
</div>
<DialogFooter>
<Button variant="outline" onClick={() => onOpenChange(false)}>
<DialogFooter className="flex-row justify-end gap-2 pt-2 sm:pt-3">
<Button
variant="secondary"
onClick={() => onOpenChange(false)}
className="h-8 sm:h-9 text-xs sm:text-sm"
>
Cancel
</Button>
<Button onClick={handleConfirm}>Move here</Button>
<Button onClick={handleConfirm} className="h-8 sm:h-9 text-xs sm:text-sm">
Move here
</Button>
</DialogFooter>
</DialogContent>
</Dialog>

View file

@ -1,8 +1,8 @@
"use client";
import { useAtom } from "jotai";
import { TreePine } from "lucide-react";
import { useCallback, useMemo } from "react";
import { CirclePlus } from "lucide-react";
import { useCallback, useMemo, useState } from "react";
import { DndProvider } from "react-dnd";
import { HTML5Backend } from "react-dnd-html5-backend";
import { renamingFolderIdAtom } from "@/atoms/documents/folder.atoms";
@ -28,6 +28,7 @@ interface FolderTreeViewProps {
onEditDocument: (doc: DocumentNodeDoc) => void;
onDeleteDocument: (doc: DocumentNodeDoc) => void;
onMoveDocument: (doc: DocumentNodeDoc) => void;
onExportDocument?: (doc: DocumentNodeDoc, format: string) => void;
activeTypes: DocumentTypeEnum[];
onDropIntoFolder?: (
itemType: "folder" | "document",
@ -62,6 +63,7 @@ export function FolderTreeView({
onEditDocument,
onDeleteDocument,
onMoveDocument,
onExportDocument,
activeTypes,
onDropIntoFolder,
onReorderFolder,
@ -80,6 +82,8 @@ export function FolderTreeView({
return counts;
}, [folders, foldersByParent, docsByFolder]);
const [openContextMenuId, setOpenContextMenuId] = useState<string | null>(null);
// Single subscription for rename state — derived boolean passed to each FolderNode
const [renamingFolderId, setRenamingFolderId] = useAtom(renamingFolderIdAtom);
const handleStartRename = useCallback(
@ -157,6 +161,8 @@ export function FolderTreeView({
onDropIntoFolder={onDropIntoFolder}
onReorderFolder={onReorderFolder}
siblingPositions={siblingPositions}
contextMenuOpen={openContextMenuId === `folder-${f.id}`}
onContextMenuOpenChange={(open) => setOpenContextMenuId(open ? `folder-${f.id}` : null)}
/>
);
@ -177,6 +183,9 @@ export function FolderTreeView({
onEdit={onEditDocument}
onDelete={onDeleteDocument}
onMove={onMoveDocument}
onExport={onExportDocument}
contextMenuOpen={openContextMenuId === `doc-${d.id}`}
onContextMenuOpenChange={(open) => setOpenContextMenuId(open ? `doc-${d.id}` : null)}
/>
);
}
@ -189,7 +198,7 @@ export function FolderTreeView({
if (treeNodes.length === 0 && folders.length === 0 && documents.length === 0) {
return (
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-4 py-12 text-muted-foreground">
<TreePine className="h-10 w-10" />
<CirclePlus className="h-10 w-10 rotate-45" />
<p className="text-sm">No documents yet</p>
</div>
);
@ -198,7 +207,7 @@ export function FolderTreeView({
if (treeNodes.length === 0 && activeTypes.length > 0) {
return (
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-4 py-12 text-muted-foreground">
<TreePine className="h-10 w-10" />
<CirclePlus className="h-10 w-10 rotate-45" />
<p className="text-sm">No matching documents</p>
</div>
);

View file

@ -185,7 +185,7 @@ function DateTimePickerField({
type="time"
value={time}
onChange={handleTimeChange}
className="w-[120px] text-sm shrink-0 pl-1.5 [&::-webkit-calendar-picker-indicator]:order-first [&::-webkit-calendar-picker-indicator]:mr-1"
className="w-[120px] text-sm shrink-0 appearance-none [&::-webkit-calendar-picker-indicator]:hidden [&::-webkit-calendar-picker-indicator]:appearance-none"
/>
</div>
);

View file

@ -396,10 +396,13 @@ export function AllPrivateChatsSidebarContent({
variant="ghost"
size="icon"
className={cn(
"h-6 w-6 shrink-0",
"h-6 w-6 shrink-0 hover:bg-transparent",
isMobile
? "opacity-0 pointer-events-none absolute"
: openDropdownId === thread.id
? "opacity-100"
: "md:opacity-0 md:group-hover:opacity-100 md:focus:opacity-100",
openDropdownId === thread.id && "bg-accent hover:bg-accent",
"transition-opacity"
)}
disabled={isBusy}

View file

@ -396,10 +396,13 @@ export function AllSharedChatsSidebarContent({
variant="ghost"
size="icon"
className={cn(
"h-6 w-6 shrink-0",
"h-6 w-6 shrink-0 hover:bg-transparent",
isMobile
? "opacity-0 pointer-events-none absolute"
: openDropdownId === thread.id
? "opacity-100"
: "md:opacity-0 md:group-hover:opacity-100 md:focus:opacity-100",
openDropdownId === thread.id && "bg-accent hover:bg-accent",
"transition-opacity"
)}
disabled={isBusy}

View file

@ -79,14 +79,21 @@ export function ChatListItem({
: "bg-gradient-to-l from-sidebar from-60% to-transparent group-hover/item:from-accent",
isMobile
? "opacity-0"
: isActive
: isActive || dropdownOpen
? "opacity-100"
: "opacity-0 group-hover/item:opacity-100"
)}
>
<DropdownMenu open={dropdownOpen} onOpenChange={setDropdownOpen}>
<DropdownMenuTrigger asChild>
<Button variant="ghost" size="icon" className="pointer-events-auto h-6 w-6">
<Button
variant="ghost"
size="icon"
className={cn(
"pointer-events-auto h-6 w-6 hover:bg-transparent",
dropdownOpen && "bg-accent hover:bg-accent"
)}
>
<MoreHorizontal className="h-3.5 w-3.5 text-muted-foreground" />
<span className="sr-only">{t("more_options")}</span>
</Button>

View file

@ -7,6 +7,7 @@ import { useParams } from "next/navigation";
import { useTranslations } from "next-intl";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner";
import { EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems";
import { DocumentsFilters } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsFilters";
import {
DocumentsTableShell,
@ -33,6 +34,7 @@ import { useDocumentSearch } from "@/hooks/use-document-search";
import { useDocuments } from "@/hooks/use-documents";
import { useMediaQuery } from "@/hooks/use-media-query";
import { foldersApiService } from "@/lib/apis/folders-api.service";
import { authenticatedFetch } from "@/lib/auth-utils";
import { queries } from "@/zero/queries/index";
import { SidebarSlideOutPanel } from "./SidebarSlideOutPanel";
@ -234,6 +236,43 @@ export function DocumentsSidebar({
setFolderPickerOpen(true);
}, []);
const handleExportDocument = useCallback(
async (doc: DocumentNodeDoc, format: string) => {
const safeTitle =
doc.title
.replace(/[^a-zA-Z0-9 _-]/g, "_")
.trim()
.slice(0, 80) || "document";
const ext = EXPORT_FILE_EXTENSIONS[format] ?? format;
try {
const response = await authenticatedFetch(
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${doc.id}/export?format=${format}`,
{ method: "GET" }
);
if (!response.ok) {
const errorData = await response.json().catch(() => ({ detail: "Export failed" }));
throw new Error(errorData.detail || "Export failed");
}
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `${safeTitle}.${ext}`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
} catch (err) {
console.error(`Export ${format} failed:`, err);
toast.error(err instanceof Error ? err.message : `Export failed`);
}
},
[searchSpaceId]
);
const handleFolderPickerSelect = useCallback(
async (targetFolderId: number | null) => {
if (!folderPickerTarget) return;
@ -606,6 +645,7 @@ export function DocumentsSidebar({
}}
onDeleteDocument={(doc) => handleDeleteDocument(doc.id)}
onMoveDocument={handleMoveDocument}
onExportDocument={handleExportDocument}
activeTypes={activeTypes}
onDropIntoFolder={handleDropIntoFolder}
onReorderFolder={handleReorderFolder}
@ -617,7 +657,7 @@ export function DocumentsSidebar({
open={folderPickerOpen}
onOpenChange={setFolderPickerOpen}
folders={treeFolders}
title={folderPickerTarget?.type === "folder" ? "Move folder to..." : "Move document to..."}
title={folderPickerTarget?.type === "folder" ? "Move folder to" : "Move document to"}
description="Select a destination folder, or choose Root to move to the top level."
disabledFolderIds={folderPickerTarget?.disabledIds}
onSelect={handleFolderPickerSelect}

View file

@ -199,7 +199,7 @@ export function ChatShareButton({ thread, onVisibilityChange, className }: ChatS
className={cn(
"w-full flex items-center gap-2.5 px-2.5 py-2 rounded-md transition-all",
"hover:bg-accent/50 dark:hover:bg-white/10 cursor-pointer",
"focus:outline-none",
"focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2",
isSelected && "bg-accent/80 dark:bg-white/10"
)}
>
@ -248,7 +248,7 @@ export function ChatShareButton({ thread, onVisibilityChange, className }: ChatS
className={cn(
"w-full flex items-center gap-2.5 px-2.5 py-2 rounded-md transition-all",
"hover:bg-accent/50 dark:hover:bg-white/10 cursor-pointer",
"focus:outline-none",
"focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2",
"disabled:opacity-50 disabled:cursor-not-allowed"
)}
>

View file

@ -7,7 +7,7 @@ import { useTheme } from "next-themes";
import { useCallback, useEffect, useRef, useState } from "react";
import { createPortal } from "react-dom";
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms";
import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import { useIsMobile } from "@/hooks/use-mobile";
@ -452,8 +452,8 @@ export function OnboardingTour() {
enabled: !!searchSpaceId,
});
// Get document type counts
const { data: documentTypeCounts } = useAtomValue(documentTypeCountsAtom);
// Real-time document type counts via Zero
const documentTypeCounts = useZeroDocumentTypeCounts(searchSpaceId);
// Get connectors
const { data: connectors = [] } = useAtomValue(connectorsAtom);

View file

@ -15,10 +15,9 @@ import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { ExportDropdownItems, EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems";
import { useMediaQuery } from "@/hooks/use-media-query";
import { baseApiService } from "@/lib/apis/base-api.service";
import { authenticatedFetch } from "@/lib/auth-utils";
@ -198,19 +197,6 @@ export function ReportPanelContent({
}
}, [currentMarkdown]);
// Maps backend format values to download file extensions
const FILE_EXTENSIONS: Record<string, string> = {
pdf: "pdf",
docx: "docx",
html: "html",
latex: "tex",
epub: "epub",
odt: "odt",
plain: "txt",
md: "md",
};
// Export report
const handleExport = useCallback(
async (format: string) => {
setExporting(format);
@ -219,7 +205,7 @@ export function ReportPanelContent({
.replace(/[^a-zA-Z0-9 _-]/g, "_")
.trim()
.slice(0, 80) || "report";
const ext = FILE_EXTENSIONS[format] ?? format;
const ext = EXPORT_FILE_EXTENSIONS[format] ?? format;
try {
if (format === "md") {
if (!currentMarkdown) return;
@ -329,68 +315,11 @@ export function ReportPanelContent({
align="start"
className={`min-w-[200px] select-none${insideDrawer ? " z-[100]" : ""}`}
>
{!shareToken && (
<>
<DropdownMenuLabel className="text-xs text-muted-foreground">
Documents
</DropdownMenuLabel>
<DropdownMenuItem
onClick={() => handleExport("pdf")}
disabled={exporting !== null}
>
PDF (.pdf)
</DropdownMenuItem>
<DropdownMenuItem
onClick={() => handleExport("docx")}
disabled={exporting !== null}
>
Word (.docx)
</DropdownMenuItem>
<DropdownMenuItem
onClick={() => handleExport("odt")}
disabled={exporting !== null}
>
OpenDocument (.odt)
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuLabel className="text-xs text-muted-foreground">
Web &amp; E-Book
</DropdownMenuLabel>
<DropdownMenuItem
onClick={() => handleExport("html")}
disabled={exporting !== null}
>
HTML (.html)
</DropdownMenuItem>
<DropdownMenuItem
onClick={() => handleExport("epub")}
disabled={exporting !== null}
>
EPUB (.epub)
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuLabel className="text-xs text-muted-foreground">
Source &amp; Plain
</DropdownMenuLabel>
<DropdownMenuItem
onClick={() => handleExport("latex")}
disabled={exporting !== null}
>
LaTeX (.tex)
</DropdownMenuItem>
</>
)}
<DropdownMenuItem onClick={() => handleExport("md")} disabled={exporting !== null}>
Markdown (.md)
</DropdownMenuItem>
{!shareToken && (
<DropdownMenuItem
onClick={() => handleExport("plain")}
disabled={exporting !== null}
>
Plain Text (.txt)
</DropdownMenuItem>
)}
<ExportDropdownItems
onExport={handleExport}
exporting={exporting}
showAllFormats={!shareToken}
/>
</DropdownMenuContent>
</DropdownMenu>

View file

@ -27,6 +27,7 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager
const {
data: searchSpace,
isLoading: loading,
isError,
refetch: fetchSearchSpace,
} = useQuery({
queryKey: cacheKeys.searchSpaces.detail(searchSpaceId.toString()),
@ -104,6 +105,17 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager
);
}
if (isError) {
return (
<div className="flex flex-col items-center justify-center gap-3 py-8 text-center">
<p className="text-sm text-destructive">Failed to load settings.</p>
<Button variant="outline" size="sm" onClick={() => fetchSearchSpace()}>
Retry
</Button>
</div>
);
}
return (
<div className="space-y-4 md:space-y-6">
<Alert className="bg-muted/50 py-3 md:py-4">

View file

@ -0,0 +1,142 @@
"use client";
import { Loader2 } from "lucide-react";
import { DropdownMenuItem, DropdownMenuLabel, DropdownMenuSeparator } from "@/components/ui/dropdown-menu";
import { ContextMenuItem } from "@/components/ui/context-menu";
export const EXPORT_FILE_EXTENSIONS: Record<string, string> = {
pdf: "pdf",
docx: "docx",
html: "html",
latex: "tex",
epub: "epub",
odt: "odt",
plain: "txt",
md: "md",
};
interface ExportMenuItemsProps {
onExport: (format: string) => void;
exporting: string | null;
/** Hide server-side formats (PDF, DOCX, etc.) — only show md */
showAllFormats?: boolean;
}
export function ExportDropdownItems({
onExport,
exporting,
showAllFormats = true,
}: ExportMenuItemsProps) {
const handle = (format: string) => (e: React.MouseEvent) => {
e.stopPropagation();
onExport(format);
};
return (
<>
{showAllFormats && (
<>
<DropdownMenuLabel className="text-xs text-muted-foreground">
Documents
</DropdownMenuLabel>
<DropdownMenuItem onClick={handle("pdf")} disabled={exporting !== null}>
{exporting === "pdf" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
PDF (.pdf)
</DropdownMenuItem>
<DropdownMenuItem onClick={handle("docx")} disabled={exporting !== null}>
{exporting === "docx" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
Word (.docx)
</DropdownMenuItem>
<DropdownMenuItem onClick={handle("odt")} disabled={exporting !== null}>
{exporting === "odt" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
OpenDocument (.odt)
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuLabel className="text-xs text-muted-foreground">
Web &amp; E-Book
</DropdownMenuLabel>
<DropdownMenuItem onClick={handle("html")} disabled={exporting !== null}>
{exporting === "html" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
HTML (.html)
</DropdownMenuItem>
<DropdownMenuItem onClick={handle("epub")} disabled={exporting !== null}>
{exporting === "epub" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
EPUB (.epub)
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuLabel className="text-xs text-muted-foreground">
Source &amp; Plain
</DropdownMenuLabel>
<DropdownMenuItem onClick={handle("latex")} disabled={exporting !== null}>
{exporting === "latex" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
LaTeX (.tex)
</DropdownMenuItem>
</>
)}
<DropdownMenuItem onClick={handle("md")} disabled={exporting !== null}>
{exporting === "md" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
Markdown (.md)
</DropdownMenuItem>
{showAllFormats && (
<DropdownMenuItem onClick={handle("plain")} disabled={exporting !== null}>
{exporting === "plain" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
Plain Text (.txt)
</DropdownMenuItem>
)}
</>
);
}
export function ExportContextItems({
onExport,
exporting,
showAllFormats = true,
}: ExportMenuItemsProps) {
const handle = (format: string) => (e: React.MouseEvent) => {
e.stopPropagation();
onExport(format);
};
return (
<>
{showAllFormats && (
<>
<ContextMenuItem onClick={handle("pdf")} disabled={exporting !== null}>
{exporting === "pdf" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
PDF (.pdf)
</ContextMenuItem>
<ContextMenuItem onClick={handle("docx")} disabled={exporting !== null}>
{exporting === "docx" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
Word (.docx)
</ContextMenuItem>
<ContextMenuItem onClick={handle("odt")} disabled={exporting !== null}>
{exporting === "odt" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
OpenDocument (.odt)
</ContextMenuItem>
<ContextMenuItem onClick={handle("html")} disabled={exporting !== null}>
{exporting === "html" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
HTML (.html)
</ContextMenuItem>
<ContextMenuItem onClick={handle("epub")} disabled={exporting !== null}>
{exporting === "epub" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
EPUB (.epub)
</ContextMenuItem>
<ContextMenuItem onClick={handle("latex")} disabled={exporting !== null}>
{exporting === "latex" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
LaTeX (.tex)
</ContextMenuItem>
</>
)}
<ContextMenuItem onClick={handle("md")} disabled={exporting !== null}>
{exporting === "md" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
Markdown (.md)
</ContextMenuItem>
{showAllFormats && (
<ContextMenuItem onClick={handle("plain")} disabled={exporting !== null}>
{exporting === "plain" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
Plain Text (.txt)
</ContextMenuItem>
)}
</>
);
}

View file

@ -253,6 +253,12 @@ function ApprovalCard({
String(effectiveNewDescription ?? "") !== (event?.description ?? "");
const buildFinalArgs = useCallback(() => {
const base = {
event_id: event?.event_id,
document_id: event?.document_id,
connector_id: account?.id,
};
if (pendingEdits) {
const attendeesArr = pendingEdits.attendees
? pendingEdits.attendees
@ -260,22 +266,38 @@ function ApprovalCard({
.map((e) => e.trim())
.filter(Boolean)
: null;
const origAttendees = event?.attendees?.map((a) => a.email) ?? [];
return {
event_id: event?.event_id,
document_id: event?.document_id,
connector_id: account?.id,
new_summary: pendingEdits.summary || null,
new_description: pendingEdits.description || null,
new_start_datetime: pendingEdits.start_datetime || null,
new_end_datetime: pendingEdits.end_datetime || null,
new_location: pendingEdits.location || null,
new_attendees: attendeesArr,
...base,
new_summary:
pendingEdits.summary && pendingEdits.summary !== (event?.summary ?? "")
? pendingEdits.summary
: null,
new_description:
pendingEdits.description !== (event?.description ?? "")
? pendingEdits.description || null
: null,
new_start_datetime:
pendingEdits.start_datetime && pendingEdits.start_datetime !== (event?.start ?? "")
? pendingEdits.start_datetime
: null,
new_end_datetime:
pendingEdits.end_datetime && pendingEdits.end_datetime !== (event?.end ?? "")
? pendingEdits.end_datetime
: null,
new_location:
pendingEdits.location !== (event?.location ?? "")
? pendingEdits.location || null
: null,
new_attendees:
attendeesArr && attendeesArr.join(",") !== origAttendees.join(",")
? attendeesArr
: null,
};
}
return {
event_id: event?.event_id,
document_id: event?.document_id,
connector_id: account?.id,
...base,
new_summary: actionArgs.new_summary ?? null,
new_description: actionArgs.new_description ?? null,
new_start_datetime: actionArgs.new_start_datetime ?? null,

View file

@ -1,7 +1,7 @@
"use client";
import { AnimatePresence, motion } from "motion/react";
import { useCallback, useEffect, useState } from "react";
import { useCallback, useEffect, useRef, useState } from "react";
import { createPortal } from "react-dom";
function isVideoSrc(src: string) {
@ -17,6 +17,12 @@ function ExpandedMediaOverlay({
alt: string;
onClose: () => void;
}) {
const overlayRef = useRef<HTMLDivElement>(null);
useEffect(() => {
overlayRef.current?.focus();
}, []);
useEffect(() => {
const handleKey = (e: KeyboardEvent) => {
if (e.key === "Escape") onClose();
@ -52,12 +58,20 @@ function ExpandedMediaOverlay({
return createPortal(
<motion.div
role="dialog"
aria-modal="true"
aria-label="Expanded media view"
tabIndex={-1}
ref={overlayRef}
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.2 }}
className="fixed inset-0 z-100 flex items-center justify-center bg-black/70 p-4 backdrop-blur-sm sm:p-8"
onClick={onClose}
onKeyDown={(e) => {
if (e.key === "Escape") onClose();
}}
>
{mediaElement}
</motion.div>,

View file

@ -5,10 +5,10 @@ import {
FileText,
Film,
Globe,
ImageIcon,
type LucideIcon,
Podcast,
ScanLine,
Sparkles,
Wrench,
} from "lucide-react";
@ -17,7 +17,7 @@ const TOOL_ICONS: Record<string, LucideIcon> = {
generate_podcast: Podcast,
generate_video_presentation: Film,
generate_report: FileText,
generate_image: Sparkles,
generate_image: ImageIcon,
scrape_webpage: ScanLine,
web_search: Globe,
search_surfsense_docs: BookOpen,

View file

@ -0,0 +1,31 @@
"use client";
import { useQuery } from "@rocicorp/zero/react";
import { useMemo } from "react";
import { queries } from "@/zero/queries";
/**
* Real-time document type counts derived from Zero's live document sync.
* Updates instantly as documents are created, deleted, or change type.
*/
export function useZeroDocumentTypeCounts(
searchSpaceId: number | string | null
): Record<string, number> | undefined {
const numericId = searchSpaceId != null ? Number(searchSpaceId) : null;
const [zeroDocuments] = useQuery(
queries.documents.bySpace({ searchSpaceId: numericId ?? -1 })
);
return useMemo(() => {
if (!zeroDocuments || numericId == null) return undefined;
const counts: Record<string, number> = {};
for (const doc of zeroDocuments) {
if (doc.id != null && doc.title != null && doc.title !== "") {
counts[doc.documentType] = (counts[doc.documentType] || 0) + 1;
}
}
return counts;
}, [zeroDocuments, numericId]);
}

View file

@ -17,7 +17,6 @@ export const cacheKeys = {
withQueryParams: (queries: GetDocumentsRequest["queryParams"]) =>
["documents-with-queries", ...(queries ? Object.values(queries) : [])] as const,
document: (documentId: string) => ["document", documentId] as const,
typeCounts: (searchSpaceId?: string) => ["documents", "type-counts", searchSpaceId] as const,
byChunk: (chunkId: string) => ["documents", "by-chunk", chunkId] as const,
},
logs: {