mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-29 10:56:24 +02:00
- Added a guideline to ensure that each tool (Gmail, Google Calendar, Google Drive, Linear, Notion) is called only once per user request. - Updated documentation to clarify that the system will automatically select the most relevant match when multiple items share the same title or subject, enhancing user experience and preventing redundant calls.
445 lines
16 KiB
Python
445 lines
16 KiB
Python
import asyncio
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from google.oauth2.credentials import Credentials
|
|
from googleapiclient.discovery import build
|
|
from sqlalchemy import String, and_, cast, func, or_
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm.attributes import flag_modified
|
|
|
|
from app.db import (
|
|
Document,
|
|
DocumentType,
|
|
SearchSourceConnector,
|
|
SearchSourceConnectorType,
|
|
)
|
|
from app.utils.google_credentials import build_composio_credentials
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class GmailAccount:
|
|
id: int
|
|
name: str
|
|
email: str
|
|
|
|
@classmethod
|
|
def from_connector(cls, connector: SearchSourceConnector) -> "GmailAccount":
|
|
email = ""
|
|
if connector.name and " - " in connector.name:
|
|
email = connector.name.split(" - ", 1)[1]
|
|
return cls(id=connector.id, name=connector.name, email=email)
|
|
|
|
def to_dict(self) -> dict:
|
|
return {"id": self.id, "name": self.name, "email": self.email}
|
|
|
|
|
|
@dataclass
|
|
class GmailMessage:
|
|
message_id: str
|
|
thread_id: str
|
|
subject: str
|
|
sender: str
|
|
date: str
|
|
connector_id: int
|
|
document_id: int
|
|
|
|
@classmethod
|
|
def from_document(cls, document: Document) -> "GmailMessage":
|
|
meta = document.document_metadata or {}
|
|
return cls(
|
|
message_id=meta.get("message_id", ""),
|
|
thread_id=meta.get("thread_id", ""),
|
|
subject=meta.get("subject", document.title),
|
|
sender=meta.get("sender", ""),
|
|
date=meta.get("date", ""),
|
|
connector_id=document.connector_id,
|
|
document_id=document.id,
|
|
)
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"message_id": self.message_id,
|
|
"thread_id": self.thread_id,
|
|
"subject": self.subject,
|
|
"sender": self.sender,
|
|
"date": self.date,
|
|
"connector_id": self.connector_id,
|
|
"document_id": self.document_id,
|
|
}
|
|
|
|
|
|
class GmailToolMetadataService:
|
|
def __init__(self, db_session: AsyncSession):
|
|
self._db_session = db_session
|
|
|
|
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
|
if (
|
|
connector.connector_type
|
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
|
):
|
|
cca_id = connector.config.get("composio_connected_account_id")
|
|
if cca_id:
|
|
return build_composio_credentials(cca_id)
|
|
|
|
config_data = dict(connector.config)
|
|
|
|
from app.config import config
|
|
from app.utils.oauth_security import TokenEncryption
|
|
|
|
token_encrypted = config_data.get("_token_encrypted", False)
|
|
if token_encrypted and config.SECRET_KEY:
|
|
token_encryption = TokenEncryption(config.SECRET_KEY)
|
|
if config_data.get("token"):
|
|
config_data["token"] = token_encryption.decrypt_token(
|
|
config_data["token"]
|
|
)
|
|
if config_data.get("refresh_token"):
|
|
config_data["refresh_token"] = token_encryption.decrypt_token(
|
|
config_data["refresh_token"]
|
|
)
|
|
if config_data.get("client_secret"):
|
|
config_data["client_secret"] = token_encryption.decrypt_token(
|
|
config_data["client_secret"]
|
|
)
|
|
|
|
exp = config_data.get("expiry", "")
|
|
if exp:
|
|
exp = exp.replace("Z", "")
|
|
|
|
return Credentials(
|
|
token=config_data.get("token"),
|
|
refresh_token=config_data.get("refresh_token"),
|
|
token_uri=config_data.get("token_uri"),
|
|
client_id=config_data.get("client_id"),
|
|
client_secret=config_data.get("client_secret"),
|
|
scopes=config_data.get("scopes", []),
|
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
|
)
|
|
|
|
async def _check_account_health(self, connector_id: int) -> bool:
|
|
"""Check if a Gmail connector's credentials are still valid.
|
|
|
|
Uses a lightweight ``users().getProfile(userId='me')`` call.
|
|
|
|
Returns True if the credentials are expired/invalid, False if healthy.
|
|
"""
|
|
try:
|
|
result = await self._db_session.execute(
|
|
select(SearchSourceConnector).where(
|
|
SearchSourceConnector.id == connector_id
|
|
)
|
|
)
|
|
connector = result.scalar_one_or_none()
|
|
if not connector:
|
|
return True
|
|
|
|
creds = await self._build_credentials(connector)
|
|
service = build("gmail", "v1", credentials=creds)
|
|
await asyncio.get_event_loop().run_in_executor(
|
|
None, lambda: service.users().getProfile(userId="me").execute()
|
|
)
|
|
return False
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Gmail connector %s health check failed: %s",
|
|
connector_id,
|
|
e,
|
|
)
|
|
return True
|
|
|
|
async def _persist_auth_expired(self, connector_id: int) -> None:
|
|
"""Persist ``auth_expired: True`` to the connector config if not already set."""
|
|
try:
|
|
result = await self._db_session.execute(
|
|
select(SearchSourceConnector).where(
|
|
SearchSourceConnector.id == connector_id
|
|
)
|
|
)
|
|
db_connector = result.scalar_one_or_none()
|
|
if db_connector and not db_connector.config.get("auth_expired"):
|
|
db_connector.config = {**db_connector.config, "auth_expired": True}
|
|
flag_modified(db_connector, "config")
|
|
await self._db_session.commit()
|
|
await self._db_session.refresh(db_connector)
|
|
except Exception:
|
|
logger.warning(
|
|
"Failed to persist auth_expired for connector %s",
|
|
connector_id,
|
|
exc_info=True,
|
|
)
|
|
|
|
async def _get_accounts(
|
|
self, search_space_id: int, user_id: str
|
|
) -> list[GmailAccount]:
|
|
result = await self._db_session.execute(
|
|
select(SearchSourceConnector)
|
|
.filter(
|
|
and_(
|
|
SearchSourceConnector.search_space_id == search_space_id,
|
|
SearchSourceConnector.user_id == user_id,
|
|
SearchSourceConnector.connector_type.in_([
|
|
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
|
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
|
]),
|
|
)
|
|
)
|
|
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
|
)
|
|
connectors = result.scalars().all()
|
|
return [GmailAccount.from_connector(c) for c in connectors]
|
|
|
|
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
|
accounts = await self._get_accounts(search_space_id, user_id)
|
|
|
|
if not accounts:
|
|
return {
|
|
"accounts": [],
|
|
"error": "No Gmail account connected",
|
|
}
|
|
|
|
accounts_with_status = []
|
|
for acc in accounts:
|
|
acc_dict = acc.to_dict()
|
|
auth_expired = await self._check_account_health(acc.id)
|
|
acc_dict["auth_expired"] = auth_expired
|
|
if auth_expired:
|
|
await self._persist_auth_expired(acc.id)
|
|
else:
|
|
try:
|
|
result = await self._db_session.execute(
|
|
select(SearchSourceConnector).where(
|
|
SearchSourceConnector.id == acc.id
|
|
)
|
|
)
|
|
connector = result.scalar_one_or_none()
|
|
if connector:
|
|
creds = await self._build_credentials(connector)
|
|
service = build("gmail", "v1", credentials=creds)
|
|
profile = await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
lambda: service.users()
|
|
.getProfile(userId="me")
|
|
.execute(),
|
|
)
|
|
acc_dict["email"] = profile.get("emailAddress", "")
|
|
except Exception:
|
|
logger.warning(
|
|
"Failed to fetch email for Gmail connector %s",
|
|
acc.id,
|
|
exc_info=True,
|
|
)
|
|
accounts_with_status.append(acc_dict)
|
|
|
|
return {"accounts": accounts_with_status}
|
|
|
|
async def get_update_context(
|
|
self, search_space_id: int, user_id: str, email_ref: str
|
|
) -> dict:
|
|
document, connector = await self._resolve_email(
|
|
search_space_id, user_id, email_ref
|
|
)
|
|
|
|
if not document or not connector:
|
|
return {
|
|
"error": (
|
|
f"Draft '{email_ref}' not found in your indexed Gmail messages. "
|
|
"This could mean: (1) the draft doesn't exist, "
|
|
"(2) it hasn't been indexed yet, "
|
|
"or (3) the subject is different. "
|
|
"Please check the exact draft subject in Gmail."
|
|
)
|
|
}
|
|
|
|
account = GmailAccount.from_connector(connector)
|
|
message = GmailMessage.from_document(document)
|
|
|
|
acc_dict = account.to_dict()
|
|
auth_expired = await self._check_account_health(connector.id)
|
|
acc_dict["auth_expired"] = auth_expired
|
|
if auth_expired:
|
|
await self._persist_auth_expired(connector.id)
|
|
|
|
result: dict = {
|
|
"account": acc_dict,
|
|
"email": message.to_dict(),
|
|
}
|
|
|
|
meta = document.document_metadata or {}
|
|
if meta.get("draft_id"):
|
|
result["draft_id"] = meta["draft_id"]
|
|
|
|
if not auth_expired:
|
|
existing_body = await self._fetch_draft_body(
|
|
connector, message.message_id, meta.get("draft_id")
|
|
)
|
|
if existing_body is not None:
|
|
result["existing_body"] = existing_body
|
|
|
|
return result
|
|
|
|
async def _fetch_draft_body(
|
|
self,
|
|
connector: SearchSourceConnector,
|
|
message_id: str,
|
|
draft_id: str | None,
|
|
) -> str | None:
|
|
"""Fetch the plain-text body of a Gmail draft via the API.
|
|
|
|
Tries ``drafts.get`` first (if *draft_id* is available), then falls
|
|
back to scanning ``drafts.list`` to resolve the draft by *message_id*.
|
|
Returns ``None`` on any failure so callers can degrade gracefully.
|
|
"""
|
|
try:
|
|
creds = await self._build_credentials(connector)
|
|
service = build("gmail", "v1", credentials=creds)
|
|
|
|
if not draft_id:
|
|
draft_id = await self._find_draft_id(service, message_id)
|
|
if not draft_id:
|
|
return None
|
|
|
|
draft = await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
lambda: service.users()
|
|
.drafts()
|
|
.get(userId="me", id=draft_id, format="full")
|
|
.execute(),
|
|
)
|
|
|
|
payload = draft.get("message", {}).get("payload", {})
|
|
return self._extract_body_from_payload(payload)
|
|
except Exception:
|
|
logger.warning(
|
|
"Failed to fetch draft body for message_id=%s",
|
|
message_id,
|
|
exc_info=True,
|
|
)
|
|
return None
|
|
|
|
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
|
|
"""Resolve a draft ID from its message ID by scanning drafts.list."""
|
|
try:
|
|
page_token = None
|
|
while True:
|
|
kwargs: dict[str, Any] = {"userId": "me", "maxResults": 100}
|
|
if page_token:
|
|
kwargs["pageToken"] = page_token
|
|
response = await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
lambda: service.users().drafts().list(**kwargs).execute(),
|
|
)
|
|
for draft in response.get("drafts", []):
|
|
if draft.get("message", {}).get("id") == message_id:
|
|
return draft["id"]
|
|
page_token = response.get("nextPageToken")
|
|
if not page_token:
|
|
break
|
|
return None
|
|
except Exception:
|
|
logger.warning(
|
|
"Failed to look up draft by message_id=%s", message_id, exc_info=True
|
|
)
|
|
return None
|
|
|
|
@staticmethod
|
|
def _extract_body_from_payload(payload: dict) -> str | None:
|
|
"""Extract the plain-text (or html→text) body from a Gmail payload."""
|
|
import base64
|
|
|
|
def _get_parts(p: dict) -> list[dict]:
|
|
if "parts" in p:
|
|
parts: list[dict] = []
|
|
for sub in p["parts"]:
|
|
parts.extend(_get_parts(sub))
|
|
return parts
|
|
return [p]
|
|
|
|
parts = _get_parts(payload)
|
|
text_content = ""
|
|
for part in parts:
|
|
mime_type = part.get("mimeType", "")
|
|
data = part.get("body", {}).get("data", "")
|
|
if mime_type == "text/plain" and data:
|
|
text_content += base64.urlsafe_b64decode(data + "===").decode(
|
|
"utf-8", errors="ignore"
|
|
)
|
|
elif mime_type == "text/html" and data and not text_content:
|
|
from markdownify import markdownify as md
|
|
|
|
raw_html = base64.urlsafe_b64decode(data + "===").decode(
|
|
"utf-8", errors="ignore"
|
|
)
|
|
text_content = md(raw_html).strip()
|
|
|
|
return text_content.strip() if text_content.strip() else None
|
|
|
|
async def get_trash_context(
|
|
self, search_space_id: int, user_id: str, email_ref: str
|
|
) -> dict:
|
|
document, connector = await self._resolve_email(
|
|
search_space_id, user_id, email_ref
|
|
)
|
|
|
|
if not document or not connector:
|
|
return {
|
|
"error": (
|
|
f"Email '{email_ref}' not found in your indexed Gmail messages. "
|
|
"This could mean: (1) the email doesn't exist, "
|
|
"(2) it hasn't been indexed yet, "
|
|
"or (3) the subject is different."
|
|
)
|
|
}
|
|
|
|
account = GmailAccount.from_connector(connector)
|
|
message = GmailMessage.from_document(document)
|
|
|
|
acc_dict = account.to_dict()
|
|
auth_expired = await self._check_account_health(connector.id)
|
|
acc_dict["auth_expired"] = auth_expired
|
|
if auth_expired:
|
|
await self._persist_auth_expired(connector.id)
|
|
|
|
return {
|
|
"account": acc_dict,
|
|
"email": message.to_dict(),
|
|
}
|
|
|
|
async def _resolve_email(
|
|
self, search_space_id: int, user_id: str, email_ref: str
|
|
) -> tuple[Document | None, SearchSourceConnector | None]:
|
|
result = await self._db_session.execute(
|
|
select(Document, SearchSourceConnector)
|
|
.join(
|
|
SearchSourceConnector,
|
|
Document.connector_id == SearchSourceConnector.id,
|
|
)
|
|
.filter(
|
|
and_(
|
|
Document.search_space_id == search_space_id,
|
|
Document.document_type.in_([
|
|
DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
|
DocumentType.COMPOSIO_GMAIL_CONNECTOR,
|
|
]),
|
|
SearchSourceConnector.user_id == user_id,
|
|
or_(
|
|
func.lower(
|
|
cast(Document.document_metadata["subject"], String)
|
|
)
|
|
== func.lower(email_ref),
|
|
func.lower(Document.title) == func.lower(email_ref),
|
|
),
|
|
)
|
|
)
|
|
.order_by(Document.updated_at.desc().nullslast())
|
|
.limit(1)
|
|
)
|
|
row = result.first()
|
|
if row:
|
|
return row[0], row[1]
|
|
return None, None
|