Merge upstream/dev into feature/multi-agent

This commit is contained in:
CREDO23 2026-05-05 01:44:46 +02:00
commit 5119915f4f
278 changed files with 34669 additions and 8970 deletions

View file

@ -1,10 +1,25 @@
"""Celery tasks package."""
"""Celery tasks package.
Also hosts the small helpers every async celery task should use to
spin up its event loop. See :func:`run_async_celery_task` for the
canonical pattern.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import Awaitable, Callable
from typing import TypeVar
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.config import config
logger = logging.getLogger(__name__)
_celery_engine = None
_celery_session_maker = None
@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker:
_celery_engine, expire_on_commit=False
)
return _celery_session_maker
def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None:
"""Drop the shared ``app.db.engine`` connection pool synchronously.
The shared engine (used by ``shielded_async_session`` and most
routes / services) is a module-level singleton with a real pool.
Each celery task creates a fresh ``asyncio`` event loop; asyncpg
connections cache a reference to whichever loop opened them. When
a subsequent task's loop pulls a stale connection from the pool,
SQLAlchemy's ``pool_pre_ping`` checkout crashes with::
AttributeError: 'NoneType' object has no attribute 'send'
File ".../asyncio/proactor_events.py", line 402, in _loop_writing
self._write_fut = self._loop._proactor.send(self._sock, data)
or hangs forever inside the asyncpg ``Connection._cancel`` cleanup
coroutine that can never run because its loop is gone.
Disposing the engine forces the pool to drop every cached
connection so the next checkout opens a fresh one on the current
loop. Safe to call from a task's finally block; failure is logged
but never propagated.
"""
try:
from app.db import engine as shared_engine
loop.run_until_complete(shared_engine.dispose())
except Exception:
logger.warning("Shared DB engine dispose() failed", exc_info=True)
T = TypeVar("T")
def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T:
"""Run an async coroutine inside a fresh event loop with proper
DB-engine cleanup.
This is the canonical entry point for every async celery task.
It performs three responsibilities that were previously copy-pasted
(incorrectly) across each task module:
1. Create a fresh ``asyncio`` loop and install it on the current
thread (celery's ``--pool=solo`` runs every task on the main
thread, but other pool types don't).
2. Dispose the shared ``app.db.engine`` BEFORE the task runs so
any stale connections left over from a previous task's loop
are dropped defends against tasks that crashed without
cleaning up.
3. Dispose the shared engine AFTER the task runs so the
connections we opened on this loop are released before the
loop closes (avoids ``coroutine 'Connection._cancel' was
never awaited`` warnings and the next-task hang).
Use as::
@celery_app.task(name="my_task", bind=True)
def my_task(self, *args):
return run_async_celery_task(lambda: _my_task_impl(*args))
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Defense-in-depth: prior task may have crashed before
# disposing. Idempotent — no-op if pool is already empty.
_dispose_shared_db_engine(loop)
return loop.run_until_complete(coro_factory())
finally:
# Drop any connections this task opened so they don't leak
# into the next task's loop.
_dispose_shared_db_engine(loop)
with contextlib.suppress(Exception):
loop.run_until_complete(loop.shutdown_asyncgens())
with contextlib.suppress(Exception):
asyncio.set_event_loop(None)
loop.close()
__all__ = [
"get_celery_session_maker",
"run_async_celery_task",
]

View file

@ -4,7 +4,7 @@ import logging
import traceback
from app.celery_app import celery_app
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -49,22 +49,15 @@ def index_notion_pages_task(
end_date: str,
):
"""Celery task to index Notion pages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_notion_pages(
return run_async_celery_task(
lambda: _index_notion_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_notion_pages", connector_id)
raise
finally:
loop.close()
async def _index_notion_pages(
@ -95,19 +88,11 @@ def index_github_repos_task(
end_date: str,
):
"""Celery task to index GitHub repositories."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_github_repos(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_github_repos(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_github_repos(
@ -138,19 +123,11 @@ def index_confluence_pages_task(
end_date: str,
):
"""Celery task to index Confluence pages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_confluence_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_confluence_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_confluence_pages(
@ -181,22 +158,15 @@ def index_google_calendar_events_task(
end_date: str,
):
"""Celery task to index Google Calendar events."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_google_calendar_events(
return run_async_celery_task(
lambda: _index_google_calendar_events(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_google_calendar_events", connector_id)
raise
finally:
loop.close()
async def _index_google_calendar_events(
@ -227,19 +197,11 @@ def index_google_gmail_messages_task(
end_date: str,
):
"""Celery task to index Google Gmail messages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_google_gmail_messages(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_google_gmail_messages(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_google_gmail_messages(
@ -269,22 +231,14 @@ def index_google_drive_files_task(
items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options'
):
"""Celery task to index Google Drive folders and files."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_google_drive_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
return run_async_celery_task(
lambda: _index_google_drive_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
finally:
loop.close()
)
async def _index_google_drive_files(
@ -317,22 +271,14 @@ def index_onedrive_files_task(
items_dict: dict,
):
"""Celery task to index OneDrive folders and files."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_onedrive_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
return run_async_celery_task(
lambda: _index_onedrive_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
finally:
loop.close()
)
async def _index_onedrive_files(
@ -365,22 +311,14 @@ def index_dropbox_files_task(
items_dict: dict,
):
"""Celery task to index Dropbox folders and files."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_dropbox_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
return run_async_celery_task(
lambda: _index_dropbox_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
finally:
loop.close()
)
async def _index_dropbox_files(
@ -414,19 +352,11 @@ def index_elasticsearch_documents_task(
end_date: str,
):
"""Celery task to index Elasticsearch documents."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_elasticsearch_documents(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_elasticsearch_documents(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_elasticsearch_documents(
@ -457,22 +387,15 @@ def index_crawled_urls_task(
end_date: str,
):
"""Celery task to index Web page Urls."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_crawled_urls(
return run_async_celery_task(
lambda: _index_crawled_urls(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_crawled_urls", connector_id)
raise
finally:
loop.close()
async def _index_crawled_urls(
@ -503,19 +426,11 @@ def index_bookstack_pages_task(
end_date: str,
):
"""Celery task to index BookStack pages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_bookstack_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_bookstack_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_bookstack_pages(
@ -546,19 +461,11 @@ def index_composio_connector_task(
end_date: str | None,
):
"""Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio)."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_composio_connector(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_composio_connector(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_composio_connector(

View file

@ -11,7 +11,7 @@ from app.db import Document
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str):
document_id: ID of document to reindex
user_id: ID of user who edited the document
"""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_reindex_document(document_id, user_id))
finally:
loop.close()
return run_async_celery_task(lambda: _reindex_document(document_id, user_id))
async def _reindex_document(document_id: int, user_id: str):

View file

@ -11,7 +11,7 @@ from app.celery_app import celery_app
from app.config import config
from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
from app.tasks.connector_indexers.local_folder_indexer import (
index_local_folder,
index_uploaded_files,
@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int):
)
def delete_document_task(self, document_id: int):
"""Celery task to delete a document and its chunks in batches."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_delete_document_background(document_id))
finally:
loop.close()
return run_async_celery_task(lambda: _delete_document_background(document_id))
async def _delete_document_background(document_id: int) -> None:
@ -153,14 +148,9 @@ def delete_folder_documents_task(
folder_subtree_ids: list[int] | None = None,
):
"""Celery task to delete documents first, then the folder rows."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_delete_folder_documents(document_ids, folder_subtree_ids)
)
finally:
loop.close()
return run_async_celery_task(
lambda: _delete_folder_documents(document_ids, folder_subtree_ids)
)
async def _delete_folder_documents(
@ -209,12 +199,9 @@ async def _delete_folder_documents(
)
def delete_search_space_task(self, search_space_id: int):
"""Celery task to delete a search space and heavy child rows in batches."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_delete_search_space_background(search_space_id))
finally:
loop.close()
return run_async_celery_task(
lambda: _delete_search_space_background(search_space_id)
)
async def _delete_search_space_background(search_space_id: int) -> None:
@ -269,18 +256,11 @@ def process_extension_document_task(
search_space_id: ID of the search space
user_id: ID of the user
"""
# Create a new event loop for this task
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_process_extension_document(
individual_document_dict, search_space_id, user_id
)
return run_async_celery_task(
lambda: _process_extension_document(
individual_document_dict, search_space_id, user_id
)
finally:
loop.close()
)
async def _process_extension_document(
@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st
search_space_id: ID of the search space
user_id: ID of the user
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id))
finally:
loop.close()
return run_async_celery_task(
lambda: _process_youtube_video(url, search_space_id, user_id)
)
async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
@ -573,12 +549,9 @@ def process_file_upload_task(
except Exception as e:
logger.warning(f"[process_file_upload] Could not get file size: {e}")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_process_file_upload(file_path, filename, search_space_id, user_id)
run_async_celery_task(
lambda: _process_file_upload(file_path, filename, search_space_id, user_id)
)
logger.info(
f"[process_file_upload] Task completed successfully for: {filename}"
@ -589,8 +562,6 @@ def process_file_upload_task(
f"Traceback:\n{traceback.format_exc()}"
)
raise
finally:
loop.close()
async def _process_file_upload(
@ -811,25 +782,17 @@ def process_file_upload_with_document_task(
"File may have been removed before syncing could start."
)
# Mark document as failed since file is missing
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_mark_document_failed(
document_id,
"File not found. Please re-upload the file.",
)
run_async_celery_task(
lambda: _mark_document_failed(
document_id,
"File not found. Please re-upload the file.",
)
finally:
loop.close()
)
return
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_process_file_with_document(
run_async_celery_task(
lambda: _process_file_with_document(
document_id,
temp_path,
filename,
@ -849,8 +812,6 @@ def process_file_upload_with_document_task(
f"Traceback:\n{traceback.format_exc()}"
)
raise
finally:
loop.close()
async def _mark_document_failed(document_id: int, reason: str):
@ -1119,22 +1080,16 @@ def process_circleback_meeting_task(
search_space_id: ID of the search space
connector_id: ID of the Circleback connector (for deletion support)
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_process_circleback_meeting(
meeting_id,
meeting_name,
markdown_content,
metadata,
search_space_id,
connector_id,
)
return run_async_celery_task(
lambda: _process_circleback_meeting(
meeting_id,
meeting_name,
markdown_content,
metadata,
search_space_id,
connector_id,
)
finally:
loop.close()
)
async def _process_circleback_meeting(
@ -1291,25 +1246,19 @@ def index_local_folder_task(
target_file_paths: list[str] | None = None,
):
"""Celery task to index a local folder. Config is passed directly — no connector row."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_local_folder_async(
search_space_id=search_space_id,
user_id=user_id,
folder_path=folder_path,
folder_name=folder_name,
exclude_patterns=exclude_patterns,
file_extensions=file_extensions,
root_folder_id=root_folder_id,
enable_summary=enable_summary,
target_file_paths=target_file_paths,
)
return run_async_celery_task(
lambda: _index_local_folder_async(
search_space_id=search_space_id,
user_id=user_id,
folder_path=folder_path,
folder_name=folder_name,
exclude_patterns=exclude_patterns,
file_extensions=file_extensions,
root_folder_id=root_folder_id,
enable_summary=enable_summary,
target_file_paths=target_file_paths,
)
finally:
loop.close()
)
async def _index_local_folder_async(
@ -1441,23 +1390,18 @@ def index_uploaded_folder_files_task(
processing_mode: str = "basic",
):
"""Celery task to index files uploaded from the desktop app."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_uploaded_folder_files_async(
search_space_id=search_space_id,
user_id=user_id,
folder_name=folder_name,
root_folder_id=root_folder_id,
enable_summary=enable_summary,
file_mappings=file_mappings,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)
return run_async_celery_task(
lambda: _index_uploaded_folder_files_async(
search_space_id=search_space_id,
user_id=user_id,
folder_name=folder_name,
root_folder_id=root_folder_id,
enable_summary=enable_summary,
file_mappings=file_mappings,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)
finally:
loop.close()
)
async def _index_uploaded_folder_files_async(
@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str:
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
"""Full AI sort for all documents in a search space."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
finally:
loop.close()
return run_async_celery_task(
lambda: _ai_sort_search_space_async(search_space_id, user_id)
)
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
)
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
"""Incremental AI sort for a single document after indexing."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_ai_sort_document_async(search_space_id, user_id, document_id)
)
finally:
loop.close()
return run_async_celery_task(
lambda: _ai_sort_document_async(search_space_id, user_id, document_id)
)
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):

View file

@ -2,14 +2,13 @@
from __future__ import annotations
import asyncio
import logging
from app.celery_app import celery_app
from app.db import SearchSourceConnector
from app.schemas.obsidian_plugin import NotePayload
from app.services.obsidian_plugin_indexer import upsert_note
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -22,18 +21,13 @@ def index_obsidian_attachment_task(
user_id: str,
) -> None:
"""Process one Obsidian non-markdown attachment asynchronously."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_obsidian_attachment(
connector_id=connector_id,
payload_data=payload_data,
user_id=user_id,
)
return run_async_celery_task(
lambda: _index_obsidian_attachment(
connector_id=connector_id,
payload_data=payload_data,
user_id=user_id,
)
finally:
loop.close()
)
async def _index_obsidian_attachment(

View file

@ -3,14 +3,22 @@
import asyncio
import logging
import sys
from contextlib import asynccontextmanager
from sqlalchemy import select
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app
from app.config import config as app_config
from app.db import Podcast, PodcastStatus
from app.tasks.celery_tasks import get_celery_session_maker
from app.services.billable_calls import (
BillingSettlementError,
QuotaInsufficientError,
_resolve_agent_billing_for_search_space,
billable_call,
)
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -28,6 +36,13 @@ if sys.platform.startswith("win"):
# =============================================================================
@asynccontextmanager
async def _celery_billable_session():
"""Session factory used by billable_call inside the Celery worker loop."""
async with get_celery_session_maker()() as session:
yield session
@celery_app.task(name="generate_content_podcast", bind=True)
def generate_content_podcast_task(
self,
@ -40,27 +55,22 @@ def generate_content_podcast_task(
Celery task to generate podcast from source content.
Updates existing podcast record created by the tool.
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
_generate_content_podcast(
return run_async_celery_task(
lambda: _generate_content_podcast(
podcast_id,
source_content,
search_space_id,
user_prompt,
)
)
loop.run_until_complete(loop.shutdown_asyncgens())
return result
except Exception as e:
logger.error(f"Error generating content podcast: {e!s}")
loop.run_until_complete(_mark_podcast_failed(podcast_id))
try:
run_async_celery_task(lambda: _mark_podcast_failed(podcast_id))
except Exception:
logger.exception("Failed to mark podcast %s as failed", podcast_id)
return {"status": "failed", "podcast_id": podcast_id}
finally:
asyncio.set_event_loop(None)
loop.close()
async def _mark_podcast_failed(podcast_id: int) -> None:
@ -96,6 +106,31 @@ async def _generate_content_podcast(
podcast.status = PodcastStatus.GENERATING
await session.commit()
try:
(
owner_user_id,
billing_tier,
base_model,
) = await _resolve_agent_billing_for_search_space(
session,
search_space_id,
thread_id=podcast.thread_id,
)
except ValueError as resolve_err:
logger.error(
"Podcast %s: cannot resolve billing for search_space=%s: %s",
podcast.id,
search_space_id,
resolve_err,
)
podcast.status = PodcastStatus.FAILED
await session.commit()
return {
"status": "failed",
"podcast_id": podcast.id,
"reason": "billing_resolution_failed",
}
graph_config = {
"configurable": {
"podcast_title": podcast.title,
@ -109,9 +144,52 @@ async def _generate_content_podcast(
db_session=session,
)
graph_result = await podcaster_graph.ainvoke(
initial_state, config=graph_config
)
try:
async with billable_call(
user_id=owner_user_id,
search_space_id=search_space_id,
billing_tier=billing_tier,
base_model=base_model,
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
usage_type="podcast_generation",
call_details={
"podcast_id": podcast.id,
"title": podcast.title,
"thread_id": podcast.thread_id,
},
billable_session_factory=_celery_billable_session,
):
graph_result = await podcaster_graph.ainvoke(
initial_state, config=graph_config
)
except QuotaInsufficientError as exc:
logger.info(
"Podcast %s denied: out of premium credits "
"(used=%d/%d remaining=%d)",
podcast.id,
exc.used_micros,
exc.limit_micros,
exc.remaining_micros,
)
podcast.status = PodcastStatus.FAILED
await session.commit()
return {
"status": "failed",
"podcast_id": podcast.id,
"reason": "premium_quota_exhausted",
}
except BillingSettlementError:
logger.exception(
"Podcast %s: premium billing settlement failed",
podcast.id,
)
podcast.status = PodcastStatus.FAILED
await session.commit()
return {
"status": "failed",
"podcast_id": podcast.id,
"reason": "billing_settlement_failed",
}
podcast_transcript = graph_result.get("podcast_transcript", [])
file_path = graph_result.get("final_podcast_file_path", "")
@ -133,7 +211,14 @@ async def _generate_content_podcast(
podcast.podcast_transcript = serializable_transcript
podcast.file_location = file_path
podcast.status = PodcastStatus.READY
logger.info(
"Podcast %s: committing READY transcript_entries=%d file=%s",
podcast.id,
len(serializable_transcript),
file_path,
)
await session.commit()
logger.info("Podcast %s: READY commit complete", podcast.id)
logger.info(f"Successfully generated podcast: {podcast.id}")

View file

@ -7,7 +7,7 @@ from sqlalchemy.future import select
from app.celery_app import celery_app
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
from app.utils.indexing_locks import is_connector_indexing_locked
logger = logging.getLogger(__name__)
@ -20,15 +20,7 @@ def check_periodic_schedules_task():
This task runs every minute and triggers indexing for any connector
whose next_scheduled_at time has passed.
"""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_check_and_trigger_schedules())
finally:
loop.close()
return run_async_celery_task(_check_and_trigger_schedules)
async def _check_and_trigger_schedules():

View file

@ -34,7 +34,7 @@ from sqlalchemy.future import select
from app.celery_app import celery_app
from app.config import config
from app.db import Document, DocumentStatus, Notification
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -69,16 +69,12 @@ def cleanup_stale_indexing_notifications_task():
Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task.
Also marks associated pending/processing documents as failed.
"""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def _both() -> None:
await _cleanup_stale_notifications()
await _cleanup_stale_document_processing_notifications()
try:
loop.run_until_complete(_cleanup_stale_notifications())
loop.run_until_complete(_cleanup_stale_document_processing_notifications())
finally:
loop.close()
return run_async_celery_task(_both)
async def _cleanup_stale_notifications():

View file

@ -2,7 +2,6 @@
from __future__ import annotations
import asyncio
import logging
from datetime import UTC, datetime, timedelta
@ -18,7 +17,7 @@ from app.db import (
PremiumTokenPurchaseStatus,
)
from app.routes import stripe_routes
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -36,13 +35,7 @@ def get_stripe_client() -> StripeClient | None:
@celery_app.task(name="reconcile_pending_stripe_page_purchases")
def reconcile_pending_stripe_page_purchases_task():
"""Recover paid purchases that were left pending due to missed webhook handling."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_reconcile_pending_page_purchases())
finally:
loop.close()
return run_async_celery_task(_reconcile_pending_page_purchases)
async def _reconcile_pending_page_purchases() -> None:
@ -141,13 +134,7 @@ async def _reconcile_pending_page_purchases() -> None:
@celery_app.task(name="reconcile_pending_stripe_token_purchases")
def reconcile_pending_stripe_token_purchases_task():
"""Recover paid token purchases that were left pending due to missed webhook handling."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_reconcile_pending_token_purchases())
finally:
loop.close()
return run_async_celery_task(_reconcile_pending_token_purchases)
async def _reconcile_pending_token_purchases() -> None:

View file

@ -3,14 +3,22 @@
import asyncio
import logging
import sys
from contextlib import asynccontextmanager
from sqlalchemy import select
from app.agents.video_presentation.graph import graph as video_presentation_graph
from app.agents.video_presentation.state import State as VideoPresentationState
from app.celery_app import celery_app
from app.config import config as app_config
from app.db import VideoPresentation, VideoPresentationStatus
from app.tasks.celery_tasks import get_celery_session_maker
from app.services.billable_calls import (
BillingSettlementError,
QuotaInsufficientError,
_resolve_agent_billing_for_search_space,
billable_call,
)
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -23,6 +31,13 @@ if sys.platform.startswith("win"):
)
@asynccontextmanager
async def _celery_billable_session():
"""Session factory used by billable_call inside the Celery worker loop."""
async with get_celery_session_maker()() as session:
yield session
@celery_app.task(name="generate_video_presentation", bind=True)
def generate_video_presentation_task(
self,
@ -35,27 +50,30 @@ def generate_video_presentation_task(
Celery task to generate video presentation from source content.
Updates existing video presentation record created by the tool.
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
_generate_video_presentation(
return run_async_celery_task(
lambda: _generate_video_presentation(
video_presentation_id,
source_content,
search_space_id,
user_prompt,
)
)
loop.run_until_complete(loop.shutdown_asyncgens())
return result
except Exception as e:
logger.error(f"Error generating video presentation: {e!s}")
loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id))
# Mark FAILED in a fresh loop — the previous loop is closed.
# Swallow secondary failures; the row will simply stay in
# GENERATING and be flushed by the periodic stale cleanup.
try:
run_async_celery_task(
lambda: _mark_video_presentation_failed(video_presentation_id)
)
except Exception:
logger.exception(
"Failed to mark video presentation %s as failed",
video_presentation_id,
)
return {"status": "failed", "video_presentation_id": video_presentation_id}
finally:
asyncio.set_event_loop(None)
loop.close()
async def _mark_video_presentation_failed(video_presentation_id: int) -> None:
@ -97,6 +115,32 @@ async def _generate_video_presentation(
video_pres.status = VideoPresentationStatus.GENERATING
await session.commit()
try:
(
owner_user_id,
billing_tier,
base_model,
) = await _resolve_agent_billing_for_search_space(
session,
search_space_id,
thread_id=video_pres.thread_id,
)
except ValueError as resolve_err:
logger.error(
"VideoPresentation %s: cannot resolve billing for "
"search_space=%s: %s",
video_pres.id,
search_space_id,
resolve_err,
)
video_pres.status = VideoPresentationStatus.FAILED
await session.commit()
return {
"status": "failed",
"video_presentation_id": video_pres.id,
"reason": "billing_resolution_failed",
}
graph_config = {
"configurable": {
"video_title": video_pres.title,
@ -110,9 +154,52 @@ async def _generate_video_presentation(
db_session=session,
)
graph_result = await video_presentation_graph.ainvoke(
initial_state, config=graph_config
)
try:
async with billable_call(
user_id=owner_user_id,
search_space_id=search_space_id,
billing_tier=billing_tier,
base_model=base_model,
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
usage_type="video_presentation_generation",
call_details={
"video_presentation_id": video_pres.id,
"title": video_pres.title,
"thread_id": video_pres.thread_id,
},
billable_session_factory=_celery_billable_session,
):
graph_result = await video_presentation_graph.ainvoke(
initial_state, config=graph_config
)
except QuotaInsufficientError as exc:
logger.info(
"VideoPresentation %s denied: out of premium credits "
"(used=%d/%d remaining=%d)",
video_pres.id,
exc.used_micros,
exc.limit_micros,
exc.remaining_micros,
)
video_pres.status = VideoPresentationStatus.FAILED
await session.commit()
return {
"status": "failed",
"video_presentation_id": video_pres.id,
"reason": "premium_quota_exhausted",
}
except BillingSettlementError:
logger.exception(
"VideoPresentation %s: premium billing settlement failed",
video_pres.id,
)
video_pres.status = VideoPresentationStatus.FAILED
await session.commit()
return {
"status": "failed",
"video_presentation_id": video_pres.id,
"reason": "billing_settlement_failed",
}
# Serialize slides (parsed content + audio info merged)
slides_raw = graph_result.get("slides", [])
@ -143,7 +230,14 @@ async def _generate_video_presentation(
video_pres.slides = serializable_slides
video_pres.scene_codes = serializable_scene_codes
video_pres.status = VideoPresentationStatus.READY
logger.info(
"VideoPresentation %s: committing READY slides=%d scene_codes=%d",
video_pres.id,
len(serializable_slides),
len(serializable_scene_codes),
)
await session.commit()
logger.info("VideoPresentation %s: READY commit complete", video_pres.id)
logger.info(f"Successfully generated video presentation: {video_pres.id}")

View file

@ -0,0 +1,515 @@
"""Server-side mirror of the frontend's assistant-ui ``ContentPart`` projection.
Background
----------
The streaming chat task in ``stream_new_chat`` / ``stream_resume_chat`` yields
SSE events that the frontend folds into a ``ContentPartsState`` (see
``surfsense_web/lib/chat/streaming-state.ts`` and the matching pipeline in
``stream-pipeline.ts``). When a turn ends, the frontend calls
``buildContentForPersistence(...)`` and round-trips that ``ContentPart[]``
JSONB to ``POST /threads/{id}/messages``, which is what was historically
written to ``new_chat_messages.content``.
After the ghost-thread fix moved persistence server-side, the assistant
row is written by ``finalize_assistant_turn`` in the streaming finally
block. The frontend's later ``appendMessage`` is now a no-op (recovers
via the ``(thread_id, turn_id, role)`` partial unique index added in
migration 141), which means the *server* is now responsible for
producing the rich ``ContentPart[]`` shape the FE expects on history
reload text + reasoning + tool-call cards (with ``args``, ``argsText``,
``result``, ``langchainToolCallId``) + thinking-step buckets +
step-separators.
This module is the in-memory accumulator that mirrors the FE state for
exactly that purpose. The streaming code calls ``on_text_*`` / ``on_reasoning_*``
/ ``on_tool_*`` / ``on_thinking_step`` / ``on_step_separator`` /
``mark_interrupted`` at the same call sites it yields the matching
``streaming_service.format_*`` SSE event, so the in-memory ``parts`` list
stays in lockstep with what the FE's pipeline would have produced live.
``snapshot()`` is then taken once in the ``finally`` block and persisted
in a single UPDATE.
Pure synchronous state no DB I/O, no async, no flush callbacks. The
streaming code is responsible for driving lifecycle methods; this class
is a thin projection helper.
"""
from __future__ import annotations
import copy
import json
import logging
from typing import Any
logger = logging.getLogger(__name__)
# Mirrors the FE's filter in ``buildContentForPersistence`` / ``buildContentForUI``:
# only text/reasoning/tool-call parts count as "meaningful". data-thinking-steps
# and data-step-separator decorate the meaningful parts but never stand alone
# in a successful turn.
_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"})
class AssistantContentBuilder:
"""Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``.
Output shape (deep copy of ``self.parts`` via ``snapshot()``) strictly
matches the FE ``ContentPart`` union::
| { type: "text"; text: string }
| { type: "reasoning"; text: string }
| { type: "tool-call"; toolCallId: str; toolName: str;
args: dict; result?: any; argsText?: str; langchainToolCallId?: str;
state?: "aborted" }
| { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] } }
| { type: "data-step-separator"; data: { stepIndex: int } }
Order matches the wire order of the SSE events that drive the lifecycle
methods, with two FE-mirrored exceptions:
1. ``data-thinking-steps`` is a *singleton* and pinned at index 0 the
first time we see a ``data-thinking-step`` SSE event (the FE's
``updateThinkingSteps`` does ``unshift`` on first sight). Subsequent
thinking-step updates mutate that singleton in place.
2. ``data-step-separator`` is appended only when the message already has
meaningful content and the previous part isn't itself a separator
(so the FIRST step of a turn doesn't generate a leading divider).
"""
def __init__(self) -> None:
self.parts: list[dict[str, Any]] = []
# Index of the active text/reasoning part within ``parts`` while
# streaming is open; -1 means "no active part" and the next delta
# opens a fresh one. Mirrors ``ContentPartsState.currentTextPartIndex``.
self._current_text_idx: int = -1
self._current_reasoning_idx: int = -1
# ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the
# synthetic ``call_<run_id>`` (legacy) or the LangChain
# ``tool_call.id`` (parity_v2) — same key the streaming layer
# threads through every ``tool-input-*`` / ``tool-output-*`` event.
self._tool_call_idx_by_ui_id: dict[str, int] = {}
# Live argsText accumulator (concatenated ``tool-input-delta`` chunks)
# so we can reproduce the FE's ``appendToolInputDelta`` behaviour
# before ``tool-input-available`` overwrites it with the
# pretty-printed final JSON.
self._args_text_by_ui_id: dict[str, str] = {}
# ------------------------------------------------------------------
# Text
# ------------------------------------------------------------------
def on_text_start(self, text_id: str) -> None:
"""Begin a fresh text block.
Symmetric to FE ``appendText``: opening text closes any active
reasoning so the renderer treats them as separate parts. The
actual text part isn't materialised here — it's lazily created
on the first ``on_text_delta`` so an empty start/end pair
leaves no trace. Matches the FE pipeline which has no explicit
``text-start`` handler at all.
"""
if self._current_reasoning_idx >= 0:
self._current_reasoning_idx = -1
def on_text_delta(self, text_id: str, delta: str) -> None:
if not delta:
return
if self._current_reasoning_idx >= 0:
# FE behaviour: a text delta after reasoning implicitly
# closes the reasoning block (see ``appendText`` lines
# 178-180).
self._current_reasoning_idx = -1
if (
self._current_text_idx >= 0
and 0 <= self._current_text_idx < len(self.parts)
and self.parts[self._current_text_idx].get("type") == "text"
):
self.parts[self._current_text_idx]["text"] += delta
return
self.parts.append({"type": "text", "text": delta})
self._current_text_idx = len(self.parts) - 1
def on_text_end(self, text_id: str) -> None:
"""Close the active text block.
Mirrors the wire-level ``text-end`` boundary the streaming layer
emits before tool calls / reasoning / step boundaries. The FE
pipeline implicitly closes via ``currentTextPartIndex = -1``
in ``addToolCall`` / ``appendReasoning`` / ``addStepSeparator``;
our helper does the same explicitly so callers don't have to
maintain that invariant per call site.
"""
self._current_text_idx = -1
# ------------------------------------------------------------------
# Reasoning
# ------------------------------------------------------------------
def on_reasoning_start(self, reasoning_id: str) -> None:
if self._current_text_idx >= 0:
self._current_text_idx = -1
def on_reasoning_delta(self, reasoning_id: str, delta: str) -> None:
if not delta:
return
if self._current_text_idx >= 0:
self._current_text_idx = -1
if (
self._current_reasoning_idx >= 0
and 0 <= self._current_reasoning_idx < len(self.parts)
and self.parts[self._current_reasoning_idx].get("type") == "reasoning"
):
self.parts[self._current_reasoning_idx]["text"] += delta
return
self.parts.append({"type": "reasoning", "text": delta})
self._current_reasoning_idx = len(self.parts) - 1
def on_reasoning_end(self, reasoning_id: str) -> None:
self._current_reasoning_idx = -1
# ------------------------------------------------------------------
# Tool calls
# ------------------------------------------------------------------
def on_tool_input_start(
self,
ui_id: str,
tool_name: str,
langchain_tool_call_id: str | None,
) -> None:
"""Register a tool-call card. Args are filled in by later events."""
if not ui_id:
return
# Skip duplicate registration: parity_v2 may emit
# ``tool-input-start`` from both ``on_chat_model_stream``
# (when tool_call_chunks register a name) and ``on_tool_start``
# (the canonical path). The FE de-dupes via ``toolCallIndices``;
# we mirror that here.
if ui_id in self._tool_call_idx_by_ui_id:
if langchain_tool_call_id:
idx = self._tool_call_idx_by_ui_id[ui_id]
part = self.parts[idx]
if not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
return
part: dict[str, Any] = {
"type": "tool-call",
"toolCallId": ui_id,
"toolName": tool_name,
"args": {},
}
if langchain_tool_call_id:
part["langchainToolCallId"] = langchain_tool_call_id
self.parts.append(part)
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
self._current_text_idx = -1
self._current_reasoning_idx = -1
def on_tool_input_delta(self, ui_id: str, args_chunk: str) -> None:
"""Append a streamed args-delta chunk to the matching card's argsText.
Mirrors FE ``appendToolInputDelta``: no-ops when no card has been
registered yet for the given ``ui_id`` the deltas have nowhere
safe to land.
"""
if not ui_id or not args_chunk:
return
idx = self._tool_call_idx_by_ui_id.get(ui_id)
if idx is None:
return
if not (0 <= idx < len(self.parts)):
return
part = self.parts[idx]
if part.get("type") != "tool-call":
return
new_text = (part.get("argsText") or "") + args_chunk
part["argsText"] = new_text
self._args_text_by_ui_id[ui_id] = new_text
def on_tool_input_available(
self,
ui_id: str,
tool_name: str,
args: dict[str, Any],
langchain_tool_call_id: str | None,
) -> None:
"""Finalize the tool-call card's input.
Mirrors FE ``stream-pipeline.ts`` lines 127-153: replaces ``argsText``
with ``json.dumps(input, indent=2)`` so the post-stream card renders
pretty-printed JSON, sets the full ``args`` dict, and backfills
``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time.
Also creates the card if no prior ``tool-input-start`` registered it
(legacy parity_v2-OFF / late-registration paths).
"""
if not ui_id:
return
try:
final_args_text = json.dumps(args or {}, indent=2, ensure_ascii=False)
except (TypeError, ValueError):
# Defensive: ``args`` should already be JSON-safe (the
# streaming layer sanitizes it before emitting), but if a
# caller hands us a non-serializable value we still want
# to record the call without breaking the snapshot.
final_args_text = str(args)
idx = self._tool_call_idx_by_ui_id.get(ui_id)
if idx is not None and 0 <= idx < len(self.parts):
part = self.parts[idx]
if part.get("type") == "tool-call":
part["args"] = args or {}
part["argsText"] = final_args_text
if langchain_tool_call_id and not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
return
# No prior tool-input-start: register the card now.
new_part: dict[str, Any] = {
"type": "tool-call",
"toolCallId": ui_id,
"toolName": tool_name,
"args": args or {},
"argsText": final_args_text,
}
if langchain_tool_call_id:
new_part["langchainToolCallId"] = langchain_tool_call_id
self.parts.append(new_part)
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
self._current_text_idx = -1
self._current_reasoning_idx = -1
def on_tool_output_available(
self,
ui_id: str,
output: Any,
langchain_tool_call_id: str | None,
) -> None:
"""Attach the tool's output (``result``) to the matching card.
Mirrors FE ``updateToolCall``: backfill ``langchainToolCallId``
only if not already set (a NULL late-arriving value never blows
away an earlier known good one).
"""
if not ui_id:
return
idx = self._tool_call_idx_by_ui_id.get(ui_id)
if idx is None or not (0 <= idx < len(self.parts)):
return
part = self.parts[idx]
if part.get("type") != "tool-call":
return
part["result"] = output
if langchain_tool_call_id and not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
# ------------------------------------------------------------------
# Thinking steps & step separators
# ------------------------------------------------------------------
def on_thinking_step(
self,
step_id: str,
title: str,
status: str,
items: list[str] | None,
) -> None:
"""Update / insert the singleton ``data-thinking-steps`` part.
Mirrors FE ``updateThinkingSteps``: maintain a single
``data-thinking-steps`` part anchored at index 0, replacing or
unshifting on first sight. Each ``on_thinking_step`` call
replaces the entry in the steps list keyed by ``step_id`` (or
appends if new).
"""
if not step_id:
return
new_step = {
"id": step_id,
"title": title or "",
"status": status or "in_progress",
"items": list(items) if items else [],
}
# Find existing data-thinking-steps part.
existing_idx = -1
for i, p in enumerate(self.parts):
if p.get("type") == "data-thinking-steps":
existing_idx = i
break
if existing_idx >= 0:
current_steps = self.parts[existing_idx].get("data", {}).get("steps") or []
replaced = False
for i, step in enumerate(current_steps):
if step.get("id") == step_id:
current_steps[i] = new_step
replaced = True
break
if not replaced:
current_steps.append(new_step)
self.parts[existing_idx] = {
"type": "data-thinking-steps",
"data": {"steps": current_steps},
}
return
# First sight: unshift to position 0 (FE parity).
self.parts.insert(
0,
{
"type": "data-thinking-steps",
"data": {"steps": [new_step]},
},
)
# Bump tracked indices since we inserted at the head.
if self._current_text_idx >= 0:
self._current_text_idx += 1
if self._current_reasoning_idx >= 0:
self._current_reasoning_idx += 1
for ui_id, idx in list(self._tool_call_idx_by_ui_id.items()):
self._tool_call_idx_by_ui_id[ui_id] = idx + 1
def on_step_separator(self) -> None:
"""Append a ``data-step-separator`` between consecutive model steps.
Mirrors FE ``addStepSeparator``: only emit when the message
already has meaningful content AND the previous part isn't
itself a separator. ``stepIndex`` is the running count of
separators already in ``parts``.
"""
has_content = any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
if not has_content:
return
if self.parts and self.parts[-1].get("type") == "data-step-separator":
return
step_index = sum(
1 for p in self.parts if p.get("type") == "data-step-separator"
)
self.parts.append(
{
"type": "data-step-separator",
"data": {"stepIndex": step_index},
}
)
self._current_text_idx = -1
self._current_reasoning_idx = -1
# ------------------------------------------------------------------
# Interruption handling
# ------------------------------------------------------------------
def mark_interrupted(self) -> None:
"""Close any open text/reasoning and flip running tools to aborted.
Called from the streaming ``finally`` block before ``snapshot()`` so
the persisted JSONB reflects a coherent end-state even when the
client disconnected mid-turn or the agent hit a fatal error.
- Active text/reasoning blocks: simply lose their "active"
marker (no synthetic content appended). Whatever was streamed
stays as-is.
- Tool-call parts that never received a ``result`` get
``state="aborted"`` so the FE history loader can render them
as "interrupted" rather than "still running".
"""
self._current_text_idx = -1
self._current_reasoning_idx = -1
for part in self.parts:
if part.get("type") != "tool-call":
continue
if "result" in part:
continue
part["state"] = "aborted"
# ------------------------------------------------------------------
# Snapshot & introspection
# ------------------------------------------------------------------
def snapshot(self) -> list[dict[str, Any]]:
"""Return a deep copy of ``parts`` ready for SQL UPDATE / json.dumps.
Deep-copied so callers that finalize from the shielded ``finally``
block can't accidentally mutate the persisted payload while the
SQL UPDATE is in flight (the streaming layer doesn't touch the
builder after this call, but defensive copies are cheap and cheap
is what we want in a finally block).
"""
return copy.deepcopy(self.parts)
def is_empty(self) -> bool:
"""True if no meaningful content was captured.
``data-thinking-steps`` and ``data-step-separator`` decorate
meaningful content but don't count on their own — a turn that
only emitted a thinking step before being interrupted should
still be treated as empty for the status-marker fallback.
"""
return not any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
def stats(self) -> dict[str, int]:
"""Return counts of each part-type plus rough byte size.
Used by the streaming layer's perf logger so an ops dashboard
can correlate finalize latency with payload size, and so a
regression that quietly stops emitting tool-call parts (or
starts emitting hundreds) shows up in [PERF] grep rather than
only as a "history reload looks weird" bug report.
``bytes`` is the JSON-serialised payload length what actually
crosses the wire to PostgreSQL's JSONB column. We compute it
with ``ensure_ascii=False`` to match the JSONB encoder's UTF-8
on-disk layout closely enough for back-of-the-envelope sizing.
Reasoning/text/tool-call/thinking-step/step-separator counts are
independent so any one can spike without the others.
Defensive: ``json.dumps`` failure (a non-serializable value
slipped past the streaming layer's sanitization) is reported as
``bytes=-1`` rather than raised perf logging must not be the
thing that breaks the streaming finally block.
"""
text_blocks = 0
reasoning_blocks = 0
tool_calls = 0
tool_calls_completed = 0
tool_calls_aborted = 0
thinking_step_parts = 0
step_separators = 0
for part in self.parts:
kind = part.get("type")
if kind == "text":
text_blocks += 1
elif kind == "reasoning":
reasoning_blocks += 1
elif kind == "tool-call":
tool_calls += 1
if part.get("state") == "aborted":
tool_calls_aborted += 1
elif "result" in part:
tool_calls_completed += 1
elif kind == "data-thinking-steps":
thinking_step_parts += 1
elif kind == "data-step-separator":
step_separators += 1
try:
byte_size = len(json.dumps(self.parts, ensure_ascii=False, default=str))
except (TypeError, ValueError):
byte_size = -1
return {
"parts": len(self.parts),
"bytes": byte_size,
"text": text_blocks,
"reasoning": reasoning_blocks,
"tool_calls": tool_calls,
"tool_calls_completed": tool_calls_completed,
"tool_calls_aborted": tool_calls_aborted,
"thinking_step_parts": thinking_step_parts,
"step_separators": step_separators,
}

View file

@ -0,0 +1,534 @@
"""Server-side message persistence helpers for the streaming chat agent.
Historically the streaming task (``stream_new_chat``/``stream_resume_chat``)
left ``new_chat_messages`` empty and relied on the frontend to round-trip
``POST /threads/{id}/messages`` afterwards. That gave authenticated clients
a "ghost-thread" abuse vector: skip the round-trip and burn LLM tokens
without leaving an audit trail. These helpers move both writes (the user
turn that triggered the stream and the assistant turn the stream produced)
into the server itself, idempotent against the partial unique index
``uq_new_chat_messages_thread_turn_role`` so legacy frontends that *do*
keep posting via ``appendMessage`` simply hit the unique-index recovery
path on the second writer instead of creating duplicates.
Assistant turn lifecycle
------------------------
The assistant side is split into two helpers so we can capture the row id
*before* the stream produces any output:
* ``persist_assistant_shell`` runs immediately after ``persist_user_turn``
and INSERTs an empty assistant row anchored to ``(thread_id, turn_id,
ASSISTANT)``. Returns the row id so the streaming layer can correlate
later writes (token_usage, AgentActionLog future-correlation) against
a stable PK from the start of the turn.
* ``finalize_assistant_turn`` runs from the streaming ``finally`` block.
It UPDATEs the row's ``content`` to the rich ``ContentPart[]`` snapshot
produced server-side by ``AssistantContentBuilder`` and writes the
``token_usage`` row using ``INSERT ... ON CONFLICT DO NOTHING`` against
the ``uq_token_usage_message_id`` partial unique index from migration
142, hard-eliminating any race against ``append_message``'s recovery
branch.
Defensive contract
------------------
* Every helper runs inside ``shielded_async_session()`` so ``session.close()``
survives starlette's mid-stream cancel scope on client disconnect.
* ``persist_user_turn`` and ``persist_assistant_shell`` use ``INSERT ... ON
CONFLICT DO NOTHING ... RETURNING id`` keyed on the ``(thread_id, turn_id,
role)`` partial unique index. On conflict the insert silently no-ops at
the DB level no Python ``IntegrityError`` is constructed, which
eliminates spurious debugger pauses and keeps logs clean. On conflict a
follow-up ``SELECT`` resolves the existing row id so the streaming layer
can correlate writes against a stable PK.
* ``finalize_assistant_turn`` is best-effort: it never raises. The
streaming ``finally`` block calls it from within
``anyio.CancelScope(shield=True)`` and any raised exception there
would mask the real error.
"""
from __future__ import annotations
import logging
import time
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy import text as sa_text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.future import select
from app.db import (
NewChatMessage,
NewChatMessageRole,
NewChatThread,
TokenUsage,
shielded_async_session,
)
from app.services.token_tracking_service import (
TurnTokenAccumulator,
)
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
# Empty initial assistant content. ``finalize_assistant_turn`` overwrites
# this in a single UPDATE at end-of-stream with the full ``ContentPart[]``
# snapshot produced by ``AssistantContentBuilder``. We persist a one-element
# list with an empty text part so a crash between shell-INSERT and finalize
# leaves the row in a FE-renderable shape (blank bubble) instead of
# blowing up the history loader.
_EMPTY_SHELL_CONTENT: list[dict[str, Any]] = [{"type": "text", "text": ""}]
# Substituted content for genuinely empty turns (no text, no reasoning,
# no tool calls). The streaming layer flips to this when
# ``AssistantContentBuilder.is_empty()`` returns True so the persisted
# row is at least somewhat self-describing instead of an empty text
# bubble. The FE's ``ContentPart`` union doesn't include ``status``
# yet, so the history loader will silently drop this part and render
# a blank bubble (matches today's behaviour for empty turns); a follow-up
# FE PR adds the explicit "no response" rendering.
_STATUS_NO_RESPONSE: list[dict[str, Any]] = [
{"type": "status", "text": "(no text response)"}
]
def _build_user_content(
user_query: str,
user_image_data_urls: list[str] | None,
mentioned_documents: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
"""Build the persisted user-message ``content`` (assistant-ui v2 parts).
Mirrors the shape the existing frontend posts via
``appendMessage`` (see ``surfsense_web/.../new-chat/[[...chat_id]]/page.tsx``):
[{"type": "text", "text": "..."},
{"type": "image", "image": "data:..."},
{"type": "mentioned-documents", "documents": [{"id": int,
"title": str, "document_type": str}, ...]}]
The companion reader is
``app.utils.user_message_multimodal.split_persisted_user_content_parts``
which expects exactly this shape keep them in sync.
``mentioned_documents``: optional list of ``{id, title, document_type}``
dicts. When non-empty (and a ``mentioned-documents`` part is not already
in some other input shape), a single ``{"type": "mentioned-documents",
"documents": [...]}`` part is appended. Mirrors the FE injection at
``page.tsx:281-286`` (``persistUserTurn``).
"""
parts: list[dict[str, Any]] = [{"type": "text", "text": user_query or ""}]
for url in user_image_data_urls or ():
if isinstance(url, str) and url:
parts.append({"type": "image", "image": url})
if mentioned_documents:
normalized: list[dict[str, Any]] = []
for doc in mentioned_documents:
if not isinstance(doc, dict):
continue
doc_id = doc.get("id")
title = doc.get("title")
document_type = doc.get("document_type")
if doc_id is None or title is None or document_type is None:
continue
normalized.append(
{
"id": doc_id,
"title": str(title),
"document_type": str(document_type),
}
)
if normalized:
parts.append({"type": "mentioned-documents", "documents": normalized})
return parts
async def persist_user_turn(
*,
chat_id: int,
user_id: str | None,
turn_id: str,
user_query: str,
user_image_data_urls: list[str] | None = None,
mentioned_documents: list[dict[str, Any]] | None = None,
) -> int | None:
"""Persist the user-side row for a chat turn and return its ``id``.
Uses ``INSERT ... ON CONFLICT DO NOTHING ... RETURNING id`` keyed on the
``(thread_id, turn_id, role)`` partial unique index from migration 141
(``WHERE turn_id IS NOT NULL``). On conflict the insert silently no-ops
at the DB level no Python ``IntegrityError`` is constructed, which
eliminates the debugger pause that ``justMyCode=false`` + async greenlet
interactions used to produce, and keeps production logs clean.
Returns the ``id`` of the row that exists for this turn after the call:
the freshly inserted ``id`` on the happy path, or the existing ``id``
when a previous writer (legacy FE ``appendMessage`` racing the SSE
stream, redelivered request, etc.) already wrote it. Returns ``None``
only on genuine DB failure; the caller should yield a streaming error
and abort the turn so we never produce a title/assistant row that
isn't anchored to a persisted user message.
Other constraint violations (FK, NOT NULL, etc.) still raise
``IntegrityError`` only the ``(thread_id, turn_id, role)`` collision
is silenced.
"""
if not turn_id:
# Defensive: turn_id is always populated by the streaming path
# before this helper is called. If it isn't, we cannot be
# idempotent against the unique index — refuse to write rather
# than create a row the unique index can't dedupe.
logger.error(
"persist_user_turn called without a turn_id (chat_id=%s); skipping",
chat_id,
)
return None
t0 = time.perf_counter()
outcome = "failed"
resolved_id: int | None = None
try:
async with shielded_async_session() as ws:
# Re-attach the thread row so we can also bump updated_at
# in the same write — keeps the sidebar ordering accurate
# when a user fires off a turn but never reaches the
# legacy appendMessage.
thread = await ws.get(NewChatThread, chat_id)
author_uuid: UUID | None = None
if user_id:
try:
author_uuid = UUID(user_id)
except (TypeError, ValueError):
logger.warning(
"persist_user_turn: invalid user_id=%r, persisting as anonymous",
user_id,
)
content_payload = _build_user_content(
user_query, user_image_data_urls, mentioned_documents
)
insert_stmt = (
pg_insert(NewChatMessage)
.values(
thread_id=chat_id,
role=NewChatMessageRole.USER,
content=content_payload,
author_id=author_uuid,
turn_id=turn_id,
)
.on_conflict_do_nothing(
index_elements=["thread_id", "turn_id", "role"],
index_where=sa_text("turn_id IS NOT NULL"),
)
.returning(NewChatMessage.id)
)
inserted_id = (await ws.execute(insert_stmt)).scalar()
if inserted_id is None:
# Conflict on partial unique index — another writer
# (legacy FE appendMessage, redelivered request, etc.)
# already persisted this row. Look it up and reuse.
lookup = await ws.execute(
select(NewChatMessage.id).where(
NewChatMessage.thread_id == chat_id,
NewChatMessage.turn_id == turn_id,
NewChatMessage.role == NewChatMessageRole.USER,
)
)
existing_id = lookup.scalars().first()
if existing_id is None:
# Conflict reported but no row found — extremely
# unlikely (concurrent DELETE). Surface as failure.
logger.warning(
"persist_user_turn: conflict but no matching row "
"(chat_id=%s, turn_id=%s)",
chat_id,
turn_id,
)
outcome = "integrity_no_match"
return None
resolved_id = int(existing_id)
outcome = "race_recovered"
else:
resolved_id = int(inserted_id)
outcome = "inserted"
# Bump thread.updated_at only on a real insert — when
# we recovered an existing row the prior writer
# already touched the thread.
if thread is not None:
thread.updated_at = datetime.now(UTC)
await ws.commit()
return resolved_id
except Exception:
logger.exception(
"persist_user_turn failed (chat_id=%s, turn_id=%s)",
chat_id,
turn_id,
)
return None
finally:
_perf_log.info(
"[persist_user_turn] outcome=%s chat_id=%s turn_id=%s "
"message_id=%s query_len=%d images=%d mentioned_docs=%d "
"in %.3fs",
outcome,
chat_id,
turn_id,
resolved_id,
len(user_query or ""),
len(user_image_data_urls or ()),
len(mentioned_documents or ()),
time.perf_counter() - t0,
)
async def persist_assistant_shell(
*,
chat_id: int,
user_id: str | None,
turn_id: str,
) -> int | None:
"""Pre-write an empty assistant row for the turn and return its id.
Inserts a placeholder ``new_chat_messages`` row (empty text content) so
the streaming layer has a stable ``message_id`` to correlate against
for the rest of the turn. ``finalize_assistant_turn`` overwrites the
``content`` field at end-of-stream with the rich ``ContentPart[]``
snapshot produced by ``AssistantContentBuilder``.
Returns the row id on success, ``None`` on a genuine DB failure (caller
should abort the turn rather than stream into a void).
Idempotent against the ``(thread_id, turn_id, ASSISTANT)`` partial unique
index from migration 141: if a row already exists (resume retry, racing
legacy frontend, redelivered request, etc.) we look it up by
``(thread_id, turn_id, role)`` and return its existing id. The streaming
layer is then free to UPDATE that row at finalize time.
"""
if not turn_id:
logger.error(
"persist_assistant_shell called without a turn_id (chat_id=%s); skipping",
chat_id,
)
return None
t0 = time.perf_counter()
outcome = "failed"
resolved_id: int | None = None
try:
async with shielded_async_session() as ws:
insert_stmt = (
pg_insert(NewChatMessage)
.values(
thread_id=chat_id,
role=NewChatMessageRole.ASSISTANT,
content=_EMPTY_SHELL_CONTENT,
author_id=None,
turn_id=turn_id,
)
.on_conflict_do_nothing(
index_elements=["thread_id", "turn_id", "role"],
index_where=sa_text("turn_id IS NOT NULL"),
)
.returning(NewChatMessage.id)
)
inserted_id = (await ws.execute(insert_stmt)).scalar()
if inserted_id is None:
# Conflict — another writer (legacy FE appendMessage,
# resume retry, redelivered request) wrote the
# (thread_id, turn_id, ASSISTANT) row first. Look it up
# so the streaming layer can UPDATE the same row at
# finalize time.
lookup = await ws.execute(
select(NewChatMessage.id).where(
NewChatMessage.thread_id == chat_id,
NewChatMessage.turn_id == turn_id,
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
)
)
existing_id = lookup.scalars().first()
if existing_id is None:
logger.warning(
"persist_assistant_shell: conflict but no matching "
"(thread_id, turn_id, role) row found "
"(chat_id=%s, turn_id=%s)",
chat_id,
turn_id,
)
outcome = "integrity_no_match"
return None
resolved_id = int(existing_id)
outcome = "race_recovered"
else:
resolved_id = int(inserted_id)
outcome = "inserted"
await ws.commit()
return resolved_id
except Exception:
logger.exception(
"persist_assistant_shell failed (chat_id=%s, turn_id=%s)",
chat_id,
turn_id,
)
return None
finally:
_perf_log.info(
"[persist_assistant_shell] outcome=%s chat_id=%s turn_id=%s "
"message_id=%s in %.3fs",
outcome,
chat_id,
turn_id,
resolved_id,
time.perf_counter() - t0,
)
async def finalize_assistant_turn(
*,
message_id: int,
chat_id: int,
search_space_id: int,
user_id: str | None,
turn_id: str,
content: list[dict[str, Any]],
accumulator: TurnTokenAccumulator | None,
) -> None:
"""Finalize the assistant row and write its token_usage.
Two writes in a single shielded session:
1. ``UPDATE new_chat_messages SET content = :c, updated_at = now()
WHERE id = :id`` overwrites the placeholder ``persist_assistant_shell``
wrote with the full ``ContentPart[]`` snapshot produced server-side.
2. ``INSERT INTO token_usage (...) VALUES (...) ON CONFLICT (message_id)
WHERE message_id IS NOT NULL DO NOTHING`` uses the partial unique
index ``uq_token_usage_message_id`` from migration 142 to make the
insert idempotent against ``append_message``'s recovery branch
(which uses the same ON CONFLICT clause).
Substitutes the status-marker payload when ``content`` is empty
(pure tool-call turn that aborted before any output, or interrupt
before any event arrived). The status marker is preferable to a
blank text bubble because token accounting still runs and an ops
dashboard can flag the row.
Best-effort never raises. The streaming ``finally`` calls this
from within ``anyio.CancelScope(shield=True)``; any raised exception
here would mask the real error that triggered the cleanup.
"""
if not turn_id:
logger.error(
"finalize_assistant_turn called without turn_id "
"(chat_id=%s, message_id=%s); skipping",
chat_id,
message_id,
)
return
if not message_id:
logger.error(
"finalize_assistant_turn called without message_id "
"(chat_id=%s, turn_id=%s); skipping",
chat_id,
turn_id,
)
return
payload: list[dict[str, Any]]
is_status_marker = False
if content:
payload = content
else:
payload = _STATUS_NO_RESPONSE
is_status_marker = True
t0 = time.perf_counter()
outcome = "failed"
token_usage_attempted = bool(
accumulator is not None and accumulator.calls and user_id
)
try:
async with shielded_async_session() as ws:
assistant_row = await ws.get(NewChatMessage, message_id)
if assistant_row is None:
logger.warning(
"finalize_assistant_turn: row not found "
"(chat_id=%s, message_id=%s, turn_id=%s); skipping",
chat_id,
message_id,
turn_id,
)
outcome = "row_missing"
return
assistant_row.content = payload
assistant_row.updated_at = datetime.now(UTC)
# Token usage. ``record_token_usage`` (used elsewhere) does
# SELECT-then-INSERT in two statements which races with
# ``append_message``. Switch to a single INSERT ... ON
# CONFLICT DO NOTHING keyed on the migration-142 partial
# unique index so the loser silently drops its write at
# the DB level — exactly one row per ``message_id``,
# regardless of which session committed first.
if accumulator is not None and accumulator.calls and user_id:
try:
user_uuid = UUID(user_id)
except (TypeError, ValueError):
logger.warning(
"finalize_assistant_turn: invalid user_id=%r, "
"skipping token_usage row",
user_id,
)
else:
insert_stmt = (
pg_insert(TokenUsage)
.values(
usage_type="chat",
prompt_tokens=accumulator.total_prompt_tokens,
completion_tokens=accumulator.total_completion_tokens,
total_tokens=accumulator.grand_total,
cost_micros=accumulator.total_cost_micros,
model_breakdown=accumulator.per_message_summary(),
call_details={"calls": accumulator.serialized_calls()},
thread_id=chat_id,
message_id=message_id,
search_space_id=search_space_id,
user_id=user_uuid,
)
.on_conflict_do_nothing(
index_elements=["message_id"],
index_where=sa_text("message_id IS NOT NULL"),
)
)
await ws.execute(insert_stmt)
await ws.commit()
outcome = "ok"
except Exception:
logger.exception(
"finalize_assistant_turn failed (chat_id=%s, message_id=%s, turn_id=%s)",
chat_id,
message_id,
turn_id,
)
finally:
_perf_log.info(
"[finalize_assistant_turn] outcome=%s chat_id=%s message_id=%s "
"turn_id=%s parts=%d status_marker=%s "
"token_usage_attempted=%s in %.3fs",
outcome,
chat_id,
message_id,
turn_id,
len(payload),
is_status_marker,
token_usage_attempted,
time.perf_counter() - t0,
)

File diff suppressed because it is too large Load diff

View file

@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.composio_service import ComposioService
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.google_credentials import (
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
build_composio_credentials,
)
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
from .base import (
check_duplicate_document_by_hash,
@ -44,6 +42,10 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
HEARTBEAT_INTERVAL_SECONDS = 30
def _format_calendar_event_to_markdown(event: dict) -> str:
return GoogleCalendarConnector.format_event_to_markdown(None, event)
def _build_connector_doc(
event: dict,
event_markdown: str,
@ -150,7 +152,14 @@ async def index_google_calendar_events(
)
return 0, 0, f"Connector with ID {connector_id} not found"
# ── Credential building ───────────────────────────────────────
is_composio_connector = (
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
)
calendar_client = None
composio_service = None
connected_account_id = None
# ── Credential/client building ────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
@ -161,7 +170,7 @@ async def index_google_calendar_events(
{"error_type": "MissingComposioAccount"},
)
return 0, 0, "Composio connected_account_id not found"
credentials = build_composio_credentials(connected_account_id)
composio_service = ComposioService()
else:
config_data = connector.config
@ -229,12 +238,13 @@ async def index_google_calendar_events(
{"stage": "client_initialization"},
)
calendar_client = GoogleCalendarConnector(
credentials=credentials,
session=session,
user_id=user_id,
connector_id=connector_id,
)
if not is_composio_connector:
calendar_client = GoogleCalendarConnector(
credentials=credentials,
session=session,
user_id=user_id,
connector_id=connector_id,
)
# Handle 'undefined' string from frontend (treat as None)
if start_date == "undefined" or start_date == "":
@ -300,9 +310,26 @@ async def index_google_calendar_events(
)
try:
events, error = await calendar_client.get_all_primary_calendar_events(
start_date=start_date_str, end_date=end_date_str
)
if is_composio_connector:
start_dt = parse_date_flexible(start_date_str).replace(
hour=0, minute=0, second=0, microsecond=0
)
end_dt = parse_date_flexible(end_date_str).replace(
hour=23, minute=59, second=59, microsecond=0
)
events, error = await composio_service.get_calendar_events(
connected_account_id=connected_account_id,
entity_id=f"surfsense_{user_id}",
time_min=start_dt.isoformat(),
time_max=end_dt.isoformat(),
max_results=250,
)
if not events and not error:
error = "No events found in the specified date range."
else:
events, error = await calendar_client.get_all_primary_calendar_events(
start_date=start_date_str, end_date=end_date_str
)
if error:
if "No events found" in error:
@ -381,7 +408,7 @@ async def index_google_calendar_events(
documents_skipped += 1
continue
event_markdown = calendar_client.format_event_to_markdown(event)
event_markdown = _format_calendar_event_to_markdown(event)
if not event_markdown.strip():
logger.warning(f"Skipping event with no content: {event_summary}")
documents_skipped += 1

View file

@ -9,6 +9,8 @@ import asyncio
import logging
import time
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any
from sqlalchemy import String, cast, select
from sqlalchemy.exc import SQLAlchemyError
@ -37,6 +39,7 @@ from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.composio_service import ComposioService
from app.services.llm_service import get_user_long_context_llm
from app.services.page_limit_service import PageLimitService
from app.services.task_logging_service import TaskLoggingService
@ -45,10 +48,7 @@ from app.tasks.connector_indexers.base import (
get_connector_by_id,
update_connector_last_indexed,
)
from app.utils.google_credentials import (
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
build_composio_credentials,
)
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
ACCEPTED_DRIVE_CONNECTOR_TYPES = {
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
@ -61,6 +61,209 @@ HEARTBEAT_INTERVAL_SECONDS = 30
logger = logging.getLogger(__name__)
class ComposioDriveClient:
"""Google Drive client facade backed by Composio tool execution.
Composio-managed OAuth connections can execute tools without exposing raw
OAuth tokens through connected account state.
"""
def __init__(
self,
session: AsyncSession,
connector_id: int,
connected_account_id: str,
entity_id: str,
):
self.session = session
self.connector_id = connector_id
self.connected_account_id = connected_account_id
self.entity_id = entity_id
self.composio = ComposioService()
async def list_files(
self,
query: str = "",
fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)",
page_size: int = 100,
page_token: str | None = None,
) -> tuple[list[dict[str, Any]], str | None, str | None]:
params: dict[str, Any] = {
"page_size": min(page_size, 100),
"fields": fields,
}
if query:
params["q"] = query
if page_token:
params["page_token"] = page_token
result = await self.composio.execute_tool(
connected_account_id=self.connected_account_id,
tool_name="GOOGLEDRIVE_LIST_FILES",
params=params,
entity_id=self.entity_id,
)
if not result.get("success"):
return [], None, result.get("error", "Unknown error")
data = result.get("data", {})
files = []
next_token = None
if isinstance(data, dict):
inner_data = data.get("data", data)
if isinstance(inner_data, dict):
files = inner_data.get("files", [])
next_token = inner_data.get("nextPageToken") or inner_data.get(
"next_page_token"
)
elif isinstance(data, list):
files = data
return files, next_token, None
async def get_file_metadata(
self, file_id: str, fields: str = "*"
) -> tuple[dict[str, Any] | None, str | None]:
result = await self.composio.execute_tool(
connected_account_id=self.connected_account_id,
tool_name="GOOGLEDRIVE_GET_FILE_METADATA",
params={"file_id": file_id, "fields": fields},
entity_id=self.entity_id,
)
if not result.get("success"):
return None, result.get("error", "Unknown error")
data = result.get("data", {})
if isinstance(data, dict):
inner_data = data.get("data", data)
if isinstance(inner_data, dict):
return inner_data, None
return None, "Could not extract metadata from Composio response"
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
return await self._download_file_content(file_id)
async def download_file_to_disk(
self,
file_id: str,
dest_path: str,
chunksize: int = 5 * 1024 * 1024,
) -> str | None:
del chunksize
content, error = await self.download_file(file_id)
if error:
return error
if content is None:
return "No content returned from Composio"
Path(dest_path).write_bytes(content)
return None
async def export_google_file(
self, file_id: str, mime_type: str
) -> tuple[bytes | None, str | None]:
return await self._download_file_content(file_id, mime_type=mime_type)
async def _download_file_content(
self, file_id: str, mime_type: str | None = None
) -> tuple[bytes | None, str | None]:
params: dict[str, Any] = {"file_id": file_id}
if mime_type:
params["mime_type"] = mime_type
result = await self.composio.execute_tool(
connected_account_id=self.connected_account_id,
tool_name="GOOGLEDRIVE_DOWNLOAD_FILE",
params=params,
entity_id=self.entity_id,
)
if not result.get("success"):
return None, result.get("error", "Unknown error")
return self._read_download_result(result.get("data"))
def _read_download_result(self, data: Any) -> tuple[bytes | None, str | None]:
if isinstance(data, bytes):
return data, None
file_path: str | None = None
if isinstance(data, str):
file_path = data
elif isinstance(data, dict):
inner_data = data.get("data", data)
if isinstance(inner_data, dict):
for key in ("file_path", "downloaded_file_content", "path", "uri"):
value = inner_data.get(key)
if isinstance(value, str):
file_path = value
break
if isinstance(value, dict):
nested = (
value.get("file_path")
or value.get("downloaded_file_content")
or value.get("path")
or value.get("uri")
or value.get("s3url")
)
if isinstance(nested, str):
file_path = nested
break
if not file_path:
return None, "No file path/content returned from Composio"
if file_path.startswith(("http://", "https://")):
try:
import urllib.request
with urllib.request.urlopen(file_path, timeout=60) as response:
return response.read(), None
except Exception as e:
return None, f"Failed to download Composio file URL: {e!s}"
path_obj = Path(file_path)
if path_obj.is_absolute() or ".composio" in str(path_obj):
if not path_obj.exists():
return None, f"File not found at path: {file_path}"
return path_obj.read_bytes(), None
try:
import base64
return base64.b64decode(file_path), None
except Exception:
return file_path.encode("utf-8"), None
def _build_drive_client_for_connector(
session: AsyncSession,
connector_id: int,
connector: object,
user_id: str,
) -> tuple[GoogleDriveClient | ComposioDriveClient | None, str | None]:
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
return None, (
f"Composio connected_account_id not found for connector {connector_id}"
)
return (
ComposioDriveClient(
session,
connector_id,
connected_account_id,
entity_id=f"surfsense_{user_id}",
),
None,
)
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
return None, "SECRET_KEY not configured but credentials are marked as encrypted"
return GoogleDriveClient(session, connector_id), None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@ -927,34 +1130,17 @@ async def index_google_drive_files(
{"stage": "client_initialization"},
)
pre_built_credentials = None
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
await task_logger.log_task_failure(
log_entry,
error_msg,
"Missing Composio account",
{"error_type": "MissingComposioAccount"},
)
return 0, 0, error_msg, 0
pre_built_credentials = build_composio_credentials(connected_account_id)
else:
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
await task_logger.log_task_failure(
log_entry,
"SECRET_KEY not configured but credentials are encrypted",
"Missing SECRET_KEY",
{"error_type": "MissingSecretKey"},
)
return (
0,
0,
"SECRET_KEY not configured but credentials are marked as encrypted",
0,
)
drive_client, client_error = _build_drive_client_for_connector(
session, connector_id, connector, user_id
)
if client_error or not drive_client:
await task_logger.log_task_failure(
log_entry,
client_error or "Failed to initialize Google Drive client",
"Missing connector credentials",
{"error_type": "ClientInitializationError"},
)
return 0, 0, client_error, 0
connector_enable_summary = getattr(connector, "enable_summary", True)
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
@ -963,10 +1149,6 @@ async def index_google_drive_files(
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials
)
if not folder_id:
error_msg = "folder_id is required for Google Drive indexing"
await task_logger.log_task_failure(
@ -979,8 +1161,14 @@ async def index_google_drive_files(
folder_tokens = connector.config.get("folder_tokens", {})
start_page_token = folder_tokens.get(target_folder_id)
is_composio_connector = (
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
)
can_use_delta = (
use_delta_sync and start_page_token and connector.last_indexed_at
not is_composio_connector
and use_delta_sync
and start_page_token
and connector.last_indexed_at
)
documents_unsupported = 0
@ -1051,7 +1239,16 @@ async def index_google_drive_files(
)
if documents_indexed > 0 or can_use_delta:
new_token, token_error = await get_start_page_token(drive_client)
if isinstance(drive_client, ComposioDriveClient):
(
new_token,
token_error,
) = await drive_client.composio.get_drive_start_page_token(
drive_client.connected_account_id,
drive_client.entity_id,
)
else:
new_token, token_error = await get_start_page_token(drive_client)
if new_token and not token_error:
await session.refresh(connector)
if "folder_tokens" not in connector.config:
@ -1137,32 +1334,17 @@ async def index_google_drive_single_file(
)
return 0, error_msg
pre_built_credentials = None
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
await task_logger.log_task_failure(
log_entry,
error_msg,
"Missing Composio account",
{"error_type": "MissingComposioAccount"},
)
return 0, error_msg
pre_built_credentials = build_composio_credentials(connected_account_id)
else:
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
await task_logger.log_task_failure(
log_entry,
"SECRET_KEY not configured but credentials are encrypted",
"Missing SECRET_KEY",
{"error_type": "MissingSecretKey"},
)
return (
0,
"SECRET_KEY not configured but credentials are marked as encrypted",
)
drive_client, client_error = _build_drive_client_for_connector(
session, connector_id, connector, user_id
)
if client_error or not drive_client:
await task_logger.log_task_failure(
log_entry,
client_error or "Failed to initialize Google Drive client",
"Missing connector credentials",
{"error_type": "ClientInitializationError"},
)
return 0, client_error
connector_enable_summary = getattr(connector, "enable_summary", True)
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
@ -1171,10 +1353,6 @@ async def index_google_drive_single_file(
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials
)
file, error = await get_file_by_id(drive_client, file_id)
if error or not file:
error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}"
@ -1276,32 +1454,18 @@ async def index_google_drive_selected_files(
)
return 0, 0, [error_msg]
pre_built_credentials = None
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
await task_logger.log_task_failure(
log_entry,
error_msg,
"Missing Composio account",
{"error_type": "MissingComposioAccount"},
)
return 0, 0, [error_msg]
pre_built_credentials = build_composio_credentials(connected_account_id)
else:
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
error_msg = (
"SECRET_KEY not configured but credentials are marked as encrypted"
)
await task_logger.log_task_failure(
log_entry,
error_msg,
"Missing SECRET_KEY",
{"error_type": "MissingSecretKey"},
)
return 0, 0, [error_msg]
drive_client, client_error = _build_drive_client_for_connector(
session, connector_id, connector, user_id
)
if client_error or not drive_client:
error_msg = client_error or "Failed to initialize Google Drive client"
await task_logger.log_task_failure(
log_entry,
error_msg,
"Missing connector credentials",
{"error_type": "ClientInitializationError"},
)
return 0, 0, [error_msg]
connector_enable_summary = getattr(connector, "enable_summary", True)
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
@ -1310,10 +1474,6 @@ async def index_google_drive_selected_files(
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials
)
indexed, skipped, unsupported, errors = await _index_selected_files(
drive_client,
session,

View file

@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.composio_service import ComposioService
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.google_credentials import (
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
build_composio_credentials,
)
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
from .base import (
calculate_date_range,
@ -44,6 +42,62 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
HEARTBEAT_INTERVAL_SECONDS = 30
def _normalize_composio_gmail_message(message: dict) -> dict:
if message.get("payload"):
return message
headers = []
header_values = {
"Subject": message.get("subject"),
"From": message.get("from") or message.get("sender"),
"To": message.get("to") or message.get("recipient"),
"Date": message.get("date"),
}
for name, value in header_values.items():
if value:
headers.append({"name": name, "value": value})
return {
**message,
"id": message.get("id")
or message.get("message_id")
or message.get("messageId"),
"threadId": message.get("threadId") or message.get("thread_id"),
"payload": {"headers": headers},
"snippet": message.get("snippet", ""),
"messageText": message.get("messageText") or message.get("body") or "",
}
def _format_gmail_message_to_markdown(message: dict) -> str:
headers = {
header.get("name", "").lower(): header.get("value", "")
for header in message.get("payload", {}).get("headers", [])
if isinstance(header, dict)
}
subject = headers.get("subject", "No Subject")
from_email = headers.get("from", "Unknown Sender")
to_email = headers.get("to", "Unknown Recipient")
date_str = headers.get("date", "Unknown Date")
message_text = (
message.get("messageText")
or message.get("body")
or message.get("text")
or message.get("snippet", "")
)
return (
f"# {subject}\n\n"
f"**From:** {from_email}\n"
f"**To:** {to_email}\n"
f"**Date:** {date_str}\n\n"
f"## Message Content\n\n{message_text}\n\n"
f"## Message Details\n\n"
f"- **Message ID:** {message.get('id', 'Unknown')}\n"
f"- **Thread ID:** {message.get('threadId', 'Unknown')}\n"
)
def _build_connector_doc(
message: dict,
markdown_content: str,
@ -162,7 +216,14 @@ async def index_google_gmail_messages(
)
return 0, 0, error_msg
# ── Credential building ───────────────────────────────────────
is_composio_connector = (
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
)
gmail_connector = None
composio_service = None
connected_account_id = None
# ── Credential/client building ────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
@ -173,7 +234,7 @@ async def index_google_gmail_messages(
{"error_type": "MissingComposioAccount"},
)
return 0, 0, "Composio connected_account_id not found"
credentials = build_composio_credentials(connected_account_id)
composio_service = ComposioService()
else:
config_data = connector.config
@ -241,9 +302,10 @@ async def index_google_gmail_messages(
{"stage": "client_initialization"},
)
gmail_connector = GoogleGmailConnector(
credentials, session, user_id, connector_id
)
if not is_composio_connector:
gmail_connector = GoogleGmailConnector(
credentials, session, user_id, connector_id
)
calculated_start_date, calculated_end_date = calculate_date_range(
connector, start_date, end_date, default_days_back=365
@ -254,11 +316,60 @@ async def index_google_gmail_messages(
f"Fetching emails for connector {connector_id} "
f"from {calculated_start_date} to {calculated_end_date}"
)
messages, error = await gmail_connector.get_recent_messages(
max_results=max_messages,
start_date=calculated_start_date,
end_date=calculated_end_date,
)
if is_composio_connector:
query_parts = []
if calculated_start_date:
query_parts.append(f"after:{calculated_start_date.replace('-', '/')}")
if calculated_end_date:
query_parts.append(f"before:{calculated_end_date.replace('-', '/')}")
query = " ".join(query_parts)
messages = []
page_token = None
error = None
while len(messages) < max_messages:
page_size = min(50, max_messages - len(messages))
(
page_messages,
page_token,
_estimate,
page_error,
) = await composio_service.get_gmail_messages(
connected_account_id=connected_account_id,
entity_id=f"surfsense_{user_id}",
query=query,
max_results=page_size,
page_token=page_token,
)
if page_error:
error = page_error
break
for page_message in page_messages:
message_id = (
page_message.get("id")
or page_message.get("message_id")
or page_message.get("messageId")
)
if message_id:
(
detail,
detail_error,
) = await composio_service.get_gmail_message_detail(
connected_account_id=connected_account_id,
entity_id=f"surfsense_{user_id}",
message_id=message_id,
)
if not detail_error and isinstance(detail, dict):
page_message = detail
messages.append(_normalize_composio_gmail_message(page_message))
if not page_token:
break
else:
messages, error = await gmail_connector.get_recent_messages(
max_results=max_messages,
start_date=calculated_start_date,
end_date=calculated_end_date,
)
if error:
error_message = error
@ -326,7 +437,12 @@ async def index_google_gmail_messages(
documents_skipped += 1
continue
markdown_content = gmail_connector.format_message_to_markdown(message)
if is_composio_connector:
markdown_content = _format_gmail_message_to_markdown(message)
else:
markdown_content = gmail_connector.format_message_to_markdown(
message
)
if not markdown_content.strip():
logger.warning(f"Skipping message with no content: {message_id}")
documents_skipped += 1