mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-07 06:42:39 +02:00
Merge upstream/dev into feature/multi-agent
This commit is contained in:
commit
5119915f4f
278 changed files with 34669 additions and 8970 deletions
|
|
@ -1,10 +1,25 @@
|
|||
"""Celery tasks package."""
|
||||
"""Celery tasks package.
|
||||
|
||||
Also hosts the small helpers every async celery task should use to
|
||||
spin up its event loop. See :func:`run_async_celery_task` for the
|
||||
canonical pattern.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TypeVar
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_celery_engine = None
|
||||
_celery_session_maker = None
|
||||
|
||||
|
|
@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker:
|
|||
_celery_engine, expire_on_commit=False
|
||||
)
|
||||
return _celery_session_maker
|
||||
|
||||
|
||||
def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Drop the shared ``app.db.engine`` connection pool synchronously.
|
||||
|
||||
The shared engine (used by ``shielded_async_session`` and most
|
||||
routes / services) is a module-level singleton with a real pool.
|
||||
Each celery task creates a fresh ``asyncio`` event loop; asyncpg
|
||||
connections cache a reference to whichever loop opened them. When
|
||||
a subsequent task's loop pulls a stale connection from the pool,
|
||||
SQLAlchemy's ``pool_pre_ping`` checkout crashes with::
|
||||
|
||||
AttributeError: 'NoneType' object has no attribute 'send'
|
||||
File ".../asyncio/proactor_events.py", line 402, in _loop_writing
|
||||
self._write_fut = self._loop._proactor.send(self._sock, data)
|
||||
|
||||
or hangs forever inside the asyncpg ``Connection._cancel`` cleanup
|
||||
coroutine that can never run because its loop is gone.
|
||||
|
||||
Disposing the engine forces the pool to drop every cached
|
||||
connection so the next checkout opens a fresh one on the current
|
||||
loop. Safe to call from a task's finally block; failure is logged
|
||||
but never propagated.
|
||||
"""
|
||||
try:
|
||||
from app.db import engine as shared_engine
|
||||
|
||||
loop.run_until_complete(shared_engine.dispose())
|
||||
except Exception:
|
||||
logger.warning("Shared DB engine dispose() failed", exc_info=True)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T:
|
||||
"""Run an async coroutine inside a fresh event loop with proper
|
||||
DB-engine cleanup.
|
||||
|
||||
This is the canonical entry point for every async celery task.
|
||||
It performs three responsibilities that were previously copy-pasted
|
||||
(incorrectly) across each task module:
|
||||
|
||||
1. Create a fresh ``asyncio`` loop and install it on the current
|
||||
thread (celery's ``--pool=solo`` runs every task on the main
|
||||
thread, but other pool types don't).
|
||||
2. Dispose the shared ``app.db.engine`` BEFORE the task runs so
|
||||
any stale connections left over from a previous task's loop
|
||||
are dropped — defends against tasks that crashed without
|
||||
cleaning up.
|
||||
3. Dispose the shared engine AFTER the task runs so the
|
||||
connections we opened on this loop are released before the
|
||||
loop closes (avoids ``coroutine 'Connection._cancel' was
|
||||
never awaited`` warnings and the next-task hang).
|
||||
|
||||
Use as::
|
||||
|
||||
@celery_app.task(name="my_task", bind=True)
|
||||
def my_task(self, *args):
|
||||
return run_async_celery_task(lambda: _my_task_impl(*args))
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
# Defense-in-depth: prior task may have crashed before
|
||||
# disposing. Idempotent — no-op if pool is already empty.
|
||||
_dispose_shared_db_engine(loop)
|
||||
return loop.run_until_complete(coro_factory())
|
||||
finally:
|
||||
# Drop any connections this task opened so they don't leak
|
||||
# into the next task's loop.
|
||||
_dispose_shared_db_engine(loop)
|
||||
with contextlib.suppress(Exception):
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
with contextlib.suppress(Exception):
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_celery_session_maker",
|
||||
"run_async_celery_task",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import traceback
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -49,22 +49,15 @@ def index_notion_pages_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Notion pages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_notion_pages(
|
||||
return run_async_celery_task(
|
||||
lambda: _index_notion_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_notion_pages", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_notion_pages(
|
||||
|
|
@ -95,19 +88,11 @@ def index_github_repos_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index GitHub repositories."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_github_repos(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_github_repos(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_github_repos(
|
||||
|
|
@ -138,19 +123,11 @@ def index_confluence_pages_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Confluence pages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_confluence_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_confluence_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_confluence_pages(
|
||||
|
|
@ -181,22 +158,15 @@ def index_google_calendar_events_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Google Calendar events."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_google_calendar_events(
|
||||
return run_async_celery_task(
|
||||
lambda: _index_google_calendar_events(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_google_calendar_events", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_google_calendar_events(
|
||||
|
|
@ -227,19 +197,11 @@ def index_google_gmail_messages_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Google Gmail messages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_google_gmail_messages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_google_gmail_messages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_google_gmail_messages(
|
||||
|
|
@ -269,22 +231,14 @@ def index_google_drive_files_task(
|
|||
items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options'
|
||||
):
|
||||
"""Celery task to index Google Drive folders and files."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_google_drive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_google_drive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_google_drive_files(
|
||||
|
|
@ -317,22 +271,14 @@ def index_onedrive_files_task(
|
|||
items_dict: dict,
|
||||
):
|
||||
"""Celery task to index OneDrive folders and files."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_onedrive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_onedrive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_onedrive_files(
|
||||
|
|
@ -365,22 +311,14 @@ def index_dropbox_files_task(
|
|||
items_dict: dict,
|
||||
):
|
||||
"""Celery task to index Dropbox folders and files."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_dropbox_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_dropbox_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_dropbox_files(
|
||||
|
|
@ -414,19 +352,11 @@ def index_elasticsearch_documents_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Elasticsearch documents."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_elasticsearch_documents(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_elasticsearch_documents(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_elasticsearch_documents(
|
||||
|
|
@ -457,22 +387,15 @@ def index_crawled_urls_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Web page Urls."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_crawled_urls(
|
||||
return run_async_celery_task(
|
||||
lambda: _index_crawled_urls(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_crawled_urls", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_crawled_urls(
|
||||
|
|
@ -503,19 +426,11 @@ def index_bookstack_pages_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index BookStack pages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_bookstack_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_bookstack_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_bookstack_pages(
|
||||
|
|
@ -546,19 +461,11 @@ def index_composio_connector_task(
|
|||
end_date: str | None,
|
||||
):
|
||||
"""Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio)."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_composio_connector(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_composio_connector(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_composio_connector(
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from app.db import Document
|
|||
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str):
|
|||
document_id: ID of document to reindex
|
||||
user_id: ID of user who edited the document
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_reindex_document(document_id, user_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(lambda: _reindex_document(document_id, user_id))
|
||||
|
||||
|
||||
async def _reindex_document(document_id: int, user_id: str):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from app.celery_app import celery_app
|
|||
from app.config import config
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
from app.tasks.connector_indexers.local_folder_indexer import (
|
||||
index_local_folder,
|
||||
index_uploaded_files,
|
||||
|
|
@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int):
|
|||
)
|
||||
def delete_document_task(self, document_id: int):
|
||||
"""Celery task to delete a document and its chunks in batches."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_delete_document_background(document_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(lambda: _delete_document_background(document_id))
|
||||
|
||||
|
||||
async def _delete_document_background(document_id: int) -> None:
|
||||
|
|
@ -153,14 +148,9 @@ def delete_folder_documents_task(
|
|||
folder_subtree_ids: list[int] | None = None,
|
||||
):
|
||||
"""Celery task to delete documents first, then the folder rows."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_delete_folder_documents(document_ids, folder_subtree_ids)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _delete_folder_documents(document_ids, folder_subtree_ids)
|
||||
)
|
||||
|
||||
|
||||
async def _delete_folder_documents(
|
||||
|
|
@ -209,12 +199,9 @@ async def _delete_folder_documents(
|
|||
)
|
||||
def delete_search_space_task(self, search_space_id: int):
|
||||
"""Celery task to delete a search space and heavy child rows in batches."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_delete_search_space_background(search_space_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _delete_search_space_background(search_space_id)
|
||||
)
|
||||
|
||||
|
||||
async def _delete_search_space_background(search_space_id: int) -> None:
|
||||
|
|
@ -269,18 +256,11 @@ def process_extension_document_task(
|
|||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
"""
|
||||
# Create a new event loop for this task
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_process_extension_document(
|
||||
individual_document_dict, search_space_id, user_id
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _process_extension_document(
|
||||
individual_document_dict, search_space_id, user_id
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _process_extension_document(
|
||||
|
|
@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st
|
|||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _process_youtube_video(url, search_space_id, user_id)
|
||||
)
|
||||
|
||||
|
||||
async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
|
||||
|
|
@ -573,12 +549,9 @@ def process_file_upload_task(
|
|||
except Exception as e:
|
||||
logger.warning(f"[process_file_upload] Could not get file size: {e}")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_process_file_upload(file_path, filename, search_space_id, user_id)
|
||||
run_async_celery_task(
|
||||
lambda: _process_file_upload(file_path, filename, search_space_id, user_id)
|
||||
)
|
||||
logger.info(
|
||||
f"[process_file_upload] Task completed successfully for: {filename}"
|
||||
|
|
@ -589,8 +562,6 @@ def process_file_upload_task(
|
|||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _process_file_upload(
|
||||
|
|
@ -811,25 +782,17 @@ def process_file_upload_with_document_task(
|
|||
"File may have been removed before syncing could start."
|
||||
)
|
||||
# Mark document as failed since file is missing
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_mark_document_failed(
|
||||
document_id,
|
||||
"File not found. Please re-upload the file.",
|
||||
)
|
||||
run_async_celery_task(
|
||||
lambda: _mark_document_failed(
|
||||
document_id,
|
||||
"File not found. Please re-upload the file.",
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
return
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_process_file_with_document(
|
||||
run_async_celery_task(
|
||||
lambda: _process_file_with_document(
|
||||
document_id,
|
||||
temp_path,
|
||||
filename,
|
||||
|
|
@ -849,8 +812,6 @@ def process_file_upload_with_document_task(
|
|||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _mark_document_failed(document_id: int, reason: str):
|
||||
|
|
@ -1119,22 +1080,16 @@ def process_circleback_meeting_task(
|
|||
search_space_id: ID of the search space
|
||||
connector_id: ID of the Circleback connector (for deletion support)
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_process_circleback_meeting(
|
||||
meeting_id,
|
||||
meeting_name,
|
||||
markdown_content,
|
||||
metadata,
|
||||
search_space_id,
|
||||
connector_id,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _process_circleback_meeting(
|
||||
meeting_id,
|
||||
meeting_name,
|
||||
markdown_content,
|
||||
metadata,
|
||||
search_space_id,
|
||||
connector_id,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _process_circleback_meeting(
|
||||
|
|
@ -1291,25 +1246,19 @@ def index_local_folder_task(
|
|||
target_file_paths: list[str] | None = None,
|
||||
):
|
||||
"""Celery task to index a local folder. Config is passed directly — no connector row."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_local_folder_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_path=folder_path,
|
||||
folder_name=folder_name,
|
||||
exclude_patterns=exclude_patterns,
|
||||
file_extensions=file_extensions,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
target_file_paths=target_file_paths,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_local_folder_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_path=folder_path,
|
||||
folder_name=folder_name,
|
||||
exclude_patterns=exclude_patterns,
|
||||
file_extensions=file_extensions,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
target_file_paths=target_file_paths,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_local_folder_async(
|
||||
|
|
@ -1441,23 +1390,18 @@ def index_uploaded_folder_files_task(
|
|||
processing_mode: str = "basic",
|
||||
):
|
||||
"""Celery task to index files uploaded from the desktop app."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_uploaded_folder_files_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_name=folder_name,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
file_mappings=file_mappings,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_uploaded_folder_files_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_name=folder_name,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
file_mappings=file_mappings,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_uploaded_folder_files_async(
|
||||
|
|
@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str:
|
|||
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
|
||||
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
|
||||
"""Full AI sort for all documents in a search space."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _ai_sort_search_space_async(search_space_id, user_id)
|
||||
)
|
||||
|
||||
|
||||
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
|
||||
|
|
@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
|
|||
)
|
||||
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
|
||||
"""Incremental AI sort for a single document after indexing."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_ai_sort_document_async(search_space_id, user_id, document_id)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _ai_sort_document_async(search_space_id, user_id, document_id)
|
||||
)
|
||||
|
||||
|
||||
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):
|
||||
|
|
|
|||
|
|
@ -2,14 +2,13 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.db import SearchSourceConnector
|
||||
from app.schemas.obsidian_plugin import NotePayload
|
||||
from app.services.obsidian_plugin_indexer import upsert_note
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -22,18 +21,13 @@ def index_obsidian_attachment_task(
|
|||
user_id: str,
|
||||
) -> None:
|
||||
"""Process one Obsidian non-markdown attachment asynchronously."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_obsidian_attachment(
|
||||
connector_id=connector_id,
|
||||
payload_data=payload_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_obsidian_attachment(
|
||||
connector_id=connector_id,
|
||||
payload_data=payload_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_obsidian_attachment(
|
||||
|
|
|
|||
|
|
@ -3,14 +3,22 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State as PodcasterState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config as app_config
|
||||
from app.db import Podcast, PodcastStatus
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.services.billable_calls import (
|
||||
BillingSettlementError,
|
||||
QuotaInsufficientError,
|
||||
_resolve_agent_billing_for_search_space,
|
||||
billable_call,
|
||||
)
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -28,6 +36,13 @@ if sys.platform.startswith("win"):
|
|||
# =============================================================================
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _celery_billable_session():
|
||||
"""Session factory used by billable_call inside the Celery worker loop."""
|
||||
async with get_celery_session_maker()() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@celery_app.task(name="generate_content_podcast", bind=True)
|
||||
def generate_content_podcast_task(
|
||||
self,
|
||||
|
|
@ -40,27 +55,22 @@ def generate_content_podcast_task(
|
|||
Celery task to generate podcast from source content.
|
||||
Updates existing podcast record created by the tool.
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
_generate_content_podcast(
|
||||
return run_async_celery_task(
|
||||
lambda: _generate_content_podcast(
|
||||
podcast_id,
|
||||
source_content,
|
||||
search_space_id,
|
||||
user_prompt,
|
||||
)
|
||||
)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating content podcast: {e!s}")
|
||||
loop.run_until_complete(_mark_podcast_failed(podcast_id))
|
||||
try:
|
||||
run_async_celery_task(lambda: _mark_podcast_failed(podcast_id))
|
||||
except Exception:
|
||||
logger.exception("Failed to mark podcast %s as failed", podcast_id)
|
||||
return {"status": "failed", "podcast_id": podcast_id}
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _mark_podcast_failed(podcast_id: int) -> None:
|
||||
|
|
@ -96,6 +106,31 @@ async def _generate_content_podcast(
|
|||
podcast.status = PodcastStatus.GENERATING
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
(
|
||||
owner_user_id,
|
||||
billing_tier,
|
||||
base_model,
|
||||
) = await _resolve_agent_billing_for_search_space(
|
||||
session,
|
||||
search_space_id,
|
||||
thread_id=podcast.thread_id,
|
||||
)
|
||||
except ValueError as resolve_err:
|
||||
logger.error(
|
||||
"Podcast %s: cannot resolve billing for search_space=%s: %s",
|
||||
podcast.id,
|
||||
search_space_id,
|
||||
resolve_err,
|
||||
)
|
||||
podcast.status = PodcastStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"podcast_id": podcast.id,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"configurable": {
|
||||
"podcast_title": podcast.title,
|
||||
|
|
@ -109,9 +144,52 @@ async def _generate_content_podcast(
|
|||
db_session=session,
|
||||
)
|
||||
|
||||
graph_result = await podcaster_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
try:
|
||||
async with billable_call(
|
||||
user_id=owner_user_id,
|
||||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
|
||||
usage_type="podcast_generation",
|
||||
call_details={
|
||||
"podcast_id": podcast.id,
|
||||
"title": podcast.title,
|
||||
"thread_id": podcast.thread_id,
|
||||
},
|
||||
billable_session_factory=_celery_billable_session,
|
||||
):
|
||||
graph_result = await podcaster_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
except QuotaInsufficientError as exc:
|
||||
logger.info(
|
||||
"Podcast %s denied: out of premium credits "
|
||||
"(used=%d/%d remaining=%d)",
|
||||
podcast.id,
|
||||
exc.used_micros,
|
||||
exc.limit_micros,
|
||||
exc.remaining_micros,
|
||||
)
|
||||
podcast.status = PodcastStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"podcast_id": podcast.id,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
except BillingSettlementError:
|
||||
logger.exception(
|
||||
"Podcast %s: premium billing settlement failed",
|
||||
podcast.id,
|
||||
)
|
||||
podcast.status = PodcastStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"podcast_id": podcast.id,
|
||||
"reason": "billing_settlement_failed",
|
||||
}
|
||||
|
||||
podcast_transcript = graph_result.get("podcast_transcript", [])
|
||||
file_path = graph_result.get("final_podcast_file_path", "")
|
||||
|
|
@ -133,7 +211,14 @@ async def _generate_content_podcast(
|
|||
podcast.podcast_transcript = serializable_transcript
|
||||
podcast.file_location = file_path
|
||||
podcast.status = PodcastStatus.READY
|
||||
logger.info(
|
||||
"Podcast %s: committing READY transcript_entries=%d file=%s",
|
||||
podcast.id,
|
||||
len(serializable_transcript),
|
||||
file_path,
|
||||
)
|
||||
await session.commit()
|
||||
logger.info("Podcast %s: READY commit complete", podcast.id)
|
||||
|
||||
logger.info(f"Successfully generated podcast: {podcast.id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from sqlalchemy.future import select
|
|||
|
||||
from app.celery_app import celery_app
|
||||
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
from app.utils.indexing_locks import is_connector_indexing_locked
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -20,15 +20,7 @@ def check_periodic_schedules_task():
|
|||
This task runs every minute and triggers indexing for any connector
|
||||
whose next_scheduled_at time has passed.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_check_and_trigger_schedules())
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(_check_and_trigger_schedules)
|
||||
|
||||
|
||||
async def _check_and_trigger_schedules():
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from sqlalchemy.future import select
|
|||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import Document, DocumentStatus, Notification
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -69,16 +69,12 @@ def cleanup_stale_indexing_notifications_task():
|
|||
Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task.
|
||||
Also marks associated pending/processing documents as failed.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
async def _both() -> None:
|
||||
await _cleanup_stale_notifications()
|
||||
await _cleanup_stale_document_processing_notifications()
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_cleanup_stale_notifications())
|
||||
loop.run_until_complete(_cleanup_stale_document_processing_notifications())
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(_both)
|
||||
|
||||
|
||||
async def _cleanup_stale_notifications():
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
|
|
@ -18,7 +17,7 @@ from app.db import (
|
|||
PremiumTokenPurchaseStatus,
|
||||
)
|
||||
from app.routes import stripe_routes
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -36,13 +35,7 @@ def get_stripe_client() -> StripeClient | None:
|
|||
@celery_app.task(name="reconcile_pending_stripe_page_purchases")
|
||||
def reconcile_pending_stripe_page_purchases_task():
|
||||
"""Recover paid purchases that were left pending due to missed webhook handling."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_reconcile_pending_page_purchases())
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(_reconcile_pending_page_purchases)
|
||||
|
||||
|
||||
async def _reconcile_pending_page_purchases() -> None:
|
||||
|
|
@ -141,13 +134,7 @@ async def _reconcile_pending_page_purchases() -> None:
|
|||
@celery_app.task(name="reconcile_pending_stripe_token_purchases")
|
||||
def reconcile_pending_stripe_token_purchases_task():
|
||||
"""Recover paid token purchases that were left pending due to missed webhook handling."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_reconcile_pending_token_purchases())
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(_reconcile_pending_token_purchases)
|
||||
|
||||
|
||||
async def _reconcile_pending_token_purchases() -> None:
|
||||
|
|
|
|||
|
|
@ -3,14 +3,22 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.video_presentation.graph import graph as video_presentation_graph
|
||||
from app.agents.video_presentation.state import State as VideoPresentationState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config as app_config
|
||||
from app.db import VideoPresentation, VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.services.billable_calls import (
|
||||
BillingSettlementError,
|
||||
QuotaInsufficientError,
|
||||
_resolve_agent_billing_for_search_space,
|
||||
billable_call,
|
||||
)
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -23,6 +31,13 @@ if sys.platform.startswith("win"):
|
|||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _celery_billable_session():
|
||||
"""Session factory used by billable_call inside the Celery worker loop."""
|
||||
async with get_celery_session_maker()() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@celery_app.task(name="generate_video_presentation", bind=True)
|
||||
def generate_video_presentation_task(
|
||||
self,
|
||||
|
|
@ -35,27 +50,30 @@ def generate_video_presentation_task(
|
|||
Celery task to generate video presentation from source content.
|
||||
Updates existing video presentation record created by the tool.
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
_generate_video_presentation(
|
||||
return run_async_celery_task(
|
||||
lambda: _generate_video_presentation(
|
||||
video_presentation_id,
|
||||
source_content,
|
||||
search_space_id,
|
||||
user_prompt,
|
||||
)
|
||||
)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating video presentation: {e!s}")
|
||||
loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id))
|
||||
# Mark FAILED in a fresh loop — the previous loop is closed.
|
||||
# Swallow secondary failures; the row will simply stay in
|
||||
# GENERATING and be flushed by the periodic stale cleanup.
|
||||
try:
|
||||
run_async_celery_task(
|
||||
lambda: _mark_video_presentation_failed(video_presentation_id)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to mark video presentation %s as failed",
|
||||
video_presentation_id,
|
||||
)
|
||||
return {"status": "failed", "video_presentation_id": video_presentation_id}
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _mark_video_presentation_failed(video_presentation_id: int) -> None:
|
||||
|
|
@ -97,6 +115,32 @@ async def _generate_video_presentation(
|
|||
video_pres.status = VideoPresentationStatus.GENERATING
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
(
|
||||
owner_user_id,
|
||||
billing_tier,
|
||||
base_model,
|
||||
) = await _resolve_agent_billing_for_search_space(
|
||||
session,
|
||||
search_space_id,
|
||||
thread_id=video_pres.thread_id,
|
||||
)
|
||||
except ValueError as resolve_err:
|
||||
logger.error(
|
||||
"VideoPresentation %s: cannot resolve billing for "
|
||||
"search_space=%s: %s",
|
||||
video_pres.id,
|
||||
search_space_id,
|
||||
resolve_err,
|
||||
)
|
||||
video_pres.status = VideoPresentationStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"video_presentation_id": video_pres.id,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"configurable": {
|
||||
"video_title": video_pres.title,
|
||||
|
|
@ -110,9 +154,52 @@ async def _generate_video_presentation(
|
|||
db_session=session,
|
||||
)
|
||||
|
||||
graph_result = await video_presentation_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
try:
|
||||
async with billable_call(
|
||||
user_id=owner_user_id,
|
||||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
|
||||
usage_type="video_presentation_generation",
|
||||
call_details={
|
||||
"video_presentation_id": video_pres.id,
|
||||
"title": video_pres.title,
|
||||
"thread_id": video_pres.thread_id,
|
||||
},
|
||||
billable_session_factory=_celery_billable_session,
|
||||
):
|
||||
graph_result = await video_presentation_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
except QuotaInsufficientError as exc:
|
||||
logger.info(
|
||||
"VideoPresentation %s denied: out of premium credits "
|
||||
"(used=%d/%d remaining=%d)",
|
||||
video_pres.id,
|
||||
exc.used_micros,
|
||||
exc.limit_micros,
|
||||
exc.remaining_micros,
|
||||
)
|
||||
video_pres.status = VideoPresentationStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"video_presentation_id": video_pres.id,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
except BillingSettlementError:
|
||||
logger.exception(
|
||||
"VideoPresentation %s: premium billing settlement failed",
|
||||
video_pres.id,
|
||||
)
|
||||
video_pres.status = VideoPresentationStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"video_presentation_id": video_pres.id,
|
||||
"reason": "billing_settlement_failed",
|
||||
}
|
||||
|
||||
# Serialize slides (parsed content + audio info merged)
|
||||
slides_raw = graph_result.get("slides", [])
|
||||
|
|
@ -143,7 +230,14 @@ async def _generate_video_presentation(
|
|||
video_pres.slides = serializable_slides
|
||||
video_pres.scene_codes = serializable_scene_codes
|
||||
video_pres.status = VideoPresentationStatus.READY
|
||||
logger.info(
|
||||
"VideoPresentation %s: committing READY slides=%d scene_codes=%d",
|
||||
video_pres.id,
|
||||
len(serializable_slides),
|
||||
len(serializable_scene_codes),
|
||||
)
|
||||
await session.commit()
|
||||
logger.info("VideoPresentation %s: READY commit complete", video_pres.id)
|
||||
|
||||
logger.info(f"Successfully generated video presentation: {video_pres.id}")
|
||||
|
||||
|
|
|
|||
515
surfsense_backend/app/tasks/chat/content_builder.py
Normal file
515
surfsense_backend/app/tasks/chat/content_builder.py
Normal file
|
|
@ -0,0 +1,515 @@
|
|||
"""Server-side mirror of the frontend's assistant-ui ``ContentPart`` projection.
|
||||
|
||||
Background
|
||||
----------
|
||||
The streaming chat task in ``stream_new_chat`` / ``stream_resume_chat`` yields
|
||||
SSE events that the frontend folds into a ``ContentPartsState`` (see
|
||||
``surfsense_web/lib/chat/streaming-state.ts`` and the matching pipeline in
|
||||
``stream-pipeline.ts``). When a turn ends, the frontend calls
|
||||
``buildContentForPersistence(...)`` and round-trips that ``ContentPart[]``
|
||||
JSONB to ``POST /threads/{id}/messages``, which is what was historically
|
||||
written to ``new_chat_messages.content``.
|
||||
|
||||
After the ghost-thread fix moved persistence server-side, the assistant
|
||||
row is written by ``finalize_assistant_turn`` in the streaming finally
|
||||
block. The frontend's later ``appendMessage`` is now a no-op (recovers
|
||||
via the ``(thread_id, turn_id, role)`` partial unique index added in
|
||||
migration 141), which means the *server* is now responsible for
|
||||
producing the rich ``ContentPart[]`` shape the FE expects on history
|
||||
reload — text + reasoning + tool-call cards (with ``args``, ``argsText``,
|
||||
``result``, ``langchainToolCallId``) + thinking-step buckets +
|
||||
step-separators.
|
||||
|
||||
This module is the in-memory accumulator that mirrors the FE state for
|
||||
exactly that purpose. The streaming code calls ``on_text_*`` / ``on_reasoning_*``
|
||||
/ ``on_tool_*`` / ``on_thinking_step`` / ``on_step_separator`` /
|
||||
``mark_interrupted`` at the same call sites it yields the matching
|
||||
``streaming_service.format_*`` SSE event, so the in-memory ``parts`` list
|
||||
stays in lockstep with what the FE's pipeline would have produced live.
|
||||
``snapshot()`` is then taken once in the ``finally`` block and persisted
|
||||
in a single UPDATE.
|
||||
|
||||
Pure synchronous state — no DB I/O, no async, no flush callbacks. The
|
||||
streaming code is responsible for driving lifecycle methods; this class
|
||||
is a thin projection helper.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mirrors the FE's filter in ``buildContentForPersistence`` / ``buildContentForUI``:
|
||||
# only text/reasoning/tool-call parts count as "meaningful". data-thinking-steps
|
||||
# and data-step-separator decorate the meaningful parts but never stand alone
|
||||
# in a successful turn.
|
||||
_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"})
|
||||
|
||||
|
||||
class AssistantContentBuilder:
|
||||
"""Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``.
|
||||
|
||||
Output shape (deep copy of ``self.parts`` via ``snapshot()``) strictly
|
||||
matches the FE ``ContentPart`` union::
|
||||
|
||||
| { type: "text"; text: string }
|
||||
| { type: "reasoning"; text: string }
|
||||
| { type: "tool-call"; toolCallId: str; toolName: str;
|
||||
args: dict; result?: any; argsText?: str; langchainToolCallId?: str;
|
||||
state?: "aborted" }
|
||||
| { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] } }
|
||||
| { type: "data-step-separator"; data: { stepIndex: int } }
|
||||
|
||||
Order matches the wire order of the SSE events that drive the lifecycle
|
||||
methods, with two FE-mirrored exceptions:
|
||||
|
||||
1. ``data-thinking-steps`` is a *singleton* and pinned at index 0 the
|
||||
first time we see a ``data-thinking-step`` SSE event (the FE's
|
||||
``updateThinkingSteps`` does ``unshift`` on first sight). Subsequent
|
||||
thinking-step updates mutate that singleton in place.
|
||||
2. ``data-step-separator`` is appended only when the message already has
|
||||
meaningful content and the previous part isn't itself a separator
|
||||
(so the FIRST step of a turn doesn't generate a leading divider).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.parts: list[dict[str, Any]] = []
|
||||
# Index of the active text/reasoning part within ``parts`` while
|
||||
# streaming is open; -1 means "no active part" and the next delta
|
||||
# opens a fresh one. Mirrors ``ContentPartsState.currentTextPartIndex``.
|
||||
self._current_text_idx: int = -1
|
||||
self._current_reasoning_idx: int = -1
|
||||
# ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the
|
||||
# synthetic ``call_<run_id>`` (legacy) or the LangChain
|
||||
# ``tool_call.id`` (parity_v2) — same key the streaming layer
|
||||
# threads through every ``tool-input-*`` / ``tool-output-*`` event.
|
||||
self._tool_call_idx_by_ui_id: dict[str, int] = {}
|
||||
# Live argsText accumulator (concatenated ``tool-input-delta`` chunks)
|
||||
# so we can reproduce the FE's ``appendToolInputDelta`` behaviour
|
||||
# before ``tool-input-available`` overwrites it with the
|
||||
# pretty-printed final JSON.
|
||||
self._args_text_by_ui_id: dict[str, str] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Text
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def on_text_start(self, text_id: str) -> None:
|
||||
"""Begin a fresh text block.
|
||||
|
||||
Symmetric to FE ``appendText``: opening text closes any active
|
||||
reasoning so the renderer treats them as separate parts. The
|
||||
actual text part isn't materialised here — it's lazily created
|
||||
on the first ``on_text_delta`` so an empty start/end pair
|
||||
leaves no trace. Matches the FE pipeline which has no explicit
|
||||
``text-start`` handler at all.
|
||||
"""
|
||||
if self._current_reasoning_idx >= 0:
|
||||
self._current_reasoning_idx = -1
|
||||
|
||||
def on_text_delta(self, text_id: str, delta: str) -> None:
|
||||
if not delta:
|
||||
return
|
||||
if self._current_reasoning_idx >= 0:
|
||||
# FE behaviour: a text delta after reasoning implicitly
|
||||
# closes the reasoning block (see ``appendText`` lines
|
||||
# 178-180).
|
||||
self._current_reasoning_idx = -1
|
||||
if (
|
||||
self._current_text_idx >= 0
|
||||
and 0 <= self._current_text_idx < len(self.parts)
|
||||
and self.parts[self._current_text_idx].get("type") == "text"
|
||||
):
|
||||
self.parts[self._current_text_idx]["text"] += delta
|
||||
return
|
||||
self.parts.append({"type": "text", "text": delta})
|
||||
self._current_text_idx = len(self.parts) - 1
|
||||
|
||||
def on_text_end(self, text_id: str) -> None:
|
||||
"""Close the active text block.
|
||||
|
||||
Mirrors the wire-level ``text-end`` boundary the streaming layer
|
||||
emits before tool calls / reasoning / step boundaries. The FE
|
||||
pipeline implicitly closes via ``currentTextPartIndex = -1``
|
||||
in ``addToolCall`` / ``appendReasoning`` / ``addStepSeparator``;
|
||||
our helper does the same explicitly so callers don't have to
|
||||
maintain that invariant per call site.
|
||||
"""
|
||||
self._current_text_idx = -1
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reasoning
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def on_reasoning_start(self, reasoning_id: str) -> None:
|
||||
if self._current_text_idx >= 0:
|
||||
self._current_text_idx = -1
|
||||
|
||||
def on_reasoning_delta(self, reasoning_id: str, delta: str) -> None:
|
||||
if not delta:
|
||||
return
|
||||
if self._current_text_idx >= 0:
|
||||
self._current_text_idx = -1
|
||||
if (
|
||||
self._current_reasoning_idx >= 0
|
||||
and 0 <= self._current_reasoning_idx < len(self.parts)
|
||||
and self.parts[self._current_reasoning_idx].get("type") == "reasoning"
|
||||
):
|
||||
self.parts[self._current_reasoning_idx]["text"] += delta
|
||||
return
|
||||
self.parts.append({"type": "reasoning", "text": delta})
|
||||
self._current_reasoning_idx = len(self.parts) - 1
|
||||
|
||||
def on_reasoning_end(self, reasoning_id: str) -> None:
|
||||
self._current_reasoning_idx = -1
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool calls
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def on_tool_input_start(
|
||||
self,
|
||||
ui_id: str,
|
||||
tool_name: str,
|
||||
langchain_tool_call_id: str | None,
|
||||
) -> None:
|
||||
"""Register a tool-call card. Args are filled in by later events."""
|
||||
if not ui_id:
|
||||
return
|
||||
# Skip duplicate registration: parity_v2 may emit
|
||||
# ``tool-input-start`` from both ``on_chat_model_stream``
|
||||
# (when tool_call_chunks register a name) and ``on_tool_start``
|
||||
# (the canonical path). The FE de-dupes via ``toolCallIndices``;
|
||||
# we mirror that here.
|
||||
if ui_id in self._tool_call_idx_by_ui_id:
|
||||
if langchain_tool_call_id:
|
||||
idx = self._tool_call_idx_by_ui_id[ui_id]
|
||||
part = self.parts[idx]
|
||||
if not part.get("langchainToolCallId"):
|
||||
part["langchainToolCallId"] = langchain_tool_call_id
|
||||
return
|
||||
|
||||
part: dict[str, Any] = {
|
||||
"type": "tool-call",
|
||||
"toolCallId": ui_id,
|
||||
"toolName": tool_name,
|
||||
"args": {},
|
||||
}
|
||||
if langchain_tool_call_id:
|
||||
part["langchainToolCallId"] = langchain_tool_call_id
|
||||
self.parts.append(part)
|
||||
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
|
||||
|
||||
self._current_text_idx = -1
|
||||
self._current_reasoning_idx = -1
|
||||
|
||||
def on_tool_input_delta(self, ui_id: str, args_chunk: str) -> None:
|
||||
"""Append a streamed args-delta chunk to the matching card's argsText.
|
||||
|
||||
Mirrors FE ``appendToolInputDelta``: no-ops when no card has been
|
||||
registered yet for the given ``ui_id`` — the deltas have nowhere
|
||||
safe to land.
|
||||
"""
|
||||
if not ui_id or not args_chunk:
|
||||
return
|
||||
idx = self._tool_call_idx_by_ui_id.get(ui_id)
|
||||
if idx is None:
|
||||
return
|
||||
if not (0 <= idx < len(self.parts)):
|
||||
return
|
||||
part = self.parts[idx]
|
||||
if part.get("type") != "tool-call":
|
||||
return
|
||||
new_text = (part.get("argsText") or "") + args_chunk
|
||||
part["argsText"] = new_text
|
||||
self._args_text_by_ui_id[ui_id] = new_text
|
||||
|
||||
def on_tool_input_available(
|
||||
self,
|
||||
ui_id: str,
|
||||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
langchain_tool_call_id: str | None,
|
||||
) -> None:
|
||||
"""Finalize the tool-call card's input.
|
||||
|
||||
Mirrors FE ``stream-pipeline.ts`` lines 127-153: replaces ``argsText``
|
||||
with ``json.dumps(input, indent=2)`` so the post-stream card renders
|
||||
pretty-printed JSON, sets the full ``args`` dict, and backfills
|
||||
``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time.
|
||||
Also creates the card if no prior ``tool-input-start`` registered it
|
||||
(legacy parity_v2-OFF / late-registration paths).
|
||||
"""
|
||||
if not ui_id:
|
||||
return
|
||||
try:
|
||||
final_args_text = json.dumps(args or {}, indent=2, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
# Defensive: ``args`` should already be JSON-safe (the
|
||||
# streaming layer sanitizes it before emitting), but if a
|
||||
# caller hands us a non-serializable value we still want
|
||||
# to record the call without breaking the snapshot.
|
||||
final_args_text = str(args)
|
||||
|
||||
idx = self._tool_call_idx_by_ui_id.get(ui_id)
|
||||
if idx is not None and 0 <= idx < len(self.parts):
|
||||
part = self.parts[idx]
|
||||
if part.get("type") == "tool-call":
|
||||
part["args"] = args or {}
|
||||
part["argsText"] = final_args_text
|
||||
if langchain_tool_call_id and not part.get("langchainToolCallId"):
|
||||
part["langchainToolCallId"] = langchain_tool_call_id
|
||||
return
|
||||
|
||||
# No prior tool-input-start: register the card now.
|
||||
new_part: dict[str, Any] = {
|
||||
"type": "tool-call",
|
||||
"toolCallId": ui_id,
|
||||
"toolName": tool_name,
|
||||
"args": args or {},
|
||||
"argsText": final_args_text,
|
||||
}
|
||||
if langchain_tool_call_id:
|
||||
new_part["langchainToolCallId"] = langchain_tool_call_id
|
||||
self.parts.append(new_part)
|
||||
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
|
||||
|
||||
self._current_text_idx = -1
|
||||
self._current_reasoning_idx = -1
|
||||
|
||||
def on_tool_output_available(
|
||||
self,
|
||||
ui_id: str,
|
||||
output: Any,
|
||||
langchain_tool_call_id: str | None,
|
||||
) -> None:
|
||||
"""Attach the tool's output (``result``) to the matching card.
|
||||
|
||||
Mirrors FE ``updateToolCall``: backfill ``langchainToolCallId``
|
||||
only if not already set (a NULL late-arriving value never blows
|
||||
away an earlier known good one).
|
||||
"""
|
||||
if not ui_id:
|
||||
return
|
||||
idx = self._tool_call_idx_by_ui_id.get(ui_id)
|
||||
if idx is None or not (0 <= idx < len(self.parts)):
|
||||
return
|
||||
part = self.parts[idx]
|
||||
if part.get("type") != "tool-call":
|
||||
return
|
||||
part["result"] = output
|
||||
if langchain_tool_call_id and not part.get("langchainToolCallId"):
|
||||
part["langchainToolCallId"] = langchain_tool_call_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Thinking steps & step separators
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def on_thinking_step(
|
||||
self,
|
||||
step_id: str,
|
||||
title: str,
|
||||
status: str,
|
||||
items: list[str] | None,
|
||||
) -> None:
|
||||
"""Update / insert the singleton ``data-thinking-steps`` part.
|
||||
|
||||
Mirrors FE ``updateThinkingSteps``: maintain a single
|
||||
``data-thinking-steps`` part anchored at index 0, replacing or
|
||||
unshifting on first sight. Each ``on_thinking_step`` call
|
||||
replaces the entry in the steps list keyed by ``step_id`` (or
|
||||
appends if new).
|
||||
"""
|
||||
if not step_id:
|
||||
return
|
||||
|
||||
new_step = {
|
||||
"id": step_id,
|
||||
"title": title or "",
|
||||
"status": status or "in_progress",
|
||||
"items": list(items) if items else [],
|
||||
}
|
||||
|
||||
# Find existing data-thinking-steps part.
|
||||
existing_idx = -1
|
||||
for i, p in enumerate(self.parts):
|
||||
if p.get("type") == "data-thinking-steps":
|
||||
existing_idx = i
|
||||
break
|
||||
|
||||
if existing_idx >= 0:
|
||||
current_steps = self.parts[existing_idx].get("data", {}).get("steps") or []
|
||||
replaced = False
|
||||
for i, step in enumerate(current_steps):
|
||||
if step.get("id") == step_id:
|
||||
current_steps[i] = new_step
|
||||
replaced = True
|
||||
break
|
||||
if not replaced:
|
||||
current_steps.append(new_step)
|
||||
self.parts[existing_idx] = {
|
||||
"type": "data-thinking-steps",
|
||||
"data": {"steps": current_steps},
|
||||
}
|
||||
return
|
||||
|
||||
# First sight: unshift to position 0 (FE parity).
|
||||
self.parts.insert(
|
||||
0,
|
||||
{
|
||||
"type": "data-thinking-steps",
|
||||
"data": {"steps": [new_step]},
|
||||
},
|
||||
)
|
||||
# Bump tracked indices since we inserted at the head.
|
||||
if self._current_text_idx >= 0:
|
||||
self._current_text_idx += 1
|
||||
if self._current_reasoning_idx >= 0:
|
||||
self._current_reasoning_idx += 1
|
||||
for ui_id, idx in list(self._tool_call_idx_by_ui_id.items()):
|
||||
self._tool_call_idx_by_ui_id[ui_id] = idx + 1
|
||||
|
||||
def on_step_separator(self) -> None:
|
||||
"""Append a ``data-step-separator`` between consecutive model steps.
|
||||
|
||||
Mirrors FE ``addStepSeparator``: only emit when the message
|
||||
already has meaningful content AND the previous part isn't
|
||||
itself a separator. ``stepIndex`` is the running count of
|
||||
separators already in ``parts``.
|
||||
"""
|
||||
has_content = any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
|
||||
if not has_content:
|
||||
return
|
||||
if self.parts and self.parts[-1].get("type") == "data-step-separator":
|
||||
return
|
||||
step_index = sum(
|
||||
1 for p in self.parts if p.get("type") == "data-step-separator"
|
||||
)
|
||||
self.parts.append(
|
||||
{
|
||||
"type": "data-step-separator",
|
||||
"data": {"stepIndex": step_index},
|
||||
}
|
||||
)
|
||||
self._current_text_idx = -1
|
||||
self._current_reasoning_idx = -1
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Interruption handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def mark_interrupted(self) -> None:
|
||||
"""Close any open text/reasoning and flip running tools to aborted.
|
||||
|
||||
Called from the streaming ``finally`` block before ``snapshot()`` so
|
||||
the persisted JSONB reflects a coherent end-state even when the
|
||||
client disconnected mid-turn or the agent hit a fatal error.
|
||||
|
||||
- Active text/reasoning blocks: simply lose their "active"
|
||||
marker (no synthetic content appended). Whatever was streamed
|
||||
stays as-is.
|
||||
- Tool-call parts that never received a ``result`` get
|
||||
``state="aborted"`` so the FE history loader can render them
|
||||
as "interrupted" rather than "still running".
|
||||
"""
|
||||
self._current_text_idx = -1
|
||||
self._current_reasoning_idx = -1
|
||||
for part in self.parts:
|
||||
if part.get("type") != "tool-call":
|
||||
continue
|
||||
if "result" in part:
|
||||
continue
|
||||
part["state"] = "aborted"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Snapshot & introspection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def snapshot(self) -> list[dict[str, Any]]:
|
||||
"""Return a deep copy of ``parts`` ready for SQL UPDATE / json.dumps.
|
||||
|
||||
Deep-copied so callers that finalize from the shielded ``finally``
|
||||
block can't accidentally mutate the persisted payload while the
|
||||
SQL UPDATE is in flight (the streaming layer doesn't touch the
|
||||
builder after this call, but defensive copies are cheap and cheap
|
||||
is what we want in a finally block).
|
||||
"""
|
||||
return copy.deepcopy(self.parts)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""True if no meaningful content was captured.
|
||||
|
||||
``data-thinking-steps`` and ``data-step-separator`` decorate
|
||||
meaningful content but don't count on their own — a turn that
|
||||
only emitted a thinking step before being interrupted should
|
||||
still be treated as empty for the status-marker fallback.
|
||||
"""
|
||||
return not any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
|
||||
|
||||
def stats(self) -> dict[str, int]:
|
||||
"""Return counts of each part-type plus rough byte size.
|
||||
|
||||
Used by the streaming layer's perf logger so an ops dashboard
|
||||
can correlate finalize latency with payload size, and so a
|
||||
regression that quietly stops emitting tool-call parts (or
|
||||
starts emitting hundreds) shows up in [PERF] grep rather than
|
||||
only as a "history reload looks weird" bug report.
|
||||
|
||||
``bytes`` is the JSON-serialised payload length — what actually
|
||||
crosses the wire to PostgreSQL's JSONB column. We compute it
|
||||
with ``ensure_ascii=False`` to match the JSONB encoder's UTF-8
|
||||
on-disk layout closely enough for back-of-the-envelope sizing.
|
||||
Reasoning/text/tool-call/thinking-step/step-separator counts are
|
||||
independent so any one can spike without the others.
|
||||
|
||||
Defensive: ``json.dumps`` failure (a non-serializable value
|
||||
slipped past the streaming layer's sanitization) is reported as
|
||||
``bytes=-1`` rather than raised — perf logging must not be the
|
||||
thing that breaks the streaming finally block.
|
||||
"""
|
||||
text_blocks = 0
|
||||
reasoning_blocks = 0
|
||||
tool_calls = 0
|
||||
tool_calls_completed = 0
|
||||
tool_calls_aborted = 0
|
||||
thinking_step_parts = 0
|
||||
step_separators = 0
|
||||
|
||||
for part in self.parts:
|
||||
kind = part.get("type")
|
||||
if kind == "text":
|
||||
text_blocks += 1
|
||||
elif kind == "reasoning":
|
||||
reasoning_blocks += 1
|
||||
elif kind == "tool-call":
|
||||
tool_calls += 1
|
||||
if part.get("state") == "aborted":
|
||||
tool_calls_aborted += 1
|
||||
elif "result" in part:
|
||||
tool_calls_completed += 1
|
||||
elif kind == "data-thinking-steps":
|
||||
thinking_step_parts += 1
|
||||
elif kind == "data-step-separator":
|
||||
step_separators += 1
|
||||
|
||||
try:
|
||||
byte_size = len(json.dumps(self.parts, ensure_ascii=False, default=str))
|
||||
except (TypeError, ValueError):
|
||||
byte_size = -1
|
||||
|
||||
return {
|
||||
"parts": len(self.parts),
|
||||
"bytes": byte_size,
|
||||
"text": text_blocks,
|
||||
"reasoning": reasoning_blocks,
|
||||
"tool_calls": tool_calls,
|
||||
"tool_calls_completed": tool_calls_completed,
|
||||
"tool_calls_aborted": tool_calls_aborted,
|
||||
"thinking_step_parts": thinking_step_parts,
|
||||
"step_separators": step_separators,
|
||||
}
|
||||
534
surfsense_backend/app/tasks/chat/persistence.py
Normal file
534
surfsense_backend/app/tasks/chat/persistence.py
Normal file
|
|
@ -0,0 +1,534 @@
|
|||
"""Server-side message persistence helpers for the streaming chat agent.
|
||||
|
||||
Historically the streaming task (``stream_new_chat``/``stream_resume_chat``)
|
||||
left ``new_chat_messages`` empty and relied on the frontend to round-trip
|
||||
``POST /threads/{id}/messages`` afterwards. That gave authenticated clients
|
||||
a "ghost-thread" abuse vector: skip the round-trip and burn LLM tokens
|
||||
without leaving an audit trail. These helpers move both writes (the user
|
||||
turn that triggered the stream and the assistant turn the stream produced)
|
||||
into the server itself, idempotent against the partial unique index
|
||||
``uq_new_chat_messages_thread_turn_role`` so legacy frontends that *do*
|
||||
keep posting via ``appendMessage`` simply hit the unique-index recovery
|
||||
path on the second writer instead of creating duplicates.
|
||||
|
||||
Assistant turn lifecycle
|
||||
------------------------
|
||||
The assistant side is split into two helpers so we can capture the row id
|
||||
*before* the stream produces any output:
|
||||
|
||||
* ``persist_assistant_shell`` runs immediately after ``persist_user_turn``
|
||||
and INSERTs an empty assistant row anchored to ``(thread_id, turn_id,
|
||||
ASSISTANT)``. Returns the row id so the streaming layer can correlate
|
||||
later writes (token_usage, AgentActionLog future-correlation) against
|
||||
a stable PK from the start of the turn.
|
||||
* ``finalize_assistant_turn`` runs from the streaming ``finally`` block.
|
||||
It UPDATEs the row's ``content`` to the rich ``ContentPart[]`` snapshot
|
||||
produced server-side by ``AssistantContentBuilder`` and writes the
|
||||
``token_usage`` row using ``INSERT ... ON CONFLICT DO NOTHING`` against
|
||||
the ``uq_token_usage_message_id`` partial unique index from migration
|
||||
142, hard-eliminating any race against ``append_message``'s recovery
|
||||
branch.
|
||||
|
||||
Defensive contract
|
||||
------------------
|
||||
|
||||
* Every helper runs inside ``shielded_async_session()`` so ``session.close()``
|
||||
survives starlette's mid-stream cancel scope on client disconnect.
|
||||
* ``persist_user_turn`` and ``persist_assistant_shell`` use ``INSERT ... ON
|
||||
CONFLICT DO NOTHING ... RETURNING id`` keyed on the ``(thread_id, turn_id,
|
||||
role)`` partial unique index. On conflict the insert silently no-ops at
|
||||
the DB level — no Python ``IntegrityError`` is constructed, which
|
||||
eliminates spurious debugger pauses and keeps logs clean. On conflict a
|
||||
follow-up ``SELECT`` resolves the existing row id so the streaming layer
|
||||
can correlate writes against a stable PK.
|
||||
* ``finalize_assistant_turn`` is best-effort: it never raises. The
|
||||
streaming ``finally`` block calls it from within
|
||||
``anyio.CancelScope(shield=True)`` and any raised exception there
|
||||
would mask the real error.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import text as sa_text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import (
|
||||
NewChatMessage,
|
||||
NewChatMessageRole,
|
||||
NewChatThread,
|
||||
TokenUsage,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.services.token_tracking_service import (
|
||||
TurnTokenAccumulator,
|
||||
)
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
# Empty initial assistant content. ``finalize_assistant_turn`` overwrites
|
||||
# this in a single UPDATE at end-of-stream with the full ``ContentPart[]``
|
||||
# snapshot produced by ``AssistantContentBuilder``. We persist a one-element
|
||||
# list with an empty text part so a crash between shell-INSERT and finalize
|
||||
# leaves the row in a FE-renderable shape (blank bubble) instead of
|
||||
# blowing up the history loader.
|
||||
_EMPTY_SHELL_CONTENT: list[dict[str, Any]] = [{"type": "text", "text": ""}]
|
||||
|
||||
# Substituted content for genuinely empty turns (no text, no reasoning,
|
||||
# no tool calls). The streaming layer flips to this when
|
||||
# ``AssistantContentBuilder.is_empty()`` returns True so the persisted
|
||||
# row is at least somewhat self-describing instead of an empty text
|
||||
# bubble. The FE's ``ContentPart`` union doesn't include ``status``
|
||||
# yet, so the history loader will silently drop this part and render
|
||||
# a blank bubble (matches today's behaviour for empty turns); a follow-up
|
||||
# FE PR adds the explicit "no response" rendering.
|
||||
_STATUS_NO_RESPONSE: list[dict[str, Any]] = [
|
||||
{"type": "status", "text": "(no text response)"}
|
||||
]
|
||||
|
||||
|
||||
def _build_user_content(
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
mentioned_documents: list[dict[str, Any]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the persisted user-message ``content`` (assistant-ui v2 parts).
|
||||
|
||||
Mirrors the shape the existing frontend posts via
|
||||
``appendMessage`` (see ``surfsense_web/.../new-chat/[[...chat_id]]/page.tsx``):
|
||||
|
||||
[{"type": "text", "text": "..."},
|
||||
{"type": "image", "image": "data:..."},
|
||||
{"type": "mentioned-documents", "documents": [{"id": int,
|
||||
"title": str, "document_type": str}, ...]}]
|
||||
|
||||
The companion reader is
|
||||
``app.utils.user_message_multimodal.split_persisted_user_content_parts``
|
||||
which expects exactly this shape — keep them in sync.
|
||||
|
||||
``mentioned_documents``: optional list of ``{id, title, document_type}``
|
||||
dicts. When non-empty (and a ``mentioned-documents`` part is not already
|
||||
in some other input shape), a single ``{"type": "mentioned-documents",
|
||||
"documents": [...]}`` part is appended. Mirrors the FE injection at
|
||||
``page.tsx:281-286`` (``persistUserTurn``).
|
||||
"""
|
||||
parts: list[dict[str, Any]] = [{"type": "text", "text": user_query or ""}]
|
||||
for url in user_image_data_urls or ():
|
||||
if isinstance(url, str) and url:
|
||||
parts.append({"type": "image", "image": url})
|
||||
if mentioned_documents:
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for doc in mentioned_documents:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
doc_id = doc.get("id")
|
||||
title = doc.get("title")
|
||||
document_type = doc.get("document_type")
|
||||
if doc_id is None or title is None or document_type is None:
|
||||
continue
|
||||
normalized.append(
|
||||
{
|
||||
"id": doc_id,
|
||||
"title": str(title),
|
||||
"document_type": str(document_type),
|
||||
}
|
||||
)
|
||||
if normalized:
|
||||
parts.append({"type": "mentioned-documents", "documents": normalized})
|
||||
return parts
|
||||
|
||||
|
||||
async def persist_user_turn(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
turn_id: str,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None = None,
|
||||
mentioned_documents: list[dict[str, Any]] | None = None,
|
||||
) -> int | None:
|
||||
"""Persist the user-side row for a chat turn and return its ``id``.
|
||||
|
||||
Uses ``INSERT ... ON CONFLICT DO NOTHING ... RETURNING id`` keyed on the
|
||||
``(thread_id, turn_id, role)`` partial unique index from migration 141
|
||||
(``WHERE turn_id IS NOT NULL``). On conflict the insert silently no-ops
|
||||
at the DB level — no Python ``IntegrityError`` is constructed, which
|
||||
eliminates the debugger pause that ``justMyCode=false`` + async greenlet
|
||||
interactions used to produce, and keeps production logs clean.
|
||||
|
||||
Returns the ``id`` of the row that exists for this turn after the call:
|
||||
the freshly inserted ``id`` on the happy path, or the existing ``id``
|
||||
when a previous writer (legacy FE ``appendMessage`` racing the SSE
|
||||
stream, redelivered request, etc.) already wrote it. Returns ``None``
|
||||
only on genuine DB failure; the caller should yield a streaming error
|
||||
and abort the turn so we never produce a title/assistant row that
|
||||
isn't anchored to a persisted user message.
|
||||
|
||||
Other constraint violations (FK, NOT NULL, etc.) still raise
|
||||
``IntegrityError`` — only the ``(thread_id, turn_id, role)`` collision
|
||||
is silenced.
|
||||
"""
|
||||
if not turn_id:
|
||||
# Defensive: turn_id is always populated by the streaming path
|
||||
# before this helper is called. If it isn't, we cannot be
|
||||
# idempotent against the unique index — refuse to write rather
|
||||
# than create a row the unique index can't dedupe.
|
||||
logger.error(
|
||||
"persist_user_turn called without a turn_id (chat_id=%s); skipping",
|
||||
chat_id,
|
||||
)
|
||||
return None
|
||||
|
||||
t0 = time.perf_counter()
|
||||
outcome = "failed"
|
||||
resolved_id: int | None = None
|
||||
try:
|
||||
async with shielded_async_session() as ws:
|
||||
# Re-attach the thread row so we can also bump updated_at
|
||||
# in the same write — keeps the sidebar ordering accurate
|
||||
# when a user fires off a turn but never reaches the
|
||||
# legacy appendMessage.
|
||||
thread = await ws.get(NewChatThread, chat_id)
|
||||
author_uuid: UUID | None = None
|
||||
if user_id:
|
||||
try:
|
||||
author_uuid = UUID(user_id)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"persist_user_turn: invalid user_id=%r, persisting as anonymous",
|
||||
user_id,
|
||||
)
|
||||
|
||||
content_payload = _build_user_content(
|
||||
user_query, user_image_data_urls, mentioned_documents
|
||||
)
|
||||
insert_stmt = (
|
||||
pg_insert(NewChatMessage)
|
||||
.values(
|
||||
thread_id=chat_id,
|
||||
role=NewChatMessageRole.USER,
|
||||
content=content_payload,
|
||||
author_id=author_uuid,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["thread_id", "turn_id", "role"],
|
||||
index_where=sa_text("turn_id IS NOT NULL"),
|
||||
)
|
||||
.returning(NewChatMessage.id)
|
||||
)
|
||||
inserted_id = (await ws.execute(insert_stmt)).scalar()
|
||||
|
||||
if inserted_id is None:
|
||||
# Conflict on partial unique index — another writer
|
||||
# (legacy FE appendMessage, redelivered request, etc.)
|
||||
# already persisted this row. Look it up and reuse.
|
||||
lookup = await ws.execute(
|
||||
select(NewChatMessage.id).where(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
NewChatMessage.turn_id == turn_id,
|
||||
NewChatMessage.role == NewChatMessageRole.USER,
|
||||
)
|
||||
)
|
||||
existing_id = lookup.scalars().first()
|
||||
if existing_id is None:
|
||||
# Conflict reported but no row found — extremely
|
||||
# unlikely (concurrent DELETE). Surface as failure.
|
||||
logger.warning(
|
||||
"persist_user_turn: conflict but no matching row "
|
||||
"(chat_id=%s, turn_id=%s)",
|
||||
chat_id,
|
||||
turn_id,
|
||||
)
|
||||
outcome = "integrity_no_match"
|
||||
return None
|
||||
resolved_id = int(existing_id)
|
||||
outcome = "race_recovered"
|
||||
else:
|
||||
resolved_id = int(inserted_id)
|
||||
outcome = "inserted"
|
||||
# Bump thread.updated_at only on a real insert — when
|
||||
# we recovered an existing row the prior writer
|
||||
# already touched the thread.
|
||||
if thread is not None:
|
||||
thread.updated_at = datetime.now(UTC)
|
||||
|
||||
await ws.commit()
|
||||
return resolved_id
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"persist_user_turn failed (chat_id=%s, turn_id=%s)",
|
||||
chat_id,
|
||||
turn_id,
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
_perf_log.info(
|
||||
"[persist_user_turn] outcome=%s chat_id=%s turn_id=%s "
|
||||
"message_id=%s query_len=%d images=%d mentioned_docs=%d "
|
||||
"in %.3fs",
|
||||
outcome,
|
||||
chat_id,
|
||||
turn_id,
|
||||
resolved_id,
|
||||
len(user_query or ""),
|
||||
len(user_image_data_urls or ()),
|
||||
len(mentioned_documents or ()),
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
|
||||
|
||||
async def persist_assistant_shell(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
turn_id: str,
|
||||
) -> int | None:
|
||||
"""Pre-write an empty assistant row for the turn and return its id.
|
||||
|
||||
Inserts a placeholder ``new_chat_messages`` row (empty text content) so
|
||||
the streaming layer has a stable ``message_id`` to correlate against
|
||||
for the rest of the turn. ``finalize_assistant_turn`` overwrites the
|
||||
``content`` field at end-of-stream with the rich ``ContentPart[]``
|
||||
snapshot produced by ``AssistantContentBuilder``.
|
||||
|
||||
Returns the row id on success, ``None`` on a genuine DB failure (caller
|
||||
should abort the turn rather than stream into a void).
|
||||
|
||||
Idempotent against the ``(thread_id, turn_id, ASSISTANT)`` partial unique
|
||||
index from migration 141: if a row already exists (resume retry, racing
|
||||
legacy frontend, redelivered request, etc.) we look it up by
|
||||
``(thread_id, turn_id, role)`` and return its existing id. The streaming
|
||||
layer is then free to UPDATE that row at finalize time.
|
||||
"""
|
||||
if not turn_id:
|
||||
logger.error(
|
||||
"persist_assistant_shell called without a turn_id (chat_id=%s); skipping",
|
||||
chat_id,
|
||||
)
|
||||
return None
|
||||
|
||||
t0 = time.perf_counter()
|
||||
outcome = "failed"
|
||||
resolved_id: int | None = None
|
||||
try:
|
||||
async with shielded_async_session() as ws:
|
||||
insert_stmt = (
|
||||
pg_insert(NewChatMessage)
|
||||
.values(
|
||||
thread_id=chat_id,
|
||||
role=NewChatMessageRole.ASSISTANT,
|
||||
content=_EMPTY_SHELL_CONTENT,
|
||||
author_id=None,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["thread_id", "turn_id", "role"],
|
||||
index_where=sa_text("turn_id IS NOT NULL"),
|
||||
)
|
||||
.returning(NewChatMessage.id)
|
||||
)
|
||||
inserted_id = (await ws.execute(insert_stmt)).scalar()
|
||||
|
||||
if inserted_id is None:
|
||||
# Conflict — another writer (legacy FE appendMessage,
|
||||
# resume retry, redelivered request) wrote the
|
||||
# (thread_id, turn_id, ASSISTANT) row first. Look it up
|
||||
# so the streaming layer can UPDATE the same row at
|
||||
# finalize time.
|
||||
lookup = await ws.execute(
|
||||
select(NewChatMessage.id).where(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
NewChatMessage.turn_id == turn_id,
|
||||
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
|
||||
)
|
||||
)
|
||||
existing_id = lookup.scalars().first()
|
||||
if existing_id is None:
|
||||
logger.warning(
|
||||
"persist_assistant_shell: conflict but no matching "
|
||||
"(thread_id, turn_id, role) row found "
|
||||
"(chat_id=%s, turn_id=%s)",
|
||||
chat_id,
|
||||
turn_id,
|
||||
)
|
||||
outcome = "integrity_no_match"
|
||||
return None
|
||||
resolved_id = int(existing_id)
|
||||
outcome = "race_recovered"
|
||||
else:
|
||||
resolved_id = int(inserted_id)
|
||||
outcome = "inserted"
|
||||
|
||||
await ws.commit()
|
||||
return resolved_id
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"persist_assistant_shell failed (chat_id=%s, turn_id=%s)",
|
||||
chat_id,
|
||||
turn_id,
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
_perf_log.info(
|
||||
"[persist_assistant_shell] outcome=%s chat_id=%s turn_id=%s "
|
||||
"message_id=%s in %.3fs",
|
||||
outcome,
|
||||
chat_id,
|
||||
turn_id,
|
||||
resolved_id,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
|
||||
|
||||
async def finalize_assistant_turn(
|
||||
*,
|
||||
message_id: int,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
turn_id: str,
|
||||
content: list[dict[str, Any]],
|
||||
accumulator: TurnTokenAccumulator | None,
|
||||
) -> None:
|
||||
"""Finalize the assistant row and write its token_usage.
|
||||
|
||||
Two writes in a single shielded session:
|
||||
|
||||
1. ``UPDATE new_chat_messages SET content = :c, updated_at = now()
|
||||
WHERE id = :id`` — overwrites the placeholder ``persist_assistant_shell``
|
||||
wrote with the full ``ContentPart[]`` snapshot produced server-side.
|
||||
2. ``INSERT INTO token_usage (...) VALUES (...) ON CONFLICT (message_id)
|
||||
WHERE message_id IS NOT NULL DO NOTHING`` — uses the partial unique
|
||||
index ``uq_token_usage_message_id`` from migration 142 to make the
|
||||
insert idempotent against ``append_message``'s recovery branch
|
||||
(which uses the same ON CONFLICT clause).
|
||||
|
||||
Substitutes the status-marker payload when ``content`` is empty
|
||||
(pure tool-call turn that aborted before any output, or interrupt
|
||||
before any event arrived). The status marker is preferable to a
|
||||
blank text bubble because token accounting still runs and an ops
|
||||
dashboard can flag the row.
|
||||
|
||||
Best-effort — never raises. The streaming ``finally`` calls this
|
||||
from within ``anyio.CancelScope(shield=True)``; any raised exception
|
||||
here would mask the real error that triggered the cleanup.
|
||||
"""
|
||||
if not turn_id:
|
||||
logger.error(
|
||||
"finalize_assistant_turn called without turn_id "
|
||||
"(chat_id=%s, message_id=%s); skipping",
|
||||
chat_id,
|
||||
message_id,
|
||||
)
|
||||
return
|
||||
if not message_id:
|
||||
logger.error(
|
||||
"finalize_assistant_turn called without message_id "
|
||||
"(chat_id=%s, turn_id=%s); skipping",
|
||||
chat_id,
|
||||
turn_id,
|
||||
)
|
||||
return
|
||||
|
||||
payload: list[dict[str, Any]]
|
||||
is_status_marker = False
|
||||
if content:
|
||||
payload = content
|
||||
else:
|
||||
payload = _STATUS_NO_RESPONSE
|
||||
is_status_marker = True
|
||||
|
||||
t0 = time.perf_counter()
|
||||
outcome = "failed"
|
||||
token_usage_attempted = bool(
|
||||
accumulator is not None and accumulator.calls and user_id
|
||||
)
|
||||
try:
|
||||
async with shielded_async_session() as ws:
|
||||
assistant_row = await ws.get(NewChatMessage, message_id)
|
||||
if assistant_row is None:
|
||||
logger.warning(
|
||||
"finalize_assistant_turn: row not found "
|
||||
"(chat_id=%s, message_id=%s, turn_id=%s); skipping",
|
||||
chat_id,
|
||||
message_id,
|
||||
turn_id,
|
||||
)
|
||||
outcome = "row_missing"
|
||||
return
|
||||
|
||||
assistant_row.content = payload
|
||||
assistant_row.updated_at = datetime.now(UTC)
|
||||
|
||||
# Token usage. ``record_token_usage`` (used elsewhere) does
|
||||
# SELECT-then-INSERT in two statements which races with
|
||||
# ``append_message``. Switch to a single INSERT ... ON
|
||||
# CONFLICT DO NOTHING keyed on the migration-142 partial
|
||||
# unique index so the loser silently drops its write at
|
||||
# the DB level — exactly one row per ``message_id``,
|
||||
# regardless of which session committed first.
|
||||
if accumulator is not None and accumulator.calls and user_id:
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"finalize_assistant_turn: invalid user_id=%r, "
|
||||
"skipping token_usage row",
|
||||
user_id,
|
||||
)
|
||||
else:
|
||||
insert_stmt = (
|
||||
pg_insert(TokenUsage)
|
||||
.values(
|
||||
usage_type="chat",
|
||||
prompt_tokens=accumulator.total_prompt_tokens,
|
||||
completion_tokens=accumulator.total_completion_tokens,
|
||||
total_tokens=accumulator.grand_total,
|
||||
cost_micros=accumulator.total_cost_micros,
|
||||
model_breakdown=accumulator.per_message_summary(),
|
||||
call_details={"calls": accumulator.serialized_calls()},
|
||||
thread_id=chat_id,
|
||||
message_id=message_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_uuid,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["message_id"],
|
||||
index_where=sa_text("message_id IS NOT NULL"),
|
||||
)
|
||||
)
|
||||
await ws.execute(insert_stmt)
|
||||
|
||||
await ws.commit()
|
||||
outcome = "ok"
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"finalize_assistant_turn failed (chat_id=%s, message_id=%s, turn_id=%s)",
|
||||
chat_id,
|
||||
message_id,
|
||||
turn_id,
|
||||
)
|
||||
finally:
|
||||
_perf_log.info(
|
||||
"[finalize_assistant_turn] outcome=%s chat_id=%s message_id=%s "
|
||||
"turn_id=%s parts=%d status_marker=%s "
|
||||
"token_usage_attempted=%s in %.3fs",
|
||||
outcome,
|
||||
chat_id,
|
||||
message_id,
|
||||
turn_id,
|
||||
len(payload),
|
||||
is_status_marker,
|
||||
token_usage_attempted,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
|
|||
IndexingPipelineService,
|
||||
PlaceholderInfo,
|
||||
)
|
||||
from app.services.composio_service import ComposioService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.google_credentials import (
|
||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
||||
build_composio_credentials,
|
||||
)
|
||||
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||
|
||||
from .base import (
|
||||
check_duplicate_document_by_hash,
|
||||
|
|
@ -44,6 +42,10 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
|||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
|
||||
def _format_calendar_event_to_markdown(event: dict) -> str:
|
||||
return GoogleCalendarConnector.format_event_to_markdown(None, event)
|
||||
|
||||
|
||||
def _build_connector_doc(
|
||||
event: dict,
|
||||
event_markdown: str,
|
||||
|
|
@ -150,7 +152,14 @@ async def index_google_calendar_events(
|
|||
)
|
||||
return 0, 0, f"Connector with ID {connector_id} not found"
|
||||
|
||||
# ── Credential building ───────────────────────────────────────
|
||||
is_composio_connector = (
|
||||
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||
)
|
||||
calendar_client = None
|
||||
composio_service = None
|
||||
connected_account_id = None
|
||||
|
||||
# ── Credential/client building ────────────────────────────────
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
|
|
@ -161,7 +170,7 @@ async def index_google_calendar_events(
|
|||
{"error_type": "MissingComposioAccount"},
|
||||
)
|
||||
return 0, 0, "Composio connected_account_id not found"
|
||||
credentials = build_composio_credentials(connected_account_id)
|
||||
composio_service = ComposioService()
|
||||
else:
|
||||
config_data = connector.config
|
||||
|
||||
|
|
@ -229,12 +238,13 @@ async def index_google_calendar_events(
|
|||
{"stage": "client_initialization"},
|
||||
)
|
||||
|
||||
calendar_client = GoogleCalendarConnector(
|
||||
credentials=credentials,
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
if not is_composio_connector:
|
||||
calendar_client = GoogleCalendarConnector(
|
||||
credentials=credentials,
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
|
||||
# Handle 'undefined' string from frontend (treat as None)
|
||||
if start_date == "undefined" or start_date == "":
|
||||
|
|
@ -300,9 +310,26 @@ async def index_google_calendar_events(
|
|||
)
|
||||
|
||||
try:
|
||||
events, error = await calendar_client.get_all_primary_calendar_events(
|
||||
start_date=start_date_str, end_date=end_date_str
|
||||
)
|
||||
if is_composio_connector:
|
||||
start_dt = parse_date_flexible(start_date_str).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
end_dt = parse_date_flexible(end_date_str).replace(
|
||||
hour=23, minute=59, second=59, microsecond=0
|
||||
)
|
||||
events, error = await composio_service.get_calendar_events(
|
||||
connected_account_id=connected_account_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
time_min=start_dt.isoformat(),
|
||||
time_max=end_dt.isoformat(),
|
||||
max_results=250,
|
||||
)
|
||||
if not events and not error:
|
||||
error = "No events found in the specified date range."
|
||||
else:
|
||||
events, error = await calendar_client.get_all_primary_calendar_events(
|
||||
start_date=start_date_str, end_date=end_date_str
|
||||
)
|
||||
|
||||
if error:
|
||||
if "No events found" in error:
|
||||
|
|
@ -381,7 +408,7 @@ async def index_google_calendar_events(
|
|||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
event_markdown = calendar_client.format_event_to_markdown(event)
|
||||
event_markdown = _format_calendar_event_to_markdown(event)
|
||||
if not event_markdown.strip():
|
||||
logger.warning(f"Skipping event with no content: {event_summary}")
|
||||
documents_skipped += 1
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ import asyncio
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import String, cast, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
|
@ -37,6 +39,7 @@ from app.indexing_pipeline.indexing_pipeline_service import (
|
|||
IndexingPipelineService,
|
||||
PlaceholderInfo,
|
||||
)
|
||||
from app.services.composio_service import ComposioService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
|
|
@ -45,10 +48,7 @@ from app.tasks.connector_indexers.base import (
|
|||
get_connector_by_id,
|
||||
update_connector_last_indexed,
|
||||
)
|
||||
from app.utils.google_credentials import (
|
||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
||||
build_composio_credentials,
|
||||
)
|
||||
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||
|
||||
ACCEPTED_DRIVE_CONNECTOR_TYPES = {
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
|
|
@ -61,6 +61,209 @@ HEARTBEAT_INTERVAL_SECONDS = 30
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ComposioDriveClient:
|
||||
"""Google Drive client facade backed by Composio tool execution.
|
||||
|
||||
Composio-managed OAuth connections can execute tools without exposing raw
|
||||
OAuth tokens through connected account state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
):
|
||||
self.session = session
|
||||
self.connector_id = connector_id
|
||||
self.connected_account_id = connected_account_id
|
||||
self.entity_id = entity_id
|
||||
self.composio = ComposioService()
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
query: str = "",
|
||||
fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)",
|
||||
page_size: int = 100,
|
||||
page_token: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], str | None, str | None]:
|
||||
params: dict[str, Any] = {
|
||||
"page_size": min(page_size, 100),
|
||||
"fields": fields,
|
||||
}
|
||||
if query:
|
||||
params["q"] = query
|
||||
if page_token:
|
||||
params["page_token"] = page_token
|
||||
|
||||
result = await self.composio.execute_tool(
|
||||
connected_account_id=self.connected_account_id,
|
||||
tool_name="GOOGLEDRIVE_LIST_FILES",
|
||||
params=params,
|
||||
entity_id=self.entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return [], None, result.get("error", "Unknown error")
|
||||
|
||||
data = result.get("data", {})
|
||||
files = []
|
||||
next_token = None
|
||||
if isinstance(data, dict):
|
||||
inner_data = data.get("data", data)
|
||||
if isinstance(inner_data, dict):
|
||||
files = inner_data.get("files", [])
|
||||
next_token = inner_data.get("nextPageToken") or inner_data.get(
|
||||
"next_page_token"
|
||||
)
|
||||
elif isinstance(data, list):
|
||||
files = data
|
||||
|
||||
return files, next_token, None
|
||||
|
||||
async def get_file_metadata(
|
||||
self, file_id: str, fields: str = "*"
|
||||
) -> tuple[dict[str, Any] | None, str | None]:
|
||||
result = await self.composio.execute_tool(
|
||||
connected_account_id=self.connected_account_id,
|
||||
tool_name="GOOGLEDRIVE_GET_FILE_METADATA",
|
||||
params={"file_id": file_id, "fields": fields},
|
||||
entity_id=self.entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, result.get("error", "Unknown error")
|
||||
|
||||
data = result.get("data", {})
|
||||
if isinstance(data, dict):
|
||||
inner_data = data.get("data", data)
|
||||
if isinstance(inner_data, dict):
|
||||
return inner_data, None
|
||||
|
||||
return None, "Could not extract metadata from Composio response"
|
||||
|
||||
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
|
||||
return await self._download_file_content(file_id)
|
||||
|
||||
async def download_file_to_disk(
|
||||
self,
|
||||
file_id: str,
|
||||
dest_path: str,
|
||||
chunksize: int = 5 * 1024 * 1024,
|
||||
) -> str | None:
|
||||
del chunksize
|
||||
content, error = await self.download_file(file_id)
|
||||
if error:
|
||||
return error
|
||||
if content is None:
|
||||
return "No content returned from Composio"
|
||||
Path(dest_path).write_bytes(content)
|
||||
return None
|
||||
|
||||
async def export_google_file(
|
||||
self, file_id: str, mime_type: str
|
||||
) -> tuple[bytes | None, str | None]:
|
||||
return await self._download_file_content(file_id, mime_type=mime_type)
|
||||
|
||||
async def _download_file_content(
|
||||
self, file_id: str, mime_type: str | None = None
|
||||
) -> tuple[bytes | None, str | None]:
|
||||
params: dict[str, Any] = {"file_id": file_id}
|
||||
if mime_type:
|
||||
params["mime_type"] = mime_type
|
||||
|
||||
result = await self.composio.execute_tool(
|
||||
connected_account_id=self.connected_account_id,
|
||||
tool_name="GOOGLEDRIVE_DOWNLOAD_FILE",
|
||||
params=params,
|
||||
entity_id=self.entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, result.get("error", "Unknown error")
|
||||
|
||||
return self._read_download_result(result.get("data"))
|
||||
|
||||
def _read_download_result(self, data: Any) -> tuple[bytes | None, str | None]:
|
||||
if isinstance(data, bytes):
|
||||
return data, None
|
||||
|
||||
file_path: str | None = None
|
||||
if isinstance(data, str):
|
||||
file_path = data
|
||||
elif isinstance(data, dict):
|
||||
inner_data = data.get("data", data)
|
||||
if isinstance(inner_data, dict):
|
||||
for key in ("file_path", "downloaded_file_content", "path", "uri"):
|
||||
value = inner_data.get(key)
|
||||
if isinstance(value, str):
|
||||
file_path = value
|
||||
break
|
||||
if isinstance(value, dict):
|
||||
nested = (
|
||||
value.get("file_path")
|
||||
or value.get("downloaded_file_content")
|
||||
or value.get("path")
|
||||
or value.get("uri")
|
||||
or value.get("s3url")
|
||||
)
|
||||
if isinstance(nested, str):
|
||||
file_path = nested
|
||||
break
|
||||
|
||||
if not file_path:
|
||||
return None, "No file path/content returned from Composio"
|
||||
|
||||
if file_path.startswith(("http://", "https://")):
|
||||
try:
|
||||
import urllib.request
|
||||
|
||||
with urllib.request.urlopen(file_path, timeout=60) as response:
|
||||
return response.read(), None
|
||||
except Exception as e:
|
||||
return None, f"Failed to download Composio file URL: {e!s}"
|
||||
|
||||
path_obj = Path(file_path)
|
||||
if path_obj.is_absolute() or ".composio" in str(path_obj):
|
||||
if not path_obj.exists():
|
||||
return None, f"File not found at path: {file_path}"
|
||||
return path_obj.read_bytes(), None
|
||||
|
||||
try:
|
||||
import base64
|
||||
|
||||
return base64.b64decode(file_path), None
|
||||
except Exception:
|
||||
return file_path.encode("utf-8"), None
|
||||
|
||||
|
||||
def _build_drive_client_for_connector(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
connector: object,
|
||||
user_id: str,
|
||||
) -> tuple[GoogleDriveClient | ComposioDriveClient | None, str | None]:
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
return None, (
|
||||
f"Composio connected_account_id not found for connector {connector_id}"
|
||||
)
|
||||
return (
|
||||
ComposioDriveClient(
|
||||
session,
|
||||
connector_id,
|
||||
connected_account_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and not config.SECRET_KEY:
|
||||
return None, "SECRET_KEY not configured but credentials are marked as encrypted"
|
||||
|
||||
return GoogleDriveClient(session, connector_id), None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -927,34 +1130,17 @@ async def index_google_drive_files(
|
|||
{"stage": "client_initialization"},
|
||||
)
|
||||
|
||||
pre_built_credentials = None
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
error_msg,
|
||||
"Missing Composio account",
|
||||
{"error_type": "MissingComposioAccount"},
|
||||
)
|
||||
return 0, 0, error_msg, 0
|
||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
||||
else:
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and not config.SECRET_KEY:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
"SECRET_KEY not configured but credentials are encrypted",
|
||||
"Missing SECRET_KEY",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return (
|
||||
0,
|
||||
0,
|
||||
"SECRET_KEY not configured but credentials are marked as encrypted",
|
||||
0,
|
||||
)
|
||||
drive_client, client_error = _build_drive_client_for_connector(
|
||||
session, connector_id, connector, user_id
|
||||
)
|
||||
if client_error or not drive_client:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
client_error or "Failed to initialize Google Drive client",
|
||||
"Missing connector credentials",
|
||||
{"error_type": "ClientInitializationError"},
|
||||
)
|
||||
return 0, 0, client_error, 0
|
||||
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
|
||||
|
|
@ -963,10 +1149,6 @@ async def index_google_drive_files(
|
|||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(session, search_space_id)
|
||||
drive_client = GoogleDriveClient(
|
||||
session, connector_id, credentials=pre_built_credentials
|
||||
)
|
||||
|
||||
if not folder_id:
|
||||
error_msg = "folder_id is required for Google Drive indexing"
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -979,8 +1161,14 @@ async def index_google_drive_files(
|
|||
|
||||
folder_tokens = connector.config.get("folder_tokens", {})
|
||||
start_page_token = folder_tokens.get(target_folder_id)
|
||||
is_composio_connector = (
|
||||
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||
)
|
||||
can_use_delta = (
|
||||
use_delta_sync and start_page_token and connector.last_indexed_at
|
||||
not is_composio_connector
|
||||
and use_delta_sync
|
||||
and start_page_token
|
||||
and connector.last_indexed_at
|
||||
)
|
||||
|
||||
documents_unsupported = 0
|
||||
|
|
@ -1051,7 +1239,16 @@ async def index_google_drive_files(
|
|||
)
|
||||
|
||||
if documents_indexed > 0 or can_use_delta:
|
||||
new_token, token_error = await get_start_page_token(drive_client)
|
||||
if isinstance(drive_client, ComposioDriveClient):
|
||||
(
|
||||
new_token,
|
||||
token_error,
|
||||
) = await drive_client.composio.get_drive_start_page_token(
|
||||
drive_client.connected_account_id,
|
||||
drive_client.entity_id,
|
||||
)
|
||||
else:
|
||||
new_token, token_error = await get_start_page_token(drive_client)
|
||||
if new_token and not token_error:
|
||||
await session.refresh(connector)
|
||||
if "folder_tokens" not in connector.config:
|
||||
|
|
@ -1137,32 +1334,17 @@ async def index_google_drive_single_file(
|
|||
)
|
||||
return 0, error_msg
|
||||
|
||||
pre_built_credentials = None
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
error_msg,
|
||||
"Missing Composio account",
|
||||
{"error_type": "MissingComposioAccount"},
|
||||
)
|
||||
return 0, error_msg
|
||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
||||
else:
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and not config.SECRET_KEY:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
"SECRET_KEY not configured but credentials are encrypted",
|
||||
"Missing SECRET_KEY",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return (
|
||||
0,
|
||||
"SECRET_KEY not configured but credentials are marked as encrypted",
|
||||
)
|
||||
drive_client, client_error = _build_drive_client_for_connector(
|
||||
session, connector_id, connector, user_id
|
||||
)
|
||||
if client_error or not drive_client:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
client_error or "Failed to initialize Google Drive client",
|
||||
"Missing connector credentials",
|
||||
{"error_type": "ClientInitializationError"},
|
||||
)
|
||||
return 0, client_error
|
||||
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
|
||||
|
|
@ -1171,10 +1353,6 @@ async def index_google_drive_single_file(
|
|||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(session, search_space_id)
|
||||
drive_client = GoogleDriveClient(
|
||||
session, connector_id, credentials=pre_built_credentials
|
||||
)
|
||||
|
||||
file, error = await get_file_by_id(drive_client, file_id)
|
||||
if error or not file:
|
||||
error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}"
|
||||
|
|
@ -1276,32 +1454,18 @@ async def index_google_drive_selected_files(
|
|||
)
|
||||
return 0, 0, [error_msg]
|
||||
|
||||
pre_built_credentials = None
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
error_msg,
|
||||
"Missing Composio account",
|
||||
{"error_type": "MissingComposioAccount"},
|
||||
)
|
||||
return 0, 0, [error_msg]
|
||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
||||
else:
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and not config.SECRET_KEY:
|
||||
error_msg = (
|
||||
"SECRET_KEY not configured but credentials are marked as encrypted"
|
||||
)
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
error_msg,
|
||||
"Missing SECRET_KEY",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return 0, 0, [error_msg]
|
||||
drive_client, client_error = _build_drive_client_for_connector(
|
||||
session, connector_id, connector, user_id
|
||||
)
|
||||
if client_error or not drive_client:
|
||||
error_msg = client_error or "Failed to initialize Google Drive client"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
error_msg,
|
||||
"Missing connector credentials",
|
||||
{"error_type": "ClientInitializationError"},
|
||||
)
|
||||
return 0, 0, [error_msg]
|
||||
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
|
||||
|
|
@ -1310,10 +1474,6 @@ async def index_google_drive_selected_files(
|
|||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(session, search_space_id)
|
||||
drive_client = GoogleDriveClient(
|
||||
session, connector_id, credentials=pre_built_credentials
|
||||
)
|
||||
|
||||
indexed, skipped, unsupported, errors = await _index_selected_files(
|
||||
drive_client,
|
||||
session,
|
||||
|
|
|
|||
|
|
@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
|
|||
IndexingPipelineService,
|
||||
PlaceholderInfo,
|
||||
)
|
||||
from app.services.composio_service import ComposioService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.google_credentials import (
|
||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
||||
build_composio_credentials,
|
||||
)
|
||||
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||
|
||||
from .base import (
|
||||
calculate_date_range,
|
||||
|
|
@ -44,6 +42,62 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
|||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
|
||||
def _normalize_composio_gmail_message(message: dict) -> dict:
|
||||
if message.get("payload"):
|
||||
return message
|
||||
|
||||
headers = []
|
||||
header_values = {
|
||||
"Subject": message.get("subject"),
|
||||
"From": message.get("from") or message.get("sender"),
|
||||
"To": message.get("to") or message.get("recipient"),
|
||||
"Date": message.get("date"),
|
||||
}
|
||||
for name, value in header_values.items():
|
||||
if value:
|
||||
headers.append({"name": name, "value": value})
|
||||
|
||||
return {
|
||||
**message,
|
||||
"id": message.get("id")
|
||||
or message.get("message_id")
|
||||
or message.get("messageId"),
|
||||
"threadId": message.get("threadId") or message.get("thread_id"),
|
||||
"payload": {"headers": headers},
|
||||
"snippet": message.get("snippet", ""),
|
||||
"messageText": message.get("messageText") or message.get("body") or "",
|
||||
}
|
||||
|
||||
|
||||
def _format_gmail_message_to_markdown(message: dict) -> str:
|
||||
headers = {
|
||||
header.get("name", "").lower(): header.get("value", "")
|
||||
for header in message.get("payload", {}).get("headers", [])
|
||||
if isinstance(header, dict)
|
||||
}
|
||||
subject = headers.get("subject", "No Subject")
|
||||
from_email = headers.get("from", "Unknown Sender")
|
||||
to_email = headers.get("to", "Unknown Recipient")
|
||||
date_str = headers.get("date", "Unknown Date")
|
||||
message_text = (
|
||||
message.get("messageText")
|
||||
or message.get("body")
|
||||
or message.get("text")
|
||||
or message.get("snippet", "")
|
||||
)
|
||||
|
||||
return (
|
||||
f"# {subject}\n\n"
|
||||
f"**From:** {from_email}\n"
|
||||
f"**To:** {to_email}\n"
|
||||
f"**Date:** {date_str}\n\n"
|
||||
f"## Message Content\n\n{message_text}\n\n"
|
||||
f"## Message Details\n\n"
|
||||
f"- **Message ID:** {message.get('id', 'Unknown')}\n"
|
||||
f"- **Thread ID:** {message.get('threadId', 'Unknown')}\n"
|
||||
)
|
||||
|
||||
|
||||
def _build_connector_doc(
|
||||
message: dict,
|
||||
markdown_content: str,
|
||||
|
|
@ -162,7 +216,14 @@ async def index_google_gmail_messages(
|
|||
)
|
||||
return 0, 0, error_msg
|
||||
|
||||
# ── Credential building ───────────────────────────────────────
|
||||
is_composio_connector = (
|
||||
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||
)
|
||||
gmail_connector = None
|
||||
composio_service = None
|
||||
connected_account_id = None
|
||||
|
||||
# ── Credential/client building ────────────────────────────────
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
|
|
@ -173,7 +234,7 @@ async def index_google_gmail_messages(
|
|||
{"error_type": "MissingComposioAccount"},
|
||||
)
|
||||
return 0, 0, "Composio connected_account_id not found"
|
||||
credentials = build_composio_credentials(connected_account_id)
|
||||
composio_service = ComposioService()
|
||||
else:
|
||||
config_data = connector.config
|
||||
|
||||
|
|
@ -241,9 +302,10 @@ async def index_google_gmail_messages(
|
|||
{"stage": "client_initialization"},
|
||||
)
|
||||
|
||||
gmail_connector = GoogleGmailConnector(
|
||||
credentials, session, user_id, connector_id
|
||||
)
|
||||
if not is_composio_connector:
|
||||
gmail_connector = GoogleGmailConnector(
|
||||
credentials, session, user_id, connector_id
|
||||
)
|
||||
|
||||
calculated_start_date, calculated_end_date = calculate_date_range(
|
||||
connector, start_date, end_date, default_days_back=365
|
||||
|
|
@ -254,11 +316,60 @@ async def index_google_gmail_messages(
|
|||
f"Fetching emails for connector {connector_id} "
|
||||
f"from {calculated_start_date} to {calculated_end_date}"
|
||||
)
|
||||
messages, error = await gmail_connector.get_recent_messages(
|
||||
max_results=max_messages,
|
||||
start_date=calculated_start_date,
|
||||
end_date=calculated_end_date,
|
||||
)
|
||||
if is_composio_connector:
|
||||
query_parts = []
|
||||
if calculated_start_date:
|
||||
query_parts.append(f"after:{calculated_start_date.replace('-', '/')}")
|
||||
if calculated_end_date:
|
||||
query_parts.append(f"before:{calculated_end_date.replace('-', '/')}")
|
||||
query = " ".join(query_parts)
|
||||
|
||||
messages = []
|
||||
page_token = None
|
||||
error = None
|
||||
while len(messages) < max_messages:
|
||||
page_size = min(50, max_messages - len(messages))
|
||||
(
|
||||
page_messages,
|
||||
page_token,
|
||||
_estimate,
|
||||
page_error,
|
||||
) = await composio_service.get_gmail_messages(
|
||||
connected_account_id=connected_account_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
query=query,
|
||||
max_results=page_size,
|
||||
page_token=page_token,
|
||||
)
|
||||
if page_error:
|
||||
error = page_error
|
||||
break
|
||||
for page_message in page_messages:
|
||||
message_id = (
|
||||
page_message.get("id")
|
||||
or page_message.get("message_id")
|
||||
or page_message.get("messageId")
|
||||
)
|
||||
if message_id:
|
||||
(
|
||||
detail,
|
||||
detail_error,
|
||||
) = await composio_service.get_gmail_message_detail(
|
||||
connected_account_id=connected_account_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
message_id=message_id,
|
||||
)
|
||||
if not detail_error and isinstance(detail, dict):
|
||||
page_message = detail
|
||||
messages.append(_normalize_composio_gmail_message(page_message))
|
||||
if not page_token:
|
||||
break
|
||||
else:
|
||||
messages, error = await gmail_connector.get_recent_messages(
|
||||
max_results=max_messages,
|
||||
start_date=calculated_start_date,
|
||||
end_date=calculated_end_date,
|
||||
)
|
||||
|
||||
if error:
|
||||
error_message = error
|
||||
|
|
@ -326,7 +437,12 @@ async def index_google_gmail_messages(
|
|||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
markdown_content = gmail_connector.format_message_to_markdown(message)
|
||||
if is_composio_connector:
|
||||
markdown_content = _format_gmail_message_to_markdown(message)
|
||||
else:
|
||||
markdown_content = gmail_connector.format_message_to_markdown(
|
||||
message
|
||||
)
|
||||
if not markdown_content.strip():
|
||||
logger.warning(f"Skipping message with no content: {message_id}")
|
||||
documents_skipped += 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue