mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
451 lines
16 KiB
Python
451 lines
16 KiB
Python
import asyncio
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from google.oauth2.credentials import Credentials
|
|
from googleapiclient.discovery import build
|
|
from sqlalchemy import String, and_, cast, func, or_
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm.attributes import flag_modified
|
|
|
|
from app.db import (
|
|
Document,
|
|
DocumentType,
|
|
SearchSourceConnector,
|
|
SearchSourceConnectorType,
|
|
)
|
|
from app.utils.google_credentials import build_composio_credentials
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class GmailAccount:
|
|
id: int
|
|
name: str
|
|
email: str
|
|
|
|
@classmethod
|
|
def from_connector(cls, connector: SearchSourceConnector) -> "GmailAccount":
|
|
email = ""
|
|
if connector.name and " - " in connector.name:
|
|
email = connector.name.split(" - ", 1)[1]
|
|
return cls(id=connector.id, name=connector.name, email=email)
|
|
|
|
def to_dict(self) -> dict:
|
|
return {"id": self.id, "name": self.name, "email": self.email}
|
|
|
|
|
|
@dataclass
|
|
class GmailMessage:
|
|
message_id: str
|
|
thread_id: str
|
|
subject: str
|
|
sender: str
|
|
date: str
|
|
connector_id: int
|
|
document_id: int
|
|
|
|
@classmethod
|
|
def from_document(cls, document: Document) -> "GmailMessage":
|
|
meta = document.document_metadata or {}
|
|
return cls(
|
|
message_id=meta.get("message_id", ""),
|
|
thread_id=meta.get("thread_id", ""),
|
|
subject=meta.get("subject", document.title),
|
|
sender=meta.get("sender", ""),
|
|
date=meta.get("date", ""),
|
|
connector_id=document.connector_id,
|
|
document_id=document.id,
|
|
)
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"message_id": self.message_id,
|
|
"thread_id": self.thread_id,
|
|
"subject": self.subject,
|
|
"sender": self.sender,
|
|
"date": self.date,
|
|
"connector_id": self.connector_id,
|
|
"document_id": self.document_id,
|
|
}
|
|
|
|
|
|
class GmailToolMetadataService:
|
|
def __init__(self, db_session: AsyncSession):
|
|
self._db_session = db_session
|
|
|
|
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
|
if (
|
|
connector.connector_type
|
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
|
):
|
|
cca_id = connector.config.get("composio_connected_account_id")
|
|
if cca_id:
|
|
return build_composio_credentials(cca_id)
|
|
|
|
config_data = dict(connector.config)
|
|
|
|
from app.config import config
|
|
from app.utils.oauth_security import TokenEncryption
|
|
|
|
token_encrypted = config_data.get("_token_encrypted", False)
|
|
if token_encrypted and config.SECRET_KEY:
|
|
token_encryption = TokenEncryption(config.SECRET_KEY)
|
|
if config_data.get("token"):
|
|
config_data["token"] = token_encryption.decrypt_token(
|
|
config_data["token"]
|
|
)
|
|
if config_data.get("refresh_token"):
|
|
config_data["refresh_token"] = token_encryption.decrypt_token(
|
|
config_data["refresh_token"]
|
|
)
|
|
if config_data.get("client_secret"):
|
|
config_data["client_secret"] = token_encryption.decrypt_token(
|
|
config_data["client_secret"]
|
|
)
|
|
|
|
exp = config_data.get("expiry", "")
|
|
if exp:
|
|
exp = exp.replace("Z", "")
|
|
|
|
return Credentials(
|
|
token=config_data.get("token"),
|
|
refresh_token=config_data.get("refresh_token"),
|
|
token_uri=config_data.get("token_uri"),
|
|
client_id=config_data.get("client_id"),
|
|
client_secret=config_data.get("client_secret"),
|
|
scopes=config_data.get("scopes", []),
|
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
|
)
|
|
|
|
async def _check_account_health(self, connector_id: int) -> bool:
|
|
"""Check if a Gmail connector's credentials are still valid.
|
|
|
|
Uses a lightweight ``users().getProfile(userId='me')`` call.
|
|
|
|
Returns True if the credentials are expired/invalid, False if healthy.
|
|
"""
|
|
try:
|
|
result = await self._db_session.execute(
|
|
select(SearchSourceConnector).where(
|
|
SearchSourceConnector.id == connector_id
|
|
)
|
|
)
|
|
connector = result.scalar_one_or_none()
|
|
if not connector:
|
|
return True
|
|
|
|
creds = await self._build_credentials(connector)
|
|
service = build("gmail", "v1", credentials=creds)
|
|
await asyncio.get_event_loop().run_in_executor(
|
|
None, lambda: service.users().getProfile(userId="me").execute()
|
|
)
|
|
return False
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Gmail connector %s health check failed: %s",
|
|
connector_id,
|
|
e,
|
|
)
|
|
return True
|
|
|
|
async def _persist_auth_expired(self, connector_id: int) -> None:
|
|
"""Persist ``auth_expired: True`` to the connector config if not already set."""
|
|
try:
|
|
result = await self._db_session.execute(
|
|
select(SearchSourceConnector).where(
|
|
SearchSourceConnector.id == connector_id
|
|
)
|
|
)
|
|
db_connector = result.scalar_one_or_none()
|
|
if db_connector and not db_connector.config.get("auth_expired"):
|
|
db_connector.config = {**db_connector.config, "auth_expired": True}
|
|
flag_modified(db_connector, "config")
|
|
await self._db_session.commit()
|
|
await self._db_session.refresh(db_connector)
|
|
except Exception:
|
|
logger.warning(
|
|
"Failed to persist auth_expired for connector %s",
|
|
connector_id,
|
|
exc_info=True,
|
|
)
|
|
|
|
async def _get_accounts(
|
|
self, search_space_id: int, user_id: str
|
|
) -> list[GmailAccount]:
|
|
result = await self._db_session.execute(
|
|
select(SearchSourceConnector)
|
|
.filter(
|
|
and_(
|
|
SearchSourceConnector.search_space_id == search_space_id,
|
|
SearchSourceConnector.user_id == user_id,
|
|
SearchSourceConnector.connector_type.in_(
|
|
[
|
|
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
|
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
|
]
|
|
),
|
|
)
|
|
)
|
|
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
|
)
|
|
connectors = result.scalars().all()
|
|
return [GmailAccount.from_connector(c) for c in connectors]
|
|
|
|
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
|
accounts = await self._get_accounts(search_space_id, user_id)
|
|
|
|
if not accounts:
|
|
return {
|
|
"accounts": [],
|
|
"error": "No Gmail account connected",
|
|
}
|
|
|
|
accounts_with_status = []
|
|
for acc in accounts:
|
|
acc_dict = acc.to_dict()
|
|
auth_expired = await self._check_account_health(acc.id)
|
|
acc_dict["auth_expired"] = auth_expired
|
|
if auth_expired:
|
|
await self._persist_auth_expired(acc.id)
|
|
else:
|
|
try:
|
|
result = await self._db_session.execute(
|
|
select(SearchSourceConnector).where(
|
|
SearchSourceConnector.id == acc.id
|
|
)
|
|
)
|
|
connector = result.scalar_one_or_none()
|
|
if connector:
|
|
creds = await self._build_credentials(connector)
|
|
service = build("gmail", "v1", credentials=creds)
|
|
profile = await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
lambda service=service: (
|
|
service.users().getProfile(userId="me").execute()
|
|
),
|
|
)
|
|
acc_dict["email"] = profile.get("emailAddress", "")
|
|
except Exception:
|
|
logger.warning(
|
|
"Failed to fetch email for Gmail connector %s",
|
|
acc.id,
|
|
exc_info=True,
|
|
)
|
|
accounts_with_status.append(acc_dict)
|
|
|
|
return {"accounts": accounts_with_status}
|
|
|
|
async def get_update_context(
|
|
self, search_space_id: int, user_id: str, email_ref: str
|
|
) -> dict:
|
|
document, connector = await self._resolve_email(
|
|
search_space_id, user_id, email_ref
|
|
)
|
|
|
|
if not document or not connector:
|
|
return {
|
|
"error": (
|
|
f"Draft '{email_ref}' not found in your indexed Gmail messages. "
|
|
"This could mean: (1) the draft doesn't exist, "
|
|
"(2) it hasn't been indexed yet, "
|
|
"or (3) the subject is different. "
|
|
"Please check the exact draft subject in Gmail."
|
|
)
|
|
}
|
|
|
|
account = GmailAccount.from_connector(connector)
|
|
message = GmailMessage.from_document(document)
|
|
|
|
acc_dict = account.to_dict()
|
|
auth_expired = await self._check_account_health(connector.id)
|
|
acc_dict["auth_expired"] = auth_expired
|
|
if auth_expired:
|
|
await self._persist_auth_expired(connector.id)
|
|
|
|
result: dict = {
|
|
"account": acc_dict,
|
|
"email": message.to_dict(),
|
|
}
|
|
|
|
meta = document.document_metadata or {}
|
|
if meta.get("draft_id"):
|
|
result["draft_id"] = meta["draft_id"]
|
|
|
|
if not auth_expired:
|
|
existing_body = await self._fetch_draft_body(
|
|
connector, message.message_id, meta.get("draft_id")
|
|
)
|
|
if existing_body is not None:
|
|
result["existing_body"] = existing_body
|
|
|
|
return result
|
|
|
|
async def _fetch_draft_body(
|
|
self,
|
|
connector: SearchSourceConnector,
|
|
message_id: str,
|
|
draft_id: str | None,
|
|
) -> str | None:
|
|
"""Fetch the plain-text body of a Gmail draft via the API.
|
|
|
|
Tries ``drafts.get`` first (if *draft_id* is available), then falls
|
|
back to scanning ``drafts.list`` to resolve the draft by *message_id*.
|
|
Returns ``None`` on any failure so callers can degrade gracefully.
|
|
"""
|
|
try:
|
|
creds = await self._build_credentials(connector)
|
|
service = build("gmail", "v1", credentials=creds)
|
|
|
|
if not draft_id:
|
|
draft_id = await self._find_draft_id(service, message_id)
|
|
if not draft_id:
|
|
return None
|
|
|
|
draft = await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
lambda: (
|
|
service.users()
|
|
.drafts()
|
|
.get(userId="me", id=draft_id, format="full")
|
|
.execute()
|
|
),
|
|
)
|
|
|
|
payload = draft.get("message", {}).get("payload", {})
|
|
return self._extract_body_from_payload(payload)
|
|
except Exception:
|
|
logger.warning(
|
|
"Failed to fetch draft body for message_id=%s",
|
|
message_id,
|
|
exc_info=True,
|
|
)
|
|
return None
|
|
|
|
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
|
|
"""Resolve a draft ID from its message ID by scanning drafts.list."""
|
|
try:
|
|
page_token = None
|
|
while True:
|
|
kwargs: dict[str, Any] = {"userId": "me", "maxResults": 100}
|
|
if page_token:
|
|
kwargs["pageToken"] = page_token
|
|
response = await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
lambda kwargs=kwargs: (
|
|
service.users().drafts().list(**kwargs).execute()
|
|
),
|
|
)
|
|
for draft in response.get("drafts", []):
|
|
if draft.get("message", {}).get("id") == message_id:
|
|
return draft["id"]
|
|
page_token = response.get("nextPageToken")
|
|
if not page_token:
|
|
break
|
|
return None
|
|
except Exception:
|
|
logger.warning(
|
|
"Failed to look up draft by message_id=%s", message_id, exc_info=True
|
|
)
|
|
return None
|
|
|
|
@staticmethod
|
|
def _extract_body_from_payload(payload: dict) -> str | None:
|
|
"""Extract the plain-text (or html→text) body from a Gmail payload."""
|
|
import base64
|
|
|
|
def _get_parts(p: dict) -> list[dict]:
|
|
if "parts" in p:
|
|
parts: list[dict] = []
|
|
for sub in p["parts"]:
|
|
parts.extend(_get_parts(sub))
|
|
return parts
|
|
return [p]
|
|
|
|
parts = _get_parts(payload)
|
|
text_content = ""
|
|
for part in parts:
|
|
mime_type = part.get("mimeType", "")
|
|
data = part.get("body", {}).get("data", "")
|
|
if mime_type == "text/plain" and data:
|
|
text_content += base64.urlsafe_b64decode(data + "===").decode(
|
|
"utf-8", errors="ignore"
|
|
)
|
|
elif mime_type == "text/html" and data and not text_content:
|
|
from markdownify import markdownify as md
|
|
|
|
raw_html = base64.urlsafe_b64decode(data + "===").decode(
|
|
"utf-8", errors="ignore"
|
|
)
|
|
text_content = md(raw_html).strip()
|
|
|
|
return text_content.strip() if text_content.strip() else None
|
|
|
|
async def get_trash_context(
|
|
self, search_space_id: int, user_id: str, email_ref: str
|
|
) -> dict:
|
|
document, connector = await self._resolve_email(
|
|
search_space_id, user_id, email_ref
|
|
)
|
|
|
|
if not document or not connector:
|
|
return {
|
|
"error": (
|
|
f"Email '{email_ref}' not found in your indexed Gmail messages. "
|
|
"This could mean: (1) the email doesn't exist, "
|
|
"(2) it hasn't been indexed yet, "
|
|
"or (3) the subject is different."
|
|
)
|
|
}
|
|
|
|
account = GmailAccount.from_connector(connector)
|
|
message = GmailMessage.from_document(document)
|
|
|
|
acc_dict = account.to_dict()
|
|
auth_expired = await self._check_account_health(connector.id)
|
|
acc_dict["auth_expired"] = auth_expired
|
|
if auth_expired:
|
|
await self._persist_auth_expired(connector.id)
|
|
|
|
return {
|
|
"account": acc_dict,
|
|
"email": message.to_dict(),
|
|
}
|
|
|
|
async def _resolve_email(
|
|
self, search_space_id: int, user_id: str, email_ref: str
|
|
) -> tuple[Document | None, SearchSourceConnector | None]:
|
|
result = await self._db_session.execute(
|
|
select(Document, SearchSourceConnector)
|
|
.join(
|
|
SearchSourceConnector,
|
|
Document.connector_id == SearchSourceConnector.id,
|
|
)
|
|
.filter(
|
|
and_(
|
|
Document.search_space_id == search_space_id,
|
|
Document.document_type.in_(
|
|
[
|
|
DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
|
DocumentType.COMPOSIO_GMAIL_CONNECTOR,
|
|
]
|
|
),
|
|
SearchSourceConnector.user_id == user_id,
|
|
or_(
|
|
func.lower(cast(Document.document_metadata["subject"], String))
|
|
== func.lower(email_ref),
|
|
func.lower(Document.title) == func.lower(email_ref),
|
|
),
|
|
)
|
|
)
|
|
.order_by(Document.updated_at.desc().nullslast())
|
|
.limit(1)
|
|
)
|
|
row = result.first()
|
|
if row:
|
|
return row[0], row[1]
|
|
return None, None
|