mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 21:32:39 +02:00
fix: fixed composio issues
This commit is contained in:
parent
47b2994ec7
commit
cea8618aed
25 changed files with 1756 additions and 461 deletions
|
|
@ -0,0 +1,41 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
|
||||||
|
def split_recipients(value: str | None) -> list[str]:
|
||||||
|
if not value:
|
||||||
|
return []
|
||||||
|
return [recipient.strip() for recipient in value.split(",") if recipient.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_composio_data(data: Any) -> Any:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner = data.get("data", data)
|
||||||
|
if isinstance(inner, dict):
|
||||||
|
return inner.get("response_data", inner)
|
||||||
|
return inner
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_composio_gmail_tool(
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
user_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
params: dict[str, Any],
|
||||||
|
) -> tuple[Any, str | None]:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return None, "Composio connected account ID not found for this Gmail connector."
|
||||||
|
|
||||||
|
result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
params=params,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return None, result.get("error", "Unknown Composio Gmail error")
|
||||||
|
|
||||||
|
return unwrap_composio_data(result.get("data")), None
|
||||||
|
|
@ -157,16 +157,13 @@ def create_create_gmail_draft_tool(
|
||||||
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_gmail = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_gmail:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||||
|
|
@ -208,10 +205,6 @@ def create_create_gmail_draft_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
from googleapiclient.discovery import build
|
|
||||||
|
|
||||||
gmail_service = build("gmail", "v1", credentials=creds)
|
|
||||||
|
|
||||||
message = MIMEText(final_body)
|
message = MIMEText(final_body)
|
||||||
message["to"] = final_to
|
message["to"] = final_to
|
||||||
message["subject"] = final_subject
|
message["subject"] = final_subject
|
||||||
|
|
@ -222,15 +215,43 @@ def create_create_gmail_draft_tool(
|
||||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
created = await asyncio.get_event_loop().run_in_executor(
|
if is_composio_gmail:
|
||||||
None,
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
lambda: (
|
execute_composio_gmail_tool,
|
||||||
gmail_service.users()
|
split_recipients,
|
||||||
.drafts()
|
)
|
||||||
.create(userId="me", body={"message": {"raw": raw}})
|
|
||||||
.execute()
|
created, error = await execute_composio_gmail_tool(
|
||||||
),
|
connector,
|
||||||
)
|
user_id,
|
||||||
|
"GMAIL_CREATE_EMAIL_DRAFT",
|
||||||
|
{
|
||||||
|
"user_id": "me",
|
||||||
|
"recipient_email": final_to,
|
||||||
|
"subject": final_subject,
|
||||||
|
"body": final_body,
|
||||||
|
"cc": split_recipients(final_cc),
|
||||||
|
"bcc": split_recipients(final_bcc),
|
||||||
|
"is_html": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
if not isinstance(created, dict):
|
||||||
|
created = {}
|
||||||
|
else:
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
|
created = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
gmail_service.users()
|
||||||
|
.drafts()
|
||||||
|
.create(userId="me", body={"message": {"raw": raw}})
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
except Exception as api_err:
|
except Exception as api_err:
|
||||||
from googleapiclient.errors import HttpError
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,54 @@ def create_read_gmail_email_tool(
|
||||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
|
):
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||||
|
_format_gmail_summary,
|
||||||
|
)
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
service = ComposioService()
|
||||||
|
detail, error = await service.get_gmail_message_detail(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
return {"status": "error", "message": error}
|
||||||
|
if not detail:
|
||||||
|
return {
|
||||||
|
"status": "not_found",
|
||||||
|
"message": f"Email with ID '{message_id}' not found.",
|
||||||
|
}
|
||||||
|
|
||||||
|
summary = _format_gmail_summary(detail)
|
||||||
|
content = (
|
||||||
|
f"# {summary['subject']}\n\n"
|
||||||
|
f"**From:** {summary['from']}\n"
|
||||||
|
f"**To:** {summary['to']}\n"
|
||||||
|
f"**Date:** {summary['date']}\n\n"
|
||||||
|
f"## Message Content\n\n"
|
||||||
|
f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
|
||||||
|
f"## Message Details\n\n"
|
||||||
|
f"- **Message ID:** {summary['message_id']}\n"
|
||||||
|
f"- **Thread ID:** {summary['thread_id']}\n"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message_id": summary["message_id"] or message_id,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||||
|
|
||||||
creds = _build_credentials(connector)
|
creds = _build_credentials(connector)
|
||||||
|
|
|
||||||
|
|
@ -39,12 +39,7 @@ def _build_credentials(connector: SearchSourceConnector):
|
||||||
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
|
|
||||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
raise ValueError("Composio connectors must use Composio tool execution.")
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
|
||||||
if not cca_id:
|
|
||||||
raise ValueError("Composio connected account ID not found.")
|
|
||||||
return build_composio_credentials(cca_id)
|
|
||||||
|
|
||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
|
@ -67,6 +62,63 @@ def _build_credentials(connector: SearchSourceConnector):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _gmail_headers(message: dict[str, Any]) -> dict[str, str]:
|
||||||
|
headers = message.get("payload", {}).get("headers", [])
|
||||||
|
return {
|
||||||
|
header.get("name", "").lower(): header.get("value", "")
|
||||||
|
for header in headers
|
||||||
|
if isinstance(header, dict)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
headers = _gmail_headers(message)
|
||||||
|
return {
|
||||||
|
"message_id": message.get("id") or message.get("messageId"),
|
||||||
|
"thread_id": message.get("threadId"),
|
||||||
|
"subject": message.get("subject") or headers.get("subject", "No Subject"),
|
||||||
|
"from": message.get("sender") or headers.get("from", "Unknown"),
|
||||||
|
"to": message.get("to") or headers.get("to", ""),
|
||||||
|
"date": message.get("messageTimestamp") or headers.get("date", ""),
|
||||||
|
"snippet": message.get("snippet") or message.get("messageText", "")[:300],
|
||||||
|
"labels": message.get("labelIds", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _search_composio_gmail(
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
user_id: str,
|
||||||
|
query: str,
|
||||||
|
max_results: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
service = ComposioService()
|
||||||
|
messages, _next_token, _estimate, error = await service.get_gmail_messages(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
query=query,
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
return {"status": "error", "message": error}
|
||||||
|
|
||||||
|
emails = [_format_gmail_summary(message) for message in messages]
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"emails": emails,
|
||||||
|
"total": len(emails),
|
||||||
|
"message": "No emails found." if not emails else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_search_gmail_tool(
|
def create_search_gmail_tool(
|
||||||
db_session: AsyncSession | None = None,
|
db_session: AsyncSession | None = None,
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
|
|
@ -110,6 +162,14 @@ def create_search_gmail_tool(
|
||||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
|
):
|
||||||
|
return await _search_composio_gmail(
|
||||||
|
connector, str(user_id), query, max_results
|
||||||
|
)
|
||||||
|
|
||||||
creds = _build_credentials(connector)
|
creds = _build_credentials(connector)
|
||||||
|
|
||||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||||
|
|
|
||||||
|
|
@ -158,16 +158,13 @@ def create_send_gmail_email_tool(
|
||||||
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_gmail = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_gmail:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||||
|
|
@ -209,10 +206,6 @@ def create_send_gmail_email_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
from googleapiclient.discovery import build
|
|
||||||
|
|
||||||
gmail_service = build("gmail", "v1", credentials=creds)
|
|
||||||
|
|
||||||
message = MIMEText(final_body)
|
message = MIMEText(final_body)
|
||||||
message["to"] = final_to
|
message["to"] = final_to
|
||||||
message["subject"] = final_subject
|
message["subject"] = final_subject
|
||||||
|
|
@ -223,15 +216,43 @@ def create_send_gmail_email_tool(
|
||||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sent = await asyncio.get_event_loop().run_in_executor(
|
if is_composio_gmail:
|
||||||
None,
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
lambda: (
|
execute_composio_gmail_tool,
|
||||||
gmail_service.users()
|
split_recipients,
|
||||||
.messages()
|
)
|
||||||
.send(userId="me", body={"raw": raw})
|
|
||||||
.execute()
|
sent, error = await execute_composio_gmail_tool(
|
||||||
),
|
connector,
|
||||||
)
|
user_id,
|
||||||
|
"GMAIL_SEND_EMAIL",
|
||||||
|
{
|
||||||
|
"user_id": "me",
|
||||||
|
"recipient_email": final_to,
|
||||||
|
"subject": final_subject,
|
||||||
|
"body": final_body,
|
||||||
|
"cc": split_recipients(final_cc),
|
||||||
|
"bcc": split_recipients(final_bcc),
|
||||||
|
"is_html": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
if not isinstance(sent, dict):
|
||||||
|
sent = {}
|
||||||
|
else:
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
|
sent = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
gmail_service.users()
|
||||||
|
.messages()
|
||||||
|
.send(userId="me", body={"raw": raw})
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
except Exception as api_err:
|
except Exception as api_err:
|
||||||
from googleapiclient.errors import HttpError
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -158,16 +158,13 @@ def create_trash_gmail_email_tool(
|
||||||
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
|
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_gmail = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_gmail:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||||
|
|
@ -209,20 +206,33 @@ def create_trash_gmail_email_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
from googleapiclient.discovery import build
|
|
||||||
|
|
||||||
gmail_service = build("gmail", "v1", credentials=creds)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
if is_composio_gmail:
|
||||||
None,
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
lambda: (
|
execute_composio_gmail_tool,
|
||||||
gmail_service.users()
|
)
|
||||||
.messages()
|
|
||||||
.trash(userId="me", id=final_message_id)
|
_trashed, error = await execute_composio_gmail_tool(
|
||||||
.execute()
|
connector,
|
||||||
),
|
user_id,
|
||||||
)
|
"GMAIL_MOVE_TO_TRASH",
|
||||||
|
{"user_id": "me", "message_id": final_message_id},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
else:
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
gmail_service.users()
|
||||||
|
.messages()
|
||||||
|
.trash(userId="me", id=final_message_id)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
except Exception as api_err:
|
except Exception as api_err:
|
||||||
from googleapiclient.errors import HttpError
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -188,16 +188,13 @@ def create_update_gmail_draft_tool(
|
||||||
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
|
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_gmail = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_gmail:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||||
|
|
@ -239,18 +236,22 @@ def create_update_gmail_draft_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
from googleapiclient.discovery import build
|
|
||||||
|
|
||||||
gmail_service = build("gmail", "v1", credentials=creds)
|
|
||||||
|
|
||||||
# Resolve draft_id if not already available
|
# Resolve draft_id if not already available
|
||||||
if not final_draft_id:
|
if not final_draft_id:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
||||||
)
|
)
|
||||||
final_draft_id = await _find_draft_id_by_message(
|
if is_composio_gmail:
|
||||||
gmail_service, message_id
|
final_draft_id = await _find_composio_draft_id_by_message(
|
||||||
)
|
connector, user_id, message_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
|
final_draft_id = await _find_draft_id_by_message(
|
||||||
|
gmail_service, message_id
|
||||||
|
)
|
||||||
|
|
||||||
if not final_draft_id:
|
if not final_draft_id:
|
||||||
return {
|
return {
|
||||||
|
|
@ -272,19 +273,48 @@ def create_update_gmail_draft_tool(
|
||||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
updated = await asyncio.get_event_loop().run_in_executor(
|
if is_composio_gmail:
|
||||||
None,
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
lambda: (
|
execute_composio_gmail_tool,
|
||||||
gmail_service.users()
|
split_recipients,
|
||||||
.drafts()
|
)
|
||||||
.update(
|
|
||||||
userId="me",
|
updated, error = await execute_composio_gmail_tool(
|
||||||
id=final_draft_id,
|
connector,
|
||||||
body={"message": {"raw": raw}},
|
user_id,
|
||||||
)
|
"GMAIL_UPDATE_DRAFT",
|
||||||
.execute()
|
{
|
||||||
),
|
"user_id": "me",
|
||||||
)
|
"draft_id": final_draft_id,
|
||||||
|
"recipient_email": final_to,
|
||||||
|
"subject": final_subject,
|
||||||
|
"body": final_body,
|
||||||
|
"cc": split_recipients(final_cc),
|
||||||
|
"bcc": split_recipients(final_bcc),
|
||||||
|
"is_html": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
if not isinstance(updated, dict):
|
||||||
|
updated = {}
|
||||||
|
else:
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
|
updated = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
gmail_service.users()
|
||||||
|
.drafts()
|
||||||
|
.update(
|
||||||
|
userId="me",
|
||||||
|
id=final_draft_id,
|
||||||
|
body={"message": {"raw": raw}},
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
except Exception as api_err:
|
except Exception as api_err:
|
||||||
from googleapiclient.errors import HttpError
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
|
@ -408,3 +438,35 @@ async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to look up draft by message_id: {e}")
|
logger.warning(f"Failed to look up draft by message_id: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _find_composio_draft_id_by_message(
|
||||||
|
connector: Any, user_id: str, message_id: str
|
||||||
|
) -> str | None:
|
||||||
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
|
execute_composio_gmail_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
page_token = ""
|
||||||
|
while True:
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"user_id": "me",
|
||||||
|
"max_results": 100,
|
||||||
|
"verbose": False,
|
||||||
|
}
|
||||||
|
if page_token:
|
||||||
|
params["page_token"] = page_token
|
||||||
|
|
||||||
|
data, error = await execute_composio_gmail_tool(
|
||||||
|
connector, user_id, "GMAIL_LIST_DRAFTS", params
|
||||||
|
)
|
||||||
|
if error or not isinstance(data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
for draft in data.get("drafts", []):
|
||||||
|
if draft.get("message", {}).get("id") == message_id:
|
||||||
|
return draft.get("id")
|
||||||
|
|
||||||
|
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
|
||||||
|
if not page_token:
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -168,16 +168,13 @@ def create_create_calendar_event_tool(
|
||||||
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_calendar = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_calendar:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this connector.",
|
"message": "Composio connected account ID not found for this connector.",
|
||||||
|
|
@ -211,10 +208,6 @@ def create_create_calendar_event_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
service = await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
|
||||||
)
|
|
||||||
|
|
||||||
tz = context.get("timezone", "UTC")
|
tz = context.get("timezone", "UTC")
|
||||||
event_body: dict[str, Any] = {
|
event_body: dict[str, Any] = {
|
||||||
"summary": final_summary,
|
"summary": final_summary,
|
||||||
|
|
@ -231,14 +224,51 @@ def create_create_calendar_event_tool(
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
created = await asyncio.get_event_loop().run_in_executor(
|
if is_composio_calendar:
|
||||||
None,
|
from app.services.composio_service import ComposioService
|
||||||
lambda: (
|
|
||||||
service.events()
|
composio_params = {
|
||||||
.insert(calendarId="primary", body=event_body)
|
"calendar_id": "primary",
|
||||||
.execute()
|
"summary": final_summary,
|
||||||
),
|
"start_datetime": final_start_datetime,
|
||||||
)
|
"end_datetime": final_end_datetime,
|
||||||
|
"timezone": tz,
|
||||||
|
"attendees": final_attendees or [],
|
||||||
|
}
|
||||||
|
if final_description:
|
||||||
|
composio_params["description"] = final_description
|
||||||
|
if final_location:
|
||||||
|
composio_params["location"] = final_location
|
||||||
|
|
||||||
|
composio_result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLECALENDAR_CREATE_EVENT",
|
||||||
|
params=composio_params,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not composio_result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
composio_result.get(
|
||||||
|
"error", "Unknown Composio Calendar error"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
created = composio_result.get("data", {})
|
||||||
|
if isinstance(created, dict):
|
||||||
|
created = created.get("data", created)
|
||||||
|
if isinstance(created, dict):
|
||||||
|
created = created.get("response_data", created)
|
||||||
|
else:
|
||||||
|
service = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
|
)
|
||||||
|
created = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
service.events()
|
||||||
|
.insert(calendarId="primary", body=event_body)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
except Exception as api_err:
|
except Exception as api_err:
|
||||||
from googleapiclient.errors import HttpError
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -159,16 +159,13 @@ def create_delete_calendar_event_tool(
|
||||||
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_calendar = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_calendar:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this connector.",
|
"message": "Composio connected account ID not found for this connector.",
|
||||||
|
|
@ -202,19 +199,34 @@ def create_delete_calendar_event_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
service = await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
if is_composio_calendar:
|
||||||
None,
|
from app.services.composio_service import ComposioService
|
||||||
lambda: (
|
|
||||||
service.events()
|
composio_result = await ComposioService().execute_tool(
|
||||||
.delete(calendarId="primary", eventId=final_event_id)
|
connected_account_id=cca_id,
|
||||||
.execute()
|
tool_name="GOOGLECALENDAR_DELETE_EVENT",
|
||||||
),
|
params={"calendar_id": "primary", "event_id": final_event_id},
|
||||||
)
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not composio_result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
composio_result.get(
|
||||||
|
"error", "Unknown Composio Calendar error"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
service = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
|
)
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
service.events()
|
||||||
|
.delete(calendarId="primary", eventId=final_event_id)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
except Exception as api_err:
|
except Exception as api_err:
|
||||||
from googleapiclient.errors import HttpError
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,35 @@ _CALENDAR_TYPES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _to_calendar_boundary(value: str, *, is_end: bool) -> str:
|
||||||
|
if "T" in value:
|
||||||
|
return value
|
||||||
|
time = "23:59:59" if is_end else "00:00:00"
|
||||||
|
return f"{value}T{time}Z"
|
||||||
|
|
||||||
|
|
||||||
|
def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
events = []
|
||||||
|
for ev in events_raw:
|
||||||
|
start = ev.get("start", {})
|
||||||
|
end = ev.get("end", {})
|
||||||
|
attendees_raw = ev.get("attendees", [])
|
||||||
|
events.append(
|
||||||
|
{
|
||||||
|
"event_id": ev.get("id"),
|
||||||
|
"summary": ev.get("summary", "No Title"),
|
||||||
|
"start": start.get("dateTime") or start.get("date", ""),
|
||||||
|
"end": end.get("dateTime") or end.get("date", ""),
|
||||||
|
"location": ev.get("location", ""),
|
||||||
|
"description": ev.get("description", ""),
|
||||||
|
"html_link": ev.get("htmlLink", ""),
|
||||||
|
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
|
||||||
|
"status": ev.get("status", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
def create_search_calendar_events_tool(
|
def create_search_calendar_events_tool(
|
||||||
db_session: AsyncSession | None = None,
|
db_session: AsyncSession | None = None,
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
|
|
@ -61,22 +90,47 @@ def create_search_calendar_events_tool(
|
||||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||||
}
|
}
|
||||||
|
|
||||||
creds = _build_credentials(connector)
|
if (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
|
):
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this connector.",
|
||||||
|
}
|
||||||
|
|
||||||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
cal = GoogleCalendarConnector(
|
events_raw, error = await ComposioService().get_calendar_events(
|
||||||
credentials=creds,
|
connected_account_id=cca_id,
|
||||||
session=db_session,
|
entity_id=f"surfsense_{user_id}",
|
||||||
user_id=user_id,
|
time_min=_to_calendar_boundary(start_date, is_end=False),
|
||||||
connector_id=connector.id,
|
time_max=_to_calendar_boundary(end_date, is_end=True),
|
||||||
)
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
if not events_raw and not error:
|
||||||
|
error = "No events found in the specified date range."
|
||||||
|
else:
|
||||||
|
creds = _build_credentials(connector)
|
||||||
|
|
||||||
events_raw, error = await cal.get_all_primary_calendar_events(
|
from app.connectors.google_calendar_connector import (
|
||||||
start_date=start_date,
|
GoogleCalendarConnector,
|
||||||
end_date=end_date,
|
)
|
||||||
max_results=max_results,
|
|
||||||
)
|
cal = GoogleCalendarConnector(
|
||||||
|
credentials=creds,
|
||||||
|
session=db_session,
|
||||||
|
user_id=user_id,
|
||||||
|
connector_id=connector.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
events_raw, error = await cal.get_all_primary_calendar_events(
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
if (
|
if (
|
||||||
|
|
@ -97,24 +151,7 @@ def create_search_calendar_events_tool(
|
||||||
}
|
}
|
||||||
return {"status": "error", "message": error}
|
return {"status": "error", "message": error}
|
||||||
|
|
||||||
events = []
|
events = _format_calendar_events(events_raw)
|
||||||
for ev in events_raw:
|
|
||||||
start = ev.get("start", {})
|
|
||||||
end = ev.get("end", {})
|
|
||||||
attendees_raw = ev.get("attendees", [])
|
|
||||||
events.append(
|
|
||||||
{
|
|
||||||
"event_id": ev.get("id"),
|
|
||||||
"summary": ev.get("summary", "No Title"),
|
|
||||||
"start": start.get("dateTime") or start.get("date", ""),
|
|
||||||
"end": end.get("dateTime") or end.get("date", ""),
|
|
||||||
"location": ev.get("location", ""),
|
|
||||||
"description": ev.get("description", ""),
|
|
||||||
"html_link": ev.get("htmlLink", ""),
|
|
||||||
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
|
|
||||||
"status": ev.get("status", ""),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "success", "events": events, "total": len(events)}
|
return {"status": "success", "events": events, "total": len(events)}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -192,16 +192,13 @@ def create_update_calendar_event_tool(
|
||||||
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_calendar = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_calendar:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this connector.",
|
"message": "Composio connected account ID not found for this connector.",
|
||||||
|
|
@ -235,10 +232,6 @@ def create_update_calendar_event_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
service = await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
|
||||||
)
|
|
||||||
|
|
||||||
update_body: dict[str, Any] = {}
|
update_body: dict[str, Any] = {}
|
||||||
if final_new_summary is not None:
|
if final_new_summary is not None:
|
||||||
update_body["summary"] = final_new_summary
|
update_body["summary"] = final_new_summary
|
||||||
|
|
@ -264,18 +257,65 @@ def create_update_calendar_event_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
updated = await asyncio.get_event_loop().run_in_executor(
|
if is_composio_calendar:
|
||||||
None,
|
from app.services.composio_service import ComposioService
|
||||||
lambda: (
|
|
||||||
service.events()
|
composio_params: dict[str, Any] = {
|
||||||
.patch(
|
"calendar_id": "primary",
|
||||||
calendarId="primary",
|
"event_id": final_event_id,
|
||||||
eventId=final_event_id,
|
}
|
||||||
body=update_body,
|
if final_new_summary is not None:
|
||||||
|
composio_params["summary"] = final_new_summary
|
||||||
|
if final_new_start_datetime is not None:
|
||||||
|
composio_params["start_time"] = final_new_start_datetime
|
||||||
|
if final_new_end_datetime is not None:
|
||||||
|
composio_params["end_time"] = final_new_end_datetime
|
||||||
|
if final_new_description is not None:
|
||||||
|
composio_params["description"] = final_new_description
|
||||||
|
if final_new_location is not None:
|
||||||
|
composio_params["location"] = final_new_location
|
||||||
|
if final_new_attendees is not None:
|
||||||
|
composio_params["attendees"] = [
|
||||||
|
e.strip() for e in final_new_attendees if e.strip()
|
||||||
|
]
|
||||||
|
if not _is_date_only(
|
||||||
|
final_new_start_datetime or final_new_end_datetime or ""
|
||||||
|
):
|
||||||
|
composio_params["timezone"] = context.get("timezone", "UTC")
|
||||||
|
|
||||||
|
composio_result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLECALENDAR_PATCH_EVENT",
|
||||||
|
params=composio_params,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not composio_result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
composio_result.get(
|
||||||
|
"error", "Unknown Composio Calendar error"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
.execute()
|
updated = composio_result.get("data", {})
|
||||||
),
|
if isinstance(updated, dict):
|
||||||
)
|
updated = updated.get("data", updated)
|
||||||
|
if isinstance(updated, dict):
|
||||||
|
updated = updated.get("response_data", updated)
|
||||||
|
else:
|
||||||
|
service = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
|
)
|
||||||
|
updated = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
service.events()
|
||||||
|
.patch(
|
||||||
|
calendarId="primary",
|
||||||
|
eventId=final_event_id,
|
||||||
|
body=update_body,
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
except Exception as api_err:
|
except Exception as api_err:
|
||||||
from googleapiclient.errors import HttpError
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -179,29 +179,59 @@ def create_create_google_drive_file_tool(
|
||||||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_built_creds = None
|
is_composio_drive = (
|
||||||
if (
|
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_drive:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
pre_built_creds = build_composio_credentials(cca_id)
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this Drive connector.",
|
||||||
|
}
|
||||||
client = GoogleDriveClient(
|
client = GoogleDriveClient(
|
||||||
session=db_session,
|
session=db_session,
|
||||||
connector_id=actual_connector_id,
|
connector_id=actual_connector_id,
|
||||||
credentials=pre_built_creds,
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
created = await client.create_file(
|
if is_composio_drive:
|
||||||
name=final_name,
|
from app.services.composio_service import ComposioService
|
||||||
mime_type=mime_type,
|
|
||||||
parent_folder_id=final_parent_folder_id,
|
params: dict[str, Any] = {
|
||||||
content=final_content,
|
"name": final_name,
|
||||||
)
|
"mimeType": mime_type,
|
||||||
|
"fields": "id,name,webViewLink,mimeType",
|
||||||
|
}
|
||||||
|
if final_parent_folder_id:
|
||||||
|
params["parents"] = [final_parent_folder_id]
|
||||||
|
if final_content:
|
||||||
|
params["description"] = final_content[:4096]
|
||||||
|
|
||||||
|
result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLEDRIVE_CREATE_FILE",
|
||||||
|
params=params,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
result.get("error", "Unknown Composio Drive error")
|
||||||
|
)
|
||||||
|
created = result.get("data", {})
|
||||||
|
if isinstance(created, dict):
|
||||||
|
created = created.get("data", created)
|
||||||
|
if isinstance(created, dict):
|
||||||
|
created = created.get("response_data", created)
|
||||||
|
if not isinstance(created, dict):
|
||||||
|
created = {}
|
||||||
|
else:
|
||||||
|
created = await client.create_file(
|
||||||
|
name=final_name,
|
||||||
|
mime_type=mime_type,
|
||||||
|
parent_folder_id=final_parent_folder_id,
|
||||||
|
content=final_content,
|
||||||
|
)
|
||||||
except HttpError as http_err:
|
except HttpError as http_err:
|
||||||
if http_err.resp.status == 403:
|
if http_err.resp.status == 403:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
||||||
|
|
@ -158,24 +158,38 @@ def create_delete_google_drive_file_tool(
|
||||||
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_built_creds = None
|
is_composio_drive = (
|
||||||
if (
|
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_drive:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
pre_built_creds = build_composio_credentials(cca_id)
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this Drive connector.",
|
||||||
|
}
|
||||||
|
|
||||||
client = GoogleDriveClient(
|
client = GoogleDriveClient(
|
||||||
session=db_session,
|
session=db_session,
|
||||||
connector_id=connector.id,
|
connector_id=connector.id,
|
||||||
credentials=pre_built_creds,
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await client.trash_file(file_id=final_file_id)
|
if is_composio_drive:
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLEDRIVE_TRASH_FILE",
|
||||||
|
params={"file_id": final_file_id},
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
result.get("error", "Unknown Composio Drive error")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await client.trash_file(file_id=final_file_id)
|
||||||
except HttpError as http_err:
|
except HttpError as http_err:
|
||||||
if http_err.resp.status == 403:
|
if http_err.resp.status == 403:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
|
||||||
{
|
{
|
||||||
"create_gmail_draft",
|
"create_gmail_draft",
|
||||||
"update_gmail_draft",
|
"update_gmail_draft",
|
||||||
|
"create_calendar_event",
|
||||||
"create_notion_page",
|
"create_notion_page",
|
||||||
"create_confluence_page",
|
"create_confluence_page",
|
||||||
"create_google_drive_file",
|
"create_google_drive_file",
|
||||||
|
|
|
||||||
|
|
@ -649,13 +649,9 @@ async def list_composio_drive_folders(
|
||||||
"""
|
"""
|
||||||
List folders AND files in user's Google Drive via Composio.
|
List folders AND files in user's Google Drive via Composio.
|
||||||
|
|
||||||
Uses the same GoogleDriveClient / list_folder_contents path as the native
|
Uses Composio's Google Drive tool execution path so managed OAuth tokens
|
||||||
connector, with Composio-sourced credentials. This means auth errors
|
do not need to be exposed through connected account state.
|
||||||
propagate identically (Google returns 401 → exception → auth_expired flag).
|
|
||||||
"""
|
"""
|
||||||
from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
|
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
|
||||||
|
|
||||||
if not ComposioService.is_enabled():
|
if not ComposioService.is_enabled():
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=503,
|
status_code=503,
|
||||||
|
|
@ -689,10 +685,37 @@ async def list_composio_drive_folders(
|
||||||
detail="Composio connected account not found. Please reconnect the connector.",
|
detail="Composio connected account not found. Please reconnect the connector.",
|
||||||
)
|
)
|
||||||
|
|
||||||
credentials = build_composio_credentials(composio_connected_account_id)
|
service = ComposioService()
|
||||||
drive_client = GoogleDriveClient(session, connector_id, credentials=credentials)
|
entity_id = f"surfsense_{user.id}"
|
||||||
|
items = []
|
||||||
|
page_token = None
|
||||||
|
error = None
|
||||||
|
|
||||||
items, error = await list_folder_contents(drive_client, parent_id=parent_id)
|
while True:
|
||||||
|
page_items, next_token, page_error = await service.get_drive_files(
|
||||||
|
connected_account_id=composio_connected_account_id,
|
||||||
|
entity_id=entity_id,
|
||||||
|
folder_id=parent_id,
|
||||||
|
page_token=page_token,
|
||||||
|
page_size=100,
|
||||||
|
)
|
||||||
|
if page_error:
|
||||||
|
error = page_error
|
||||||
|
break
|
||||||
|
|
||||||
|
items.extend(page_items)
|
||||||
|
if not next_token:
|
||||||
|
break
|
||||||
|
page_token = next_token
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
item["isFolder"] = (
|
||||||
|
item.get("mimeType") == "application/vnd.google-apps.folder"
|
||||||
|
)
|
||||||
|
|
||||||
|
items.sort(
|
||||||
|
key=lambda item: (not item["isFolder"], item.get("name", "").lower())
|
||||||
|
)
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
error_lower = error.lower()
|
error_lower = error.lower()
|
||||||
|
|
|
||||||
|
|
@ -408,12 +408,37 @@ class ComposioService:
|
||||||
files = []
|
files = []
|
||||||
next_token = None
|
next_token = None
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
response_data = (
|
||||||
|
inner_data.get("response_data", {})
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
# Try direct access first, then nested
|
# Try direct access first, then nested
|
||||||
files = data.get("files", []) or data.get("data", {}).get("files", [])
|
files = (
|
||||||
|
data.get("files", [])
|
||||||
|
or (
|
||||||
|
inner_data.get("files", [])
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
or response_data.get("files", [])
|
||||||
|
)
|
||||||
next_token = (
|
next_token = (
|
||||||
data.get("nextPageToken")
|
data.get("nextPageToken")
|
||||||
or data.get("next_page_token")
|
or data.get("next_page_token")
|
||||||
or data.get("data", {}).get("nextPageToken")
|
or (
|
||||||
|
inner_data.get("nextPageToken")
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
inner_data.get("next_page_token")
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
or response_data.get("nextPageToken")
|
||||||
|
or response_data.get("next_page_token")
|
||||||
)
|
)
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
files = data
|
files = data
|
||||||
|
|
@ -819,24 +844,61 @@ class ComposioService:
|
||||||
next_token = None
|
next_token = None
|
||||||
result_size_estimate = None
|
result_size_estimate = None
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
response_data = (
|
||||||
|
inner_data.get("response_data", {})
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
messages = (
|
messages = (
|
||||||
data.get("messages", [])
|
data.get("messages", [])
|
||||||
or data.get("data", {}).get("messages", [])
|
or (
|
||||||
|
inner_data.get("messages", [])
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
or response_data.get("messages", [])
|
||||||
or data.get("emails", [])
|
or data.get("emails", [])
|
||||||
|
or (
|
||||||
|
inner_data.get("emails", [])
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
or response_data.get("emails", [])
|
||||||
)
|
)
|
||||||
# Check for pagination token in various possible locations
|
# Check for pagination token in various possible locations
|
||||||
next_token = (
|
next_token = (
|
||||||
data.get("nextPageToken")
|
data.get("nextPageToken")
|
||||||
or data.get("next_page_token")
|
or data.get("next_page_token")
|
||||||
or data.get("data", {}).get("nextPageToken")
|
or (
|
||||||
or data.get("data", {}).get("next_page_token")
|
inner_data.get("nextPageToken")
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
inner_data.get("next_page_token")
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
or response_data.get("nextPageToken")
|
||||||
|
or response_data.get("next_page_token")
|
||||||
)
|
)
|
||||||
# Extract resultSizeEstimate if available (Gmail API provides this)
|
# Extract resultSizeEstimate if available (Gmail API provides this)
|
||||||
result_size_estimate = (
|
result_size_estimate = (
|
||||||
data.get("resultSizeEstimate")
|
data.get("resultSizeEstimate")
|
||||||
or data.get("result_size_estimate")
|
or data.get("result_size_estimate")
|
||||||
or data.get("data", {}).get("resultSizeEstimate")
|
or (
|
||||||
or data.get("data", {}).get("result_size_estimate")
|
inner_data.get("resultSizeEstimate")
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
inner_data.get("result_size_estimate")
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
or response_data.get("resultSizeEstimate")
|
||||||
|
or response_data.get("result_size_estimate")
|
||||||
)
|
)
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
messages = data
|
messages = data
|
||||||
|
|
@ -864,7 +926,7 @@ class ComposioService:
|
||||||
try:
|
try:
|
||||||
result = await self.execute_tool(
|
result = await self.execute_tool(
|
||||||
connected_account_id=connected_account_id,
|
connected_account_id=connected_account_id,
|
||||||
tool_name="GMAIL_GET_MESSAGE_BY_MESSAGE_ID",
|
tool_name="GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID",
|
||||||
params={"message_id": message_id}, # snake_case
|
params={"message_id": message_id}, # snake_case
|
||||||
entity_id=entity_id,
|
entity_id=entity_id,
|
||||||
)
|
)
|
||||||
|
|
@ -872,7 +934,13 @@ class ComposioService:
|
||||||
if not result.get("success"):
|
if not result.get("success"):
|
||||||
return None, result.get("error", "Unknown error")
|
return None, result.get("error", "Unknown error")
|
||||||
|
|
||||||
return result.get("data"), None
|
data = result.get("data")
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
if isinstance(inner_data, dict):
|
||||||
|
return inner_data.get("response_data", inner_data), None
|
||||||
|
|
||||||
|
return data, None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get Gmail message detail: {e!s}")
|
logger.error(f"Failed to get Gmail message detail: {e!s}")
|
||||||
|
|
@ -928,10 +996,27 @@ class ComposioService:
|
||||||
# Try different possible response structures
|
# Try different possible response structures
|
||||||
events = []
|
events = []
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
response_data = (
|
||||||
|
inner_data.get("response_data", {})
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
events = (
|
events = (
|
||||||
data.get("items", [])
|
data.get("items", [])
|
||||||
or data.get("data", {}).get("items", [])
|
or (
|
||||||
|
inner_data.get("items", [])
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
or response_data.get("items", [])
|
||||||
or data.get("events", [])
|
or data.get("events", [])
|
||||||
|
or (
|
||||||
|
inner_data.get("events", [])
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
or response_data.get("events", [])
|
||||||
)
|
)
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
events = data
|
events = data
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from app.db import (
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
)
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -78,14 +78,49 @@ class GmailToolMetadataService:
|
||||||
def __init__(self, db_session: AsyncSession):
|
def __init__(self, db_session: AsyncSession):
|
||||||
self._db_session = db_session
|
self._db_session = db_session
|
||||||
|
|
||||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
|
||||||
if (
|
return (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
):
|
)
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
|
||||||
if cca_id:
|
def _get_composio_connected_account_id(
|
||||||
return build_composio_credentials(cca_id)
|
self, connector: SearchSourceConnector
|
||||||
|
) -> str:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
raise ValueError("Composio connected_account_id not found")
|
||||||
|
return cca_id
|
||||||
|
|
||||||
|
def _unwrap_composio_data(self, data: Any) -> Any:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner = data.get("data", data)
|
||||||
|
if isinstance(inner, dict):
|
||||||
|
return inner.get("response_data", inner)
|
||||||
|
return inner
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def _execute_composio_gmail_tool(
|
||||||
|
self,
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
tool_name: str,
|
||||||
|
params: dict[str, Any],
|
||||||
|
) -> tuple[Any, str | None]:
|
||||||
|
result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=self._get_composio_connected_account_id(connector),
|
||||||
|
tool_name=tool_name,
|
||||||
|
params=params,
|
||||||
|
entity_id=f"surfsense_{connector.user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return None, result.get("error", "Unknown Composio Gmail error")
|
||||||
|
return self._unwrap_composio_data(result.get("data")), None
|
||||||
|
|
||||||
|
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
raise ValueError(
|
||||||
|
"Composio Gmail connectors must use Composio tool execution"
|
||||||
|
)
|
||||||
|
|
||||||
config_data = dict(connector.config)
|
config_data = dict(connector.config)
|
||||||
|
|
||||||
|
|
@ -139,6 +174,12 @@ class GmailToolMetadataService:
|
||||||
if not connector:
|
if not connector:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
_profile, error = await self._execute_composio_gmail_tool(
|
||||||
|
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
|
||||||
|
)
|
||||||
|
return bool(error)
|
||||||
|
|
||||||
creds = await self._build_credentials(connector)
|
creds = await self._build_credentials(connector)
|
||||||
service = build("gmail", "v1", credentials=creds)
|
service = build("gmail", "v1", credentials=creds)
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
|
@ -221,14 +262,21 @@ class GmailToolMetadataService:
|
||||||
)
|
)
|
||||||
connector = result.scalar_one_or_none()
|
connector = result.scalar_one_or_none()
|
||||||
if connector:
|
if connector:
|
||||||
creds = await self._build_credentials(connector)
|
if self._is_composio_connector(connector):
|
||||||
service = build("gmail", "v1", credentials=creds)
|
profile, error = await self._execute_composio_gmail_tool(
|
||||||
profile = await asyncio.get_event_loop().run_in_executor(
|
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
|
||||||
None,
|
)
|
||||||
lambda service=service: (
|
if error:
|
||||||
service.users().getProfile(userId="me").execute()
|
raise RuntimeError(error)
|
||||||
),
|
else:
|
||||||
)
|
creds = await self._build_credentials(connector)
|
||||||
|
service = build("gmail", "v1", credentials=creds)
|
||||||
|
profile = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda service=service: (
|
||||||
|
service.users().getProfile(userId="me").execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
acc_dict["email"] = profile.get("emailAddress", "")
|
acc_dict["email"] = profile.get("emailAddress", "")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -298,6 +346,23 @@ class GmailToolMetadataService:
|
||||||
Returns ``None`` on any failure so callers can degrade gracefully.
|
Returns ``None`` on any failure so callers can degrade gracefully.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
if not draft_id:
|
||||||
|
draft_id = await self._find_composio_draft_id(connector, message_id)
|
||||||
|
if not draft_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
draft, error = await self._execute_composio_gmail_tool(
|
||||||
|
connector,
|
||||||
|
"GMAIL_GET_DRAFT",
|
||||||
|
{"user_id": "me", "draft_id": draft_id, "format": "full"},
|
||||||
|
)
|
||||||
|
if error or not isinstance(draft, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
payload = draft.get("message", {}).get("payload", {})
|
||||||
|
return self._extract_body_from_payload(payload)
|
||||||
|
|
||||||
creds = await self._build_credentials(connector)
|
creds = await self._build_credentials(connector)
|
||||||
service = build("gmail", "v1", credentials=creds)
|
service = build("gmail", "v1", credentials=creds)
|
||||||
|
|
||||||
|
|
@ -326,6 +391,33 @@ class GmailToolMetadataService:
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _find_composio_draft_id(
|
||||||
|
self, connector: SearchSourceConnector, message_id: str
|
||||||
|
) -> str | None:
|
||||||
|
page_token = ""
|
||||||
|
while True:
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"user_id": "me",
|
||||||
|
"max_results": 100,
|
||||||
|
"verbose": False,
|
||||||
|
}
|
||||||
|
if page_token:
|
||||||
|
params["page_token"] = page_token
|
||||||
|
|
||||||
|
data, error = await self._execute_composio_gmail_tool(
|
||||||
|
connector, "GMAIL_LIST_DRAFTS", params
|
||||||
|
)
|
||||||
|
if error or not isinstance(data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
for draft in data.get("drafts", []):
|
||||||
|
if draft.get("message", {}).get("id") == message_id:
|
||||||
|
return draft.get("id")
|
||||||
|
|
||||||
|
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
|
||||||
|
if not page_token:
|
||||||
|
return None
|
||||||
|
|
||||||
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
|
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
|
||||||
"""Resolve a draft ID from its message ID by scanning drafts.list."""
|
"""Resolve a draft ID from its message ID by scanning drafts.list."""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from app.db import (
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
)
|
)
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -21,7 +22,6 @@ from app.utils.document_converters import (
|
||||||
generate_document_summary,
|
generate_document_summary,
|
||||||
generate_unique_identifier_hash,
|
generate_unique_identifier_hash,
|
||||||
)
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -203,23 +203,46 @@ class GoogleCalendarKBSyncService:
|
||||||
logger.warning("Document %s not found in KB", document_id)
|
logger.warning("Document %s not found in KB", document_id)
|
||||||
return {"status": "not_indexed"}
|
return {"status": "not_indexed"}
|
||||||
|
|
||||||
creds = await self._build_credentials_for_connector(connector_id)
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
service = await loop.run_in_executor(
|
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
|
||||||
)
|
|
||||||
|
|
||||||
calendar_id = (document.document_metadata or {}).get(
|
calendar_id = (document.document_metadata or {}).get(
|
||||||
"calendar_id"
|
"calendar_id"
|
||||||
) or "primary"
|
) or "primary"
|
||||||
live_event = await loop.run_in_executor(
|
connector = await self._get_connector(connector_id)
|
||||||
None,
|
if (
|
||||||
lambda: (
|
connector.connector_type
|
||||||
service.events()
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
.get(calendarId=calendar_id, eventId=event_id)
|
):
|
||||||
.execute()
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
),
|
if not cca_id:
|
||||||
)
|
raise ValueError("Composio connected_account_id not found")
|
||||||
|
composio_result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLECALENDAR_EVENTS_GET",
|
||||||
|
params={"calendar_id": calendar_id, "event_id": event_id},
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not composio_result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
composio_result.get("error", "Unknown Composio Calendar error")
|
||||||
|
)
|
||||||
|
live_event = composio_result.get("data", {})
|
||||||
|
if isinstance(live_event, dict):
|
||||||
|
live_event = live_event.get("data", live_event)
|
||||||
|
if isinstance(live_event, dict):
|
||||||
|
live_event = live_event.get("response_data", live_event)
|
||||||
|
else:
|
||||||
|
creds = await self._build_credentials_for_connector(connector_id)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
service = await loop.run_in_executor(
|
||||||
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
|
)
|
||||||
|
live_event = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
service.events()
|
||||||
|
.get(calendarId=calendar_id, eventId=event_id)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
event_summary = live_event.get("summary", "")
|
event_summary = live_event.get("summary", "")
|
||||||
description = live_event.get("description", "")
|
description = live_event.get("description", "")
|
||||||
|
|
@ -322,7 +345,7 @@ class GoogleCalendarKBSyncService:
|
||||||
await self.db_session.rollback()
|
await self.db_session.rollback()
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
|
async def _get_connector(self, connector_id: int) -> SearchSourceConnector:
|
||||||
result = await self.db_session.execute(
|
result = await self.db_session.execute(
|
||||||
select(SearchSourceConnector).where(
|
select(SearchSourceConnector).where(
|
||||||
SearchSourceConnector.id == connector_id
|
SearchSourceConnector.id == connector_id
|
||||||
|
|
@ -331,15 +354,17 @@ class GoogleCalendarKBSyncService:
|
||||||
connector = result.scalar_one_or_none()
|
connector = result.scalar_one_or_none()
|
||||||
if not connector:
|
if not connector:
|
||||||
raise ValueError(f"Connector {connector_id} not found")
|
raise ValueError(f"Connector {connector_id} not found")
|
||||||
|
return connector
|
||||||
|
|
||||||
|
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
|
||||||
|
connector = await self._get_connector(connector_id)
|
||||||
if (
|
if (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
):
|
):
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
raise ValueError(
|
||||||
if cca_id:
|
"Composio Calendar connectors must use Composio tool execution"
|
||||||
return build_composio_credentials(cca_id)
|
)
|
||||||
raise ValueError("Composio connected_account_id not found")
|
|
||||||
|
|
||||||
config_data = dict(connector.config)
|
config_data = dict(connector.config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from app.db import (
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
)
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -94,15 +94,49 @@ class GoogleCalendarToolMetadataService:
|
||||||
def __init__(self, db_session: AsyncSession):
|
def __init__(self, db_session: AsyncSession):
|
||||||
self._db_session = db_session
|
self._db_session = db_session
|
||||||
|
|
||||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
|
||||||
if (
|
return (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
):
|
)
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
|
||||||
if cca_id:
|
def _get_composio_connected_account_id(
|
||||||
return build_composio_credentials(cca_id)
|
self, connector: SearchSourceConnector
|
||||||
|
) -> str:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
raise ValueError("Composio connected_account_id not found")
|
raise ValueError("Composio connected_account_id not found")
|
||||||
|
return cca_id
|
||||||
|
|
||||||
|
async def _execute_composio_calendar_tool(
|
||||||
|
self,
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
tool_name: str,
|
||||||
|
params: dict,
|
||||||
|
) -> tuple[dict | list | None, str | None]:
|
||||||
|
service = ComposioService()
|
||||||
|
result = await service.execute_tool(
|
||||||
|
connected_account_id=self._get_composio_connected_account_id(connector),
|
||||||
|
tool_name=tool_name,
|
||||||
|
params=params,
|
||||||
|
entity_id=f"surfsense_{connector.user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return None, result.get("error", "Unknown Composio Calendar error")
|
||||||
|
|
||||||
|
data = result.get("data")
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner = data.get("data", data)
|
||||||
|
if isinstance(inner, dict):
|
||||||
|
return inner.get("response_data", inner), None
|
||||||
|
return inner, None
|
||||||
|
return data, None
|
||||||
|
|
||||||
|
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
raise ValueError(
|
||||||
|
"Composio Calendar connectors must use Composio tool execution"
|
||||||
|
)
|
||||||
|
|
||||||
config_data = dict(connector.config)
|
config_data = dict(connector.config)
|
||||||
|
|
||||||
|
|
@ -156,6 +190,14 @@ class GoogleCalendarToolMetadataService:
|
||||||
if not connector:
|
if not connector:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
_data, error = await self._execute_composio_calendar_tool(
|
||||||
|
connector,
|
||||||
|
"GOOGLECALENDAR_GET_CALENDAR",
|
||||||
|
{"calendar_id": "primary"},
|
||||||
|
)
|
||||||
|
return bool(error)
|
||||||
|
|
||||||
creds = await self._build_credentials(connector)
|
creds = await self._build_credentials(connector)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(
|
||||||
|
|
@ -255,16 +297,48 @@ class GoogleCalendarToolMetadataService:
|
||||||
timezone_str = ""
|
timezone_str = ""
|
||||||
if connector:
|
if connector:
|
||||||
try:
|
try:
|
||||||
creds = await self._build_credentials(connector)
|
if self._is_composio_connector(connector):
|
||||||
loop = asyncio.get_event_loop()
|
cal_list, cal_error = await self._execute_composio_calendar_tool(
|
||||||
service = await loop.run_in_executor(
|
connector, "GOOGLECALENDAR_LIST_CALENDARS", {}
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
)
|
||||||
)
|
if cal_error:
|
||||||
|
raise RuntimeError(cal_error)
|
||||||
|
(
|
||||||
|
settings,
|
||||||
|
settings_error,
|
||||||
|
) = await self._execute_composio_calendar_tool(
|
||||||
|
connector,
|
||||||
|
"GOOGLECALENDAR_SETTINGS_GET",
|
||||||
|
{"setting": "timezone"},
|
||||||
|
)
|
||||||
|
if not settings_error and isinstance(settings, dict):
|
||||||
|
timezone_str = settings.get("value", "")
|
||||||
|
else:
|
||||||
|
creds = await self._build_credentials(connector)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
service = await loop.run_in_executor(
|
||||||
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
|
)
|
||||||
|
|
||||||
cal_list = await loop.run_in_executor(
|
cal_list = await loop.run_in_executor(
|
||||||
None, lambda: service.calendarList().list().execute()
|
None, lambda: service.calendarList().list().execute()
|
||||||
)
|
)
|
||||||
for cal in cal_list.get("items", []):
|
|
||||||
|
tz_setting = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: service.settings().get(setting="timezone").execute(),
|
||||||
|
)
|
||||||
|
timezone_str = tz_setting.get("value", "")
|
||||||
|
|
||||||
|
calendar_items = []
|
||||||
|
if isinstance(cal_list, dict):
|
||||||
|
calendar_items = (
|
||||||
|
cal_list.get("items") or cal_list.get("calendars") or []
|
||||||
|
)
|
||||||
|
elif isinstance(cal_list, list):
|
||||||
|
calendar_items = cal_list
|
||||||
|
|
||||||
|
for cal in calendar_items:
|
||||||
calendars.append(
|
calendars.append(
|
||||||
{
|
{
|
||||||
"id": cal.get("id", ""),
|
"id": cal.get("id", ""),
|
||||||
|
|
@ -272,12 +346,6 @@ class GoogleCalendarToolMetadataService:
|
||||||
"primary": cal.get("primary", False),
|
"primary": cal.get("primary", False),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
tz_setting = await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: service.settings().get(setting="timezone").execute(),
|
|
||||||
)
|
|
||||||
timezone_str = tz_setting.get("value", "")
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to fetch calendars/timezone for connector %s",
|
"Failed to fetch calendars/timezone for connector %s",
|
||||||
|
|
@ -321,20 +389,29 @@ class GoogleCalendarToolMetadataService:
|
||||||
|
|
||||||
event_dict = event.to_dict()
|
event_dict = event.to_dict()
|
||||||
try:
|
try:
|
||||||
creds = await self._build_credentials(connector)
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
service = await loop.run_in_executor(
|
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
|
||||||
)
|
|
||||||
calendar_id = event.calendar_id or "primary"
|
calendar_id = event.calendar_id or "primary"
|
||||||
live_event = await loop.run_in_executor(
|
if self._is_composio_connector(connector):
|
||||||
None,
|
live_event, error = await self._execute_composio_calendar_tool(
|
||||||
lambda: (
|
connector,
|
||||||
service.events()
|
"GOOGLECALENDAR_EVENTS_GET",
|
||||||
.get(calendarId=calendar_id, eventId=event.event_id)
|
{"calendar_id": calendar_id, "event_id": event.event_id},
|
||||||
.execute()
|
)
|
||||||
),
|
if error:
|
||||||
)
|
raise RuntimeError(error)
|
||||||
|
else:
|
||||||
|
creds = await self._build_credentials(connector)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
service = await loop.run_in_executor(
|
||||||
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
|
)
|
||||||
|
live_event = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
service.events()
|
||||||
|
.get(calendarId=calendar_id, eventId=event.event_id)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
event_dict["summary"] = live_event.get("summary", event_dict["summary"])
|
event_dict["summary"] = live_event.get("summary", event_dict["summary"])
|
||||||
event_dict["description"] = live_event.get(
|
event_dict["description"] = live_event.get(
|
||||||
|
|
@ -376,12 +453,30 @@ class GoogleCalendarToolMetadataService:
|
||||||
) -> dict:
|
) -> dict:
|
||||||
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
|
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
|
||||||
if not resolved:
|
if not resolved:
|
||||||
|
live_resolved = await self._resolve_live_event(
|
||||||
|
search_space_id, user_id, event_ref
|
||||||
|
)
|
||||||
|
if not live_resolved:
|
||||||
|
return {
|
||||||
|
"error": (
|
||||||
|
f"Event '{event_ref}' not found in your indexed or live Google Calendar events. "
|
||||||
|
"This could mean: (1) the event doesn't exist, "
|
||||||
|
"(2) the event name is different, or "
|
||||||
|
"(3) the connected calendar account cannot access it."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
connector, live_event = live_resolved
|
||||||
|
account = GoogleCalendarAccount.from_connector(connector)
|
||||||
|
acc_dict = account.to_dict()
|
||||||
|
auth_expired = await self._check_account_health(connector.id)
|
||||||
|
acc_dict["auth_expired"] = auth_expired
|
||||||
|
if auth_expired:
|
||||||
|
await self._persist_auth_expired(connector.id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"error": (
|
"account": acc_dict,
|
||||||
f"Event '{event_ref}' not found in your indexed Google Calendar events. "
|
"event": self._event_dict_from_live_event(live_event),
|
||||||
"This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, "
|
|
||||||
"or (3) the event name is different."
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
document, connector = resolved
|
document, connector = resolved
|
||||||
|
|
@ -429,3 +524,110 @@ class GoogleCalendarToolMetadataService:
|
||||||
if row:
|
if row:
|
||||||
return row[0], row[1]
|
return row[0], row[1]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _resolve_live_event(
|
||||||
|
self, search_space_id: int, user_id: str, event_ref: str
|
||||||
|
) -> tuple[SearchSourceConnector, dict] | None:
|
||||||
|
result = await self._db_session.execute(
|
||||||
|
select(SearchSourceConnector)
|
||||||
|
.filter(
|
||||||
|
and_(
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
||||||
|
)
|
||||||
|
connectors = result.scalars().all()
|
||||||
|
|
||||||
|
for connector in connectors:
|
||||||
|
try:
|
||||||
|
events = await self._search_live_events(connector, event_ref)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to search live calendar events for connector %s",
|
||||||
|
connector.id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not events:
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized_ref = event_ref.strip().lower()
|
||||||
|
exact_match = next(
|
||||||
|
(
|
||||||
|
event
|
||||||
|
for event in events
|
||||||
|
if event.get("summary", "").strip().lower() == normalized_ref
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return connector, exact_match or events[0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _search_live_events(
|
||||||
|
self, connector: SearchSourceConnector, event_ref: str
|
||||||
|
) -> list[dict]:
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
data, error = await self._execute_composio_calendar_tool(
|
||||||
|
connector,
|
||||||
|
"GOOGLECALENDAR_EVENTS_LIST",
|
||||||
|
{
|
||||||
|
"calendar_id": "primary",
|
||||||
|
"q": event_ref,
|
||||||
|
"max_results": 10,
|
||||||
|
"single_events": True,
|
||||||
|
"order_by": "startTime",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return data.get("items") or data.get("events") or []
|
||||||
|
return data if isinstance(data, list) else []
|
||||||
|
|
||||||
|
creds = await self._build_credentials(connector)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
service = await loop.run_in_executor(
|
||||||
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
|
)
|
||||||
|
response = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
service.events()
|
||||||
|
.list(
|
||||||
|
calendarId="primary",
|
||||||
|
q=event_ref,
|
||||||
|
maxResults=10,
|
||||||
|
singleEvents=True,
|
||||||
|
orderBy="startTime",
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response.get("items", [])
|
||||||
|
|
||||||
|
def _event_dict_from_live_event(self, event: dict) -> dict:
|
||||||
|
start_data = event.get("start", {})
|
||||||
|
end_data = event.get("end", {})
|
||||||
|
return {
|
||||||
|
"event_id": event.get("id", ""),
|
||||||
|
"summary": event.get("summary", "No Title"),
|
||||||
|
"start": start_data.get("dateTime", start_data.get("date", "")),
|
||||||
|
"end": end_data.get("dateTime", end_data.get("date", "")),
|
||||||
|
"description": event.get("description", ""),
|
||||||
|
"location": event.get("location", ""),
|
||||||
|
"attendees": [
|
||||||
|
{
|
||||||
|
"email": attendee.get("email", ""),
|
||||||
|
"responseStatus": attendee.get("responseStatus", ""),
|
||||||
|
}
|
||||||
|
for attendee in event.get("attendees", [])
|
||||||
|
],
|
||||||
|
"calendar_id": event.get("calendarId", "primary"),
|
||||||
|
"document_id": None,
|
||||||
|
"indexed_at": None,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from app.db import (
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
)
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -67,6 +67,42 @@ class GoogleDriveToolMetadataService:
|
||||||
def __init__(self, db_session: AsyncSession):
|
def __init__(self, db_session: AsyncSession):
|
||||||
self._db_session = db_session
|
self._db_session = db_session
|
||||||
|
|
||||||
|
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
|
||||||
|
return (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_composio_connected_account_id(
|
||||||
|
self, connector: SearchSourceConnector
|
||||||
|
) -> str:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
raise ValueError("Composio connected_account_id not found")
|
||||||
|
return cca_id
|
||||||
|
|
||||||
|
async def _execute_composio_drive_tool(
|
||||||
|
self,
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
tool_name: str,
|
||||||
|
params: dict,
|
||||||
|
) -> tuple[dict | list | None, str | None]:
|
||||||
|
result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=self._get_composio_connected_account_id(connector),
|
||||||
|
tool_name=tool_name,
|
||||||
|
params=params,
|
||||||
|
entity_id=f"surfsense_{connector.user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return None, result.get("error", "Unknown Composio Drive error")
|
||||||
|
data = result.get("data")
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner = data.get("data", data)
|
||||||
|
if isinstance(inner, dict):
|
||||||
|
return inner.get("response_data", inner), None
|
||||||
|
return inner, None
|
||||||
|
return data, None
|
||||||
|
|
||||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||||
accounts = await self._get_google_drive_accounts(search_space_id, user_id)
|
accounts = await self._get_google_drive_accounts(search_space_id, user_id)
|
||||||
|
|
||||||
|
|
@ -200,19 +236,21 @@ class GoogleDriveToolMetadataService:
|
||||||
if not connector:
|
if not connector:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
pre_built_creds = None
|
if self._is_composio_connector(connector):
|
||||||
if (
|
_data, error = await self._execute_composio_drive_tool(
|
||||||
connector.connector_type
|
connector,
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
"GOOGLEDRIVE_LIST_FILES",
|
||||||
):
|
{
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
"q": "trashed = false",
|
||||||
if cca_id:
|
"page_size": 1,
|
||||||
pre_built_creds = build_composio_credentials(cca_id)
|
"fields": "files(id)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return bool(error)
|
||||||
|
|
||||||
client = GoogleDriveClient(
|
client = GoogleDriveClient(
|
||||||
session=self._db_session,
|
session=self._db_session,
|
||||||
connector_id=connector_id,
|
connector_id=connector_id,
|
||||||
credentials=pre_built_creds,
|
|
||||||
)
|
)
|
||||||
await client.list_files(
|
await client.list_files(
|
||||||
query="trashed = false", page_size=1, fields="files(id)"
|
query="trashed = false", page_size=1, fields="files(id)"
|
||||||
|
|
@ -274,19 +312,39 @@ class GoogleDriveToolMetadataService:
|
||||||
parent_folders[connector_id] = []
|
parent_folders[connector_id] = []
|
||||||
continue
|
continue
|
||||||
|
|
||||||
pre_built_creds = None
|
if self._is_composio_connector(connector):
|
||||||
if (
|
data, error = await self._execute_composio_drive_tool(
|
||||||
connector.connector_type
|
connector,
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
"GOOGLEDRIVE_LIST_FILES",
|
||||||
):
|
{
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
"q": "mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents",
|
||||||
if cca_id:
|
"fields": "files(id,name)",
|
||||||
pre_built_creds = build_composio_credentials(cca_id)
|
"page_size": 50,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to list folders for connector %s: %s",
|
||||||
|
connector_id,
|
||||||
|
error,
|
||||||
|
)
|
||||||
|
parent_folders[connector_id] = []
|
||||||
|
continue
|
||||||
|
folders = []
|
||||||
|
if isinstance(data, dict):
|
||||||
|
folders = data.get("files", [])
|
||||||
|
elif isinstance(data, list):
|
||||||
|
folders = data
|
||||||
|
parent_folders[connector_id] = [
|
||||||
|
{"folder_id": f["id"], "name": f["name"]}
|
||||||
|
for f in folders
|
||||||
|
if f.get("id") and f.get("name")
|
||||||
|
]
|
||||||
|
continue
|
||||||
|
|
||||||
client = GoogleDriveClient(
|
client = GoogleDriveClient(
|
||||||
session=self._db_session,
|
session=self._db_session,
|
||||||
connector_id=connector_id,
|
connector_id=connector_id,
|
||||||
credentials=pre_built_creds,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
folders, _, error = await client.list_files(
|
folders, _, error = await client.list_files(
|
||||||
|
|
|
||||||
|
|
@ -96,6 +96,46 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||||
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||||
|
|
||||||
|
|
||||||
|
def _first_interrupt_value(state: Any) -> dict[str, Any] | None:
|
||||||
|
"""Return the first LangGraph interrupt payload across all snapshot tasks."""
|
||||||
|
def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None:
|
||||||
|
if isinstance(candidate, dict):
|
||||||
|
value = candidate.get("value", candidate)
|
||||||
|
return value if isinstance(value, dict) else None
|
||||||
|
value = getattr(candidate, "value", None)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return value
|
||||||
|
if isinstance(candidate, (list, tuple)):
|
||||||
|
for item in candidate:
|
||||||
|
extracted = _extract_interrupt_value(item)
|
||||||
|
if extracted is not None:
|
||||||
|
return extracted
|
||||||
|
return None
|
||||||
|
|
||||||
|
for task in getattr(state, "tasks", ()) or ():
|
||||||
|
try:
|
||||||
|
interrupts = getattr(task, "interrupts", ()) or ()
|
||||||
|
except (AttributeError, IndexError, TypeError):
|
||||||
|
interrupts = ()
|
||||||
|
if not interrupts:
|
||||||
|
extracted = _extract_interrupt_value(task)
|
||||||
|
if extracted is not None:
|
||||||
|
return extracted
|
||||||
|
continue
|
||||||
|
for interrupt_item in interrupts:
|
||||||
|
extracted = _extract_interrupt_value(interrupt_item)
|
||||||
|
if extracted is not None:
|
||||||
|
return extracted
|
||||||
|
try:
|
||||||
|
state_interrupts = getattr(state, "interrupts", ()) or ()
|
||||||
|
except (AttributeError, IndexError, TypeError):
|
||||||
|
state_interrupts = ()
|
||||||
|
extracted = _extract_interrupt_value(state_interrupts)
|
||||||
|
if extracted is not None:
|
||||||
|
return extracted
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
||||||
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
|
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
|
||||||
|
|
||||||
|
|
@ -2178,10 +2218,10 @@ async def _stream_agent_events(
|
||||||
result.agent_called_update_memory = called_update_memory
|
result.agent_called_update_memory = called_update_memory
|
||||||
_log_file_contract("turn_outcome", result)
|
_log_file_contract("turn_outcome", result)
|
||||||
|
|
||||||
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
|
interrupt_value = _first_interrupt_value(state)
|
||||||
if is_interrupted:
|
if interrupt_value is not None:
|
||||||
result.is_interrupted = True
|
result.is_interrupted = True
|
||||||
result.interrupt_value = state.tasks[0].interrupts[0].value
|
result.interrupt_value = interrupt_value
|
||||||
yield streaming_service.format_interrupt_request(result.interrupt_value)
|
yield streaming_service.format_interrupt_request(result.interrupt_value)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
IndexingPipelineService,
|
IndexingPipelineService,
|
||||||
PlaceholderInfo,
|
PlaceholderInfo,
|
||||||
)
|
)
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
from app.utils.google_credentials import (
|
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
|
||||||
build_composio_credentials,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
check_duplicate_document_by_hash,
|
check_duplicate_document_by_hash,
|
||||||
|
|
@ -44,6 +42,10 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||||
|
|
||||||
|
|
||||||
|
def _format_calendar_event_to_markdown(event: dict) -> str:
|
||||||
|
return GoogleCalendarConnector.format_event_to_markdown(None, event)
|
||||||
|
|
||||||
|
|
||||||
def _build_connector_doc(
|
def _build_connector_doc(
|
||||||
event: dict,
|
event: dict,
|
||||||
event_markdown: str,
|
event_markdown: str,
|
||||||
|
|
@ -150,7 +152,14 @@ async def index_google_calendar_events(
|
||||||
)
|
)
|
||||||
return 0, 0, f"Connector with ID {connector_id} not found"
|
return 0, 0, f"Connector with ID {connector_id} not found"
|
||||||
|
|
||||||
# ── Credential building ───────────────────────────────────────
|
is_composio_connector = (
|
||||||
|
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
|
)
|
||||||
|
calendar_client = None
|
||||||
|
composio_service = None
|
||||||
|
connected_account_id = None
|
||||||
|
|
||||||
|
# ── Credential/client building ────────────────────────────────
|
||||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||||
if not connected_account_id:
|
if not connected_account_id:
|
||||||
|
|
@ -161,7 +170,7 @@ async def index_google_calendar_events(
|
||||||
{"error_type": "MissingComposioAccount"},
|
{"error_type": "MissingComposioAccount"},
|
||||||
)
|
)
|
||||||
return 0, 0, "Composio connected_account_id not found"
|
return 0, 0, "Composio connected_account_id not found"
|
||||||
credentials = build_composio_credentials(connected_account_id)
|
composio_service = ComposioService()
|
||||||
else:
|
else:
|
||||||
config_data = connector.config
|
config_data = connector.config
|
||||||
|
|
||||||
|
|
@ -229,12 +238,13 @@ async def index_google_calendar_events(
|
||||||
{"stage": "client_initialization"},
|
{"stage": "client_initialization"},
|
||||||
)
|
)
|
||||||
|
|
||||||
calendar_client = GoogleCalendarConnector(
|
if not is_composio_connector:
|
||||||
credentials=credentials,
|
calendar_client = GoogleCalendarConnector(
|
||||||
session=session,
|
credentials=credentials,
|
||||||
user_id=user_id,
|
session=session,
|
||||||
connector_id=connector_id,
|
user_id=user_id,
|
||||||
)
|
connector_id=connector_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Handle 'undefined' string from frontend (treat as None)
|
# Handle 'undefined' string from frontend (treat as None)
|
||||||
if start_date == "undefined" or start_date == "":
|
if start_date == "undefined" or start_date == "":
|
||||||
|
|
@ -300,9 +310,26 @@ async def index_google_calendar_events(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
events, error = await calendar_client.get_all_primary_calendar_events(
|
if is_composio_connector:
|
||||||
start_date=start_date_str, end_date=end_date_str
|
start_dt = parse_date_flexible(start_date_str).replace(
|
||||||
)
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
|
)
|
||||||
|
end_dt = parse_date_flexible(end_date_str).replace(
|
||||||
|
hour=23, minute=59, second=59, microsecond=0
|
||||||
|
)
|
||||||
|
events, error = await composio_service.get_calendar_events(
|
||||||
|
connected_account_id=connected_account_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
time_min=start_dt.isoformat(),
|
||||||
|
time_max=end_dt.isoformat(),
|
||||||
|
max_results=250,
|
||||||
|
)
|
||||||
|
if not events and not error:
|
||||||
|
error = "No events found in the specified date range."
|
||||||
|
else:
|
||||||
|
events, error = await calendar_client.get_all_primary_calendar_events(
|
||||||
|
start_date=start_date_str, end_date=end_date_str
|
||||||
|
)
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
if "No events found" in error:
|
if "No events found" in error:
|
||||||
|
|
@ -381,7 +408,7 @@ async def index_google_calendar_events(
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
event_markdown = calendar_client.format_event_to_markdown(event)
|
event_markdown = _format_calendar_event_to_markdown(event)
|
||||||
if not event_markdown.strip():
|
if not event_markdown.strip():
|
||||||
logger.warning(f"Skipping event with no content: {event_summary}")
|
logger.warning(f"Skipping event with no content: {event_summary}")
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import String, cast, select
|
from sqlalchemy import String, cast, select
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
@ -37,6 +39,7 @@ from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
IndexingPipelineService,
|
IndexingPipelineService,
|
||||||
PlaceholderInfo,
|
PlaceholderInfo,
|
||||||
)
|
)
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.page_limit_service import PageLimitService
|
from app.services.page_limit_service import PageLimitService
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
|
|
@ -45,10 +48,7 @@ from app.tasks.connector_indexers.base import (
|
||||||
get_connector_by_id,
|
get_connector_by_id,
|
||||||
update_connector_last_indexed,
|
update_connector_last_indexed,
|
||||||
)
|
)
|
||||||
from app.utils.google_credentials import (
|
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
|
||||||
build_composio_credentials,
|
|
||||||
)
|
|
||||||
|
|
||||||
ACCEPTED_DRIVE_CONNECTOR_TYPES = {
|
ACCEPTED_DRIVE_CONNECTOR_TYPES = {
|
||||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||||
|
|
@ -61,6 +61,209 @@ HEARTBEAT_INTERVAL_SECONDS = 30
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ComposioDriveClient:
|
||||||
|
"""Google Drive client facade backed by Composio tool execution.
|
||||||
|
|
||||||
|
Composio-managed OAuth connections can execute tools without exposing raw
|
||||||
|
OAuth tokens through connected account state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
connector_id: int,
|
||||||
|
connected_account_id: str,
|
||||||
|
entity_id: str,
|
||||||
|
):
|
||||||
|
self.session = session
|
||||||
|
self.connector_id = connector_id
|
||||||
|
self.connected_account_id = connected_account_id
|
||||||
|
self.entity_id = entity_id
|
||||||
|
self.composio = ComposioService()
|
||||||
|
|
||||||
|
async def list_files(
|
||||||
|
self,
|
||||||
|
query: str = "",
|
||||||
|
fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)",
|
||||||
|
page_size: int = 100,
|
||||||
|
page_token: str | None = None,
|
||||||
|
) -> tuple[list[dict[str, Any]], str | None, str | None]:
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"page_size": min(page_size, 100),
|
||||||
|
"fields": fields,
|
||||||
|
}
|
||||||
|
if query:
|
||||||
|
params["q"] = query
|
||||||
|
if page_token:
|
||||||
|
params["page_token"] = page_token
|
||||||
|
|
||||||
|
result = await self.composio.execute_tool(
|
||||||
|
connected_account_id=self.connected_account_id,
|
||||||
|
tool_name="GOOGLEDRIVE_LIST_FILES",
|
||||||
|
params=params,
|
||||||
|
entity_id=self.entity_id,
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return [], None, result.get("error", "Unknown error")
|
||||||
|
|
||||||
|
data = result.get("data", {})
|
||||||
|
files = []
|
||||||
|
next_token = None
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
if isinstance(inner_data, dict):
|
||||||
|
files = inner_data.get("files", [])
|
||||||
|
next_token = inner_data.get("nextPageToken") or inner_data.get(
|
||||||
|
"next_page_token"
|
||||||
|
)
|
||||||
|
elif isinstance(data, list):
|
||||||
|
files = data
|
||||||
|
|
||||||
|
return files, next_token, None
|
||||||
|
|
||||||
|
async def get_file_metadata(
|
||||||
|
self, file_id: str, fields: str = "*"
|
||||||
|
) -> tuple[dict[str, Any] | None, str | None]:
|
||||||
|
result = await self.composio.execute_tool(
|
||||||
|
connected_account_id=self.connected_account_id,
|
||||||
|
tool_name="GOOGLEDRIVE_GET_FILE_METADATA",
|
||||||
|
params={"file_id": file_id, "fields": fields},
|
||||||
|
entity_id=self.entity_id,
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return None, result.get("error", "Unknown error")
|
||||||
|
|
||||||
|
data = result.get("data", {})
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
if isinstance(inner_data, dict):
|
||||||
|
return inner_data, None
|
||||||
|
|
||||||
|
return None, "Could not extract metadata from Composio response"
|
||||||
|
|
||||||
|
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
|
||||||
|
return await self._download_file_content(file_id)
|
||||||
|
|
||||||
|
async def download_file_to_disk(
|
||||||
|
self,
|
||||||
|
file_id: str,
|
||||||
|
dest_path: str,
|
||||||
|
chunksize: int = 5 * 1024 * 1024,
|
||||||
|
) -> str | None:
|
||||||
|
del chunksize
|
||||||
|
content, error = await self.download_file(file_id)
|
||||||
|
if error:
|
||||||
|
return error
|
||||||
|
if content is None:
|
||||||
|
return "No content returned from Composio"
|
||||||
|
Path(dest_path).write_bytes(content)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def export_google_file(
|
||||||
|
self, file_id: str, mime_type: str
|
||||||
|
) -> tuple[bytes | None, str | None]:
|
||||||
|
return await self._download_file_content(file_id, mime_type=mime_type)
|
||||||
|
|
||||||
|
async def _download_file_content(
|
||||||
|
self, file_id: str, mime_type: str | None = None
|
||||||
|
) -> tuple[bytes | None, str | None]:
|
||||||
|
params: dict[str, Any] = {"file_id": file_id}
|
||||||
|
if mime_type:
|
||||||
|
params["mime_type"] = mime_type
|
||||||
|
|
||||||
|
result = await self.composio.execute_tool(
|
||||||
|
connected_account_id=self.connected_account_id,
|
||||||
|
tool_name="GOOGLEDRIVE_DOWNLOAD_FILE",
|
||||||
|
params=params,
|
||||||
|
entity_id=self.entity_id,
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return None, result.get("error", "Unknown error")
|
||||||
|
|
||||||
|
return self._read_download_result(result.get("data"))
|
||||||
|
|
||||||
|
def _read_download_result(self, data: Any) -> tuple[bytes | None, str | None]:
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
return data, None
|
||||||
|
|
||||||
|
file_path: str | None = None
|
||||||
|
if isinstance(data, str):
|
||||||
|
file_path = data
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
if isinstance(inner_data, dict):
|
||||||
|
for key in ("file_path", "downloaded_file_content", "path", "uri"):
|
||||||
|
value = inner_data.get(key)
|
||||||
|
if isinstance(value, str):
|
||||||
|
file_path = value
|
||||||
|
break
|
||||||
|
if isinstance(value, dict):
|
||||||
|
nested = (
|
||||||
|
value.get("file_path")
|
||||||
|
or value.get("downloaded_file_content")
|
||||||
|
or value.get("path")
|
||||||
|
or value.get("uri")
|
||||||
|
or value.get("s3url")
|
||||||
|
)
|
||||||
|
if isinstance(nested, str):
|
||||||
|
file_path = nested
|
||||||
|
break
|
||||||
|
|
||||||
|
if not file_path:
|
||||||
|
return None, "No file path/content returned from Composio"
|
||||||
|
|
||||||
|
if file_path.startswith(("http://", "https://")):
|
||||||
|
try:
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
with urllib.request.urlopen(file_path, timeout=60) as response:
|
||||||
|
return response.read(), None
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"Failed to download Composio file URL: {e!s}"
|
||||||
|
|
||||||
|
path_obj = Path(file_path)
|
||||||
|
if path_obj.is_absolute() or ".composio" in str(path_obj):
|
||||||
|
if not path_obj.exists():
|
||||||
|
return None, f"File not found at path: {file_path}"
|
||||||
|
return path_obj.read_bytes(), None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import base64
|
||||||
|
|
||||||
|
return base64.b64decode(file_path), None
|
||||||
|
except Exception:
|
||||||
|
return file_path.encode("utf-8"), None
|
||||||
|
|
||||||
|
|
||||||
|
def _build_drive_client_for_connector(
|
||||||
|
session: AsyncSession,
|
||||||
|
connector_id: int,
|
||||||
|
connector: object,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[GoogleDriveClient | ComposioDriveClient | None, str | None]:
|
||||||
|
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||||
|
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not connected_account_id:
|
||||||
|
return None, (
|
||||||
|
f"Composio connected_account_id not found for connector {connector_id}"
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
ComposioDriveClient(
|
||||||
|
session,
|
||||||
|
connector_id,
|
||||||
|
connected_account_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||||
|
if token_encrypted and not config.SECRET_KEY:
|
||||||
|
return None, "SECRET_KEY not configured but credentials are marked as encrypted"
|
||||||
|
|
||||||
|
return GoogleDriveClient(session, connector_id), None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -927,34 +1130,17 @@ async def index_google_drive_files(
|
||||||
{"stage": "client_initialization"},
|
{"stage": "client_initialization"},
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_built_credentials = None
|
drive_client, client_error = _build_drive_client_for_connector(
|
||||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
session, connector_id, connector, user_id
|
||||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
)
|
||||||
if not connected_account_id:
|
if client_error or not drive_client:
|
||||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
await task_logger.log_task_failure(
|
||||||
await task_logger.log_task_failure(
|
log_entry,
|
||||||
log_entry,
|
client_error or "Failed to initialize Google Drive client",
|
||||||
error_msg,
|
"Missing connector credentials",
|
||||||
"Missing Composio account",
|
{"error_type": "ClientInitializationError"},
|
||||||
{"error_type": "MissingComposioAccount"},
|
)
|
||||||
)
|
return 0, 0, client_error, 0
|
||||||
return 0, 0, error_msg, 0
|
|
||||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
|
||||||
else:
|
|
||||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
|
||||||
if token_encrypted and not config.SECRET_KEY:
|
|
||||||
await task_logger.log_task_failure(
|
|
||||||
log_entry,
|
|
||||||
"SECRET_KEY not configured but credentials are encrypted",
|
|
||||||
"Missing SECRET_KEY",
|
|
||||||
{"error_type": "MissingSecretKey"},
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
"SECRET_KEY not configured but credentials are marked as encrypted",
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
|
|
||||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||||
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
|
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
|
||||||
|
|
@ -963,10 +1149,6 @@ async def index_google_drive_files(
|
||||||
from app.services.llm_service import get_vision_llm
|
from app.services.llm_service import get_vision_llm
|
||||||
|
|
||||||
vision_llm = await get_vision_llm(session, search_space_id)
|
vision_llm = await get_vision_llm(session, search_space_id)
|
||||||
drive_client = GoogleDriveClient(
|
|
||||||
session, connector_id, credentials=pre_built_credentials
|
|
||||||
)
|
|
||||||
|
|
||||||
if not folder_id:
|
if not folder_id:
|
||||||
error_msg = "folder_id is required for Google Drive indexing"
|
error_msg = "folder_id is required for Google Drive indexing"
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
|
|
@ -979,8 +1161,14 @@ async def index_google_drive_files(
|
||||||
|
|
||||||
folder_tokens = connector.config.get("folder_tokens", {})
|
folder_tokens = connector.config.get("folder_tokens", {})
|
||||||
start_page_token = folder_tokens.get(target_folder_id)
|
start_page_token = folder_tokens.get(target_folder_id)
|
||||||
|
is_composio_connector = (
|
||||||
|
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
|
)
|
||||||
can_use_delta = (
|
can_use_delta = (
|
||||||
use_delta_sync and start_page_token and connector.last_indexed_at
|
not is_composio_connector
|
||||||
|
and use_delta_sync
|
||||||
|
and start_page_token
|
||||||
|
and connector.last_indexed_at
|
||||||
)
|
)
|
||||||
|
|
||||||
documents_unsupported = 0
|
documents_unsupported = 0
|
||||||
|
|
@ -1051,7 +1239,16 @@ async def index_google_drive_files(
|
||||||
)
|
)
|
||||||
|
|
||||||
if documents_indexed > 0 or can_use_delta:
|
if documents_indexed > 0 or can_use_delta:
|
||||||
new_token, token_error = await get_start_page_token(drive_client)
|
if isinstance(drive_client, ComposioDriveClient):
|
||||||
|
(
|
||||||
|
new_token,
|
||||||
|
token_error,
|
||||||
|
) = await drive_client.composio.get_drive_start_page_token(
|
||||||
|
drive_client.connected_account_id,
|
||||||
|
drive_client.entity_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_token, token_error = await get_start_page_token(drive_client)
|
||||||
if new_token and not token_error:
|
if new_token and not token_error:
|
||||||
await session.refresh(connector)
|
await session.refresh(connector)
|
||||||
if "folder_tokens" not in connector.config:
|
if "folder_tokens" not in connector.config:
|
||||||
|
|
@ -1137,32 +1334,17 @@ async def index_google_drive_single_file(
|
||||||
)
|
)
|
||||||
return 0, error_msg
|
return 0, error_msg
|
||||||
|
|
||||||
pre_built_credentials = None
|
drive_client, client_error = _build_drive_client_for_connector(
|
||||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
session, connector_id, connector, user_id
|
||||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
)
|
||||||
if not connected_account_id:
|
if client_error or not drive_client:
|
||||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
await task_logger.log_task_failure(
|
||||||
await task_logger.log_task_failure(
|
log_entry,
|
||||||
log_entry,
|
client_error or "Failed to initialize Google Drive client",
|
||||||
error_msg,
|
"Missing connector credentials",
|
||||||
"Missing Composio account",
|
{"error_type": "ClientInitializationError"},
|
||||||
{"error_type": "MissingComposioAccount"},
|
)
|
||||||
)
|
return 0, client_error
|
||||||
return 0, error_msg
|
|
||||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
|
||||||
else:
|
|
||||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
|
||||||
if token_encrypted and not config.SECRET_KEY:
|
|
||||||
await task_logger.log_task_failure(
|
|
||||||
log_entry,
|
|
||||||
"SECRET_KEY not configured but credentials are encrypted",
|
|
||||||
"Missing SECRET_KEY",
|
|
||||||
{"error_type": "MissingSecretKey"},
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
0,
|
|
||||||
"SECRET_KEY not configured but credentials are marked as encrypted",
|
|
||||||
)
|
|
||||||
|
|
||||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||||
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
|
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
|
||||||
|
|
@ -1171,10 +1353,6 @@ async def index_google_drive_single_file(
|
||||||
from app.services.llm_service import get_vision_llm
|
from app.services.llm_service import get_vision_llm
|
||||||
|
|
||||||
vision_llm = await get_vision_llm(session, search_space_id)
|
vision_llm = await get_vision_llm(session, search_space_id)
|
||||||
drive_client = GoogleDriveClient(
|
|
||||||
session, connector_id, credentials=pre_built_credentials
|
|
||||||
)
|
|
||||||
|
|
||||||
file, error = await get_file_by_id(drive_client, file_id)
|
file, error = await get_file_by_id(drive_client, file_id)
|
||||||
if error or not file:
|
if error or not file:
|
||||||
error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}"
|
error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}"
|
||||||
|
|
@ -1276,32 +1454,18 @@ async def index_google_drive_selected_files(
|
||||||
)
|
)
|
||||||
return 0, 0, [error_msg]
|
return 0, 0, [error_msg]
|
||||||
|
|
||||||
pre_built_credentials = None
|
drive_client, client_error = _build_drive_client_for_connector(
|
||||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
session, connector_id, connector, user_id
|
||||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
)
|
||||||
if not connected_account_id:
|
if client_error or not drive_client:
|
||||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
error_msg = client_error or "Failed to initialize Google Drive client"
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry,
|
log_entry,
|
||||||
error_msg,
|
error_msg,
|
||||||
"Missing Composio account",
|
"Missing connector credentials",
|
||||||
{"error_type": "MissingComposioAccount"},
|
{"error_type": "ClientInitializationError"},
|
||||||
)
|
)
|
||||||
return 0, 0, [error_msg]
|
return 0, 0, [error_msg]
|
||||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
|
||||||
else:
|
|
||||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
|
||||||
if token_encrypted and not config.SECRET_KEY:
|
|
||||||
error_msg = (
|
|
||||||
"SECRET_KEY not configured but credentials are marked as encrypted"
|
|
||||||
)
|
|
||||||
await task_logger.log_task_failure(
|
|
||||||
log_entry,
|
|
||||||
error_msg,
|
|
||||||
"Missing SECRET_KEY",
|
|
||||||
{"error_type": "MissingSecretKey"},
|
|
||||||
)
|
|
||||||
return 0, 0, [error_msg]
|
|
||||||
|
|
||||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||||
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
|
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
|
||||||
|
|
@ -1310,10 +1474,6 @@ async def index_google_drive_selected_files(
|
||||||
from app.services.llm_service import get_vision_llm
|
from app.services.llm_service import get_vision_llm
|
||||||
|
|
||||||
vision_llm = await get_vision_llm(session, search_space_id)
|
vision_llm = await get_vision_llm(session, search_space_id)
|
||||||
drive_client = GoogleDriveClient(
|
|
||||||
session, connector_id, credentials=pre_built_credentials
|
|
||||||
)
|
|
||||||
|
|
||||||
indexed, skipped, unsupported, errors = await _index_selected_files(
|
indexed, skipped, unsupported, errors = await _index_selected_files(
|
||||||
drive_client,
|
drive_client,
|
||||||
session,
|
session,
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
IndexingPipelineService,
|
IndexingPipelineService,
|
||||||
PlaceholderInfo,
|
PlaceholderInfo,
|
||||||
)
|
)
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
from app.utils.google_credentials import (
|
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
|
||||||
build_composio_credentials,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
calculate_date_range,
|
calculate_date_range,
|
||||||
|
|
@ -44,6 +42,62 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_composio_gmail_message(message: dict) -> dict:
|
||||||
|
if message.get("payload"):
|
||||||
|
return message
|
||||||
|
|
||||||
|
headers = []
|
||||||
|
header_values = {
|
||||||
|
"Subject": message.get("subject"),
|
||||||
|
"From": message.get("from") or message.get("sender"),
|
||||||
|
"To": message.get("to") or message.get("recipient"),
|
||||||
|
"Date": message.get("date"),
|
||||||
|
}
|
||||||
|
for name, value in header_values.items():
|
||||||
|
if value:
|
||||||
|
headers.append({"name": name, "value": value})
|
||||||
|
|
||||||
|
return {
|
||||||
|
**message,
|
||||||
|
"id": message.get("id")
|
||||||
|
or message.get("message_id")
|
||||||
|
or message.get("messageId"),
|
||||||
|
"threadId": message.get("threadId") or message.get("thread_id"),
|
||||||
|
"payload": {"headers": headers},
|
||||||
|
"snippet": message.get("snippet", ""),
|
||||||
|
"messageText": message.get("messageText") or message.get("body") or "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _format_gmail_message_to_markdown(message: dict) -> str:
|
||||||
|
headers = {
|
||||||
|
header.get("name", "").lower(): header.get("value", "")
|
||||||
|
for header in message.get("payload", {}).get("headers", [])
|
||||||
|
if isinstance(header, dict)
|
||||||
|
}
|
||||||
|
subject = headers.get("subject", "No Subject")
|
||||||
|
from_email = headers.get("from", "Unknown Sender")
|
||||||
|
to_email = headers.get("to", "Unknown Recipient")
|
||||||
|
date_str = headers.get("date", "Unknown Date")
|
||||||
|
message_text = (
|
||||||
|
message.get("messageText")
|
||||||
|
or message.get("body")
|
||||||
|
or message.get("text")
|
||||||
|
or message.get("snippet", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"# {subject}\n\n"
|
||||||
|
f"**From:** {from_email}\n"
|
||||||
|
f"**To:** {to_email}\n"
|
||||||
|
f"**Date:** {date_str}\n\n"
|
||||||
|
f"## Message Content\n\n{message_text}\n\n"
|
||||||
|
f"## Message Details\n\n"
|
||||||
|
f"- **Message ID:** {message.get('id', 'Unknown')}\n"
|
||||||
|
f"- **Thread ID:** {message.get('threadId', 'Unknown')}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_connector_doc(
|
def _build_connector_doc(
|
||||||
message: dict,
|
message: dict,
|
||||||
markdown_content: str,
|
markdown_content: str,
|
||||||
|
|
@ -162,7 +216,14 @@ async def index_google_gmail_messages(
|
||||||
)
|
)
|
||||||
return 0, 0, error_msg
|
return 0, 0, error_msg
|
||||||
|
|
||||||
# ── Credential building ───────────────────────────────────────
|
is_composio_connector = (
|
||||||
|
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
|
)
|
||||||
|
gmail_connector = None
|
||||||
|
composio_service = None
|
||||||
|
connected_account_id = None
|
||||||
|
|
||||||
|
# ── Credential/client building ────────────────────────────────
|
||||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||||
if not connected_account_id:
|
if not connected_account_id:
|
||||||
|
|
@ -173,7 +234,7 @@ async def index_google_gmail_messages(
|
||||||
{"error_type": "MissingComposioAccount"},
|
{"error_type": "MissingComposioAccount"},
|
||||||
)
|
)
|
||||||
return 0, 0, "Composio connected_account_id not found"
|
return 0, 0, "Composio connected_account_id not found"
|
||||||
credentials = build_composio_credentials(connected_account_id)
|
composio_service = ComposioService()
|
||||||
else:
|
else:
|
||||||
config_data = connector.config
|
config_data = connector.config
|
||||||
|
|
||||||
|
|
@ -241,9 +302,10 @@ async def index_google_gmail_messages(
|
||||||
{"stage": "client_initialization"},
|
{"stage": "client_initialization"},
|
||||||
)
|
)
|
||||||
|
|
||||||
gmail_connector = GoogleGmailConnector(
|
if not is_composio_connector:
|
||||||
credentials, session, user_id, connector_id
|
gmail_connector = GoogleGmailConnector(
|
||||||
)
|
credentials, session, user_id, connector_id
|
||||||
|
)
|
||||||
|
|
||||||
calculated_start_date, calculated_end_date = calculate_date_range(
|
calculated_start_date, calculated_end_date = calculate_date_range(
|
||||||
connector, start_date, end_date, default_days_back=365
|
connector, start_date, end_date, default_days_back=365
|
||||||
|
|
@ -254,11 +316,60 @@ async def index_google_gmail_messages(
|
||||||
f"Fetching emails for connector {connector_id} "
|
f"Fetching emails for connector {connector_id} "
|
||||||
f"from {calculated_start_date} to {calculated_end_date}"
|
f"from {calculated_start_date} to {calculated_end_date}"
|
||||||
)
|
)
|
||||||
messages, error = await gmail_connector.get_recent_messages(
|
if is_composio_connector:
|
||||||
max_results=max_messages,
|
query_parts = []
|
||||||
start_date=calculated_start_date,
|
if calculated_start_date:
|
||||||
end_date=calculated_end_date,
|
query_parts.append(f"after:{calculated_start_date.replace('-', '/')}")
|
||||||
)
|
if calculated_end_date:
|
||||||
|
query_parts.append(f"before:{calculated_end_date.replace('-', '/')}")
|
||||||
|
query = " ".join(query_parts)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
page_token = None
|
||||||
|
error = None
|
||||||
|
while len(messages) < max_messages:
|
||||||
|
page_size = min(50, max_messages - len(messages))
|
||||||
|
(
|
||||||
|
page_messages,
|
||||||
|
page_token,
|
||||||
|
_estimate,
|
||||||
|
page_error,
|
||||||
|
) = await composio_service.get_gmail_messages(
|
||||||
|
connected_account_id=connected_account_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
query=query,
|
||||||
|
max_results=page_size,
|
||||||
|
page_token=page_token,
|
||||||
|
)
|
||||||
|
if page_error:
|
||||||
|
error = page_error
|
||||||
|
break
|
||||||
|
for page_message in page_messages:
|
||||||
|
message_id = (
|
||||||
|
page_message.get("id")
|
||||||
|
or page_message.get("message_id")
|
||||||
|
or page_message.get("messageId")
|
||||||
|
)
|
||||||
|
if message_id:
|
||||||
|
(
|
||||||
|
detail,
|
||||||
|
detail_error,
|
||||||
|
) = await composio_service.get_gmail_message_detail(
|
||||||
|
connected_account_id=connected_account_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
if not detail_error and isinstance(detail, dict):
|
||||||
|
page_message = detail
|
||||||
|
messages.append(_normalize_composio_gmail_message(page_message))
|
||||||
|
if not page_token:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
messages, error = await gmail_connector.get_recent_messages(
|
||||||
|
max_results=max_messages,
|
||||||
|
start_date=calculated_start_date,
|
||||||
|
end_date=calculated_end_date,
|
||||||
|
)
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
error_message = error
|
error_message = error
|
||||||
|
|
@ -326,7 +437,12 @@ async def index_google_gmail_messages(
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
markdown_content = gmail_connector.format_message_to_markdown(message)
|
if is_composio_connector:
|
||||||
|
markdown_content = _format_gmail_message_to_markdown(message)
|
||||||
|
else:
|
||||||
|
markdown_content = gmail_connector.format_message_to_markdown(
|
||||||
|
message
|
||||||
|
)
|
||||||
if not markdown_content.strip():
|
if not markdown_content.strip():
|
||||||
logger.warning(f"Skipping message with no content: {message_id}")
|
logger.warning(f"Skipping message with no content: {message_id}")
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
|
|
|
||||||
|
|
@ -51,22 +51,34 @@ class _FakeToolMessage:
|
||||||
tool_call_id: str | None = None
|
tool_call_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeInterrupt:
|
||||||
|
value: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeTask:
|
||||||
|
interrupts: tuple[_FakeInterrupt, ...] = ()
|
||||||
|
|
||||||
|
|
||||||
class _FakeAgentState:
|
class _FakeAgentState:
|
||||||
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
|
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, tasks: list[Any] | None = None) -> None:
|
||||||
# Empty values keeps the cloud-fallback safety-net branch a no-op,
|
# Empty values keeps the cloud-fallback safety-net branch a no-op,
|
||||||
# and an empty ``tasks`` list keeps the post-stream interrupt
|
# and empty ``tasks`` keep the post-stream interrupt check a no-op too.
|
||||||
# check a no-op too.
|
|
||||||
self.values: dict[str, Any] = {}
|
self.values: dict[str, Any] = {}
|
||||||
self.tasks: list[Any] = []
|
self.tasks: list[Any] = tasks or []
|
||||||
|
|
||||||
|
|
||||||
class _FakeAgent:
|
class _FakeAgent:
|
||||||
"""Replays a list of ``astream_events`` events."""
|
"""Replays a list of ``astream_events`` events."""
|
||||||
|
|
||||||
def __init__(self, events: list[dict[str, Any]]) -> None:
|
def __init__(
|
||||||
|
self, events: list[dict[str, Any]], state: _FakeAgentState | None = None
|
||||||
|
) -> None:
|
||||||
self._events = events
|
self._events = events
|
||||||
|
self._state = state or _FakeAgentState()
|
||||||
|
|
||||||
async def astream_events( # type: ignore[no-untyped-def]
|
async def astream_events( # type: ignore[no-untyped-def]
|
||||||
self, _input_data: Any, *, config: dict[str, Any], version: str
|
self, _input_data: Any, *, config: dict[str, Any], version: str
|
||||||
|
|
@ -79,7 +91,7 @@ class _FakeAgent:
|
||||||
# Called once after astream_events drains so the cloud-fallback
|
# Called once after astream_events drains so the cloud-fallback
|
||||||
# safety net can inspect staged filesystem work. The fake stays
|
# safety net can inspect staged filesystem work. The fake stays
|
||||||
# empty so the safety net is a no-op.
|
# empty so the safety net is a no-op.
|
||||||
return _FakeAgentState()
|
return self._state
|
||||||
|
|
||||||
|
|
||||||
def _model_stream(
|
def _model_stream(
|
||||||
|
|
@ -170,11 +182,13 @@ def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
async def _drain(
|
||||||
|
events: list[dict[str, Any]], state: _FakeAgentState | None = None
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
"""Run ``_stream_agent_events`` against a fake agent and return the
|
"""Run ``_stream_agent_events`` against a fake agent and return the
|
||||||
SSE payloads (parsed JSON) it yielded.
|
SSE payloads (parsed JSON) it yielded.
|
||||||
"""
|
"""
|
||||||
agent = _FakeAgent(events)
|
agent = _FakeAgent(events, state=state)
|
||||||
service = VercelStreamingService()
|
service = VercelStreamingService()
|
||||||
result = StreamResult()
|
result = StreamResult()
|
||||||
config = {"configurable": {"thread_id": "test-thread"}}
|
config = {"configurable": {"thread_id": "test-thread"}}
|
||||||
|
|
@ -525,3 +539,29 @@ async def test_unmatched_fallback_still_attaches_lc_id(
|
||||||
assert len(starts) == 1
|
assert len(starts) == 1
|
||||||
assert starts[0]["toolCallId"].startswith("call_run-1")
|
assert starts[0]["toolCallId"].startswith("call_run-1")
|
||||||
assert starts[0]["langchainToolCallId"] == "lc-orphan"
|
assert starts[0]["langchainToolCallId"] == "lc-orphan"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_interrupt_request_uses_task_that_contains_interrupt(
|
||||||
|
parity_v2_on: None,
|
||||||
|
) -> None:
|
||||||
|
interrupt_payload = {
|
||||||
|
"type": "calendar_event_create",
|
||||||
|
"action": {
|
||||||
|
"tool": "create_calendar_event",
|
||||||
|
"params": {"summary": "mom bday"},
|
||||||
|
},
|
||||||
|
"context": {},
|
||||||
|
}
|
||||||
|
state = _FakeAgentState(
|
||||||
|
tasks=[
|
||||||
|
_FakeTask(interrupts=()),
|
||||||
|
_FakeTask(interrupts=(_FakeInterrupt(value=interrupt_payload),)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
payloads = await _drain([], state=state)
|
||||||
|
|
||||||
|
interrupts = _of_type(payloads, "data-interrupt-request")
|
||||||
|
assert len(interrupts) == 1
|
||||||
|
assert interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue