Merge pull request #774 from AnishSarkar22/fix/documents

fix: add MIME type fetch for composio drive connector and some fixes
This commit is contained in:
Rohan Verma 2026-02-03 13:52:04 -08:00 committed by GitHub
commit 02470fb21b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 494 additions and 264 deletions

View file

@ -4,6 +4,7 @@ Composio Google Drive Connector Module.
Provides Google Drive specific methods for data retrieval and indexing via Composio. Provides Google Drive specific methods for data retrieval and indexing via Composio.
""" """
import contextlib
import hashlib import hashlib
import json import json
import logging import logging
@ -179,13 +180,14 @@ class ComposioGoogleDriveConnector(ComposioConnector):
) )
async def get_drive_file_content( async def get_drive_file_content(
self, file_id: str self, file_id: str, original_mime_type: str | None = None
) -> tuple[bytes | None, str | None]: ) -> tuple[bytes | None, str | None]:
""" """
Download file content from Google Drive via Composio. Download file content from Google Drive via Composio.
Args: Args:
file_id: Google Drive file ID. file_id: Google Drive file ID.
original_mime_type: Original MIME type (used to detect Google Workspace files for export).
Returns: Returns:
Tuple of (file content bytes, error message). Tuple of (file content bytes, error message).
@ -200,6 +202,31 @@ class ComposioGoogleDriveConnector(ComposioConnector):
connected_account_id=connected_account_id, connected_account_id=connected_account_id,
entity_id=entity_id, entity_id=entity_id,
file_id=file_id, file_id=file_id,
original_mime_type=original_mime_type,
)
async def get_file_metadata(
self, file_id: str
) -> tuple[dict[str, Any] | None, str | None]:
"""
Get metadata for a specific file from Google Drive.
Args:
file_id: The ID of the file to get metadata for.
Returns:
Tuple of (metadata dict, error message).
"""
connected_account_id = await self.get_connected_account_id()
if not connected_account_id:
return None, "No connected account ID found"
entity_id = await self.get_entity_id()
service = await self._get_service()
return await service.get_file_metadata(
connected_account_id=connected_account_id,
entity_id=entity_id,
file_id=file_id,
) )
async def get_drive_start_page_token(self) -> tuple[str | None, str | None]: async def get_drive_start_page_token(self) -> tuple[str | None, str | None]:
@ -292,8 +319,10 @@ async def _process_file_content(
if isinstance(content, str): if isinstance(content, str):
content = content.encode("utf-8") content = content.encode("utf-8")
# Check if this is a binary file # Check if this is a binary file based on extension or MIME type
if _is_binary_file(file_name, mime_type): is_binary = _is_binary_file(file_name, mime_type)
if is_binary:
# Use ETL service for binary files (PDF, Office docs, etc.) # Use ETL service for binary files (PDF, Office docs, etc.)
temp_file_path = None temp_file_path = None
try: try:
@ -316,7 +345,7 @@ async def _process_file_content(
return extracted_text return extracted_text
else: else:
# Fallback if extraction fails # Fallback if extraction fails
logger.warning(f"Could not extract text from binary file {file_name}") logger.warning(f"ETL returned empty for binary file {file_name}")
return f"# {file_name}\n\n[Binary file - text extraction failed]\n\n**File ID:** {file_id}\n**Type:** {mime_type}\n" return f"# {file_name}\n\n[Binary file - text extraction failed]\n\n**File ID:** {file_id}\n**Type:** {mime_type}\n"
except Exception as e: except Exception as e:
@ -327,10 +356,8 @@ async def _process_file_content(
finally: finally:
# Cleanup temp file # Cleanup temp file
if temp_file_path and os.path.exists(temp_file_path): if temp_file_path and os.path.exists(temp_file_path):
try: with contextlib.suppress(Exception):
os.unlink(temp_file_path) os.unlink(temp_file_path)
except Exception as e:
logger.debug(f"Could not delete temp file {temp_file_path}: {e}")
else: else:
# Text file - try to decode as UTF-8 # Text file - try to decode as UTF-8
try: try:
@ -372,9 +399,13 @@ async def _extract_text_with_etl(
from logging import ERROR, getLogger from logging import ERROR, getLogger
etl_service = config.ETL_SERVICE etl_service = config.ETL_SERVICE
logger.debug(
f"[_extract_text_with_etl] START - file_path={file_path}, file_name={file_name}, etl_service={etl_service}"
)
try: try:
if etl_service == "UNSTRUCTURED": if etl_service == "UNSTRUCTURED":
logger.debug("[_extract_text_with_etl] Using UNSTRUCTURED ETL")
from langchain_unstructured import UnstructuredLoader from langchain_unstructured import UnstructuredLoader
from app.utils.document_converters import convert_document_to_markdown from app.utils.document_converters import convert_document_to_markdown
@ -390,11 +421,20 @@ async def _extract_text_with_etl(
) )
docs = await loader.aload() docs = await loader.aload()
logger.debug(
f"[_extract_text_with_etl] UNSTRUCTURED loaded {len(docs) if docs else 0} docs"
)
if docs: if docs:
return await convert_document_to_markdown(docs) result = await convert_document_to_markdown(docs)
logger.debug(
f"[_extract_text_with_etl] UNSTRUCTURED result: {len(result) if result else 0} chars"
)
return result
logger.debug("[_extract_text_with_etl] UNSTRUCTURED returned no docs")
return None return None
elif etl_service == "LLAMACLOUD": elif etl_service == "LLAMACLOUD":
logger.debug("[_extract_text_with_etl] Using LLAMACLOUD ETL")
from app.tasks.document_processors.file_processors import ( from app.tasks.document_processors.file_processors import (
parse_with_llamacloud_retry, parse_with_llamacloud_retry,
) )
@ -413,11 +453,22 @@ async def _extract_text_with_etl(
markdown_documents = await result.aget_markdown_documents( markdown_documents = await result.aget_markdown_documents(
split_by_page=False split_by_page=False
) )
logger.debug(
f"[_extract_text_with_etl] LLAMACLOUD got {len(markdown_documents) if markdown_documents else 0} markdown docs"
)
if markdown_documents: if markdown_documents:
return markdown_documents[0].text text = markdown_documents[0].text
logger.debug(
f"[_extract_text_with_etl] LLAMACLOUD result: {len(text) if text else 0} chars"
)
return text
logger.debug(
"[_extract_text_with_etl] LLAMACLOUD returned no markdown docs"
)
return None return None
elif etl_service == "DOCLING": elif etl_service == "DOCLING":
logger.debug("[_extract_text_with_etl] Using DOCLING ETL")
from app.services.docling_service import create_docling_service from app.services.docling_service import create_docling_service
docling_service = create_docling_service() docling_service = create_docling_service()
@ -441,16 +492,30 @@ async def _extract_text_with_etl(
result = await docling_service.process_document( result = await docling_service.process_document(
file_path, file_name file_path, file_name
) )
logger.debug(
f"[_extract_text_with_etl] DOCLING result keys: {list(result.keys()) if result else 'None'}"
)
finally: finally:
pdfminer_logger.setLevel(original_level) pdfminer_logger.setLevel(original_level)
return result.get("content") content = result.get("content")
logger.debug(
f"[_extract_text_with_etl] DOCLING content: {len(content) if content else 0} chars"
)
return content
else: else:
logger.warning(f"Unknown ETL service: {etl_service}") logger.warning(
f"[_extract_text_with_etl] Unknown ETL service: {etl_service}"
)
return None return None
except Exception as e: except Exception as e:
logger.error(f"ETL extraction failed for {file_name}: {e!s}") logger.error(
f"[_extract_text_with_etl] ETL extraction EXCEPTION for {file_name}: {e!s}"
)
import traceback
logger.error(f"[_extract_text_with_etl] Traceback: {traceback.format_exc()}")
return None return None
@ -979,7 +1044,7 @@ async def _index_composio_drive_full_scan(
all_files.extend(folder_files[:max_files_per_folder]) all_files.extend(folder_files[:max_files_per_folder])
logger.info(f"Found {len(folder_files)} files in folder {folder_name}") logger.info(f"Found {len(folder_files)} files in folder {folder_name}")
# Add specifically selected files # Add specifically selected files - fetch metadata to get mimeType
for selected_file in selected_files: for selected_file in selected_files:
file_id = selected_file.get("id") file_id = selected_file.get("id")
file_name = selected_file.get("name", "Unknown") file_name = selected_file.get("name", "Unknown")
@ -987,14 +1052,35 @@ async def _index_composio_drive_full_scan(
if not file_id: if not file_id:
continue continue
# Add file info (we'll fetch content later during indexing) # Fetch file metadata to get proper mimeType
all_files.append( metadata, meta_error = await composio_connector.get_file_metadata(file_id)
{ if metadata and not meta_error:
"id": file_id, all_files.append(
"name": file_name, {
"mimeType": "", # Will be determined later "id": file_id,
} "name": metadata.get("name") or file_name,
) "mimeType": metadata.get("mimeType", ""),
"modifiedTime": metadata.get("modifiedTime", ""),
"createdTime": metadata.get("createdTime", ""),
}
)
logger.info(
f"Fetched metadata for UI-selected file: {file_name} "
f"(mimeType={metadata.get('mimeType', 'unknown')})"
)
else:
# Fallback if metadata fetch fails - content-based detection will handle it
logger.warning(
f"Could not fetch metadata for file {file_name}: {meta_error}. "
f"Falling back to content-based detection."
)
all_files.append(
{
"id": file_id,
"name": file_name,
"mimeType": "", # Content-based detection will handle this
}
)
else: else:
# No selection specified - fetch all files (original behavior) # No selection specified - fetch all files (original behavior)
page_token = None page_token = None
@ -1128,8 +1214,10 @@ async def _process_single_drive_file(
session, unique_identifier_hash session, unique_identifier_hash
) )
# Get file content # Get file content (pass mime_type for Google Workspace export handling)
content, content_error = await composio_connector.get_drive_file_content(file_id) content, content_error = await composio_connector.get_drive_file_content(
file_id, original_mime_type=mime_type
)
if content_error or not content: if content_error or not content:
logger.warning(f"Could not get content for file {file_name}: {content_error}") logger.warning(f"Could not get content for file {file_name}: {content_error}")

View file

@ -42,10 +42,6 @@ from app.utils.connector_naming import (
) )
from app.utils.oauth_security import OAuthStateManager from app.utils.oauth_security import OAuthStateManager
# Note: We no longer use check_duplicate_connector for Composio connectors because
# Composio generates a new connected_account_id each time, even for the same Google account.
# Instead, we check for existing connectors by type/space/user and update them.
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@ -256,11 +252,6 @@ async def composio_callback(
"connectedAccountId" "connectedAccountId"
) or query_params.get("connected_account_id") ) or query_params.get("connected_account_id")
# DEBUG: Log query parameter received
logger.info(
f"DEBUG: Callback received - connectedAccountId: {query_params.get('connectedAccountId')}, connected_account_id: {query_params.get('connected_account_id')}, using: {final_connected_account_id}"
)
# If we still don't have a connected_account_id, warn but continue # If we still don't have a connected_account_id, warn but continue
# (the connector will be created but indexing won't work until updated) # (the connector will be created but indexing won't work until updated)
if not final_connected_account_id: if not final_connected_account_id:
@ -273,6 +264,9 @@ async def composio_callback(
f"Successfully got connected_account_id: {final_connected_account_id}" f"Successfully got connected_account_id: {final_connected_account_id}"
) )
# Build entity_id for Composio API calls (same format as used in initiate)
entity_id = f"surfsense_{user_id}"
# Build connector config # Build connector config
connector_config = { connector_config = {
"composio_connected_account_id": final_connected_account_id, "composio_connected_account_id": final_connected_account_id,
@ -290,20 +284,51 @@ async def composio_callback(
) )
connector_type = SearchSourceConnectorType(connector_type_str) connector_type = SearchSourceConnectorType(connector_type_str)
# Check for existing connector of the same type for this user/space # Get the base name for this connector type (e.g., "Google Drive", "Gmail")
# When reconnecting, Composio gives a new connected_account_id, so we need to base_name = get_base_name_for_type(connector_type)
# check by connector_type, user_id, and search_space_id instead of connected_account_id
# FIRST: Get the email for this connected account
# This is needed to determine if it's a reconnection (same email) or new account
email = None
try:
email = await service.get_connected_account_email(
connected_account_id=final_connected_account_id,
entity_id=entity_id,
toolkit_id=toolkit_id,
)
if email:
logger.info(f"Retrieved email {email} for {toolkit_id} connector")
except Exception as email_error:
logger.warning(f"Could not get email for connector: {email_error!s}")
# Generate the connector name (with email if available)
# Format: "Gmail (Composio) - john@gmail.com" or "Gmail (Composio) 1" if no email
if email:
connector_name = f"{base_name} (Composio) - {email}"
else:
# Fallback to generic naming if email not available
count = await count_connectors_of_type(
session, connector_type, space_id, user_id
)
if count == 0:
connector_name = f"{base_name} (Composio) 1"
else:
connector_name = f"{base_name} (Composio) {count + 1}"
# Check if a connector with this SAME name already exists (reconnection case)
# This allows multiple accounts (different emails) while supporting reconnection
existing_connector_result = await session.execute( existing_connector_result = await session.execute(
select(SearchSourceConnector).where( select(SearchSourceConnector).where(
SearchSourceConnector.connector_type == connector_type, SearchSourceConnector.connector_type == connector_type,
SearchSourceConnector.search_space_id == space_id, SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.name == connector_name,
) )
) )
existing_connector = existing_connector_result.scalars().first() existing_connector = existing_connector_result.scalars().first()
if existing_connector: if existing_connector:
# Delete the old Composio connected account before updating # This is a RECONNECTION of the same account - update existing connector
old_connected_account_id = existing_connector.config.get( old_connected_account_id = existing_connector.config.get(
"composio_connected_account_id" "composio_connected_account_id"
) )
@ -320,22 +345,16 @@ async def composio_callback(
f"Deleted old Composio connected account {old_connected_account_id} " f"Deleted old Composio connected account {old_connected_account_id} "
f"before updating connector {existing_connector.id}" f"before updating connector {existing_connector.id}"
) )
else:
logger.warning(
f"Failed to delete old Composio connected account {old_connected_account_id}"
)
except Exception as delete_error: except Exception as delete_error:
# Log but don't fail - the old account may already be deleted
logger.warning( logger.warning(
f"Error deleting old Composio connected account {old_connected_account_id}: {delete_error!s}" f"Error deleting old Composio connected account {old_connected_account_id}: {delete_error!s}"
) )
# Update existing connector with new connected_account_id # Update existing connector with new connected_account_id
# IMPORTANT: Merge new credentials with existing config to preserve # Merge new credentials with existing config to preserve user settings
# user settings like selected_folders, selected_files, indexing_options,
# drive_page_token, etc. that would otherwise be wiped on reconnection.
logger.info( logger.info(
f"Updating existing Composio connector {existing_connector.id} with new connected_account_id {final_connected_account_id}" f"Reconnecting existing Composio connector {existing_connector.id} ({connector_name}) "
f"with new connected_account_id {final_connected_account_id}"
) )
existing_config = ( existing_config = (
existing_connector.config.copy() if existing_connector.config else {} existing_connector.config.copy() if existing_connector.config else {}
@ -347,28 +366,16 @@ async def composio_callback(
await session.commit() await session.commit()
await session.refresh(existing_connector) await session.refresh(existing_connector)
# Get the frontend connector ID based on toolkit_id
frontend_connector_id = TOOLKIT_TO_FRONTEND_CONNECTOR_ID.get( frontend_connector_id = TOOLKIT_TO_FRONTEND_CONNECTOR_ID.get(
toolkit_id, "composio-connector" toolkit_id, "composio-connector"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={existing_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={existing_connector.id}&view=configure"
) )
# This is a NEW account - create a new connector
try: try:
# Count existing connectors of this type to determine the number logger.info(f"Creating new Composio connector: {connector_name}")
count = await count_connectors_of_type(
session, connector_type, space_id, user_id
)
# Generate base name (e.g., "Gmail", "Google Drive")
base_name = get_base_name_for_type(connector_type)
# Format: "Gmail (Composio) 1", "Gmail (Composio) 2", etc.
if count == 0:
connector_name = f"{base_name} (Composio) 1"
else:
connector_name = f"{base_name} (Composio) {count + 1}"
db_connector = SearchSourceConnector( db_connector = SearchSourceConnector(
name=connector_name, name=connector_name,
@ -392,7 +399,7 @@ async def composio_callback(
toolkit_id, "composio-connector" toolkit_id, "composio-connector"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={db_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={db_connector.id}&view=configure"
) )
except IntegrityError as e: except IntegrityError as e:

View file

@ -15,17 +15,6 @@ from app.config import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Mapping of toolkit IDs to their Composio auth config IDs
# These use Composio's managed OAuth (no custom credentials needed)
COMPOSIO_TOOLKIT_AUTH_CONFIGS = {
"googledrive": "default", # Uses Composio's managed Google OAuth
"gmail": "default",
"googlecalendar": "default",
"slack": "default",
"notion": "default",
"github": "default",
}
# Mapping of toolkit IDs to their display names # Mapping of toolkit IDs to their display names
COMPOSIO_TOOLKIT_NAMES = { COMPOSIO_TOOLKIT_NAMES = {
"googledrive": "Google Drive", "googledrive": "Google Drive",
@ -234,134 +223,6 @@ class ComposioService:
logger.error(f"Failed to initiate Composio connection: {e!s}") logger.error(f"Failed to initiate Composio connection: {e!s}")
raise raise
async def get_connected_account(
self, connected_account_id: str
) -> dict[str, Any] | None:
"""
Get details of a connected account.
Args:
connected_account_id: The Composio connected account ID.
Returns:
Connected account details or None if not found.
"""
try:
# Pass connected_account_id as positional argument (not keyword)
account = self.client.connected_accounts.get(connected_account_id)
return {
"id": account.id,
"status": getattr(account, "status", None),
"toolkit": getattr(account, "toolkit", None),
"user_id": getattr(account, "user_id", None),
}
except Exception as e:
logger.error(
f"Failed to get connected account {connected_account_id}: {e!s}"
)
return None
async def list_all_connections(self) -> list[dict[str, Any]]:
"""
List ALL connected accounts (for debugging).
Returns:
List of all connected account details.
"""
try:
accounts_response = self.client.connected_accounts.list()
if hasattr(accounts_response, "items"):
accounts = accounts_response.items
elif hasattr(accounts_response, "__iter__"):
accounts = accounts_response
else:
logger.warning(
f"Unexpected accounts response type: {type(accounts_response)}"
)
return []
result = []
for acc in accounts:
toolkit_raw = getattr(acc, "toolkit", None)
toolkit_info = None
if toolkit_raw:
if isinstance(toolkit_raw, str):
toolkit_info = toolkit_raw
elif hasattr(toolkit_raw, "slug"):
toolkit_info = toolkit_raw.slug
elif hasattr(toolkit_raw, "name"):
toolkit_info = toolkit_raw.name
else:
toolkit_info = str(toolkit_raw)
result.append(
{
"id": acc.id,
"status": getattr(acc, "status", None),
"toolkit": toolkit_info,
"user_id": getattr(acc, "user_id", None),
}
)
return result
except Exception as e:
logger.error(f"Failed to list all connections: {e!s}")
return []
async def list_user_connections(self, user_id: str) -> list[dict[str, Any]]:
"""
List all connected accounts for a user.
Args:
user_id: The user's unique identifier.
Returns:
List of connected account details.
"""
try:
accounts_response = self.client.connected_accounts.list(user_id=user_id)
# Handle paginated response (may have .items attribute) or direct list
if hasattr(accounts_response, "items"):
accounts = accounts_response.items
elif hasattr(accounts_response, "__iter__"):
accounts = accounts_response
else:
logger.warning(
f"Unexpected accounts response type: {type(accounts_response)}"
)
return []
result = []
for acc in accounts:
# Extract toolkit info - might be string or object
toolkit_raw = getattr(acc, "toolkit", None)
toolkit_info = None
if toolkit_raw:
if isinstance(toolkit_raw, str):
toolkit_info = toolkit_raw
elif hasattr(toolkit_raw, "slug"):
toolkit_info = toolkit_raw.slug
elif hasattr(toolkit_raw, "name"):
toolkit_info = toolkit_raw.name
else:
toolkit_info = toolkit_raw
result.append(
{
"id": acc.id,
"status": getattr(acc, "status", None),
"toolkit": toolkit_info,
}
)
logger.info(f"Found {len(result)} connections for user {user_id}: {result}")
return result
except Exception as e:
logger.error(f"Failed to list connections for user {user_id}: {e!s}")
return []
async def delete_connected_account(self, connected_account_id: str) -> bool: async def delete_connected_account(self, connected_account_id: str) -> bool:
""" """
Delete a connected account from Composio. Delete a connected account from Composio.
@ -449,8 +310,11 @@ class ComposioService:
""" """
try: try:
# Composio uses snake_case for parameters # Composio uses snake_case for parameters
# IMPORTANT: Include 'fields' to ensure mimeType is returned in the response
# Without this, Google Drive API may not include mimeType for some files
params = { params = {
"page_size": min(page_size, 100), "page_size": min(page_size, 100),
"fields": "files(id,name,mimeType,modifiedTime,createdTime),nextPageToken",
} }
if folder_id: if folder_id:
# List contents of a specific folder (exclude shortcuts - we don't have access to them) # List contents of a specific folder (exclude shortcuts - we don't have access to them)
@ -498,7 +362,11 @@ class ComposioService:
return [], None, str(e) return [], None, str(e)
async def get_drive_file_content( async def get_drive_file_content(
self, connected_account_id: str, entity_id: str, file_id: str self,
connected_account_id: str,
entity_id: str,
file_id: str,
original_mime_type: str | None = None,
) -> tuple[bytes | None, str | None]: ) -> tuple[bytes | None, str | None]:
""" """
Download file content from Google Drive via Composio. Download file content from Google Drive via Composio.
@ -507,10 +375,13 @@ class ComposioService:
to a local directory, and the local file path is provided in the response. to a local directory, and the local file path is provided in the response.
Response includes: file_path, file_name, size fields. Response includes: file_path, file_name, size fields.
For Google Workspace files (Docs, Sheets, Slides), exports to PDF format.
Args: Args:
connected_account_id: Composio connected account ID. connected_account_id: Composio connected account ID.
entity_id: The entity/user ID that owns the connected account. entity_id: The entity/user ID that owns the connected account.
file_id: Google Drive file ID. file_id: Google Drive file ID.
original_mime_type: Original MIME type of the file (used to detect Google Workspace files).
Returns: Returns:
Tuple of (file content bytes, error message). Tuple of (file content bytes, error message).
@ -518,10 +389,19 @@ class ComposioService:
from pathlib import Path from pathlib import Path
try: try:
params = {"file_id": file_id}
# For Google Workspace files, explicitly export as PDF
# This ensures consistent behavior and proper binary detection
if original_mime_type and original_mime_type.startswith(
"application/vnd.google-apps."
):
params["mime_type"] = "application/pdf"
result = await self.execute_tool( result = await self.execute_tool(
connected_account_id=connected_account_id, connected_account_id=connected_account_id,
tool_name="GOOGLEDRIVE_DOWNLOAD_FILE", tool_name="GOOGLEDRIVE_DOWNLOAD_FILE",
params={"file_id": file_id}, params=params,
entity_id=entity_id, entity_id=entity_id,
) )
@ -651,6 +531,60 @@ class ComposioService:
logger.error(f"Failed to get Drive file content: {e!s}") logger.error(f"Failed to get Drive file content: {e!s}")
return None, str(e) return None, str(e)
async def get_file_metadata(
self, connected_account_id: str, entity_id: str, file_id: str
) -> tuple[dict[str, Any] | None, str | None]:
"""
Get metadata for a specific file from Google Drive.
Args:
connected_account_id: Composio connected account ID.
entity_id: The entity/user ID that owns the connected account.
file_id: The ID of the file to get metadata for.
Returns:
Tuple of (metadata dict, error message).
"""
try:
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GOOGLEDRIVE_GET_FILE_METADATA",
params={
"file_id": file_id,
"fields": "id,name,mimeType,modifiedTime,createdTime,size",
},
entity_id=entity_id,
)
if not result.get("success"):
return None, result.get("error", "Unknown error")
data = result.get("data", {})
# Handle nested response structure
if isinstance(data, dict):
inner_data = data.get("data", data)
if isinstance(inner_data, dict):
# Extract metadata fields with fallbacks for camelCase/snake_case
metadata = {
"id": inner_data.get("id") or file_id,
"name": inner_data.get("name", ""),
"mimeType": inner_data.get("mimeType")
or inner_data.get("mime_type", ""),
"modifiedTime": inner_data.get("modifiedTime")
or inner_data.get("modified_time", ""),
"createdTime": inner_data.get("createdTime")
or inner_data.get("created_time", ""),
"size": inner_data.get("size", ""),
}
return metadata, None
return None, "Could not extract metadata from response"
except Exception as e:
logger.error(f"Failed to get file metadata: {e!s}")
return None, str(e)
async def get_drive_start_page_token( async def get_drive_start_page_token(
self, connected_account_id: str, entity_id: str self, connected_account_id: str, entity_id: str
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
@ -945,6 +879,178 @@ class ComposioService:
logger.error(f"Failed to list Calendar events: {e!s}") logger.error(f"Failed to list Calendar events: {e!s}")
return [], str(e) return [], str(e)
# ===== User Info Methods =====
async def get_connected_account_email(
self,
connected_account_id: str,
entity_id: str,
toolkit_id: str,
) -> str | None:
"""
Get the email address associated with a connected account.
Uses toolkit-specific API calls:
- Google Drive: List files and extract owner email
- Gmail: Get user profile
- Google Calendar: List events and extract organizer/creator email
Args:
connected_account_id: Composio connected account ID.
entity_id: The entity/user ID that owns the connected account.
toolkit_id: The toolkit identifier (googledrive, gmail, googlecalendar).
Returns:
Email address string or None if not available.
"""
try:
email = await self._extract_email_for_toolkit(
connected_account_id, entity_id, toolkit_id
)
if email:
logger.info(f"Retrieved email {email} for {toolkit_id} connector")
else:
logger.warning(f"Could not retrieve email for {toolkit_id} connector")
return email
except Exception as e:
logger.error(f"Failed to get email for {toolkit_id} connector: {e!s}")
return None
async def _extract_email_for_toolkit(
self,
connected_account_id: str,
entity_id: str,
toolkit_id: str,
) -> str | None:
"""Extract email based on toolkit type."""
if toolkit_id == "googledrive":
return await self._get_drive_owner_email(connected_account_id, entity_id)
elif toolkit_id == "gmail":
return await self._get_gmail_profile_email(connected_account_id, entity_id)
elif toolkit_id == "googlecalendar":
return await self._get_calendar_user_email(connected_account_id, entity_id)
return None
async def _get_drive_owner_email(
self, connected_account_id: str, entity_id: str
) -> str | None:
"""Get email from Google Drive file owner where me=True."""
# List files owned by the user and find one where owner.me=True
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GOOGLEDRIVE_LIST_FILES",
params={
"page_size": 10,
"fields": "files(owners)",
"q": "'me' in owners", # Only files owned by current user
},
entity_id=entity_id,
)
if not result.get("success"):
return None
data = result.get("data", {})
if not isinstance(data, dict):
return None
files = data.get("files") or data.get("data", {}).get("files", [])
for file in files:
owners = file.get("owners", [])
for owner in owners:
# Only return email if this is the current user (me=True)
if owner.get("me") and owner.get("emailAddress"):
return owner.get("emailAddress")
return None
async def _get_gmail_profile_email(
self, connected_account_id: str, entity_id: str
) -> str | None:
"""Get email from Gmail profile."""
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GMAIL_GET_PROFILE",
params={},
entity_id=entity_id,
)
if not result.get("success"):
return None
data = result.get("data", {})
if not isinstance(data, dict):
return None
return data.get("emailAddress") or data.get("data", {}).get("emailAddress")
async def _get_calendar_user_email(
self, connected_account_id: str, entity_id: str
) -> str | None:
"""Get email from Google Calendar primary calendar or event organizer/creator."""
# Method 1: Get primary calendar - the "summary" field is the user's email
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GOOGLECALENDAR_GET_CALENDAR",
params={"calendar_id": "primary"},
entity_id=entity_id,
)
if result.get("success"):
data = result.get("data", {})
if isinstance(data, dict):
# Handle nested structure: data['data']['calendar_data']['summary']
calendar_data = (
data.get("data", {}).get("calendar_data", {})
if isinstance(data.get("data"), dict)
else {}
)
summary = (
calendar_data.get("summary")
or calendar_data.get("id")
or data.get("data", {}).get("summary")
or data.get("summary")
)
if summary and "@" in summary:
return summary
# Method 2: Fallback - list events to get calendar summary (owner's email)
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GOOGLECALENDAR_EVENTS_LIST",
params={"max_results": 20},
entity_id=entity_id,
)
if not result.get("success"):
return None
data = result.get("data", {})
if not isinstance(data, dict):
return None
# The events list response contains 'summary' which is the calendar owner's email
nested_data = data.get("data", {}) if isinstance(data.get("data"), dict) else {}
summary = nested_data.get("summary") or data.get("summary")
if summary and "@" in summary:
return summary
# Method 3: Check event organizers/creators
items = nested_data.get("items", []) or data.get("items", [])
for event in items:
organizer = event.get("organizer", {})
if organizer.get("self"):
return organizer.get("email")
creator = event.get("creator", {})
if creator.get("self"):
return creator.get("email")
return None
# Singleton instance # Singleton instance
_composio_service: ComposioService | None = None _composio_service: ComposioService | None = None

View file

@ -166,8 +166,8 @@ async def _delete_connector_async(
user_id=UUID(user_id), user_id=UUID(user_id),
search_space_id=search_space_id, search_space_id=search_space_id,
type="connector_deletion", type="connector_deletion",
title=f"{connector_name} Removed", title=f"{connector_name} removed",
message=f"Connector and {total_deleted} {doc_text} have been removed from your knowledge base.", message=f"Cleanup complete. {total_deleted} {doc_text} removed.",
notification_metadata={ notification_metadata={
"connector_id": connector_id, "connector_id": connector_id,
"connector_name": connector_name, "connector_name": connector_name,

View file

@ -9,7 +9,6 @@ import {
import { useQueryClient } from "@tanstack/react-query"; import { useQueryClient } from "@tanstack/react-query";
import { useAtomValue, useSetAtom } from "jotai"; import { useAtomValue, useSetAtom } from "jotai";
import { useParams, useSearchParams } from "next/navigation"; import { useParams, useSearchParams } from "next/navigation";
import { useTranslations } from "next-intl";
import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner"; import { toast } from "sonner";
import { z } from "zod"; import { z } from "zod";
@ -39,7 +38,7 @@ import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview"; import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview";
import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage"; import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage";
import { RecallMemoryToolUI, SaveMemoryToolUI } from "@/components/tool-ui/user-memory"; import { RecallMemoryToolUI, SaveMemoryToolUI } from "@/components/tool-ui/user-memory";
import { Spinner } from "@/components/ui/spinner"; import { Skeleton } from "@/components/ui/skeleton";
import { useChatSessionStateSync } from "@/hooks/use-chat-session-state"; import { useChatSessionStateSync } from "@/hooks/use-chat-session-state";
import { useMessagesElectric } from "@/hooks/use-messages-electric"; import { useMessagesElectric } from "@/hooks/use-messages-electric";
// import { WriteTodosToolUI } from "@/components/tool-ui/write-todos"; // import { WriteTodosToolUI } from "@/components/tool-ui/write-todos";
@ -53,12 +52,10 @@ import {
} from "@/lib/chat/podcast-state"; } from "@/lib/chat/podcast-state";
import { import {
appendMessage, appendMessage,
type ChatVisibility,
createThread, createThread,
getRegenerateUrl, getRegenerateUrl,
getThreadFull, getThreadFull,
getThreadMessages, getThreadMessages,
type MessageRecord,
type ThreadRecord, type ThreadRecord,
} from "@/lib/chat/thread-persistence"; } from "@/lib/chat/thread-persistence";
import { import {
@ -67,6 +64,7 @@ import {
trackChatMessageSent, trackChatMessageSent,
trackChatResponseReceived, trackChatResponseReceived,
} from "@/lib/posthog/events"; } from "@/lib/posthog/events";
import { documentsApiService } from "@/lib/apis/documents-api.service";
/** /**
* Extract thinking steps from message content * Extract thinking steps from message content
@ -137,7 +135,6 @@ interface ThinkingStepData {
} }
export default function NewChatPage() { export default function NewChatPage() {
const t = useTranslations("dashboard");
const params = useParams(); const params = useParams();
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const [isInitializing, setIsInitializing] = useState(true); const [isInitializing, setIsInitializing] = useState(true);
@ -329,6 +326,33 @@ export default function NewChatPage() {
initializeThread(); initializeThread();
}, [initializeThread]); }, [initializeThread]);
// Prefetch document titles for @ mention picker
// Runs when user lands on page so data is ready when they type @
useEffect(() => {
if (!searchSpaceId) return;
const prefetchParams = {
search_space_id: searchSpaceId,
page: 0,
page_size: 20,
};
queryClient.prefetchQuery({
queryKey: ["document-titles", prefetchParams],
queryFn: () => documentsApiService.searchDocumentTitles({ queryParams: prefetchParams }),
staleTime: 60 * 1000,
});
queryClient.prefetchQuery({
queryKey: ["surfsense-docs-mention", "", false],
queryFn: () =>
documentsApiService.getSurfsenseDocs({
queryParams: { page: 0, page_size: 20 },
}),
staleTime: 3 * 60 * 1000,
});
}, [searchSpaceId, queryClient]);
// Handle scroll to comment from URL query params (e.g., from inbox item click) // Handle scroll to comment from URL query params (e.g., from inbox item click)
const searchParams = useSearchParams(); const searchParams = useSearchParams();
const targetCommentIdParam = searchParams.get("commentId"); const targetCommentIdParam = searchParams.get("commentId");
@ -367,19 +391,6 @@ export default function NewChatPage() {
setIsRunning(false); setIsRunning(false);
}, []); }, []);
// Handle visibility change from ChatShareButton
const handleVisibilityChange = useCallback(
(newVisibility: ChatVisibility) => {
setCurrentThread((prev) => (prev ? { ...prev, visibility: newVisibility } : null));
// Refetch all thread queries so sidebar reflects the change immediately
// Use predicate to match any query that starts with "threads"
queryClient.refetchQueries({
predicate: (query) => Array.isArray(query.queryKey) && query.queryKey[0] === "threads",
});
},
[queryClient]
);
// Handle new message from user // Handle new message from user
const onNew = useCallback( const onNew = useCallback(
async (message: AppendMessage) => { async (message: AppendMessage) => {
@ -1346,14 +1357,11 @@ export default function NewChatPage() {
); );
// Handle reloading/refreshing the last AI response // Handle reloading/refreshing the last AI response
const onReload = useCallback( const onReload = useCallback(async () => {
async (parentId: string | null) => { // parentId is the ID of the message to reload from (the user message)
// parentId is the ID of the message to reload from (the user message) // We call regenerate without a query to use the same query
// We call regenerate without a query to use the same query await handleRegenerate(null);
await handleRegenerate(null); }, [handleRegenerate]);
},
[handleRegenerate]
);
// Create external store runtime with attachment support // Create external store runtime with attachment support
const runtime = useExternalStoreRuntime({ const runtime = useExternalStoreRuntime({
@ -1372,9 +1380,39 @@ export default function NewChatPage() {
// Show loading state only when loading an existing thread // Show loading state only when loading an existing thread
if (isInitializing) { if (isInitializing) {
return ( return (
<div className="flex h-[calc(100vh-64px)] flex-col items-center justify-center gap-4"> <div className="flex h-[calc(100vh-64px)] flex-col bg-background px-4">
<Spinner size="lg" /> <div className="mx-auto w-full max-w-[44rem] flex flex-1 flex-col gap-6 py-8">
<div className="text-sm text-muted-foreground">{t("loading_chat")}</div> {/* User message */}
<div className="flex justify-end">
<Skeleton className="h-12 w-56 rounded-2xl" />
</div>
{/* Assistant message */}
<div className="flex flex-col gap-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-[85%]" />
<Skeleton className="h-4 w-[70%]" />
</div>
{/* User message */}
<div className="flex justify-end">
<Skeleton className="h-12 w-40 rounded-2xl" />
</div>
{/* Assistant message */}
<div className="flex flex-col gap-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-[90%]" />
<Skeleton className="h-4 w-[60%]" />
</div>
</div>
{/* Input bar */}
<div className="sticky bottom-0 pb-6 bg-background">
<div className="mx-auto w-full max-w-[44rem]">
<Skeleton className="h-24 w-full rounded-2xl" />
</div>
</div>
</div> </div>
); );
} }

View file

@ -1,6 +1,6 @@
"use client"; "use client";
import { keepPreviousData, useQuery, useQueryClient } from "@tanstack/react-query"; import { keepPreviousData, useQuery } from "@tanstack/react-query";
import { import {
forwardRef, forwardRef,
useCallback, useCallback,
@ -14,6 +14,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import type { Document, SearchDocumentTitlesResponse } from "@/contracts/types/document.types"; import type { Document, SearchDocumentTitlesResponse } from "@/contracts/types/document.types";
import { documentsApiService } from "@/lib/apis/documents-api.service"; import { documentsApiService } from "@/lib/apis/documents-api.service";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { Skeleton } from "@/components/ui/skeleton";
export interface DocumentMentionPickerRef { export interface DocumentMentionPickerRef {
selectHighlighted: () => void; selectHighlighted: () => void;
@ -77,8 +78,6 @@ export const DocumentMentionPicker = forwardRef<
}, },
ref ref
) { ) {
const queryClient = useQueryClient();
// Debounced search value to minimize API calls and prevent race conditions // Debounced search value to minimize API calls and prevent race conditions
const search = externalSearch; const search = externalSearch;
const debouncedSearch = useDebounced(search, DEBOUNCE_MS); const debouncedSearch = useDebounced(search, DEBOUNCE_MS);
@ -106,32 +105,6 @@ export const DocumentMentionPicker = forwardRef<
const shouldSearch = debouncedSearch.trim().length > 0; const shouldSearch = debouncedSearch.trim().length > 0;
const isSingleCharSearch = debouncedSearch.trim().length === 1; const isSingleCharSearch = debouncedSearch.trim().length === 1;
// Prefetch initial data on mount for instant display when picker opens
useEffect(() => {
if (!searchSpaceId) return;
const prefetchParams = {
search_space_id: searchSpaceId,
page: 0,
page_size: PAGE_SIZE,
};
queryClient.prefetchQuery({
queryKey: ["document-titles", prefetchParams],
queryFn: () => documentsApiService.searchDocumentTitles({ queryParams: prefetchParams }),
staleTime: 60 * 1000,
});
queryClient.prefetchQuery({
queryKey: ["surfsense-docs-mention", "", false],
queryFn: () =>
documentsApiService.getSurfsenseDocs({
queryParams: { page: 0, page_size: PAGE_SIZE },
}),
staleTime: 3 * 60 * 1000,
});
}, [searchSpaceId, queryClient]);
// Reset pagination state when search query or search space changes. // Reset pagination state when search query or search space changes.
// Documents are not cleared to maintain visual continuity during fetches. // Documents are not cleared to maintain visual continuity during fetches.
// biome-ignore lint/correctness/useExhaustiveDependencies: Intentional reset on search/space change // biome-ignore lint/correctness/useExhaustiveDependencies: Intentional reset on search/space change
@ -439,8 +412,26 @@ export const DocumentMentionPicker = forwardRef<
onScroll={handleScroll} onScroll={handleScroll}
> >
{actualLoading ? ( {actualLoading ? (
<div className="flex items-center justify-center py-4"> <div className="py-1 px-2">
<div className="animate-spin h-5 w-5 border-2 border-primary border-t-transparent rounded-full" /> <div className="px-3 py-2">
<Skeleton className="h-[16px] w-24" />
</div>
{["a", "b", "c", "d", "e"].map((id, i) => (
<div
key={id}
className={cn(
"w-full flex items-center gap-2 px-3 py-2 text-left rounded-md",
i >= 3 && "hidden sm:flex"
)}
>
<span className="shrink-0">
<Skeleton className="h-4 w-4" />
</span>
<span className="flex-1 text-sm">
<Skeleton className="h-[20px]" style={{ width: `${60 + ((i * 7) % 30)}%` }} />
</span>
</div>
))}
</div> </div>
) : actualDocuments.length > 0 ? ( ) : actualDocuments.length > 0 ? (
<div className="py-1 px-2"> <div className="py-1 px-2">