mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
Merge branch 'dev' of https://github.com/MODSetter/SurfSense into dev
This commit is contained in:
commit
947def5c4a
65 changed files with 10278 additions and 7590 deletions
|
|
@ -14,6 +14,20 @@ from app.services.google_calendar import GoogleCalendarToolMetadataService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_date_only(value: str) -> bool:
|
||||
"""Return True when *value* looks like a bare date (YYYY-MM-DD) with no time component."""
|
||||
return len(value) <= 10 and "T" not in value
|
||||
|
||||
|
||||
def _build_time_body(value: str, context: dict[str, Any] | Any) -> dict[str, str]:
|
||||
"""Build a Google Calendar start/end body using ``date`` for all-day
|
||||
events and ``dateTime`` for timed events."""
|
||||
if _is_date_only(value):
|
||||
return {"date": value}
|
||||
tz = context.get("timezone", "UTC") if isinstance(context, dict) else "UTC"
|
||||
return {"dateTime": value, "timeZone": tz}
|
||||
|
||||
|
||||
def create_update_calendar_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
|
|
@ -255,25 +269,13 @@ def create_update_calendar_event_tool(
|
|||
if final_new_summary is not None:
|
||||
update_body["summary"] = final_new_summary
|
||||
if final_new_start_datetime is not None:
|
||||
tz = (
|
||||
context.get("timezone", "UTC")
|
||||
if isinstance(context, dict)
|
||||
else "UTC"
|
||||
update_body["start"] = _build_time_body(
|
||||
final_new_start_datetime, context
|
||||
)
|
||||
update_body["start"] = {
|
||||
"dateTime": final_new_start_datetime,
|
||||
"timeZone": tz,
|
||||
}
|
||||
if final_new_end_datetime is not None:
|
||||
tz = (
|
||||
context.get("timezone", "UTC")
|
||||
if isinstance(context, dict)
|
||||
else "UTC"
|
||||
update_body["end"] = _build_time_body(
|
||||
final_new_end_datetime, context
|
||||
)
|
||||
update_body["end"] = {
|
||||
"dateTime": final_new_end_datetime,
|
||||
"timeZone": tz,
|
||||
}
|
||||
if final_new_description is not None:
|
||||
update_body["description"] = final_new_description
|
||||
if final_new_location is not None:
|
||||
|
|
|
|||
|
|
@ -2,13 +2,14 @@
|
|||
|
||||
from .change_tracker import categorize_change, fetch_all_changes, get_start_page_token
|
||||
from .client import GoogleDriveClient
|
||||
from .content_extractor import download_and_process_file
|
||||
from .content_extractor import download_and_extract_content, download_and_process_file
|
||||
from .credentials import get_valid_credentials, validate_credentials
|
||||
from .folder_manager import get_file_by_id, get_files_in_folder, list_folder_contents
|
||||
|
||||
__all__ = [
|
||||
"GoogleDriveClient",
|
||||
"categorize_change",
|
||||
"download_and_extract_content",
|
||||
"download_and_process_file",
|
||||
"fetch_all_changes",
|
||||
"get_file_by_id",
|
||||
|
|
|
|||
|
|
@ -84,22 +84,50 @@ async def get_changes(
|
|||
return [], None, f"Error getting changes: {e!s}"
|
||||
|
||||
|
||||
async def _is_descendant_of(
|
||||
client: GoogleDriveClient,
|
||||
parent_ids: list[str],
|
||||
target_folder_id: str,
|
||||
max_depth: int = 20,
|
||||
) -> bool:
|
||||
"""Walk up the parent chain to check if any ancestor is *target_folder_id*."""
|
||||
visited: set[str] = set()
|
||||
to_check = list(parent_ids)
|
||||
|
||||
for _ in range(max_depth):
|
||||
if not to_check:
|
||||
return False
|
||||
|
||||
current = to_check.pop(0)
|
||||
if current in visited:
|
||||
continue
|
||||
visited.add(current)
|
||||
|
||||
if current == target_folder_id:
|
||||
return True
|
||||
|
||||
try:
|
||||
service = await client.get_service()
|
||||
meta = (
|
||||
service.files()
|
||||
.get(fileId=current, fields="parents", supportsAllDrives=True)
|
||||
.execute()
|
||||
)
|
||||
grandparents = meta.get("parents", [])
|
||||
to_check.extend(grandparents)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _filter_changes_by_folder(
|
||||
client: GoogleDriveClient,
|
||||
changes: list[dict[str, Any]],
|
||||
folder_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter changes to only include files within the specified folder.
|
||||
|
||||
Args:
|
||||
client: GoogleDriveClient instance
|
||||
changes: List of changes from API
|
||||
folder_id: Folder ID to filter by
|
||||
|
||||
Returns:
|
||||
Filtered list of changes
|
||||
"""
|
||||
"""Filter changes to only include files within the specified folder
|
||||
(direct children or nested descendants)."""
|
||||
filtered = []
|
||||
|
||||
for change in changes:
|
||||
|
|
@ -108,14 +136,10 @@ async def _filter_changes_by_folder(
|
|||
filtered.append(change)
|
||||
continue
|
||||
|
||||
# Check if file is in the folder (or subfolder)
|
||||
parents = file.get("parents", [])
|
||||
if folder_id in parents:
|
||||
filtered.append(change)
|
||||
else:
|
||||
# Check if any parent is a descendant of folder_id
|
||||
# This is a simplified check - full implementation would traverse hierarchy
|
||||
# For now, we'll include it and let indexer validate
|
||||
elif await _is_descendant_of(client, parents, folder_id):
|
||||
filtered.append(change)
|
||||
|
||||
return filtered
|
||||
|
|
|
|||
|
|
@ -1,9 +1,15 @@
|
|||
"""Google Drive API client."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httplib2
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_httplib2 import AuthorizedHttp
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
from googleapiclient.http import MediaIoBaseUpload
|
||||
|
|
@ -12,6 +18,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from .credentials import get_valid_credentials
|
||||
from .file_types import GOOGLE_DOC, GOOGLE_SHEET
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_thread_http(credentials: Credentials) -> AuthorizedHttp:
|
||||
"""Create a per-thread HTTP transport so concurrent downloads don't share
|
||||
the same ``httplib2.Http`` (which is not thread-safe)."""
|
||||
return AuthorizedHttp(credentials, http=httplib2.Http())
|
||||
|
||||
|
||||
class GoogleDriveClient:
|
||||
"""Client for Google Drive API operations."""
|
||||
|
|
@ -34,7 +48,9 @@ class GoogleDriveClient:
|
|||
self.session = session
|
||||
self.connector_id = connector_id
|
||||
self._credentials = credentials
|
||||
self._resolved_credentials: Credentials | None = None
|
||||
self.service = None
|
||||
self._service_lock = asyncio.Lock()
|
||||
|
||||
async def get_service(self):
|
||||
"""
|
||||
|
|
@ -49,6 +65,10 @@ class GoogleDriveClient:
|
|||
if self.service:
|
||||
return self.service
|
||||
|
||||
async with self._service_lock:
|
||||
if self.service:
|
||||
return self.service
|
||||
|
||||
try:
|
||||
if self._credentials:
|
||||
credentials = self._credentials
|
||||
|
|
@ -56,6 +76,7 @@ class GoogleDriveClient:
|
|||
credentials = await get_valid_credentials(
|
||||
self.session, self.connector_id
|
||||
)
|
||||
self._resolved_credentials = credentials
|
||||
self.service = build("drive", "v3", credentials=credentials)
|
||||
return self.service
|
||||
except Exception as e:
|
||||
|
|
@ -134,6 +155,33 @@ class GoogleDriveClient:
|
|||
except Exception as e:
|
||||
return None, f"Error getting file metadata: {e!s}"
|
||||
|
||||
@staticmethod
|
||||
def _sync_download_file(
|
||||
service, file_id: str, credentials: Credentials,
|
||||
) -> tuple[bytes | None, str | None]:
|
||||
"""Blocking download — runs on a worker thread via ``to_thread``."""
|
||||
thread = threading.current_thread().name
|
||||
t0 = time.monotonic()
|
||||
logger.info(f"[download] START file={file_id} thread={thread}")
|
||||
try:
|
||||
from googleapiclient.http import MediaIoBaseDownload
|
||||
|
||||
http = _build_thread_http(credentials)
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
request.http = http
|
||||
fh = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(fh, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
return fh.getvalue(), None
|
||||
except HttpError as e:
|
||||
return None, f"HTTP error downloading file: {e.resp.status}"
|
||||
except Exception as e:
|
||||
return None, f"Error downloading file: {e!s}"
|
||||
finally:
|
||||
logger.info(f"[download] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
|
||||
|
||||
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
|
||||
"""
|
||||
Download binary file content.
|
||||
|
|
@ -144,27 +192,76 @@ class GoogleDriveClient:
|
|||
Returns:
|
||||
Tuple of (file content bytes, error message)
|
||||
"""
|
||||
try:
|
||||
service = await self.get_service()
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
return await asyncio.to_thread(
|
||||
self._sync_download_file, service, file_id, self._resolved_credentials,
|
||||
)
|
||||
|
||||
import io
|
||||
|
||||
fh = io.BytesIO()
|
||||
@staticmethod
|
||||
def _sync_download_file_to_disk(
|
||||
service, file_id: str, dest_path: str, chunksize: int,
|
||||
credentials: Credentials,
|
||||
) -> str | None:
|
||||
"""Blocking download-to-disk — runs on a worker thread via ``to_thread``."""
|
||||
thread = threading.current_thread().name
|
||||
t0 = time.monotonic()
|
||||
logger.info(f"[download-to-disk] START file={file_id} thread={thread}")
|
||||
try:
|
||||
from googleapiclient.http import MediaIoBaseDownload
|
||||
|
||||
downloader = MediaIoBaseDownload(fh, request)
|
||||
|
||||
http = _build_thread_http(credentials)
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
request.http = http
|
||||
with open(dest_path, "wb") as fh:
|
||||
downloader = MediaIoBaseDownload(fh, request, chunksize=chunksize)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
return fh.getvalue(), None
|
||||
|
||||
return None
|
||||
except HttpError as e:
|
||||
return None, f"HTTP error downloading file: {e.resp.status}"
|
||||
return f"HTTP error downloading file: {e.resp.status}"
|
||||
except Exception as e:
|
||||
return None, f"Error downloading file: {e!s}"
|
||||
return f"Error downloading file: {e!s}"
|
||||
finally:
|
||||
logger.info(f"[download-to-disk] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
|
||||
|
||||
async def download_file_to_disk(
|
||||
self, file_id: str, dest_path: str, chunksize: int = 5 * 1024 * 1024,
|
||||
) -> str | None:
|
||||
"""Stream file directly to disk in chunks, avoiding full in-memory buffering.
|
||||
|
||||
Returns error message on failure, None on success.
|
||||
"""
|
||||
service = await self.get_service()
|
||||
return await asyncio.to_thread(
|
||||
self._sync_download_file_to_disk,
|
||||
service, file_id, dest_path, chunksize, self._resolved_credentials,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _sync_export_google_file(
|
||||
service, file_id: str, mime_type: str, credentials: Credentials,
|
||||
) -> tuple[bytes | None, str | None]:
|
||||
"""Blocking export — runs on a worker thread via ``to_thread``."""
|
||||
thread = threading.current_thread().name
|
||||
t0 = time.monotonic()
|
||||
logger.info(f"[export] START file={file_id} thread={thread}")
|
||||
try:
|
||||
http = _build_thread_http(credentials)
|
||||
content = (
|
||||
service.files()
|
||||
.export(fileId=file_id, mimeType=mime_type)
|
||||
.execute(http=http)
|
||||
)
|
||||
if not isinstance(content, bytes):
|
||||
content = content.encode("utf-8")
|
||||
return content, None
|
||||
except HttpError as e:
|
||||
return None, f"HTTP error exporting file: {e.resp.status}"
|
||||
except Exception as e:
|
||||
return None, f"Error exporting file: {e!s}"
|
||||
finally:
|
||||
logger.info(f"[export] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
|
||||
|
||||
async def export_google_file(
|
||||
self, file_id: str, mime_type: str
|
||||
|
|
@ -179,24 +276,12 @@ class GoogleDriveClient:
|
|||
Returns:
|
||||
Tuple of (exported content as bytes, error message)
|
||||
"""
|
||||
try:
|
||||
service = await self.get_service()
|
||||
content = (
|
||||
service.files().export(fileId=file_id, mimeType=mime_type).execute()
|
||||
return await asyncio.to_thread(
|
||||
self._sync_export_google_file, service, file_id, mime_type,
|
||||
self._resolved_credentials,
|
||||
)
|
||||
|
||||
# Content is already bytes from the API
|
||||
# Keep as bytes to support both text and binary formats (like PDF)
|
||||
if not isinstance(content, bytes):
|
||||
content = content.encode("utf-8")
|
||||
|
||||
return content, None
|
||||
|
||||
except HttpError as e:
|
||||
return None, f"HTTP error exporting file: {e.resp.status}"
|
||||
except Exception as e:
|
||||
return None, f"Error exporting file: {e!s}"
|
||||
|
||||
async def create_file(
|
||||
self,
|
||||
name: str,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
"""Content extraction for Google Drive files."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -12,11 +15,182 @@ from app.db import Log
|
|||
from app.services.task_logging_service import TaskLoggingService
|
||||
|
||||
from .client import GoogleDriveClient
|
||||
from .file_types import get_export_mime_type, is_google_workspace_file, should_skip_file
|
||||
from .file_types import (
|
||||
get_export_mime_type,
|
||||
get_extension_from_mime,
|
||||
is_google_workspace_file,
|
||||
should_skip_file,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def download_and_extract_content(
|
||||
client: GoogleDriveClient,
|
||||
file: dict[str, Any],
|
||||
) -> tuple[str | None, dict[str, Any], str | None]:
|
||||
"""Download a Google Drive file and extract its content as markdown.
|
||||
|
||||
ETL only -- no DB writes, no indexing, no summarization.
|
||||
|
||||
Returns:
|
||||
(markdown_content, drive_metadata, error_message)
|
||||
On success error_message is None.
|
||||
"""
|
||||
file_id = file.get("id")
|
||||
file_name = file.get("name", "Unknown")
|
||||
mime_type = file.get("mimeType", "")
|
||||
|
||||
if should_skip_file(mime_type):
|
||||
return None, {}, f"Skipping {mime_type}"
|
||||
|
||||
logger.info(f"Downloading file for content extraction: {file_name} ({mime_type})")
|
||||
|
||||
drive_metadata: dict[str, Any] = {
|
||||
"google_drive_file_id": file_id,
|
||||
"google_drive_file_name": file_name,
|
||||
"google_drive_mime_type": mime_type,
|
||||
"source_connector": "google_drive",
|
||||
}
|
||||
if "modifiedTime" in file:
|
||||
drive_metadata["modified_time"] = file["modifiedTime"]
|
||||
if "createdTime" in file:
|
||||
drive_metadata["created_time"] = file["createdTime"]
|
||||
if "size" in file:
|
||||
drive_metadata["file_size"] = file["size"]
|
||||
if "webViewLink" in file:
|
||||
drive_metadata["web_view_link"] = file["webViewLink"]
|
||||
if "md5Checksum" in file:
|
||||
drive_metadata["md5_checksum"] = file["md5Checksum"]
|
||||
if is_google_workspace_file(mime_type):
|
||||
export_ext = get_extension_from_mime(get_export_mime_type(mime_type) or "")
|
||||
drive_metadata["exported_as"] = export_ext.lstrip(".") if export_ext else "pdf"
|
||||
drive_metadata["original_workspace_type"] = mime_type.split(".")[-1]
|
||||
|
||||
temp_file_path = None
|
||||
try:
|
||||
if is_google_workspace_file(mime_type):
|
||||
export_mime = get_export_mime_type(mime_type)
|
||||
if not export_mime:
|
||||
return None, drive_metadata, f"Cannot export Google Workspace type: {mime_type}"
|
||||
content_bytes, error = await client.export_google_file(file_id, export_mime)
|
||||
if error:
|
||||
return None, drive_metadata, error
|
||||
extension = get_extension_from_mime(export_mime) or ".pdf"
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
|
||||
tmp.write(content_bytes)
|
||||
temp_file_path = tmp.name
|
||||
else:
|
||||
extension = (
|
||||
Path(file_name).suffix
|
||||
or get_extension_from_mime(mime_type)
|
||||
or ".bin"
|
||||
)
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
|
||||
temp_file_path = tmp.name
|
||||
|
||||
error = await client.download_file_to_disk(file_id, temp_file_path)
|
||||
if error:
|
||||
return None, drive_metadata, error
|
||||
|
||||
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
|
||||
return markdown, drive_metadata, None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract content from {file_name}: {e!s}")
|
||||
return None, drive_metadata, str(e)
|
||||
finally:
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.unlink(temp_file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
||||
"""Parse a local file to markdown using the configured ETL service."""
|
||||
lower = filename.lower()
|
||||
|
||||
if lower.endswith((".md", ".markdown", ".txt")):
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
|
||||
from app.config import config as app_config
|
||||
from litellm import atranscription
|
||||
|
||||
stt_service_type = (
|
||||
"local"
|
||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||
else "external"
|
||||
)
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
t0 = time.monotonic()
|
||||
logger.info(f"[local-stt] START file={filename} thread={threading.current_thread().name}")
|
||||
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
|
||||
logger.info(f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
|
||||
text = result.get("text", "")
|
||||
else:
|
||||
with open(file_path, "rb") as audio_file:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
|
||||
resp = await atranscription(**kwargs)
|
||||
text = resp.get("text", "")
|
||||
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
return f"# Transcription of {filename}\n\n{text}"
|
||||
|
||||
# Document files -- use configured ETL service
|
||||
from app.config import config as app_config
|
||||
|
||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
docs = await loader.aload()
|
||||
return await convert_document_to_markdown(docs)
|
||||
|
||||
if app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||
from app.tasks.document_processors.file_processors import (
|
||||
parse_with_llamacloud_retry,
|
||||
)
|
||||
|
||||
result = await parse_with_llamacloud_retry(file_path=file_path, estimated_pages=50)
|
||||
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
|
||||
if not markdown_documents:
|
||||
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
|
||||
return markdown_documents[0].text
|
||||
|
||||
if app_config.ETL_SERVICE == "DOCLING":
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
converter = DocumentConverter()
|
||||
t0 = time.monotonic()
|
||||
logger.info(f"[docling] START file={filename} thread={threading.current_thread().name}")
|
||||
result = await asyncio.to_thread(converter.convert, file_path)
|
||||
logger.info(f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
|
||||
return result.document.export_to_markdown()
|
||||
|
||||
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
|
||||
|
||||
|
||||
async def download_and_process_file(
|
||||
client: GoogleDriveClient,
|
||||
file: dict[str, Any],
|
||||
|
|
@ -68,14 +242,17 @@ async def download_and_process_file(
|
|||
if error:
|
||||
return None, error
|
||||
|
||||
extension = ".pdf" if export_mime == "application/pdf" else ".txt"
|
||||
extension = get_extension_from_mime(export_mime) or ".pdf"
|
||||
else:
|
||||
content_bytes, error = await client.download_file(file_id)
|
||||
if error:
|
||||
return None, error
|
||||
|
||||
# Preserve original file extension
|
||||
extension = Path(file_name).suffix or ".bin"
|
||||
extension = (
|
||||
Path(file_name).suffix
|
||||
or get_extension_from_mime(mime_type)
|
||||
or ".bin"
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp_file:
|
||||
tmp_file.write(content_bytes)
|
||||
|
|
@ -113,7 +290,12 @@ async def download_and_process_file(
|
|||
connector_info["metadata"]["md5_checksum"] = file["md5Checksum"]
|
||||
|
||||
if is_google_workspace_file(mime_type):
|
||||
connector_info["metadata"]["exported_as"] = "pdf"
|
||||
export_ext = get_extension_from_mime(
|
||||
get_export_mime_type(mime_type) or ""
|
||||
)
|
||||
connector_info["metadata"]["exported_as"] = (
|
||||
export_ext.lstrip(".") if export_ext else "pdf"
|
||||
)
|
||||
connector_info["metadata"]["original_workspace_type"] = mime_type.split(
|
||||
"."
|
||||
)[-1]
|
||||
|
|
|
|||
|
|
@ -7,11 +7,34 @@ GOOGLE_FOLDER = "application/vnd.google-apps.folder"
|
|||
GOOGLE_SHORTCUT = "application/vnd.google-apps.shortcut"
|
||||
|
||||
EXPORT_FORMATS = {
|
||||
GOOGLE_DOC: "application/pdf",
|
||||
GOOGLE_SHEET: "application/pdf",
|
||||
GOOGLE_DOC: "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
GOOGLE_SHEET: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
GOOGLE_SLIDE: "application/pdf",
|
||||
}
|
||||
|
||||
MIME_TO_EXTENSION: dict[str, str] = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/msword": ".doc",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"text/plain": ".txt",
|
||||
"text/csv": ".csv",
|
||||
"text/html": ".html",
|
||||
"text/markdown": ".md",
|
||||
"application/json": ".json",
|
||||
"application/xml": ".xml",
|
||||
"image/png": ".png",
|
||||
"image/jpeg": ".jpg",
|
||||
}
|
||||
|
||||
|
||||
def get_extension_from_mime(mime_type: str) -> str | None:
|
||||
"""Return a file extension (with leading dot) for a MIME type, or None."""
|
||||
return MIME_TO_EXTENSION.get(mime_type)
|
||||
|
||||
|
||||
def is_google_workspace_file(mime_type: str) -> bool:
|
||||
"""Check if file is a Google Workspace file that needs export."""
|
||||
|
|
|
|||
|
|
@ -3,10 +3,17 @@ import hashlib
|
|||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
|
||||
|
||||
def compute_identifier_hash(
|
||||
document_type_value: str, unique_id: str, search_space_id: int
|
||||
) -> str:
|
||||
"""Return a stable SHA-256 hash from raw identity components."""
|
||||
combined = f"{document_type_value}:{unique_id}:{search_space_id}"
|
||||
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
|
||||
"""Return a stable SHA-256 hash identifying a document by its source identity."""
|
||||
combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}"
|
||||
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
return compute_identifier_hash(doc.document_type.value, doc.unique_id, doc.search_space_id)
|
||||
|
||||
|
||||
def compute_content_hash(doc: ConnectorDocument) -> str:
|
||||
|
|
|
|||
|
|
@ -1,17 +1,21 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
from app.indexing_pipeline.document_embedder import embed_texts
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
compute_identifier_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.document_persistence import (
|
||||
|
|
@ -54,6 +58,62 @@ class IndexingPipelineService:
|
|||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def migrate_legacy_docs(
|
||||
self, connector_docs: list[ConnectorDocument]
|
||||
) -> None:
|
||||
"""Migrate legacy Composio documents to their native Google type.
|
||||
|
||||
For each ConnectorDocument whose document_type has a Composio equivalent
|
||||
in NATIVE_TO_LEGACY_DOCTYPE, look up the old document by legacy hash and
|
||||
update its unique_identifier_hash and document_type so that
|
||||
prepare_for_indexing() can find it under the native hash.
|
||||
"""
|
||||
for doc in connector_docs:
|
||||
legacy_type = NATIVE_TO_LEGACY_DOCTYPE.get(doc.document_type.value)
|
||||
if not legacy_type:
|
||||
continue
|
||||
|
||||
legacy_hash = compute_identifier_hash(
|
||||
legacy_type, doc.unique_id, doc.search_space_id
|
||||
)
|
||||
result = await self.session.execute(
|
||||
select(Document).filter(
|
||||
Document.unique_identifier_hash == legacy_hash
|
||||
)
|
||||
)
|
||||
existing = result.scalars().first()
|
||||
if existing is None:
|
||||
continue
|
||||
|
||||
native_hash = compute_identifier_hash(
|
||||
doc.document_type.value, doc.unique_id, doc.search_space_id
|
||||
)
|
||||
existing.unique_identifier_hash = native_hash
|
||||
existing.document_type = doc.document_type
|
||||
|
||||
await self.session.commit()
|
||||
|
||||
async def index_batch(
|
||||
self, connector_docs: list[ConnectorDocument], llm
|
||||
) -> list[Document]:
|
||||
"""Convenience method: prepare_for_indexing then index each document.
|
||||
|
||||
Indexers that need heartbeat callbacks or custom per-document logic
|
||||
should call prepare_for_indexing() + index() directly instead.
|
||||
"""
|
||||
doc_map = {
|
||||
compute_unique_identifier_hash(cd): cd for cd in connector_docs
|
||||
}
|
||||
documents = await self.prepare_for_indexing(connector_docs)
|
||||
results: list[Document] = []
|
||||
for document in documents:
|
||||
connector_doc = doc_map.get(document.unique_identifier_hash)
|
||||
if connector_doc is None:
|
||||
continue
|
||||
result = await self.index(document, connector_doc, llm)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
async def prepare_for_indexing(
|
||||
self, connector_docs: list[ConnectorDocument]
|
||||
) -> list[Document]:
|
||||
|
|
@ -200,13 +260,14 @@ class IndexingPipelineService:
|
|||
)
|
||||
|
||||
t_step = time.perf_counter()
|
||||
chunk_texts = chunk_text(
|
||||
chunk_texts = await asyncio.to_thread(
|
||||
chunk_text,
|
||||
connector_doc.source_markdown,
|
||||
use_code_chunker=connector_doc.should_use_code_chunker,
|
||||
)
|
||||
|
||||
texts_to_embed = [content, *chunk_texts]
|
||||
embeddings = embed_texts(texts_to_embed)
|
||||
embeddings = await asyncio.to_thread(embed_texts, texts_to_embed)
|
||||
summary_embedding, *chunk_embeddings = embeddings
|
||||
|
||||
chunks = [
|
||||
|
|
@ -268,3 +329,126 @@ class IndexingPipelineService:
|
|||
await self.session.refresh(document)
|
||||
|
||||
return document
|
||||
|
||||
async def index_batch_parallel(
|
||||
self,
|
||||
connector_docs: list[ConnectorDocument],
|
||||
get_llm: Callable[[AsyncSession], Awaitable],
|
||||
*,
|
||||
max_concurrency: int = 4,
|
||||
on_heartbeat: Callable[[int], Awaitable[None]] | None = None,
|
||||
heartbeat_interval: float = 30.0,
|
||||
) -> tuple[list[Document], int, int]:
|
||||
"""Index documents in parallel with bounded concurrency.
|
||||
|
||||
Phase 1 (serial): prepare_for_indexing using self.session.
|
||||
Phase 2 (parallel): index each document in an isolated session,
|
||||
bounded by a semaphore to avoid overwhelming APIs/DB.
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
perf = get_perf_logger()
|
||||
t_total = time.perf_counter()
|
||||
|
||||
doc_map = {
|
||||
compute_unique_identifier_hash(cd): cd for cd in connector_docs
|
||||
}
|
||||
documents = await self.prepare_for_indexing(connector_docs)
|
||||
|
||||
if not documents:
|
||||
return [], 0, 0
|
||||
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
sem = asyncio.Semaphore(max_concurrency)
|
||||
lock = asyncio.Lock()
|
||||
indexed_count = 0
|
||||
failed_count = 0
|
||||
results: list[Document] = []
|
||||
last_heartbeat = time.time()
|
||||
|
||||
async def _index_one(document: Document) -> Document | Exception:
|
||||
nonlocal indexed_count, failed_count, last_heartbeat
|
||||
|
||||
connector_doc = doc_map.get(document.unique_identifier_hash)
|
||||
if connector_doc is None:
|
||||
logger.warning(
|
||||
"No matching ConnectorDocument for document %s, skipping",
|
||||
document.id,
|
||||
)
|
||||
async with lock:
|
||||
failed_count += 1
|
||||
return document
|
||||
|
||||
async with sem:
|
||||
session_maker = get_celery_session_maker()
|
||||
async with session_maker() as isolated_session:
|
||||
try:
|
||||
refetched = await isolated_session.get(
|
||||
Document, document.id
|
||||
)
|
||||
if refetched is None:
|
||||
async with lock:
|
||||
failed_count += 1
|
||||
return document
|
||||
|
||||
llm = await get_llm(isolated_session)
|
||||
iso_pipeline = IndexingPipelineService(isolated_session)
|
||||
result = await iso_pipeline.index(
|
||||
refetched, connector_doc, llm
|
||||
)
|
||||
|
||||
async with lock:
|
||||
if DocumentStatus.is_state(
|
||||
result.status, DocumentStatus.READY
|
||||
):
|
||||
indexed_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
if on_heartbeat:
|
||||
now = time.time()
|
||||
if now - last_heartbeat >= heartbeat_interval:
|
||||
await on_heartbeat(indexed_count)
|
||||
last_heartbeat = now
|
||||
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Parallel index failed for doc %s: %s",
|
||||
document.id,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
async with lock:
|
||||
failed_count += 1
|
||||
return exc
|
||||
|
||||
tasks = [_index_one(doc) for doc in documents]
|
||||
t_parallel = time.perf_counter()
|
||||
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
perf.info(
|
||||
"[indexing] index_batch_parallel gather docs=%d concurrency=%d "
|
||||
"indexed=%d failed=%d in %.3fs",
|
||||
len(documents),
|
||||
max_concurrency,
|
||||
indexed_count,
|
||||
failed_count,
|
||||
time.perf_counter() - t_parallel,
|
||||
)
|
||||
|
||||
for outcome in outcomes:
|
||||
if isinstance(outcome, Document):
|
||||
results.append(outcome)
|
||||
elif isinstance(outcome, Exception):
|
||||
pass
|
||||
|
||||
perf.info(
|
||||
"[indexing] index_batch_parallel TOTAL input=%d prepared=%d "
|
||||
"indexed=%d failed=%d in %.3fs",
|
||||
len(connector_docs),
|
||||
len(documents),
|
||||
indexed_count,
|
||||
failed_count,
|
||||
time.perf_counter() - t_total,
|
||||
)
|
||||
return results, indexed_count, failed_count
|
||||
|
|
|
|||
|
|
@ -1,19 +1,43 @@
|
|||
"""
|
||||
Editor routes for document editing with markdown (Plate.js frontend).
|
||||
Includes multi-format export (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
import pypandoc
|
||||
import typst
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import Document, DocumentType, Permission, User, get_async_session
|
||||
from app.routes.reports_routes import (
|
||||
ExportFormat,
|
||||
_FILE_EXTENSIONS,
|
||||
_MEDIA_TYPES,
|
||||
_normalize_latex_delimiters,
|
||||
_strip_wrapping_code_fences,
|
||||
)
|
||||
from app.templates.export_helpers import (
|
||||
get_html_css_path,
|
||||
get_reference_docx_path,
|
||||
get_typst_template_path,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
|
|
@ -212,3 +236,153 @@ async def save_document(
|
|||
"message": "Document saved and will be reindexed in the background",
|
||||
"updated_at": document.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search-spaces/{search_space_id}/documents/{document_id}/export"
|
||||
)
|
||||
async def export_document(
|
||||
search_space_id: int,
|
||||
document_id: int,
|
||||
format: ExportFormat = Query(
|
||||
ExportFormat.PDF,
|
||||
description="Export format: pdf, docx, html, latex, epub, odt, or plain",
|
||||
),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Export a document in the requested format (reuses the report export pipeline)."""
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.options(selectinload(Document.chunks))
|
||||
.filter(
|
||||
Document.id == document_id,
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
# Resolve markdown content (same priority as editor-content endpoint)
|
||||
markdown_content: str | None = document.source_markdown
|
||||
if markdown_content is None and document.blocknote_document:
|
||||
from app.utils.blocknote_to_markdown import blocknote_to_markdown
|
||||
|
||||
markdown_content = blocknote_to_markdown(document.blocknote_document)
|
||||
if markdown_content is None:
|
||||
chunks = sorted(document.chunks, key=lambda c: c.id)
|
||||
if chunks:
|
||||
markdown_content = "\n\n".join(chunk.content for chunk in chunks)
|
||||
|
||||
if not markdown_content or not markdown_content.strip():
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Document has no content to export"
|
||||
)
|
||||
|
||||
markdown_content = _strip_wrapping_code_fences(markdown_content)
|
||||
markdown_content = _normalize_latex_delimiters(markdown_content)
|
||||
|
||||
doc_title = document.title or "Document"
|
||||
formatted_date = (
|
||||
document.created_at.strftime("%B %d, %Y") if document.created_at else ""
|
||||
)
|
||||
input_fmt = "gfm+tex_math_dollars"
|
||||
meta_args = ["-M", f"title:{doc_title}", "-M", f"date:{formatted_date}"]
|
||||
|
||||
def _convert_and_read() -> bytes:
|
||||
if format == ExportFormat.PDF:
|
||||
typst_template = str(get_typst_template_path())
|
||||
typst_markup: str = pypandoc.convert_text(
|
||||
markdown_content,
|
||||
"typst",
|
||||
format=input_fmt,
|
||||
extra_args=[
|
||||
"--standalone",
|
||||
f"--template={typst_template}",
|
||||
"-V", "mainfont:Libertinus Serif",
|
||||
"-V", "codefont:DejaVu Sans Mono",
|
||||
*meta_args,
|
||||
],
|
||||
)
|
||||
return typst.compile(typst_markup.encode("utf-8"))
|
||||
|
||||
if format == ExportFormat.DOCX:
|
||||
return _pandoc_to_tempfile(
|
||||
format.value,
|
||||
["--standalone", f"--reference-doc={get_reference_docx_path()}", *meta_args],
|
||||
)
|
||||
|
||||
if format == ExportFormat.HTML:
|
||||
html_str: str = pypandoc.convert_text(
|
||||
markdown_content,
|
||||
"html5",
|
||||
format=input_fmt,
|
||||
extra_args=[
|
||||
"--standalone", "--embed-resources",
|
||||
f"--css={get_html_css_path()}",
|
||||
"--syntax-highlighting=pygments",
|
||||
*meta_args,
|
||||
],
|
||||
)
|
||||
return html_str.encode("utf-8")
|
||||
|
||||
if format == ExportFormat.EPUB:
|
||||
return _pandoc_to_tempfile("epub3", ["--standalone", *meta_args])
|
||||
|
||||
if format == ExportFormat.ODT:
|
||||
return _pandoc_to_tempfile("odt", ["--standalone", *meta_args])
|
||||
|
||||
if format == ExportFormat.LATEX:
|
||||
tex_str: str = pypandoc.convert_text(
|
||||
markdown_content, "latex", format=input_fmt,
|
||||
extra_args=["--standalone", *meta_args],
|
||||
)
|
||||
return tex_str.encode("utf-8")
|
||||
|
||||
plain_str: str = pypandoc.convert_text(
|
||||
markdown_content, "plain", format=input_fmt,
|
||||
extra_args=["--wrap=auto", "--columns=80"],
|
||||
)
|
||||
return plain_str.encode("utf-8")
|
||||
|
||||
def _pandoc_to_tempfile(output_format: str, extra_args: list[str]) -> bytes:
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=f".{output_format}")
|
||||
os.close(fd)
|
||||
try:
|
||||
pypandoc.convert_text(
|
||||
markdown_content, output_format, format=input_fmt,
|
||||
extra_args=extra_args, outputfile=tmp_path,
|
||||
)
|
||||
with open(tmp_path, "rb") as f:
|
||||
return f.read()
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
output = await loop.run_in_executor(None, _convert_and_read)
|
||||
except Exception as e:
|
||||
logger.exception("Document export failed")
|
||||
raise HTTPException(status_code=500, detail=f"Export failed: {e!s}") from e
|
||||
|
||||
safe_title = (
|
||||
"".join(c if c.isalnum() or c in " -_" else "_" for c in doc_title)
|
||||
.strip()[:80]
|
||||
or "document"
|
||||
)
|
||||
ext = _FILE_EXTENSIONS[format]
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(output),
|
||||
media_type=_MEDIA_TYPES[format],
|
||||
headers={"Content-Disposition": f'attachment; filename="{safe_title}.{ext}"'},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2329,7 +2329,7 @@ async def run_google_drive_indexing(
|
|||
try:
|
||||
from app.tasks.connector_indexers.google_drive_indexer import (
|
||||
index_google_drive_files,
|
||||
index_google_drive_single_file,
|
||||
index_google_drive_selected_files,
|
||||
)
|
||||
|
||||
# Parse the structured data
|
||||
|
|
@ -2402,25 +2402,23 @@ async def run_google_drive_indexing(
|
|||
exc_info=True,
|
||||
)
|
||||
|
||||
# Index each individual file
|
||||
for file in items.files:
|
||||
# Index all selected files together via the parallel pipeline
|
||||
if items.files:
|
||||
try:
|
||||
indexed_count, error_message = await index_google_drive_single_file(
|
||||
file_tuples = [(f.id, f.name) for f in items.files]
|
||||
indexed_count, _skipped, file_errors = await index_google_drive_selected_files(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
file_id=file.id,
|
||||
file_name=file.name,
|
||||
files=file_tuples,
|
||||
)
|
||||
if error_message:
|
||||
errors.append(f"File '{file.name}': {error_message}")
|
||||
else:
|
||||
total_indexed += indexed_count
|
||||
errors.extend(file_errors)
|
||||
except Exception as e:
|
||||
errors.append(f"File '{file.name}': {e!s}")
|
||||
errors.append(f"File batch indexing: {e!s}")
|
||||
logger.error(
|
||||
f"Error indexing file {file.name} ({file.id}): {e}",
|
||||
f"Error batch indexing files: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -209,8 +209,8 @@ class GoogleCalendarKBSyncService:
|
|||
)
|
||||
|
||||
calendar_id = (document.document_metadata or {}).get(
|
||||
"calendar_id", "primary"
|
||||
)
|
||||
"calendar_id"
|
||||
) or "primary"
|
||||
live_event = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
|
|
|
|||
|
|
@ -1,49 +1,74 @@
|
|||
"""
|
||||
Confluence connector indexer.
|
||||
|
||||
Provides real-time document status updates during indexing using a two-phase approach:
|
||||
- Phase 1: Create all documents with PENDING status (visible in UI immediately)
|
||||
- Phase 2: Process each document one by one (PENDING → PROCESSING → READY/FAILED)
|
||||
"""
|
||||
"""Confluence connector indexer using the unified parallel indexing pipeline."""
|
||||
|
||||
import contextlib
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.db import DocumentType, SearchSourceConnectorType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
from .base import (
|
||||
calculate_date_range,
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_connector_by_id,
|
||||
get_current_timestamp,
|
||||
logger,
|
||||
safe_set_chunks,
|
||||
update_connector_last_indexed,
|
||||
)
|
||||
|
||||
# Type hint for heartbeat callback
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
|
||||
# Heartbeat interval in seconds
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
|
||||
def _build_connector_doc(
|
||||
page: dict,
|
||||
full_content: str,
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
) -> ConnectorDocument:
|
||||
"""Map a raw Confluence page dict to a ConnectorDocument."""
|
||||
page_id = page.get("id", "")
|
||||
page_title = page.get("title", "")
|
||||
space_id = page.get("spaceId", "")
|
||||
comment_count = len(page.get("comments", []))
|
||||
|
||||
metadata = {
|
||||
"page_id": page_id,
|
||||
"page_title": page_title,
|
||||
"space_id": space_id,
|
||||
"comment_count": comment_count,
|
||||
"connector_id": connector_id,
|
||||
"document_type": "Confluence Page",
|
||||
"connector_type": "Confluence",
|
||||
}
|
||||
|
||||
fallback_summary = (
|
||||
f"Confluence Page: {page_title}\n\nSpace ID: {space_id}\n\n{full_content}"
|
||||
)
|
||||
|
||||
return ConnectorDocument(
|
||||
title=page_title,
|
||||
source_markdown=full_content,
|
||||
unique_id=page_id,
|
||||
document_type=DocumentType.CONFLUENCE_CONNECTOR,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=enable_summary,
|
||||
fallback_summary=fallback_summary,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def index_confluence_pages(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
|
|
@ -53,26 +78,9 @@ async def index_confluence_pages(
|
|||
end_date: str | None = None,
|
||||
update_last_indexed: bool = True,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, str | None]:
|
||||
"""
|
||||
Index Confluence pages and comments.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Confluence connector
|
||||
search_space_id: ID of the search space to store documents in
|
||||
user_id: User ID
|
||||
start_date: Start date for indexing (YYYY-MM-DD format)
|
||||
end_date: End date for indexing (YYYY-MM-DD format)
|
||||
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
|
||||
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
|
||||
|
||||
Returns:
|
||||
Tuple containing (number of documents indexed, error message or None)
|
||||
"""
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""Index Confluence pages and comments."""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="confluence_pages_indexing",
|
||||
source="connector_indexing_task",
|
||||
|
|
@ -86,7 +94,6 @@ async def index_confluence_pages(
|
|||
)
|
||||
|
||||
try:
|
||||
# Get the connector from the database
|
||||
connector = await get_connector_by_id(
|
||||
session, connector_id, SearchSourceConnectorType.CONFLUENCE_CONNECTOR
|
||||
)
|
||||
|
|
@ -98,9 +105,8 @@ async def index_confluence_pages(
|
|||
"Connector not found",
|
||||
{"error_type": "ConnectorNotFound"},
|
||||
)
|
||||
return 0, f"Connector with ID {connector_id} not found"
|
||||
return 0, 0, f"Connector with ID {connector_id} not found"
|
||||
|
||||
# Initialize Confluence OAuth client
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Initializing Confluence OAuth client for connector {connector_id}",
|
||||
|
|
@ -114,7 +120,6 @@ async def index_confluence_pages(
|
|||
)
|
||||
)
|
||||
|
||||
# Calculate date range
|
||||
start_date_str, end_date_str = calculate_date_range(
|
||||
connector, start_date, end_date, default_days_back=365
|
||||
)
|
||||
|
|
@ -129,19 +134,14 @@ async def index_confluence_pages(
|
|||
},
|
||||
)
|
||||
|
||||
# Get pages within date range
|
||||
try:
|
||||
pages, error = await confluence_client.get_pages_by_date_range(
|
||||
start_date=start_date_str, end_date=end_date_str, include_comments=True
|
||||
)
|
||||
|
||||
if error:
|
||||
# Don't treat "No pages found" as an error that should stop indexing
|
||||
if "No pages found" in error:
|
||||
logger.info(f"No Confluence pages found: {error}")
|
||||
logger.info(
|
||||
"No pages found is not a critical error, continuing with update"
|
||||
)
|
||||
if update_last_indexed:
|
||||
await update_connector_last_indexed(
|
||||
session, connector, update_last_indexed
|
||||
|
|
@ -156,11 +156,10 @@ async def index_confluence_pages(
|
|||
f"No Confluence pages found in date range {start_date_str} to {end_date_str}",
|
||||
{"pages_found": 0},
|
||||
)
|
||||
# Close client before returning
|
||||
if confluence_client:
|
||||
with contextlib.suppress(Exception):
|
||||
await confluence_client.close()
|
||||
return 0, None
|
||||
return 0, 0, None
|
||||
else:
|
||||
logger.error(f"Failed to get Confluence pages: {error}")
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -169,36 +168,35 @@ async def index_confluence_pages(
|
|||
"API Error",
|
||||
{"error_type": "APIError"},
|
||||
)
|
||||
# Close client on error
|
||||
if confluence_client:
|
||||
with contextlib.suppress(Exception):
|
||||
await confluence_client.close()
|
||||
return 0, f"Failed to get Confluence pages: {error}"
|
||||
return 0, 0, f"Failed to get Confluence pages: {error}"
|
||||
|
||||
logger.info(f"Retrieved {len(pages)} pages from Confluence API")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Confluence pages: {e!s}", exc_info=True)
|
||||
# Close client on error
|
||||
if confluence_client:
|
||||
with contextlib.suppress(Exception):
|
||||
await confluence_client.close()
|
||||
return 0, f"Error fetching Confluence pages: {e!s}"
|
||||
return 0, 0, f"Error fetching Confluence pages: {e!s}"
|
||||
|
||||
if not pages:
|
||||
logger.info("No Confluence pages found for the specified date range")
|
||||
if update_last_indexed:
|
||||
await update_connector_last_indexed(
|
||||
session, connector, update_last_indexed
|
||||
)
|
||||
await session.commit()
|
||||
if confluence_client:
|
||||
with contextlib.suppress(Exception):
|
||||
await confluence_client.close()
|
||||
return 0, 0, None
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Analyze all pages, create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# =======================================================================
|
||||
documents_indexed = 0
|
||||
documents_skipped = 0
|
||||
documents_failed = 0
|
||||
duplicate_content_count = 0
|
||||
|
||||
# Heartbeat tracking - update notification periodically to prevent appearing stuck
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
pages_to_process = [] # List of dicts with document and page data
|
||||
new_documents_created = False
|
||||
connector_docs: list[ConnectorDocument] = []
|
||||
|
||||
for page in pages:
|
||||
try:
|
||||
|
|
@ -213,12 +211,10 @@ async def index_confluence_pages(
|
|||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Extract page content
|
||||
page_content = ""
|
||||
if page.get("body") and page["body"].get("storage"):
|
||||
page_content = page["body"]["storage"].get("value", "")
|
||||
|
||||
# Add comments to content
|
||||
comments = page.get("comments", [])
|
||||
comments_content = ""
|
||||
if comments:
|
||||
|
|
@ -235,61 +231,25 @@ async def index_confluence_pages(
|
|||
|
||||
comments_content += f"**Comment by {comment_author}** ({comment_date}):\n{comment_body}\n\n"
|
||||
|
||||
# Combine page content with comments
|
||||
full_content = f"# {page_title}\n\n{page_content}{comments_content}"
|
||||
|
||||
if not full_content.strip():
|
||||
if not page_content.strip() and not comments:
|
||||
logger.warning(f"Skipping page with no content: {page_title}")
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Generate unique identifier hash for this Confluence page
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.CONFLUENCE_CONNECTOR, page_id, search_space_id
|
||||
doc = _build_connector_doc(
|
||||
page,
|
||||
full_content,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=connector.enable_summary,
|
||||
)
|
||||
|
||||
# Generate content hash
|
||||
content_hash = generate_content_hash(full_content, search_space_id)
|
||||
|
||||
# Check if document with this unique identifier already exists
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
comment_count = len(comments)
|
||||
|
||||
if existing_document:
|
||||
# Document exists - check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Ensure status is ready (might have been stuck in processing/pending)
|
||||
if not DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
pages_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
"full_content": full_content,
|
||||
"page_content": page_content,
|
||||
"content_hash": content_hash,
|
||||
"page_id": page_id,
|
||||
"page_title": page_title,
|
||||
"space_id": space_id,
|
||||
"comment_count": comment_count,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Document doesn't exist by unique_identifier_hash
|
||||
# Check if a document with the same content_hash exists (from another connector)
|
||||
with session.no_autoflush:
|
||||
duplicate_by_content = await check_duplicate_document_by_hash(
|
||||
session, content_hash
|
||||
session, compute_content_hash(doc)
|
||||
)
|
||||
|
||||
if duplicate_by_content:
|
||||
|
|
@ -302,151 +262,29 @@ async def index_confluence_pages(
|
|||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Create new document with PENDING status (visible in UI immediately)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=page_title,
|
||||
document_type=DocumentType.CONFLUENCE_CONNECTOR,
|
||||
document_metadata={
|
||||
"page_id": page_id,
|
||||
"page_title": page_title,
|
||||
"space_id": space_id,
|
||||
"comment_count": comment_count,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
chunks=[], # Empty at creation - safe for async
|
||||
status=DocumentStatus.pending(), # Pending until processing starts
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
session.add(document)
|
||||
new_documents_created = True
|
||||
|
||||
pages_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
"full_content": full_content,
|
||||
"page_content": page_content,
|
||||
"content_hash": content_hash,
|
||||
"page_id": page_id,
|
||||
"page_title": page_title,
|
||||
"space_id": space_id,
|
||||
"comment_count": comment_count,
|
||||
}
|
||||
)
|
||||
connector_docs.append(doc)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Phase 1 for page: {e!s}", exc_info=True)
|
||||
documents_failed += 1
|
||||
logger.error(f"Error building ConnectorDocument for page: {e!s}", exc_info=True)
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Commit all pending documents - they all appear in UI now
|
||||
if new_documents_created:
|
||||
logger.info(
|
||||
f"Phase 1: Committing {len([p for p in pages_to_process if p['is_new']])} pending documents"
|
||||
)
|
||||
await session.commit()
|
||||
pipeline = IndexingPipelineService(session)
|
||||
await pipeline.migrate_legacy_docs(connector_docs)
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(pages_to_process)} documents")
|
||||
async def _get_llm(s: AsyncSession):
|
||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||
|
||||
for item in pages_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
document = item["document"]
|
||||
try:
|
||||
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
|
||||
document.status = DocumentStatus.processing()
|
||||
await session.commit()
|
||||
|
||||
# Heavy processing (LLM, embeddings, chunks)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
|
||||
connector_docs,
|
||||
_get_llm,
|
||||
max_concurrency=3,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
|
||||
)
|
||||
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata = {
|
||||
"page_title": item["page_title"],
|
||||
"page_id": item["page_id"],
|
||||
"space_id": item["space_id"],
|
||||
"comment_count": item["comment_count"],
|
||||
"document_type": "Confluence Page",
|
||||
"connector_type": "Confluence",
|
||||
}
|
||||
(
|
||||
summary_content,
|
||||
summary_embedding,
|
||||
) = await generate_document_summary(
|
||||
item["full_content"], user_llm, document_metadata
|
||||
)
|
||||
else:
|
||||
summary_content = f"Confluence Page: {item['page_title']}\n\nSpace ID: {item['space_id']}\n\n{item['full_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
# Process chunks - using the full page content with comments
|
||||
chunks = await create_document_chunks(item["full_content"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = item["page_title"]
|
||||
document.content = summary_content
|
||||
document.content_hash = item["content_hash"]
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"page_id": item["page_id"],
|
||||
"page_title": item["page_title"],
|
||||
"space_id": item["space_id"],
|
||||
"comment_count": item["comment_count"],
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
await safe_set_chunks(session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready()
|
||||
|
||||
documents_indexed += 1
|
||||
|
||||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Confluence pages processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing page {item.get('page_title', 'Unknown')}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
document.updated_at = get_current_timestamp()
|
||||
except Exception as status_error:
|
||||
logger.error(
|
||||
f"Failed to update document status to failed: {status_error}"
|
||||
)
|
||||
documents_failed += 1
|
||||
continue # Skip this page and continue with others
|
||||
|
||||
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
|
||||
# This ensures the UI shows "Last indexed" instead of "Never indexed"
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
||||
# Final commit to ensure all documents are persisted (safety net)
|
||||
logger.info(
|
||||
f"Final commit: Total {documents_indexed} Confluence pages processed"
|
||||
)
|
||||
|
|
@ -456,7 +294,6 @@ async def index_confluence_pages(
|
|||
"Successfully committed all Confluence document changes to database"
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any remaining integrity errors gracefully (race conditions, etc.)
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in str(e).lower()
|
||||
or "uniqueviolationerror" in str(e).lower()
|
||||
|
|
@ -467,11 +304,9 @@ async def index_confluence_pages(
|
|||
f"Rolling back and continuing. Error: {e!s}"
|
||||
)
|
||||
await session.rollback()
|
||||
# Don't fail the entire task - some documents may have been successfully indexed
|
||||
else:
|
||||
raise
|
||||
|
||||
# Build warning message if there were issues
|
||||
warning_parts = []
|
||||
if duplicate_content_count > 0:
|
||||
warning_parts.append(f"{duplicate_content_count} duplicate")
|
||||
|
|
@ -479,7 +314,6 @@ async def index_confluence_pages(
|
|||
warning_parts.append(f"{documents_failed} failed")
|
||||
warning_message = ", ".join(warning_parts) if warning_parts else None
|
||||
|
||||
# Log success
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed Confluence indexing for connector {connector_id}",
|
||||
|
|
@ -490,22 +324,19 @@ async def index_confluence_pages(
|
|||
"duplicate_content_count": duplicate_content_count,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Confluence indexing completed: {documents_indexed} ready, "
|
||||
f"{documents_skipped} skipped, {documents_failed} failed "
|
||||
f"({duplicate_content_count} duplicate content)"
|
||||
)
|
||||
|
||||
# Close the client connection
|
||||
if confluence_client:
|
||||
await confluence_client.close()
|
||||
|
||||
return documents_indexed, warning_message
|
||||
return documents_indexed, documents_skipped, warning_message
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
# Close client if it exists
|
||||
if confluence_client:
|
||||
with contextlib.suppress(Exception):
|
||||
await confluence_client.close()
|
||||
|
|
@ -516,10 +347,9 @@ async def index_confluence_pages(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, f"Database error: {db_error!s}"
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
# Close client if it exists
|
||||
if confluence_client:
|
||||
with contextlib.suppress(Exception):
|
||||
await confluence_client.close()
|
||||
|
|
@ -530,4 +360,4 @@ async def index_confluence_pages(
|
|||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Confluence pages: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to index Confluence pages: {e!s}"
|
||||
return 0, 0, f"Failed to index Confluence pages: {e!s}"
|
||||
|
|
|
|||
|
|
@ -1,12 +1,10 @@
|
|||
"""
|
||||
Google Calendar connector indexer.
|
||||
|
||||
Implements 2-phase document status updates for real-time UI feedback:
|
||||
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
|
||||
- Phase 2: Process each document: pending → processing → ready/failed
|
||||
Uses the shared IndexingPipelineService for document deduplication,
|
||||
summarization, chunking, and embedding.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
|
@ -15,29 +13,22 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.db import DocumentType, SearchSourceConnectorType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
from app.utils.google_credentials import (
|
||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
||||
build_composio_credentials,
|
||||
)
|
||||
|
||||
from .base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_connector_by_id,
|
||||
get_current_timestamp,
|
||||
logger,
|
||||
parse_date_flexible,
|
||||
safe_set_chunks,
|
||||
update_connector_last_indexed,
|
||||
)
|
||||
|
||||
|
|
@ -46,13 +37,60 @@ ACCEPTED_CALENDAR_CONNECTOR_TYPES = {
|
|||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
}
|
||||
|
||||
# Type hint for heartbeat callback
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
|
||||
# Heartbeat interval in seconds
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
|
||||
def _build_connector_doc(
|
||||
event: dict,
|
||||
event_markdown: str,
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
) -> ConnectorDocument:
|
||||
"""Map a raw Google Calendar API event dict to a ConnectorDocument."""
|
||||
event_id = event.get("id", "")
|
||||
event_summary = event.get("summary", "No Title")
|
||||
calendar_id = event.get("calendarId", "")
|
||||
|
||||
start = event.get("start", {})
|
||||
end = event.get("end", {})
|
||||
start_time = start.get("dateTime") or start.get("date", "")
|
||||
end_time = end.get("dateTime") or end.get("date", "")
|
||||
location = event.get("location", "")
|
||||
|
||||
metadata = {
|
||||
"event_id": event_id,
|
||||
"event_summary": event_summary,
|
||||
"calendar_id": calendar_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
"connector_id": connector_id,
|
||||
"document_type": "Google Calendar Event",
|
||||
"connector_type": "Google Calendar",
|
||||
}
|
||||
|
||||
fallback_summary = (
|
||||
f"Google Calendar Event: {event_summary}\n\n{event_markdown}"
|
||||
)
|
||||
|
||||
return ConnectorDocument(
|
||||
title=event_summary,
|
||||
source_markdown=event_markdown,
|
||||
unique_id=event_id,
|
||||
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=enable_summary,
|
||||
fallback_summary=fallback_summary,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def index_google_calendar_events(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
|
|
@ -82,7 +120,6 @@ async def index_google_calendar_events(
|
|||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="google_calendar_events_indexing",
|
||||
source="connector_indexing_task",
|
||||
|
|
@ -96,7 +133,7 @@ async def index_google_calendar_events(
|
|||
)
|
||||
|
||||
try:
|
||||
# Accept both native and Composio Calendar connectors
|
||||
# ── Connector lookup ──────────────────────────────────────────
|
||||
connector = None
|
||||
for ct in ACCEPTED_CALENDAR_CONNECTOR_TYPES:
|
||||
connector = await get_connector_by_id(session, connector_id, ct)
|
||||
|
|
@ -112,7 +149,7 @@ async def index_google_calendar_events(
|
|||
)
|
||||
return 0, 0, f"Connector with ID {connector_id} not found"
|
||||
|
||||
# Build credentials based on connector type
|
||||
# ── Credential building ───────────────────────────────────────
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
|
|
@ -184,6 +221,7 @@ async def index_google_calendar_events(
|
|||
)
|
||||
return 0, 0, "Google Calendar credentials not found in connector config"
|
||||
|
||||
# ── Calendar client init ──────────────────────────────────────
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Initializing Google Calendar client for connector {connector_id}",
|
||||
|
|
@ -203,36 +241,26 @@ async def index_google_calendar_events(
|
|||
if end_date == "undefined" or end_date == "":
|
||||
end_date = None
|
||||
|
||||
# Calculate date range
|
||||
# For calendar connectors, allow future dates to index upcoming events
|
||||
# ── Date range calculation ────────────────────────────────────
|
||||
if start_date is None or end_date is None:
|
||||
# Fall back to calculating dates based on last_indexed_at
|
||||
# Default to today (users can manually select future dates if needed)
|
||||
calculated_end_date = datetime.now()
|
||||
|
||||
# Use last_indexed_at as start date if available, otherwise use 30 days ago
|
||||
if connector.last_indexed_at:
|
||||
# Convert dates to be comparable (both timezone-naive)
|
||||
last_indexed_naive = (
|
||||
connector.last_indexed_at.replace(tzinfo=None)
|
||||
if connector.last_indexed_at.tzinfo
|
||||
else connector.last_indexed_at
|
||||
)
|
||||
|
||||
# Allow future dates - use last_indexed_at as start date
|
||||
calculated_start_date = last_indexed_naive
|
||||
logger.info(
|
||||
f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date"
|
||||
)
|
||||
else:
|
||||
calculated_start_date = datetime.now() - timedelta(
|
||||
days=365
|
||||
) # Use 365 days as default for calendar events (matches frontend)
|
||||
calculated_start_date = datetime.now() - timedelta(days=365)
|
||||
logger.info(
|
||||
f"No last_indexed_at found, using {calculated_start_date.strftime('%Y-%m-%d')} (365 days ago) as start date"
|
||||
)
|
||||
|
||||
# Use calculated dates if not provided
|
||||
start_date_str = (
|
||||
start_date if start_date else calculated_start_date.strftime("%Y-%m-%d")
|
||||
)
|
||||
|
|
@ -240,19 +268,14 @@ async def index_google_calendar_events(
|
|||
end_date if end_date else calculated_end_date.strftime("%Y-%m-%d")
|
||||
)
|
||||
else:
|
||||
# Use provided dates (including future dates)
|
||||
start_date_str = start_date
|
||||
end_date_str = end_date
|
||||
|
||||
# FIX: Ensure end_date is at least 1 day after start_date to avoid
|
||||
# "start_date must be strictly before end_date" errors when dates are the same
|
||||
# (e.g., when last_indexed_at is today)
|
||||
if start_date_str == end_date_str:
|
||||
logger.info(
|
||||
f"Start date ({start_date_str}) equals end date ({end_date_str}), "
|
||||
"adjusting end date to next day to ensure valid date range"
|
||||
)
|
||||
# Parse end_date and add 1 day
|
||||
try:
|
||||
end_dt = parse_date_flexible(end_date_str)
|
||||
except ValueError:
|
||||
|
|
@ -264,6 +287,7 @@ async def index_google_calendar_events(
|
|||
end_date_str = end_dt.strftime("%Y-%m-%d")
|
||||
logger.info(f"Adjusted end date to {end_date_str}")
|
||||
|
||||
# ── Fetch events ──────────────────────────────────────────────
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching Google Calendar events from {start_date_str} to {end_date_str}",
|
||||
|
|
@ -274,27 +298,19 @@ async def index_google_calendar_events(
|
|||
},
|
||||
)
|
||||
|
||||
# Get events within date range from primary calendar
|
||||
try:
|
||||
events, error = await calendar_client.get_all_primary_calendar_events(
|
||||
start_date=start_date_str, end_date=end_date_str
|
||||
)
|
||||
|
||||
if error:
|
||||
# Don't treat "No events found" as an error that should stop indexing
|
||||
if "No events found" in error:
|
||||
logger.info(f"No Google Calendar events found: {error}")
|
||||
logger.info(
|
||||
"No events found is not a critical error, continuing with update"
|
||||
)
|
||||
if update_last_indexed:
|
||||
await update_connector_last_indexed(
|
||||
session, connector, update_last_indexed
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Updated last_indexed_at to {connector.last_indexed_at} despite no events found"
|
||||
)
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
|
|
@ -304,7 +320,6 @@ async def index_google_calendar_events(
|
|||
return 0, 0, None
|
||||
else:
|
||||
logger.error(f"Failed to get Google Calendar events: {error}")
|
||||
# Check if this is an authentication error that requires re-authentication
|
||||
error_message = error
|
||||
error_type = "APIError"
|
||||
if (
|
||||
|
|
@ -329,28 +344,15 @@ async def index_google_calendar_events(
|
|||
logger.error(f"Error fetching Google Calendar events: {e!s}", exc_info=True)
|
||||
return 0, 0, f"Error fetching Google Calendar events: {e!s}"
|
||||
|
||||
documents_indexed = 0
|
||||
# ── Build ConnectorDocuments ──────────────────────────────────
|
||||
connector_docs: list[ConnectorDocument] = []
|
||||
documents_skipped = 0
|
||||
documents_failed = 0 # Track events that failed processing
|
||||
duplicate_content_count = (
|
||||
0 # Track events skipped due to duplicate content_hash
|
||||
)
|
||||
|
||||
# Heartbeat tracking - update notification periodically to prevent appearing stuck
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Analyze all events, create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# =======================================================================
|
||||
events_to_process = [] # List of dicts with document and event data
|
||||
new_documents_created = False
|
||||
duplicate_content_count = 0
|
||||
|
||||
for event in events:
|
||||
try:
|
||||
event_id = event.get("id")
|
||||
event_summary = event.get("summary", "No Title")
|
||||
calendar_id = event.get("calendarId", "")
|
||||
|
||||
if not event_id:
|
||||
logger.warning(f"Skipping event with missing ID: {event_summary}")
|
||||
|
|
@ -363,246 +365,55 @@ async def index_google_calendar_events(
|
|||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
start = event.get("start", {})
|
||||
end = event.get("end", {})
|
||||
start_time = start.get("dateTime") or start.get("date", "")
|
||||
end_time = end.get("dateTime") or end.get("date", "")
|
||||
location = event.get("location", "")
|
||||
description = event.get("description", "")
|
||||
|
||||
# Generate unique identifier hash for this Google Calendar event
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_CALENDAR_CONNECTOR, event_id, search_space_id
|
||||
doc = _build_connector_doc(
|
||||
event,
|
||||
event_markdown,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=connector.enable_summary,
|
||||
)
|
||||
|
||||
# Generate content hash
|
||||
content_hash = generate_content_hash(event_markdown, search_space_id)
|
||||
|
||||
# Check if document with this unique identifier already exists
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
# Fallback: legacy Composio hash
|
||||
if not existing_document:
|
||||
legacy_hash = generate_unique_identifier_hash(
|
||||
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
event_id,
|
||||
search_space_id,
|
||||
)
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, legacy_hash
|
||||
)
|
||||
if existing_document:
|
||||
existing_document.unique_identifier_hash = (
|
||||
unique_identifier_hash
|
||||
)
|
||||
if (
|
||||
existing_document.document_type
|
||||
== DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
existing_document.document_type = (
|
||||
DocumentType.GOOGLE_CALENDAR_CONNECTOR
|
||||
)
|
||||
logger.info(
|
||||
f"Migrated legacy Composio Calendar document: {event_id}"
|
||||
)
|
||||
|
||||
if existing_document:
|
||||
# Document exists - check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Ensure status is ready (might have been stuck in processing/pending)
|
||||
if not DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
events_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
"event_markdown": event_markdown,
|
||||
"content_hash": content_hash,
|
||||
"event_id": event_id,
|
||||
"event_summary": event_summary,
|
||||
"calendar_id": calendar_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
"description": description,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Document doesn't exist by unique_identifier_hash
|
||||
# Check if a document with the same content_hash exists (from another connector)
|
||||
with session.no_autoflush:
|
||||
duplicate_by_content = await check_duplicate_document_by_hash(
|
||||
session, content_hash
|
||||
duplicate = await check_duplicate_document_by_hash(
|
||||
session, compute_content_hash(doc)
|
||||
)
|
||||
|
||||
if duplicate_by_content:
|
||||
# A document with the same content already exists (likely from Composio connector)
|
||||
if duplicate:
|
||||
logger.info(
|
||||
f"Event {event_summary} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate_by_content.id}, "
|
||||
f"type: {duplicate_by_content.document_type}). Skipping to avoid duplicate content."
|
||||
f"Event {doc.title} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate.id}, "
|
||||
f"type: {duplicate.document_type}). Skipping."
|
||||
)
|
||||
duplicate_content_count += 1
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Create new document with PENDING status (visible in UI immediately)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=event_summary,
|
||||
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
document_metadata={
|
||||
"event_id": event_id,
|
||||
"event_summary": event_summary,
|
||||
"calendar_id": calendar_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
chunks=[], # Empty at creation - safe for async
|
||||
status=DocumentStatus.pending(), # Pending until processing starts
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
session.add(document)
|
||||
new_documents_created = True
|
||||
|
||||
events_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
"event_markdown": event_markdown,
|
||||
"content_hash": content_hash,
|
||||
"event_id": event_id,
|
||||
"event_summary": event_summary,
|
||||
"calendar_id": calendar_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
"description": description,
|
||||
}
|
||||
)
|
||||
connector_docs.append(doc)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Phase 1 for event: {e!s}", exc_info=True)
|
||||
documents_failed += 1
|
||||
logger.error(f"Error building ConnectorDocument for event: {e!s}", exc_info=True)
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Commit all pending documents - they all appear in UI now
|
||||
if new_documents_created:
|
||||
logger.info(
|
||||
f"Phase 1: Committing {len([e for e in events_to_process if e['is_new']])} pending documents"
|
||||
)
|
||||
await session.commit()
|
||||
# ── Pipeline: migrate legacy docs + parallel index ─────────────
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(events_to_process)} documents")
|
||||
await pipeline.migrate_legacy_docs(connector_docs)
|
||||
|
||||
for item in events_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
async def _get_llm(s):
|
||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||
|
||||
document = item["document"]
|
||||
try:
|
||||
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
|
||||
document.status = DocumentStatus.processing()
|
||||
await session.commit()
|
||||
|
||||
# Heavy processing (LLM, embeddings, chunks)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
|
||||
connector_docs,
|
||||
_get_llm,
|
||||
max_concurrency=3,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
|
||||
)
|
||||
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"event_id": item["event_id"],
|
||||
"event_summary": item["event_summary"],
|
||||
"calendar_id": item["calendar_id"],
|
||||
"start_time": item["start_time"],
|
||||
"end_time": item["end_time"],
|
||||
"location": item["location"] or "No location",
|
||||
"document_type": "Google Calendar Event",
|
||||
"connector_type": "Google Calendar",
|
||||
}
|
||||
(
|
||||
summary_content,
|
||||
summary_embedding,
|
||||
) = await generate_document_summary(
|
||||
item["event_markdown"], user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = f"Google Calendar Event: {item['event_summary']}\n\n{item['event_markdown']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["event_markdown"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = item["event_summary"]
|
||||
document.content = summary_content
|
||||
document.content_hash = item["content_hash"]
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"event_id": item["event_id"],
|
||||
"event_summary": item["event_summary"],
|
||||
"calendar_id": item["calendar_id"],
|
||||
"start_time": item["start_time"],
|
||||
"end_time": item["end_time"],
|
||||
"location": item["location"],
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
await safe_set_chunks(session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready()
|
||||
|
||||
documents_indexed += 1
|
||||
|
||||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Google Calendar events processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Calendar event: {e!s}", exc_info=True)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
document.updated_at = get_current_timestamp()
|
||||
except Exception as status_error:
|
||||
logger.error(
|
||||
f"Failed to update document status to failed: {status_error}"
|
||||
)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
|
||||
# ── Finalize ──────────────────────────────────────────────────
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
||||
# Final commit for any remaining documents not yet committed in batches
|
||||
logger.info(
|
||||
f"Final commit: Total {documents_indexed} Google Calendar events processed"
|
||||
)
|
||||
|
|
@ -612,22 +423,18 @@ async def index_google_calendar_events(
|
|||
"Successfully committed all Google Calendar document changes to database"
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any remaining integrity errors gracefully (race conditions, etc.)
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in str(e).lower()
|
||||
or "uniqueviolationerror" in str(e).lower()
|
||||
):
|
||||
logger.warning(
|
||||
f"Duplicate content_hash detected during final commit. "
|
||||
f"This may occur if the same event was indexed by multiple connectors. "
|
||||
f"Rolling back and continuing. Error: {e!s}"
|
||||
)
|
||||
await session.rollback()
|
||||
# Don't fail the entire task - some documents may have been successfully indexed
|
||||
else:
|
||||
raise
|
||||
|
||||
# Build warning message if there were issues
|
||||
warning_parts = []
|
||||
if duplicate_content_count > 0:
|
||||
warning_parts.append(f"{duplicate_content_count} duplicate")
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,12 +1,10 @@
|
|||
"""
|
||||
Google Gmail connector indexer.
|
||||
|
||||
Implements 2-phase document status updates for real-time UI feedback:
|
||||
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
|
||||
- Phase 2: Process each document: pending → processing → ready/failed
|
||||
Uses the shared IndexingPipelineService for document deduplication,
|
||||
summarization, chunking, and embedding.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -15,21 +13,12 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentStatus,
|
||||
DocumentType,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.db import DocumentType, SearchSourceConnectorType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
from app.utils.google_credentials import (
|
||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
||||
build_composio_credentials,
|
||||
|
|
@ -37,12 +26,9 @@ from app.utils.google_credentials import (
|
|||
|
||||
from .base import (
|
||||
calculate_date_range,
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_connector_by_id,
|
||||
get_current_timestamp,
|
||||
logger,
|
||||
safe_set_chunks,
|
||||
update_connector_last_indexed,
|
||||
)
|
||||
|
||||
|
|
@ -51,13 +37,70 @@ ACCEPTED_GMAIL_CONNECTOR_TYPES = {
|
|||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
}
|
||||
|
||||
# Type hint for heartbeat callback
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
|
||||
# Heartbeat interval in seconds
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
|
||||
def _build_connector_doc(
|
||||
message: dict,
|
||||
markdown_content: str,
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
) -> ConnectorDocument:
|
||||
"""Map a raw Gmail API message dict to a ConnectorDocument."""
|
||||
message_id = message.get("id", "")
|
||||
thread_id = message.get("threadId", "")
|
||||
payload = message.get("payload", {})
|
||||
headers = payload.get("headers", [])
|
||||
|
||||
subject = "No Subject"
|
||||
sender = "Unknown Sender"
|
||||
date_str = "Unknown Date"
|
||||
|
||||
for header in headers:
|
||||
name = header.get("name", "").lower()
|
||||
value = header.get("value", "")
|
||||
if name == "subject":
|
||||
subject = value
|
||||
elif name == "from":
|
||||
sender = value
|
||||
elif name == "date":
|
||||
date_str = value
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date": date_str,
|
||||
"connector_id": connector_id,
|
||||
"document_type": "Gmail Message",
|
||||
"connector_type": "Google Gmail",
|
||||
}
|
||||
|
||||
fallback_summary = (
|
||||
f"Google Gmail Message: {subject}\n\n"
|
||||
f"From: {sender}\nDate: {date_str}\n\n"
|
||||
f"{markdown_content}"
|
||||
)
|
||||
|
||||
return ConnectorDocument(
|
||||
title=subject,
|
||||
source_markdown=markdown_content,
|
||||
unique_id=message_id,
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=enable_summary,
|
||||
fallback_summary=fallback_summary,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def index_google_gmail_messages(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
|
|
@ -80,7 +123,7 @@ async def index_google_gmail_messages(
|
|||
start_date: Start date for filtering messages (YYYY-MM-DD format)
|
||||
end_date: End date for filtering messages (YYYY-MM-DD format)
|
||||
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
|
||||
max_messages: Maximum number of messages to fetch (default: 100)
|
||||
max_messages: Maximum number of messages to fetch (default: 1000)
|
||||
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
|
||||
|
||||
Returns:
|
||||
|
|
@ -88,7 +131,6 @@ async def index_google_gmail_messages(
|
|||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="google_gmail_messages_indexing",
|
||||
source="connector_indexing_task",
|
||||
|
|
@ -103,7 +145,7 @@ async def index_google_gmail_messages(
|
|||
)
|
||||
|
||||
try:
|
||||
# Accept both native and Composio Gmail connectors
|
||||
# ── Connector lookup ──────────────────────────────────────────
|
||||
connector = None
|
||||
for ct in ACCEPTED_GMAIL_CONNECTOR_TYPES:
|
||||
connector = await get_connector_by_id(session, connector_id, ct)
|
||||
|
|
@ -117,7 +159,7 @@ async def index_google_gmail_messages(
|
|||
)
|
||||
return 0, 0, error_msg
|
||||
|
||||
# Build credentials based on connector type
|
||||
# ── Credential building ───────────────────────────────────────
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
|
|
@ -189,6 +231,7 @@ async def index_google_gmail_messages(
|
|||
)
|
||||
return 0, 0, "Google gmail credentials not found in connector config"
|
||||
|
||||
# ── Gmail client init ─────────────────────────────────────────
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Initializing Google gmail client for connector {connector_id}",
|
||||
|
|
@ -199,14 +242,11 @@ async def index_google_gmail_messages(
|
|||
credentials, session, user_id, connector_id
|
||||
)
|
||||
|
||||
# Calculate date range using last_indexed_at if dates not provided
|
||||
# This ensures Gmail uses the same date logic as other connectors
|
||||
# (uses last_indexed_at → now, or 365 days back for first-time indexing)
|
||||
calculated_start_date, calculated_end_date = calculate_date_range(
|
||||
connector, start_date, end_date, default_days_back=365
|
||||
)
|
||||
|
||||
# Fetch recent Google gmail messages
|
||||
# ── Fetch messages ────────────────────────────────────────────
|
||||
logger.info(
|
||||
f"Fetching emails for connector {connector_id} "
|
||||
f"from {calculated_start_date} to {calculated_end_date}"
|
||||
|
|
@ -218,7 +258,6 @@ async def index_google_gmail_messages(
|
|||
)
|
||||
|
||||
if error:
|
||||
# Check if this is an authentication error that requires re-authentication
|
||||
error_message = error
|
||||
error_type = "APIError"
|
||||
if (
|
||||
|
|
@ -243,286 +282,74 @@ async def index_google_gmail_messages(
|
|||
|
||||
logger.info(f"Found {len(messages)} Google gmail messages to index")
|
||||
|
||||
documents_indexed = 0
|
||||
# ── Build ConnectorDocuments ──────────────────────────────────
|
||||
connector_docs: list[ConnectorDocument] = []
|
||||
documents_skipped = 0
|
||||
documents_failed = 0 # Track messages that failed processing
|
||||
duplicate_content_count = (
|
||||
0 # Track messages skipped due to duplicate content_hash
|
||||
)
|
||||
|
||||
# Heartbeat tracking - update notification periodically to prevent appearing stuck
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Analyze all messages, create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# =======================================================================
|
||||
messages_to_process = [] # List of dicts with document and message data
|
||||
new_documents_created = False
|
||||
duplicate_content_count = 0
|
||||
|
||||
for message in messages:
|
||||
try:
|
||||
# Extract message information
|
||||
message_id = message.get("id", "")
|
||||
thread_id = message.get("threadId", "")
|
||||
|
||||
# Extract headers for subject and sender
|
||||
payload = message.get("payload", {})
|
||||
headers = payload.get("headers", [])
|
||||
|
||||
subject = "No Subject"
|
||||
sender = "Unknown Sender"
|
||||
date_str = "Unknown Date"
|
||||
|
||||
for header in headers:
|
||||
name = header.get("name", "").lower()
|
||||
value = header.get("value", "")
|
||||
if name == "subject":
|
||||
subject = value
|
||||
elif name == "from":
|
||||
sender = value
|
||||
elif name == "date":
|
||||
date_str = value
|
||||
|
||||
if not message_id:
|
||||
logger.warning(f"Skipping message with missing ID: {subject}")
|
||||
logger.warning("Skipping message with missing ID")
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Format message to markdown
|
||||
markdown_content = gmail_connector.format_message_to_markdown(message)
|
||||
|
||||
if not markdown_content.strip():
|
||||
logger.warning(f"Skipping message with no content: {subject}")
|
||||
logger.warning(f"Skipping message with no content: {message_id}")
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Generate unique identifier hash for this Gmail message
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_GMAIL_CONNECTOR, message_id, search_space_id
|
||||
doc = _build_connector_doc(
|
||||
message,
|
||||
markdown_content,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=connector.enable_summary,
|
||||
)
|
||||
|
||||
# Generate content hash
|
||||
content_hash = generate_content_hash(markdown_content, search_space_id)
|
||||
|
||||
# Check if document with this unique identifier already exists
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
# Fallback: legacy Composio hash
|
||||
if not existing_document:
|
||||
legacy_hash = generate_unique_identifier_hash(
|
||||
DocumentType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
message_id,
|
||||
search_space_id,
|
||||
)
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, legacy_hash
|
||||
)
|
||||
if existing_document:
|
||||
existing_document.unique_identifier_hash = (
|
||||
unique_identifier_hash
|
||||
)
|
||||
if (
|
||||
existing_document.document_type
|
||||
== DocumentType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
existing_document.document_type = (
|
||||
DocumentType.GOOGLE_GMAIL_CONNECTOR
|
||||
)
|
||||
logger.info(
|
||||
f"Migrated legacy Composio Gmail document: {message_id}"
|
||||
)
|
||||
|
||||
if existing_document:
|
||||
# Document exists - check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Ensure status is ready (might have been stuck in processing/pending)
|
||||
if not DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
messages_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date_str": date_str,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Document doesn't exist by unique_identifier_hash
|
||||
# Check if a document with the same content_hash exists (from another connector)
|
||||
with session.no_autoflush:
|
||||
duplicate_by_content = await check_duplicate_document_by_hash(
|
||||
session, content_hash
|
||||
duplicate = await check_duplicate_document_by_hash(
|
||||
session, compute_content_hash(doc)
|
||||
)
|
||||
|
||||
if duplicate_by_content:
|
||||
if duplicate:
|
||||
logger.info(
|
||||
f"Gmail message {subject} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate_by_content.id}, "
|
||||
f"type: {duplicate_by_content.document_type}). Skipping."
|
||||
f"Gmail message {doc.title} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate.id}, "
|
||||
f"type: {duplicate.document_type}). Skipping."
|
||||
)
|
||||
duplicate_content_count += 1
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Create new document with PENDING status (visible in UI immediately)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=subject,
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
document_metadata={
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date": date_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
chunks=[], # Empty at creation - safe for async
|
||||
status=DocumentStatus.pending(), # Pending until processing starts
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
session.add(document)
|
||||
new_documents_created = True
|
||||
|
||||
messages_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date_str": date_str,
|
||||
}
|
||||
)
|
||||
connector_docs.append(doc)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Phase 1 for message: {e!s}", exc_info=True)
|
||||
documents_failed += 1
|
||||
logger.error(f"Error building ConnectorDocument for message: {e!s}", exc_info=True)
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Commit all pending documents - they all appear in UI now
|
||||
if new_documents_created:
|
||||
logger.info(
|
||||
f"Phase 1: Committing {len([m for m in messages_to_process if m['is_new']])} pending documents"
|
||||
)
|
||||
await session.commit()
|
||||
# ── Pipeline: migrate legacy docs + parallel index ─────────────
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
|
||||
await pipeline.migrate_legacy_docs(connector_docs)
|
||||
|
||||
for item in messages_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
async def _get_llm(s):
|
||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||
|
||||
document = item["document"]
|
||||
try:
|
||||
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
|
||||
document.status = DocumentStatus.processing()
|
||||
await session.commit()
|
||||
|
||||
# Heavy processing (LLM, embeddings, chunks)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
|
||||
connector_docs,
|
||||
_get_llm,
|
||||
max_concurrency=3,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
|
||||
)
|
||||
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"message_id": item["message_id"],
|
||||
"thread_id": item["thread_id"],
|
||||
"subject": item["subject"],
|
||||
"sender": item["sender"],
|
||||
"date": item["date_str"],
|
||||
"document_type": "Gmail Message",
|
||||
"connector_type": "Google Gmail",
|
||||
}
|
||||
(
|
||||
summary_content,
|
||||
summary_embedding,
|
||||
) = await generate_document_summary(
|
||||
item["markdown_content"],
|
||||
user_llm,
|
||||
document_metadata_for_summary,
|
||||
)
|
||||
else:
|
||||
summary_content = f"Google Gmail Message: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = item["subject"]
|
||||
document.content = summary_content
|
||||
document.content_hash = item["content_hash"]
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"message_id": item["message_id"],
|
||||
"thread_id": item["thread_id"],
|
||||
"subject": item["subject"],
|
||||
"sender": item["sender"],
|
||||
"date": item["date_str"],
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
await safe_set_chunks(session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready()
|
||||
|
||||
documents_indexed += 1
|
||||
|
||||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Gmail messages processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Gmail message: {e!s}", exc_info=True)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
document.updated_at = get_current_timestamp()
|
||||
except Exception as status_error:
|
||||
logger.error(
|
||||
f"Failed to update document status to failed: {status_error}"
|
||||
)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
|
||||
# ── Finalize ──────────────────────────────────────────────────
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
||||
# Final commit for any remaining documents not yet committed in batches
|
||||
logger.info(f"Final commit: Total {documents_indexed} Gmail messages processed")
|
||||
try:
|
||||
await session.commit()
|
||||
|
|
@ -530,22 +357,18 @@ async def index_google_gmail_messages(
|
|||
"Successfully committed all Google Gmail document changes to database"
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any remaining integrity errors gracefully (race conditions, etc.)
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in str(e).lower()
|
||||
or "uniqueviolationerror" in str(e).lower()
|
||||
):
|
||||
logger.warning(
|
||||
f"Duplicate content_hash detected during final commit. "
|
||||
f"This may occur if the same message was indexed by multiple connectors. "
|
||||
f"Rolling back and continuing. Error: {e!s}"
|
||||
)
|
||||
await session.rollback()
|
||||
# Don't fail the entire task - some documents may have been successfully indexed
|
||||
else:
|
||||
raise
|
||||
|
||||
# Build warning message if there were issues
|
||||
warning_parts = []
|
||||
if duplicate_content_count > 0:
|
||||
warning_parts.append(f"{duplicate_content_count} duplicate")
|
||||
|
|
@ -555,7 +378,6 @@ async def index_google_gmail_messages(
|
|||
|
||||
total_processed = documents_indexed
|
||||
|
||||
# Log success
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed Google Gmail indexing for connector {connector_id}",
|
||||
|
|
|
|||
|
|
@ -1,49 +1,80 @@
|
|||
"""
|
||||
Jira connector indexer.
|
||||
|
||||
Provides real-time document status updates during indexing using a two-phase approach:
|
||||
- Phase 1: Create all documents with PENDING status (visible in UI immediately)
|
||||
- Phase 2: Process each document one by one (PENDING → PROCESSING → READY/FAILED)
|
||||
"""
|
||||
"""Jira connector indexer using the unified parallel indexing pipeline."""
|
||||
|
||||
import contextlib
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.db import DocumentType, SearchSourceConnectorType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
from .base import (
|
||||
calculate_date_range,
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_connector_by_id,
|
||||
get_current_timestamp,
|
||||
logger,
|
||||
safe_set_chunks,
|
||||
update_connector_last_indexed,
|
||||
)
|
||||
|
||||
# Type hint for heartbeat callback
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
|
||||
# Heartbeat interval in seconds - update notification every 30 seconds
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
|
||||
def _build_connector_doc(
|
||||
issue: dict,
|
||||
formatted_issue: dict,
|
||||
issue_content: str,
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
) -> ConnectorDocument:
|
||||
"""Map a raw Jira issue dict to a ConnectorDocument."""
|
||||
issue_id = issue.get("key", "")
|
||||
issue_identifier = issue.get("key", "")
|
||||
issue_title = issue.get("id", "")
|
||||
state = formatted_issue.get("status", "Unknown")
|
||||
priority = formatted_issue.get("priority", "Unknown")
|
||||
comment_count = len(formatted_issue.get("comments", []))
|
||||
|
||||
metadata = {
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"state": state,
|
||||
"priority": priority,
|
||||
"comment_count": comment_count,
|
||||
"connector_id": connector_id,
|
||||
"document_type": "Jira Issue",
|
||||
"connector_type": "Jira",
|
||||
}
|
||||
|
||||
fallback_summary = (
|
||||
f"Jira Issue {issue_identifier}: {issue_title}\n\n"
|
||||
f"Status: {state}\n\n{issue_content}"
|
||||
)
|
||||
|
||||
return ConnectorDocument(
|
||||
title=f"{issue_identifier}: {issue_title}",
|
||||
source_markdown=issue_content,
|
||||
unique_id=issue_id,
|
||||
document_type=DocumentType.JIRA_CONNECTOR,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=enable_summary,
|
||||
fallback_summary=fallback_summary,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def index_jira_issues(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
|
|
@ -53,26 +84,9 @@ async def index_jira_issues(
|
|||
end_date: str | None = None,
|
||||
update_last_indexed: bool = True,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, str | None]:
|
||||
"""
|
||||
Index Jira issues and comments.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Jira connector
|
||||
search_space_id: ID of the search space to store documents in
|
||||
user_id: User ID
|
||||
start_date: Start date for indexing (YYYY-MM-DD format)
|
||||
end_date: End date for indexing (YYYY-MM-DD format)
|
||||
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
|
||||
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
|
||||
|
||||
Returns:
|
||||
Tuple containing (number of documents indexed, error message or None)
|
||||
"""
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""Index Jira issues and comments."""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="jira_issues_indexing",
|
||||
source="connector_indexing_task",
|
||||
|
|
@ -86,7 +100,6 @@ async def index_jira_issues(
|
|||
)
|
||||
|
||||
try:
|
||||
# Get the connector from the database
|
||||
connector = await get_connector_by_id(
|
||||
session, connector_id, SearchSourceConnectorType.JIRA_CONNECTOR
|
||||
)
|
||||
|
|
@ -98,24 +111,15 @@ async def index_jira_issues(
|
|||
"Connector not found",
|
||||
{"error_type": "ConnectorNotFound"},
|
||||
)
|
||||
return 0, f"Connector with ID {connector_id} not found"
|
||||
return 0, 0, f"Connector with ID {connector_id} not found"
|
||||
|
||||
# Initialize Jira client with internal refresh capability
|
||||
# Token refresh will happen automatically when needed
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Initializing Jira client for connector {connector_id}",
|
||||
{"stage": "client_initialization"},
|
||||
)
|
||||
|
||||
logger.info(f"Initializing Jira client for connector {connector_id}")
|
||||
|
||||
# Create connector with session and connector_id for internal refresh
|
||||
# Token refresh will happen automatically when needed
|
||||
jira_client = JiraHistoryConnector(session=session, connector_id=connector_id)
|
||||
|
||||
# Calculate date range
|
||||
# Handle "undefined" strings from frontend
|
||||
if start_date == "undefined" or start_date == "":
|
||||
start_date = None
|
||||
if end_date == "undefined" or end_date == "":
|
||||
|
|
@ -135,19 +139,14 @@ async def index_jira_issues(
|
|||
},
|
||||
)
|
||||
|
||||
# Get issues within date range
|
||||
try:
|
||||
issues, error = await jira_client.get_issues_by_date_range(
|
||||
start_date=start_date_str, end_date=end_date_str, include_comments=True
|
||||
)
|
||||
|
||||
if error:
|
||||
# Don't treat "No issues found" as an error that should stop indexing
|
||||
if "No issues found" in error:
|
||||
logger.info(f"No Jira issues found: {error}")
|
||||
logger.info(
|
||||
"No issues found is not a critical error, continuing with update"
|
||||
)
|
||||
if update_last_indexed:
|
||||
await update_connector_last_indexed(
|
||||
session, connector, update_last_indexed
|
||||
|
|
@ -162,7 +161,8 @@ async def index_jira_issues(
|
|||
f"No Jira issues found in date range {start_date_str} to {end_date_str}",
|
||||
{"issues_found": 0},
|
||||
)
|
||||
return 0, None
|
||||
await jira_client.close()
|
||||
return 0, 0, None
|
||||
else:
|
||||
logger.error(f"Failed to get Jira issues: {error}")
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -171,29 +171,30 @@ async def index_jira_issues(
|
|||
"API Error",
|
||||
{"error_type": "APIError"},
|
||||
)
|
||||
return 0, f"Failed to get Jira issues: {error}"
|
||||
await jira_client.close()
|
||||
return 0, 0, f"Failed to get Jira issues: {error}"
|
||||
|
||||
logger.info(f"Retrieved {len(issues)} issues from Jira API")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Jira issues: {e!s}", exc_info=True)
|
||||
return 0, f"Error fetching Jira issues: {e!s}"
|
||||
await jira_client.close()
|
||||
return 0, 0, f"Error fetching Jira issues: {e!s}"
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Analyze all issues, create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# =======================================================================
|
||||
documents_indexed = 0
|
||||
if not issues:
|
||||
logger.info("No Jira issues found for the specified date range")
|
||||
if update_last_indexed:
|
||||
await update_connector_last_indexed(
|
||||
session, connector, update_last_indexed
|
||||
)
|
||||
await session.commit()
|
||||
await jira_client.close()
|
||||
return 0, 0, None
|
||||
|
||||
connector_docs: list[ConnectorDocument] = []
|
||||
documents_skipped = 0
|
||||
documents_failed = 0
|
||||
duplicate_content_count = 0
|
||||
|
||||
# Heartbeat tracking - update notification periodically to prevent appearing stuck
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
issues_to_process = [] # List of dicts with document and issue data
|
||||
new_documents_created = False
|
||||
|
||||
for issue in issues:
|
||||
try:
|
||||
issue_id = issue.get("key")
|
||||
|
|
@ -207,10 +208,7 @@ async def index_jira_issues(
|
|||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Format the issue for better readability
|
||||
formatted_issue = jira_client.format_issue(issue)
|
||||
|
||||
# Convert to markdown
|
||||
issue_content = jira_client.format_issue_to_markdown(formatted_issue)
|
||||
|
||||
if not issue_content:
|
||||
|
|
@ -220,53 +218,19 @@ async def index_jira_issues(
|
|||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Generate unique identifier hash for this Jira issue
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.JIRA_CONNECTOR, issue_id, search_space_id
|
||||
doc = _build_connector_doc(
|
||||
issue,
|
||||
formatted_issue,
|
||||
issue_content,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=connector.enable_summary,
|
||||
)
|
||||
|
||||
# Generate content hash
|
||||
content_hash = generate_content_hash(issue_content, search_space_id)
|
||||
|
||||
# Check if document with this unique identifier already exists
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
comment_count = len(formatted_issue.get("comments", []))
|
||||
|
||||
if existing_document:
|
||||
# Document exists - check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Ensure status is ready (might have been stuck in processing/pending)
|
||||
if not DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
issues_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
"issue_content": issue_content,
|
||||
"content_hash": content_hash,
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"formatted_issue": formatted_issue,
|
||||
"comment_count": comment_count,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Document doesn't exist by unique_identifier_hash
|
||||
# Check if a document with the same content_hash exists (from another connector)
|
||||
with session.no_autoflush:
|
||||
duplicate_by_content = await check_duplicate_document_by_hash(
|
||||
session, content_hash
|
||||
session, compute_content_hash(doc)
|
||||
)
|
||||
|
||||
if duplicate_by_content:
|
||||
|
|
@ -279,160 +243,37 @@ async def index_jira_issues(
|
|||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Create new document with PENDING status (visible in UI immediately)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=f"{issue_identifier}: {issue_title}",
|
||||
document_type=DocumentType.JIRA_CONNECTOR,
|
||||
document_metadata={
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"state": formatted_issue.get("status", "Unknown"),
|
||||
"comment_count": comment_count,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
chunks=[], # Empty at creation - safe for async
|
||||
status=DocumentStatus.pending(), # Pending until processing starts
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
session.add(document)
|
||||
new_documents_created = True
|
||||
|
||||
issues_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
"issue_content": issue_content,
|
||||
"content_hash": content_hash,
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"formatted_issue": formatted_issue,
|
||||
"comment_count": comment_count,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Phase 1 for issue: {e!s}", exc_info=True)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
# Commit all pending documents - they all appear in UI now
|
||||
if new_documents_created:
|
||||
logger.info(
|
||||
f"Phase 1: Committing {len([i for i in issues_to_process if i['is_new']])} pending documents"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(issues_to_process)} documents")
|
||||
|
||||
for item in issues_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
document = item["document"]
|
||||
try:
|
||||
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
|
||||
document.status = DocumentStatus.processing()
|
||||
await session.commit()
|
||||
|
||||
# Heavy processing (LLM, embeddings, chunks)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata = {
|
||||
"issue_key": item["issue_identifier"],
|
||||
"issue_title": item["issue_title"],
|
||||
"status": item["formatted_issue"].get("status", "Unknown"),
|
||||
"priority": item["formatted_issue"].get("priority", "Unknown"),
|
||||
"comment_count": item["comment_count"],
|
||||
"document_type": "Jira Issue",
|
||||
"connector_type": "Jira",
|
||||
}
|
||||
(
|
||||
summary_content,
|
||||
summary_embedding,
|
||||
) = await generate_document_summary(
|
||||
item["issue_content"], user_llm, document_metadata
|
||||
)
|
||||
else:
|
||||
summary_content = f"Jira Issue {item['issue_identifier']}: {item['issue_title']}\n\n{item['issue_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
# Process chunks - using the full issue content with comments
|
||||
chunks = await create_document_chunks(item["issue_content"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = f"{item['issue_identifier']}: {item['issue_title']}"
|
||||
document.content = summary_content
|
||||
document.content_hash = item["content_hash"]
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"issue_id": item["issue_id"],
|
||||
"issue_identifier": item["issue_identifier"],
|
||||
"issue_title": item["issue_title"],
|
||||
"state": item["formatted_issue"].get("status", "Unknown"),
|
||||
"comment_count": item["comment_count"],
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
await safe_set_chunks(session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready()
|
||||
|
||||
documents_indexed += 1
|
||||
|
||||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Jira issues processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
connector_docs.append(doc)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing issue {item.get('issue_identifier', 'Unknown')}: {e!s}",
|
||||
f"Error building ConnectorDocument for issue {issue_identifier}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
document.updated_at = get_current_timestamp()
|
||||
except Exception as status_error:
|
||||
logger.error(
|
||||
f"Failed to update document status to failed: {status_error}"
|
||||
)
|
||||
documents_failed += 1
|
||||
continue # Skip this issue and continue with others
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
pipeline = IndexingPipelineService(session)
|
||||
await pipeline.migrate_legacy_docs(connector_docs)
|
||||
|
||||
async def _get_llm(s: AsyncSession):
|
||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||
|
||||
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
|
||||
connector_docs,
|
||||
_get_llm,
|
||||
max_concurrency=3,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
|
||||
)
|
||||
|
||||
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
|
||||
# This ensures the UI shows "Last indexed" instead of "Never indexed"
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
||||
# Final commit to ensure all documents are persisted (safety net)
|
||||
logger.info(f"Final commit: Total {documents_indexed} Jira issues processed")
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info("Successfully committed all JIRA document changes to database")
|
||||
except Exception as e:
|
||||
# Handle any remaining integrity errors gracefully (race conditions, etc.)
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in str(e).lower()
|
||||
or "uniqueviolationerror" in str(e).lower()
|
||||
|
|
@ -447,7 +288,6 @@ async def index_jira_issues(
|
|||
else:
|
||||
raise
|
||||
|
||||
# Build warning message if there were issues
|
||||
warning_parts = []
|
||||
if duplicate_content_count > 0:
|
||||
warning_parts.append(f"{duplicate_content_count} duplicate")
|
||||
|
|
@ -455,7 +295,6 @@ async def index_jira_issues(
|
|||
warning_parts.append(f"{documents_failed} failed")
|
||||
warning_message = ", ".join(warning_parts) if warning_parts else None
|
||||
|
||||
# Log success
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed JIRA indexing for connector {connector_id}",
|
||||
|
|
@ -466,17 +305,13 @@ async def index_jira_issues(
|
|||
"duplicate_content_count": duplicate_content_count,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"JIRA indexing completed: {documents_indexed} ready, "
|
||||
f"{documents_skipped} skipped, {documents_failed} failed "
|
||||
f"({duplicate_content_count} duplicate content)"
|
||||
)
|
||||
|
||||
# Clean up the connector
|
||||
await jira_client.close()
|
||||
|
||||
return documents_indexed, warning_message
|
||||
return documents_indexed, documents_skipped, warning_message
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
@ -487,11 +322,10 @@ async def index_jira_issues(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
# Clean up the connector in case of error
|
||||
if "jira_client" in locals():
|
||||
with contextlib.suppress(Exception):
|
||||
await jira_client.close()
|
||||
return 0, f"Database error: {db_error!s}"
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -501,8 +335,7 @@ async def index_jira_issues(
|
|||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index JIRA issues: {e!s}", exc_info=True)
|
||||
# Clean up the connector in case of error
|
||||
if "jira_client" in locals():
|
||||
with contextlib.suppress(Exception):
|
||||
await jira_client.close()
|
||||
return 0, f"Failed to index JIRA issues: {e!s}"
|
||||
return 0, 0, f"Failed to index JIRA issues: {e!s}"
|
||||
|
|
|
|||
|
|
@ -1,48 +1,84 @@
|
|||
"""
|
||||
Linear connector indexer.
|
||||
|
||||
Implements 2-phase document status updates for real-time UI feedback:
|
||||
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
|
||||
- Phase 2: Process each document: pending → processing → ready/failed
|
||||
Uses the shared IndexingPipelineService for document deduplication,
|
||||
summarization, chunking, and embedding with bounded parallel indexing.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.linear_connector import LinearConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.db import DocumentType, SearchSourceConnectorType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
from .base import (
|
||||
calculate_date_range,
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_connector_by_id,
|
||||
get_current_timestamp,
|
||||
logger,
|
||||
safe_set_chunks,
|
||||
update_connector_last_indexed,
|
||||
)
|
||||
|
||||
# Type hint for heartbeat callback
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
|
||||
# Heartbeat interval in seconds - update notification every 30 seconds
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
|
||||
def _build_connector_doc(
|
||||
issue: dict,
|
||||
formatted_issue: dict,
|
||||
issue_content: str,
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
) -> ConnectorDocument:
|
||||
"""Map a raw Linear issue dict to a ConnectorDocument."""
|
||||
issue_id = issue.get("id", "")
|
||||
issue_identifier = issue.get("identifier", "")
|
||||
issue_title = issue.get("title", "")
|
||||
state = formatted_issue.get("state", "Unknown")
|
||||
priority = formatted_issue.get("priority", "Unknown")
|
||||
comment_count = len(formatted_issue.get("comments", []))
|
||||
|
||||
metadata = {
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"state": state,
|
||||
"priority": priority,
|
||||
"comment_count": comment_count,
|
||||
"connector_id": connector_id,
|
||||
"document_type": "Linear Issue",
|
||||
"connector_type": "Linear",
|
||||
}
|
||||
|
||||
fallback_summary = (
|
||||
f"Linear Issue {issue_identifier}: {issue_title}\n\n"
|
||||
f"Status: {state}\n\n{issue_content}"
|
||||
)
|
||||
|
||||
return ConnectorDocument(
|
||||
title=f"{issue_identifier}: {issue_title}",
|
||||
source_markdown=issue_content,
|
||||
unique_id=issue_id,
|
||||
document_type=DocumentType.LINEAR_CONNECTOR,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=enable_summary,
|
||||
fallback_summary=fallback_summary,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def index_linear_issues(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
|
|
@ -52,26 +88,15 @@ async def index_linear_issues(
|
|||
end_date: str | None = None,
|
||||
update_last_indexed: bool = True,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, str | None]:
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""
|
||||
Index Linear issues and comments.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Linear connector
|
||||
search_space_id: ID of the search space to store documents in
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing (YYYY-MM-DD format)
|
||||
end_date: End date for indexing (YYYY-MM-DD format)
|
||||
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
|
||||
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
|
||||
|
||||
Returns:
|
||||
Tuple containing (number of documents indexed, error message or None)
|
||||
Tuple of (indexed_count, skipped_count, warning_or_error_message)
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="linear_issues_indexing",
|
||||
source="connector_indexing_task",
|
||||
|
|
@ -85,7 +110,7 @@ async def index_linear_issues(
|
|||
)
|
||||
|
||||
try:
|
||||
# Get the connector
|
||||
# ── Connector lookup ──────────────────────────────────────────
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Retrieving Linear connector {connector_id} from database",
|
||||
|
|
@ -104,11 +129,11 @@ async def index_linear_issues(
|
|||
{"error_type": "ConnectorNotFound"},
|
||||
)
|
||||
return (
|
||||
0,
|
||||
0,
|
||||
f"Connector with ID {connector_id} not found or is not a Linear connector",
|
||||
)
|
||||
|
||||
# Check if access_token exists (support both new OAuth format and old API key format)
|
||||
if not connector.config.get("access_token") and not connector.config.get(
|
||||
"LINEAR_API_KEY"
|
||||
):
|
||||
|
|
@ -118,26 +143,22 @@ async def index_linear_issues(
|
|||
"Missing Linear access token",
|
||||
{"error_type": "MissingToken"},
|
||||
)
|
||||
return 0, "Linear access token not found in connector config"
|
||||
return 0, 0, "Linear access token not found in connector config"
|
||||
|
||||
# Initialize Linear client with internal refresh capability
|
||||
# ── Client init ───────────────────────────────────────────────
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Initializing Linear client for connector {connector_id}",
|
||||
{"stage": "client_initialization"},
|
||||
)
|
||||
|
||||
# Create connector with session and connector_id for internal refresh
|
||||
# Token refresh will happen automatically when needed
|
||||
linear_client = LinearConnector(session=session, connector_id=connector_id)
|
||||
|
||||
# Handle 'undefined' string from frontend (treat as None)
|
||||
if start_date == "undefined" or start_date == "":
|
||||
start_date = None
|
||||
if end_date == "undefined" or end_date == "":
|
||||
end_date = None
|
||||
|
||||
# Calculate date range
|
||||
start_date_str, end_date_str = calculate_date_range(
|
||||
connector, start_date, end_date, default_days_back=365
|
||||
)
|
||||
|
|
@ -154,37 +175,34 @@ async def index_linear_issues(
|
|||
},
|
||||
)
|
||||
|
||||
# Get issues within date range
|
||||
# ── Fetch issues ──────────────────────────────────────────────
|
||||
try:
|
||||
issues, error = await linear_client.get_issues_by_date_range(
|
||||
start_date=start_date_str, end_date=end_date_str, include_comments=True
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
include_comments=True,
|
||||
)
|
||||
|
||||
if error:
|
||||
# Don't treat "No issues found" as an error that should stop indexing
|
||||
if "No issues found" in error:
|
||||
logger.info(f"No Linear issues found: {error}")
|
||||
logger.info(
|
||||
"No issues found is not a critical error, continuing with update"
|
||||
)
|
||||
if update_last_indexed:
|
||||
await update_connector_last_indexed(
|
||||
session, connector, update_last_indexed
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found"
|
||||
)
|
||||
return 0, None
|
||||
return 0, 0, None
|
||||
else:
|
||||
logger.error(f"Failed to get Linear issues: {error}")
|
||||
return 0, f"Failed to get Linear issues: {error}"
|
||||
return 0, 0, f"Failed to get Linear issues: {error}"
|
||||
|
||||
logger.info(f"Retrieved {len(issues)} issues from Linear API")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception when calling Linear API: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to get Linear issues: {e!s}"
|
||||
logger.error(
|
||||
f"Exception when calling Linear API: {e!s}", exc_info=True
|
||||
)
|
||||
return 0, 0, f"Failed to get Linear issues: {e!s}"
|
||||
|
||||
if not issues:
|
||||
logger.info("No Linear issues found for the specified date range")
|
||||
|
|
@ -193,19 +211,12 @@ async def index_linear_issues(
|
|||
session, connector, update_last_indexed
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found"
|
||||
)
|
||||
return 0, None # Return None instead of error message when no issues found
|
||||
return 0, 0, None
|
||||
|
||||
# Track the number of documents indexed
|
||||
documents_indexed = 0
|
||||
# ── Build ConnectorDocuments ──────────────────────────────────
|
||||
connector_docs: list[ConnectorDocument] = []
|
||||
documents_skipped = 0
|
||||
documents_failed = 0 # Track issues that failed processing
|
||||
skipped_issues = []
|
||||
|
||||
# Heartbeat tracking - update notification periodically to prevent appearing stuck
|
||||
last_heartbeat_time = time.time()
|
||||
duplicate_content_count = 0
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
|
|
@ -213,13 +224,6 @@ async def index_linear_issues(
|
|||
{"stage": "process_issues", "total_issues": len(issues)},
|
||||
)
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Analyze all issues, create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# =======================================================================
|
||||
issues_to_process = [] # List of dicts with document and issue data
|
||||
new_documents_created = False
|
||||
|
||||
for issue in issues:
|
||||
try:
|
||||
issue_id = issue.get("id", "")
|
||||
|
|
@ -230,271 +234,102 @@ async def index_linear_issues(
|
|||
logger.warning(
|
||||
f"Skipping issue with missing ID or title: {issue_id or 'Unknown'}"
|
||||
)
|
||||
skipped_issues.append(
|
||||
f"{issue_identifier or 'Unknown'} (missing data)"
|
||||
)
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Format the issue first to get well-structured data
|
||||
formatted_issue = linear_client.format_issue(issue)
|
||||
|
||||
# Convert issue to markdown format
|
||||
issue_content = linear_client.format_issue_to_markdown(formatted_issue)
|
||||
issue_content = linear_client.format_issue_to_markdown(
|
||||
formatted_issue
|
||||
)
|
||||
|
||||
if not issue_content:
|
||||
logger.warning(
|
||||
f"Skipping issue with no content: {issue_identifier} - {issue_title}"
|
||||
)
|
||||
skipped_issues.append(f"{issue_identifier} (no content)")
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Generate unique identifier hash for this Linear issue
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.LINEAR_CONNECTOR, issue_id, search_space_id
|
||||
)
|
||||
|
||||
# Generate content hash
|
||||
content_hash = generate_content_hash(issue_content, search_space_id)
|
||||
|
||||
# Check if document with this unique identifier already exists
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
state = formatted_issue.get("state", "Unknown")
|
||||
description = formatted_issue.get("description", "")
|
||||
comment_count = len(formatted_issue.get("comments", []))
|
||||
priority = formatted_issue.get("priority", "Unknown")
|
||||
|
||||
if existing_document:
|
||||
# Document exists - check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Ensure status is ready (might have been stuck in processing/pending)
|
||||
if not DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
logger.info(
|
||||
f"Document for Linear issue {issue_identifier} unchanged. Skipping."
|
||||
)
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
issues_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
"issue_content": issue_content,
|
||||
"content_hash": content_hash,
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"state": state,
|
||||
"description": description,
|
||||
"comment_count": comment_count,
|
||||
"priority": priority,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Document doesn't exist by unique_identifier_hash
|
||||
# Check if a document with the same content_hash exists (from another connector)
|
||||
with session.no_autoflush:
|
||||
duplicate_by_content = await check_duplicate_document_by_hash(
|
||||
session, content_hash
|
||||
)
|
||||
|
||||
if duplicate_by_content:
|
||||
logger.info(
|
||||
f"Linear issue {issue_identifier} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate_by_content.id}, "
|
||||
f"type: {duplicate_by_content.document_type}). Skipping."
|
||||
)
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Create new document with PENDING status (visible in UI immediately)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=f"{issue_identifier}: {issue_title}",
|
||||
document_type=DocumentType.LINEAR_CONNECTOR,
|
||||
document_metadata={
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"state": state,
|
||||
"comment_count": comment_count,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
chunks=[], # Empty at creation - safe for async
|
||||
status=DocumentStatus.pending(), # Pending until processing starts
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
doc = _build_connector_doc(
|
||||
issue,
|
||||
formatted_issue,
|
||||
issue_content,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
session.add(document)
|
||||
new_documents_created = True
|
||||
|
||||
issues_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
"issue_content": issue_content,
|
||||
"content_hash": content_hash,
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"state": state,
|
||||
"description": description,
|
||||
"comment_count": comment_count,
|
||||
"priority": priority,
|
||||
}
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=connector.enable_summary,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Phase 1 for issue: {e!s}", exc_info=True)
|
||||
documents_failed += 1
|
||||
with session.no_autoflush:
|
||||
duplicate = await check_duplicate_document_by_hash(
|
||||
session, compute_content_hash(doc)
|
||||
)
|
||||
if duplicate:
|
||||
logger.info(
|
||||
f"Linear issue {doc.title} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate.id}, "
|
||||
f"type: {duplicate.document_type}). Skipping."
|
||||
)
|
||||
duplicate_content_count += 1
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Commit all pending documents - they all appear in UI now
|
||||
if new_documents_created:
|
||||
logger.info(
|
||||
f"Phase 1: Committing {len([i for i in issues_to_process if i['is_new']])} pending documents"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(issues_to_process)} documents")
|
||||
|
||||
for item in issues_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
document = item["document"]
|
||||
try:
|
||||
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
|
||||
document.status = DocumentStatus.processing()
|
||||
await session.commit()
|
||||
|
||||
# Heavy processing (LLM, embeddings, chunks)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"issue_id": item["issue_identifier"],
|
||||
"issue_title": item["issue_title"],
|
||||
"state": item["state"],
|
||||
"priority": item["priority"],
|
||||
"comment_count": item["comment_count"],
|
||||
"document_type": "Linear Issue",
|
||||
"connector_type": "Linear",
|
||||
}
|
||||
(
|
||||
summary_content,
|
||||
summary_embedding,
|
||||
) = await generate_document_summary(
|
||||
item["issue_content"], user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = f"Linear Issue {item['issue_identifier']}: {item['issue_title']}\n\nStatus: {item['state']}\n\n{item['issue_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["issue_content"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = f"{item['issue_identifier']}: {item['issue_title']}"
|
||||
document.content = summary_content
|
||||
document.content_hash = item["content_hash"]
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"issue_id": item["issue_id"],
|
||||
"issue_identifier": item["issue_identifier"],
|
||||
"issue_title": item["issue_title"],
|
||||
"state": item["state"],
|
||||
"comment_count": item["comment_count"],
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
await safe_set_chunks(session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready()
|
||||
|
||||
documents_indexed += 1
|
||||
|
||||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Linear issues processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
connector_docs.append(doc)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing issue {item.get('issue_identifier', 'Unknown')}: {e!s}",
|
||||
f"Error building ConnectorDocument for issue: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
document.updated_at = get_current_timestamp()
|
||||
except Exception as status_error:
|
||||
logger.error(
|
||||
f"Failed to update document status to failed: {status_error}"
|
||||
)
|
||||
skipped_issues.append(
|
||||
f"{item.get('issue_identifier', 'Unknown')} (processing error)"
|
||||
)
|
||||
documents_failed += 1
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
|
||||
# ── Pipeline: migrate legacy docs + parallel index ────────────
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
await pipeline.migrate_legacy_docs(connector_docs)
|
||||
|
||||
async def _get_llm(s):
|
||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||
|
||||
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
|
||||
connector_docs,
|
||||
_get_llm,
|
||||
max_concurrency=3,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
|
||||
)
|
||||
|
||||
# ── Finalize ──────────────────────────────────────────────────
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
||||
# Final commit for any remaining documents not yet committed in batches
|
||||
logger.info(f"Final commit: Total {documents_indexed} Linear issues processed")
|
||||
logger.info(
|
||||
f"Final commit: Total {documents_indexed} Linear issues processed"
|
||||
)
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Successfully committed all Linear document changes to database"
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any remaining integrity errors gracefully (race conditions, etc.)
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in str(e).lower()
|
||||
or "uniqueviolationerror" in str(e).lower()
|
||||
):
|
||||
logger.warning(
|
||||
f"Duplicate content_hash detected during final commit. "
|
||||
f"This may occur if the same issue was indexed by multiple connectors. "
|
||||
f"Rolling back and continuing. Error: {e!s}"
|
||||
)
|
||||
await session.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
# Build warning message if there were issues
|
||||
warning_parts = []
|
||||
warning_parts: list[str] = []
|
||||
if duplicate_content_count > 0:
|
||||
warning_parts.append(f"{duplicate_content_count} duplicate")
|
||||
if documents_failed > 0:
|
||||
warning_parts.append(f"{documents_failed} failed")
|
||||
warning_message = ", ".join(warning_parts) if warning_parts else None
|
||||
|
||||
# Log success
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed Linear indexing for connector {connector_id}",
|
||||
|
|
@ -503,7 +338,7 @@ async def index_linear_issues(
|
|||
"documents_indexed": documents_indexed,
|
||||
"documents_skipped": documents_skipped,
|
||||
"documents_failed": documents_failed,
|
||||
"skipped_issues_count": len(skipped_issues),
|
||||
"duplicate_content_count": duplicate_content_count,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -511,7 +346,7 @@ async def index_linear_issues(
|
|||
f"Linear indexing completed: {documents_indexed} ready, "
|
||||
f"{documents_skipped} skipped, {documents_failed} failed"
|
||||
)
|
||||
return documents_indexed, warning_message
|
||||
return documents_indexed, documents_skipped, warning_message
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
@ -522,7 +357,7 @@ async def index_linear_issues(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, f"Database error: {db_error!s}"
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -532,4 +367,4 @@ async def index_linear_issues(
|
|||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Linear issues: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to index Linear issues: {e!s}"
|
||||
return 0, 0, f"Failed to index Linear issues: {e!s}"
|
||||
|
|
|
|||
|
|
@ -1,12 +1,10 @@
|
|||
"""
|
||||
Notion connector indexer.
|
||||
|
||||
Implements real-time document status updates using a two-phase approach:
|
||||
- Phase 1: Create all documents with PENDING status (visible in UI immediately)
|
||||
- Phase 2: Process each document one by one (pending → processing → ready/failed)
|
||||
Uses the shared IndexingPipelineService for document deduplication,
|
||||
summarization, chunking, and embedding with bounded parallel indexing.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -14,42 +12,64 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.notion_history import NotionHistoryConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.db import DocumentType, SearchSourceConnectorType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
from app.utils.notion_utils import process_blocks
|
||||
|
||||
from .base import (
|
||||
build_document_metadata_string,
|
||||
calculate_date_range,
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_connector_by_id,
|
||||
get_current_timestamp,
|
||||
logger,
|
||||
safe_set_chunks,
|
||||
update_connector_last_indexed,
|
||||
)
|
||||
|
||||
# Type alias for retry callback
|
||||
# Signature: async callback(retry_reason, attempt, max_attempts, wait_seconds) -> None
|
||||
RetryCallbackType = Callable[[str, int, int, float], Awaitable[None]]
|
||||
|
||||
# Type alias for heartbeat callback
|
||||
# Signature: async callback(indexed_count) -> None
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
|
||||
# Heartbeat interval in seconds - update notification every 30 seconds
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
|
||||
def _build_connector_doc(
|
||||
page: dict,
|
||||
markdown_content: str,
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
) -> ConnectorDocument:
|
||||
"""Map a raw Notion page dict to a ConnectorDocument."""
|
||||
page_id = page.get("page_id", "")
|
||||
page_title = page.get("title", f"Untitled page ({page_id})")
|
||||
|
||||
metadata = {
|
||||
"page_title": page_title,
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id,
|
||||
"document_type": "Notion Page",
|
||||
"connector_type": "Notion",
|
||||
}
|
||||
|
||||
fallback_summary = f"Notion Page: {page_title}\n\n{markdown_content}"
|
||||
|
||||
return ConnectorDocument(
|
||||
title=page_title,
|
||||
source_markdown=markdown_content,
|
||||
unique_id=page_id,
|
||||
document_type=DocumentType.NOTION_CONNECTOR,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=enable_summary,
|
||||
fallback_summary=fallback_summary,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def index_notion_pages(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
|
|
@ -60,30 +80,15 @@ async def index_notion_pages(
|
|||
update_last_indexed: bool = True,
|
||||
on_retry_callback: RetryCallbackType | None = None,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, str | None]:
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""
|
||||
Index Notion pages from all accessible pages.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Notion connector
|
||||
search_space_id: ID of the search space to store documents in
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing (YYYY-MM-DD format)
|
||||
end_date: End date for indexing (YYYY-MM-DD format)
|
||||
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
|
||||
on_retry_callback: Optional callback for retry progress notifications.
|
||||
Signature: async callback(retry_reason, attempt, max_attempts, wait_seconds)
|
||||
retry_reason is one of: 'rate_limit', 'server_error', 'timeout'
|
||||
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
|
||||
Called periodically with (indexed_count) to prevent task appearing stuck.
|
||||
|
||||
Returns:
|
||||
Tuple containing (number of documents indexed, error message or None)
|
||||
Tuple of (indexed_count, skipped_count, warning_or_error_message)
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="notion_pages_indexing",
|
||||
source="connector_indexing_task",
|
||||
|
|
@ -97,7 +102,7 @@ async def index_notion_pages(
|
|||
)
|
||||
|
||||
try:
|
||||
# Get the connector
|
||||
# ── Connector lookup ──────────────────────────────────────────
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Retrieving Notion connector {connector_id} from database",
|
||||
|
|
@ -116,11 +121,11 @@ async def index_notion_pages(
|
|||
{"error_type": "ConnectorNotFound"},
|
||||
)
|
||||
return (
|
||||
0,
|
||||
0,
|
||||
f"Connector with ID {connector_id} not found or is not a Notion connector",
|
||||
)
|
||||
|
||||
# Check if access_token exists (support both new OAuth format and old integration token format)
|
||||
if not connector.config.get("access_token") and not connector.config.get(
|
||||
"NOTION_INTEGRATION_TOKEN"
|
||||
):
|
||||
|
|
@ -130,9 +135,9 @@ async def index_notion_pages(
|
|||
"Missing Notion access token",
|
||||
{"error_type": "MissingToken"},
|
||||
)
|
||||
return 0, "Notion access token not found in connector config"
|
||||
return 0, 0, "Notion access token not found in connector config"
|
||||
|
||||
# Initialize Notion client with internal refresh capability
|
||||
# ── Client init ───────────────────────────────────────────────
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Initializing Notion client for connector {connector_id}",
|
||||
|
|
@ -141,18 +146,15 @@ async def index_notion_pages(
|
|||
|
||||
logger.info(f"Initializing Notion client for connector {connector_id}")
|
||||
|
||||
# Handle 'undefined' string from frontend (treat as None)
|
||||
if start_date == "undefined" or start_date == "":
|
||||
start_date = None
|
||||
if end_date == "undefined" or end_date == "":
|
||||
end_date = None
|
||||
|
||||
# Calculate date range using the shared utility function
|
||||
start_date_str, end_date_str = calculate_date_range(
|
||||
connector, start_date, end_date, default_days_back=365
|
||||
)
|
||||
|
||||
# Convert YYYY-MM-DD to ISO format for Notion API
|
||||
start_date_iso = datetime.strptime(start_date_str, "%Y-%m-%d").strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ"
|
||||
)
|
||||
|
|
@ -160,13 +162,10 @@ async def index_notion_pages(
|
|||
"%Y-%m-%dT%H:%M:%SZ"
|
||||
)
|
||||
|
||||
# Create connector with session and connector_id for internal refresh
|
||||
# Token refresh will happen automatically when needed
|
||||
notion_client = NotionHistoryConnector(
|
||||
session=session, connector_id=connector_id
|
||||
)
|
||||
|
||||
# Set retry callback if provided (for user notifications during rate limits)
|
||||
if on_retry_callback:
|
||||
notion_client.set_retry_callback(on_retry_callback)
|
||||
|
||||
|
|
@ -182,21 +181,19 @@ async def index_notion_pages(
|
|||
},
|
||||
)
|
||||
|
||||
# Get all pages
|
||||
# ── Fetch pages ───────────────────────────────────────────────
|
||||
try:
|
||||
pages = await notion_client.get_all_pages(
|
||||
start_date=start_date_iso, end_date=end_date_iso
|
||||
)
|
||||
logger.info(f"Found {len(pages)} Notion pages")
|
||||
|
||||
# Get count of pages that had unsupported content skipped
|
||||
pages_with_skipped_content = notion_client.get_skipped_content_count()
|
||||
if pages_with_skipped_content > 0:
|
||||
logger.info(
|
||||
f"{pages_with_skipped_content} pages had Notion AI content skipped (not available via API)"
|
||||
)
|
||||
|
||||
# Check if using legacy integration token and log warning
|
||||
if notion_client.is_using_legacy_token():
|
||||
logger.warning(
|
||||
f"Connector {connector_id} is using legacy integration token. "
|
||||
|
|
@ -204,8 +201,6 @@ async def index_notion_pages(
|
|||
)
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
# Check if this is an unsupported block type error (transcription, ai_block, etc.)
|
||||
# These are known Notion API limitations and should be logged as warnings, not errors
|
||||
unsupported_block_errors = [
|
||||
"transcription is not supported",
|
||||
"ai_block is not supported",
|
||||
|
|
@ -216,7 +211,6 @@ async def index_notion_pages(
|
|||
)
|
||||
|
||||
if is_unsupported_block_error:
|
||||
# Log as warning since this is a known Notion API limitation
|
||||
logger.warning(
|
||||
f"Notion API limitation for connector {connector_id}: {error_str}. "
|
||||
"This is a known issue with Notion AI blocks (transcription, ai_block) "
|
||||
|
|
@ -229,7 +223,6 @@ async def index_notion_pages(
|
|||
{"error_type": "UnsupportedBlockType", "is_known_limitation": True},
|
||||
)
|
||||
else:
|
||||
# Log as error for other failures
|
||||
logger.error(
|
||||
f"Error fetching Notion pages for connector {connector_id}: {error_str}",
|
||||
exc_info=True,
|
||||
|
|
@ -242,7 +235,7 @@ async def index_notion_pages(
|
|||
)
|
||||
|
||||
await notion_client.close()
|
||||
return 0, f"Failed to get Notion pages: {e!s}"
|
||||
return 0, 0, f"Failed to get Notion pages: {e!s}"
|
||||
|
||||
if not pages:
|
||||
await task_logger.log_task_success(
|
||||
|
|
@ -252,21 +245,17 @@ async def index_notion_pages(
|
|||
{"pages_found": 0},
|
||||
)
|
||||
logger.info("No Notion pages found to index")
|
||||
# CRITICAL: Update timestamp even when no pages found so Zero syncs
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
await update_connector_last_indexed(
|
||||
session, connector, update_last_indexed
|
||||
)
|
||||
await session.commit()
|
||||
await notion_client.close()
|
||||
return 0, None # Success with 0 pages, not an error
|
||||
return 0, 0, None
|
||||
|
||||
# Track the number of documents indexed
|
||||
documents_indexed = 0
|
||||
# ── Build ConnectorDocuments ──────────────────────────────────
|
||||
connector_docs: list[ConnectorDocument] = []
|
||||
documents_skipped = 0
|
||||
documents_failed = 0
|
||||
duplicate_content_count = 0
|
||||
skipped_pages = []
|
||||
|
||||
# Heartbeat tracking - update notification periodically to prevent appearing stuck
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
|
|
@ -274,13 +263,6 @@ async def index_notion_pages(
|
|||
{"stage": "process_pages", "total_pages": len(pages)},
|
||||
)
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Analyze all pages, create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# =======================================================================
|
||||
pages_to_process = [] # List of dicts with document and page data
|
||||
new_documents_created = False
|
||||
|
||||
for page in pages:
|
||||
try:
|
||||
page_id = page.get("page_id")
|
||||
|
|
@ -293,225 +275,71 @@ async def index_notion_pages(
|
|||
|
||||
if not page_content:
|
||||
logger.info(f"No content found in page {page_title}. Skipping.")
|
||||
skipped_pages.append(f"{page_title} (no content)")
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Convert page content to markdown format
|
||||
markdown_content = f"# Notion Page: {page_title}\n\n"
|
||||
markdown_content += process_blocks(page_content)
|
||||
|
||||
# Format document metadata
|
||||
metadata_sections = [
|
||||
("METADATA", [f"PAGE_TITLE: {page_title}", f"PAGE_ID: {page_id}"]),
|
||||
(
|
||||
"CONTENT",
|
||||
[
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
markdown_content,
|
||||
"TEXT_END",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Build the document string
|
||||
combined_document_string = build_document_metadata_string(
|
||||
metadata_sections
|
||||
if not markdown_content.strip():
|
||||
logger.warning(
|
||||
f"Skipping page with empty markdown: {page_title}"
|
||||
)
|
||||
|
||||
# Generate unique identifier hash for this Notion page
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTION_CONNECTOR, page_id, search_space_id
|
||||
)
|
||||
|
||||
# Generate content hash
|
||||
content_hash = generate_content_hash(
|
||||
combined_document_string, search_space_id
|
||||
)
|
||||
|
||||
# Check if document with this unique identifier already exists
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
if existing_document:
|
||||
# Document exists - check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Ensure status is ready (might have been stuck in processing/pending)
|
||||
if not DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
pages_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"page_id": page_id,
|
||||
"page_title": page_title,
|
||||
}
|
||||
doc = _build_connector_doc(
|
||||
page,
|
||||
markdown_content,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=connector.enable_summary,
|
||||
)
|
||||
continue
|
||||
|
||||
# Document doesn't exist by unique_identifier_hash
|
||||
# Check if a document with the same content_hash exists (from another connector)
|
||||
with session.no_autoflush:
|
||||
duplicate_by_content = await check_duplicate_document_by_hash(
|
||||
session, content_hash
|
||||
duplicate = await check_duplicate_document_by_hash(
|
||||
session, compute_content_hash(doc)
|
||||
)
|
||||
|
||||
if duplicate_by_content:
|
||||
if duplicate:
|
||||
logger.info(
|
||||
f"Notion page {page_title} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate_by_content.id}, "
|
||||
f"type: {duplicate_by_content.document_type}). Skipping."
|
||||
f"Notion page {doc.title} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate.id}, "
|
||||
f"type: {duplicate.document_type}). Skipping."
|
||||
)
|
||||
duplicate_content_count += 1
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Create new document with PENDING status (visible in UI immediately)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=page_title,
|
||||
document_type=DocumentType.NOTION_CONNECTOR,
|
||||
document_metadata={
|
||||
"page_title": page_title,
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
chunks=[], # Empty at creation - safe for async
|
||||
status=DocumentStatus.pending(), # Pending until processing starts
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
session.add(document)
|
||||
new_documents_created = True
|
||||
|
||||
pages_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"page_id": page_id,
|
||||
"page_title": page_title,
|
||||
}
|
||||
)
|
||||
connector_docs.append(doc)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Phase 1 for page: {e!s}", exc_info=True)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
# Commit all pending documents - they all appear in UI now
|
||||
if new_documents_created:
|
||||
logger.info(
|
||||
f"Phase 1: Committing {len([p for p in pages_to_process if p['is_new']])} pending documents"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(pages_to_process)} documents")
|
||||
|
||||
for item in pages_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
document = item["document"]
|
||||
try:
|
||||
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
|
||||
document.status = DocumentStatus.processing()
|
||||
await session.commit()
|
||||
|
||||
# Heavy processing (LLM, embeddings, chunks)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"page_title": item["page_title"],
|
||||
"page_id": item["page_id"],
|
||||
"document_type": "Notion Page",
|
||||
"connector_type": "Notion",
|
||||
}
|
||||
(
|
||||
summary_content,
|
||||
summary_embedding,
|
||||
) = await generate_document_summary(
|
||||
item["markdown_content"],
|
||||
user_llm,
|
||||
document_metadata_for_summary,
|
||||
)
|
||||
else:
|
||||
summary_content = f"Notion Page: {item['page_title']}\n\n{item['markdown_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = item["page_title"]
|
||||
document.content = summary_content
|
||||
document.content_hash = item["content_hash"]
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"page_title": item["page_title"],
|
||||
"page_id": item["page_id"],
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
await safe_set_chunks(session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready()
|
||||
|
||||
documents_indexed += 1
|
||||
|
||||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Notion pages processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Notion page: {e!s}", exc_info=True)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
document.updated_at = get_current_timestamp()
|
||||
except Exception as status_error:
|
||||
logger.error(
|
||||
f"Failed to update document status to failed: {status_error}"
|
||||
f"Error building ConnectorDocument for page: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
skipped_pages.append(f"{item['page_title']} (processing error)")
|
||||
documents_failed += 1
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
|
||||
# ── Pipeline: migrate legacy docs + parallel index ────────────
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
await pipeline.migrate_legacy_docs(connector_docs)
|
||||
|
||||
async def _get_llm(s):
|
||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||
|
||||
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
|
||||
connector_docs,
|
||||
_get_llm,
|
||||
max_concurrency=3,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
|
||||
)
|
||||
|
||||
# ── Finalize ──────────────────────────────────────────────────
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
||||
total_processed = documents_indexed
|
||||
|
||||
# Final commit to ensure all documents are persisted (safety net)
|
||||
logger.info(f"Final commit: Total {documents_indexed} documents processed")
|
||||
try:
|
||||
await session.commit()
|
||||
|
|
@ -519,59 +347,53 @@ async def index_notion_pages(
|
|||
"Successfully committed all Notion document changes to database"
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any remaining integrity errors gracefully (race conditions, etc.)
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in str(e).lower()
|
||||
or "uniqueviolationerror" in str(e).lower()
|
||||
):
|
||||
logger.warning(
|
||||
f"Duplicate content_hash detected during final commit. "
|
||||
f"This may occur if the same page was indexed by multiple connectors. "
|
||||
f"Rolling back and continuing. Error: {e!s}"
|
||||
)
|
||||
await session.rollback()
|
||||
# Don't fail the entire task - some documents may have been successfully indexed
|
||||
else:
|
||||
raise
|
||||
|
||||
# Get final count of pages with skipped Notion AI content
|
||||
# ── Build warning / notification message ──────────────────────
|
||||
pages_with_skipped_ai_content = notion_client.get_skipped_content_count()
|
||||
|
||||
# Build warning message if there were issues
|
||||
warning_parts = []
|
||||
warning_parts: list[str] = []
|
||||
if duplicate_content_count > 0:
|
||||
warning_parts.append(f"{duplicate_content_count} duplicate")
|
||||
if documents_failed > 0:
|
||||
warning_parts.append(f"{documents_failed} failed")
|
||||
warning_message = ", ".join(warning_parts) if warning_parts else None
|
||||
|
||||
# Prepare result message with user-friendly notification about skipped content
|
||||
result_message = None
|
||||
if skipped_pages:
|
||||
result_message = f"Processed {total_processed} pages. Skipped {len(skipped_pages)} pages: {', '.join(skipped_pages)}"
|
||||
else:
|
||||
result_message = f"Processed {total_processed} pages."
|
||||
|
||||
# Add user-friendly message about skipped Notion AI content
|
||||
notification_parts: list[str] = []
|
||||
if pages_with_skipped_ai_content > 0:
|
||||
result_message += (
|
||||
" Audio transcriptions and AI summaries from Notion aren't accessible "
|
||||
"via their API - all other content was saved."
|
||||
notification_parts.append(
|
||||
"Some Notion AI content couldn't be synced (API limitation)"
|
||||
)
|
||||
if notion_client.is_using_legacy_token():
|
||||
notification_parts.append(
|
||||
"Using legacy token. Reconnect with OAuth for better reliability."
|
||||
)
|
||||
if warning_parts:
|
||||
notification_parts.append(", ".join(warning_parts))
|
||||
|
||||
user_notification_message = (
|
||||
" ".join(notification_parts) if notification_parts else None
|
||||
)
|
||||
|
||||
# Log success
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed Notion indexing for connector {connector_id}",
|
||||
{
|
||||
"pages_processed": total_processed,
|
||||
"pages_processed": documents_indexed,
|
||||
"documents_indexed": documents_indexed,
|
||||
"documents_skipped": documents_skipped,
|
||||
"documents_failed": documents_failed,
|
||||
"duplicate_content_count": duplicate_content_count,
|
||||
"skipped_pages_count": len(skipped_pages),
|
||||
"pages_with_skipped_ai_content": pages_with_skipped_ai_content,
|
||||
"result_message": result_message,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -581,35 +403,9 @@ async def index_notion_pages(
|
|||
f"({duplicate_content_count} duplicate content)"
|
||||
)
|
||||
|
||||
# Clean up the async client
|
||||
await notion_client.close()
|
||||
|
||||
# Build user-friendly notification messages
|
||||
# This will be shown in the notification to inform users
|
||||
notification_parts = []
|
||||
|
||||
if pages_with_skipped_ai_content > 0:
|
||||
notification_parts.append(
|
||||
"Some Notion AI content couldn't be synced (API limitation)"
|
||||
)
|
||||
|
||||
if notion_client.is_using_legacy_token():
|
||||
notification_parts.append(
|
||||
"Using legacy token. Reconnect with OAuth for better reliability."
|
||||
)
|
||||
|
||||
# Include warning message if there were issues
|
||||
if warning_message:
|
||||
notification_parts.append(warning_message)
|
||||
|
||||
user_notification_message = (
|
||||
" ".join(notification_parts) if notification_parts else None
|
||||
)
|
||||
|
||||
return (
|
||||
total_processed,
|
||||
user_notification_message,
|
||||
)
|
||||
return documents_indexed, documents_skipped, user_notification_message
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
@ -622,10 +418,9 @@ async def index_notion_pages(
|
|||
logger.error(
|
||||
f"Database error during Notion indexing: {db_error!s}", exc_info=True
|
||||
)
|
||||
# Clean up the async client in case of error
|
||||
if "notion_client" in locals():
|
||||
await notion_client.close()
|
||||
return 0, f"Database error: {db_error!s}"
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -635,7 +430,6 @@ async def index_notion_pages(
|
|||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Notion pages: {e!s}", exc_info=True)
|
||||
# Clean up the async client in case of error
|
||||
if "notion_client" in locals():
|
||||
await notion_client.close()
|
||||
return 0, f"Failed to index Notion pages: {e!s}"
|
||||
return 0, 0, f"Failed to index Notion pages: {e!s}"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -11,6 +12,12 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HuggingFace fast tokenizers (Rust-backed) are not thread-safe — concurrent
|
||||
# access from multiple threads causes "RuntimeError: Already borrowed".
|
||||
# This reentrant lock serialises tokenizer + embedding model access so that
|
||||
# asyncio.to_thread calls from index_batch_parallel don't collide.
|
||||
_embedding_lock = threading.RLock()
|
||||
|
||||
|
||||
def _get_embedding_max_tokens() -> int:
|
||||
"""Get the max token limit for the configured embedding model.
|
||||
|
|
@ -36,6 +43,7 @@ def truncate_for_embedding(text: str) -> str:
|
|||
if len(text) // 3 <= max_tokens:
|
||||
return text
|
||||
|
||||
with _embedding_lock:
|
||||
tokenizer = config.embedding_model_instance.get_tokenizer()
|
||||
tokens = tokenizer.encode(text)
|
||||
if len(tokens) <= max_tokens:
|
||||
|
|
@ -52,6 +60,7 @@ def embed_text(text: str) -> np.ndarray:
|
|||
"""Truncate text to fit and embed it. Drop-in replacement for
|
||||
``config.embedding_model_instance.embed(text)`` that never exceeds the
|
||||
model's context window."""
|
||||
with _embedding_lock:
|
||||
return config.embedding_model_instance.embed(truncate_for_embedding(text))
|
||||
|
||||
|
||||
|
|
@ -66,6 +75,7 @@ def embed_texts(texts: list[str]) -> list[np.ndarray]:
|
|||
"""
|
||||
if not texts:
|
||||
return []
|
||||
with _embedding_lock:
|
||||
truncated = [truncate_for_embedding(t) for t in texts]
|
||||
if config.is_local_embedding_model:
|
||||
return [config.embedding_model_instance.embed(t) for t in truncated]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,111 @@
|
|||
"""Integration tests: Calendar indexer builds ConnectorDocuments that flow through the pipeline."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _cal_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
|
||||
return ConnectorDocument(
|
||||
title=f"Event {unique_id}",
|
||||
source_markdown=f"## Calendar Event\n\nDetails for {unique_id}",
|
||||
unique_id=unique_id,
|
||||
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=True,
|
||||
fallback_summary=f"Calendar: Event {unique_id}",
|
||||
metadata={
|
||||
"event_id": unique_id,
|
||||
"start_time": "2025-01-15T10:00:00",
|
||||
"end_time": "2025-01-15T11:00:00",
|
||||
"document_type": "Google Calendar Event",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
async def test_calendar_pipeline_creates_ready_document(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
"""A Calendar ConnectorDocument flows through prepare + index to a READY document."""
|
||||
space_id = db_search_space.id
|
||||
doc = _cal_doc(
|
||||
unique_id="evt-1",
|
||||
search_space_id=space_id,
|
||||
connector_id=db_connector.id,
|
||||
user_id=str(db_user.id),
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
prepared = await service.prepare_for_indexing([doc])
|
||||
assert len(prepared) == 1
|
||||
|
||||
await service.index(prepared[0], doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == space_id)
|
||||
)
|
||||
row = result.scalars().first()
|
||||
|
||||
assert row is not None
|
||||
assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR
|
||||
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
async def test_calendar_legacy_doc_migrated(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
"""A legacy Composio Calendar doc is migrated and reused."""
|
||||
space_id = db_search_space.id
|
||||
user_id = str(db_user.id)
|
||||
evt_id = "evt-legacy-cal"
|
||||
|
||||
legacy_hash = compute_identifier_hash(
|
||||
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR.value, evt_id, space_id
|
||||
)
|
||||
legacy_doc = Document(
|
||||
title="Old Calendar Event",
|
||||
document_type=DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
content="old summary",
|
||||
content_hash=f"ch-{legacy_hash[:12]}",
|
||||
unique_identifier_hash=legacy_hash,
|
||||
source_markdown="## Old event",
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
embedding=[0.1] * _EMBEDDING_DIM,
|
||||
status={"state": "ready"},
|
||||
)
|
||||
db_session.add(legacy_doc)
|
||||
await db_session.flush()
|
||||
original_id = legacy_doc.id
|
||||
|
||||
connector_doc = _cal_doc(
|
||||
unique_id=evt_id,
|
||||
search_space_id=space_id,
|
||||
connector_id=db_connector.id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
await service.migrate_legacy_docs([connector_doc])
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
||||
row = result.scalars().first()
|
||||
|
||||
assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR
|
||||
native_hash = compute_identifier_hash(
|
||||
DocumentType.GOOGLE_CALENDAR_CONNECTOR.value, evt_id, space_id
|
||||
)
|
||||
assert row.unique_identifier_hash == native_hash
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
"""Integration tests: Drive indexer builds ConnectorDocuments that flow through the pipeline."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _drive_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
|
||||
return ConnectorDocument(
|
||||
title=f"File {unique_id}.pdf",
|
||||
source_markdown=f"## Document Content\n\nText from file {unique_id}",
|
||||
unique_id=unique_id,
|
||||
document_type=DocumentType.GOOGLE_DRIVE_FILE,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=True,
|
||||
fallback_summary=f"File: {unique_id}.pdf",
|
||||
metadata={
|
||||
"google_drive_file_id": unique_id,
|
||||
"google_drive_file_name": f"{unique_id}.pdf",
|
||||
"document_type": "Google Drive File",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
async def test_drive_pipeline_creates_ready_document(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
"""A Drive ConnectorDocument flows through prepare + index to a READY document."""
|
||||
space_id = db_search_space.id
|
||||
doc = _drive_doc(
|
||||
unique_id="file-abc",
|
||||
search_space_id=space_id,
|
||||
connector_id=db_connector.id,
|
||||
user_id=str(db_user.id),
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
prepared = await service.prepare_for_indexing([doc])
|
||||
assert len(prepared) == 1
|
||||
|
||||
await service.index(prepared[0], doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == space_id)
|
||||
)
|
||||
row = result.scalars().first()
|
||||
|
||||
assert row is not None
|
||||
assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE
|
||||
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
async def test_drive_legacy_doc_migrated(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
"""A legacy Composio Drive doc is migrated and reused."""
|
||||
space_id = db_search_space.id
|
||||
user_id = str(db_user.id)
|
||||
file_id = "file-legacy-drive"
|
||||
|
||||
legacy_hash = compute_identifier_hash(
|
||||
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR.value, file_id, space_id
|
||||
)
|
||||
legacy_doc = Document(
|
||||
title="Old Drive File",
|
||||
document_type=DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
content="old file summary",
|
||||
content_hash=f"ch-{legacy_hash[:12]}",
|
||||
unique_identifier_hash=legacy_hash,
|
||||
source_markdown="## Old file content",
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
embedding=[0.1] * _EMBEDDING_DIM,
|
||||
status={"state": "ready"},
|
||||
)
|
||||
db_session.add(legacy_doc)
|
||||
await db_session.flush()
|
||||
original_id = legacy_doc.id
|
||||
|
||||
connector_doc = _drive_doc(
|
||||
unique_id=file_id,
|
||||
search_space_id=space_id,
|
||||
connector_id=db_connector.id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
await service.migrate_legacy_docs([connector_doc])
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
||||
row = result.scalars().first()
|
||||
|
||||
assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE
|
||||
native_hash = compute_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE.value, file_id, space_id
|
||||
)
|
||||
assert row.unique_identifier_hash == native_hash
|
||||
|
||||
|
||||
async def test_should_skip_file_skips_failed_document(
|
||||
db_session, db_search_space, db_user,
|
||||
):
|
||||
"""A FAILED document with unchanged md5 must be skipped — user can manually retry via Quick Index."""
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
pkg = "app.tasks.connector_indexers"
|
||||
stub = pkg not in sys.modules
|
||||
if stub:
|
||||
mod = types.ModuleType(pkg)
|
||||
mod.__path__ = ["app/tasks/connector_indexers"]
|
||||
mod.__package__ = pkg
|
||||
sys.modules[pkg] = mod
|
||||
|
||||
try:
|
||||
gdm = importlib.import_module(
|
||||
"app.tasks.connector_indexers.google_drive_indexer"
|
||||
)
|
||||
_should_skip_file = gdm._should_skip_file
|
||||
finally:
|
||||
if stub:
|
||||
sys.modules.pop(pkg, None)
|
||||
|
||||
space_id = db_search_space.id
|
||||
file_id = "file-failed-drive"
|
||||
md5 = "abc123deadbeef"
|
||||
|
||||
doc_hash = compute_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE.value, file_id, space_id
|
||||
)
|
||||
failed_doc = Document(
|
||||
title="Failed File.pdf",
|
||||
document_type=DocumentType.GOOGLE_DRIVE_FILE,
|
||||
content="LLM rate limit exceeded",
|
||||
content_hash=f"ch-{doc_hash[:12]}",
|
||||
unique_identifier_hash=doc_hash,
|
||||
source_markdown="## Real content",
|
||||
search_space_id=space_id,
|
||||
created_by_id=str(db_user.id),
|
||||
embedding=[0.1] * _EMBEDDING_DIM,
|
||||
status=DocumentStatus.failed("LLM rate limit exceeded"),
|
||||
document_metadata={
|
||||
"google_drive_file_id": file_id,
|
||||
"google_drive_file_name": "Failed File.pdf",
|
||||
"md5_checksum": md5,
|
||||
},
|
||||
)
|
||||
db_session.add(failed_doc)
|
||||
await db_session.flush()
|
||||
|
||||
incoming_file = {"id": file_id, "name": "Failed File.pdf", "mimeType": "application/pdf", "md5Checksum": md5}
|
||||
|
||||
should_skip, msg = await _should_skip_file(db_session, incoming_file, space_id)
|
||||
|
||||
assert should_skip, "FAILED documents must be skipped during automatic sync"
|
||||
assert "failed" in msg.lower()
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
"""Integration tests: Gmail indexer builds ConnectorDocuments that flow through the pipeline."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_identifier_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _gmail_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
|
||||
"""Build a Gmail-style ConnectorDocument like the real indexer does."""
|
||||
return ConnectorDocument(
|
||||
title=f"Subject for {unique_id}",
|
||||
source_markdown=f"## Email\n\nBody of {unique_id}",
|
||||
unique_id=unique_id,
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=True,
|
||||
fallback_summary=f"Gmail: Subject for {unique_id}",
|
||||
metadata={
|
||||
"message_id": unique_id,
|
||||
"from": "sender@example.com",
|
||||
"document_type": "Gmail Message",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
async def test_gmail_pipeline_creates_ready_document(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
"""A Gmail ConnectorDocument flows through prepare + index to a READY document."""
|
||||
space_id = db_search_space.id
|
||||
doc = _gmail_doc(
|
||||
unique_id="msg-pipeline-1",
|
||||
search_space_id=space_id,
|
||||
connector_id=db_connector.id,
|
||||
user_id=str(db_user.id),
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
prepared = await service.prepare_for_indexing([doc])
|
||||
assert len(prepared) == 1
|
||||
|
||||
await service.index(prepared[0], doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == space_id)
|
||||
)
|
||||
row = result.scalars().first()
|
||||
|
||||
assert row is not None
|
||||
assert row.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
|
||||
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
|
||||
assert row.source_markdown == doc.source_markdown
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
async def test_gmail_legacy_doc_migrated_then_reused(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
"""A legacy Composio Gmail doc is migrated then reused by the pipeline."""
|
||||
space_id = db_search_space.id
|
||||
user_id = str(db_user.id)
|
||||
msg_id = "msg-legacy-gmail"
|
||||
|
||||
legacy_hash = compute_identifier_hash(
|
||||
DocumentType.COMPOSIO_GMAIL_CONNECTOR.value, msg_id, space_id
|
||||
)
|
||||
legacy_doc = Document(
|
||||
title="Old Gmail",
|
||||
document_type=DocumentType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
content="old summary",
|
||||
content_hash=f"ch-{legacy_hash[:12]}",
|
||||
unique_identifier_hash=legacy_hash,
|
||||
source_markdown="## Old content",
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
embedding=[0.1] * _EMBEDDING_DIM,
|
||||
status={"state": "ready"},
|
||||
)
|
||||
db_session.add(legacy_doc)
|
||||
await db_session.flush()
|
||||
original_id = legacy_doc.id
|
||||
|
||||
connector_doc = _gmail_doc(
|
||||
unique_id=msg_id,
|
||||
search_space_id=space_id,
|
||||
connector_id=db_connector.id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
await service.migrate_legacy_docs([connector_doc])
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
assert len(prepared) == 1
|
||||
assert prepared[0].id == original_id
|
||||
assert prepared[0].document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
|
||||
|
||||
native_hash = compute_identifier_hash(
|
||||
DocumentType.GOOGLE_GMAIL_CONNECTOR.value, msg_id, space_id
|
||||
)
|
||||
assert prepared[0].unique_identifier_hash == native_hash
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""Integration tests for IndexingPipelineService.index_batch()."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
async def test_index_batch_creates_ready_documents(
|
||||
db_session, db_search_space, make_connector_document, mocker
|
||||
):
|
||||
"""index_batch prepares and indexes a batch, resulting in READY documents."""
|
||||
space_id = db_search_space.id
|
||||
docs = [
|
||||
make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id="msg-batch-1",
|
||||
search_space_id=space_id,
|
||||
source_markdown="## Email 1\n\nBody",
|
||||
),
|
||||
make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id="msg-batch-2",
|
||||
search_space_id=space_id,
|
||||
source_markdown="## Email 2\n\nDifferent body",
|
||||
),
|
||||
]
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
results = await service.index_batch(docs, llm=mocker.Mock())
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == space_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
assert len(rows) == 2
|
||||
|
||||
for row in rows:
|
||||
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
|
||||
assert row.content is not None
|
||||
assert row.embedding is not None
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
async def test_index_batch_empty_returns_empty(db_session, mocker):
|
||||
"""index_batch with empty input returns an empty list."""
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
results = await service.index_batch([], llm=mocker.Mock())
|
||||
assert results == []
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
"""Integration tests for IndexingPipelineService.migrate_legacy_docs()."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_legacy_composio_gmail_doc_migrated_in_db(
|
||||
db_session, db_search_space, db_user, make_connector_document
|
||||
):
|
||||
"""A Composio Gmail doc in the DB gets its hash and type updated to native."""
|
||||
space_id = db_search_space.id
|
||||
user_id = str(db_user.id)
|
||||
unique_id = "msg-legacy-123"
|
||||
|
||||
legacy_hash = compute_identifier_hash(
|
||||
DocumentType.COMPOSIO_GMAIL_CONNECTOR.value, unique_id, space_id
|
||||
)
|
||||
native_hash = compute_identifier_hash(
|
||||
DocumentType.GOOGLE_GMAIL_CONNECTOR.value, unique_id, space_id
|
||||
)
|
||||
|
||||
legacy_doc = Document(
|
||||
title="Old Gmail",
|
||||
document_type=DocumentType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
content="legacy content",
|
||||
content_hash=f"ch-{legacy_hash[:12]}",
|
||||
unique_identifier_hash=legacy_hash,
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
embedding=[0.1] * _EMBEDDING_DIM,
|
||||
status={"state": "ready"},
|
||||
)
|
||||
db_session.add(legacy_doc)
|
||||
await db_session.flush()
|
||||
doc_id = legacy_doc.id
|
||||
|
||||
connector_doc = make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id=unique_id,
|
||||
search_space_id=space_id,
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
await service.migrate_legacy_docs([connector_doc])
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == doc_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.unique_identifier_hash == native_hash
|
||||
assert reloaded.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
|
||||
|
||||
|
||||
async def test_no_legacy_doc_is_noop(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""When no legacy document exists, migrate_legacy_docs does nothing."""
|
||||
connector_doc = make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
unique_id="evt-no-legacy",
|
||||
search_space_id=db_search_space.id,
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
await service.migrate_legacy_docs([connector_doc])
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
assert result.scalars().all() == []
|
||||
|
||||
|
||||
async def test_non_google_type_is_skipped(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""migrate_legacy_docs skips ConnectorDocuments that are not Google types."""
|
||||
connector_doc = make_connector_document(
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
unique_id="task-1",
|
||||
search_space_id=db_search_space.id,
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
await service.migrate_legacy_docs([connector_doc])
|
||||
34
surfsense_backend/tests/unit/connector_indexers/conftest.py
Normal file
34
surfsense_backend/tests/unit/connector_indexers/conftest.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""Pre-register the connector_indexers package to bypass a circular import
|
||||
in its ``__init__.py`` (airtable_indexer -> routes -> connector_indexers).
|
||||
|
||||
This lets tests import individual indexer modules (e.g.
|
||||
``google_drive_indexer``) without triggering the full package init.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
_BACKEND = Path(__file__).resolve().parents[3]
|
||||
|
||||
|
||||
def _stub_package(dotted: str, fs_dir: Path) -> None:
|
||||
if dotted not in sys.modules:
|
||||
mod = types.ModuleType(dotted)
|
||||
mod.__path__ = [str(fs_dir)]
|
||||
mod.__package__ = dotted
|
||||
sys.modules[dotted] = mod
|
||||
|
||||
parts = dotted.split(".")
|
||||
if len(parts) > 1:
|
||||
parent_dotted = ".".join(parts[:-1])
|
||||
parent = sys.modules.get(parent_dotted)
|
||||
if parent is not None:
|
||||
setattr(parent, parts[-1], sys.modules[dotted])
|
||||
|
||||
|
||||
_stub_package("app.tasks", _BACKEND / "app" / "tasks")
|
||||
_stub_package(
|
||||
"app.tasks.connector_indexers",
|
||||
_BACKEND / "app" / "tasks" / "connector_indexers",
|
||||
)
|
||||
|
|
@ -0,0 +1,373 @@
|
|||
"""Tests for Confluence indexer migrated to the unified parallel pipeline."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import app.tasks.connector_indexers.confluence_indexer as _mod
|
||||
from app.db import DocumentType
|
||||
from app.tasks.connector_indexers.confluence_indexer import (
|
||||
_build_connector_doc,
|
||||
index_confluence_pages,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_USER_ID = "00000000-0000-0000-0000-000000000001"
|
||||
_CONNECTOR_ID = 42
|
||||
_SEARCH_SPACE_ID = 1
|
||||
|
||||
|
||||
def _make_page(
|
||||
page_id: str = "p1",
|
||||
title: str = "Home",
|
||||
space_id: str = "S1",
|
||||
body: str = "<p>Hello</p>",
|
||||
comments=None,
|
||||
):
|
||||
return {
|
||||
"id": page_id,
|
||||
"title": title,
|
||||
"spaceId": space_id,
|
||||
"body": {"storage": {"value": body}},
|
||||
"comments": comments or [],
|
||||
}
|
||||
|
||||
|
||||
def _to_markdown(page: dict) -> str:
|
||||
page_title = page.get("title", "")
|
||||
page_content = page.get("body", {}).get("storage", {}).get("value", "")
|
||||
comments = page.get("comments", [])
|
||||
comments_content = ""
|
||||
if comments:
|
||||
comments_content = "\n\n## Comments\n\n"
|
||||
for comment in comments:
|
||||
comment_body = (
|
||||
comment.get("body", {}).get("storage", {}).get("value", "")
|
||||
)
|
||||
comment_author = comment.get("version", {}).get("authorId", "Unknown")
|
||||
comment_date = comment.get("version", {}).get("createdAt", "")
|
||||
comments_content += (
|
||||
f"**Comment by {comment_author}** ({comment_date}):\n"
|
||||
f"{comment_body}\n\n"
|
||||
)
|
||||
return f"# {page_title}\n\n{page_content}{comments_content}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 1: _build_connector_doc tracer bullet
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_build_connector_doc_produces_correct_fields():
|
||||
page = _make_page(
|
||||
page_id="abc-123",
|
||||
title="Engineering Handbook",
|
||||
space_id="ENG",
|
||||
comments=[{"id": "c1"}],
|
||||
)
|
||||
markdown = _to_markdown(page)
|
||||
|
||||
doc = _build_connector_doc(
|
||||
page,
|
||||
markdown,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert doc.title == "Engineering Handbook"
|
||||
assert doc.unique_id == "abc-123"
|
||||
assert doc.document_type == DocumentType.CONFLUENCE_CONNECTOR
|
||||
assert doc.source_markdown == markdown
|
||||
assert doc.search_space_id == _SEARCH_SPACE_ID
|
||||
assert doc.connector_id == _CONNECTOR_ID
|
||||
assert doc.created_by_id == _USER_ID
|
||||
assert doc.should_summarize is True
|
||||
assert doc.metadata["page_id"] == "abc-123"
|
||||
assert doc.metadata["page_title"] == "Engineering Handbook"
|
||||
assert doc.metadata["space_id"] == "ENG"
|
||||
assert doc.metadata["comment_count"] == 1
|
||||
assert doc.metadata["connector_id"] == _CONNECTOR_ID
|
||||
assert doc.metadata["document_type"] == "Confluence Page"
|
||||
assert doc.metadata["connector_type"] == "Confluence"
|
||||
assert doc.fallback_summary is not None
|
||||
assert "Engineering Handbook" in doc.fallback_summary
|
||||
assert markdown in doc.fallback_summary
|
||||
|
||||
|
||||
async def test_build_connector_doc_summary_disabled():
|
||||
doc = _build_connector_doc(
|
||||
_make_page(),
|
||||
_to_markdown(_make_page()),
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=False,
|
||||
)
|
||||
assert doc.should_summarize is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixtures for Slices 2-7
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mock_connector(enable_summary: bool = True):
|
||||
c = MagicMock()
|
||||
c.config = {"access_token": "tok"}
|
||||
c.enable_summary = enable_summary
|
||||
c.last_indexed_at = None
|
||||
return c
|
||||
|
||||
|
||||
def _mock_confluence_client(pages=None, error=None):
|
||||
client = MagicMock()
|
||||
client.get_pages_by_date_range = AsyncMock(
|
||||
return_value=(pages if pages is not None else [], error),
|
||||
)
|
||||
client.close = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def confluence_mocks(monkeypatch):
|
||||
mock_session = AsyncMock()
|
||||
mock_session.no_autoflush = MagicMock()
|
||||
|
||||
mock_connector = _mock_connector()
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
|
||||
)
|
||||
|
||||
confluence_client = _mock_confluence_client(pages=[_make_page()])
|
||||
monkeypatch.setattr(
|
||||
_mod, "ConfluenceHistoryConnector", MagicMock(return_value=confluence_client),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "update_connector_last_indexed", AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||
)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
mock_task_logger.log_task_success = AsyncMock()
|
||||
mock_task_logger.log_task_failure = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
|
||||
)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 1, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
|
||||
return {
|
||||
"session": mock_session,
|
||||
"connector": mock_connector,
|
||||
"confluence_client": confluence_client,
|
||||
"task_logger": mock_task_logger,
|
||||
"pipeline_mock": pipeline_mock,
|
||||
"batch_mock": batch_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_index(mocks, **overrides):
|
||||
return await index_confluence_pages(
|
||||
session=mocks["session"],
|
||||
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
|
||||
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
|
||||
user_id=overrides.get("user_id", _USER_ID),
|
||||
start_date=overrides.get("start_date", "2025-01-01"),
|
||||
end_date=overrides.get("end_date", "2025-12-31"),
|
||||
update_last_indexed=overrides.get("update_last_indexed", True),
|
||||
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 2: Full pipeline wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_one_page_calls_pipeline_and_returns_indexed_count(confluence_mocks):
|
||||
indexed, skipped, warning = await _run_index(confluence_mocks)
|
||||
assert indexed == 1
|
||||
assert skipped == 0
|
||||
assert warning is None
|
||||
|
||||
confluence_mocks["batch_mock"].assert_called_once()
|
||||
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert connector_docs[0].document_type == DocumentType.CONFLUENCE_CONNECTOR
|
||||
|
||||
|
||||
async def test_pipeline_called_with_max_concurrency_3(confluence_mocks):
|
||||
await _run_index(confluence_mocks)
|
||||
call_kwargs = confluence_mocks["batch_mock"].call_args[1]
|
||||
assert call_kwargs.get("max_concurrency") == 3
|
||||
|
||||
|
||||
async def test_migrate_legacy_docs_called_before_indexing(confluence_mocks):
|
||||
await _run_index(confluence_mocks)
|
||||
confluence_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 3: Page skipping (missing id/title/content)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_pages_with_missing_id_are_skipped(confluence_mocks):
|
||||
pages = [
|
||||
_make_page(page_id="p1", title="Valid"),
|
||||
_make_page(page_id="", title="Missing id"),
|
||||
]
|
||||
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
|
||||
pages,
|
||||
None,
|
||||
)
|
||||
_, skipped, _ = await _run_index(confluence_mocks)
|
||||
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
async def test_pages_with_missing_title_are_skipped(confluence_mocks):
|
||||
pages = [
|
||||
_make_page(page_id="p1", title="Valid"),
|
||||
_make_page(page_id="p2", title=""),
|
||||
]
|
||||
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
|
||||
pages,
|
||||
None,
|
||||
)
|
||||
_, skipped, _ = await _run_index(confluence_mocks)
|
||||
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
async def test_pages_with_no_content_are_skipped(confluence_mocks):
|
||||
pages = [
|
||||
_make_page(page_id="p1", title="Valid", body="<p>ok</p>"),
|
||||
_make_page(page_id="p2", title="Empty", body=""),
|
||||
]
|
||||
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
|
||||
pages,
|
||||
None,
|
||||
)
|
||||
_, skipped, _ = await _run_index(confluence_mocks)
|
||||
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 4: Duplicate content skipping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_duplicate_content_pages_are_skipped(confluence_mocks, monkeypatch):
|
||||
pages = [
|
||||
_make_page(page_id="p1", title="One"),
|
||||
_make_page(page_id="p2", title="Two"),
|
||||
]
|
||||
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
|
||||
pages,
|
||||
None,
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _check_dup(session, content_hash):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 2:
|
||||
dup = MagicMock()
|
||||
dup.id = 99
|
||||
dup.document_type = "OTHER"
|
||||
return dup
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
|
||||
|
||||
_, skipped, _ = await _run_index(confluence_mocks)
|
||||
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 5: Heartbeat callback forwarding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_heartbeat_callback_forwarded_to_pipeline(confluence_mocks):
|
||||
heartbeat_cb = AsyncMock()
|
||||
await _run_index(confluence_mocks, on_heartbeat_callback=heartbeat_cb)
|
||||
call_kwargs = confluence_mocks["batch_mock"].call_args[1]
|
||||
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 6: Empty pages and no-data success tuple
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_empty_pages_returns_zero_tuple(confluence_mocks):
|
||||
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
|
||||
[],
|
||||
None,
|
||||
)
|
||||
indexed, skipped, warning = await _run_index(confluence_mocks)
|
||||
assert indexed == 0
|
||||
assert skipped == 0
|
||||
assert warning is None
|
||||
confluence_mocks["batch_mock"].assert_not_called()
|
||||
|
||||
|
||||
async def test_no_pages_error_message_returns_success_tuple(confluence_mocks):
|
||||
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
|
||||
[],
|
||||
"No pages found in date range",
|
||||
)
|
||||
indexed, skipped, warning = await _run_index(confluence_mocks)
|
||||
assert indexed == 0
|
||||
assert skipped == 0
|
||||
assert warning is None
|
||||
|
||||
|
||||
async def test_api_error_still_returns_3_tuple(confluence_mocks):
|
||||
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
|
||||
[],
|
||||
"API exploded",
|
||||
)
|
||||
result = await _run_index(confluence_mocks)
|
||||
assert len(result) == 3
|
||||
assert result[0] == 0
|
||||
assert result[1] == 0
|
||||
assert "Failed to get Confluence pages" in result[2]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 7: Failed docs warning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_failed_docs_warning_in_result(confluence_mocks):
|
||||
confluence_mocks["batch_mock"].return_value = ([], 0, 2)
|
||||
_, _, warning = await _run_index(confluence_mocks)
|
||||
assert warning is not None
|
||||
assert "2 failed" in warning
|
||||
|
|
@ -0,0 +1,671 @@
|
|||
"""Tests for parallel download + indexing in the Google Drive indexer."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.tasks.connector_indexers.google_drive_indexer import (
|
||||
_download_files_parallel,
|
||||
_index_full_scan,
|
||||
_index_selected_files,
|
||||
_index_with_delta_sync,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_USER_ID = "00000000-0000-0000-0000-000000000001"
|
||||
_CONNECTOR_ID = 42
|
||||
_SEARCH_SPACE_ID = 1
|
||||
|
||||
|
||||
def _make_file_dict(file_id: str, name: str, mime: str = "text/plain") -> dict:
|
||||
return {"id": file_id, "name": name, "mimeType": mime}
|
||||
|
||||
|
||||
def _mock_extract_ok(file_id: str, file_name: str):
|
||||
"""Return a successful (markdown, metadata, None) tuple."""
|
||||
return (
|
||||
f"# Content of {file_name}",
|
||||
{"google_drive_file_id": file_id, "google_drive_file_name": file_name},
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_drive_client():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_extract(monkeypatch):
|
||||
"""Provide a helper to set the download_and_extract_content mock."""
|
||||
def _patch(side_effect=None, return_value=None):
|
||||
mock = AsyncMock(side_effect=side_effect, return_value=return_value)
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
return _patch
|
||||
|
||||
|
||||
async def test_single_file_returns_one_connector_document(
|
||||
mock_drive_client, patch_extract,
|
||||
):
|
||||
"""Tracer bullet: downloading one file produces one ConnectorDocument."""
|
||||
patch_extract(return_value=_mock_extract_ok("f1", "test.txt"))
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_drive_client,
|
||||
[_make_file_dict("f1", "test.txt")],
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert failed == 0
|
||||
assert docs[0].title == "test.txt"
|
||||
assert docs[0].unique_id == "f1"
|
||||
|
||||
|
||||
async def test_multiple_files_all_produce_documents(
|
||||
mock_drive_client, patch_extract,
|
||||
):
|
||||
"""All files are downloaded and converted to ConnectorDocuments."""
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
patch_extract(
|
||||
side_effect=[_mock_extract_ok(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
)
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_drive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 3
|
||||
assert failed == 0
|
||||
assert {d.unique_id for d in docs} == {"f0", "f1", "f2"}
|
||||
|
||||
|
||||
async def test_one_download_exception_does_not_block_others(
|
||||
mock_drive_client, patch_extract,
|
||||
):
|
||||
"""A RuntimeError in one download still lets the other files succeed."""
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
patch_extract(
|
||||
side_effect=[
|
||||
_mock_extract_ok("f0", "file0.txt"),
|
||||
RuntimeError("network timeout"),
|
||||
_mock_extract_ok("f2", "file2.txt"),
|
||||
]
|
||||
)
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_drive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 2
|
||||
assert failed == 1
|
||||
assert {d.unique_id for d in docs} == {"f0", "f2"}
|
||||
|
||||
|
||||
async def test_etl_error_counts_as_download_failure(
|
||||
mock_drive_client, patch_extract,
|
||||
):
|
||||
"""download_and_extract_content returning an error is counted as failed."""
|
||||
files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")]
|
||||
patch_extract(
|
||||
side_effect=[
|
||||
_mock_extract_ok("f0", "good.txt"),
|
||||
(None, {}, "ETL failed"),
|
||||
]
|
||||
)
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_drive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert failed == 1
|
||||
|
||||
|
||||
async def test_concurrency_bounded_by_semaphore(
|
||||
mock_drive_client, monkeypatch,
|
||||
):
|
||||
"""Peak concurrent downloads never exceeds max_concurrency."""
|
||||
lock = asyncio.Lock()
|
||||
active = 0
|
||||
peak = 0
|
||||
|
||||
async def _slow_extract(client, file):
|
||||
nonlocal active, peak
|
||||
async with lock:
|
||||
active += 1
|
||||
peak = max(peak, active)
|
||||
await asyncio.sleep(0.05)
|
||||
async with lock:
|
||||
active -= 1
|
||||
fid = file["id"]
|
||||
return _mock_extract_ok(fid, file["name"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
|
||||
_slow_extract,
|
||||
)
|
||||
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(6)]
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_drive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
max_concurrency=2,
|
||||
)
|
||||
|
||||
assert len(docs) == 6
|
||||
assert failed == 0
|
||||
assert peak <= 2, f"Peak concurrency was {peak}, expected <= 2"
|
||||
|
||||
|
||||
async def test_heartbeat_fires_during_parallel_downloads(
|
||||
mock_drive_client, monkeypatch,
|
||||
):
|
||||
"""on_heartbeat is called at least once when downloads take time."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
monkeypatch.setattr(_mod, "HEARTBEAT_INTERVAL_SECONDS", 0)
|
||||
|
||||
async def _slow_extract(client, file):
|
||||
await asyncio.sleep(0.05)
|
||||
return _mock_extract_ok(file["id"], file["name"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
|
||||
_slow_extract,
|
||||
)
|
||||
|
||||
heartbeat_calls: list[int] = []
|
||||
|
||||
async def _on_heartbeat(count: int):
|
||||
heartbeat_calls.append(count)
|
||||
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_drive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
on_heartbeat=_on_heartbeat,
|
||||
)
|
||||
|
||||
assert len(docs) == 3
|
||||
assert failed == 0
|
||||
assert len(heartbeat_calls) >= 1, "Heartbeat should have fired at least once"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 6, 6b, 6c -- _index_full_scan three-phase pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _folder_dict(file_id: str, name: str) -> dict:
|
||||
return {"id": file_id, "name": name, "mimeType": "application/vnd.google-apps.folder"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def full_scan_mocks(mock_drive_client, monkeypatch):
|
||||
"""Wire up all mocks needed to call _index_full_scan in isolation."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_connector = MagicMock()
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
mock_log_entry = MagicMock()
|
||||
|
||||
skip_results: dict[str, tuple[bool, str | None]] = {}
|
||||
|
||||
async def _fake_skip(session, file, search_space_id):
|
||||
return skip_results.get(file["id"], (False, None))
|
||||
|
||||
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
|
||||
|
||||
download_mock = AsyncMock(return_value=([], 0))
|
||||
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 0, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
|
||||
)
|
||||
|
||||
return {
|
||||
"drive_client": mock_drive_client,
|
||||
"session": mock_session,
|
||||
"connector": mock_connector,
|
||||
"task_logger": mock_task_logger,
|
||||
"log_entry": mock_log_entry,
|
||||
"skip_results": skip_results,
|
||||
"download_mock": download_mock,
|
||||
"batch_mock": batch_mock,
|
||||
"pipeline_mock": pipeline_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_full_scan(mocks, *, max_files=500, include_subfolders=False):
|
||||
return await _index_full_scan(
|
||||
mocks["drive_client"],
|
||||
mocks["session"],
|
||||
mocks["connector"],
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"folder-root",
|
||||
"My Folder",
|
||||
mocks["task_logger"],
|
||||
mocks["log_entry"],
|
||||
max_files,
|
||||
include_subfolders=include_subfolders,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
|
||||
"""Full scan collects files serially, downloads and indexes in parallel,
|
||||
and returns correct (indexed, skipped) with renames counted as indexed."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
page_files = [
|
||||
_folder_dict("folder1", "SubFolder"),
|
||||
_make_file_dict("skip1", "unchanged.txt"),
|
||||
_make_file_dict("rename1", "renamed.txt"),
|
||||
_make_file_dict("new1", "new1.txt"),
|
||||
_make_file_dict("new2", "new2.txt"),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None, None)),
|
||||
)
|
||||
|
||||
full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged")
|
||||
full_scan_mocks["skip_results"]["rename1"] = (True, "File renamed: 'old' → 'renamed.txt'")
|
||||
|
||||
mock_docs = [MagicMock(), MagicMock()]
|
||||
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
|
||||
full_scan_mocks["batch_mock"].return_value = ([], 2, 0)
|
||||
|
||||
indexed, skipped = await _run_full_scan(full_scan_mocks)
|
||||
|
||||
assert indexed == 3 # 1 renamed + 2 from batch
|
||||
assert skipped == 1 # 1 unchanged
|
||||
|
||||
full_scan_mocks["download_mock"].assert_called_once()
|
||||
call_files = full_scan_mocks["download_mock"].call_args[0][1]
|
||||
assert len(call_files) == 2
|
||||
assert {f["id"] for f in call_files} == {"new1", "new2"}
|
||||
|
||||
|
||||
async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
|
||||
"""Only max_files non-folder files are processed; the rest are ignored."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)]
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None, None)),
|
||||
)
|
||||
|
||||
full_scan_mocks["download_mock"].return_value = ([], 0)
|
||||
full_scan_mocks["batch_mock"].return_value = ([], 0, 0)
|
||||
|
||||
await _run_full_scan(full_scan_mocks, max_files=3)
|
||||
|
||||
download_call_files = full_scan_mocks["download_mock"].call_args[0][1]
|
||||
assert len(download_call_files) == 3
|
||||
|
||||
|
||||
async def test_full_scan_uses_max_concurrency_3_for_indexing(
|
||||
full_scan_mocks, monkeypatch,
|
||||
):
|
||||
"""index_batch_parallel is called with max_concurrency=3."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
page_files = [_make_file_dict("f1", "file1.txt")]
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None, None)),
|
||||
)
|
||||
|
||||
mock_docs = [MagicMock()]
|
||||
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
|
||||
full_scan_mocks["batch_mock"].return_value = ([], 1, 0)
|
||||
|
||||
await _run_full_scan(full_scan_mocks)
|
||||
|
||||
call_kwargs = full_scan_mocks["batch_mock"].call_args
|
||||
assert call_kwargs[1].get("max_concurrency") == 3 or (
|
||||
len(call_kwargs[0]) > 2 and call_kwargs[0][2] == 3
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 7 -- _index_with_delta_sync three-phase pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
||||
"""Removed/trashed changes call _remove_document; the rest go through
|
||||
_download_files_parallel and index_batch_parallel."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
changes = [
|
||||
{"fileId": "del1", "removed": True},
|
||||
{"fileId": "del2", "file": {"id": "del2", "trashed": True}},
|
||||
{"fileId": "trash1", "file": {"id": "trash1", "trashed": True}},
|
||||
{"fileId": "mod1", "file": _make_file_dict("mod1", "modified1.txt")},
|
||||
{"fileId": "mod2", "file": _make_file_dict("mod2", "modified2.txt")},
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "fetch_all_changes",
|
||||
AsyncMock(return_value=(changes, "new-token", None)),
|
||||
)
|
||||
|
||||
change_types = {
|
||||
"del1": "removed",
|
||||
"del2": "removed",
|
||||
"trash1": "trashed",
|
||||
"mod1": "modified",
|
||||
"mod2": "modified",
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
_mod, "categorize_change",
|
||||
lambda change: change_types[change["fileId"]],
|
||||
)
|
||||
|
||||
remove_calls: list[str] = []
|
||||
|
||||
async def _fake_remove(session, file_id, search_space_id):
|
||||
remove_calls.append(file_id)
|
||||
|
||||
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file",
|
||||
AsyncMock(return_value=(False, None)),
|
||||
)
|
||||
|
||||
mock_docs = [MagicMock(), MagicMock()]
|
||||
download_mock = AsyncMock(return_value=(mock_docs, 0))
|
||||
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 2, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
indexed, skipped = await _index_with_delta_sync(
|
||||
MagicMock(),
|
||||
mock_session,
|
||||
MagicMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"folder-root",
|
||||
"start-token-abc",
|
||||
mock_task_logger,
|
||||
MagicMock(),
|
||||
max_files=500,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert sorted(remove_calls) == ["del1", "del2", "trash1"]
|
||||
|
||||
download_mock.assert_called_once()
|
||||
downloaded_files = download_mock.call_args[0][1]
|
||||
assert len(downloaded_files) == 2
|
||||
assert {f["id"] for f in downloaded_files} == {"mod1", "mod2"}
|
||||
|
||||
assert indexed == 2
|
||||
assert skipped == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _index_selected_files -- parallel indexing of user-selected files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def selected_files_mocks(mock_drive_client, monkeypatch):
|
||||
"""Wire up mocks for _index_selected_files tests."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
|
||||
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
|
||||
|
||||
async def _fake_get_file(client, file_id):
|
||||
return get_file_results.get(file_id, (None, f"Not configured: {file_id}"))
|
||||
|
||||
monkeypatch.setattr(_mod, "get_file_by_id", _fake_get_file)
|
||||
|
||||
skip_results: dict[str, tuple[bool, str | None]] = {}
|
||||
|
||||
async def _fake_skip(session, file, search_space_id):
|
||||
return skip_results.get(file["id"], (False, None))
|
||||
|
||||
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
|
||||
|
||||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||
|
||||
return {
|
||||
"drive_client": mock_drive_client,
|
||||
"session": mock_session,
|
||||
"get_file_results": get_file_results,
|
||||
"skip_results": skip_results,
|
||||
"download_and_index_mock": download_and_index_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_selected(mocks, file_ids):
|
||||
return await _index_selected_files(
|
||||
mocks["drive_client"],
|
||||
mocks["session"],
|
||||
file_ids,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_selected_files_single_file_indexed(selected_files_mocks):
|
||||
"""Tracer bullet: one file fetched, not skipped, indexed via parallel pipeline."""
|
||||
selected_files_mocks["get_file_results"]["f1"] = (
|
||||
_make_file_dict("f1", "report.pdf"),
|
||||
None,
|
||||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
selected_files_mocks, [("f1", "report.pdf")],
|
||||
)
|
||||
|
||||
assert indexed == 1
|
||||
assert skipped == 0
|
||||
assert errors == []
|
||||
selected_files_mocks["download_and_index_mock"].assert_called_once()
|
||||
|
||||
|
||||
async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
|
||||
"""get_file_by_id failing for one file collects an error; others still indexed."""
|
||||
selected_files_mocks["get_file_results"]["f1"] = (
|
||||
_make_file_dict("f1", "first.txt"), None,
|
||||
)
|
||||
selected_files_mocks["get_file_results"]["f2"] = (None, "HTTP 404")
|
||||
selected_files_mocks["get_file_results"]["f3"] = (
|
||||
_make_file_dict("f3", "third.txt"), None,
|
||||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[("f1", "first.txt"), ("f2", "mid.txt"), ("f3", "third.txt")],
|
||||
)
|
||||
|
||||
assert indexed == 2
|
||||
assert skipped == 0
|
||||
assert len(errors) == 1
|
||||
assert "mid.txt" in errors[0]
|
||||
assert "HTTP 404" in errors[0]
|
||||
|
||||
|
||||
async def test_selected_files_skip_rename_counting(selected_files_mocks):
|
||||
"""Unchanged files are skipped, renames counted as indexed,
|
||||
and only new files are sent to _download_and_index."""
|
||||
for fid, fname in [("s1", "unchanged.txt"), ("r1", "renamed.txt"),
|
||||
("n1", "new1.txt"), ("n2", "new2.txt")]:
|
||||
selected_files_mocks["get_file_results"][fid] = (
|
||||
_make_file_dict(fid, fname), None,
|
||||
)
|
||||
|
||||
selected_files_mocks["skip_results"]["s1"] = (True, "unchanged")
|
||||
selected_files_mocks["skip_results"]["r1"] = (True, "File renamed: 'old' \u2192 'renamed.txt'")
|
||||
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[("s1", "unchanged.txt"), ("r1", "renamed.txt"),
|
||||
("n1", "new1.txt"), ("n2", "new2.txt")],
|
||||
)
|
||||
|
||||
assert indexed == 3 # 1 renamed + 2 batch
|
||||
assert skipped == 1 # 1 unchanged
|
||||
assert errors == []
|
||||
|
||||
mock = selected_files_mocks["download_and_index_mock"]
|
||||
mock.assert_called_once()
|
||||
call_files = mock.call_args[1].get("files") if "files" in (mock.call_args[1] or {}) else mock.call_args[0][2]
|
||||
assert len(call_files) == 2
|
||||
assert {f["id"] for f in call_files} == {"n1", "n2"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# asyncio.to_thread verification — prove blocking calls run in parallel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_client_download_file_runs_in_thread_parallel():
|
||||
"""Calling download_file concurrently via asyncio.gather should overlap
|
||||
blocking work on separate threads, proving to_thread is effective.
|
||||
|
||||
Strategy: patch _sync_download_file with a blocking time.sleep(0.2).
|
||||
Launch 3 concurrent calls. Serial would take >=0.6s; parallel < 0.4s.
|
||||
"""
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
|
||||
BLOCK_SECONDS = 0.2
|
||||
NUM_CALLS = 3
|
||||
|
||||
def _blocking_download(service, file_id, credentials):
|
||||
time.sleep(BLOCK_SECONDS)
|
||||
return b"fake-content", None
|
||||
|
||||
client = GoogleDriveClient.__new__(GoogleDriveClient)
|
||||
client.service = MagicMock()
|
||||
client._resolved_credentials = MagicMock()
|
||||
client._service_lock = asyncio.Lock()
|
||||
|
||||
with patch.object(
|
||||
GoogleDriveClient, "_sync_download_file", staticmethod(_blocking_download),
|
||||
):
|
||||
start = time.monotonic()
|
||||
results = await asyncio.gather(
|
||||
*(client.download_file(f"file-{i}") for i in range(NUM_CALLS))
|
||||
)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
for content, error in results:
|
||||
assert content == b"fake-content"
|
||||
assert error is None
|
||||
|
||||
serial_minimum = BLOCK_SECONDS * NUM_CALLS
|
||||
assert elapsed < serial_minimum, (
|
||||
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
|
||||
f"downloads are not running in parallel"
|
||||
)
|
||||
|
||||
|
||||
async def test_client_export_google_file_runs_in_thread_parallel():
|
||||
"""Same strategy for export_google_file — verify to_thread parallelism."""
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
|
||||
BLOCK_SECONDS = 0.2
|
||||
NUM_CALLS = 3
|
||||
|
||||
def _blocking_export(service, file_id, mime_type, credentials):
|
||||
time.sleep(BLOCK_SECONDS)
|
||||
return b"exported", None
|
||||
|
||||
client = GoogleDriveClient.__new__(GoogleDriveClient)
|
||||
client.service = MagicMock()
|
||||
client._resolved_credentials = MagicMock()
|
||||
client._service_lock = asyncio.Lock()
|
||||
|
||||
with patch.object(
|
||||
GoogleDriveClient, "_sync_export_google_file", staticmethod(_blocking_export),
|
||||
):
|
||||
start = time.monotonic()
|
||||
results = await asyncio.gather(
|
||||
*(client.export_google_file(f"file-{i}", "application/pdf")
|
||||
for i in range(NUM_CALLS))
|
||||
)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
for content, error in results:
|
||||
assert content == b"exported"
|
||||
assert error is None
|
||||
|
||||
serial_minimum = BLOCK_SECONDS * NUM_CALLS
|
||||
assert elapsed < serial_minimum, (
|
||||
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
|
||||
f"exports are not running in parallel"
|
||||
)
|
||||
|
|
@ -0,0 +1,372 @@
|
|||
"""Tests for Jira indexer migrated to the unified parallel pipeline."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import app.tasks.connector_indexers.jira_indexer as _mod
|
||||
from app.db import DocumentType
|
||||
from app.tasks.connector_indexers.jira_indexer import (
|
||||
_build_connector_doc,
|
||||
index_jira_issues,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_USER_ID = "00000000-0000-0000-0000-000000000001"
|
||||
_CONNECTOR_ID = 42
|
||||
_SEARCH_SPACE_ID = 1
|
||||
|
||||
|
||||
def _make_issue(
|
||||
issue_key: str = "ENG-1",
|
||||
issue_id: str = "10001",
|
||||
title: str = "Fix login",
|
||||
):
|
||||
return {"key": issue_key, "id": issue_id, "title": title}
|
||||
|
||||
|
||||
def _make_formatted_issue(
|
||||
issue_key: str = "ENG-1",
|
||||
issue_id: str = "10001",
|
||||
title: str = "Fix login",
|
||||
status: str = "In Progress",
|
||||
priority: str = "High",
|
||||
comments=None,
|
||||
):
|
||||
return {
|
||||
"key": issue_key,
|
||||
"id": issue_id,
|
||||
"title": title,
|
||||
"status": status,
|
||||
"priority": priority,
|
||||
"comments": comments or [],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 1: _build_connector_doc tracer bullet
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_build_connector_doc_produces_correct_fields():
|
||||
issue = _make_issue(issue_key="ENG-42", issue_id="4242", title="Fix auth bug")
|
||||
formatted = _make_formatted_issue(
|
||||
issue_key="ENG-42",
|
||||
issue_id="4242",
|
||||
title="Fix auth bug",
|
||||
status="Done",
|
||||
priority="Urgent",
|
||||
comments=[{"id": "c1"}],
|
||||
)
|
||||
markdown = "# ENG-42: Fix auth bug\n\nBody"
|
||||
|
||||
doc = _build_connector_doc(
|
||||
issue,
|
||||
formatted,
|
||||
markdown,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert doc.title == "ENG-42: 4242"
|
||||
assert doc.unique_id == "ENG-42"
|
||||
assert doc.document_type == DocumentType.JIRA_CONNECTOR
|
||||
assert doc.source_markdown == markdown
|
||||
assert doc.search_space_id == _SEARCH_SPACE_ID
|
||||
assert doc.connector_id == _CONNECTOR_ID
|
||||
assert doc.created_by_id == _USER_ID
|
||||
assert doc.should_summarize is True
|
||||
assert doc.metadata["issue_id"] == "ENG-42"
|
||||
assert doc.metadata["issue_identifier"] == "ENG-42"
|
||||
assert doc.metadata["issue_title"] == "4242"
|
||||
assert doc.metadata["state"] == "Done"
|
||||
assert doc.metadata["priority"] == "Urgent"
|
||||
assert doc.metadata["comment_count"] == 1
|
||||
assert doc.metadata["connector_id"] == _CONNECTOR_ID
|
||||
assert doc.metadata["document_type"] == "Jira Issue"
|
||||
assert doc.metadata["connector_type"] == "Jira"
|
||||
assert doc.fallback_summary is not None
|
||||
assert "ENG-42" in doc.fallback_summary
|
||||
assert markdown in doc.fallback_summary
|
||||
|
||||
|
||||
async def test_build_connector_doc_summary_disabled():
|
||||
doc = _build_connector_doc(
|
||||
_make_issue(),
|
||||
_make_formatted_issue(),
|
||||
"# content",
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=False,
|
||||
)
|
||||
assert doc.should_summarize is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixtures for Slices 2-7
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mock_connector(enable_summary: bool = True):
|
||||
c = MagicMock()
|
||||
c.config = {"access_token": "tok"}
|
||||
c.enable_summary = enable_summary
|
||||
c.last_indexed_at = None
|
||||
return c
|
||||
|
||||
|
||||
def _mock_jira_client(issues=None, error=None):
|
||||
client = MagicMock()
|
||||
client.get_issues_by_date_range = AsyncMock(
|
||||
return_value=(issues if issues is not None else [], error),
|
||||
)
|
||||
client.format_issue = MagicMock(
|
||||
side_effect=lambda i: _make_formatted_issue(
|
||||
issue_key=i.get("key", ""),
|
||||
issue_id=i.get("id", ""),
|
||||
title=i.get("title", ""),
|
||||
)
|
||||
)
|
||||
client.format_issue_to_markdown = MagicMock(
|
||||
side_effect=lambda fi: f"# {fi.get('key', '')}: {fi.get('id', '')}\n\nContent"
|
||||
)
|
||||
client.close = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jira_mocks(monkeypatch):
|
||||
mock_session = AsyncMock()
|
||||
mock_session.no_autoflush = MagicMock()
|
||||
|
||||
mock_connector = _mock_connector()
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
|
||||
)
|
||||
|
||||
jira_client = _mock_jira_client(issues=[_make_issue()])
|
||||
monkeypatch.setattr(
|
||||
_mod, "JiraHistoryConnector", MagicMock(return_value=jira_client),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "update_connector_last_indexed", AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||
)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
mock_task_logger.log_task_success = AsyncMock()
|
||||
mock_task_logger.log_task_failure = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
|
||||
)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 1, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
|
||||
return {
|
||||
"session": mock_session,
|
||||
"connector": mock_connector,
|
||||
"jira_client": jira_client,
|
||||
"task_logger": mock_task_logger,
|
||||
"pipeline_mock": pipeline_mock,
|
||||
"batch_mock": batch_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_index(mocks, **overrides):
|
||||
return await index_jira_issues(
|
||||
session=mocks["session"],
|
||||
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
|
||||
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
|
||||
user_id=overrides.get("user_id", _USER_ID),
|
||||
start_date=overrides.get("start_date", "2025-01-01"),
|
||||
end_date=overrides.get("end_date", "2025-12-31"),
|
||||
update_last_indexed=overrides.get("update_last_indexed", True),
|
||||
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 2: Full pipeline wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_one_issue_calls_pipeline_and_returns_indexed_count(jira_mocks):
|
||||
indexed, skipped, warning = await _run_index(jira_mocks)
|
||||
assert indexed == 1
|
||||
assert skipped == 0
|
||||
assert warning is None
|
||||
|
||||
jira_mocks["batch_mock"].assert_called_once()
|
||||
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert connector_docs[0].document_type == DocumentType.JIRA_CONNECTOR
|
||||
|
||||
|
||||
async def test_pipeline_called_with_max_concurrency_3(jira_mocks):
|
||||
await _run_index(jira_mocks)
|
||||
call_kwargs = jira_mocks["batch_mock"].call_args[1]
|
||||
assert call_kwargs.get("max_concurrency") == 3
|
||||
|
||||
|
||||
async def test_migrate_legacy_docs_called_before_indexing(jira_mocks):
|
||||
await _run_index(jira_mocks)
|
||||
jira_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 3: Issue skipping (missing key/title/content)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_issues_with_missing_key_are_skipped(jira_mocks):
|
||||
issues = [
|
||||
_make_issue(issue_key="ENG-1", issue_id="10001"),
|
||||
{"key": "", "id": "10002", "title": "No key"},
|
||||
]
|
||||
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
|
||||
|
||||
_, skipped, _ = await _run_index(jira_mocks)
|
||||
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
async def test_issues_with_missing_title_are_skipped(jira_mocks):
|
||||
issues = [
|
||||
_make_issue(issue_key="ENG-1", issue_id="10001"),
|
||||
{"key": "ENG-2", "id": "", "title": "Missing id used as title"},
|
||||
]
|
||||
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
|
||||
|
||||
_, skipped, _ = await _run_index(jira_mocks)
|
||||
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
async def test_issues_with_no_content_are_skipped(jira_mocks):
|
||||
issues = [
|
||||
_make_issue(issue_key="ENG-1", issue_id="10001"),
|
||||
_make_issue(issue_key="ENG-2", issue_id="10002"),
|
||||
]
|
||||
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
|
||||
|
||||
jira_mocks["jira_client"].format_issue_to_markdown.side_effect = [
|
||||
"# ENG-1: 10001\n\nContent",
|
||||
"",
|
||||
]
|
||||
_, skipped, _ = await _run_index(jira_mocks)
|
||||
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 4: Duplicate content skipping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_duplicate_content_issues_are_skipped(jira_mocks, monkeypatch):
|
||||
issues = [
|
||||
_make_issue(issue_key="ENG-1", issue_id="10001"),
|
||||
_make_issue(issue_key="ENG-2", issue_id="10002"),
|
||||
]
|
||||
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _check_dup(session, content_hash):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 2:
|
||||
dup = MagicMock()
|
||||
dup.id = 99
|
||||
dup.document_type = "OTHER"
|
||||
return dup
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
|
||||
|
||||
_, skipped, _ = await _run_index(jira_mocks)
|
||||
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 5: Heartbeat callback forwarding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_heartbeat_callback_forwarded_to_pipeline(jira_mocks):
|
||||
heartbeat_cb = AsyncMock()
|
||||
await _run_index(jira_mocks, on_heartbeat_callback=heartbeat_cb)
|
||||
call_kwargs = jira_mocks["batch_mock"].call_args[1]
|
||||
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 6: Empty issues and no-data success tuple
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_empty_issues_returns_zero_tuple(jira_mocks):
|
||||
jira_mocks["jira_client"].get_issues_by_date_range.return_value = ([], None)
|
||||
indexed, skipped, warning = await _run_index(jira_mocks)
|
||||
assert indexed == 0
|
||||
assert skipped == 0
|
||||
assert warning is None
|
||||
jira_mocks["batch_mock"].assert_not_called()
|
||||
|
||||
|
||||
async def test_no_issues_error_message_returns_success_tuple(jira_mocks):
|
||||
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (
|
||||
[],
|
||||
"No issues found in date range",
|
||||
)
|
||||
indexed, skipped, warning = await _run_index(jira_mocks)
|
||||
assert indexed == 0
|
||||
assert skipped == 0
|
||||
assert warning is None
|
||||
|
||||
|
||||
async def test_api_error_still_returns_3_tuple(jira_mocks):
|
||||
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (
|
||||
[],
|
||||
"API exploded",
|
||||
)
|
||||
result = await _run_index(jira_mocks)
|
||||
assert len(result) == 3
|
||||
assert result[0] == 0
|
||||
assert result[1] == 0
|
||||
assert "Failed to get Jira issues" in result[2]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 7: Failed docs warning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_failed_docs_warning_in_result(jira_mocks):
|
||||
jira_mocks["batch_mock"].return_value = ([], 0, 2)
|
||||
_, _, warning = await _run_index(jira_mocks)
|
||||
assert warning is not None
|
||||
assert "2 failed" in warning
|
||||
|
|
@ -0,0 +1,355 @@
|
|||
"""Tests for Linear indexer migrated to the unified parallel pipeline."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import app.tasks.connector_indexers.linear_indexer as _mod
|
||||
from app.db import DocumentType
|
||||
from app.tasks.connector_indexers.linear_indexer import (
|
||||
_build_connector_doc,
|
||||
index_linear_issues,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_USER_ID = "00000000-0000-0000-0000-000000000001"
|
||||
_CONNECTOR_ID = 42
|
||||
_SEARCH_SPACE_ID = 1
|
||||
|
||||
|
||||
def _make_issue(
|
||||
issue_id: str = "issue-1",
|
||||
identifier: str = "ENG-1",
|
||||
title: str = "Fix bug",
|
||||
):
|
||||
return {"id": issue_id, "identifier": identifier, "title": title}
|
||||
|
||||
|
||||
def _make_formatted_issue(
|
||||
issue_id: str = "issue-1",
|
||||
identifier: str = "ENG-1",
|
||||
title: str = "Fix bug",
|
||||
state: str = "In Progress",
|
||||
priority: str = "High",
|
||||
comments=None,
|
||||
):
|
||||
return {
|
||||
"id": issue_id,
|
||||
"identifier": identifier,
|
||||
"title": title,
|
||||
"state": state,
|
||||
"priority": priority,
|
||||
"description": "Some description",
|
||||
"comments": comments or [],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 1: _build_connector_doc tracer bullet
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_build_connector_doc_produces_correct_fields():
|
||||
"""Tracer bullet: a Linear issue produces a ConnectorDocument with correct fields."""
|
||||
issue = _make_issue(issue_id="abc-123", identifier="ENG-42", title="Fix login bug")
|
||||
formatted = _make_formatted_issue(
|
||||
issue_id="abc-123",
|
||||
identifier="ENG-42",
|
||||
title="Fix login bug",
|
||||
state="Done",
|
||||
priority="Urgent",
|
||||
comments=[{"id": "c1"}],
|
||||
)
|
||||
markdown = "# ENG-42: Fix login bug\n\nDescription here"
|
||||
|
||||
doc = _build_connector_doc(
|
||||
issue,
|
||||
formatted,
|
||||
markdown,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert doc.title == "ENG-42: Fix login bug"
|
||||
assert doc.unique_id == "abc-123"
|
||||
assert doc.document_type == DocumentType.LINEAR_CONNECTOR
|
||||
assert doc.source_markdown == markdown
|
||||
assert doc.search_space_id == _SEARCH_SPACE_ID
|
||||
assert doc.connector_id == _CONNECTOR_ID
|
||||
assert doc.created_by_id == _USER_ID
|
||||
assert doc.should_summarize is True
|
||||
assert doc.metadata["issue_id"] == "abc-123"
|
||||
assert doc.metadata["issue_identifier"] == "ENG-42"
|
||||
assert doc.metadata["issue_title"] == "Fix login bug"
|
||||
assert doc.metadata["state"] == "Done"
|
||||
assert doc.metadata["priority"] == "Urgent"
|
||||
assert doc.metadata["comment_count"] == 1
|
||||
assert doc.metadata["connector_id"] == _CONNECTOR_ID
|
||||
assert doc.metadata["document_type"] == "Linear Issue"
|
||||
assert doc.metadata["connector_type"] == "Linear"
|
||||
assert doc.fallback_summary is not None
|
||||
assert "ENG-42" in doc.fallback_summary
|
||||
assert markdown in doc.fallback_summary
|
||||
|
||||
|
||||
async def test_build_connector_doc_summary_disabled():
|
||||
"""When enable_summary is False, should_summarize is False."""
|
||||
doc = _build_connector_doc(
|
||||
_make_issue(),
|
||||
_make_formatted_issue(),
|
||||
"# content",
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=False,
|
||||
)
|
||||
|
||||
assert doc.should_summarize is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixtures for Slices 2-6
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mock_connector(enable_summary: bool = True):
|
||||
c = MagicMock()
|
||||
c.config = {"access_token": "tok"}
|
||||
c.enable_summary = enable_summary
|
||||
c.last_indexed_at = None
|
||||
return c
|
||||
|
||||
|
||||
def _mock_linear_client(issues=None, error=None):
|
||||
client = MagicMock()
|
||||
client.get_issues_by_date_range = AsyncMock(
|
||||
return_value=(issues if issues is not None else [], error),
|
||||
)
|
||||
client.format_issue = MagicMock(side_effect=lambda i: _make_formatted_issue(
|
||||
issue_id=i.get("id", ""),
|
||||
identifier=i.get("identifier", ""),
|
||||
title=i.get("title", ""),
|
||||
))
|
||||
client.format_issue_to_markdown = MagicMock(
|
||||
side_effect=lambda fi: f"# {fi.get('identifier', '')}: {fi.get('title', '')}\n\nContent"
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def linear_mocks(monkeypatch):
|
||||
"""Wire up all external boundary mocks for index_linear_issues."""
|
||||
mock_session = AsyncMock()
|
||||
mock_session.no_autoflush = MagicMock()
|
||||
|
||||
mock_connector = _mock_connector()
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
|
||||
)
|
||||
|
||||
linear_client = _mock_linear_client(issues=[_make_issue()])
|
||||
monkeypatch.setattr(
|
||||
_mod, "LinearConnector", MagicMock(return_value=linear_client),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "update_connector_last_indexed", AsyncMock(),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||
)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
mock_task_logger.log_task_success = AsyncMock()
|
||||
mock_task_logger.log_task_failure = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
|
||||
)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 1, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
|
||||
return {
|
||||
"session": mock_session,
|
||||
"connector": mock_connector,
|
||||
"linear_client": linear_client,
|
||||
"task_logger": mock_task_logger,
|
||||
"pipeline_mock": pipeline_mock,
|
||||
"batch_mock": batch_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_index(mocks, **overrides):
|
||||
return await index_linear_issues(
|
||||
session=mocks["session"],
|
||||
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
|
||||
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
|
||||
user_id=overrides.get("user_id", _USER_ID),
|
||||
start_date=overrides.get("start_date", "2025-01-01"),
|
||||
end_date=overrides.get("end_date", "2025-12-31"),
|
||||
update_last_indexed=overrides.get("update_last_indexed", True),
|
||||
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 2: Full pipeline wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_one_issue_calls_pipeline_and_returns_indexed_count(linear_mocks):
|
||||
"""One valid issue is passed to the pipeline and the indexed count is returned."""
|
||||
indexed, skipped, warning = await _run_index(linear_mocks)
|
||||
|
||||
assert indexed == 1
|
||||
assert skipped == 0
|
||||
assert warning is None
|
||||
|
||||
linear_mocks["batch_mock"].assert_called_once()
|
||||
call_args = linear_mocks["batch_mock"].call_args
|
||||
connector_docs = call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert connector_docs[0].document_type == DocumentType.LINEAR_CONNECTOR
|
||||
|
||||
|
||||
async def test_pipeline_called_with_max_concurrency_3(linear_mocks):
|
||||
"""index_batch_parallel is called with max_concurrency=3."""
|
||||
await _run_index(linear_mocks)
|
||||
|
||||
call_kwargs = linear_mocks["batch_mock"].call_args[1]
|
||||
assert call_kwargs.get("max_concurrency") == 3
|
||||
|
||||
|
||||
async def test_migrate_legacy_docs_called_before_indexing(linear_mocks):
|
||||
"""migrate_legacy_docs is called on the pipeline before index_batch_parallel."""
|
||||
await _run_index(linear_mocks)
|
||||
|
||||
linear_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 3: Issue skipping (missing ID / title)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_issues_with_missing_id_are_skipped(linear_mocks):
|
||||
"""Issues without id are skipped and not passed to the pipeline."""
|
||||
issues = [
|
||||
_make_issue(issue_id="valid-1", identifier="ENG-1", title="Valid"),
|
||||
{"id": "", "identifier": "ENG-2", "title": "No ID"},
|
||||
]
|
||||
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
|
||||
|
||||
indexed, skipped, _ = await _run_index(linear_mocks)
|
||||
|
||||
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert connector_docs[0].unique_id == "valid-1"
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
async def test_issues_with_missing_title_are_skipped(linear_mocks):
|
||||
"""Issues without title are skipped."""
|
||||
issues = [
|
||||
_make_issue(issue_id="valid-1", identifier="ENG-1", title="Valid"),
|
||||
{"id": "id-2", "identifier": "ENG-2", "title": ""},
|
||||
]
|
||||
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
|
||||
|
||||
indexed, skipped, _ = await _run_index(linear_mocks)
|
||||
|
||||
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 4: Duplicate content skipping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_duplicate_content_issues_are_skipped(linear_mocks, monkeypatch):
|
||||
"""Issues whose content hash matches an existing document are skipped."""
|
||||
issues = [
|
||||
_make_issue(issue_id="new-1", identifier="ENG-1", title="New"),
|
||||
_make_issue(issue_id="dup-1", identifier="ENG-2", title="Dup"),
|
||||
]
|
||||
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _check_dup(session, content_hash):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 2:
|
||||
dup = MagicMock()
|
||||
dup.id = 99
|
||||
dup.document_type = "OTHER"
|
||||
return dup
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
|
||||
|
||||
indexed, skipped, _ = await _run_index(linear_mocks)
|
||||
|
||||
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 5: Heartbeat callback forwarding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_heartbeat_callback_forwarded_to_pipeline(linear_mocks):
|
||||
"""on_heartbeat_callback is passed through to index_batch_parallel."""
|
||||
heartbeat_cb = AsyncMock()
|
||||
|
||||
await _run_index(linear_mocks, on_heartbeat_callback=heartbeat_cb)
|
||||
|
||||
call_kwargs = linear_mocks["batch_mock"].call_args[1]
|
||||
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 6: Empty issues early return
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_empty_issues_returns_zero_tuple(linear_mocks):
|
||||
"""When no issues are found, returns (0, 0, None) and pipeline is not called."""
|
||||
linear_mocks["linear_client"].get_issues_by_date_range.return_value = ([], None)
|
||||
|
||||
indexed, skipped, warning = await _run_index(linear_mocks)
|
||||
|
||||
assert indexed == 0
|
||||
assert skipped == 0
|
||||
assert warning is None
|
||||
|
||||
linear_mocks["batch_mock"].assert_not_called()
|
||||
|
||||
|
||||
async def test_failed_docs_warning_in_result(linear_mocks):
|
||||
"""When documents fail indexing, the warning includes the count."""
|
||||
linear_mocks["batch_mock"].return_value = ([], 0, 2)
|
||||
|
||||
_, _, warning = await _run_index(linear_mocks)
|
||||
|
||||
assert warning is not None
|
||||
assert "2 failed" in warning
|
||||
|
|
@ -0,0 +1,345 @@
|
|||
"""Tests for Notion indexer migrated to the unified parallel pipeline."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import app.tasks.connector_indexers.notion_indexer as _mod
|
||||
from app.db import DocumentType
|
||||
from app.tasks.connector_indexers.notion_indexer import (
|
||||
_build_connector_doc,
|
||||
index_notion_pages,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_USER_ID = "00000000-0000-0000-0000-000000000001"
|
||||
_CONNECTOR_ID = 42
|
||||
_SEARCH_SPACE_ID = 1
|
||||
|
||||
|
||||
def _make_page(page_id: str = "page-1", title: str = "Test Page", content=None):
|
||||
if content is None:
|
||||
content = [{"type": "paragraph", "content": "Hello world", "children": []}]
|
||||
return {"page_id": page_id, "title": title, "content": content}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 1: _build_connector_doc tracer bullet
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_build_connector_doc_produces_correct_fields():
|
||||
"""Tracer bullet: a single Notion page produces a ConnectorDocument with correct fields."""
|
||||
|
||||
page = _make_page(page_id="abc-123", title="My Notion Page")
|
||||
markdown = "# My Notion Page\n\nHello world"
|
||||
|
||||
doc = _build_connector_doc(
|
||||
page,
|
||||
markdown,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert doc.title == "My Notion Page"
|
||||
assert doc.unique_id == "abc-123"
|
||||
assert doc.document_type == DocumentType.NOTION_CONNECTOR
|
||||
assert doc.source_markdown == markdown
|
||||
assert doc.search_space_id == _SEARCH_SPACE_ID
|
||||
assert doc.connector_id == _CONNECTOR_ID
|
||||
assert doc.created_by_id == _USER_ID
|
||||
assert doc.should_summarize is True
|
||||
assert doc.metadata["page_title"] == "My Notion Page"
|
||||
assert doc.metadata["page_id"] == "abc-123"
|
||||
assert doc.metadata["connector_id"] == _CONNECTOR_ID
|
||||
assert doc.metadata["document_type"] == "Notion Page"
|
||||
assert doc.metadata["connector_type"] == "Notion"
|
||||
assert doc.fallback_summary is not None
|
||||
assert "My Notion Page" in doc.fallback_summary
|
||||
assert markdown in doc.fallback_summary
|
||||
|
||||
|
||||
async def test_build_connector_doc_summary_disabled():
|
||||
"""When enable_summary is False, should_summarize is False."""
|
||||
doc = _build_connector_doc(
|
||||
_make_page(),
|
||||
"# content",
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=False,
|
||||
)
|
||||
|
||||
assert doc.should_summarize is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixtures for Slices 2-7 (full index_notion_pages tests)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mock_connector(enable_summary: bool = True):
|
||||
c = MagicMock()
|
||||
c.config = {"access_token": "tok"}
|
||||
c.enable_summary = enable_summary
|
||||
c.last_indexed_at = None
|
||||
return c
|
||||
|
||||
|
||||
def _mock_notion_client(pages=None, skipped_count=0, legacy_token=False):
|
||||
client = MagicMock()
|
||||
client.get_all_pages = AsyncMock(return_value=pages if pages is not None else [])
|
||||
client.get_skipped_content_count = MagicMock(return_value=skipped_count)
|
||||
client.is_using_legacy_token = MagicMock(return_value=legacy_token)
|
||||
client.close = AsyncMock()
|
||||
client.set_retry_callback = MagicMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_mocks(monkeypatch):
|
||||
"""Wire up all external boundary mocks for index_notion_pages."""
|
||||
mock_session = AsyncMock()
|
||||
mock_session.no_autoflush = MagicMock()
|
||||
|
||||
mock_connector = _mock_connector()
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
|
||||
)
|
||||
|
||||
notion_client = _mock_notion_client(pages=[_make_page()])
|
||||
monkeypatch.setattr(
|
||||
_mod, "NotionHistoryConnector", MagicMock(return_value=notion_client),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "update_connector_last_indexed", AsyncMock(),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "process_blocks", MagicMock(return_value="Converted markdown content"),
|
||||
)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
mock_task_logger.log_task_success = AsyncMock()
|
||||
mock_task_logger.log_task_failure = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
|
||||
)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 1, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
|
||||
return {
|
||||
"session": mock_session,
|
||||
"connector": mock_connector,
|
||||
"notion_client": notion_client,
|
||||
"task_logger": mock_task_logger,
|
||||
"pipeline_mock": pipeline_mock,
|
||||
"batch_mock": batch_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_index(mocks, **overrides):
|
||||
return await index_notion_pages(
|
||||
session=mocks["session"],
|
||||
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
|
||||
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
|
||||
user_id=overrides.get("user_id", _USER_ID),
|
||||
start_date=overrides.get("start_date", "2025-01-01"),
|
||||
end_date=overrides.get("end_date", "2025-12-31"),
|
||||
update_last_indexed=overrides.get("update_last_indexed", True),
|
||||
on_retry_callback=overrides.get("on_retry_callback"),
|
||||
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 2: Full pipeline wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_one_page_calls_pipeline_and_returns_indexed_count(notion_mocks):
|
||||
"""One valid page is passed to the pipeline and the indexed count is returned."""
|
||||
indexed, skipped, warning = await _run_index(notion_mocks)
|
||||
|
||||
assert indexed == 1
|
||||
assert skipped == 0
|
||||
assert warning is None
|
||||
|
||||
notion_mocks["batch_mock"].assert_called_once()
|
||||
call_args = notion_mocks["batch_mock"].call_args
|
||||
connector_docs = call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert connector_docs[0].document_type == DocumentType.NOTION_CONNECTOR
|
||||
|
||||
|
||||
async def test_pipeline_called_with_max_concurrency_3(notion_mocks):
|
||||
"""index_batch_parallel is called with max_concurrency=3."""
|
||||
await _run_index(notion_mocks)
|
||||
|
||||
call_kwargs = notion_mocks["batch_mock"].call_args[1]
|
||||
assert call_kwargs.get("max_concurrency") == 3
|
||||
|
||||
|
||||
async def test_migrate_legacy_docs_called_before_indexing(notion_mocks):
|
||||
"""migrate_legacy_docs is called on the pipeline before index_batch_parallel."""
|
||||
await _run_index(notion_mocks)
|
||||
|
||||
notion_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 3: Page skipping (no content / missing ID)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_pages_with_missing_id_are_skipped(notion_mocks, monkeypatch):
|
||||
"""Pages without page_id are skipped and not passed to the pipeline."""
|
||||
pages = [
|
||||
_make_page(page_id="valid-1"),
|
||||
{"title": "No ID page", "content": [{"type": "paragraph", "content": "text", "children": []}]},
|
||||
]
|
||||
notion_mocks["notion_client"].get_all_pages.return_value = pages
|
||||
|
||||
_, skipped, _ = await _run_index(notion_mocks)
|
||||
|
||||
connector_docs = notion_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert connector_docs[0].unique_id == "valid-1"
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
async def test_pages_with_no_content_are_skipped(notion_mocks, monkeypatch):
|
||||
"""Pages with empty content are skipped."""
|
||||
pages = [
|
||||
_make_page(page_id="valid-1"),
|
||||
_make_page(page_id="empty-1", content=[]),
|
||||
]
|
||||
notion_mocks["notion_client"].get_all_pages.return_value = pages
|
||||
|
||||
_, skipped, _ = await _run_index(notion_mocks)
|
||||
|
||||
connector_docs = notion_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 4: Duplicate content skipping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_duplicate_content_pages_are_skipped(notion_mocks, monkeypatch):
|
||||
"""Pages whose content hash matches an existing document are skipped."""
|
||||
pages = [
|
||||
_make_page(page_id="new-1"),
|
||||
_make_page(page_id="dup-1"),
|
||||
]
|
||||
notion_mocks["notion_client"].get_all_pages.return_value = pages
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _check_dup(session, content_hash):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 2:
|
||||
dup = MagicMock()
|
||||
dup.id = 99
|
||||
dup.document_type = "OTHER"
|
||||
return dup
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
|
||||
|
||||
_, skipped, _ = await _run_index(notion_mocks)
|
||||
|
||||
connector_docs = notion_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
assert skipped == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 5: Heartbeat callback forwarding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_heartbeat_callback_forwarded_to_pipeline(notion_mocks):
|
||||
"""on_heartbeat_callback is passed through to index_batch_parallel."""
|
||||
heartbeat_cb = AsyncMock()
|
||||
|
||||
await _run_index(notion_mocks, on_heartbeat_callback=heartbeat_cb)
|
||||
|
||||
call_kwargs = notion_mocks["batch_mock"].call_args[1]
|
||||
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 6: Notion-specific warning messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_skipped_ai_content_warning_in_result(notion_mocks):
|
||||
"""When Notion AI content was skipped, the warning message includes it."""
|
||||
notion_mocks["notion_client"].get_skipped_content_count.return_value = 3
|
||||
|
||||
_, _, warning = await _run_index(notion_mocks)
|
||||
|
||||
assert warning is not None
|
||||
assert "API limitation" in warning
|
||||
|
||||
|
||||
async def test_legacy_token_warning_in_result(notion_mocks):
|
||||
"""When using legacy token, the warning message includes a notice."""
|
||||
notion_mocks["notion_client"].is_using_legacy_token.return_value = True
|
||||
|
||||
_, _, warning = await _run_index(notion_mocks)
|
||||
|
||||
assert warning is not None
|
||||
assert "legacy token" in warning.lower()
|
||||
|
||||
|
||||
async def test_failed_docs_warning_in_result(notion_mocks):
|
||||
"""When documents fail indexing, the warning includes the count."""
|
||||
notion_mocks["batch_mock"].return_value = ([], 0, 2)
|
||||
|
||||
_, _, warning = await _run_index(notion_mocks)
|
||||
|
||||
assert warning is not None
|
||||
assert "2 failed" in warning
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 7: Empty pages early return
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_empty_pages_returns_zero_tuple(notion_mocks):
|
||||
"""When no pages are found, returns (0, 0, None) and updates last_indexed."""
|
||||
notion_mocks["notion_client"].get_all_pages.return_value = []
|
||||
|
||||
indexed, skipped, warning = await _run_index(notion_mocks)
|
||||
|
||||
assert indexed == 0
|
||||
assert skipped == 0
|
||||
assert warning is None
|
||||
|
||||
notion_mocks["batch_mock"].assert_not_called()
|
||||
|
|
@ -3,6 +3,7 @@ import pytest
|
|||
from app.db import DocumentType
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
compute_identifier_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
|
||||
|
|
@ -61,3 +62,23 @@ def test_different_content_produces_different_content_hash(make_connector_docume
|
|||
doc_a = make_connector_document(source_markdown="Original content")
|
||||
doc_b = make_connector_document(source_markdown="Updated content")
|
||||
assert compute_content_hash(doc_a) != compute_content_hash(doc_b)
|
||||
|
||||
|
||||
def test_compute_identifier_hash_matches_connector_doc_hash(make_connector_document):
|
||||
"""Raw-args hash equals ConnectorDocument hash for equivalent inputs."""
|
||||
doc = make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id="msg-123",
|
||||
search_space_id=5,
|
||||
)
|
||||
raw_hash = compute_identifier_hash("GOOGLE_GMAIL_CONNECTOR", "msg-123", 5)
|
||||
assert raw_hash == compute_unique_identifier_hash(doc)
|
||||
|
||||
|
||||
def test_compute_identifier_hash_differs_for_different_inputs():
|
||||
"""Different arguments produce different hashes."""
|
||||
h1 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-1", 1)
|
||||
h2 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-2", 1)
|
||||
h3 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-1", 2)
|
||||
h4 = compute_identifier_hash("COMPOSIO_GOOGLE_DRIVE_CONNECTOR", "file-1", 1)
|
||||
assert len({h1, h2, h3, h4}) == 4
|
||||
|
|
|
|||
|
|
@ -0,0 +1,82 @@
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_unique_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(mock_session):
|
||||
return IndexingPipelineService(mock_session)
|
||||
|
||||
|
||||
async def test_calls_prepare_then_index_per_document(
|
||||
pipeline, make_connector_document
|
||||
):
|
||||
"""index_batch calls prepare_for_indexing, then index() for each returned doc."""
|
||||
doc1 = make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id="msg-1",
|
||||
search_space_id=1,
|
||||
)
|
||||
doc2 = make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id="msg-2",
|
||||
search_space_id=1,
|
||||
)
|
||||
|
||||
orm1 = MagicMock(spec=Document)
|
||||
orm1.unique_identifier_hash = compute_unique_identifier_hash(doc1)
|
||||
orm2 = MagicMock(spec=Document)
|
||||
orm2.unique_identifier_hash = compute_unique_identifier_hash(doc2)
|
||||
|
||||
mock_llm = MagicMock()
|
||||
|
||||
pipeline.prepare_for_indexing = AsyncMock(return_value=[orm1, orm2])
|
||||
pipeline.index = AsyncMock(side_effect=lambda doc, cdoc, llm: doc)
|
||||
|
||||
results = await pipeline.index_batch([doc1, doc2], mock_llm)
|
||||
|
||||
pipeline.prepare_for_indexing.assert_awaited_once_with([doc1, doc2])
|
||||
assert pipeline.index.await_count == 2
|
||||
assert results == [orm1, orm2]
|
||||
|
||||
|
||||
async def test_empty_input_returns_empty(pipeline):
|
||||
"""Empty connector_docs list returns empty result."""
|
||||
pipeline.prepare_for_indexing = AsyncMock(return_value=[])
|
||||
|
||||
results = await pipeline.index_batch([], MagicMock())
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
async def test_skips_document_without_matching_connector_doc(
|
||||
pipeline, make_connector_document
|
||||
):
|
||||
"""If prepare returns a doc whose hash has no matching ConnectorDocument, it's skipped."""
|
||||
doc1 = make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id="msg-1",
|
||||
search_space_id=1,
|
||||
)
|
||||
|
||||
orphan_orm = MagicMock(spec=Document)
|
||||
orphan_orm.unique_identifier_hash = "nonexistent-hash"
|
||||
|
||||
pipeline.prepare_for_indexing = AsyncMock(return_value=[orphan_orm])
|
||||
pipeline.index = AsyncMock()
|
||||
|
||||
results = await pipeline.index_batch([doc1], MagicMock())
|
||||
|
||||
pipeline.index.assert_not_awaited()
|
||||
assert results == []
|
||||
|
|
@ -0,0 +1,186 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_unique_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
session = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(mock_session):
|
||||
return IndexingPipelineService(mock_session)
|
||||
|
||||
|
||||
def _make_orm_doc(connector_doc, doc_id):
|
||||
"""Create a MagicMock Document bound to a ConnectorDocument's hash."""
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.id = doc_id
|
||||
doc.unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
|
||||
doc.status = DocumentStatus.pending()
|
||||
return doc
|
||||
|
||||
|
||||
async def test_index_calls_embed_and_chunk_via_to_thread(
|
||||
pipeline, make_connector_document, monkeypatch
|
||||
):
|
||||
"""index() runs embed_texts and chunk_text via asyncio.to_thread, not blocking the loop."""
|
||||
to_thread_calls = []
|
||||
original_to_thread = asyncio.to_thread
|
||||
|
||||
async def tracking_to_thread(func, *args, **kwargs):
|
||||
to_thread_calls.append(func.__name__)
|
||||
return await original_to_thread(func, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(asyncio, "to_thread", tracking_to_thread)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
|
||||
AsyncMock(return_value="Summary."),
|
||||
)
|
||||
mock_chunk = MagicMock(return_value=["chunk1"])
|
||||
mock_chunk.__name__ = "chunk_text"
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
|
||||
mock_chunk,
|
||||
)
|
||||
mock_embed = MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts])
|
||||
mock_embed.__name__ = "embed_texts"
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.embed_texts",
|
||||
mock_embed,
|
||||
)
|
||||
|
||||
connector_doc = make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id="msg-1",
|
||||
search_space_id=1,
|
||||
)
|
||||
document = MagicMock(spec=Document)
|
||||
document.id = 1
|
||||
document.status = DocumentStatus.pending()
|
||||
|
||||
await pipeline.index(document, connector_doc, llm=MagicMock())
|
||||
|
||||
assert "chunk_text" in to_thread_calls
|
||||
assert "embed_texts" in to_thread_calls
|
||||
|
||||
|
||||
def _mock_session_factory(orm_docs_by_id):
|
||||
"""Replace get_celery_session_maker with a two-level callable.
|
||||
|
||||
get_celery_session_maker() -> session_maker
|
||||
session_maker() -> async context manager yielding a mock session
|
||||
"""
|
||||
|
||||
def _get_maker():
|
||||
def _make_session():
|
||||
session = MagicMock()
|
||||
session.get = AsyncMock(
|
||||
side_effect=lambda model, doc_id: orm_docs_by_id.get(doc_id)
|
||||
)
|
||||
ctx = MagicMock()
|
||||
ctx.__aenter__ = AsyncMock(return_value=session)
|
||||
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
return ctx
|
||||
|
||||
return _make_session
|
||||
|
||||
return _get_maker
|
||||
|
||||
|
||||
async def test_batch_parallel_indexes_all_documents(
|
||||
pipeline, make_connector_document, monkeypatch
|
||||
):
|
||||
"""index_batch_parallel indexes all documents and returns correct counts."""
|
||||
docs = [
|
||||
make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id=f"msg-{i}",
|
||||
search_space_id=1,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)]
|
||||
pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs)
|
||||
|
||||
orm_by_id = {d.id: d for d in orm_docs}
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.celery_tasks.get_celery_session_maker",
|
||||
_mock_session_factory(orm_by_id),
|
||||
)
|
||||
|
||||
index_calls = []
|
||||
|
||||
async def fake_index(self, document, connector_doc, llm):
|
||||
index_calls.append(document.id)
|
||||
document.status = DocumentStatus.ready()
|
||||
return document
|
||||
|
||||
monkeypatch.setattr(IndexingPipelineService, "index", fake_index)
|
||||
|
||||
async def mock_get_llm(session):
|
||||
return MagicMock()
|
||||
|
||||
_, indexed, failed = await pipeline.index_batch_parallel(
|
||||
docs, mock_get_llm, max_concurrency=2
|
||||
)
|
||||
|
||||
assert indexed == 3
|
||||
assert failed == 0
|
||||
assert sorted(index_calls) == [1, 2, 3]
|
||||
|
||||
|
||||
async def test_batch_parallel_one_failure_does_not_affect_others(
|
||||
pipeline, make_connector_document, monkeypatch
|
||||
):
|
||||
"""One document failure doesn't prevent other documents from being indexed."""
|
||||
docs = [
|
||||
make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id=f"msg-{i}",
|
||||
search_space_id=1,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)]
|
||||
pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs)
|
||||
|
||||
orm_by_id = {d.id: d for d in orm_docs}
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.celery_tasks.get_celery_session_maker",
|
||||
_mock_session_factory(orm_by_id),
|
||||
)
|
||||
|
||||
async def failing_index(self, document, connector_doc, llm):
|
||||
if document.id == 2:
|
||||
raise RuntimeError("LLM exploded")
|
||||
document.status = DocumentStatus.ready()
|
||||
return document
|
||||
|
||||
monkeypatch.setattr(IndexingPipelineService, "index", failing_index)
|
||||
|
||||
async def mock_get_llm(session):
|
||||
return MagicMock()
|
||||
|
||||
_, indexed, failed = await pipeline.index_batch_parallel(
|
||||
docs, mock_get_llm, max_concurrency=4
|
||||
)
|
||||
|
||||
assert indexed == 2
|
||||
assert failed == 1
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(mock_session):
|
||||
return IndexingPipelineService(mock_session)
|
||||
|
||||
|
||||
def _make_execute_side_effect(doc_by_hash: dict):
|
||||
"""Return a side_effect for session.execute that resolves documents by hash."""
|
||||
|
||||
async def _side_effect(stmt):
|
||||
result = MagicMock()
|
||||
for h, doc in doc_by_hash.items():
|
||||
if h in str(stmt.compile(compile_kwargs={"literal_binds": True})):
|
||||
result.scalars.return_value.first.return_value = doc
|
||||
return result
|
||||
result.scalars.return_value.first.return_value = None
|
||||
return result
|
||||
|
||||
return _side_effect
|
||||
|
||||
|
||||
async def test_updates_hash_and_type_for_legacy_document(
|
||||
pipeline, mock_session, make_connector_document
|
||||
):
|
||||
"""Legacy Composio document gets unique_identifier_hash and document_type updated."""
|
||||
doc = make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id="msg-abc",
|
||||
search_space_id=1,
|
||||
)
|
||||
|
||||
legacy_hash = compute_identifier_hash("COMPOSIO_GMAIL_CONNECTOR", "msg-abc", 1)
|
||||
native_hash = compute_identifier_hash("GOOGLE_GMAIL_CONNECTOR", "msg-abc", 1)
|
||||
|
||||
existing = MagicMock(spec=Document)
|
||||
existing.unique_identifier_hash = legacy_hash
|
||||
existing.document_type = DocumentType.COMPOSIO_GMAIL_CONNECTOR
|
||||
|
||||
result_mock = MagicMock()
|
||||
result_mock.scalars.return_value.first.return_value = existing
|
||||
mock_session.execute = AsyncMock(return_value=result_mock)
|
||||
|
||||
await pipeline.migrate_legacy_docs([doc])
|
||||
|
||||
assert existing.unique_identifier_hash == native_hash
|
||||
assert existing.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_noop_when_no_legacy_document_exists(
|
||||
pipeline, mock_session, make_connector_document
|
||||
):
|
||||
"""No updates when no legacy Composio document is found in DB."""
|
||||
doc = make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id="msg-xyz",
|
||||
search_space_id=1,
|
||||
)
|
||||
|
||||
result_mock = MagicMock()
|
||||
result_mock.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=result_mock)
|
||||
|
||||
await pipeline.migrate_legacy_docs([doc])
|
||||
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_skips_non_google_doc_types(
|
||||
pipeline, mock_session, make_connector_document
|
||||
):
|
||||
"""Non-Google doc types have no legacy mapping and trigger no DB query."""
|
||||
doc = make_connector_document(
|
||||
document_type=DocumentType.SLACK_CONNECTOR,
|
||||
unique_id="slack-123",
|
||||
search_space_id=1,
|
||||
)
|
||||
|
||||
await pipeline.migrate_legacy_docs([doc])
|
||||
|
||||
mock_session.execute.assert_not_awaited()
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_handles_all_three_google_types(
|
||||
pipeline, mock_session, make_connector_document
|
||||
):
|
||||
"""Each native Google type correctly maps to its Composio legacy type."""
|
||||
mappings = [
|
||||
(DocumentType.GOOGLE_GMAIL_CONNECTOR, "COMPOSIO_GMAIL_CONNECTOR"),
|
||||
(DocumentType.GOOGLE_CALENDAR_CONNECTOR, "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"),
|
||||
(DocumentType.GOOGLE_DRIVE_FILE, "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"),
|
||||
]
|
||||
for native_type, expected_legacy in mappings:
|
||||
doc = make_connector_document(
|
||||
document_type=native_type,
|
||||
unique_id="id-1",
|
||||
search_space_id=1,
|
||||
)
|
||||
|
||||
existing = MagicMock(spec=Document)
|
||||
existing.document_type = DocumentType(expected_legacy)
|
||||
|
||||
result_mock = MagicMock()
|
||||
result_mock.scalars.return_value.first.return_value = existing
|
||||
mock_session.execute = AsyncMock(return_value=result_mock)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
await pipeline.migrate_legacy_docs([doc])
|
||||
|
||||
assert existing.document_type == native_type
|
||||
8862
surfsense_backend/uv.lock
generated
8862
surfsense_backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,8 +1,14 @@
|
|||
import { loader } from "fumadocs-core/source";
|
||||
import type { Metadata } from "next";
|
||||
import { changelog } from "@/.source/server";
|
||||
import { formatDate } from "@/lib/utils";
|
||||
import { getMDXComponents } from "@/mdx-components";
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: "Changelog | SurfSense",
|
||||
description: "See what's new in SurfSense.",
|
||||
};
|
||||
|
||||
const source = loader({
|
||||
baseUrl: "/changelog",
|
||||
source: changelog.toFumadocsSource(),
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
import React from "react";
|
||||
import type { Metadata } from "next";
|
||||
import { ContactFormGridWithDetails } from "@/components/contact/contact-form";
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: "Contact | SurfSense",
|
||||
description: "Get in touch with the SurfSense team.",
|
||||
};
|
||||
|
||||
const page = () => {
|
||||
return (
|
||||
<div>
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
import React from "react";
|
||||
import type { Metadata } from "next";
|
||||
import PricingBasic from "@/components/pricing/pricing-section";
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: "Pricing | SurfSense",
|
||||
description: "Explore SurfSense plans and pricing options.",
|
||||
};
|
||||
|
||||
const page = () => {
|
||||
return (
|
||||
<div>
|
||||
|
|
|
|||
|
|
@ -473,14 +473,14 @@ export function DocumentsTableShell({
|
|||
}, [deletableSelectedIds, bulkDeleteDocuments, deleteDocument]);
|
||||
|
||||
const bulkDeleteBar = hasDeletableSelection ? (
|
||||
<div className="flex items-center justify-center py-1.5 border-b border-border/50 bg-destructive/5 shrink-0 animate-in fade-in slide-in-from-top-1 duration-150">
|
||||
<div className="absolute inset-x-0 top-0 z-10 flex items-center justify-center py-1 pointer-events-none animate-in fade-in duration-150">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setBulkDeleteConfirmOpen(true)}
|
||||
className="flex items-center gap-1.5 px-3 py-1 rounded-md bg-destructive text-destructive-foreground shadow-sm text-xs font-medium hover:bg-destructive/90 transition-colors"
|
||||
className="pointer-events-auto flex items-center gap-1.5 px-3 py-1 rounded-md bg-destructive text-destructive-foreground shadow-lg text-xs font-medium hover:bg-destructive/90 transition-colors"
|
||||
>
|
||||
<Trash2 size={12} />
|
||||
Delete ({deletableSelectedIds.length} selected)
|
||||
Delete {deletableSelectedIds.length} {deletableSelectedIds.length === 1 ? "item" : "items"}
|
||||
</button>
|
||||
</div>
|
||||
) : null;
|
||||
|
|
@ -526,7 +526,6 @@ export function DocumentsTableShell({
|
|||
</TableRow>
|
||||
</TableHeader>
|
||||
</Table>
|
||||
{bulkDeleteBar}
|
||||
{loading ? (
|
||||
<div className="flex-1 overflow-auto">
|
||||
<Table className="table-fixed w-full">
|
||||
|
|
@ -594,7 +593,8 @@ export function DocumentsTableShell({
|
|||
)}
|
||||
</div>
|
||||
) : (
|
||||
<div ref={desktopScrollRef} className="flex-1 overflow-auto">
|
||||
<div ref={desktopScrollRef} className="flex-1 overflow-auto relative">
|
||||
{bulkDeleteBar}
|
||||
<Table className="table-fixed w-full">
|
||||
<TableBody>
|
||||
{sorted.map((doc) => {
|
||||
|
|
@ -788,9 +788,6 @@ export function DocumentsTableShell({
|
|||
)}
|
||||
</div>
|
||||
|
||||
{/* Mobile bulk delete bar */}
|
||||
<div className="md:hidden">{bulkDeleteBar}</div>
|
||||
|
||||
{/* Mobile Card View */}
|
||||
{loading ? (
|
||||
<div className="md:hidden divide-y divide-border/50 flex-1 overflow-auto">
|
||||
|
|
@ -846,8 +843,9 @@ export function DocumentsTableShell({
|
|||
) : (
|
||||
<div
|
||||
ref={mobileScrollRef}
|
||||
className="md:hidden divide-y divide-border/50 flex-1 overflow-auto"
|
||||
className="md:hidden divide-y divide-border/50 flex-1 overflow-auto relative"
|
||||
>
|
||||
{bulkDeleteBar}
|
||||
{sorted.map((doc) => {
|
||||
const isMentioned = mentionedDocIds?.has(doc.id) ?? false;
|
||||
const statusState = doc.status?.state ?? "ready";
|
||||
|
|
|
|||
|
|
@ -595,6 +595,7 @@ function CreateInviteDialog({
|
|||
});
|
||||
} catch (error) {
|
||||
console.error("Failed to create invite:", error);
|
||||
toast.error("Failed to create invite. Please try again.");
|
||||
} finally {
|
||||
setCreating(false);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,9 +29,6 @@ export const createDocumentMutationAtom = atomWithMutation((get) => {
|
|||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.documents.globalQueryParams(documentsQueryParams),
|
||||
});
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined),
|
||||
});
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
@ -75,9 +72,6 @@ export const updateDocumentMutationAtom = atomWithMutation((get) => {
|
|||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.documents.document(String(request.id)),
|
||||
});
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined),
|
||||
});
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
@ -109,9 +103,6 @@ export const deleteDocumentMutationAtom = atomWithMutation((get) => {
|
|||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.documents.document(String(request.id)),
|
||||
});
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined),
|
||||
});
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,38 +0,0 @@
|
|||
import { atomWithQuery } from "jotai-tanstack-query";
|
||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||
import type { SearchDocumentsRequest } from "@/contracts/types/document.types";
|
||||
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
import { globalDocumentsQueryParamsAtom } from "./ui.atoms";
|
||||
|
||||
export const documentsAtom = atomWithQuery((get) => {
|
||||
const searchSpaceId = get(activeSearchSpaceIdAtom);
|
||||
const queryParams = get(globalDocumentsQueryParamsAtom);
|
||||
|
||||
return {
|
||||
queryKey: cacheKeys.documents.globalQueryParams(queryParams),
|
||||
enabled: !!searchSpaceId,
|
||||
queryFn: async () => {
|
||||
return documentsApiService.getDocuments({
|
||||
queryParams: queryParams,
|
||||
});
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
export const documentTypeCountsAtom = atomWithQuery((get) => {
|
||||
const searchSpaceId = get(activeSearchSpaceIdAtom);
|
||||
|
||||
return {
|
||||
queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined),
|
||||
enabled: !!searchSpaceId,
|
||||
staleTime: 10 * 60 * 1000, // 10 minutes
|
||||
queryFn: async () => {
|
||||
return documentsApiService.getDocumentTypeCounts({
|
||||
queryParams: {
|
||||
search_space_id: searchSpaceId ?? undefined,
|
||||
},
|
||||
});
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
@ -4,7 +4,7 @@ import { useAtomValue, useSetAtom } from "jotai";
|
|||
import { AlertTriangle, Cable, Settings } from "lucide-react";
|
||||
import { forwardRef, useEffect, useImperativeHandle, useMemo, useState } from "react";
|
||||
import { createPortal } from "react-dom";
|
||||
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms";
|
||||
import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
|
||||
import { statusInboxItemsAtom } from "@/atoms/inbox/status-inbox.atom";
|
||||
import {
|
||||
globalNewLLMConfigsAtom,
|
||||
|
|
@ -72,9 +72,9 @@ export const ConnectorIndicator = forwardRef<ConnectorIndicatorHandle, Connector
|
|||
|
||||
const llmConfigLoading = preferencesLoading || globalConfigsLoading;
|
||||
|
||||
// Fetch document type counts via the lightweight /type-counts endpoint (cached 10 min)
|
||||
const { data: documentTypeCounts, isFetching: documentTypesLoading } =
|
||||
useAtomValue(documentTypeCountsAtom);
|
||||
// Real-time document type counts via Zero (updates instantly as docs are indexed)
|
||||
const documentTypeCounts = useZeroDocumentTypeCounts(searchSpaceId);
|
||||
const documentTypesLoading = documentTypeCounts === undefined;
|
||||
|
||||
// Read status inbox items from shared atom (populated by LayoutDataProvider)
|
||||
// instead of creating a duplicate useInbox("status") hook.
|
||||
|
|
|
|||
|
|
@ -867,6 +867,9 @@ export const useConnectorDialog = () => {
|
|||
|
||||
setIsOpen(false);
|
||||
setIsFromOAuth(false);
|
||||
setIndexingConfig(null);
|
||||
setIndexingConnector(null);
|
||||
setIndexingConnectorConfig(null);
|
||||
|
||||
refreshConnectors();
|
||||
queryClient.invalidateQueries({
|
||||
|
|
@ -898,6 +901,9 @@ export const useConnectorDialog = () => {
|
|||
const handleSkipIndexing = useCallback(() => {
|
||||
setIsOpen(false);
|
||||
setIsFromOAuth(false);
|
||||
setIndexingConfig(null);
|
||||
setIndexingConnector(null);
|
||||
setIndexingConnectorConfig(null);
|
||||
}, [setIsOpen]);
|
||||
|
||||
// Handle starting edit mode
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
"use client";
|
||||
|
||||
import { FolderPlus } from "lucide-react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
|
|
@ -52,22 +51,25 @@ export function CreateFolderDialog({
|
|||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-sm">
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
<FolderPlus className="size-5 text-muted-foreground" />
|
||||
<DialogContent className="select-none max-w-[90vw] sm:max-w-sm p-4 sm:p-5 data-[state=open]:animate-none data-[state=closed]:animate-none">
|
||||
<DialogHeader className="space-y-2 pb-2">
|
||||
<div className="flex items-center gap-2 sm:gap-3">
|
||||
<div className="flex-1 min-w-0">
|
||||
<DialogTitle className="text-base sm:text-lg">
|
||||
{isSubfolder ? "New subfolder" : "New folder"}
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
<DialogDescription className="text-xs sm:text-sm mt-0.5">
|
||||
{isSubfolder
|
||||
? `Create a new folder inside "${parentFolderName}".`
|
||||
: "Create a new folder at the root level."}
|
||||
</DialogDescription>
|
||||
</div>
|
||||
</div>
|
||||
</DialogHeader>
|
||||
|
||||
<form onSubmit={handleSubmit} className="flex flex-col gap-4">
|
||||
<form onSubmit={handleSubmit} className="flex flex-col gap-3 sm:gap-4">
|
||||
<div className="flex flex-col gap-2">
|
||||
<Label htmlFor="folder-name">Folder name</Label>
|
||||
<Label htmlFor="folder-name" className="text-sm">Folder name</Label>
|
||||
<Input
|
||||
ref={inputRef}
|
||||
id="folder-name"
|
||||
|
|
@ -76,14 +78,24 @@ export function CreateFolderDialog({
|
|||
onChange={(e) => setName(e.target.value)}
|
||||
maxLength={255}
|
||||
autoComplete="off"
|
||||
className="text-sm h-9 sm:h-10"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
<Button type="button" variant="outline" onClick={() => onOpenChange(false)}>
|
||||
<DialogFooter className="flex-row justify-end gap-2 pt-2 sm:pt-3">
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
onClick={() => onOpenChange(false)}
|
||||
className="h-8 sm:h-9 text-xs sm:text-sm"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button type="submit" disabled={!name.trim()}>
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={!name.trim()}
|
||||
className="h-8 sm:h-9 text-xs sm:text-sm"
|
||||
>
|
||||
Create
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
|
|
|
|||
|
|
@ -1,25 +1,32 @@
|
|||
"use client";
|
||||
|
||||
import { Eye, MoreHorizontal, Move, Pencil, Trash2 } from "lucide-react";
|
||||
import React, { useCallback } from "react";
|
||||
import { AlertCircle, Clock, Download, Eye, MoreHorizontal, Move, PenLine, Trash2 } from "lucide-react";
|
||||
import React, { useCallback, useRef, useState } from "react";
|
||||
import { useDrag } from "react-dnd";
|
||||
import { getDocumentTypeIcon } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon";
|
||||
import { ExportContextItems, ExportDropdownItems } from "@/components/shared/ExportMenuItems";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import {
|
||||
ContextMenu,
|
||||
ContextMenuContent,
|
||||
ContextMenuItem,
|
||||
ContextMenuSeparator,
|
||||
ContextMenuSub,
|
||||
ContextMenuSubContent,
|
||||
ContextMenuSubTrigger,
|
||||
ContextMenuTrigger,
|
||||
} from "@/components/ui/context-menu";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuSub,
|
||||
DropdownMenuSubContent,
|
||||
DropdownMenuSubTrigger,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { DND_TYPES } from "./FolderNode";
|
||||
|
|
@ -41,6 +48,9 @@ interface DocumentNodeProps {
|
|||
onEdit: (doc: DocumentNodeDoc) => void;
|
||||
onDelete: (doc: DocumentNodeDoc) => void;
|
||||
onMove: (doc: DocumentNodeDoc) => void;
|
||||
onExport?: (doc: DocumentNodeDoc, format: string) => void;
|
||||
contextMenuOpen?: boolean;
|
||||
onContextMenuOpenChange?: (open: boolean) => void;
|
||||
}
|
||||
|
||||
export const DocumentNode = React.memo(function DocumentNode({
|
||||
|
|
@ -52,6 +62,9 @@ export const DocumentNode = React.memo(function DocumentNode({
|
|||
onEdit,
|
||||
onDelete,
|
||||
onMove,
|
||||
onExport,
|
||||
contextMenuOpen,
|
||||
onContextMenuOpenChange,
|
||||
}: DocumentNodeProps) {
|
||||
const statusState = doc.status?.state ?? "ready";
|
||||
const isSelectable = statusState !== "pending" && statusState !== "processing";
|
||||
|
|
@ -74,48 +87,90 @@ export const DocumentNode = React.memo(function DocumentNode({
|
|||
);
|
||||
|
||||
const isProcessing = statusState === "pending" || statusState === "processing";
|
||||
const [dropdownOpen, setDropdownOpen] = useState(false);
|
||||
const [exporting, setExporting] = useState<string | null>(null);
|
||||
const rowRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
const handleExport = useCallback(
|
||||
(format: string) => {
|
||||
if (!onExport) return;
|
||||
setExporting(format);
|
||||
onExport(doc, format);
|
||||
setTimeout(() => setExporting(null), 2000);
|
||||
},
|
||||
[doc, onExport]
|
||||
);
|
||||
|
||||
const attachRef = useCallback(
|
||||
(node: HTMLButtonElement | null) => {
|
||||
(rowRef as React.MutableRefObject<HTMLButtonElement | null>).current = node;
|
||||
drag(node);
|
||||
},
|
||||
[drag]
|
||||
);
|
||||
|
||||
return (
|
||||
<ContextMenu>
|
||||
<ContextMenu onOpenChange={onContextMenuOpenChange}>
|
||||
<ContextMenuTrigger asChild>
|
||||
{/* biome-ignore lint/a11y/useSemanticElements: div required for drag ref */}
|
||||
<div
|
||||
ref={drag}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
<button
|
||||
type="button"
|
||||
ref={attachRef}
|
||||
className={cn(
|
||||
"group flex h-8 items-center gap-1.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none",
|
||||
"group flex h-8 w-full items-center gap-2.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none text-left",
|
||||
isMentioned && "bg-accent/30",
|
||||
isDragging && "opacity-40"
|
||||
)}
|
||||
style={{ paddingLeft: `${depth * 16 + 4}px` }}
|
||||
onClick={handleCheckChange}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" || e.key === " ") {
|
||||
e.preventDefault();
|
||||
handleCheckChange();
|
||||
}
|
||||
}}
|
||||
>
|
||||
{isSelectable ? (
|
||||
{(() => {
|
||||
if (statusState === "pending") {
|
||||
return (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
|
||||
<Clock className="h-3.5 w-3.5 text-muted-foreground/60" />
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top">Pending - waiting to be synced</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
if (statusState === "processing") {
|
||||
return (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
|
||||
<Spinner size="xs" className="text-primary" />
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top">Syncing</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
if (statusState === "failed") {
|
||||
return (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
|
||||
<AlertCircle className="h-3.5 w-3.5 text-destructive" />
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top" className="max-w-xs">
|
||||
{doc.status?.reason || "Processing failed"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Checkbox
|
||||
checked={isMentioned}
|
||||
onCheckedChange={handleCheckChange}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="h-3.5 w-3.5 shrink-0"
|
||||
/>
|
||||
) : (
|
||||
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
|
||||
<span
|
||||
className={cn(
|
||||
"h-2 w-2 rounded-full",
|
||||
statusState === "processing" && "animate-pulse bg-amber-500",
|
||||
statusState === "pending" && "bg-muted-foreground/40",
|
||||
statusState === "failed" && "bg-destructive"
|
||||
)}
|
||||
/>
|
||||
</span>
|
||||
)}
|
||||
);
|
||||
})()}
|
||||
|
||||
<span className="flex-1 min-w-0 truncate">{doc.title}</span>
|
||||
|
||||
|
|
@ -126,25 +181,28 @@ export const DocumentNode = React.memo(function DocumentNode({
|
|||
)}
|
||||
</span>
|
||||
|
||||
<DropdownMenu>
|
||||
<DropdownMenu open={dropdownOpen} onOpenChange={setDropdownOpen}>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-6 w-6 shrink-0 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
className={cn(
|
||||
"hidden sm:inline-flex h-6 w-6 shrink-0 hover:bg-transparent",
|
||||
dropdownOpen ? "opacity-100 bg-accent hover:bg-accent" : "opacity-0 group-hover:opacity-100"
|
||||
)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<MoreHorizontal className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end" className="w-44">
|
||||
<DropdownMenuContent align="end" className="w-40">
|
||||
<DropdownMenuItem onClick={() => onPreview(doc)}>
|
||||
<Eye className="mr-2 h-4 w-4" />
|
||||
Open
|
||||
</DropdownMenuItem>
|
||||
{isEditable && (
|
||||
<DropdownMenuItem onClick={() => onEdit(doc)}>
|
||||
<Pencil className="mr-2 h-4 w-4" />
|
||||
<PenLine className="mr-2 h-4 w-4" />
|
||||
Edit
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
|
|
@ -152,7 +210,17 @@ export const DocumentNode = React.memo(function DocumentNode({
|
|||
<Move className="mr-2 h-4 w-4" />
|
||||
Move to...
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
{onExport && (
|
||||
<DropdownMenuSub>
|
||||
<DropdownMenuSubTrigger>
|
||||
<Download className="mr-2 h-4 w-4" />
|
||||
Export
|
||||
</DropdownMenuSubTrigger>
|
||||
<DropdownMenuSubContent className="min-w-[180px]">
|
||||
<ExportDropdownItems onExport={handleExport} exporting={exporting} />
|
||||
</DropdownMenuSubContent>
|
||||
</DropdownMenuSub>
|
||||
)}
|
||||
<DropdownMenuItem
|
||||
className="text-destructive focus:text-destructive"
|
||||
disabled={isProcessing}
|
||||
|
|
@ -163,17 +231,18 @@ export const DocumentNode = React.memo(function DocumentNode({
|
|||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
</button>
|
||||
</ContextMenuTrigger>
|
||||
|
||||
<ContextMenuContent className="w-44">
|
||||
{contextMenuOpen && (
|
||||
<ContextMenuContent className="w-40">
|
||||
<ContextMenuItem onClick={() => onPreview(doc)}>
|
||||
<Eye className="mr-2 h-4 w-4" />
|
||||
Open
|
||||
</ContextMenuItem>
|
||||
{isEditable && (
|
||||
<ContextMenuItem onClick={() => onEdit(doc)}>
|
||||
<Pencil className="mr-2 h-4 w-4" />
|
||||
<PenLine className="mr-2 h-4 w-4" />
|
||||
Edit
|
||||
</ContextMenuItem>
|
||||
)}
|
||||
|
|
@ -181,7 +250,17 @@ export const DocumentNode = React.memo(function DocumentNode({
|
|||
<Move className="mr-2 h-4 w-4" />
|
||||
Move to...
|
||||
</ContextMenuItem>
|
||||
<ContextMenuSeparator />
|
||||
{onExport && (
|
||||
<ContextMenuSub>
|
||||
<ContextMenuSubTrigger>
|
||||
<Download className="mr-2 h-4 w-4" />
|
||||
Export
|
||||
</ContextMenuSubTrigger>
|
||||
<ContextMenuSubContent className="min-w-[180px]">
|
||||
<ExportContextItems onExport={handleExport} exporting={exporting} />
|
||||
</ContextMenuSubContent>
|
||||
</ContextMenuSub>
|
||||
)}
|
||||
<ContextMenuItem
|
||||
className="text-destructive focus:text-destructive"
|
||||
disabled={isProcessing}
|
||||
|
|
@ -191,6 +270,7 @@ export const DocumentNode = React.memo(function DocumentNode({
|
|||
Delete
|
||||
</ContextMenuItem>
|
||||
</ContextMenuContent>
|
||||
)}
|
||||
</ContextMenu>
|
||||
);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import {
|
|||
FolderPlus,
|
||||
MoreHorizontal,
|
||||
Move,
|
||||
Pencil,
|
||||
PenLine,
|
||||
Trash2,
|
||||
} from "lucide-react";
|
||||
import React, { useCallback, useEffect, useRef, useState } from "react";
|
||||
|
|
@ -18,14 +18,12 @@ import {
|
|||
ContextMenu,
|
||||
ContextMenuContent,
|
||||
ContextMenuItem,
|
||||
ContextMenuSeparator,
|
||||
ContextMenuTrigger,
|
||||
} from "@/components/ui/context-menu";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
|
@ -66,6 +64,8 @@ interface FolderNodeProps {
|
|||
onReorderFolder?: (folderId: number, beforePos: string | null, afterPos: string | null) => void;
|
||||
siblingPositions?: { before: string | null; after: string | null };
|
||||
disabledDropIds?: Set<number>;
|
||||
contextMenuOpen?: boolean;
|
||||
onContextMenuOpenChange?: (open: boolean) => void;
|
||||
}
|
||||
|
||||
function getDropZone(
|
||||
|
|
@ -99,6 +99,8 @@ export const FolderNode = React.memo(function FolderNode({
|
|||
onReorderFolder,
|
||||
siblingPositions,
|
||||
disabledDropIds,
|
||||
contextMenuOpen,
|
||||
onContextMenuOpenChange,
|
||||
}: FolderNodeProps) {
|
||||
const [renameValue, setRenameValue] = useState(folder.name);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
|
@ -213,7 +215,7 @@ export const FolderNode = React.memo(function FolderNode({
|
|||
const FolderIcon = isExpanded ? FolderOpen : Folder;
|
||||
|
||||
return (
|
||||
<ContextMenu>
|
||||
<ContextMenu onOpenChange={onContextMenuOpenChange}>
|
||||
<ContextMenuTrigger asChild disabled={isRenaming}>
|
||||
{/* biome-ignore lint/a11y/useSemanticElements: div required for drag/drop refs */}
|
||||
<div
|
||||
|
|
@ -261,7 +263,8 @@ export const FolderNode = React.memo(function FolderNode({
|
|||
onBlur={handleRenameSubmit}
|
||||
onKeyDown={handleRenameKeyDown}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="flex-1 min-w-0 rounded border border-primary bg-background px-1 py-0.5 text-sm outline-none"
|
||||
placeholder="Enter folder name"
|
||||
className="flex-1 min-w-0 bg-transparent px-1 py-0.5 text-sm outline-none caret-primary placeholder:text-muted-foreground/50"
|
||||
/>
|
||||
) : (
|
||||
<span className="flex-1 min-w-0 truncate">{folder.name}</span>
|
||||
|
|
@ -279,13 +282,13 @@ export const FolderNode = React.memo(function FolderNode({
|
|||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-6 w-6 shrink-0 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
className="hidden sm:inline-flex h-6 w-6 shrink-0 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<MoreHorizontal className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end" className="w-48">
|
||||
<DropdownMenuContent align="end" className="w-40">
|
||||
<DropdownMenuItem
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
|
|
@ -301,7 +304,7 @@ export const FolderNode = React.memo(function FolderNode({
|
|||
startRename();
|
||||
}}
|
||||
>
|
||||
<Pencil className="mr-2 h-4 w-4" />
|
||||
<PenLine className="mr-2 h-4 w-4" />
|
||||
Rename
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
|
|
@ -313,7 +316,6 @@ export const FolderNode = React.memo(function FolderNode({
|
|||
<Move className="mr-2 h-4 w-4" />
|
||||
Move to...
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuItem
|
||||
className="text-destructive focus:text-destructive"
|
||||
onClick={(e) => {
|
||||
|
|
@ -330,21 +332,20 @@ export const FolderNode = React.memo(function FolderNode({
|
|||
</div>
|
||||
</ContextMenuTrigger>
|
||||
|
||||
{!isRenaming && (
|
||||
<ContextMenuContent className="w-48">
|
||||
{!isRenaming && contextMenuOpen && (
|
||||
<ContextMenuContent className="w-40">
|
||||
<ContextMenuItem onClick={() => onCreateSubfolder(folder.id)}>
|
||||
<FolderPlus className="mr-2 h-4 w-4" />
|
||||
New subfolder
|
||||
</ContextMenuItem>
|
||||
<ContextMenuItem onClick={() => startRename()}>
|
||||
<Pencil className="mr-2 h-4 w-4" />
|
||||
<PenLine className="mr-2 h-4 w-4" />
|
||||
Rename
|
||||
</ContextMenuItem>
|
||||
<ContextMenuItem onClick={() => onMove(folder)}>
|
||||
<Move className="mr-2 h-4 w-4" />
|
||||
Move to...
|
||||
</ContextMenuItem>
|
||||
<ContextMenuSeparator />
|
||||
<ContextMenuItem
|
||||
className="text-destructive focus:text-destructive"
|
||||
onClick={() => onDelete(folder)}
|
||||
|
|
|
|||
|
|
@ -124,10 +124,18 @@ export function FolderPickerDialog({
|
|||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-sm">
|
||||
<DialogHeader>
|
||||
<DialogTitle>{title}</DialogTitle>
|
||||
{description && <DialogDescription>{description}</DialogDescription>}
|
||||
<DialogContent className="select-none max-w-[90vw] sm:max-w-sm p-4 sm:p-5 data-[state=open]:animate-none data-[state=closed]:animate-none">
|
||||
<DialogHeader className="space-y-2 pb-2">
|
||||
<div className="flex items-center gap-2 sm:gap-3">
|
||||
<div className="flex-1 min-w-0">
|
||||
<DialogTitle className="text-base sm:text-lg">{title}</DialogTitle>
|
||||
{description && (
|
||||
<DialogDescription className="text-xs sm:text-sm mt-0.5">
|
||||
{description}
|
||||
</DialogDescription>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="max-h-[300px] overflow-y-auto rounded-md border p-1">
|
||||
|
|
@ -147,11 +155,17 @@ export function FolderPickerDialog({
|
|||
{renderPickerLevel(null, 1)}
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => onOpenChange(false)}>
|
||||
<DialogFooter className="flex-row justify-end gap-2 pt-2 sm:pt-3">
|
||||
<Button
|
||||
variant="secondary"
|
||||
onClick={() => onOpenChange(false)}
|
||||
className="h-8 sm:h-9 text-xs sm:text-sm"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleConfirm}>Move here</Button>
|
||||
<Button onClick={handleConfirm} className="h-8 sm:h-9 text-xs sm:text-sm">
|
||||
Move here
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"use client";
|
||||
|
||||
import { useAtom } from "jotai";
|
||||
import { TreePine } from "lucide-react";
|
||||
import { useCallback, useMemo } from "react";
|
||||
import { CirclePlus } from "lucide-react";
|
||||
import { useCallback, useMemo, useState } from "react";
|
||||
import { DndProvider } from "react-dnd";
|
||||
import { HTML5Backend } from "react-dnd-html5-backend";
|
||||
import { renamingFolderIdAtom } from "@/atoms/documents/folder.atoms";
|
||||
|
|
@ -28,6 +28,7 @@ interface FolderTreeViewProps {
|
|||
onEditDocument: (doc: DocumentNodeDoc) => void;
|
||||
onDeleteDocument: (doc: DocumentNodeDoc) => void;
|
||||
onMoveDocument: (doc: DocumentNodeDoc) => void;
|
||||
onExportDocument?: (doc: DocumentNodeDoc, format: string) => void;
|
||||
activeTypes: DocumentTypeEnum[];
|
||||
onDropIntoFolder?: (
|
||||
itemType: "folder" | "document",
|
||||
|
|
@ -62,6 +63,7 @@ export function FolderTreeView({
|
|||
onEditDocument,
|
||||
onDeleteDocument,
|
||||
onMoveDocument,
|
||||
onExportDocument,
|
||||
activeTypes,
|
||||
onDropIntoFolder,
|
||||
onReorderFolder,
|
||||
|
|
@ -80,6 +82,8 @@ export function FolderTreeView({
|
|||
return counts;
|
||||
}, [folders, foldersByParent, docsByFolder]);
|
||||
|
||||
const [openContextMenuId, setOpenContextMenuId] = useState<string | null>(null);
|
||||
|
||||
// Single subscription for rename state — derived boolean passed to each FolderNode
|
||||
const [renamingFolderId, setRenamingFolderId] = useAtom(renamingFolderIdAtom);
|
||||
const handleStartRename = useCallback(
|
||||
|
|
@ -157,6 +161,8 @@ export function FolderTreeView({
|
|||
onDropIntoFolder={onDropIntoFolder}
|
||||
onReorderFolder={onReorderFolder}
|
||||
siblingPositions={siblingPositions}
|
||||
contextMenuOpen={openContextMenuId === `folder-${f.id}`}
|
||||
onContextMenuOpenChange={(open) => setOpenContextMenuId(open ? `folder-${f.id}` : null)}
|
||||
/>
|
||||
);
|
||||
|
||||
|
|
@ -177,6 +183,9 @@ export function FolderTreeView({
|
|||
onEdit={onEditDocument}
|
||||
onDelete={onDeleteDocument}
|
||||
onMove={onMoveDocument}
|
||||
onExport={onExportDocument}
|
||||
contextMenuOpen={openContextMenuId === `doc-${d.id}`}
|
||||
onContextMenuOpenChange={(open) => setOpenContextMenuId(open ? `doc-${d.id}` : null)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
@ -189,7 +198,7 @@ export function FolderTreeView({
|
|||
if (treeNodes.length === 0 && folders.length === 0 && documents.length === 0) {
|
||||
return (
|
||||
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-4 py-12 text-muted-foreground">
|
||||
<TreePine className="h-10 w-10" />
|
||||
<CirclePlus className="h-10 w-10 rotate-45" />
|
||||
<p className="text-sm">No documents yet</p>
|
||||
</div>
|
||||
);
|
||||
|
|
@ -198,7 +207,7 @@ export function FolderTreeView({
|
|||
if (treeNodes.length === 0 && activeTypes.length > 0) {
|
||||
return (
|
||||
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-4 py-12 text-muted-foreground">
|
||||
<TreePine className="h-10 w-10" />
|
||||
<CirclePlus className="h-10 w-10 rotate-45" />
|
||||
<p className="text-sm">No matching documents</p>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -185,7 +185,7 @@ function DateTimePickerField({
|
|||
type="time"
|
||||
value={time}
|
||||
onChange={handleTimeChange}
|
||||
className="w-[120px] text-sm shrink-0 pl-1.5 [&::-webkit-calendar-picker-indicator]:order-first [&::-webkit-calendar-picker-indicator]:mr-1"
|
||||
className="w-[120px] text-sm shrink-0 appearance-none [&::-webkit-calendar-picker-indicator]:hidden [&::-webkit-calendar-picker-indicator]:appearance-none"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -396,10 +396,13 @@ export function AllPrivateChatsSidebarContent({
|
|||
variant="ghost"
|
||||
size="icon"
|
||||
className={cn(
|
||||
"h-6 w-6 shrink-0",
|
||||
"h-6 w-6 shrink-0 hover:bg-transparent",
|
||||
isMobile
|
||||
? "opacity-0 pointer-events-none absolute"
|
||||
: openDropdownId === thread.id
|
||||
? "opacity-100"
|
||||
: "md:opacity-0 md:group-hover:opacity-100 md:focus:opacity-100",
|
||||
openDropdownId === thread.id && "bg-accent hover:bg-accent",
|
||||
"transition-opacity"
|
||||
)}
|
||||
disabled={isBusy}
|
||||
|
|
|
|||
|
|
@ -396,10 +396,13 @@ export function AllSharedChatsSidebarContent({
|
|||
variant="ghost"
|
||||
size="icon"
|
||||
className={cn(
|
||||
"h-6 w-6 shrink-0",
|
||||
"h-6 w-6 shrink-0 hover:bg-transparent",
|
||||
isMobile
|
||||
? "opacity-0 pointer-events-none absolute"
|
||||
: openDropdownId === thread.id
|
||||
? "opacity-100"
|
||||
: "md:opacity-0 md:group-hover:opacity-100 md:focus:opacity-100",
|
||||
openDropdownId === thread.id && "bg-accent hover:bg-accent",
|
||||
"transition-opacity"
|
||||
)}
|
||||
disabled={isBusy}
|
||||
|
|
|
|||
|
|
@ -79,14 +79,21 @@ export function ChatListItem({
|
|||
: "bg-gradient-to-l from-sidebar from-60% to-transparent group-hover/item:from-accent",
|
||||
isMobile
|
||||
? "opacity-0"
|
||||
: isActive
|
||||
: isActive || dropdownOpen
|
||||
? "opacity-100"
|
||||
: "opacity-0 group-hover/item:opacity-100"
|
||||
)}
|
||||
>
|
||||
<DropdownMenu open={dropdownOpen} onOpenChange={setDropdownOpen}>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="ghost" size="icon" className="pointer-events-auto h-6 w-6">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className={cn(
|
||||
"pointer-events-auto h-6 w-6 hover:bg-transparent",
|
||||
dropdownOpen && "bg-accent hover:bg-accent"
|
||||
)}
|
||||
>
|
||||
<MoreHorizontal className="h-3.5 w-3.5 text-muted-foreground" />
|
||||
<span className="sr-only">{t("more_options")}</span>
|
||||
</Button>
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import { useParams } from "next/navigation";
|
|||
import { useTranslations } from "next-intl";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems";
|
||||
import { DocumentsFilters } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsFilters";
|
||||
import {
|
||||
DocumentsTableShell,
|
||||
|
|
@ -33,6 +34,7 @@ import { useDocumentSearch } from "@/hooks/use-document-search";
|
|||
import { useDocuments } from "@/hooks/use-documents";
|
||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||
import { foldersApiService } from "@/lib/apis/folders-api.service";
|
||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||
import { queries } from "@/zero/queries/index";
|
||||
import { SidebarSlideOutPanel } from "./SidebarSlideOutPanel";
|
||||
|
||||
|
|
@ -234,6 +236,43 @@ export function DocumentsSidebar({
|
|||
setFolderPickerOpen(true);
|
||||
}, []);
|
||||
|
||||
const handleExportDocument = useCallback(
|
||||
async (doc: DocumentNodeDoc, format: string) => {
|
||||
const safeTitle =
|
||||
doc.title
|
||||
.replace(/[^a-zA-Z0-9 _-]/g, "_")
|
||||
.trim()
|
||||
.slice(0, 80) || "document";
|
||||
const ext = EXPORT_FILE_EXTENSIONS[format] ?? format;
|
||||
|
||||
try {
|
||||
const response = await authenticatedFetch(
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${doc.id}/export?format=${format}`,
|
||||
{ method: "GET" }
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({ detail: "Export failed" }));
|
||||
throw new Error(errorData.detail || "Export failed");
|
||||
}
|
||||
|
||||
const blob = await response.blob();
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = `${safeTitle}.${ext}`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
} catch (err) {
|
||||
console.error(`Export ${format} failed:`, err);
|
||||
toast.error(err instanceof Error ? err.message : `Export failed`);
|
||||
}
|
||||
},
|
||||
[searchSpaceId]
|
||||
);
|
||||
|
||||
const handleFolderPickerSelect = useCallback(
|
||||
async (targetFolderId: number | null) => {
|
||||
if (!folderPickerTarget) return;
|
||||
|
|
@ -606,6 +645,7 @@ export function DocumentsSidebar({
|
|||
}}
|
||||
onDeleteDocument={(doc) => handleDeleteDocument(doc.id)}
|
||||
onMoveDocument={handleMoveDocument}
|
||||
onExportDocument={handleExportDocument}
|
||||
activeTypes={activeTypes}
|
||||
onDropIntoFolder={handleDropIntoFolder}
|
||||
onReorderFolder={handleReorderFolder}
|
||||
|
|
@ -617,7 +657,7 @@ export function DocumentsSidebar({
|
|||
open={folderPickerOpen}
|
||||
onOpenChange={setFolderPickerOpen}
|
||||
folders={treeFolders}
|
||||
title={folderPickerTarget?.type === "folder" ? "Move folder to..." : "Move document to..."}
|
||||
title={folderPickerTarget?.type === "folder" ? "Move folder to" : "Move document to"}
|
||||
description="Select a destination folder, or choose Root to move to the top level."
|
||||
disabledFolderIds={folderPickerTarget?.disabledIds}
|
||||
onSelect={handleFolderPickerSelect}
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ export function ChatShareButton({ thread, onVisibilityChange, className }: ChatS
|
|||
className={cn(
|
||||
"w-full flex items-center gap-2.5 px-2.5 py-2 rounded-md transition-all",
|
||||
"hover:bg-accent/50 dark:hover:bg-white/10 cursor-pointer",
|
||||
"focus:outline-none",
|
||||
"focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2",
|
||||
isSelected && "bg-accent/80 dark:bg-white/10"
|
||||
)}
|
||||
>
|
||||
|
|
@ -248,7 +248,7 @@ export function ChatShareButton({ thread, onVisibilityChange, className }: ChatS
|
|||
className={cn(
|
||||
"w-full flex items-center gap-2.5 px-2.5 py-2 rounded-md transition-all",
|
||||
"hover:bg-accent/50 dark:hover:bg-white/10 cursor-pointer",
|
||||
"focus:outline-none",
|
||||
"focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2",
|
||||
"disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
)}
|
||||
>
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import { useTheme } from "next-themes";
|
|||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { createPortal } from "react-dom";
|
||||
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
|
||||
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms";
|
||||
import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
|
||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
||||
import { useIsMobile } from "@/hooks/use-mobile";
|
||||
|
|
@ -452,8 +452,8 @@ export function OnboardingTour() {
|
|||
enabled: !!searchSpaceId,
|
||||
});
|
||||
|
||||
// Get document type counts
|
||||
const { data: documentTypeCounts } = useAtomValue(documentTypeCountsAtom);
|
||||
// Real-time document type counts via Zero
|
||||
const documentTypeCounts = useZeroDocumentTypeCounts(searchSpaceId);
|
||||
|
||||
// Get connectors
|
||||
const { data: connectors = [] } = useAtomValue(connectorsAtom);
|
||||
|
|
|
|||
|
|
@ -15,10 +15,9 @@ import {
|
|||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuLabel,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { ExportDropdownItems, EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems";
|
||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||
|
|
@ -198,19 +197,6 @@ export function ReportPanelContent({
|
|||
}
|
||||
}, [currentMarkdown]);
|
||||
|
||||
// Maps backend format values to download file extensions
|
||||
const FILE_EXTENSIONS: Record<string, string> = {
|
||||
pdf: "pdf",
|
||||
docx: "docx",
|
||||
html: "html",
|
||||
latex: "tex",
|
||||
epub: "epub",
|
||||
odt: "odt",
|
||||
plain: "txt",
|
||||
md: "md",
|
||||
};
|
||||
|
||||
// Export report
|
||||
const handleExport = useCallback(
|
||||
async (format: string) => {
|
||||
setExporting(format);
|
||||
|
|
@ -219,7 +205,7 @@ export function ReportPanelContent({
|
|||
.replace(/[^a-zA-Z0-9 _-]/g, "_")
|
||||
.trim()
|
||||
.slice(0, 80) || "report";
|
||||
const ext = FILE_EXTENSIONS[format] ?? format;
|
||||
const ext = EXPORT_FILE_EXTENSIONS[format] ?? format;
|
||||
try {
|
||||
if (format === "md") {
|
||||
if (!currentMarkdown) return;
|
||||
|
|
@ -329,68 +315,11 @@ export function ReportPanelContent({
|
|||
align="start"
|
||||
className={`min-w-[200px] select-none${insideDrawer ? " z-[100]" : ""}`}
|
||||
>
|
||||
{!shareToken && (
|
||||
<>
|
||||
<DropdownMenuLabel className="text-xs text-muted-foreground">
|
||||
Documents
|
||||
</DropdownMenuLabel>
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleExport("pdf")}
|
||||
disabled={exporting !== null}
|
||||
>
|
||||
PDF (.pdf)
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleExport("docx")}
|
||||
disabled={exporting !== null}
|
||||
>
|
||||
Word (.docx)
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleExport("odt")}
|
||||
disabled={exporting !== null}
|
||||
>
|
||||
OpenDocument (.odt)
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuLabel className="text-xs text-muted-foreground">
|
||||
Web & E-Book
|
||||
</DropdownMenuLabel>
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleExport("html")}
|
||||
disabled={exporting !== null}
|
||||
>
|
||||
HTML (.html)
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleExport("epub")}
|
||||
disabled={exporting !== null}
|
||||
>
|
||||
EPUB (.epub)
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuLabel className="text-xs text-muted-foreground">
|
||||
Source & Plain
|
||||
</DropdownMenuLabel>
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleExport("latex")}
|
||||
disabled={exporting !== null}
|
||||
>
|
||||
LaTeX (.tex)
|
||||
</DropdownMenuItem>
|
||||
</>
|
||||
)}
|
||||
<DropdownMenuItem onClick={() => handleExport("md")} disabled={exporting !== null}>
|
||||
Markdown (.md)
|
||||
</DropdownMenuItem>
|
||||
{!shareToken && (
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleExport("plain")}
|
||||
disabled={exporting !== null}
|
||||
>
|
||||
Plain Text (.txt)
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
<ExportDropdownItems
|
||||
onExport={handleExport}
|
||||
exporting={exporting}
|
||||
showAllFormats={!shareToken}
|
||||
/>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager
|
|||
const {
|
||||
data: searchSpace,
|
||||
isLoading: loading,
|
||||
isError,
|
||||
refetch: fetchSearchSpace,
|
||||
} = useQuery({
|
||||
queryKey: cacheKeys.searchSpaces.detail(searchSpaceId.toString()),
|
||||
|
|
@ -104,6 +105,17 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager
|
|||
);
|
||||
}
|
||||
|
||||
if (isError) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center gap-3 py-8 text-center">
|
||||
<p className="text-sm text-destructive">Failed to load settings.</p>
|
||||
<Button variant="outline" size="sm" onClick={() => fetchSearchSpace()}>
|
||||
Retry
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-4 md:space-y-6">
|
||||
<Alert className="bg-muted/50 py-3 md:py-4">
|
||||
|
|
|
|||
142
surfsense_web/components/shared/ExportMenuItems.tsx
Normal file
142
surfsense_web/components/shared/ExportMenuItems.tsx
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
"use client";
|
||||
|
||||
import { Loader2 } from "lucide-react";
|
||||
import { DropdownMenuItem, DropdownMenuLabel, DropdownMenuSeparator } from "@/components/ui/dropdown-menu";
|
||||
import { ContextMenuItem } from "@/components/ui/context-menu";
|
||||
|
||||
export const EXPORT_FILE_EXTENSIONS: Record<string, string> = {
|
||||
pdf: "pdf",
|
||||
docx: "docx",
|
||||
html: "html",
|
||||
latex: "tex",
|
||||
epub: "epub",
|
||||
odt: "odt",
|
||||
plain: "txt",
|
||||
md: "md",
|
||||
};
|
||||
|
||||
interface ExportMenuItemsProps {
|
||||
onExport: (format: string) => void;
|
||||
exporting: string | null;
|
||||
/** Hide server-side formats (PDF, DOCX, etc.) — only show md */
|
||||
showAllFormats?: boolean;
|
||||
}
|
||||
|
||||
export function ExportDropdownItems({
|
||||
onExport,
|
||||
exporting,
|
||||
showAllFormats = true,
|
||||
}: ExportMenuItemsProps) {
|
||||
const handle = (format: string) => (e: React.MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
onExport(format);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{showAllFormats && (
|
||||
<>
|
||||
<DropdownMenuLabel className="text-xs text-muted-foreground">
|
||||
Documents
|
||||
</DropdownMenuLabel>
|
||||
<DropdownMenuItem onClick={handle("pdf")} disabled={exporting !== null}>
|
||||
{exporting === "pdf" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
PDF (.pdf)
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={handle("docx")} disabled={exporting !== null}>
|
||||
{exporting === "docx" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
Word (.docx)
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={handle("odt")} disabled={exporting !== null}>
|
||||
{exporting === "odt" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
OpenDocument (.odt)
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuLabel className="text-xs text-muted-foreground">
|
||||
Web & E-Book
|
||||
</DropdownMenuLabel>
|
||||
<DropdownMenuItem onClick={handle("html")} disabled={exporting !== null}>
|
||||
{exporting === "html" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
HTML (.html)
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={handle("epub")} disabled={exporting !== null}>
|
||||
{exporting === "epub" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
EPUB (.epub)
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuLabel className="text-xs text-muted-foreground">
|
||||
Source & Plain
|
||||
</DropdownMenuLabel>
|
||||
<DropdownMenuItem onClick={handle("latex")} disabled={exporting !== null}>
|
||||
{exporting === "latex" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
LaTeX (.tex)
|
||||
</DropdownMenuItem>
|
||||
</>
|
||||
)}
|
||||
<DropdownMenuItem onClick={handle("md")} disabled={exporting !== null}>
|
||||
{exporting === "md" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
Markdown (.md)
|
||||
</DropdownMenuItem>
|
||||
{showAllFormats && (
|
||||
<DropdownMenuItem onClick={handle("plain")} disabled={exporting !== null}>
|
||||
{exporting === "plain" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
Plain Text (.txt)
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export function ExportContextItems({
|
||||
onExport,
|
||||
exporting,
|
||||
showAllFormats = true,
|
||||
}: ExportMenuItemsProps) {
|
||||
const handle = (format: string) => (e: React.MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
onExport(format);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{showAllFormats && (
|
||||
<>
|
||||
<ContextMenuItem onClick={handle("pdf")} disabled={exporting !== null}>
|
||||
{exporting === "pdf" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
PDF (.pdf)
|
||||
</ContextMenuItem>
|
||||
<ContextMenuItem onClick={handle("docx")} disabled={exporting !== null}>
|
||||
{exporting === "docx" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
Word (.docx)
|
||||
</ContextMenuItem>
|
||||
<ContextMenuItem onClick={handle("odt")} disabled={exporting !== null}>
|
||||
{exporting === "odt" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
OpenDocument (.odt)
|
||||
</ContextMenuItem>
|
||||
<ContextMenuItem onClick={handle("html")} disabled={exporting !== null}>
|
||||
{exporting === "html" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
HTML (.html)
|
||||
</ContextMenuItem>
|
||||
<ContextMenuItem onClick={handle("epub")} disabled={exporting !== null}>
|
||||
{exporting === "epub" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
EPUB (.epub)
|
||||
</ContextMenuItem>
|
||||
<ContextMenuItem onClick={handle("latex")} disabled={exporting !== null}>
|
||||
{exporting === "latex" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
LaTeX (.tex)
|
||||
</ContextMenuItem>
|
||||
</>
|
||||
)}
|
||||
<ContextMenuItem onClick={handle("md")} disabled={exporting !== null}>
|
||||
{exporting === "md" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
Markdown (.md)
|
||||
</ContextMenuItem>
|
||||
{showAllFormats && (
|
||||
<ContextMenuItem onClick={handle("plain")} disabled={exporting !== null}>
|
||||
{exporting === "plain" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||
Plain Text (.txt)
|
||||
</ContextMenuItem>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
|
@ -253,6 +253,12 @@ function ApprovalCard({
|
|||
String(effectiveNewDescription ?? "") !== (event?.description ?? "");
|
||||
|
||||
const buildFinalArgs = useCallback(() => {
|
||||
const base = {
|
||||
event_id: event?.event_id,
|
||||
document_id: event?.document_id,
|
||||
connector_id: account?.id,
|
||||
};
|
||||
|
||||
if (pendingEdits) {
|
||||
const attendeesArr = pendingEdits.attendees
|
||||
? pendingEdits.attendees
|
||||
|
|
@ -260,22 +266,38 @@ function ApprovalCard({
|
|||
.map((e) => e.trim())
|
||||
.filter(Boolean)
|
||||
: null;
|
||||
const origAttendees = event?.attendees?.map((a) => a.email) ?? [];
|
||||
|
||||
return {
|
||||
event_id: event?.event_id,
|
||||
document_id: event?.document_id,
|
||||
connector_id: account?.id,
|
||||
new_summary: pendingEdits.summary || null,
|
||||
new_description: pendingEdits.description || null,
|
||||
new_start_datetime: pendingEdits.start_datetime || null,
|
||||
new_end_datetime: pendingEdits.end_datetime || null,
|
||||
new_location: pendingEdits.location || null,
|
||||
new_attendees: attendeesArr,
|
||||
...base,
|
||||
new_summary:
|
||||
pendingEdits.summary && pendingEdits.summary !== (event?.summary ?? "")
|
||||
? pendingEdits.summary
|
||||
: null,
|
||||
new_description:
|
||||
pendingEdits.description !== (event?.description ?? "")
|
||||
? pendingEdits.description || null
|
||||
: null,
|
||||
new_start_datetime:
|
||||
pendingEdits.start_datetime && pendingEdits.start_datetime !== (event?.start ?? "")
|
||||
? pendingEdits.start_datetime
|
||||
: null,
|
||||
new_end_datetime:
|
||||
pendingEdits.end_datetime && pendingEdits.end_datetime !== (event?.end ?? "")
|
||||
? pendingEdits.end_datetime
|
||||
: null,
|
||||
new_location:
|
||||
pendingEdits.location !== (event?.location ?? "")
|
||||
? pendingEdits.location || null
|
||||
: null,
|
||||
new_attendees:
|
||||
attendeesArr && attendeesArr.join(",") !== origAttendees.join(",")
|
||||
? attendeesArr
|
||||
: null,
|
||||
};
|
||||
}
|
||||
return {
|
||||
event_id: event?.event_id,
|
||||
document_id: event?.document_id,
|
||||
connector_id: account?.id,
|
||||
...base,
|
||||
new_summary: actionArgs.new_summary ?? null,
|
||||
new_description: actionArgs.new_description ?? null,
|
||||
new_start_datetime: actionArgs.new_start_datetime ?? null,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { createPortal } from "react-dom";
|
||||
|
||||
function isVideoSrc(src: string) {
|
||||
|
|
@ -17,6 +17,12 @@ function ExpandedMediaOverlay({
|
|||
alt: string;
|
||||
onClose: () => void;
|
||||
}) {
|
||||
const overlayRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
overlayRef.current?.focus();
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const handleKey = (e: KeyboardEvent) => {
|
||||
if (e.key === "Escape") onClose();
|
||||
|
|
@ -52,12 +58,20 @@ function ExpandedMediaOverlay({
|
|||
|
||||
return createPortal(
|
||||
<motion.div
|
||||
role="dialog"
|
||||
aria-modal="true"
|
||||
aria-label="Expanded media view"
|
||||
tabIndex={-1}
|
||||
ref={overlayRef}
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="fixed inset-0 z-100 flex items-center justify-center bg-black/70 p-4 backdrop-blur-sm sm:p-8"
|
||||
onClick={onClose}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Escape") onClose();
|
||||
}}
|
||||
>
|
||||
{mediaElement}
|
||||
</motion.div>,
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@ import {
|
|||
FileText,
|
||||
Film,
|
||||
Globe,
|
||||
ImageIcon,
|
||||
type LucideIcon,
|
||||
Podcast,
|
||||
ScanLine,
|
||||
Sparkles,
|
||||
Wrench,
|
||||
} from "lucide-react";
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ const TOOL_ICONS: Record<string, LucideIcon> = {
|
|||
generate_podcast: Podcast,
|
||||
generate_video_presentation: Film,
|
||||
generate_report: FileText,
|
||||
generate_image: Sparkles,
|
||||
generate_image: ImageIcon,
|
||||
scrape_webpage: ScanLine,
|
||||
web_search: Globe,
|
||||
search_surfsense_docs: BookOpen,
|
||||
|
|
|
|||
31
surfsense_web/hooks/use-zero-document-type-counts.ts
Normal file
31
surfsense_web/hooks/use-zero-document-type-counts.ts
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"use client";
|
||||
|
||||
import { useQuery } from "@rocicorp/zero/react";
|
||||
import { useMemo } from "react";
|
||||
import { queries } from "@/zero/queries";
|
||||
|
||||
/**
|
||||
* Real-time document type counts derived from Zero's live document sync.
|
||||
* Updates instantly as documents are created, deleted, or change type.
|
||||
*/
|
||||
export function useZeroDocumentTypeCounts(
|
||||
searchSpaceId: number | string | null
|
||||
): Record<string, number> | undefined {
|
||||
const numericId = searchSpaceId != null ? Number(searchSpaceId) : null;
|
||||
|
||||
const [zeroDocuments] = useQuery(
|
||||
queries.documents.bySpace({ searchSpaceId: numericId ?? -1 })
|
||||
);
|
||||
|
||||
return useMemo(() => {
|
||||
if (!zeroDocuments || numericId == null) return undefined;
|
||||
|
||||
const counts: Record<string, number> = {};
|
||||
for (const doc of zeroDocuments) {
|
||||
if (doc.id != null && doc.title != null && doc.title !== "") {
|
||||
counts[doc.documentType] = (counts[doc.documentType] || 0) + 1;
|
||||
}
|
||||
}
|
||||
return counts;
|
||||
}, [zeroDocuments, numericId]);
|
||||
}
|
||||
|
|
@ -17,7 +17,6 @@ export const cacheKeys = {
|
|||
withQueryParams: (queries: GetDocumentsRequest["queryParams"]) =>
|
||||
["documents-with-queries", ...(queries ? Object.values(queries) : [])] as const,
|
||||
document: (documentId: string) => ["document", documentId] as const,
|
||||
typeCounts: (searchSpaceId?: string) => ["documents", "type-counts", searchSpaceId] as const,
|
||||
byChunk: (chunkId: string) => ["documents", "by-chunk", chunkId] as const,
|
||||
},
|
||||
logs: {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue