SurfSense/surfsense_backend/app/services/gmail/tool_metadata_service.py
2026-05-02 21:16:03 -07:00

543 lines
20 KiB
Python

import asyncio
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from sqlalchemy import String, and_, cast, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import (
Document,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
)
from app.services.composio_service import ComposioService
logger = logging.getLogger(__name__)
@dataclass
class GmailAccount:
id: int
name: str
email: str
@classmethod
def from_connector(cls, connector: SearchSourceConnector) -> "GmailAccount":
email = ""
if connector.name and " - " in connector.name:
email = connector.name.split(" - ", 1)[1]
return cls(id=connector.id, name=connector.name, email=email)
def to_dict(self) -> dict:
return {"id": self.id, "name": self.name, "email": self.email}
@dataclass
class GmailMessage:
message_id: str
thread_id: str
subject: str
sender: str
date: str
connector_id: int
document_id: int
@classmethod
def from_document(cls, document: Document) -> "GmailMessage":
meta = document.document_metadata or {}
return cls(
message_id=meta.get("message_id", ""),
thread_id=meta.get("thread_id", ""),
subject=meta.get("subject", document.title),
sender=meta.get("sender", ""),
date=meta.get("date", ""),
connector_id=document.connector_id,
document_id=document.id,
)
def to_dict(self) -> dict:
return {
"message_id": self.message_id,
"thread_id": self.thread_id,
"subject": self.subject,
"sender": self.sender,
"date": self.date,
"connector_id": self.connector_id,
"document_id": self.document_id,
}
class GmailToolMetadataService:
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
return (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
)
def _get_composio_connected_account_id(
self, connector: SearchSourceConnector
) -> str:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
raise ValueError("Composio connected_account_id not found")
return cca_id
def _unwrap_composio_data(self, data: Any) -> Any:
if isinstance(data, dict):
inner = data.get("data", data)
if isinstance(inner, dict):
return inner.get("response_data", inner)
return inner
return data
async def _execute_composio_gmail_tool(
self,
connector: SearchSourceConnector,
tool_name: str,
params: dict[str, Any],
) -> tuple[Any, str | None]:
result = await ComposioService().execute_tool(
connected_account_id=self._get_composio_connected_account_id(connector),
tool_name=tool_name,
params=params,
entity_id=f"surfsense_{connector.user_id}",
)
if not result.get("success"):
return None, result.get("error", "Unknown Composio Gmail error")
return self._unwrap_composio_data(result.get("data")), None
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
if self._is_composio_connector(connector):
raise ValueError(
"Composio Gmail connectors must use Composio tool execution"
)
config_data = dict(connector.config)
from app.config import config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
return Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
async def _check_account_health(self, connector_id: int) -> bool:
"""Check if a Gmail connector's credentials are still valid.
Uses a lightweight ``users().getProfile(userId='me')`` call.
Returns True if the credentials are expired/invalid, False if healthy.
"""
try:
result = await self._db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
)
)
connector = result.scalar_one_or_none()
if not connector:
return True
if self._is_composio_connector(connector):
_profile, error = await self._execute_composio_gmail_tool(
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
)
return bool(error)
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
await asyncio.get_event_loop().run_in_executor(
None, lambda: service.users().getProfile(userId="me").execute()
)
return False
except Exception as e:
logger.warning(
"Gmail connector %s health check failed: %s",
connector_id,
e,
)
return True
async def _persist_auth_expired(self, connector_id: int) -> None:
"""Persist ``auth_expired: True`` to the connector config if not already set."""
try:
result = await self._db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
)
)
db_connector = result.scalar_one_or_none()
if db_connector and not db_connector.config.get("auth_expired"):
db_connector.config = {**db_connector.config, "auth_expired": True}
flag_modified(db_connector, "config")
await self._db_session.commit()
await self._db_session.refresh(db_connector)
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector_id,
exc_info=True,
)
async def _get_accounts(
self, search_space_id: int, user_id: str
) -> list[GmailAccount]:
result = await self._db_session.execute(
select(SearchSourceConnector)
.filter(
and_(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(
[
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
),
)
)
.order_by(SearchSourceConnector.last_indexed_at.desc())
)
connectors = result.scalars().all()
return [GmailAccount.from_connector(c) for c in connectors]
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
accounts = await self._get_accounts(search_space_id, user_id)
if not accounts:
return {
"accounts": [],
"error": "No Gmail account connected",
}
accounts_with_status = []
for acc in accounts:
acc_dict = acc.to_dict()
auth_expired = await self._check_account_health(acc.id)
acc_dict["auth_expired"] = auth_expired
if auth_expired:
await self._persist_auth_expired(acc.id)
else:
try:
result = await self._db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == acc.id
)
)
connector = result.scalar_one_or_none()
if connector:
if self._is_composio_connector(connector):
profile, error = await self._execute_composio_gmail_tool(
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
)
if error:
raise RuntimeError(error)
else:
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
profile = await asyncio.get_event_loop().run_in_executor(
None,
lambda service=service: (
service.users().getProfile(userId="me").execute()
),
)
acc_dict["email"] = profile.get("emailAddress", "")
except Exception:
logger.warning(
"Failed to fetch email for Gmail connector %s",
acc.id,
exc_info=True,
)
accounts_with_status.append(acc_dict)
return {"accounts": accounts_with_status}
async def get_update_context(
self, search_space_id: int, user_id: str, email_ref: str
) -> dict:
document, connector = await self._resolve_email(
search_space_id, user_id, email_ref
)
if not document or not connector:
return {
"error": (
f"Draft '{email_ref}' not found in your indexed Gmail messages. "
"This could mean: (1) the draft doesn't exist, "
"(2) it hasn't been indexed yet, "
"or (3) the subject is different. "
"Please check the exact draft subject in Gmail."
)
}
account = GmailAccount.from_connector(connector)
message = GmailMessage.from_document(document)
acc_dict = account.to_dict()
auth_expired = await self._check_account_health(connector.id)
acc_dict["auth_expired"] = auth_expired
if auth_expired:
await self._persist_auth_expired(connector.id)
result: dict = {
"account": acc_dict,
"email": message.to_dict(),
}
meta = document.document_metadata or {}
if meta.get("draft_id"):
result["draft_id"] = meta["draft_id"]
if not auth_expired:
existing_body = await self._fetch_draft_body(
connector, message.message_id, meta.get("draft_id")
)
if existing_body is not None:
result["existing_body"] = existing_body
return result
async def _fetch_draft_body(
self,
connector: SearchSourceConnector,
message_id: str,
draft_id: str | None,
) -> str | None:
"""Fetch the plain-text body of a Gmail draft via the API.
Tries ``drafts.get`` first (if *draft_id* is available), then falls
back to scanning ``drafts.list`` to resolve the draft by *message_id*.
Returns ``None`` on any failure so callers can degrade gracefully.
"""
try:
if self._is_composio_connector(connector):
if not draft_id:
draft_id = await self._find_composio_draft_id(connector, message_id)
if not draft_id:
return None
draft, error = await self._execute_composio_gmail_tool(
connector,
"GMAIL_GET_DRAFT",
{"user_id": "me", "draft_id": draft_id, "format": "full"},
)
if error or not isinstance(draft, dict):
return None
payload = draft.get("message", {}).get("payload", {})
return self._extract_body_from_payload(payload)
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
if not draft_id:
draft_id = await self._find_draft_id(service, message_id)
if not draft_id:
return None
draft = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.users()
.drafts()
.get(userId="me", id=draft_id, format="full")
.execute()
),
)
payload = draft.get("message", {}).get("payload", {})
return self._extract_body_from_payload(payload)
except Exception:
logger.warning(
"Failed to fetch draft body for message_id=%s",
message_id,
exc_info=True,
)
return None
async def _find_composio_draft_id(
self, connector: SearchSourceConnector, message_id: str
) -> str | None:
page_token = ""
while True:
params: dict[str, Any] = {
"user_id": "me",
"max_results": 100,
"verbose": False,
}
if page_token:
params["page_token"] = page_token
data, error = await self._execute_composio_gmail_tool(
connector, "GMAIL_LIST_DRAFTS", params
)
if error or not isinstance(data, dict):
return None
for draft in data.get("drafts", []):
if draft.get("message", {}).get("id") == message_id:
return draft.get("id")
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
if not page_token:
return None
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
"""Resolve a draft ID from its message ID by scanning drafts.list."""
try:
page_token = None
while True:
kwargs: dict[str, Any] = {"userId": "me", "maxResults": 100}
if page_token:
kwargs["pageToken"] = page_token
response = await asyncio.get_event_loop().run_in_executor(
None,
lambda kwargs=kwargs: (
service.users().drafts().list(**kwargs).execute()
),
)
for draft in response.get("drafts", []):
if draft.get("message", {}).get("id") == message_id:
return draft["id"]
page_token = response.get("nextPageToken")
if not page_token:
break
return None
except Exception:
logger.warning(
"Failed to look up draft by message_id=%s", message_id, exc_info=True
)
return None
@staticmethod
def _extract_body_from_payload(payload: dict) -> str | None:
"""Extract the plain-text (or html→text) body from a Gmail payload."""
import base64
def _get_parts(p: dict) -> list[dict]:
if "parts" in p:
parts: list[dict] = []
for sub in p["parts"]:
parts.extend(_get_parts(sub))
return parts
return [p]
parts = _get_parts(payload)
text_content = ""
for part in parts:
mime_type = part.get("mimeType", "")
data = part.get("body", {}).get("data", "")
if mime_type == "text/plain" and data:
text_content += base64.urlsafe_b64decode(data + "===").decode(
"utf-8", errors="ignore"
)
elif mime_type == "text/html" and data and not text_content:
from markdownify import markdownify as md
raw_html = base64.urlsafe_b64decode(data + "===").decode(
"utf-8", errors="ignore"
)
text_content = md(raw_html).strip()
return text_content.strip() if text_content.strip() else None
async def get_trash_context(
self, search_space_id: int, user_id: str, email_ref: str
) -> dict:
document, connector = await self._resolve_email(
search_space_id, user_id, email_ref
)
if not document or not connector:
return {
"error": (
f"Email '{email_ref}' not found in your indexed Gmail messages. "
"This could mean: (1) the email doesn't exist, "
"(2) it hasn't been indexed yet, "
"or (3) the subject is different."
)
}
account = GmailAccount.from_connector(connector)
message = GmailMessage.from_document(document)
acc_dict = account.to_dict()
auth_expired = await self._check_account_health(connector.id)
acc_dict["auth_expired"] = auth_expired
if auth_expired:
await self._persist_auth_expired(connector.id)
return {
"account": acc_dict,
"email": message.to_dict(),
}
async def _resolve_email(
self, search_space_id: int, user_id: str, email_ref: str
) -> tuple[Document | None, SearchSourceConnector | None]:
result = await self._db_session.execute(
select(Document, SearchSourceConnector)
.join(
SearchSourceConnector,
Document.connector_id == SearchSourceConnector.id,
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type.in_(
[
DocumentType.GOOGLE_GMAIL_CONNECTOR,
DocumentType.COMPOSIO_GMAIL_CONNECTOR,
]
),
SearchSourceConnector.user_id == user_id,
or_(
func.lower(cast(Document.document_metadata["subject"], String))
== func.lower(email_ref),
func.lower(Document.title) == func.lower(email_ref),
),
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
row = result.first()
if row:
return row[0], row[1]
return None, None