mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 21:32:39 +02:00
feat: enhance Google Drive account authentication handling
- Added checks for expired authentication in Google Drive file creation and deletion tools, returning appropriate error messages for re-authentication. - Updated the Google Drive tool metadata service to track account health and persist authentication status. - Enhanced UI components to display authentication errors and differentiate between valid and expired accounts, improving user experience during file operations.
This commit is contained in:
parent
d21593ee71
commit
75f0975674
7 changed files with 213 additions and 27 deletions
|
|
@ -87,6 +87,15 @@ def create_create_google_drive_file_tool(
|
|||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Google Drive accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -76,6 +76,18 @@ def create_delete_google_drive_file_tool(
|
|||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Google Drive account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
|
||||
file = context["file"]
|
||||
file_id = file["file_id"]
|
||||
document_id = file.get("document_id")
|
||||
|
|
|
|||
|
|
@ -1,15 +1,21 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -71,8 +77,17 @@ class GoogleDriveToolMetadataService:
|
|||
"error": "No Google Drive account connected",
|
||||
}
|
||||
|
||||
accounts_with_status = []
|
||||
for acc in accounts:
|
||||
acc_dict = acc.to_dict()
|
||||
auth_expired = await self._check_account_health(acc.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(acc.id)
|
||||
accounts_with_status.append(acc_dict)
|
||||
|
||||
return {
|
||||
"accounts": [acc.to_dict() for acc in accounts],
|
||||
"accounts": accounts_with_status,
|
||||
"supported_types": ["google_doc", "google_sheet"],
|
||||
}
|
||||
|
||||
|
|
@ -127,8 +142,14 @@ class GoogleDriveToolMetadataService:
|
|||
account = GoogleDriveAccount.from_connector(connector)
|
||||
file = GoogleDriveFile.from_document(document)
|
||||
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(connector.id)
|
||||
|
||||
return {
|
||||
"account": account.to_dict(),
|
||||
"account": acc_dict,
|
||||
"file": file.to_dict(),
|
||||
}
|
||||
|
||||
|
|
@ -151,3 +172,67 @@ class GoogleDriveToolMetadataService:
|
|||
)
|
||||
connectors = result.scalars().all()
|
||||
return [GoogleDriveAccount.from_connector(c) for c in connectors]
|
||||
|
||||
async def _check_account_health(self, connector_id: int) -> bool:
|
||||
"""Check if a Google Drive connector's credentials are still valid.
|
||||
|
||||
Uses a lightweight ``files.list(pageSize=1)`` call to verify access.
|
||||
|
||||
Returns True if the credentials are expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if not connector:
|
||||
return True
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=self._db_session,
|
||||
connector_id=connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
await client.list_files(
|
||||
query="trashed = false", page_size=1, fields="files(id)"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Google Drive connector %s health check failed: %s",
|
||||
connector_id,
|
||||
e,
|
||||
)
|
||||
return True
|
||||
|
||||
async def _persist_auth_expired(self, connector_id: int) -> None:
|
||||
"""Persist ``auth_expired: True`` to the connector config if not already set."""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
db_connector = result.scalar_one_or_none()
|
||||
if db_connector and not db_connector.config.get("auth_expired"):
|
||||
db_connector.config = {**db_connector.config, "auth_expired": True}
|
||||
flag_modified(db_connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(db_connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue