mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
merge upstream/dev into feat/migrate-electric-to-zero
Resolve 8 conflicts: - Accept upstream deletion of 3 composio_*_connector.py (unified Google connectors) - Accept our deletion of ElectricProvider.tsx, use-connectors-electric.ts, use-messages-electric.ts (replaced by Zero equivalents) - Keep both new deps in package.json (@rocicorp/zero + @slate-serializers/html) - Regenerate pnpm-lock.yaml
This commit is contained in:
commit
5d8a62a4a6
207 changed files with 28023 additions and 12247 deletions
|
|
@ -100,7 +100,8 @@ TEAMS_CLIENT_ID=your_teams_client_id_here
|
|||
TEAMS_CLIENT_SECRET=your_teams_client_secret_here
|
||||
TEAMS_REDIRECT_URI=http://localhost:8000/api/v1/auth/teams/connector/callback
|
||||
|
||||
#Composio Coonnector
|
||||
# Composio Connector
|
||||
# NOTE: Disable "Mask Connected Account Secrets" in Composio dashboard (Settings → Project Settings) for Google indexing to work.
|
||||
COMPOSIO_API_KEY=your_api_key_here
|
||||
COMPOSIO_ENABLED=TRUE
|
||||
COMPOSIO_REDIRECT_URI=http://localhost:8000/api/v1/auth/composio/connector/callback
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.system_prompt import (
|
||||
build_configurable_system_prompt,
|
||||
build_surfsense_system_prompt,
|
||||
|
|
@ -65,10 +68,11 @@ _CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = {
|
|||
"BOOKSTACK_CONNECTOR": "BOOKSTACK_CONNECTOR",
|
||||
"CIRCLEBACK_CONNECTOR": "CIRCLEBACK", # Connector type differs from document type
|
||||
"OBSIDIAN_CONNECTOR": "OBSIDIAN_CONNECTOR",
|
||||
# Composio connectors
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"COMPOSIO_GMAIL_CONNECTOR": "COMPOSIO_GMAIL_CONNECTOR",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
# Composio connectors (unified to native document types).
|
||||
# Reverse of NATIVE_TO_LEGACY_DOCTYPE in app.db.
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "GOOGLE_DRIVE_FILE",
|
||||
"COMPOSIO_GMAIL_CONNECTOR": "GOOGLE_GMAIL_CONNECTOR",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "GOOGLE_CALENDAR_CONNECTOR",
|
||||
}
|
||||
|
||||
# Document types that don't come from SearchSourceConnector but should always be searchable
|
||||
|
|
@ -292,6 +296,69 @@ async def create_surfsense_deep_agent(
|
|||
]
|
||||
modified_disabled_tools.extend(linear_tools)
|
||||
|
||||
# Disable Google Drive action tools if no Google Drive connector is configured
|
||||
has_google_drive_connector = (
|
||||
available_connectors is not None and "GOOGLE_DRIVE_FILE" in available_connectors
|
||||
)
|
||||
if not has_google_drive_connector:
|
||||
google_drive_tools = [
|
||||
"create_google_drive_file",
|
||||
"delete_google_drive_file",
|
||||
]
|
||||
modified_disabled_tools.extend(google_drive_tools)
|
||||
|
||||
# Disable Google Calendar action tools if no Google Calendar connector is configured
|
||||
has_google_calendar_connector = (
|
||||
available_connectors is not None
|
||||
and "GOOGLE_CALENDAR_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_google_calendar_connector:
|
||||
calendar_tools = [
|
||||
"create_calendar_event",
|
||||
"update_calendar_event",
|
||||
"delete_calendar_event",
|
||||
]
|
||||
modified_disabled_tools.extend(calendar_tools)
|
||||
|
||||
# Disable Gmail action tools if no Gmail connector is configured
|
||||
has_gmail_connector = (
|
||||
available_connectors is not None
|
||||
and "GOOGLE_GMAIL_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_gmail_connector:
|
||||
gmail_tools = [
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"send_gmail_email",
|
||||
"trash_gmail_email",
|
||||
]
|
||||
modified_disabled_tools.extend(gmail_tools)
|
||||
|
||||
# Disable Jira action tools if no Jira connector is configured
|
||||
has_jira_connector = (
|
||||
available_connectors is not None and "JIRA_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_jira_connector:
|
||||
jira_tools = [
|
||||
"create_jira_issue",
|
||||
"update_jira_issue",
|
||||
"delete_jira_issue",
|
||||
]
|
||||
modified_disabled_tools.extend(jira_tools)
|
||||
|
||||
# Disable Confluence action tools if no Confluence connector is configured
|
||||
has_confluence_connector = (
|
||||
available_connectors is not None
|
||||
and "CONFLUENCE_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_confluence_connector:
|
||||
confluence_tools = [
|
||||
"create_confluence_page",
|
||||
"update_confluence_page",
|
||||
"delete_confluence_page",
|
||||
]
|
||||
modified_disabled_tools.extend(confluence_tools)
|
||||
|
||||
# Build tools using the async registry (includes MCP tools)
|
||||
_t0 = time.perf_counter()
|
||||
tools = await build_tools_async(
|
||||
|
|
@ -345,6 +412,7 @@ async def create_surfsense_deep_agent(
|
|||
system_prompt=system_prompt,
|
||||
context_schema=SurfSenseContextSchema,
|
||||
checkpointer=checkpointer,
|
||||
middleware=[DedupHITLToolCallsMiddleware()],
|
||||
**deep_agent_kwargs,
|
||||
)
|
||||
_perf_log.info(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,93 @@
|
|||
"""Middleware that deduplicates HITL tool calls within a single LLM response.
|
||||
|
||||
When the LLM emits multiple calls to the same HITL tool with the same
|
||||
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
|
||||
only the first call is kept. Non-HITL tools are never touched.
|
||||
|
||||
This runs in the ``after_model`` hook — **before** any tool executes — so
|
||||
the duplicate call is stripped from the AIMessage that gets checkpointed.
|
||||
That means it is also safe across LangGraph ``interrupt()`` boundaries:
|
||||
the removed call will never appear on graph resume.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_HITL_TOOL_DEDUP_KEYS: dict[str, str] = {
|
||||
"delete_calendar_event": "event_title_or_id",
|
||||
"update_calendar_event": "event_title_or_id",
|
||||
"trash_gmail_email": "email_subject_or_id",
|
||||
"update_gmail_draft": "draft_subject_or_id",
|
||||
"delete_google_drive_file": "file_name",
|
||||
"delete_notion_page": "page_title",
|
||||
"update_notion_page": "page_title",
|
||||
"delete_linear_issue": "issue_ref",
|
||||
"update_linear_issue": "issue_ref",
|
||||
"update_jira_issue": "issue_title_or_key",
|
||||
"delete_jira_issue": "issue_title_or_key",
|
||||
"update_confluence_page": "page_title_or_id",
|
||||
"delete_confluence_page": "page_title_or_id",
|
||||
}
|
||||
|
||||
|
||||
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Remove duplicate HITL tool calls from a single LLM response.
|
||||
|
||||
Only the **first** occurrence of each (tool-name, primary-arg-value)
|
||||
pair is kept; subsequent duplicates are silently dropped.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def after_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state)
|
||||
|
||||
async def aafter_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state)
|
||||
|
||||
@staticmethod
|
||||
def _dedup(state: AgentState) -> dict[str, Any] | None: # type: ignore[type-arg]
|
||||
messages = state.get("messages")
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_msg = messages[-1]
|
||||
if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None):
|
||||
return None
|
||||
|
||||
tool_calls: list[dict[str, Any]] = last_msg.tool_calls
|
||||
seen: set[tuple[str, str]] = set()
|
||||
deduped: list[dict[str, Any]] = []
|
||||
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
dedup_key_arg = _HITL_TOOL_DEDUP_KEYS.get(name)
|
||||
if dedup_key_arg is not None:
|
||||
arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower()
|
||||
key = (name, arg_val)
|
||||
if key in seen:
|
||||
logger.info(
|
||||
"Dedup: dropped duplicate HITL tool call %s(%s)",
|
||||
name,
|
||||
arg_val,
|
||||
)
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(tc)
|
||||
|
||||
if len(deduped) == len(tool_calls):
|
||||
return None
|
||||
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
|
||||
return {"messages": [updated_msg]}
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
"""Confluence tools for creating, updating, and deleting pages."""
|
||||
|
||||
from .create_page import create_create_confluence_page_tool
|
||||
from .delete_page import create_delete_confluence_page_tool
|
||||
from .update_page import create_update_confluence_page_tool
|
||||
|
||||
__all__ = [
|
||||
"create_create_confluence_page_tool",
|
||||
"create_delete_confluence_page_tool",
|
||||
"create_update_confluence_page_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_confluence_page_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_confluence_page(
|
||||
title: str,
|
||||
content: str | None = None,
|
||||
space_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new page in Confluence.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new Confluence page.
|
||||
|
||||
Args:
|
||||
title: Title of the page.
|
||||
content: Optional HTML/storage format content for the page body.
|
||||
space_id: Optional Confluence space ID to create the page in.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, page_id, and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(f"create_confluence_page called: title='{title}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Confluence accounts need re-authentication.",
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "confluence_page_creation",
|
||||
"action": {
|
||||
"tool": "create_confluence_page",
|
||||
"params": {
|
||||
"title": title,
|
||||
"content": content,
|
||||
"space_id": space_id,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not created.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_title = final_params.get("title", title)
|
||||
final_content = final_params.get("content", content) or ""
|
||||
final_space_id = final_params.get("space_id", space_id)
|
||||
final_connector_id = final_params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
return {"status": "error", "message": "Page title cannot be empty."}
|
||||
if not final_space_id:
|
||||
return {"status": "error", "message": "A space must be selected."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Confluence connector found.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
api_result = await client.create_page(
|
||||
space_id=final_space_id,
|
||||
title=final_title,
|
||||
body=final_content,
|
||||
)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
_conn = connector
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
page_id = str(api_result.get("id", ""))
|
||||
page_links = (
|
||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||
)
|
||||
page_url = ""
|
||||
if page_links.get("base") and page_links.get("webui"):
|
||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.confluence import ConfluenceKBSyncService
|
||||
|
||||
kb_service = ConfluenceKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
page_id=page_id,
|
||||
page_title=final_title,
|
||||
space_id=final_space_id,
|
||||
body_content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": page_id,
|
||||
"page_url": page_url,
|
||||
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error creating Confluence page: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the page.",
|
||||
}
|
||||
|
||||
return create_confluence_page
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_confluence_page_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_confluence_page(
|
||||
page_title_or_id: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a Confluence page.
|
||||
|
||||
Use this tool when the user asks to delete or remove a Confluence page.
|
||||
|
||||
Args:
|
||||
page_title_or_id: The page title or ID to identify the page.
|
||||
delete_from_kb: Whether to also remove from the knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message, and deleted_from_kb.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message to the user.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, page_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
page_data = context["page"]
|
||||
page_id = page_data["page_id"]
|
||||
page_title = page_data.get("page_title", "")
|
||||
document_id = page_data["document_id"]
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "confluence_page_deletion",
|
||||
"action": {
|
||||
"tool": "delete_confluence_page",
|
||||
"params": {
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not deleted.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_page_id = final_params.get("page_id", page_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
await client.delete_page(final_page_id)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
|
||||
message = f"Confluence page '{page_title}' deleted successfully."
|
||||
if deleted_from_kb:
|
||||
message += " Also removed from the knowledge base."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": final_page_id,
|
||||
"deleted_from_kb": deleted_from_kb,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error deleting Confluence page: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while deleting the page.",
|
||||
}
|
||||
|
||||
return delete_confluence_page
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_confluence_page_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_confluence_page(
|
||||
page_title_or_id: str,
|
||||
new_title: str | None = None,
|
||||
new_content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Confluence page.
|
||||
|
||||
Use this tool when the user asks to modify or edit a Confluence page.
|
||||
|
||||
Args:
|
||||
page_title_or_id: The page title or ID to identify the page.
|
||||
new_title: Optional new title for the page.
|
||||
new_content: Optional new HTML/storage format content.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message to the user.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, page_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
page_data = context["page"]
|
||||
page_id = page_data["page_id"]
|
||||
current_title = page_data["page_title"]
|
||||
current_body = page_data.get("body", "")
|
||||
current_version = page_data.get("version", 1)
|
||||
document_id = page_data.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "confluence_page_update",
|
||||
"action": {
|
||||
"tool": "update_confluence_page",
|
||||
"params": {
|
||||
"page_id": page_id,
|
||||
"document_id": document_id,
|
||||
"new_title": new_title,
|
||||
"new_content": new_content,
|
||||
"version": current_version,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not updated.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_page_id = final_params.get("page_id", page_id)
|
||||
final_title = final_params.get("new_title", new_title) or current_title
|
||||
final_content = final_params.get("new_content", new_content)
|
||||
if final_content is None:
|
||||
final_content = current_body
|
||||
final_version = final_params.get("version", current_version)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_document_id = final_params.get("document_id", document_id)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
api_result = await client.update_page(
|
||||
page_id=final_page_id,
|
||||
title=final_title,
|
||||
body=final_content,
|
||||
version_number=final_version + 1,
|
||||
)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
page_links = (
|
||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||
)
|
||||
page_url = ""
|
||||
if page_links.get("base") and page_links.get("webui"):
|
||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||
|
||||
kb_message_suffix = ""
|
||||
if final_document_id:
|
||||
try:
|
||||
from app.services.confluence import ConfluenceKBSyncService
|
||||
|
||||
kb_service = ConfluenceKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=final_document_id,
|
||||
page_id=final_page_id,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": final_page_id,
|
||||
"page_url": page_url,
|
||||
"message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error updating Confluence page: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the page.",
|
||||
}
|
||||
|
||||
return update_confluence_page
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from app.agents.new_chat.tools.gmail.create_draft import (
|
||||
create_create_gmail_draft_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.send_email import (
|
||||
create_send_gmail_email_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.trash_email import (
|
||||
create_trash_gmail_email_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.update_draft import (
|
||||
create_update_gmail_draft_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_gmail_draft_tool",
|
||||
"create_send_gmail_email_tool",
|
||||
"create_trash_gmail_email_tool",
|
||||
"create_update_gmail_draft_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,341 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_gmail_draft_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_gmail_draft(
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a draft email in Gmail.
|
||||
|
||||
Use when the user asks to draft, compose, or prepare an email without
|
||||
sending it.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
subject: Email subject line.
|
||||
body: Email body content.
|
||||
cc: Optional CC recipient(s), comma-separated.
|
||||
bcc: Optional BCC recipient(s), comma-separated.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", or "error"
|
||||
- draft_id: Gmail draft ID (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Draft an email to alice@example.com about the meeting"
|
||||
- "Compose a reply to Bob about the project update"
|
||||
"""
|
||||
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Gmail accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_draft_creation",
|
||||
"action": {
|
||||
"tool": "create_gmail_draft",
|
||||
"params": {
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_to = final_params.get("to", to)
|
||||
final_subject = final_params.get("subject", subject)
|
||||
final_body = final_params.get("body", body)
|
||||
final_cc = final_params.get("cc", cc)
|
||||
final_bcc = final_params.get("bcc", bcc)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.create(userId="me", body={"message": {"raw": raw}})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail draft created: id={created.get('id')}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.gmail import GmailKBSyncService
|
||||
|
||||
kb_service = GmailKBSyncService(db_session)
|
||||
draft_message = created.get("message", {})
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
message_id=draft_message.get("id", ""),
|
||||
thread_id=draft_message.get("threadId", ""),
|
||||
subject=final_subject,
|
||||
sender="me",
|
||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
body_text=final_body,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
draft_id=created.get("id"),
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"draft_id": created.get("id"),
|
||||
"message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error creating Gmail draft: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the draft. Please try again.",
|
||||
}
|
||||
|
||||
return create_gmail_draft
|
||||
343
surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
Normal file
343
surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
Normal file
|
|
@ -0,0 +1,343 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_send_gmail_email_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def send_gmail_email(
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send an email via Gmail.
|
||||
|
||||
Use when the user explicitly asks to send an email. This sends the
|
||||
email immediately - it cannot be unsent.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
subject: Email subject line.
|
||||
body: Email body content.
|
||||
cc: Optional CC recipient(s), comma-separated.
|
||||
bcc: Optional BCC recipient(s), comma-separated.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", or "error"
|
||||
- message_id: Gmail message ID (if success)
|
||||
- thread_id: Gmail thread ID (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Send an email to alice@example.com about the meeting"
|
||||
- "Email Bob the project update"
|
||||
"""
|
||||
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Gmail accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_email_send",
|
||||
"action": {
|
||||
"tool": "send_gmail_email",
|
||||
"params": {
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_to = final_params.get("to", to)
|
||||
final_subject = final_params.get("subject", subject)
|
||||
final_body = final_params.get("body", body)
|
||||
final_cc = final_params.get("cc", cc)
|
||||
final_bcc = final_params.get("bcc", bcc)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
sent = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"raw": raw})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.gmail import GmailKBSyncService
|
||||
|
||||
kb_service = GmailKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
message_id=sent.get("id", ""),
|
||||
thread_id=sent.get("threadId", ""),
|
||||
subject=final_subject,
|
||||
sender="me",
|
||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
body_text=final_body,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after send failed: {kb_err}")
|
||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": sent.get("id"),
|
||||
"thread_id": sent.get("threadId"),
|
||||
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error sending Gmail email: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while sending the email. Please try again.",
|
||||
}
|
||||
|
||||
return send_gmail_email
|
||||
337
surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
Normal file
337
surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_trash_gmail_email_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def trash_gmail_email(
|
||||
email_subject_or_id: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Move an email or draft to trash in Gmail.
|
||||
|
||||
Use when the user asks to delete, remove, or trash an email or draft.
|
||||
|
||||
Args:
|
||||
email_subject_or_id: The exact subject line or message ID of the
|
||||
email to trash (as it appears in the inbox).
|
||||
delete_from_kb: Whether to also remove the email from the knowledge base.
|
||||
Default is False.
|
||||
Set to True to remove from both Gmail and knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- message_id: Gmail message ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the email subject or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry this tool.
|
||||
Examples:
|
||||
- "Delete the email about 'Meeting Cancelled'"
|
||||
- "Trash the email from Bob about the project"
|
||||
"""
|
||||
logger.info(
|
||||
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_trash_context(
|
||||
search_space_id, user_id, email_subject_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Email not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Gmail account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
email = context["email"]
|
||||
message_id = email["message_id"]
|
||||
document_id = email.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not message_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_email_trash",
|
||||
"action": {
|
||||
"tool": "trash_gmail_email",
|
||||
"params": {
|
||||
"message_id": message_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_message_id = final_params.get("message_id", message_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this email.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.trash(userId="me", id=final_message_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail email trashed: message_id={final_message_id}")
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"message_id": final_message_id,
|
||||
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"Email trashed, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error trashing Gmail email: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while trashing the email. Please try again.",
|
||||
}
|
||||
|
||||
return trash_gmail_email
|
||||
|
|
@ -0,0 +1,438 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_gmail_draft_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_gmail_draft(
|
||||
draft_subject_or_id: str,
|
||||
body: str,
|
||||
to: str | None = None,
|
||||
subject: str | None = None,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Gmail draft.
|
||||
|
||||
Use when the user asks to modify, edit, or add content to an existing
|
||||
email draft. This replaces the draft content with the new version.
|
||||
The user will be able to review and edit the content before it is applied.
|
||||
|
||||
If the user simply wants to "edit" a draft without specifying exact changes,
|
||||
generate the body yourself using your best understanding of the conversation
|
||||
context. The user will review and can freely edit the content in the approval
|
||||
card before confirming.
|
||||
|
||||
IMPORTANT: This tool is ONLY for modifying Gmail draft content, NOT for
|
||||
deleting/trashing drafts (use trash_gmail_email instead), Notion pages,
|
||||
calendar events, or any other content type.
|
||||
|
||||
Args:
|
||||
draft_subject_or_id: The exact subject line of the draft to update
|
||||
(as it appears in Gmail drafts).
|
||||
body: The full updated body content for the draft. Generate this
|
||||
yourself based on the user's request and conversation context.
|
||||
to: Optional new recipient email address (keeps original if omitted).
|
||||
subject: Optional new subject line (keeps original if omitted).
|
||||
cc: Optional CC recipient(s), comma-separated.
|
||||
bcc: Optional BCC recipient(s), comma-separated.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- draft_id: Gmail draft ID (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the draft subject or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Update the Kurseong Plan draft with the new itinerary details"
|
||||
- "Edit my draft about the project proposal and change the recipient"
|
||||
- "Let me edit the meeting notes draft" (call with current body content so user can edit in the approval card)
|
||||
"""
|
||||
logger.info(
|
||||
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, draft_subject_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Draft not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Gmail account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
email = context["email"]
|
||||
message_id = email["message_id"]
|
||||
document_id = email.get("document_id")
|
||||
connector_id_from_context = account["id"]
|
||||
draft_id_from_context = context.get("draft_id")
|
||||
|
||||
original_subject = email.get("subject", draft_subject_or_id)
|
||||
final_subject_default = subject if subject else original_subject
|
||||
final_to_default = to if to else ""
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating Gmail draft: '{original_subject}' "
|
||||
f"(message_id={message_id}, draft_id={draft_id_from_context})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_draft_update",
|
||||
"action": {
|
||||
"tool": "update_gmail_draft",
|
||||
"params": {
|
||||
"message_id": message_id,
|
||||
"draft_id": draft_id_from_context,
|
||||
"to": final_to_default,
|
||||
"subject": final_subject_default,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_to = final_params.get("to", final_to_default)
|
||||
final_subject = final_params.get("subject", final_subject_default)
|
||||
final_body = final_params.get("body", body)
|
||||
final_cc = final_params.get("cc", cc)
|
||||
final_bcc = final_params.get("bcc", bcc)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_draft_id = final_params.get("draft_id", draft_id_from_context)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this draft.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
# Resolve draft_id if not already available
|
||||
if not final_draft_id:
|
||||
logger.info(
|
||||
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
||||
)
|
||||
final_draft_id = await _find_draft_id_by_message(
|
||||
gmail_service, message_id
|
||||
)
|
||||
|
||||
if not final_draft_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"Could not find this draft in Gmail. "
|
||||
"It may have already been sent or deleted."
|
||||
),
|
||||
}
|
||||
|
||||
message = MIMEText(final_body)
|
||||
if final_to:
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.update(
|
||||
userId="me",
|
||||
id=final_draft_id,
|
||||
body={"message": {"raw": raw}},
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail draft updated: id={updated.get('id')}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id:
|
||||
try:
|
||||
from sqlalchemy.future import select as sa_select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
sa_select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
document.source_markdown = final_body
|
||||
document.title = final_subject
|
||||
meta = dict(document.document_metadata or {})
|
||||
meta["subject"] = final_subject
|
||||
meta["draft_id"] = updated.get("id", final_draft_id)
|
||||
updated_msg = updated.get("message", {})
|
||||
if updated_msg.get("id"):
|
||||
meta["message_id"] = updated_msg["id"]
|
||||
document.document_metadata = meta
|
||||
flag_modified(document, "document_metadata")
|
||||
await db_session.commit()
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
logger.info(
|
||||
f"KB document {document_id} updated for draft {final_draft_id}"
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB update after draft edit failed: {kb_err}")
|
||||
await db_session.rollback()
|
||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"draft_id": updated.get("id"),
|
||||
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error updating Gmail draft: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the draft. Please try again.",
|
||||
}
|
||||
|
||||
return update_gmail_draft
|
||||
|
||||
|
||||
async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str | None:
|
||||
"""Look up a draft's ID by its message ID via the Gmail API."""
|
||||
try:
|
||||
page_token = None
|
||||
while True:
|
||||
kwargs: dict[str, Any] = {"userId": "me", "maxResults": 100}
|
||||
if page_token:
|
||||
kwargs["pageToken"] = page_token
|
||||
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda kwargs=kwargs: (
|
||||
gmail_service.users().drafts().list(**kwargs).execute()
|
||||
),
|
||||
)
|
||||
|
||||
for draft in response.get("drafts", []):
|
||||
if draft.get("message", {}).get("id") == message_id:
|
||||
return draft["id"]
|
||||
|
||||
page_token = response.get("nextPageToken")
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to look up draft by message_id: {e}")
|
||||
return None
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.google_calendar.create_event import (
|
||||
create_create_calendar_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.delete_event import (
|
||||
create_delete_calendar_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.update_event import (
|
||||
create_update_calendar_event_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_calendar_event_tool",
|
||||
"create_delete_calendar_event_tool",
|
||||
"create_update_calendar_event_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,352 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_calendar_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_calendar_event(
|
||||
summary: str,
|
||||
start_datetime: str,
|
||||
end_datetime: str,
|
||||
description: str | None = None,
|
||||
location: str | None = None,
|
||||
attendees: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new event on Google Calendar.
|
||||
|
||||
Use when the user asks to schedule, create, or add a calendar event.
|
||||
Ask for event details if not provided.
|
||||
|
||||
Args:
|
||||
summary: The event title.
|
||||
start_datetime: Start time in ISO 8601 format (e.g. "2026-03-20T10:00:00").
|
||||
end_datetime: End time in ISO 8601 format (e.g. "2026-03-20T11:00:00").
|
||||
description: Optional event description.
|
||||
location: Optional event location.
|
||||
attendees: Optional list of attendee email addresses.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "auth_error", or "error"
|
||||
- event_id: Google Calendar event ID (if success)
|
||||
- html_link: URL to open the event (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
|
||||
Examples:
|
||||
- "Schedule a meeting with John tomorrow at 10am"
|
||||
- "Create a calendar event for the team standup"
|
||||
"""
|
||||
logger.info(
|
||||
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning(
|
||||
"All Google Calendar accounts have expired authentication"
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating calendar event: summary='{summary}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_calendar_event_creation",
|
||||
"action": {
|
||||
"tool": "create_calendar_event",
|
||||
"params": {
|
||||
"summary": summary,
|
||||
"start_datetime": start_datetime,
|
||||
"end_datetime": end_datetime,
|
||||
"description": description,
|
||||
"location": location,
|
||||
"attendees": attendees,
|
||||
"timezone": context.get("timezone"),
|
||||
"connector_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_summary = final_params.get("summary", summary)
|
||||
final_start_datetime = final_params.get("start_datetime", start_datetime)
|
||||
final_end_datetime = final_params.get("end_datetime", end_datetime)
|
||||
final_description = final_params.get("description", description)
|
||||
final_location = final_params.get("location", location)
|
||||
final_attendees = final_params.get("attendees", attendees)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
|
||||
if not final_summary or not final_summary.strip():
|
||||
return {"status": "error", "message": "Event summary cannot be empty."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
tz = context.get("timezone", "UTC")
|
||||
event_body: dict[str, Any] = {
|
||||
"summary": final_summary,
|
||||
"start": {"dateTime": final_start_datetime, "timeZone": tz},
|
||||
"end": {"dateTime": final_end_datetime, "timeZone": tz},
|
||||
}
|
||||
if final_description:
|
||||
event_body["description"] = final_description
|
||||
if final_location:
|
||||
event_body["location"] = final_location
|
||||
if final_attendees:
|
||||
event_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_attendees if e.strip()
|
||||
]
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.insert(calendarId="primary", body=event_body)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
||||
|
||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
event_id=created.get("id"),
|
||||
event_summary=final_summary,
|
||||
calendar_id="primary",
|
||||
start_time=final_start_datetime,
|
||||
end_time=final_end_datetime,
|
||||
location=final_location,
|
||||
html_link=created.get("htmlLink"),
|
||||
description=final_description,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": created.get("id"),
|
||||
"html_link": created.get("htmlLink"),
|
||||
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error creating calendar event: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the event. Please try again.",
|
||||
}
|
||||
|
||||
return create_calendar_event
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_calendar_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_calendar_event(
|
||||
event_title_or_id: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a Google Calendar event.
|
||||
|
||||
Use when the user asks to delete, remove, or cancel a calendar event.
|
||||
|
||||
Args:
|
||||
event_title_or_id: The exact title or event ID of the event to delete.
|
||||
delete_from_kb: Whether to also remove the event from the knowledge base.
|
||||
Default is False.
|
||||
Set to True to remove from both Google Calendar and knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", "auth_error", or "error"
|
||||
- event_id: Google Calendar event ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the event name or check if it has been indexed.
|
||||
Examples:
|
||||
- "Delete the team standup event"
|
||||
- "Cancel my dentist appointment on Friday"
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, event_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Event not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch deletion context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Google Calendar account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
event = context["event"]
|
||||
event_id = event["event_id"]
|
||||
document_id = event.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not event_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_calendar_event_deletion",
|
||||
"action": {
|
||||
"tool": "delete_calendar_event",
|
||||
"params": {
|
||||
"event_id": event_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_event_id = final_params.get("event_id", event_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this event.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.delete(calendarId="primary", eventId=final_event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event deleted: event_id={final_event_id}")
|
||||
|
||||
delete_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"event_id": final_event_id,
|
||||
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
delete_result["warning"] = (
|
||||
f"Event deleted, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
delete_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
delete_result["message"] = (
|
||||
f"{delete_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return delete_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error deleting calendar event: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while deleting the event. Please try again.",
|
||||
}
|
||||
|
||||
return delete_calendar_event
|
||||
|
|
@ -0,0 +1,382 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_calendar_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_calendar_event(
|
||||
event_title_or_id: str,
|
||||
new_summary: str | None = None,
|
||||
new_start_datetime: str | None = None,
|
||||
new_end_datetime: str | None = None,
|
||||
new_description: str | None = None,
|
||||
new_location: str | None = None,
|
||||
new_attendees: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Google Calendar event.
|
||||
|
||||
Use when the user asks to modify, reschedule, or change a calendar event.
|
||||
|
||||
Args:
|
||||
event_title_or_id: The exact title or event ID of the event to update.
|
||||
new_summary: New event title (if changing).
|
||||
new_start_datetime: New start time in ISO 8601 format (if rescheduling).
|
||||
new_end_datetime: New end time in ISO 8601 format (if rescheduling).
|
||||
new_description: New event description (if changing).
|
||||
new_location: New event location (if changing).
|
||||
new_attendees: New list of attendee email addresses (if changing).
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", "auth_error", or "error"
|
||||
- event_id: Google Calendar event ID (if success)
|
||||
- html_link: URL to open the event (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the event name or check if it has been indexed.
|
||||
Examples:
|
||||
- "Reschedule the team standup to 3pm"
|
||||
- "Change the location of my dentist appointment"
|
||||
"""
|
||||
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, event_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Event not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
if context.get("auth_expired"):
|
||||
logger.warning("Google Calendar account has expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
event = context["event"]
|
||||
event_id = event["event_id"]
|
||||
document_id = event.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not event_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_calendar_event_update",
|
||||
"action": {
|
||||
"tool": "update_calendar_event",
|
||||
"params": {
|
||||
"event_id": event_id,
|
||||
"document_id": document_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"new_summary": new_summary,
|
||||
"new_start_datetime": new_start_datetime,
|
||||
"new_end_datetime": new_end_datetime,
|
||||
"new_description": new_description,
|
||||
"new_location": new_location,
|
||||
"new_attendees": new_attendees,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_event_id = final_params.get("event_id", event_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_new_summary = final_params.get("new_summary", new_summary)
|
||||
final_new_start_datetime = final_params.get(
|
||||
"new_start_datetime", new_start_datetime
|
||||
)
|
||||
final_new_end_datetime = final_params.get(
|
||||
"new_end_datetime", new_end_datetime
|
||||
)
|
||||
final_new_description = final_params.get("new_description", new_description)
|
||||
final_new_location = final_params.get("new_location", new_location)
|
||||
final_new_attendees = final_params.get("new_attendees", new_attendees)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this event.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
update_body: dict[str, Any] = {}
|
||||
if final_new_summary is not None:
|
||||
update_body["summary"] = final_new_summary
|
||||
if final_new_start_datetime is not None:
|
||||
tz = (
|
||||
context.get("timezone", "UTC")
|
||||
if isinstance(context, dict)
|
||||
else "UTC"
|
||||
)
|
||||
update_body["start"] = {
|
||||
"dateTime": final_new_start_datetime,
|
||||
"timeZone": tz,
|
||||
}
|
||||
if final_new_end_datetime is not None:
|
||||
tz = (
|
||||
context.get("timezone", "UTC")
|
||||
if isinstance(context, dict)
|
||||
else "UTC"
|
||||
)
|
||||
update_body["end"] = {
|
||||
"dateTime": final_new_end_datetime,
|
||||
"timeZone": tz,
|
||||
}
|
||||
if final_new_description is not None:
|
||||
update_body["description"] = final_new_description
|
||||
if final_new_location is not None:
|
||||
update_body["location"] = final_new_location
|
||||
if final_new_attendees is not None:
|
||||
update_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_new_attendees if e.strip()
|
||||
]
|
||||
|
||||
if not update_body:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No changes specified. Please provide at least one field to update.",
|
||||
}
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.patch(
|
||||
calendarId="primary",
|
||||
eventId=final_event_id,
|
||||
body=update_body,
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event updated: event_id={final_event_id}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id is not None:
|
||||
try:
|
||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
||||
|
||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=document_id,
|
||||
event_id=final_event_id,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": final_event_id,
|
||||
"html_link": updated.get("htmlLink"),
|
||||
"message": f"Successfully updated the calendar event.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error updating calendar event: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the event. Please try again.",
|
||||
}
|
||||
|
||||
return update_calendar_event
|
||||
|
|
@ -32,13 +32,16 @@ def create_create_google_drive_file_tool(
|
|||
"""Create a new Google Doc or Google Sheet in Google Drive.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new document
|
||||
or spreadsheet in Google Drive.
|
||||
or spreadsheet in Google Drive. The user MUST specify a topic before
|
||||
you call this tool. If the request does not contain a topic (e.g.
|
||||
"create a drive doc" or "make a Google Sheet"), ask what the file
|
||||
should be about. Never call this tool without a clear topic from the user.
|
||||
|
||||
Args:
|
||||
name: The file name (without extension).
|
||||
file_type: Either "google_doc" or "google_sheet".
|
||||
content: Optional initial content. For google_doc, provide markdown text.
|
||||
For google_sheet, provide CSV-formatted text.
|
||||
content: Optional initial content. Generate from the user's topic.
|
||||
For google_doc, provide markdown text. For google_sheet, provide CSV-formatted text.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
|
|
@ -55,8 +58,8 @@ def create_create_google_drive_file_tool(
|
|||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Create a Google Doc called 'Meeting Notes'"
|
||||
- "Create a spreadsheet named 'Budget 2026' with some sample data"
|
||||
- "Create a Google Doc with today's meeting notes"
|
||||
- "Create a spreadsheet for the 2026 budget"
|
||||
"""
|
||||
logger.info(
|
||||
f"create_google_drive_file called: name='{name}', type='{file_type}'"
|
||||
|
|
@ -84,6 +87,15 @@ def create_create_google_drive_file_tool(
|
|||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Google Drive accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
|
||||
)
|
||||
|
|
@ -154,14 +166,18 @@ def create_create_google_drive_file_tool(
|
|||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_drive_types = [
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
@ -176,8 +192,7 @@ def create_create_google_drive_file_tool(
|
|||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
@ -191,8 +206,22 @@ def create_create_google_drive_file_tool(
|
|||
logger.info(
|
||||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
created = await client.create_file(
|
||||
|
|
@ -206,22 +235,65 @@ def create_create_google_drive_file_tool(
|
|||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate.",
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.google_drive import GoogleDriveKBSyncService
|
||||
|
||||
kb_service = GoogleDriveKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
file_id=created.get("id"),
|
||||
file_name=created.get("name", final_name),
|
||||
mime_type=mime_type,
|
||||
web_view_link=created.get("webViewLink"),
|
||||
content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": created.get("id"),
|
||||
"name": created.get("name"),
|
||||
"web_view_link": created.get("webViewLink"),
|
||||
"message": f"Successfully created '{created.get('name')}' in Google Drive.",
|
||||
"message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ def create_delete_google_drive_file_tool(
|
|||
to verify the file name or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry this tool.
|
||||
|
||||
Examples:
|
||||
- "Delete the 'Meeting Notes' file from Google Drive"
|
||||
- "Trash the 'Old Budget' spreadsheet"
|
||||
|
|
@ -76,6 +75,18 @@ def create_delete_google_drive_file_tool(
|
|||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Google Drive account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
|
||||
file = context["file"]
|
||||
file_id = file["file_id"]
|
||||
document_id = file.get("document_id")
|
||||
|
|
@ -151,13 +162,17 @@ def create_delete_google_drive_file_tool(
|
|||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_drive_types = [
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
@ -170,7 +185,23 @@ def create_delete_google_drive_file_tool(
|
|||
logger.info(
|
||||
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
||||
)
|
||||
client = GoogleDriveClient(session=db_session, connector_id=connector.id)
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=connector.id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
await client.trash_file(file_id=final_file_id)
|
||||
except HttpError as http_err:
|
||||
|
|
@ -178,10 +209,26 @@ def create_delete_google_drive_file_tool(
|
|||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate.",
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
|
|
|
|||
11
surfsense_backend/app/agents/new_chat/tools/jira/__init__.py
Normal file
11
surfsense_backend/app/agents/new_chat/tools/jira/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Jira tools for creating, updating, and deleting issues."""
|
||||
|
||||
from .create_issue import create_create_jira_issue_tool
|
||||
from .delete_issue import create_delete_jira_issue_tool
|
||||
from .update_issue import create_update_jira_issue_tool
|
||||
|
||||
__all__ = [
|
||||
"create_create_jira_issue_tool",
|
||||
"create_delete_jira_issue_tool",
|
||||
"create_update_jira_issue_tool",
|
||||
]
|
||||
242
surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py
Normal file
242
surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_jira_issue_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_jira_issue(
|
||||
project_key: str,
|
||||
summary: str,
|
||||
issue_type: str = "Task",
|
||||
description: str | None = None,
|
||||
priority: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new issue in Jira.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new Jira issue/ticket.
|
||||
|
||||
Args:
|
||||
project_key: The Jira project key (e.g. "PROJ", "ENG").
|
||||
summary: Short, descriptive issue title.
|
||||
issue_type: Issue type (default "Task"). Others: "Bug", "Story", "Epic".
|
||||
description: Optional description body for the issue.
|
||||
priority: Optional priority name (e.g. "High", "Medium", "Low").
|
||||
|
||||
Returns:
|
||||
Dictionary with status, issue_key, and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user declined. Do NOT retry.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||
|
||||
try:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Jira accounts need re-authentication.",
|
||||
"connector_type": "jira",
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "jira_issue_creation",
|
||||
"action": {
|
||||
"tool": "create_jira_issue",
|
||||
"params": {
|
||||
"project_key": project_key,
|
||||
"summary": summary,
|
||||
"issue_type": issue_type,
|
||||
"description": description,
|
||||
"priority": priority,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not created.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_project_key = final_params.get("project_key", project_key)
|
||||
final_summary = final_params.get("summary", summary)
|
||||
final_issue_type = final_params.get("issue_type", issue_type)
|
||||
final_description = final_params.get("description", description)
|
||||
final_priority = final_params.get("priority", priority)
|
||||
final_connector_id = final_params.get("connector_id", connector_id)
|
||||
|
||||
if not final_summary or not final_summary.strip():
|
||||
return {"status": "error", "message": "Issue summary cannot be empty."}
|
||||
if not final_project_key:
|
||||
return {"status": "error", "message": "A project must be selected."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Jira connector found."}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
api_result = await asyncio.to_thread(
|
||||
jira_client.create_issue,
|
||||
project_key=final_project_key,
|
||||
summary=final_summary,
|
||||
issue_type=final_issue_type,
|
||||
description=final_description,
|
||||
priority=final_priority,
|
||||
)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
_conn = connector
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
issue_key = api_result.get("key", "")
|
||||
issue_url = (
|
||||
f"{jira_history._base_url}/browse/{issue_key}"
|
||||
if jira_history._base_url and issue_key
|
||||
else ""
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.jira import JiraKBSyncService
|
||||
|
||||
kb_service = JiraKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
issue_id=issue_key,
|
||||
issue_identifier=issue_key,
|
||||
issue_title=final_summary,
|
||||
description=final_description,
|
||||
state="To Do",
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": issue_key,
|
||||
"issue_url": issue_url,
|
||||
"message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error creating Jira issue: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the issue.",
|
||||
}
|
||||
|
||||
return create_jira_issue
|
||||
209
surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py
Normal file
209
surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_jira_issue_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_jira_issue(
|
||||
issue_title_or_key: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a Jira issue.
|
||||
|
||||
Use this tool when the user asks to delete or remove a Jira issue.
|
||||
|
||||
Args:
|
||||
issue_title_or_key: The issue key (e.g. "PROJ-42") or title.
|
||||
delete_from_kb: Whether to also remove from the knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message, and deleted_from_kb.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message to the user.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||
|
||||
try:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, issue_title_or_key
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "jira",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_data = context["issue"]
|
||||
issue_key = issue_data["issue_id"]
|
||||
document_id = issue_data["document_id"]
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "jira_issue_deletion",
|
||||
"action": {
|
||||
"tool": "delete_jira_issue",
|
||||
"params": {
|
||||
"issue_key": issue_key,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not deleted.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_issue_key = final_params.get("issue_key", issue_key)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
|
||||
message = f"Jira issue {final_issue_key} deleted successfully."
|
||||
if deleted_from_kb:
|
||||
message += " Also removed from the knowledge base."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": final_issue_key,
|
||||
"deleted_from_kb": deleted_from_kb,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error deleting Jira issue: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while deleting the issue.",
|
||||
}
|
||||
|
||||
return delete_jira_issue
|
||||
252
surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py
Normal file
252
surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_jira_issue_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_jira_issue(
|
||||
issue_title_or_key: str,
|
||||
new_summary: str | None = None,
|
||||
new_description: str | None = None,
|
||||
new_priority: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Jira issue.
|
||||
|
||||
Use this tool when the user asks to modify, edit, or update a Jira issue.
|
||||
|
||||
Args:
|
||||
issue_title_or_key: The issue key (e.g. "PROJ-42") or title to identify the issue.
|
||||
new_summary: Optional new title/summary for the issue.
|
||||
new_description: Optional new description.
|
||||
new_priority: Optional new priority name.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message and ask user to verify.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||
|
||||
try:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, issue_title_or_key
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "jira",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_data = context["issue"]
|
||||
issue_key = issue_data["issue_id"]
|
||||
document_id = issue_data.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "jira_issue_update",
|
||||
"action": {
|
||||
"tool": "update_jira_issue",
|
||||
"params": {
|
||||
"issue_key": issue_key,
|
||||
"document_id": document_id,
|
||||
"new_summary": new_summary,
|
||||
"new_description": new_description,
|
||||
"new_priority": new_priority,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not updated.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_issue_key = final_params.get("issue_key", issue_key)
|
||||
final_summary = final_params.get("new_summary", new_summary)
|
||||
final_description = final_params.get("new_description", new_description)
|
||||
final_priority = final_params.get("new_priority", new_priority)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_document_id = final_params.get("document_id", document_id)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
fields: dict[str, Any] = {}
|
||||
if final_summary:
|
||||
fields["summary"] = final_summary
|
||||
if final_description is not None:
|
||||
fields["description"] = {
|
||||
"type": "doc",
|
||||
"version": 1,
|
||||
"content": [
|
||||
{
|
||||
"type": "paragraph",
|
||||
"content": [{"type": "text", "text": final_description}],
|
||||
}
|
||||
],
|
||||
}
|
||||
if final_priority:
|
||||
fields["priority"] = {"name": final_priority}
|
||||
|
||||
if not fields:
|
||||
return {"status": "error", "message": "No changes specified."}
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
await asyncio.to_thread(
|
||||
jira_client.update_issue, final_issue_key, fields
|
||||
)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
issue_url = (
|
||||
f"{jira_history._base_url}/browse/{final_issue_key}"
|
||||
if jira_history._base_url and final_issue_key
|
||||
else ""
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
if final_document_id:
|
||||
try:
|
||||
from app.services.jira import JiraKBSyncService
|
||||
|
||||
kb_service = JiraKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=final_document_id,
|
||||
issue_id=final_issue_key,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": final_issue_key,
|
||||
"issue_url": issue_url,
|
||||
"message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error updating Jira issue: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the issue.",
|
||||
}
|
||||
|
||||
return update_jira_issue
|
||||
|
|
@ -9,6 +9,7 @@ This module provides:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
|
@ -19,7 +20,7 @@ from langchain_core.tools import StructuredTool
|
|||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import shielded_async_session
|
||||
from app.db import NATIVE_TO_LEGACY_DOCTYPE, shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
|
|
@ -60,7 +61,7 @@ def _is_degenerate_query(query: str) -> bool:
|
|||
|
||||
async def _browse_recent_documents(
|
||||
search_space_id: int,
|
||||
document_type: str | None,
|
||||
document_type: str | list[str] | None,
|
||||
top_k: int,
|
||||
start_date: datetime | None,
|
||||
end_date: datetime | None,
|
||||
|
|
@ -83,14 +84,22 @@ async def _browse_recent_documents(
|
|||
base_conditions = [Document.search_space_id == search_space_id]
|
||||
|
||||
if document_type is not None:
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
return []
|
||||
type_list = (
|
||||
document_type if isinstance(document_type, list) else [document_type]
|
||||
)
|
||||
doc_type_enums = []
|
||||
for dt in type_list:
|
||||
if isinstance(dt, str):
|
||||
with contextlib.suppress(KeyError):
|
||||
doc_type_enums.append(DocumentType[dt])
|
||||
else:
|
||||
doc_type_enums.append(dt)
|
||||
if not doc_type_enums:
|
||||
return []
|
||||
if len(doc_type_enums) == 1:
|
||||
base_conditions.append(Document.document_type == doc_type_enums[0])
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
||||
|
||||
if start_date is not None:
|
||||
base_conditions.append(Document.updated_at >= start_date)
|
||||
|
|
@ -195,10 +204,6 @@ _ALL_CONNECTORS: list[str] = [
|
|||
"CRAWLED_URL",
|
||||
"CIRCLEBACK",
|
||||
"OBSIDIAN_CONNECTOR",
|
||||
# Composio connectors
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"COMPOSIO_GMAIL_CONNECTOR",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
]
|
||||
|
||||
# Human-readable descriptions for each connector type
|
||||
|
|
@ -228,10 +233,6 @@ CONNECTOR_DESCRIPTIONS: dict[str, str] = {
|
|||
"BOOKSTACK_CONNECTOR": "BookStack pages (personal documentation)",
|
||||
"CIRCLEBACK": "Circleback meeting notes, transcripts, and action items",
|
||||
"OBSIDIAN_CONNECTOR": "Obsidian vault notes and markdown files (personal notes)",
|
||||
# Composio connectors
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "Google Drive files via Composio (personal cloud storage)",
|
||||
"COMPOSIO_GMAIL_CONNECTOR": "Gmail emails via Composio (personal emails)",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "Google Calendar events via Composio (personal calendar)",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -352,6 +353,20 @@ def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
|
|||
return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS))
|
||||
|
||||
|
||||
_INTERNAL_METADATA_KEYS: frozenset[str] = frozenset(
|
||||
{
|
||||
"message_id",
|
||||
"thread_id",
|
||||
"event_id",
|
||||
"calendar_id",
|
||||
"google_drive_file_id",
|
||||
"page_id",
|
||||
"issue_id",
|
||||
"connector_id",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def format_documents_for_context(
|
||||
documents: list[dict[str, Any]],
|
||||
*,
|
||||
|
|
@ -480,7 +495,10 @@ def format_documents_for_context(
|
|||
total_docs = len(grouped)
|
||||
|
||||
for doc_idx, g in enumerate(grouped.values()):
|
||||
metadata_json = json.dumps(g["metadata"], ensure_ascii=False)
|
||||
metadata_clean = {
|
||||
k: v for k, v in g["metadata"].items() if k not in _INTERNAL_METADATA_KEYS
|
||||
}
|
||||
metadata_json = json.dumps(metadata_clean, ensure_ascii=False)
|
||||
is_live_search = g["document_type"] in live_search_connectors
|
||||
|
||||
doc_lines: list[str] = [
|
||||
|
|
@ -617,7 +635,12 @@ async def search_knowledge_base_async(
|
|||
if available_document_types:
|
||||
doc_types_set = set(available_document_types)
|
||||
before_count = len(connectors)
|
||||
connectors = [c for c in connectors if c in doc_types_set]
|
||||
connectors = [
|
||||
c
|
||||
for c in connectors
|
||||
if c in doc_types_set
|
||||
or NATIVE_TO_LEGACY_DOCTYPE.get(c, "") in doc_types_set
|
||||
]
|
||||
skipped = before_count - len(connectors)
|
||||
if skipped:
|
||||
perf.info(
|
||||
|
|
@ -654,6 +677,13 @@ async def search_knowledge_base_async(
|
|||
)
|
||||
browse_connectors = connectors if connectors else [None] # type: ignore[list-item]
|
||||
|
||||
expanded_browse = []
|
||||
for c in browse_connectors:
|
||||
if c is not None and c in NATIVE_TO_LEGACY_DOCTYPE:
|
||||
expanded_browse.append([c, NATIVE_TO_LEGACY_DOCTYPE[c]])
|
||||
else:
|
||||
expanded_browse.append(c)
|
||||
|
||||
browse_results = await asyncio.gather(
|
||||
*[
|
||||
_browse_recent_documents(
|
||||
|
|
@ -663,7 +693,7 @@ async def search_knowledge_base_async(
|
|||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
for c in browse_connectors
|
||||
for c in expanded_browse
|
||||
]
|
||||
)
|
||||
for docs in browse_results:
|
||||
|
|
@ -779,6 +809,10 @@ async def search_knowledge_base_async(
|
|||
|
||||
deduplicated.append(doc)
|
||||
|
||||
# Sort by RRF score so the most relevant documents from ANY connector
|
||||
# appear first, preventing budget truncation from hiding top results.
|
||||
deduplicated.sort(key=lambda d: d.get("score", 0), reverse=True)
|
||||
|
||||
output_budget = _compute_tool_output_budget(max_input_tokens)
|
||||
result = format_documents_for_context(deduplicated, max_chars=output_budget)
|
||||
|
||||
|
|
|
|||
|
|
@ -38,11 +38,13 @@ def create_create_linear_issue_tool(
|
|||
"""Create a new issue in Linear.
|
||||
|
||||
Use this tool when the user explicitly asks to create, add, or file
|
||||
a new issue / ticket / task in Linear.
|
||||
a new issue / ticket / task in Linear. The user MUST describe the issue
|
||||
before you call this tool. If the request is vague, ask what the issue
|
||||
should be about. Never call this tool without a clear topic from the user.
|
||||
|
||||
Args:
|
||||
title: Short, descriptive issue title.
|
||||
description: Optional markdown body for the issue.
|
||||
title: Short, descriptive issue title. Infer from the user's request.
|
||||
description: Optional markdown body for the issue. Generate from context.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
|
|
@ -57,9 +59,9 @@ def create_create_linear_issue_tool(
|
|||
and move on. Do NOT retry, troubleshoot, or suggest alternatives.
|
||||
|
||||
Examples:
|
||||
- "Create a Linear issue titled 'Fix login bug'"
|
||||
- "Add a ticket for the payment timeout problem"
|
||||
- "File an issue about the broken search feature"
|
||||
- "Create a Linear issue for the login bug"
|
||||
- "File a ticket about the payment timeout problem"
|
||||
- "Add an issue for the broken search feature"
|
||||
"""
|
||||
logger.info(f"create_linear_issue called: title='{title}'")
|
||||
|
||||
|
|
@ -82,6 +84,15 @@ def create_create_linear_issue_tool(
|
|||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
workspaces = context.get("workspaces", [])
|
||||
if workspaces and all(w.get("auth_expired") for w in workspaces):
|
||||
logger.warning("All Linear accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "linear",
|
||||
}
|
||||
|
||||
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
|
||||
approval = interrupt(
|
||||
{
|
||||
|
|
@ -215,12 +226,36 @@ def create_create_linear_issue_tool(
|
|||
logger.info(
|
||||
f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.linear import LinearKBSyncService
|
||||
|
||||
kb_service = LinearKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
issue_id=result.get("id"),
|
||||
issue_identifier=result.get("identifier", ""),
|
||||
issue_title=result.get("title", final_title),
|
||||
issue_url=result.get("url"),
|
||||
description=final_description,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_id": result.get("id"),
|
||||
"identifier": result.get("identifier"),
|
||||
"url": result.get("url"),
|
||||
"message": result.get("message"),
|
||||
"message": (result.get("message", "") + kb_message_suffix),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ def create_delete_linear_issue_tool(
|
|||
- If status is "not_found", inform the user conversationally using the exact message
|
||||
provided. Do NOT treat this as an error. Simply relay the message and ask the user
|
||||
to verify the issue title or identifier, or check if it has been indexed.
|
||||
|
||||
Examples:
|
||||
- "Delete the 'Fix login bug' Linear issue"
|
||||
- "Archive ENG-42"
|
||||
|
|
@ -91,6 +90,14 @@ def create_delete_linear_issue_tool(
|
|||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
logger.warning(f"Auth expired for delete context: {error_msg}")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "linear",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Issue not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
|
|
|
|||
|
|
@ -103,6 +103,14 @@ def create_update_linear_issue_tool(
|
|||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
logger.warning(f"Auth expired for update context: {error_msg}")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "linear",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Issue not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
|
|
|
|||
|
|
@ -33,17 +33,21 @@ def create_create_notion_page_tool(
|
|||
@tool
|
||||
async def create_notion_page(
|
||||
title: str,
|
||||
content: str,
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new page in Notion with the given title and content.
|
||||
|
||||
Use this tool when the user asks you to create, save, or publish
|
||||
something to Notion. The page will be created in the user's
|
||||
configured Notion workspace.
|
||||
configured Notion workspace. The user MUST specify a topic before you
|
||||
call this tool. If the request does not contain a topic (e.g. "create a
|
||||
notion page"), ask what the page should be about. Never call this tool
|
||||
without a clear topic from the user.
|
||||
|
||||
Args:
|
||||
title: The title of the Notion page.
|
||||
content: The markdown content for the page body (supports headings, lists, paragraphs).
|
||||
content: Optional markdown content for the page body (supports headings, lists, paragraphs).
|
||||
Generate this yourself based on the user's topic.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
|
|
@ -58,8 +62,8 @@ def create_create_notion_page_tool(
|
|||
and move on. Do NOT troubleshoot or suggest alternatives.
|
||||
|
||||
Examples:
|
||||
- "Create a Notion page titled 'Meeting Notes' with content 'Discussed project timeline'"
|
||||
- "Save this to Notion with title 'Research Summary'"
|
||||
- "Create a Notion page about our Q2 roadmap"
|
||||
- "Save a summary of today's discussion to Notion"
|
||||
"""
|
||||
logger.info(f"create_notion_page called: title='{title}'")
|
||||
|
||||
|
|
@ -85,6 +89,15 @@ def create_create_notion_page_tool(
|
|||
"message": context["error"],
|
||||
}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Notion accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "notion",
|
||||
}
|
||||
|
||||
logger.info(f"Requesting approval for creating Notion page: '{title}'")
|
||||
approval = interrupt(
|
||||
{
|
||||
|
|
@ -215,6 +228,34 @@ def create_create_notion_page_tool(
|
|||
logger.info(
|
||||
f"create_page result: {result.get('status')} - {result.get('message', '')}"
|
||||
)
|
||||
|
||||
if result.get("status") == "success":
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.notion import NotionKBSyncService
|
||||
|
||||
kb_service = NotionKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
page_id=result.get("page_id"),
|
||||
page_title=result.get("title", final_title),
|
||||
page_url=result.get("url"),
|
||||
content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
result["message"] = result.get("message", "") + kb_message_suffix
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -95,8 +95,19 @@ def create_delete_notion_page_tool(
|
|||
"message": error_msg,
|
||||
}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Notion account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
}
|
||||
|
||||
page_id = context.get("page_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
connector_id_from_context = account.get("id")
|
||||
document_id = context.get("document_id")
|
||||
|
||||
logger.info(
|
||||
|
|
@ -262,6 +273,18 @@ def create_delete_notion_page_tool(
|
|||
raise
|
||||
|
||||
logger.error(f"Error deleting Notion page: {e}", exc_info=True)
|
||||
error_str = str(e).lower()
|
||||
if isinstance(e, NotionAPIError) and (
|
||||
"401" in error_str or "unauthorized" in error_str
|
||||
):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": str(e),
|
||||
"connector_id": connector_id_from_context
|
||||
if "connector_id_from_context" in dir()
|
||||
else None,
|
||||
"connector_type": "notion",
|
||||
}
|
||||
if isinstance(e, ValueError | NotionAPIError):
|
||||
message = str(e)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -33,16 +33,19 @@ def create_update_notion_page_tool(
|
|||
@tool
|
||||
async def update_notion_page(
|
||||
page_title: str,
|
||||
content: str,
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Notion page by appending new content.
|
||||
|
||||
Use this tool when the user asks you to add content to, modify, or update
|
||||
a Notion page. The new content will be appended to the existing page content.
|
||||
The user MUST specify what to add before you call this tool. If the
|
||||
request is vague, ask what content they want added.
|
||||
|
||||
Args:
|
||||
page_title: The title of the Notion page to update.
|
||||
content: The markdown content to append to the page body (supports headings, lists, paragraphs).
|
||||
content: Optional markdown content to append to the page body (supports headings, lists, paragraphs).
|
||||
Generate this yourself based on the user's request.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
|
|
@ -60,10 +63,9 @@ def create_update_notion_page_tool(
|
|||
Example: "I couldn't find the page '[page_title]' in your indexed Notion pages. [message details]"
|
||||
Do NOT treat this as an error. Do NOT invent information. Simply relay the message and
|
||||
ask the user to verify the page title or check if it's been indexed.
|
||||
|
||||
Examples:
|
||||
- "Add 'New meeting notes from today' to the 'Meeting Notes' Notion page"
|
||||
- "Append the following to the 'Project Plan' Notion page: '# Status Update\n\nCompleted phase 1'"
|
||||
- "Add today's meeting notes to the 'Meeting Notes' Notion page"
|
||||
- "Update the 'Project Plan' page with a status update on phase 1"
|
||||
"""
|
||||
logger.info(
|
||||
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
|
||||
|
|
@ -107,6 +109,17 @@ def create_update_notion_page_tool(
|
|||
"message": error_msg,
|
||||
}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Notion account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
}
|
||||
|
||||
page_id = context.get("page_id")
|
||||
document_id = context.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
|
@ -261,6 +274,18 @@ def create_update_notion_page_tool(
|
|||
raise
|
||||
|
||||
logger.error(f"Error updating Notion page: {e}", exc_info=True)
|
||||
error_str = str(e).lower()
|
||||
if isinstance(e, NotionAPIError) and (
|
||||
"401" in error_str or "unauthorized" in error_str
|
||||
):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": str(e),
|
||||
"connector_id": connector_id_from_context
|
||||
if "connector_id_from_context" in dir()
|
||||
else None,
|
||||
"connector_type": "notion",
|
||||
}
|
||||
if isinstance(e, ValueError | NotionAPIError):
|
||||
message = str(e)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -45,12 +45,33 @@ from langchain_core.tools import BaseTool
|
|||
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .confluence import (
|
||||
create_create_confluence_page_tool,
|
||||
create_delete_confluence_page_tool,
|
||||
create_update_confluence_page_tool,
|
||||
)
|
||||
from .display_image import create_display_image_tool
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .gmail import (
|
||||
create_create_gmail_draft_tool,
|
||||
create_send_gmail_email_tool,
|
||||
create_trash_gmail_email_tool,
|
||||
create_update_gmail_draft_tool,
|
||||
)
|
||||
from .google_calendar import (
|
||||
create_create_calendar_event_tool,
|
||||
create_delete_calendar_event_tool,
|
||||
create_update_calendar_event_tool,
|
||||
)
|
||||
from .google_drive import (
|
||||
create_create_google_drive_file_tool,
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
from .jira import (
|
||||
create_create_jira_issue_tool,
|
||||
create_delete_jira_issue_tool,
|
||||
create_update_jira_issue_tool,
|
||||
)
|
||||
from .knowledge_base import create_search_knowledge_base_tool
|
||||
from .linear import (
|
||||
create_create_linear_issue_tool,
|
||||
|
|
@ -257,7 +278,8 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
||||
),
|
||||
# =========================================================================
|
||||
# LINEAR TOOLS - create, update, delete issues (WIP - hidden from UI)
|
||||
# LINEAR TOOLS - create, update, delete issues
|
||||
# Auto-disabled when no Linear connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_linear_issue",
|
||||
|
|
@ -268,8 +290,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_linear_issue",
|
||||
|
|
@ -280,8 +300,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_linear_issue",
|
||||
|
|
@ -292,11 +310,10 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
# =========================================================================
|
||||
# NOTION TOOLS - create, update, delete pages (WIP - hidden from UI)
|
||||
# NOTION TOOLS - create, update, delete pages
|
||||
# Auto-disabled when no Notion connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_notion_page",
|
||||
|
|
@ -307,8 +324,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_notion_page",
|
||||
|
|
@ -319,8 +334,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_notion_page",
|
||||
|
|
@ -331,11 +344,10 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE DRIVE TOOLS - create files, delete files (WIP - hidden from UI)
|
||||
# GOOGLE DRIVE TOOLS - create files, delete files
|
||||
# Auto-disabled when no Google Drive connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_google_drive_file",
|
||||
|
|
@ -346,8 +358,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_google_drive_file",
|
||||
|
|
@ -358,8 +368,152 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE CALENDAR TOOLS - create, update, delete events
|
||||
# Auto-disabled when no Google Calendar connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_calendar_event",
|
||||
description="Create a new event on Google Calendar",
|
||||
factory=lambda deps: create_create_calendar_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_calendar_event",
|
||||
description="Update an existing indexed Google Calendar event",
|
||||
factory=lambda deps: create_update_calendar_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_calendar_event",
|
||||
description="Delete an existing indexed Google Calendar event",
|
||||
factory=lambda deps: create_delete_calendar_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# GMAIL TOOLS - create drafts, update drafts, send emails, trash emails
|
||||
# Auto-disabled when no Gmail connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_gmail_draft",
|
||||
description="Create a draft email in Gmail",
|
||||
factory=lambda deps: create_create_gmail_draft_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="send_gmail_email",
|
||||
description="Send an email via Gmail",
|
||||
factory=lambda deps: create_send_gmail_email_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="trash_gmail_email",
|
||||
description="Move an indexed email to trash in Gmail",
|
||||
factory=lambda deps: create_trash_gmail_email_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_gmail_draft",
|
||||
description="Update an existing Gmail draft",
|
||||
factory=lambda deps: create_update_gmail_draft_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# JIRA TOOLS - create, update, delete issues
|
||||
# Auto-disabled when no Jira connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_jira_issue",
|
||||
description="Create a new issue in the user's Jira project",
|
||||
factory=lambda deps: create_create_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_jira_issue",
|
||||
description="Update an existing indexed Jira issue",
|
||||
factory=lambda deps: create_update_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_jira_issue",
|
||||
description="Delete an existing indexed Jira issue",
|
||||
factory=lambda deps: create_delete_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# CONFLUENCE TOOLS - create, update, delete pages
|
||||
# Auto-disabled when no Confluence connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_confluence_page",
|
||||
description="Create a new page in the user's Confluence space",
|
||||
factory=lambda deps: create_create_confluence_page_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_confluence_page",
|
||||
description="Update an existing indexed Confluence page",
|
||||
factory=lambda deps: create_update_confluence_page_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_confluence_page",
|
||||
description="Delete an existing indexed Confluence page",
|
||||
factory=lambda deps: create_delete_confluence_page_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,719 +0,0 @@
|
|||
"""
|
||||
Composio Gmail Connector Module.
|
||||
|
||||
Provides Gmail specific methods for data retrieval and indexing via Composio.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from markdownify import markdownify as md
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.connectors.composio_connector import ComposioConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
calculate_date_range,
|
||||
check_duplicate_document_by_hash,
|
||||
safe_set_chunks,
|
||||
)
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
# Heartbeat configuration
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_current_timestamp() -> datetime:
|
||||
"""Get the current timestamp with timezone for updated_at field."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
async def check_document_by_unique_identifier(
|
||||
session: AsyncSession, unique_identifier_hash: str
|
||||
) -> Document | None:
|
||||
"""Check if a document with the given unique identifier hash already exists."""
|
||||
existing_doc_result = await session.execute(
|
||||
select(Document)
|
||||
.options(selectinload(Document.chunks))
|
||||
.where(Document.unique_identifier_hash == unique_identifier_hash)
|
||||
)
|
||||
return existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
async def update_connector_last_indexed(
|
||||
session: AsyncSession,
|
||||
connector,
|
||||
update_last_indexed: bool = True,
|
||||
) -> None:
|
||||
"""Update the last_indexed_at timestamp for a connector."""
|
||||
if update_last_indexed:
|
||||
connector.last_indexed_at = datetime.now(UTC)
|
||||
logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}")
|
||||
|
||||
|
||||
class ComposioGmailConnector(ComposioConnector):
|
||||
"""
|
||||
Gmail specific Composio connector.
|
||||
|
||||
Provides methods for listing messages, getting message details, and formatting
|
||||
Gmail messages from Gmail via Composio.
|
||||
"""
|
||||
|
||||
async def list_gmail_messages(
|
||||
self,
|
||||
query: str = "",
|
||||
max_results: int = 50,
|
||||
page_token: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], str | None, int | None, str | None]:
|
||||
"""
|
||||
List Gmail messages via Composio with pagination support.
|
||||
|
||||
Args:
|
||||
query: Gmail search query.
|
||||
max_results: Maximum number of messages per page (default: 50).
|
||||
page_token: Optional pagination token for next page.
|
||||
|
||||
Returns:
|
||||
Tuple of (messages list, next_page_token, result_size_estimate, error message).
|
||||
"""
|
||||
connected_account_id = await self.get_connected_account_id()
|
||||
if not connected_account_id:
|
||||
return [], None, None, "No connected account ID found"
|
||||
|
||||
entity_id = await self.get_entity_id()
|
||||
service = await self._get_service()
|
||||
return await service.get_gmail_messages(
|
||||
connected_account_id=connected_account_id,
|
||||
entity_id=entity_id,
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
page_token=page_token,
|
||||
)
|
||||
|
||||
async def get_gmail_message_detail(
|
||||
self, message_id: str
|
||||
) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""
|
||||
Get full details of a Gmail message via Composio.
|
||||
|
||||
Args:
|
||||
message_id: Gmail message ID.
|
||||
|
||||
Returns:
|
||||
Tuple of (message details, error message).
|
||||
"""
|
||||
connected_account_id = await self.get_connected_account_id()
|
||||
if not connected_account_id:
|
||||
return None, "No connected account ID found"
|
||||
|
||||
entity_id = await self.get_entity_id()
|
||||
service = await self._get_service()
|
||||
return await service.get_gmail_message_detail(
|
||||
connected_account_id=connected_account_id,
|
||||
entity_id=entity_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _html_to_markdown(html: str) -> str:
|
||||
"""Convert HTML (especially email layouts with nested tables) to clean markdown."""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
for tag in soup.find_all(["style", "script", "img"]):
|
||||
tag.decompose()
|
||||
for tag in soup.find_all(
|
||||
["table", "thead", "tbody", "tfoot", "tr", "td", "th"]
|
||||
):
|
||||
tag.unwrap()
|
||||
return md(str(soup)).strip()
|
||||
|
||||
def format_gmail_message_to_markdown(self, message: dict[str, Any]) -> str:
|
||||
"""
|
||||
Format a Gmail message to markdown.
|
||||
|
||||
Args:
|
||||
message: Message object from Composio's GMAIL_FETCH_EMAILS response.
|
||||
Composio structure: messageId, messageText, messageTimestamp,
|
||||
payload.headers, labelIds, attachmentList
|
||||
|
||||
Returns:
|
||||
Formatted markdown string.
|
||||
"""
|
||||
try:
|
||||
# Composio uses 'messageId' (camelCase)
|
||||
message_id = message.get("messageId", "") or message.get("id", "")
|
||||
label_ids = message.get("labelIds", [])
|
||||
|
||||
# Extract headers from payload
|
||||
payload = message.get("payload", {})
|
||||
headers = payload.get("headers", [])
|
||||
|
||||
# Parse headers into a dict
|
||||
header_dict = {}
|
||||
for header in headers:
|
||||
name = header.get("name", "").lower()
|
||||
value = header.get("value", "")
|
||||
header_dict[name] = value
|
||||
|
||||
# Extract key information
|
||||
subject = header_dict.get("subject", "No Subject")
|
||||
from_email = header_dict.get("from", "Unknown Sender")
|
||||
to_email = header_dict.get("to", "Unknown Recipient")
|
||||
# Composio provides messageTimestamp directly
|
||||
date_str = message.get("messageTimestamp", "") or header_dict.get(
|
||||
"date", "Unknown Date"
|
||||
)
|
||||
|
||||
# Build markdown content
|
||||
markdown_content = f"# {subject}\n\n"
|
||||
markdown_content += f"**From:** {from_email}\n"
|
||||
markdown_content += f"**To:** {to_email}\n"
|
||||
markdown_content += f"**Date:** {date_str}\n"
|
||||
|
||||
if label_ids:
|
||||
markdown_content += f"**Labels:** {', '.join(label_ids)}\n"
|
||||
|
||||
markdown_content += "\n---\n\n"
|
||||
|
||||
# Composio provides full message text in 'messageText' which is often raw HTML
|
||||
message_text = message.get("messageText", "")
|
||||
if message_text:
|
||||
message_text = self._html_to_markdown(message_text)
|
||||
markdown_content += f"## Content\n\n{message_text}\n\n"
|
||||
else:
|
||||
# Fallback to snippet if no messageText
|
||||
snippet = message.get("snippet", "")
|
||||
if snippet:
|
||||
markdown_content += f"## Preview\n\n{snippet}\n\n"
|
||||
|
||||
# Add attachment info if present
|
||||
attachments = message.get("attachmentList", [])
|
||||
if attachments:
|
||||
markdown_content += "## Attachments\n\n"
|
||||
for att in attachments:
|
||||
att_name = att.get("filename", att.get("name", "Unknown"))
|
||||
markdown_content += f"- {att_name}\n"
|
||||
markdown_content += "\n"
|
||||
|
||||
# Add message metadata
|
||||
markdown_content += "## Message Details\n\n"
|
||||
markdown_content += f"- **Message ID:** {message_id}\n"
|
||||
|
||||
return markdown_content
|
||||
|
||||
except Exception as e:
|
||||
return f"Error formatting message to markdown: {e!s}"
|
||||
|
||||
|
||||
# ============ Indexer Functions ============
|
||||
|
||||
|
||||
async def _analyze_gmail_messages_phase1(
|
||||
session: AsyncSession,
|
||||
messages: list[dict[str, Any]],
|
||||
composio_connector: ComposioGmailConnector,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> tuple[list[dict[str, Any]], int, int]:
|
||||
"""
|
||||
Phase 1: Analyze all messages, create pending documents.
|
||||
Makes ALL documents visible in the UI immediately with pending status.
|
||||
|
||||
Returns:
|
||||
Tuple of (messages_to_process, documents_skipped, duplicate_content_count)
|
||||
"""
|
||||
messages_to_process = []
|
||||
documents_skipped = 0
|
||||
duplicate_content_count = 0
|
||||
|
||||
for message in messages:
|
||||
try:
|
||||
# Composio uses 'messageId' (camelCase), not 'id'
|
||||
message_id = message.get("messageId", "") or message.get("id", "")
|
||||
if not message_id:
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Extract message info from Composio response
|
||||
payload = message.get("payload", {})
|
||||
headers = payload.get("headers", [])
|
||||
|
||||
subject = "No Subject"
|
||||
sender = "Unknown Sender"
|
||||
date_str = message.get("messageTimestamp", "Unknown Date")
|
||||
|
||||
for header in headers:
|
||||
name = header.get("name", "").lower()
|
||||
value = header.get("value", "")
|
||||
if name == "subject":
|
||||
subject = value
|
||||
elif name == "from":
|
||||
sender = value
|
||||
elif name == "date":
|
||||
date_str = value
|
||||
|
||||
# Format to markdown using the full message data
|
||||
markdown_content = composio_connector.format_gmail_message_to_markdown(
|
||||
message
|
||||
)
|
||||
|
||||
# Check for empty content
|
||||
if not markdown_content.strip():
|
||||
logger.warning(f"Skipping Gmail message with no content: {subject}")
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Generate unique identifier
|
||||
document_type = DocumentType(TOOLKIT_TO_DOCUMENT_TYPE["gmail"])
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
document_type, f"gmail_{message_id}", search_space_id
|
||||
)
|
||||
|
||||
content_hash = generate_content_hash(markdown_content, search_space_id)
|
||||
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
# Get label IDs and thread_id from Composio response
|
||||
label_ids = message.get("labelIds", [])
|
||||
thread_id = message.get("threadId", "") or message.get("thread_id", "")
|
||||
|
||||
if existing_document:
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Ensure status is ready (might have been stuck in processing/pending)
|
||||
if not DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
messages_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date_str": date_str,
|
||||
"label_ids": label_ids,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Document doesn't exist by unique_identifier_hash
|
||||
# Check if a document with the same content_hash exists (from standard connector)
|
||||
with session.no_autoflush:
|
||||
duplicate_by_content = await check_duplicate_document_by_hash(
|
||||
session, content_hash
|
||||
)
|
||||
|
||||
if duplicate_by_content:
|
||||
logger.info(
|
||||
f"Message {subject} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate_by_content.id}, "
|
||||
f"type: {duplicate_by_content.document_type}). Skipping."
|
||||
)
|
||||
duplicate_content_count += 1
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Create new document with PENDING status (visible in UI immediately)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=subject,
|
||||
document_type=DocumentType(TOOLKIT_TO_DOCUMENT_TYPE["gmail"]),
|
||||
document_metadata={
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date": date_str,
|
||||
"labels": label_ids,
|
||||
"connector_id": connector_id,
|
||||
"toolkit_id": "gmail",
|
||||
"source": "composio",
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
chunks=[], # Empty at creation - safe for async
|
||||
status=DocumentStatus.pending(), # Pending until processing starts
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
session.add(document)
|
||||
|
||||
messages_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date_str": date_str,
|
||||
"label_ids": label_ids,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Phase 1 for message: {e!s}", exc_info=True)
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
return messages_to_process, documents_skipped, duplicate_content_count
|
||||
|
||||
|
||||
async def _process_gmail_messages_phase2(
|
||||
session: AsyncSession,
|
||||
messages_to_process: list[dict[str, Any]],
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Phase 2: Process each document one by one.
|
||||
Each document transitions: pending → processing → ready/failed
|
||||
|
||||
Returns:
|
||||
Tuple of (documents_indexed, documents_failed)
|
||||
"""
|
||||
documents_indexed = 0
|
||||
documents_failed = 0
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
for item in messages_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
document = item["document"]
|
||||
try:
|
||||
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
|
||||
document.status = DocumentStatus.processing()
|
||||
await session.commit()
|
||||
|
||||
# Heavy processing (LLM, embeddings, chunks)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm and enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"message_id": item["message_id"],
|
||||
"thread_id": item["thread_id"],
|
||||
"subject": item["subject"],
|
||||
"sender": item["sender"],
|
||||
"document_type": "Gmail Message (Composio)",
|
||||
}
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
item["markdown_content"], user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = f"Gmail: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = item["subject"]
|
||||
document.content = summary_content
|
||||
document.content_hash = item["content_hash"]
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"message_id": item["message_id"],
|
||||
"thread_id": item["thread_id"],
|
||||
"subject": item["subject"],
|
||||
"sender": item["sender"],
|
||||
"date": item["date_str"],
|
||||
"labels": item["label_ids"],
|
||||
"connector_id": connector_id,
|
||||
"source": "composio",
|
||||
}
|
||||
await safe_set_chunks(session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready()
|
||||
|
||||
documents_indexed += 1
|
||||
|
||||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Gmail messages processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Gmail message: {e!s}", exc_info=True)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
document.updated_at = get_current_timestamp()
|
||||
except Exception as status_error:
|
||||
logger.error(
|
||||
f"Failed to update document status to failed: {status_error}"
|
||||
)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
return documents_indexed, documents_failed
|
||||
|
||||
|
||||
async def index_composio_gmail(
|
||||
session: AsyncSession,
|
||||
connector,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str | None,
|
||||
end_date: str | None,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry,
|
||||
update_last_indexed: bool = True,
|
||||
max_items: int = 1000,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, str]:
|
||||
"""Index Gmail messages via Composio with real-time document status updates."""
|
||||
try:
|
||||
composio_connector = ComposioGmailConnector(session, connector_id)
|
||||
|
||||
# Normalize date values - handle "undefined" strings from frontend
|
||||
if start_date == "undefined" or start_date == "":
|
||||
start_date = None
|
||||
if end_date == "undefined" or end_date == "":
|
||||
end_date = None
|
||||
|
||||
# Use provided dates directly if both are provided, otherwise calculate from last_indexed_at
|
||||
if start_date is not None and end_date is not None:
|
||||
start_date_str = start_date
|
||||
end_date_str = end_date
|
||||
else:
|
||||
start_date_str, end_date_str = calculate_date_range(
|
||||
connector, start_date, end_date, default_days_back=365
|
||||
)
|
||||
|
||||
# Build query with date range
|
||||
query_parts = []
|
||||
if start_date_str:
|
||||
query_parts.append(f"after:{start_date_str.replace('-', '/')}")
|
||||
if end_date_str:
|
||||
query_parts.append(f"before:{end_date_str.replace('-', '/')}")
|
||||
query = " ".join(query_parts) if query_parts else ""
|
||||
|
||||
logger.info(
|
||||
f"Gmail query for connector {connector_id}: '{query}' "
|
||||
f"(start_date={start_date_str}, end_date={end_date_str})"
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching Gmail messages via Composio for connector {connector_id}",
|
||||
{"stage": "fetching_messages"},
|
||||
)
|
||||
|
||||
# =======================================================================
|
||||
# FETCH ALL MESSAGES FIRST
|
||||
# =======================================================================
|
||||
batch_size = 50
|
||||
page_token = None
|
||||
all_messages = []
|
||||
result_size_estimate = None
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
while len(all_messages) < max_items:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(len(all_messages))
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
remaining = max_items - len(all_messages)
|
||||
current_batch_size = min(batch_size, remaining)
|
||||
|
||||
(
|
||||
messages,
|
||||
next_token,
|
||||
result_size_estimate_batch,
|
||||
error,
|
||||
) = await composio_connector.list_gmail_messages(
|
||||
query=query,
|
||||
max_results=current_batch_size,
|
||||
page_token=page_token,
|
||||
)
|
||||
|
||||
if error:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, f"Failed to fetch Gmail messages: {error}", {}
|
||||
)
|
||||
return 0, f"Failed to fetch Gmail messages: {error}"
|
||||
|
||||
if not messages:
|
||||
break
|
||||
|
||||
if result_size_estimate is None and result_size_estimate_batch is not None:
|
||||
result_size_estimate = result_size_estimate_batch
|
||||
logger.info(
|
||||
f"Gmail API estimated {result_size_estimate} total messages for query: '{query}'"
|
||||
)
|
||||
|
||||
all_messages.extend(messages)
|
||||
logger.info(
|
||||
f"Fetched {len(messages)} messages (total: {len(all_messages)})"
|
||||
)
|
||||
|
||||
if not next_token or len(messages) < current_batch_size:
|
||||
break
|
||||
|
||||
page_token = next_token
|
||||
|
||||
if not all_messages:
|
||||
success_msg = "No Gmail messages found in the specified date range"
|
||||
await task_logger.log_task_success(
|
||||
log_entry, success_msg, {"messages_count": 0}
|
||||
)
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
await session.commit()
|
||||
return (
|
||||
0,
|
||||
None,
|
||||
) # Return None (not error) when no items found - this is success with 0 items
|
||||
|
||||
logger.info(f"Found {len(all_messages)} Gmail messages to index via Composio")
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Analyze all messages, create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# =======================================================================
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Phase 1: Creating pending documents for {len(all_messages)} messages",
|
||||
{"stage": "phase1_pending"},
|
||||
)
|
||||
|
||||
(
|
||||
messages_to_process,
|
||||
documents_skipped,
|
||||
duplicate_content_count,
|
||||
) = await _analyze_gmail_messages_phase1(
|
||||
session=session,
|
||||
messages=all_messages,
|
||||
composio_connector=composio_connector,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Commit all pending documents - they all appear in UI now
|
||||
new_documents_count = len([m for m in messages_to_process if m["is_new"]])
|
||||
if new_documents_count > 0:
|
||||
logger.info(f"Phase 1: Committing {new_documents_count} pending documents")
|
||||
await session.commit()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Phase 2: Processing {len(messages_to_process)} documents",
|
||||
{"stage": "phase2_processing"},
|
||||
)
|
||||
|
||||
documents_indexed, documents_failed = await _process_gmail_messages_phase2(
|
||||
session=session,
|
||||
messages_to_process=messages_to_process,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=getattr(connector, "enable_summary", False),
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
# CRITICAL: Always update timestamp so Zero syncs
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
||||
# Final commit to ensure all documents are persisted
|
||||
logger.info(f"Final commit: Total {documents_indexed} Gmail messages processed")
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Successfully committed all Composio Gmail document changes to database"
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any remaining integrity errors gracefully
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in str(e).lower()
|
||||
or "uniqueviolationerror" in str(e).lower()
|
||||
):
|
||||
logger.warning(
|
||||
f"Duplicate content_hash detected during final commit. "
|
||||
f"Rolling back and continuing. Error: {e!s}"
|
||||
)
|
||||
await session.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
# Build warning message if there were issues
|
||||
warning_parts = []
|
||||
if duplicate_content_count > 0:
|
||||
warning_parts.append(f"{duplicate_content_count} duplicate")
|
||||
if documents_failed > 0:
|
||||
warning_parts.append(f"{documents_failed} failed")
|
||||
warning_message = ", ".join(warning_parts) if warning_parts else None
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed Gmail indexing via Composio for connector {connector_id}",
|
||||
{
|
||||
"documents_indexed": documents_indexed,
|
||||
"documents_skipped": documents_skipped,
|
||||
"documents_failed": documents_failed,
|
||||
"duplicate_content_count": duplicate_content_count,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Composio Gmail indexing completed: {documents_indexed} ready, "
|
||||
f"{documents_skipped} skipped, {documents_failed} failed "
|
||||
f"({duplicate_content_count} duplicate content)"
|
||||
)
|
||||
return documents_indexed, warning_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to index Gmail via Composio: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to index Gmail via Composio: {e!s}"
|
||||
|
|
@ -1,566 +0,0 @@
|
|||
"""
|
||||
Composio Google Calendar Connector Module.
|
||||
|
||||
Provides Google Calendar specific methods for data retrieval and indexing via Composio.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.connectors.composio_connector import ComposioConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
calculate_date_range,
|
||||
check_duplicate_document_by_hash,
|
||||
safe_set_chunks,
|
||||
)
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
# Heartbeat configuration
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_current_timestamp() -> datetime:
|
||||
"""Get the current timestamp with timezone for updated_at field."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
async def check_document_by_unique_identifier(
|
||||
session: AsyncSession, unique_identifier_hash: str
|
||||
) -> Document | None:
|
||||
"""Check if a document with the given unique identifier hash already exists."""
|
||||
existing_doc_result = await session.execute(
|
||||
select(Document)
|
||||
.options(selectinload(Document.chunks))
|
||||
.where(Document.unique_identifier_hash == unique_identifier_hash)
|
||||
)
|
||||
return existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
async def update_connector_last_indexed(
|
||||
session: AsyncSession,
|
||||
connector,
|
||||
update_last_indexed: bool = True,
|
||||
) -> None:
|
||||
"""Update the last_indexed_at timestamp for a connector."""
|
||||
if update_last_indexed:
|
||||
connector.last_indexed_at = datetime.now(UTC)
|
||||
logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}")
|
||||
|
||||
|
||||
class ComposioGoogleCalendarConnector(ComposioConnector):
|
||||
"""
|
||||
Google Calendar specific Composio connector.
|
||||
|
||||
Provides methods for listing calendar events and formatting them from
|
||||
Google Calendar via Composio.
|
||||
"""
|
||||
|
||||
async def list_calendar_events(
|
||||
self,
|
||||
time_min: str | None = None,
|
||||
time_max: str | None = None,
|
||||
max_results: int = 250,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""
|
||||
List Google Calendar events via Composio.
|
||||
|
||||
Args:
|
||||
time_min: Start time (RFC3339 format).
|
||||
time_max: End time (RFC3339 format).
|
||||
max_results: Maximum number of events.
|
||||
|
||||
Returns:
|
||||
Tuple of (events list, error message).
|
||||
"""
|
||||
connected_account_id = await self.get_connected_account_id()
|
||||
if not connected_account_id:
|
||||
return [], "No connected account ID found"
|
||||
|
||||
entity_id = await self.get_entity_id()
|
||||
service = await self._get_service()
|
||||
return await service.get_calendar_events(
|
||||
connected_account_id=connected_account_id,
|
||||
entity_id=entity_id,
|
||||
time_min=time_min,
|
||||
time_max=time_max,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
def format_calendar_event_to_markdown(self, event: dict[str, Any]) -> str:
|
||||
"""
|
||||
Format a Google Calendar event to markdown.
|
||||
|
||||
Args:
|
||||
event: Event object from Google Calendar API.
|
||||
|
||||
Returns:
|
||||
Formatted markdown string.
|
||||
"""
|
||||
try:
|
||||
# Extract basic event information
|
||||
summary = event.get("summary", "No Title")
|
||||
description = event.get("description", "")
|
||||
location = event.get("location", "")
|
||||
|
||||
# Extract start and end times
|
||||
start = event.get("start", {})
|
||||
end = event.get("end", {})
|
||||
|
||||
start_time = start.get("dateTime") or start.get("date", "")
|
||||
end_time = end.get("dateTime") or end.get("date", "")
|
||||
|
||||
# Format times for display
|
||||
def format_time(time_str: str) -> str:
|
||||
if not time_str:
|
||||
return "Unknown"
|
||||
try:
|
||||
if "T" in time_str:
|
||||
dt = datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||
return dt.strftime("%Y-%m-%d %H:%M")
|
||||
return time_str
|
||||
except Exception:
|
||||
return time_str
|
||||
|
||||
start_formatted = format_time(start_time)
|
||||
end_formatted = format_time(end_time)
|
||||
|
||||
# Extract attendees
|
||||
attendees = event.get("attendees", [])
|
||||
attendee_list = []
|
||||
for attendee in attendees:
|
||||
email = attendee.get("email", "")
|
||||
display_name = attendee.get("displayName", email)
|
||||
response_status = attendee.get("responseStatus", "")
|
||||
attendee_list.append(f"- {display_name} ({response_status})")
|
||||
|
||||
# Build markdown content
|
||||
markdown_content = f"# {summary}\n\n"
|
||||
markdown_content += f"**Start:** {start_formatted}\n"
|
||||
markdown_content += f"**End:** {end_formatted}\n"
|
||||
|
||||
if location:
|
||||
markdown_content += f"**Location:** {location}\n"
|
||||
|
||||
markdown_content += "\n"
|
||||
|
||||
if description:
|
||||
markdown_content += f"## Description\n\n{description}\n\n"
|
||||
|
||||
if attendee_list:
|
||||
markdown_content += "## Attendees\n\n"
|
||||
markdown_content += "\n".join(attendee_list)
|
||||
markdown_content += "\n\n"
|
||||
|
||||
# Add event metadata
|
||||
markdown_content += "## Event Details\n\n"
|
||||
markdown_content += f"- **Event ID:** {event.get('id', 'Unknown')}\n"
|
||||
markdown_content += f"- **Created:** {event.get('created', 'Unknown')}\n"
|
||||
markdown_content += f"- **Updated:** {event.get('updated', 'Unknown')}\n"
|
||||
|
||||
return markdown_content
|
||||
|
||||
except Exception as e:
|
||||
return f"Error formatting event to markdown: {e!s}"
|
||||
|
||||
|
||||
# ============ Indexer Functions ============
|
||||
|
||||
|
||||
async def index_composio_google_calendar(
|
||||
session: AsyncSession,
|
||||
connector,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str | None,
|
||||
end_date: str | None,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry,
|
||||
update_last_indexed: bool = True,
|
||||
max_items: int = 2500,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, str]:
|
||||
"""Index Google Calendar events via Composio."""
|
||||
try:
|
||||
composio_connector = ComposioGoogleCalendarConnector(session, connector_id)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching Google Calendar events via Composio for connector {connector_id}",
|
||||
{"stage": "fetching_events"},
|
||||
)
|
||||
|
||||
# Normalize date values - handle "undefined" strings from frontend
|
||||
if start_date == "undefined" or start_date == "":
|
||||
start_date = None
|
||||
if end_date == "undefined" or end_date == "":
|
||||
end_date = None
|
||||
|
||||
# Use provided dates directly if both are provided, otherwise calculate from last_indexed_at
|
||||
# This ensures user-selected dates are respected (matching non-Composio Calendar connector behavior)
|
||||
if start_date is not None and end_date is not None:
|
||||
# User provided both dates - use them directly
|
||||
start_date_str = start_date
|
||||
end_date_str = end_date
|
||||
else:
|
||||
# Calculate date range with defaults (uses last_indexed_at or 365 days back)
|
||||
# This ensures indexing works even when user doesn't specify dates
|
||||
start_date_str, end_date_str = calculate_date_range(
|
||||
connector, start_date, end_date, default_days_back=365
|
||||
)
|
||||
|
||||
# Build time range for API call
|
||||
time_min = f"{start_date_str}T00:00:00Z"
|
||||
time_max = f"{end_date_str}T23:59:59Z"
|
||||
|
||||
logger.info(
|
||||
f"Google Calendar query for connector {connector_id}: "
|
||||
f"(start_date={start_date_str}, end_date={end_date_str})"
|
||||
)
|
||||
|
||||
events, error = await composio_connector.list_calendar_events(
|
||||
time_min=time_min,
|
||||
time_max=time_max,
|
||||
max_results=max_items,
|
||||
)
|
||||
|
||||
if error:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, f"Failed to fetch Calendar events: {error}", {}
|
||||
)
|
||||
return 0, f"Failed to fetch Calendar events: {error}"
|
||||
|
||||
if not events:
|
||||
success_msg = "No Google Calendar events found in the specified date range"
|
||||
await task_logger.log_task_success(
|
||||
log_entry, success_msg, {"events_count": 0}
|
||||
)
|
||||
# CRITICAL: Update timestamp even when no events found so Zero syncs and UI shows indexed status
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
await session.commit()
|
||||
return (
|
||||
0,
|
||||
None,
|
||||
) # Return None (not error) when no items found - this is success with 0 items
|
||||
|
||||
logger.info(f"Found {len(events)} Google Calendar events to index via Composio")
|
||||
|
||||
documents_indexed = 0
|
||||
documents_skipped = 0
|
||||
documents_failed = 0 # Track events that failed processing
|
||||
duplicate_content_count = (
|
||||
0 # Track events skipped due to duplicate content_hash
|
||||
)
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Analyze all events, create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# =======================================================================
|
||||
events_to_process = [] # List of dicts with document and event data
|
||||
new_documents_created = False
|
||||
|
||||
for event in events:
|
||||
try:
|
||||
# Handle both standard Google API and potential Composio variations
|
||||
event_id = event.get("id", "") or event.get("eventId", "")
|
||||
summary = (
|
||||
event.get("summary", "") or event.get("title", "") or "No Title"
|
||||
)
|
||||
|
||||
if not event_id:
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Format to markdown
|
||||
markdown_content = composio_connector.format_calendar_event_to_markdown(
|
||||
event
|
||||
)
|
||||
|
||||
# Generate unique identifier
|
||||
document_type = DocumentType(TOOLKIT_TO_DOCUMENT_TYPE["googlecalendar"])
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
document_type, f"calendar_{event_id}", search_space_id
|
||||
)
|
||||
|
||||
content_hash = generate_content_hash(markdown_content, search_space_id)
|
||||
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
# Extract event times
|
||||
start = event.get("start", {})
|
||||
end = event.get("end", {})
|
||||
start_time = start.get("dateTime") or start.get("date", "")
|
||||
end_time = end.get("dateTime") or end.get("date", "")
|
||||
location = event.get("location", "")
|
||||
|
||||
if existing_document:
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Ensure status is ready (might have been stuck in processing/pending)
|
||||
if not DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
events_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"event_id": event_id,
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Document doesn't exist by unique_identifier_hash
|
||||
# Check if a document with the same content_hash exists (from standard connector)
|
||||
with session.no_autoflush:
|
||||
duplicate_by_content = await check_duplicate_document_by_hash(
|
||||
session, content_hash
|
||||
)
|
||||
|
||||
if duplicate_by_content:
|
||||
logger.info(
|
||||
f"Event {summary} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate_by_content.id}, "
|
||||
f"type: {duplicate_by_content.document_type}). Skipping."
|
||||
)
|
||||
duplicate_content_count += 1
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Create new document with PENDING status (visible in UI immediately)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=summary,
|
||||
document_type=DocumentType(
|
||||
TOOLKIT_TO_DOCUMENT_TYPE["googlecalendar"]
|
||||
),
|
||||
document_metadata={
|
||||
"event_id": event_id,
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
"connector_id": connector_id,
|
||||
"toolkit_id": "googlecalendar",
|
||||
"source": "composio",
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
chunks=[], # Empty at creation - safe for async
|
||||
status=DocumentStatus.pending(), # Pending until processing starts
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
session.add(document)
|
||||
new_documents_created = True
|
||||
|
||||
events_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"event_id": event_id,
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Phase 1 for event: {e!s}", exc_info=True)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
# Commit all pending documents - they all appear in UI now
|
||||
if new_documents_created:
|
||||
logger.info(
|
||||
f"Phase 1: Committing {len([e for e in events_to_process if e['is_new']])} pending documents"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(events_to_process)} documents")
|
||||
|
||||
for item in events_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
document = item["document"]
|
||||
try:
|
||||
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
|
||||
document.status = DocumentStatus.processing()
|
||||
await session.commit()
|
||||
|
||||
# Heavy processing (LLM, embeddings, chunks)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"event_id": item["event_id"],
|
||||
"summary": item["summary"],
|
||||
"start_time": item["start_time"],
|
||||
"document_type": "Google Calendar Event (Composio)",
|
||||
}
|
||||
(
|
||||
summary_content,
|
||||
summary_embedding,
|
||||
) = await generate_document_summary(
|
||||
item["markdown_content"],
|
||||
user_llm,
|
||||
document_metadata_for_summary,
|
||||
)
|
||||
else:
|
||||
summary_content = (
|
||||
f"Calendar: {item['summary']}\n\n{item['markdown_content']}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = item["summary"]
|
||||
document.content = summary_content
|
||||
document.content_hash = item["content_hash"]
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"event_id": item["event_id"],
|
||||
"summary": item["summary"],
|
||||
"start_time": item["start_time"],
|
||||
"end_time": item["end_time"],
|
||||
"location": item["location"],
|
||||
"connector_id": connector_id,
|
||||
"source": "composio",
|
||||
}
|
||||
await safe_set_chunks(session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready()
|
||||
|
||||
documents_indexed += 1
|
||||
|
||||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Google Calendar events processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Calendar event: {e!s}", exc_info=True)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
document.updated_at = get_current_timestamp()
|
||||
except Exception as status_error:
|
||||
logger.error(
|
||||
f"Failed to update document status to failed: {status_error}"
|
||||
)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
|
||||
# This ensures the UI shows "Last indexed" instead of "Never indexed"
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
||||
# Final commit to ensure all documents are persisted (safety net)
|
||||
# This matches the pattern used in non-Composio Gmail indexer
|
||||
logger.info(
|
||||
f"Final commit: Total {documents_indexed} Google Calendar events processed"
|
||||
)
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Successfully committed all Composio Google Calendar document changes to database"
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any remaining integrity errors gracefully (race conditions, etc.)
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in str(e).lower()
|
||||
or "uniqueviolationerror" in str(e).lower()
|
||||
):
|
||||
logger.warning(
|
||||
f"Duplicate content_hash detected during final commit. "
|
||||
f"This may occur if the same event was indexed by multiple connectors. "
|
||||
f"Rolling back and continuing. Error: {e!s}"
|
||||
)
|
||||
await session.rollback()
|
||||
# Don't fail the entire task - some documents may have been successfully indexed
|
||||
else:
|
||||
raise
|
||||
|
||||
# Build warning message if there were issues
|
||||
warning_parts = []
|
||||
if duplicate_content_count > 0:
|
||||
warning_parts.append(f"{duplicate_content_count} duplicate")
|
||||
if documents_failed > 0:
|
||||
warning_parts.append(f"{documents_failed} failed")
|
||||
warning_message = ", ".join(warning_parts) if warning_parts else None
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed Google Calendar indexing via Composio for connector {connector_id}",
|
||||
{
|
||||
"documents_indexed": documents_indexed,
|
||||
"documents_skipped": documents_skipped,
|
||||
"documents_failed": documents_failed,
|
||||
"duplicate_content_count": duplicate_content_count,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Composio Google Calendar indexing completed: {documents_indexed} ready, "
|
||||
f"{documents_skipped} skipped, {documents_failed} failed "
|
||||
f"({duplicate_content_count} duplicate content)"
|
||||
)
|
||||
return documents_indexed, warning_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to index Google Calendar via Composio: {e!s}", exc_info=True
|
||||
)
|
||||
return 0, f"Failed to index Google Calendar via Composio: {e!s}"
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -14,7 +14,6 @@ from sqlalchemy.future import select
|
|||
from app.config import config
|
||||
from app.connectors.confluence_connector import ConfluenceConnector
|
||||
from app.db import SearchSourceConnector
|
||||
from app.routes.confluence_add_connector_route import refresh_confluence_token
|
||||
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
|
|
@ -190,7 +189,11 @@ class ConfluenceHistoryConnector:
|
|||
f"Connector {self._connector_id} not found; cannot refresh token."
|
||||
)
|
||||
|
||||
# Refresh token
|
||||
# Lazy import to avoid circular dependency
|
||||
from app.routes.confluence_add_connector_route import (
|
||||
refresh_confluence_token,
|
||||
)
|
||||
|
||||
connector = await refresh_confluence_token(self._session, connector)
|
||||
|
||||
# Reload credentials after refresh
|
||||
|
|
@ -341,6 +344,61 @@ class ConfluenceHistoryConnector:
|
|||
logger.error(f"Confluence API request error: {e!s}", exc_info=True)
|
||||
raise Exception(f"Confluence API request failed: {e!s}") from e
|
||||
|
||||
async def _make_api_request_with_method(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str = "GET",
|
||||
json_payload: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Make a request to the Confluence API with a specified HTTP method."""
|
||||
if not self._use_oauth:
|
||||
raise ValueError("Write operations require OAuth authentication")
|
||||
|
||||
token = await self._get_valid_token()
|
||||
base_url = await self._get_base_url()
|
||||
http_client = await self._get_client()
|
||||
|
||||
url = f"{base_url}/wiki/api/v2/{endpoint}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
method_upper = method.upper()
|
||||
if method_upper == "POST":
|
||||
response = await http_client.post(
|
||||
url, headers=headers, json=json_payload, params=params
|
||||
)
|
||||
elif method_upper == "PUT":
|
||||
response = await http_client.put(
|
||||
url, headers=headers, json=json_payload, params=params
|
||||
)
|
||||
elif method_upper == "DELETE":
|
||||
response = await http_client.delete(url, headers=headers, params=params)
|
||||
else:
|
||||
response = await http_client.get(url, headers=headers, params=params)
|
||||
|
||||
response.raise_for_status()
|
||||
if response.status_code == 204 or not response.text:
|
||||
return {"status": "success"}
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_detail = {
|
||||
"status_code": e.response.status_code,
|
||||
"url": str(e.request.url),
|
||||
"response_text": e.response.text,
|
||||
}
|
||||
logger.error(f"Confluence API HTTP error: {error_detail}")
|
||||
raise Exception(
|
||||
f"Confluence API request failed (HTTP {e.response.status_code}): {e.response.text}"
|
||||
) from e
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Confluence API request error: {e!s}", exc_info=True)
|
||||
raise Exception(f"Confluence API request failed: {e!s}") from e
|
||||
|
||||
async def get_all_spaces(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch all spaces from Confluence.
|
||||
|
|
@ -593,6 +651,65 @@ class ConfluenceHistoryConnector:
|
|||
except Exception as e:
|
||||
return [], f"Error fetching pages: {e!s}"
|
||||
|
||||
async def get_page(self, page_id: str) -> dict[str, Any]:
|
||||
"""Fetch a single page by ID with body content."""
|
||||
return await self._make_api_request(
|
||||
f"pages/{page_id}", params={"body-format": "storage"}
|
||||
)
|
||||
|
||||
async def create_page(
|
||||
self,
|
||||
space_id: str,
|
||||
title: str,
|
||||
body: str,
|
||||
parent_page_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new Confluence page."""
|
||||
payload: dict[str, Any] = {
|
||||
"spaceId": space_id,
|
||||
"title": title,
|
||||
"body": {
|
||||
"representation": "storage",
|
||||
"value": body,
|
||||
},
|
||||
"status": "current",
|
||||
}
|
||||
if parent_page_id:
|
||||
payload["parentId"] = parent_page_id
|
||||
return await self._make_api_request_with_method(
|
||||
"pages", method="POST", json_payload=payload
|
||||
)
|
||||
|
||||
async def update_page(
|
||||
self,
|
||||
page_id: str,
|
||||
title: str,
|
||||
body: str,
|
||||
version_number: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Confluence page (requires version number)."""
|
||||
payload: dict[str, Any] = {
|
||||
"id": page_id,
|
||||
"title": title,
|
||||
"body": {
|
||||
"representation": "storage",
|
||||
"value": body,
|
||||
},
|
||||
"version": {
|
||||
"number": version_number,
|
||||
},
|
||||
"status": "current",
|
||||
}
|
||||
return await self._make_api_request_with_method(
|
||||
f"pages/{page_id}", method="PUT", json_payload=payload
|
||||
)
|
||||
|
||||
async def delete_page(self, page_id: str) -> dict[str, Any]:
|
||||
"""Delete a Confluence page."""
|
||||
return await self._make_api_request_with_method(
|
||||
f"pages/{page_id}", method="DELETE"
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client connection."""
|
||||
if self._http_client:
|
||||
|
|
|
|||
|
|
@ -52,44 +52,39 @@ class GoogleCalendarConnector:
|
|||
) -> Credentials:
|
||||
"""
|
||||
Get valid Google OAuth credentials.
|
||||
Returns:
|
||||
Google OAuth credentials
|
||||
Raises:
|
||||
ValueError: If credentials have not been set
|
||||
Exception: If credential refresh fails
|
||||
|
||||
Supports both native OAuth (with refresh_token) and Composio-sourced
|
||||
credentials (with refresh_handler). For Composio credentials, validation
|
||||
and DB persistence are skipped since Composio manages its own tokens.
|
||||
"""
|
||||
if not all(
|
||||
[
|
||||
self._credentials.client_id,
|
||||
self._credentials.client_secret,
|
||||
self._credentials.refresh_token,
|
||||
]
|
||||
has_standard_refresh = bool(self._credentials.refresh_token)
|
||||
|
||||
if has_standard_refresh and not all(
|
||||
[self._credentials.client_id, self._credentials.client_secret]
|
||||
):
|
||||
raise ValueError(
|
||||
"Google OAuth credentials (client_id, client_secret, refresh_token) must be set"
|
||||
"Google OAuth credentials (client_id, client_secret) must be set"
|
||||
)
|
||||
|
||||
if self._credentials and not self._credentials.expired:
|
||||
return self._credentials
|
||||
|
||||
# Create credentials from refresh token
|
||||
self._credentials = Credentials(
|
||||
token=self._credentials.token,
|
||||
refresh_token=self._credentials.refresh_token,
|
||||
token_uri=self._credentials.token_uri,
|
||||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
if has_standard_refresh:
|
||||
self._credentials = Credentials(
|
||||
token=self._credentials.token,
|
||||
refresh_token=self._credentials.refresh_token,
|
||||
token_uri=self._credentials.token_uri,
|
||||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
|
||||
# Refresh the token if needed
|
||||
if self._credentials.expired or not self._credentials.valid:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
# Update the connector config in DB
|
||||
if self._session:
|
||||
# Use connector_id if available, otherwise fall back to user_id query
|
||||
# Only persist refreshed token for native OAuth (Composio manages its own)
|
||||
if has_standard_refresh and self._session:
|
||||
if self._connector_id:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
@ -110,7 +105,6 @@ class GoogleCalendarConnector:
|
|||
"GOOGLE_CALENDAR_CONNECTOR connector not found; cannot persist refreshed token."
|
||||
)
|
||||
|
||||
# Encrypt sensitive credentials before storing
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
|
|
@ -119,7 +113,6 @@ class GoogleCalendarConnector:
|
|||
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
# Encrypt sensitive fields
|
||||
if creds_dict.get("token"):
|
||||
creds_dict["token"] = token_encryption.encrypt_token(
|
||||
creds_dict["token"]
|
||||
|
|
@ -143,7 +136,6 @@ class GoogleCalendarConnector:
|
|||
await self._session.commit()
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
# Check if this is an invalid_grant error (token expired/revoked)
|
||||
if (
|
||||
"invalid_grant" in error_str.lower()
|
||||
or "token has been expired or revoked" in error_str.lower()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import io
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
from googleapiclient.http import MediaIoBaseUpload
|
||||
|
|
@ -15,16 +16,24 @@ from .file_types import GOOGLE_DOC, GOOGLE_SHEET
|
|||
class GoogleDriveClient:
|
||||
"""Client for Google Drive API operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession, connector_id: int):
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
credentials: "Credentials | None" = None,
|
||||
):
|
||||
"""
|
||||
Initialize Google Drive client.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Drive connector
|
||||
credentials: Pre-built credentials (e.g. from Composio). If None,
|
||||
credentials are loaded from the DB connector config.
|
||||
"""
|
||||
self.session = session
|
||||
self.connector_id = connector_id
|
||||
self._credentials = credentials
|
||||
self.service = None
|
||||
|
||||
async def get_service(self):
|
||||
|
|
@ -41,7 +50,12 @@ class GoogleDriveClient:
|
|||
return self.service
|
||||
|
||||
try:
|
||||
credentials = await get_valid_credentials(self.session, self.connector_id)
|
||||
if self._credentials:
|
||||
credentials = self._credentials
|
||||
else:
|
||||
credentials = await get_valid_credentials(
|
||||
self.session, self.connector_id
|
||||
)
|
||||
self.service = build("drive", "v3", credentials=credentials)
|
||||
return self.service
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ async def download_and_process_file(
|
|||
task_logger: TaskLoggingService,
|
||||
log_entry: Log,
|
||||
connector_id: int | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[Any, str | None, dict[str, Any] | None]:
|
||||
"""
|
||||
Download Google Drive file and process using Surfsense file processors.
|
||||
|
|
@ -95,6 +96,7 @@ async def download_and_process_file(
|
|||
},
|
||||
}
|
||||
# Include connector_id for de-indexing support
|
||||
connector_info["enable_summary"] = enable_summary
|
||||
if connector_id is not None:
|
||||
connector_info["connector_id"] = connector_id
|
||||
|
||||
|
|
|
|||
|
|
@ -81,44 +81,39 @@ class GoogleGmailConnector:
|
|||
) -> Credentials:
|
||||
"""
|
||||
Get valid Google OAuth credentials.
|
||||
Returns:
|
||||
Google OAuth credentials
|
||||
Raises:
|
||||
ValueError: If credentials have not been set
|
||||
Exception: If credential refresh fails
|
||||
|
||||
Supports both native OAuth (with refresh_token) and Composio-sourced
|
||||
credentials (with refresh_handler). For Composio credentials, validation
|
||||
and DB persistence are skipped since Composio manages its own tokens.
|
||||
"""
|
||||
if not all(
|
||||
[
|
||||
self._credentials.client_id,
|
||||
self._credentials.client_secret,
|
||||
self._credentials.refresh_token,
|
||||
]
|
||||
has_standard_refresh = bool(self._credentials.refresh_token)
|
||||
|
||||
if has_standard_refresh and not all(
|
||||
[self._credentials.client_id, self._credentials.client_secret]
|
||||
):
|
||||
raise ValueError(
|
||||
"Google OAuth credentials (client_id, client_secret, refresh_token) must be set"
|
||||
"Google OAuth credentials (client_id, client_secret) must be set"
|
||||
)
|
||||
|
||||
if self._credentials and not self._credentials.expired:
|
||||
return self._credentials
|
||||
|
||||
# Create credentials from refresh token
|
||||
self._credentials = Credentials(
|
||||
token=self._credentials.token,
|
||||
refresh_token=self._credentials.refresh_token,
|
||||
token_uri=self._credentials.token_uri,
|
||||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
if has_standard_refresh:
|
||||
self._credentials = Credentials(
|
||||
token=self._credentials.token,
|
||||
refresh_token=self._credentials.refresh_token,
|
||||
token_uri=self._credentials.token_uri,
|
||||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
|
||||
# Refresh the token if needed
|
||||
if self._credentials.expired or not self._credentials.valid:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
# Update the connector config in DB
|
||||
if self._session:
|
||||
# Use connector_id if available, otherwise fall back to user_id query
|
||||
# Only persist refreshed token for native OAuth (Composio manages its own)
|
||||
if has_standard_refresh and self._session:
|
||||
if self._connector_id:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
@ -138,12 +133,38 @@ class GoogleGmailConnector:
|
|||
raise RuntimeError(
|
||||
"GMAIL connector not found; cannot persist refreshed token."
|
||||
)
|
||||
connector.config = json.loads(self._credentials.to_json())
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
creds_dict = json.loads(self._credentials.to_json())
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if creds_dict.get("token"):
|
||||
creds_dict["token"] = token_encryption.encrypt_token(
|
||||
creds_dict["token"]
|
||||
)
|
||||
if creds_dict.get("refresh_token"):
|
||||
creds_dict["refresh_token"] = (
|
||||
token_encryption.encrypt_token(
|
||||
creds_dict["refresh_token"]
|
||||
)
|
||||
)
|
||||
if creds_dict.get("client_secret"):
|
||||
creds_dict["client_secret"] = (
|
||||
token_encryption.encrypt_token(
|
||||
creds_dict["client_secret"]
|
||||
)
|
||||
)
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
connector.config = creds_dict
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
# Check if this is an invalid_grant error (token expired/revoked)
|
||||
if (
|
||||
"invalid_grant" in error_str.lower()
|
||||
or "token has been expired or revoked" in error_str.lower()
|
||||
|
|
|
|||
|
|
@ -167,14 +167,23 @@ class JiraConnector:
|
|||
# Use direct base URL (works for both OAuth and legacy)
|
||||
url = f"{self.base_url}/rest/api/{self.api_version}/{endpoint}"
|
||||
|
||||
if method.upper() == "POST":
|
||||
method_upper = method.upper()
|
||||
if method_upper == "POST":
|
||||
response = requests.post(
|
||||
url, headers=headers, json=json_payload, timeout=500
|
||||
)
|
||||
elif method_upper == "PUT":
|
||||
response = requests.put(
|
||||
url, headers=headers, json=json_payload, timeout=500
|
||||
)
|
||||
elif method_upper == "DELETE":
|
||||
response = requests.delete(url, headers=headers, params=params, timeout=500)
|
||||
else:
|
||||
response = requests.get(url, headers=headers, params=params, timeout=500)
|
||||
|
||||
if response.status_code == 200:
|
||||
if response.status_code in (200, 201, 204):
|
||||
if response.status_code == 204 or not response.text:
|
||||
return {"status": "success"}
|
||||
return response.json()
|
||||
else:
|
||||
raise Exception(
|
||||
|
|
@ -352,6 +361,91 @@ class JiraConnector:
|
|||
except Exception as e:
|
||||
return [], f"Error fetching issues: {e!s}"
|
||||
|
||||
def get_myself(self) -> dict[str, Any]:
|
||||
"""Fetch the current user's profile (health check)."""
|
||||
return self.make_api_request("myself")
|
||||
|
||||
def get_projects(self) -> list[dict[str, Any]]:
|
||||
"""Fetch all projects the user has access to."""
|
||||
result = self.make_api_request("project/search")
|
||||
return result.get("values", [])
|
||||
|
||||
def get_issue_types(self) -> list[dict[str, Any]]:
|
||||
"""Fetch all issue types."""
|
||||
return self.make_api_request("issuetype")
|
||||
|
||||
def get_priorities(self) -> list[dict[str, Any]]:
|
||||
"""Fetch all priority levels."""
|
||||
return self.make_api_request("priority")
|
||||
|
||||
def get_issue(self, issue_id_or_key: str) -> dict[str, Any]:
|
||||
"""Fetch a single issue by ID or key."""
|
||||
return self.make_api_request(f"issue/{issue_id_or_key}")
|
||||
|
||||
def create_issue(
|
||||
self,
|
||||
project_key: str,
|
||||
summary: str,
|
||||
issue_type: str = "Task",
|
||||
description: str | None = None,
|
||||
priority: str | None = None,
|
||||
assignee_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new Jira issue."""
|
||||
fields: dict[str, Any] = {
|
||||
"project": {"key": project_key},
|
||||
"summary": summary,
|
||||
"issuetype": {"name": issue_type},
|
||||
}
|
||||
if description:
|
||||
fields["description"] = {
|
||||
"type": "doc",
|
||||
"version": 1,
|
||||
"content": [
|
||||
{
|
||||
"type": "paragraph",
|
||||
"content": [{"type": "text", "text": description}],
|
||||
}
|
||||
],
|
||||
}
|
||||
if priority:
|
||||
fields["priority"] = {"name": priority}
|
||||
if assignee_id:
|
||||
fields["assignee"] = {"accountId": assignee_id}
|
||||
|
||||
return self.make_api_request(
|
||||
"issue", method="POST", json_payload={"fields": fields}
|
||||
)
|
||||
|
||||
def update_issue(
|
||||
self, issue_id_or_key: str, fields: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Jira issue fields."""
|
||||
return self.make_api_request(
|
||||
f"issue/{issue_id_or_key}",
|
||||
method="PUT",
|
||||
json_payload={"fields": fields},
|
||||
)
|
||||
|
||||
def delete_issue(self, issue_id_or_key: str) -> dict[str, Any]:
|
||||
"""Delete a Jira issue."""
|
||||
return self.make_api_request(f"issue/{issue_id_or_key}", method="DELETE")
|
||||
|
||||
def get_transitions(self, issue_id_or_key: str) -> list[dict[str, Any]]:
|
||||
"""Get available transitions for an issue (for status changes)."""
|
||||
result = self.make_api_request(f"issue/{issue_id_or_key}/transitions")
|
||||
return result.get("transitions", [])
|
||||
|
||||
def transition_issue(
|
||||
self, issue_id_or_key: str, transition_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""Transition an issue to a new status."""
|
||||
return self.make_api_request(
|
||||
f"issue/{issue_id_or_key}/transitions",
|
||||
method="POST",
|
||||
json_payload={"transition": {"id": transition_id}},
|
||||
)
|
||||
|
||||
def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Format an issue for easier consumption.
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from sqlalchemy.future import select
|
|||
from app.config import config
|
||||
from app.connectors.jira_connector import JiraConnector
|
||||
from app.db import SearchSourceConnector
|
||||
from app.routes.jira_add_connector_route import refresh_jira_token
|
||||
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
|
|
@ -184,7 +183,9 @@ class JiraHistoryConnector:
|
|||
f"Connector {self._connector_id} not found; cannot refresh token."
|
||||
)
|
||||
|
||||
# Refresh token
|
||||
# Lazy import to avoid circular dependency
|
||||
from app.routes.jira_add_connector_route import refresh_jira_token
|
||||
|
||||
connector = await refresh_jira_token(self._session, connector)
|
||||
|
||||
# Reload credentials after refresh
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from notion_client import AsyncClient
|
||||
from notion_client.errors import APIResponseError
|
||||
from notion_markdown import to_notion
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
@ -834,106 +834,8 @@ class NotionHistoryConnector:
|
|||
return None
|
||||
|
||||
def _markdown_to_blocks(self, markdown: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert markdown content to Notion blocks.
|
||||
|
||||
This is a simple converter that handles basic markdown.
|
||||
For more complex markdown, consider using a proper markdown parser.
|
||||
|
||||
Args:
|
||||
markdown: Markdown content
|
||||
|
||||
Returns:
|
||||
List of Notion block objects
|
||||
"""
|
||||
blocks = []
|
||||
lines = markdown.split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Heading 1
|
||||
if line.startswith("# "):
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "heading_1",
|
||||
"heading_1": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[2:]}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Heading 2
|
||||
elif line.startswith("## "):
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "heading_2",
|
||||
"heading_2": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[3:]}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Heading 3
|
||||
elif line.startswith("### "):
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "heading_3",
|
||||
"heading_3": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[4:]}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Bullet list
|
||||
elif line.startswith("- ") or line.startswith("* "):
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "bulleted_list_item",
|
||||
"bulleted_list_item": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[2:]}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Numbered list
|
||||
elif match := re.match(r"^(\d+)\.\s+(.*)$", line):
|
||||
content = match.group(2) # Extract text after "number. "
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "numbered_list_item",
|
||||
"numbered_list_item": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": content}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Regular paragraph
|
||||
else:
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "paragraph",
|
||||
"paragraph": {
|
||||
"rich_text": [{"type": "text", "text": {"content": line}}]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return blocks
|
||||
"""Convert markdown content to Notion blocks using notion-markdown."""
|
||||
return to_notion(markdown)
|
||||
|
||||
async def create_page(
|
||||
self, title: str, content: str, parent_page_id: str | None = None
|
||||
|
|
|
|||
|
|
@ -63,6 +63,16 @@ class DocumentType(StrEnum):
|
|||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
|
||||
|
||||
|
||||
# Native Google document types → their legacy Composio equivalents.
|
||||
# Old documents may still carry the Composio type until they are re-indexed;
|
||||
# search, browse, and indexing must transparently handle both.
|
||||
NATIVE_TO_LEGACY_DOCTYPE: dict[str, str] = {
|
||||
"GOOGLE_DRIVE_FILE": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"GOOGLE_GMAIL_CONNECTOR": "COMPOSIO_GMAIL_CONNECTOR",
|
||||
"GOOGLE_CALENDAR_CONNECTOR": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
}
|
||||
|
||||
|
||||
class SearchSourceConnectorType(StrEnum):
|
||||
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
|
||||
TAVILY_API = "TAVILY_API"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -157,7 +158,7 @@ class ChucksHybridSearchRetriever:
|
|||
query_text: str,
|
||||
top_k: int,
|
||||
search_space_id: int,
|
||||
document_type: str | None = None,
|
||||
document_type: str | list[str] | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list | None = None,
|
||||
|
|
@ -217,18 +218,24 @@ class ChucksHybridSearchRetriever:
|
|||
func.coalesce(Document.status["state"].astext, "ready") != "deleting",
|
||||
]
|
||||
|
||||
# Add document type filter if provided
|
||||
# Add document type filter if provided (single string or list of strings)
|
||||
if document_type is not None:
|
||||
# Convert string to enum value if needed
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
# If the document type doesn't exist in the enum, return empty results
|
||||
return []
|
||||
type_list = (
|
||||
document_type if isinstance(document_type, list) else [document_type]
|
||||
)
|
||||
doc_type_enums = []
|
||||
for dt in type_list:
|
||||
if isinstance(dt, str):
|
||||
with contextlib.suppress(KeyError):
|
||||
doc_type_enums.append(DocumentType[dt])
|
||||
else:
|
||||
doc_type_enums.append(dt)
|
||||
if not doc_type_enums:
|
||||
return []
|
||||
if len(doc_type_enums) == 1:
|
||||
base_conditions.append(Document.document_type == doc_type_enums[0])
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
||||
|
||||
# Add time-based filtering if provided
|
||||
if start_date is not None:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -149,7 +150,7 @@ class DocumentHybridSearchRetriever:
|
|||
query_text: str,
|
||||
top_k: int,
|
||||
search_space_id: int,
|
||||
document_type: str | None = None,
|
||||
document_type: str | list[str] | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list | None = None,
|
||||
|
|
@ -197,18 +198,24 @@ class DocumentHybridSearchRetriever:
|
|||
func.coalesce(Document.status["state"].astext, "ready") != "deleting",
|
||||
]
|
||||
|
||||
# Add document type filter if provided
|
||||
# Add document type filter if provided (single string or list of strings)
|
||||
if document_type is not None:
|
||||
# Convert string to enum value if needed
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
# If the document type doesn't exist in the enum, return empty results
|
||||
return []
|
||||
type_list = (
|
||||
document_type if isinstance(document_type, list) else [document_type]
|
||||
)
|
||||
doc_type_enums = []
|
||||
for dt in type_list:
|
||||
if isinstance(dt, str):
|
||||
with contextlib.suppress(KeyError):
|
||||
doc_type_enums.append(DocumentType[dt])
|
||||
else:
|
||||
doc_type_enums.append(dt)
|
||||
if not doc_type_enums:
|
||||
return []
|
||||
if len(doc_type_enums) == 1:
|
||||
base_conditions.append(Document.document_type == doc_type_enums[0])
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
||||
|
||||
# Add time-based filtering if provided
|
||||
if start_date is not None:
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ async def airtable_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=airtable_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=airtable_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -316,7 +316,7 @@ async def airtable_callback(
|
|||
f"Duplicate Airtable connector detected for user {user_id} with email {user_email}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=airtable-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=airtable-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -348,7 +348,7 @@ async def airtable_callback(
|
|||
# Redirect to the frontend with success params for indexing config
|
||||
# Using query params to auto-open the popup with config view on new-chat page
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=airtable-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=airtable-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ async def clickup_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=clickup_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=clickup_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -326,7 +326,7 @@ async def clickup_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=clickup-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=clickup-connector"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
|
|||
|
|
@ -208,7 +208,7 @@ async def composio_callback(
|
|||
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=composio_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=composio_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -263,6 +263,15 @@ async def composio_callback(
|
|||
logger.info(
|
||||
f"Successfully got connected_account_id: {final_connected_account_id}"
|
||||
)
|
||||
# Wait for Composio to finish exchanging the auth code for tokens.
|
||||
try:
|
||||
service.wait_for_connection(final_connected_account_id, timeout=30.0)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"wait_for_connection timed out for {final_connected_account_id}, "
|
||||
"proceeding anyway",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Build entity_id for Composio API calls (same format as used in initiate)
|
||||
entity_id = f"surfsense_{user_id}"
|
||||
|
|
@ -370,7 +379,7 @@ async def composio_callback(
|
|||
toolkit_id, "composio-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={existing_connector.id}&view=configure"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector={frontend_connector_id}&connectorId={existing_connector.id}"
|
||||
)
|
||||
|
||||
# This is a NEW account - create a new connector
|
||||
|
|
@ -399,7 +408,7 @@ async def composio_callback(
|
|||
toolkit_id, "composio-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={db_connector.id}&view=configure"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector={frontend_connector_id}&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
except IntegrityError as e:
|
||||
|
|
@ -425,6 +434,211 @@ async def composio_callback(
|
|||
) from e
|
||||
|
||||
|
||||
COMPOSIO_CONNECTOR_TYPES = {
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/auth/composio/connector/reauth")
|
||||
async def reauth_composio_connector(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Initiate Composio re-authentication for an expired connected account.
|
||||
|
||||
Uses Composio's refresh API so the same connected_account_id stays valid
|
||||
after the user completes the OAuth flow again.
|
||||
|
||||
Query params:
|
||||
space_id: Search space ID the connector belongs to
|
||||
connector_id: ID of the existing Composio connector to re-authenticate
|
||||
return_url: Optional frontend path to redirect to after completion
|
||||
"""
|
||||
if not ComposioService.is_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Composio integration is not enabled."
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type.in_(COMPOSIO_CONNECTOR_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Composio connector not found or access denied",
|
||||
)
|
||||
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Composio connected account ID not found. Please reconnect the connector.",
|
||||
)
|
||||
|
||||
# Build callback URL with secure state
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id,
|
||||
user.id,
|
||||
toolkit_id=connector.config.get("toolkit_id", ""),
|
||||
connector_id=connector_id,
|
||||
return_url=return_url,
|
||||
)
|
||||
|
||||
callback_base = config.COMPOSIO_REDIRECT_URI
|
||||
if not callback_base:
|
||||
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
||||
callback_base = (
|
||||
f"{backend_url}/api/v1/auth/composio/connector/reauth/callback"
|
||||
)
|
||||
else:
|
||||
# Replace the normal callback path with the reauth one
|
||||
callback_base = callback_base.replace(
|
||||
"/auth/composio/connector/callback",
|
||||
"/auth/composio/connector/reauth/callback",
|
||||
)
|
||||
|
||||
callback_url = f"{callback_base}?state={state_encoded}"
|
||||
|
||||
service = ComposioService()
|
||||
refresh_result = service.refresh_connected_account(
|
||||
connected_account_id=connected_account_id,
|
||||
redirect_url=callback_url,
|
||||
)
|
||||
|
||||
if refresh_result["redirect_url"] is None:
|
||||
# Token refreshed server-side; clear auth_expired immediately
|
||||
if connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": False}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Composio account {connected_account_id} refreshed server-side (no redirect needed)"
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Authentication refreshed successfully.",
|
||||
}
|
||||
|
||||
logger.info(f"Initiating Composio re-auth for connector {connector_id}")
|
||||
return {"auth_url": refresh_result["redirect_url"]}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Composio re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Composio re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/composio/connector/reauth/callback")
|
||||
async def composio_reauth_callback(
|
||||
request: Request,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Handle Composio re-authentication callback.
|
||||
|
||||
Clears the auth_expired flag and redirects the user back to the frontend.
|
||||
The connected_account_id has not changed — Composio refreshed it in place.
|
||||
"""
|
||||
try:
|
||||
if not state:
|
||||
raise HTTPException(status_code=400, detail="Missing state parameter")
|
||||
|
||||
state_manager = get_state_manager()
|
||||
try:
|
||||
data = state_manager.validate_state(state)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid state parameter: {e!s}"
|
||||
) from e
|
||||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
return_url = data.get("return_url")
|
||||
|
||||
if not reauth_connector_id:
|
||||
raise HTTPException(status_code=400, detail="Missing connector_id in state")
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth callback",
|
||||
)
|
||||
|
||||
# Wait for Composio to finish processing new tokens before proceeding.
|
||||
# Without this, get_access_token() may return stale credentials.
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if connected_account_id:
|
||||
try:
|
||||
service = ComposioService()
|
||||
service.wait_for_connection(connected_account_id, timeout=30.0)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"wait_for_connection timed out for connector {reauth_connector_id}, "
|
||||
"proceeding anyway — tokens may not be ready yet",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Clear auth_expired flag
|
||||
connector.config = {**connector.config, "auth_expired": False}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info(f"Composio re-auth completed for connector {reauth_connector_id}")
|
||||
|
||||
if return_url and return_url.startswith("/"):
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}{return_url}")
|
||||
|
||||
frontend_connector_id = TOOLKIT_TO_FRONTEND_CONNECTOR_ID.get(
|
||||
connector.config.get("toolkit_id", ""), "composio-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector={frontend_connector_id}&connectorId={reauth_connector_id}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Composio reauth callback: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to complete Composio re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/connectors/{connector_id}/composio-drive/folders")
|
||||
async def list_composio_drive_folders(
|
||||
connector_id: int,
|
||||
|
|
@ -433,31 +647,23 @@ async def list_composio_drive_folders(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
List folders AND files in user's Google Drive via Composio with hierarchical support.
|
||||
List folders AND files in user's Google Drive via Composio.
|
||||
|
||||
This is called at index time from the manage connector page to display
|
||||
the complete file system (folders and files). Only folders are selectable.
|
||||
|
||||
Args:
|
||||
connector_id: ID of the Composio Google Drive connector
|
||||
parent_id: Optional parent folder ID to list contents (None for root)
|
||||
|
||||
Returns:
|
||||
JSON with list of items: {
|
||||
"items": [
|
||||
{"id": str, "name": str, "mimeType": str, "isFolder": bool, ...},
|
||||
...
|
||||
]
|
||||
}
|
||||
Uses the same GoogleDriveClient / list_folder_contents path as the native
|
||||
connector, with Composio-sourced credentials. This means auth errors
|
||||
propagate identically (Google returns 401 → exception → auth_expired flag).
|
||||
"""
|
||||
from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
if not ComposioService.is_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Composio integration is not enabled.",
|
||||
)
|
||||
|
||||
connector = None
|
||||
try:
|
||||
# Get connector and verify ownership
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
|
|
@ -474,7 +680,6 @@ async def list_composio_drive_folders(
|
|||
detail="Composio Google Drive connector not found or access denied",
|
||||
)
|
||||
|
||||
# Get Composio connected account ID from config
|
||||
composio_connected_account_id = connector.config.get(
|
||||
"composio_connected_account_id"
|
||||
)
|
||||
|
|
@ -484,63 +689,43 @@ async def list_composio_drive_folders(
|
|||
detail="Composio connected account not found. Please reconnect the connector.",
|
||||
)
|
||||
|
||||
# Initialize Composio service and fetch files
|
||||
service = ComposioService()
|
||||
entity_id = f"surfsense_{user.id}"
|
||||
credentials = build_composio_credentials(composio_connected_account_id)
|
||||
drive_client = GoogleDriveClient(session, connector_id, credentials=credentials)
|
||||
|
||||
# Fetch files/folders from Composio Google Drive
|
||||
files, _next_token, error = await service.get_drive_files(
|
||||
connected_account_id=composio_connected_account_id,
|
||||
entity_id=entity_id,
|
||||
folder_id=parent_id,
|
||||
page_size=100,
|
||||
)
|
||||
items, error = await list_folder_contents(drive_client, parent_id=parent_id)
|
||||
|
||||
if error:
|
||||
logger.error(f"Failed to list Composio Drive files: {error}")
|
||||
error_lower = error.lower()
|
||||
if (
|
||||
"401" in error
|
||||
or "invalid_grant" in error_lower
|
||||
or "token has been expired or revoked" in error_lower
|
||||
or "invalid credentials" in error_lower
|
||||
or "authentication failed" in error_lower
|
||||
):
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Marked Composio connector {connector_id} as auth_expired"
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list folder contents: {error}"
|
||||
)
|
||||
|
||||
# Transform files to match the expected format with isFolder field
|
||||
items = []
|
||||
for file_info in files:
|
||||
file_id = file_info.get("id", "") or file_info.get("fileId", "")
|
||||
file_name = (
|
||||
file_info.get("name", "") or file_info.get("fileName", "") or "Untitled"
|
||||
)
|
||||
mime_type = file_info.get("mimeType", "") or file_info.get("mime_type", "")
|
||||
|
||||
if not file_id:
|
||||
continue
|
||||
|
||||
is_folder = mime_type == "application/vnd.google-apps.folder"
|
||||
|
||||
items.append(
|
||||
{
|
||||
"id": file_id,
|
||||
"name": file_name,
|
||||
"mimeType": mime_type,
|
||||
"isFolder": is_folder,
|
||||
"parents": file_info.get("parents", []),
|
||||
"size": file_info.get("size"),
|
||||
"iconLink": file_info.get("iconLink"),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort: folders first, then files, both alphabetically
|
||||
folders = sorted(
|
||||
[item for item in items if item["isFolder"]],
|
||||
key=lambda x: x["name"].lower(),
|
||||
)
|
||||
files_list = sorted(
|
||||
[item for item in items if not item["isFolder"]],
|
||||
key=lambda x: x["name"].lower(),
|
||||
)
|
||||
items = folders + files_list
|
||||
|
||||
folder_count = len(folders)
|
||||
file_count = len(files_list)
|
||||
folder_count = sum(1 for item in items if item.get("isFolder", False))
|
||||
file_count = len(items) - folder_count
|
||||
|
||||
logger.info(
|
||||
f"Listed {len(items)} total items ({folder_count} folders, {file_count} files) for Composio connector {connector_id}"
|
||||
|
|
@ -553,6 +738,31 @@ async def list_composio_drive_folders(
|
|||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing Composio Drive contents: {e!s}", exc_info=True)
|
||||
error_lower = str(e).lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "token has been expired or revoked" in error_lower
|
||||
or "invalid credentials" in error_lower
|
||||
or "authentication failed" in error_lower
|
||||
or "401" in str(e)
|
||||
):
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Marked Composio connector {connector_id} as auth_expired"
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list Drive contents: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ SCOPES = [
|
|||
"read:space:confluence",
|
||||
"read:page:confluence",
|
||||
"read:comment:confluence",
|
||||
"write:page:confluence", # Required for creating/updating pages
|
||||
"delete:page:confluence", # Required for deleting pages
|
||||
"offline_access", # Required for refresh tokens
|
||||
]
|
||||
|
||||
|
|
@ -170,7 +172,7 @@ async def confluence_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=confluence_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=confluence_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -196,6 +198,8 @@ async def confluence_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.CONFLUENCE_REDIRECT_URI:
|
||||
|
|
@ -292,6 +296,46 @@ async def confluence_callback(
|
|||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
# Handle re-authentication: update existing connector instead of creating new one
|
||||
if reauth_connector_id:
|
||||
from sqlalchemy.future import select as sa_select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
result = await session.execute(
|
||||
sa_select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = {
|
||||
**connector_config,
|
||||
"auth_expired": False,
|
||||
}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Confluence connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}?reauth=success&connector=confluence-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?reauth=success&connector=confluence-connector"
|
||||
)
|
||||
|
||||
# Extract unique identifier from connector credentials
|
||||
connector_identifier = extract_identifier_from_credentials(
|
||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR, connector_config
|
||||
|
|
@ -310,7 +354,7 @@ async def confluence_callback(
|
|||
f"Duplicate Confluence connector detected for user {user_id} with instance {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=confluence-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=confluence-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -341,7 +385,7 @@ async def confluence_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=confluence-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=confluence-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
@ -372,6 +416,73 @@ async def confluence_callback(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/confluence/connector/reauth")
|
||||
async def reauth_confluence(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Confluence re-authentication to upgrade OAuth scopes."""
|
||||
try:
|
||||
from sqlalchemy.future import select
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Confluence connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
auth_params = {
|
||||
"audience": "api.atlassian.com",
|
||||
"client_id": config.ATLASSIAN_CLIENT_ID,
|
||||
"scope": " ".join(SCOPES),
|
||||
"redirect_uri": config.CONFLUENCE_REDIRECT_URI,
|
||||
"state": state_encoded,
|
||||
"response_type": "code",
|
||||
"prompt": "consent",
|
||||
}
|
||||
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
f"Initiating Confluence re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Confluence re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Confluence re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def refresh_confluence_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ async def discord_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=discord_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=discord_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -311,7 +311,7 @@ async def discord_callback(
|
|||
f"Duplicate Discord connector detected for user {user_id} with server {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=discord-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=discord-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -342,7 +342,7 @@ async def discord_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=discord-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=discord-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||
from fastapi.responses import RedirectResponse
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.google_gmail_connector import fetch_google_user_email
|
||||
|
|
@ -32,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
SCOPES = ["https://www.googleapis.com/auth/calendar.readonly"]
|
||||
SCOPES = ["https://www.googleapis.com/auth/calendar.events"]
|
||||
REDIRECT_URI = config.GOOGLE_CALENDAR_REDIRECT_URI
|
||||
|
||||
# Initialize security utilities
|
||||
|
|
@ -111,6 +113,66 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/calendar/connector/reauth")
|
||||
async def reauth_calendar(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Google Calendar re-authentication for an existing connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Google Calendar connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
prompt="consent",
|
||||
include_granted_scopes="true",
|
||||
state=state_encoded,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Initiating Google Calendar re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Calendar re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Calendar re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/calendar/connector/callback")
|
||||
async def calendar_callback(
|
||||
request: Request,
|
||||
|
|
@ -137,7 +199,7 @@ async def calendar_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_calendar_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=google_calendar_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -197,6 +259,42 @@ async def calendar_callback(
|
|||
# Mark that credentials are encrypted for backward compatibility
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = {**creds_dict}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Calendar connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-calendar-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# Check for duplicate connector (same account already connected)
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session,
|
||||
|
|
@ -210,7 +308,7 @@ async def calendar_callback(
|
|||
f"Duplicate Google Calendar connector detected for user {user_id} with email {user_email}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=google-calendar-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=google-calendar-connector"
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -236,7 +334,7 @@ async def calendar_callback(
|
|||
# Redirect to the frontend with success params for indexing config
|
||||
# Using query params to auto-open the popup with config view on new-chat page
|
||||
return RedirectResponse(
|
||||
f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-calendar-connector&connectorId={db_connector.id}"
|
||||
f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-calendar-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
except ValidationError as e:
|
||||
await session.rollback()
|
||||
|
|
|
|||
|
|
@ -257,7 +257,7 @@ async def drive_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_drive_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=google_drive_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -345,6 +345,7 @@ async def drive_callback(
|
|||
db_connector.config = {
|
||||
**creds_dict,
|
||||
"start_page_token": existing_start_page_token,
|
||||
"auth_expired": False,
|
||||
}
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
|
|
@ -360,7 +361,7 @@ async def drive_callback(
|
|||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
|
|
@ -375,7 +376,7 @@ async def drive_callback(
|
|||
f"Duplicate Google Drive connector detected for user {user_id} with email {user_email}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=google-drive-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=google-drive-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -425,7 +426,7 @@ async def drive_callback(
|
|||
)
|
||||
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -502,11 +503,35 @@ async def list_google_drive_folders(
|
|||
items, error = await list_folder_contents(drive_client, parent_id=parent_id)
|
||||
|
||||
if error:
|
||||
error_lower = error.lower()
|
||||
if (
|
||||
"401" in error
|
||||
or "invalid_grant" in error_lower
|
||||
or "token has been expired or revoked" in error_lower
|
||||
or "invalid credentials" in error_lower
|
||||
or "authentication failed" in error_lower
|
||||
):
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(f"Marked connector {connector_id} as auth_expired")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list folder contents: {error}"
|
||||
)
|
||||
|
||||
# Count folders and files for better logging
|
||||
folder_count = sum(1 for item in items if item.get("isFolder", False))
|
||||
file_count = len(items) - folder_count
|
||||
|
||||
|
|
@ -515,7 +540,6 @@ async def list_google_drive_folders(
|
|||
+ (f" in folder {parent_id}" if parent_id else " in ROOT")
|
||||
)
|
||||
|
||||
# Log first few items for debugging
|
||||
if items:
|
||||
logger.info(f"First 3 items: {[item.get('name') for item in items[:3]]}")
|
||||
|
||||
|
|
@ -525,6 +549,31 @@ async def list_google_drive_folders(
|
|||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing Drive contents: {e!s}", exc_info=True)
|
||||
error_lower = str(e).lower()
|
||||
if (
|
||||
"401" in str(e)
|
||||
or "invalid_grant" in error_lower
|
||||
or "token has been expired or revoked" in error_lower
|
||||
or "invalid credentials" in error_lower
|
||||
or "authentication failed" in error_lower
|
||||
):
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(f"Marked connector {connector_id} as auth_expired")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list Drive contents: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||
from fastapi.responses import RedirectResponse
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.google_gmail_connector import fetch_google_user_email
|
||||
|
|
@ -71,7 +73,7 @@ def get_google_flow():
|
|||
}
|
||||
},
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"openid",
|
||||
|
|
@ -129,6 +131,66 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/gmail/connector/reauth")
|
||||
async def reauth_gmail(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Gmail re-authentication for an existing connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Gmail connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
prompt="consent",
|
||||
include_granted_scopes="true",
|
||||
state=state_encoded,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Initiating Gmail re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Gmail re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Gmail re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/gmail/connector/callback")
|
||||
async def gmail_callback(
|
||||
request: Request,
|
||||
|
|
@ -168,7 +230,7 @@ async def gmail_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_gmail_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=google_gmail_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -228,6 +290,42 @@ async def gmail_callback(
|
|||
# Mark that credentials are encrypted for backward compatibility
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = {**creds_dict}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Gmail connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-gmail-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# Check for duplicate connector (same account already connected)
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session,
|
||||
|
|
@ -241,7 +339,7 @@ async def gmail_callback(
|
|||
f"Duplicate Gmail connector detected for user {user_id} with email {user_email}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=google-gmail-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=google-gmail-connector"
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -272,7 +370,7 @@ async def gmail_callback(
|
|||
# Redirect to the frontend with success params for indexing config
|
||||
# Using query params to auto-open the popup with config view on new-chat page
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-gmail-connector&connectorId={db_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-gmail-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
except IntegrityError as e:
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ ACCESSIBLE_RESOURCES_URL = "https://api.atlassian.com/oauth/token/accessible-res
|
|||
SCOPES = [
|
||||
"read:jira-work",
|
||||
"read:jira-user",
|
||||
"write:jira-work", # Required for creating/updating/deleting issues
|
||||
"offline_access", # Required for refresh tokens
|
||||
]
|
||||
|
||||
|
|
@ -167,7 +168,7 @@ async def jira_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=jira_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=jira_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -193,6 +194,8 @@ async def jira_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.JIRA_REDIRECT_URI:
|
||||
|
|
@ -310,6 +313,46 @@ async def jira_callback(
|
|||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
# Handle re-authentication: update existing connector instead of creating new one
|
||||
if reauth_connector_id:
|
||||
from sqlalchemy.future import select as sa_select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
result = await session.execute(
|
||||
sa_select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = {
|
||||
**connector_config,
|
||||
"auth_expired": False,
|
||||
}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Jira connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}?reauth=success&connector=jira-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?reauth=success&connector=jira-connector"
|
||||
)
|
||||
|
||||
# Extract unique identifier from connector credentials
|
||||
connector_identifier = extract_identifier_from_credentials(
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR, connector_config
|
||||
|
|
@ -328,7 +371,7 @@ async def jira_callback(
|
|||
f"Duplicate Jira connector detected for user {user_id} with instance {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=jira-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=jira-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -359,7 +402,7 @@ async def jira_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=jira-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=jira-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
@ -390,6 +433,73 @@ async def jira_callback(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/jira/connector/reauth")
|
||||
async def reauth_jira(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Jira re-authentication to upgrade OAuth scopes."""
|
||||
try:
|
||||
from sqlalchemy.future import select
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Jira connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
auth_params = {
|
||||
"audience": "api.atlassian.com",
|
||||
"client_id": config.ATLASSIAN_CLIENT_ID,
|
||||
"scope": " ".join(SCOPES),
|
||||
"redirect_uri": config.JIRA_REDIRECT_URI,
|
||||
"state": state_encoded,
|
||||
"response_type": "code",
|
||||
"prompt": "consent",
|
||||
}
|
||||
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
f"Initiating Jira re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Jira re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Jira re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def refresh_jira_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
|
|
|
|||
|
|
@ -12,8 +12,10 @@ import httpx
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.linear_connector import fetch_linear_organization_name
|
||||
|
|
@ -127,6 +129,70 @@ async def connect_linear(space_id: int, user: User = Depends(current_active_user
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/linear/connector/reauth")
|
||||
async def reauth_linear(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Linear re-authentication for an existing connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Linear connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.LINEAR_CLIENT_ID:
|
||||
raise HTTPException(status_code=500, detail="Linear OAuth not configured.")
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
auth_params = {
|
||||
"client_id": config.LINEAR_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"redirect_uri": config.LINEAR_REDIRECT_URI,
|
||||
"scope": " ".join(SCOPES),
|
||||
"state": state_encoded,
|
||||
}
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
f"Initiating Linear re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Linear re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Linear re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/linear/connector/callback")
|
||||
async def linear_callback(
|
||||
request: Request,
|
||||
|
|
@ -166,7 +232,7 @@ async def linear_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=linear_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=linear_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -267,6 +333,43 @@ async def linear_callback(
|
|||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
connector_config["organization_name"] = org_name
|
||||
db_connector.config = connector_config
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Linear connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=linear-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# Check for duplicate connector (same organization already connected)
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session,
|
||||
|
|
@ -280,7 +383,7 @@ async def linear_callback(
|
|||
f"Duplicate Linear connector detected for user {user_id} with org {org_name}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=linear-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=linear-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -292,6 +395,7 @@ async def linear_callback(
|
|||
org_name,
|
||||
)
|
||||
# Create new connector
|
||||
connector_config["organization_name"] = org_name
|
||||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
|
|
@ -311,7 +415,7 @@ async def linear_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=linear-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=linear-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
@ -342,6 +446,22 @@ async def linear_callback(
|
|||
) from e
|
||||
|
||||
|
||||
async def _mark_connector_auth_expired(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> None:
|
||||
"""Persist auth_expired flag in the connector config so the frontend can show a re-auth prompt."""
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired flag for connector {connector.id}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def refresh_linear_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
|
|
@ -375,6 +495,7 @@ async def refresh_linear_token(
|
|||
) from e
|
||||
|
||||
if not refresh_token:
|
||||
await _mark_connector_auth_expired(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No refresh token available. Please re-authenticate.",
|
||||
|
|
@ -417,6 +538,7 @@ async def refresh_linear_token(
|
|||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
await _mark_connector_auth_expired(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Linear authentication failed. Please re-authenticate.",
|
||||
|
|
@ -453,10 +575,16 @@ async def refresh_linear_token(
|
|||
credentials.expires_at = expires_at
|
||||
credentials.scope = token_json.get("scope")
|
||||
|
||||
# Update connector config with encrypted tokens
|
||||
# Update connector config with encrypted tokens, preserving non-credential fields
|
||||
credentials_dict = credentials.to_dict()
|
||||
credentials_dict["_token_encrypted"] = True
|
||||
if connector.config.get("organization_name"):
|
||||
credentials_dict["organization_name"] = connector.config[
|
||||
"organization_name"
|
||||
]
|
||||
credentials_dict.pop("auth_expired", None)
|
||||
connector.config = credentials_dict
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,8 +12,10 @@ import httpx
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
|
|
@ -124,6 +126,70 @@ async def connect_notion(space_id: int, user: User = Depends(current_active_user
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/notion/connector/reauth")
|
||||
async def reauth_notion(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Notion re-authentication for an existing connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Notion connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.NOTION_CLIENT_ID:
|
||||
raise HTTPException(status_code=500, detail="Notion OAuth not configured.")
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
auth_params = {
|
||||
"client_id": config.NOTION_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"owner": "user",
|
||||
"redirect_uri": config.NOTION_REDIRECT_URI,
|
||||
"state": state_encoded,
|
||||
}
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
f"Initiating Notion re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Notion re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Notion re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/notion/connector/callback")
|
||||
async def notion_callback(
|
||||
request: Request,
|
||||
|
|
@ -163,7 +229,7 @@ async def notion_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=notion_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=notion_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -266,6 +332,42 @@ async def notion_callback(
|
|||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = connector_config
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Notion connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=notion-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# Extract unique identifier from connector credentials
|
||||
connector_identifier = extract_identifier_from_credentials(
|
||||
SearchSourceConnectorType.NOTION_CONNECTOR, connector_config
|
||||
|
|
@ -284,7 +386,7 @@ async def notion_callback(
|
|||
f"Duplicate Notion connector detected for user {user_id} with workspace {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=notion-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=notion-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -315,7 +417,7 @@ async def notion_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=notion-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=notion-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
@ -346,6 +448,22 @@ async def notion_callback(
|
|||
) from e
|
||||
|
||||
|
||||
async def _mark_connector_auth_expired(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> None:
|
||||
"""Persist auth_expired flag in the connector config so the frontend can show a re-auth prompt."""
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired flag for connector {connector.id}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def refresh_notion_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
|
|
@ -379,6 +497,7 @@ async def refresh_notion_token(
|
|||
) from e
|
||||
|
||||
if not refresh_token:
|
||||
await _mark_connector_auth_expired(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No refresh token available. Please re-authenticate.",
|
||||
|
|
@ -421,6 +540,7 @@ async def refresh_notion_token(
|
|||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
await _mark_connector_auth_expired(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Notion authentication failed. Please re-authenticate.",
|
||||
|
|
@ -469,7 +589,9 @@ async def refresh_notion_token(
|
|||
# Update connector config with encrypted tokens
|
||||
credentials_dict = credentials.to_dict()
|
||||
credentials_dict["_token_encrypted"] = True
|
||||
credentials_dict.pop("auth_expired", None)
|
||||
connector.config = credentials_dict
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ from app.tasks.connector_indexers import (
|
|||
index_slack_messages,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.connector_naming import ensure_unique_connector_name
|
||||
from app.utils.indexing_locks import (
|
||||
acquire_connector_indexing_lock,
|
||||
release_connector_indexing_lock,
|
||||
|
|
@ -189,6 +190,12 @@ async def create_search_source_connector(
|
|||
# Prepare connector data
|
||||
connector_data = connector.model_dump()
|
||||
|
||||
# MCP connectors support multiple instances — ensure unique name
|
||||
if connector.connector_type == SearchSourceConnectorType.MCP_CONNECTOR:
|
||||
connector_data["name"] = await ensure_unique_connector_name(
|
||||
session, connector_data["name"], search_space_id, user.id
|
||||
)
|
||||
|
||||
# Automatically set next_scheduled_at if periodic indexing is enabled
|
||||
if (
|
||||
connector.periodic_indexing_enabled
|
||||
|
|
@ -949,23 +956,46 @@ async def index_connector_content(
|
|||
index_google_drive_files_task,
|
||||
)
|
||||
|
||||
if not drive_items or not drive_items.has_items():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive indexing requires drive_items body parameter with folders or files",
|
||||
if drive_items and drive_items.has_items():
|
||||
logger.info(
|
||||
f"Triggering Google Drive indexing for connector {connector_id} into search space {search_space_id}, "
|
||||
f"folders: {len(drive_items.folders)}, files: {len(drive_items.files)}"
|
||||
)
|
||||
items_dict = drive_items.model_dump()
|
||||
else:
|
||||
# Quick Index / periodic sync: fall back to stored config
|
||||
config = connector.config or {}
|
||||
selected_folders = config.get("selected_folders", [])
|
||||
selected_files = config.get("selected_files", [])
|
||||
if not selected_folders and not selected_files:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive indexing requires folders or files to be configured. "
|
||||
"Please select folders/files to index.",
|
||||
)
|
||||
indexing_options = config.get(
|
||||
"indexing_options",
|
||||
{
|
||||
"max_files_per_folder": 100,
|
||||
"incremental_sync": True,
|
||||
"include_subfolders": True,
|
||||
},
|
||||
)
|
||||
items_dict = {
|
||||
"folders": selected_folders,
|
||||
"files": selected_files,
|
||||
"indexing_options": indexing_options,
|
||||
}
|
||||
logger.info(
|
||||
f"Triggering Google Drive indexing for connector {connector_id} into search space {search_space_id} "
|
||||
f"using existing config"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Google Drive indexing for connector {connector_id} into search space {search_space_id}, "
|
||||
f"folders: {len(drive_items.folders)}, files: {len(drive_items.files)}"
|
||||
)
|
||||
|
||||
# Pass structured data to Celery task
|
||||
index_google_drive_files_task.delay(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
str(user.id),
|
||||
drive_items.model_dump(), # Convert to dict for JSON serialization
|
||||
items_dict,
|
||||
)
|
||||
response_message = "Google Drive indexing started in the background."
|
||||
|
||||
|
|
@ -1061,7 +1091,7 @@ async def index_connector_content(
|
|||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_composio_connector_task,
|
||||
index_google_drive_files_task,
|
||||
)
|
||||
|
||||
# For Composio Google Drive, if drive_items is provided, update connector config
|
||||
|
|
@ -1095,34 +1125,72 @@ async def index_connector_content(
|
|||
else:
|
||||
logger.info(
|
||||
f"Triggering Composio Google Drive indexing for connector {connector_id} into search space {search_space_id} "
|
||||
f"using existing config (from {indexing_from} to {indexing_to})"
|
||||
f"using existing config"
|
||||
)
|
||||
|
||||
index_composio_connector_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
# Extract config and build items_dict for index_google_drive_files_task
|
||||
config = connector.config or {}
|
||||
selected_folders = config.get("selected_folders", [])
|
||||
selected_files = config.get("selected_files", [])
|
||||
if not selected_folders and not selected_files:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Composio Google Drive indexing requires folders or files to be configured. "
|
||||
"Please select folders/files to index.",
|
||||
)
|
||||
indexing_options = config.get(
|
||||
"indexing_options",
|
||||
{
|
||||
"max_files_per_folder": 100,
|
||||
"incremental_sync": True,
|
||||
"include_subfolders": True,
|
||||
},
|
||||
)
|
||||
items_dict = {
|
||||
"folders": selected_folders,
|
||||
"files": selected_files,
|
||||
"indexing_options": indexing_options,
|
||||
}
|
||||
index_google_drive_files_task.delay(
|
||||
connector_id, search_space_id, str(user.id), items_dict
|
||||
)
|
||||
response_message = (
|
||||
"Composio Google Drive indexing started in the background."
|
||||
)
|
||||
|
||||
elif connector.connector_type in [
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]:
|
||||
elif (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_composio_connector_task,
|
||||
index_google_gmail_messages_task,
|
||||
)
|
||||
|
||||
# For Composio Gmail and Calendar, use the same date calculation logic as normal connectors
|
||||
# This ensures consistent behavior and uses last_indexed_at to reduce API calls
|
||||
# (includes special case: if indexed today, go back 1 day to avoid missing data)
|
||||
logger.info(
|
||||
f"Triggering Composio connector indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
f"Triggering Composio Gmail indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_composio_connector_task.delay(
|
||||
index_google_gmail_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Composio connector indexing started in the background."
|
||||
response_message = "Composio Gmail indexing started in the background."
|
||||
|
||||
elif (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_google_calendar_events_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Composio Google Calendar indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_google_calendar_events_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = (
|
||||
"Composio Google Calendar indexing started in the background."
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -1229,6 +1297,48 @@ async def run_slack_indexing(
|
|||
)
|
||||
|
||||
|
||||
_AUTH_ERROR_PATTERNS = (
|
||||
"failed to refresh linear oauth",
|
||||
"failed to refresh your notion connection",
|
||||
"failed to refresh notion token",
|
||||
"authentication failed",
|
||||
"auth_expired",
|
||||
"token has been expired or revoked",
|
||||
"invalid_grant",
|
||||
)
|
||||
|
||||
|
||||
def _is_auth_error(error_message: str) -> bool:
|
||||
"""Check if an error message indicates an OAuth token expiry failure."""
|
||||
if not error_message:
|
||||
return False
|
||||
lower = error_message.lower()
|
||||
return any(pattern in lower for pattern in _AUTH_ERROR_PATTERNS)
|
||||
|
||||
|
||||
async def _persist_auth_expired(session: AsyncSession, connector_id: int) -> None:
|
||||
"""Flag a connector as auth_expired so the frontend shows a re-auth prompt."""
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(f"Marked connector {connector_id} as auth_expired")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def _run_indexing_with_notifications(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
|
|
@ -1513,6 +1623,8 @@ async def _run_indexing_with_notifications(
|
|||
else:
|
||||
# Actual failure
|
||||
logger.error(f"Indexing failed: {error_or_warning}")
|
||||
if _is_auth_error(str(error_or_warning)):
|
||||
await _persist_auth_expired(session, connector_id)
|
||||
if notification:
|
||||
# Refresh notification to ensure it's not stale after indexing function commits
|
||||
await session.refresh(notification)
|
||||
|
|
@ -1577,6 +1689,9 @@ async def _run_indexing_with_notifications(
|
|||
except Exception as e:
|
||||
logger.error(f"Error in indexing task: {e!s}", exc_info=True)
|
||||
|
||||
if _is_auth_error(str(e)):
|
||||
await _persist_auth_expired(session, connector_id)
|
||||
|
||||
# Update notification on exception
|
||||
if notification:
|
||||
try:
|
||||
|
|
@ -2172,10 +2287,9 @@ async def run_google_gmail_indexing(
|
|||
end_date: str | None,
|
||||
update_last_indexed: bool,
|
||||
on_heartbeat_callback=None,
|
||||
) -> tuple[int, str | None]:
|
||||
# Use a reasonable default for max_messages
|
||||
) -> tuple[int, int, str | None]:
|
||||
max_messages = 1000
|
||||
indexed_count, error_message = await index_google_gmail_messages(
|
||||
indexed_count, skipped_count, error_message = await index_google_gmail_messages(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
|
|
@ -2186,8 +2300,7 @@ async def run_google_gmail_indexing(
|
|||
max_messages=max_messages,
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
)
|
||||
# index_google_gmail_messages returns (int, str) but we need (int, str | None)
|
||||
return indexed_count, error_message if error_message else None
|
||||
return indexed_count, skipped_count, error_message if error_message else None
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
|
|
@ -2223,6 +2336,7 @@ async def run_google_drive_indexing(
|
|||
items = GoogleDriveIndexRequest(**items_dict)
|
||||
indexing_options = items.indexing_options
|
||||
total_indexed = 0
|
||||
total_skipped = 0
|
||||
errors = []
|
||||
|
||||
# Get connector info for notification
|
||||
|
|
@ -2260,7 +2374,11 @@ async def run_google_drive_indexing(
|
|||
# Index each folder with indexing options
|
||||
for folder in items.folders:
|
||||
try:
|
||||
indexed_count, error_message = await index_google_drive_files(
|
||||
(
|
||||
indexed_count,
|
||||
skipped_count,
|
||||
error_message,
|
||||
) = await index_google_drive_files(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
|
|
@ -2272,6 +2390,7 @@ async def run_google_drive_indexing(
|
|||
max_files=indexing_options.max_files_per_folder,
|
||||
include_subfolders=indexing_options.include_subfolders,
|
||||
)
|
||||
total_skipped += skipped_count
|
||||
if error_message:
|
||||
errors.append(f"Folder '{folder.name}': {error_message}")
|
||||
else:
|
||||
|
|
@ -2312,9 +2431,15 @@ async def run_google_drive_indexing(
|
|||
logger.error(
|
||||
f"Google Drive indexing completed with errors for connector {connector_id}: {error_message}"
|
||||
)
|
||||
if _is_auth_error(error_message):
|
||||
await _persist_auth_expired(session, connector_id)
|
||||
error_message = (
|
||||
"Google Drive authentication expired. Please re-authenticate."
|
||||
)
|
||||
else:
|
||||
# Update notification to storing stage
|
||||
if notification:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.connector_indexing.notify_indexing_progress(
|
||||
session=session,
|
||||
notification=notification,
|
||||
|
|
@ -2338,6 +2463,7 @@ async def run_google_drive_indexing(
|
|||
notification=notification,
|
||||
indexed_count=total_indexed,
|
||||
error_message=error_message,
|
||||
skipped_count=total_skipped,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -2715,9 +2841,14 @@ async def create_mcp_connector(
|
|||
"You don't have permission to create connectors in this search space",
|
||||
)
|
||||
|
||||
# Ensure unique name across MCP connectors in this search space
|
||||
unique_name = await ensure_unique_connector_name(
|
||||
session, connector_data.name, search_space_id, user.id
|
||||
)
|
||||
|
||||
# Create the connector with single server config
|
||||
db_connector = SearchSourceConnector(
|
||||
name=connector_data.name,
|
||||
name=unique_name,
|
||||
connector_type=SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
is_indexable=False, # MCP connectors are not indexable
|
||||
config={"server_config": connector_data.server_config.model_dump()},
|
||||
|
|
@ -3136,6 +3267,12 @@ async def get_drive_picker_token(
|
|||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get Drive picker token: {e!s}", exc_info=True)
|
||||
if _is_auth_error(str(e)):
|
||||
await _persist_auth_expired(session, connector_id)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to retrieve access token. Check server logs for details.",
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ async def slack_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=slack_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=slack_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -296,7 +296,7 @@ async def slack_callback(
|
|||
f"Duplicate Slack connector detected for user {user_id} with workspace {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=slack-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=slack-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -328,7 +328,7 @@ async def slack_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=slack-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=slack-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
|
|||
|
|
@ -36,32 +36,14 @@ TOOLKIT_TO_CONNECTOR_TYPE = {
|
|||
}
|
||||
|
||||
# Mapping of toolkit IDs to document types
|
||||
TOOLKIT_TO_DOCUMENT_TYPE = {
|
||||
"googledrive": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"gmail": "COMPOSIO_GMAIL_CONNECTOR",
|
||||
"googlecalendar": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
}
|
||||
# Google Drive, Gmail, Calendar use unified native indexers - not in this registry
|
||||
TOOLKIT_TO_DOCUMENT_TYPE: dict[str, str] = {}
|
||||
|
||||
# Mapping of toolkit IDs to their indexer functions
|
||||
# Format: toolkit_id -> (module_path, function_name, supports_date_filter)
|
||||
# supports_date_filter: True if the indexer accepts start_date/end_date params
|
||||
TOOLKIT_TO_INDEXER = {
|
||||
"googledrive": (
|
||||
"app.connectors.composio_google_drive_connector",
|
||||
"index_composio_google_drive",
|
||||
False, # Google Drive doesn't use date filtering
|
||||
),
|
||||
"gmail": (
|
||||
"app.connectors.composio_gmail_connector",
|
||||
"index_composio_gmail",
|
||||
True, # Gmail uses date filtering
|
||||
),
|
||||
"googlecalendar": (
|
||||
"app.connectors.composio_google_calendar_connector",
|
||||
"index_composio_google_calendar",
|
||||
True, # Calendar uses date filtering
|
||||
),
|
||||
}
|
||||
# Google Drive, Gmail, Calendar use unified native indexers - not in this registry
|
||||
TOOLKIT_TO_INDEXER: dict[str, tuple[str, str, bool]] = {}
|
||||
|
||||
|
||||
class ComposioService:
|
||||
|
|
@ -247,6 +229,68 @@ class ComposioService:
|
|||
)
|
||||
return False
|
||||
|
||||
def refresh_connected_account(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
redirect_url: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Refresh an expired Composio connected account.
|
||||
|
||||
For OAuth flows this returns a redirect_url the user must visit to
|
||||
re-authorise. The same connected_account_id stays valid afterwards.
|
||||
|
||||
Args:
|
||||
connected_account_id: The Composio connected account nanoid.
|
||||
redirect_url: Where Composio should redirect after re-auth.
|
||||
|
||||
Returns:
|
||||
Dict with id, status, and redirect_url (None when no redirect needed).
|
||||
"""
|
||||
kwargs: dict[str, Any] = {}
|
||||
if redirect_url is not None:
|
||||
kwargs["body_redirect_url"] = redirect_url
|
||||
result = self.client.connected_accounts.refresh(
|
||||
nanoid=connected_account_id,
|
||||
**kwargs,
|
||||
)
|
||||
return {
|
||||
"id": result.id,
|
||||
"status": result.status,
|
||||
"redirect_url": result.redirect_url,
|
||||
}
|
||||
|
||||
def wait_for_connection(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
timeout: float = 30.0,
|
||||
) -> str:
|
||||
"""
|
||||
Poll Composio until the connected account reaches ACTIVE status.
|
||||
|
||||
Must be called after refresh() / initiate() to ensure Composio has
|
||||
finished exchanging the authorization code for valid tokens.
|
||||
|
||||
Returns:
|
||||
The final account status string (should be "ACTIVE").
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the account does not become ACTIVE within *timeout*.
|
||||
"""
|
||||
try:
|
||||
account = self.client.connected_accounts.wait_for_connection(
|
||||
id=connected_account_id,
|
||||
timeout=timeout,
|
||||
)
|
||||
status = getattr(account, "status", "UNKNOWN")
|
||||
logger.info(f"Composio account {connected_account_id} is now {status}")
|
||||
return status
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Timeout/error waiting for Composio account {connected_account_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
def get_access_token(self, connected_account_id: str) -> str:
|
||||
"""Retrieve the raw OAuth access token for a Composio connected account."""
|
||||
account = self.client.connected_accounts.get(nanoid=connected_account_id)
|
||||
|
|
@ -258,6 +302,12 @@ class ComposioService:
|
|||
access_token = getattr(token, "access_token", None)
|
||||
if not access_token:
|
||||
raise ValueError(f"No access_token in state.val for {connected_account_id}")
|
||||
if len(access_token) < 20:
|
||||
raise ValueError(
|
||||
f"Composio returned a masked access_token ({len(access_token)} chars) "
|
||||
f"for account {connected_account_id}. Disable 'Mask Connected Account "
|
||||
f"Secrets' in Composio dashboard: Settings → Project Settings."
|
||||
)
|
||||
return access_token
|
||||
|
||||
async def execute_tool(
|
||||
|
|
|
|||
13
surfsense_backend/app/services/confluence/__init__.py
Normal file
13
surfsense_backend/app/services/confluence/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from app.services.confluence.kb_sync_service import ConfluenceKBSyncService
|
||||
from app.services.confluence.tool_metadata_service import (
|
||||
ConfluencePage,
|
||||
ConfluenceToolMetadataService,
|
||||
ConfluenceWorkspace,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConfluenceKBSyncService",
|
||||
"ConfluencePage",
|
||||
"ConfluenceToolMetadataService",
|
||||
"ConfluenceWorkspace",
|
||||
]
|
||||
240
surfsense_backend/app/services/confluence/kb_sync_service.py
Normal file
240
surfsense_backend/app/services/confluence/kb_sync_service.py
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfluenceKBSyncService:
|
||||
"""Syncs Confluence page documents to the knowledge base after HITL actions."""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
page_id: str,
|
||||
page_title: str,
|
||||
space_id: str,
|
||||
body_content: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.CONFLUENCE_CONNECTOR, page_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (body_content or "").strip()
|
||||
if not indexable_content:
|
||||
indexable_content = f"Confluence Page: {page_title}"
|
||||
|
||||
page_content = f"# {page_title}\n\n{indexable_content}"
|
||||
|
||||
content_hash = generate_content_hash(page_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"page_title": page_title,
|
||||
"space_id": space_id,
|
||||
"document_type": "Confluence Page",
|
||||
"connector_type": "Confluence",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
page_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = f"Confluence Page: {page_title}\n\n{page_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(page_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document = Document(
|
||||
title=page_title,
|
||||
document_type=DocumentType.CONFLUENCE_CONNECTOR,
|
||||
document_metadata={
|
||||
"page_id": page_id,
|
||||
"page_title": page_title,
|
||||
"space_id": space_id,
|
||||
"comment_count": 0,
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, page=%s",
|
||||
document.id,
|
||||
page_title,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for page %s: %s",
|
||||
page_title,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def sync_after_update(
|
||||
self,
|
||||
document_id: int,
|
||||
page_id: str,
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
document = await self.db_session.get(Document, document_id)
|
||||
if not document:
|
||||
return {"status": "not_indexed"}
|
||||
|
||||
connector_id = document.connector_id
|
||||
if not connector_id:
|
||||
return {"status": "error", "message": "Document has no connector_id"}
|
||||
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=self.db_session, connector_id=connector_id
|
||||
)
|
||||
page_data = await client.get_page(page_id)
|
||||
await client.close()
|
||||
|
||||
page_title = page_data.get("title", "")
|
||||
body_obj = page_data.get("body", {})
|
||||
body_content = ""
|
||||
if isinstance(body_obj, dict):
|
||||
storage = body_obj.get("storage", {})
|
||||
if isinstance(storage, dict):
|
||||
body_content = storage.get("value", "")
|
||||
|
||||
page_content = f"# {page_title}\n\n{body_content}"
|
||||
|
||||
if not page_content.strip():
|
||||
return {"status": "error", "message": "Page produced empty content"}
|
||||
|
||||
space_id = (document.document_metadata or {}).get("space_id", "")
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
doc_meta = {
|
||||
"page_title": page_title,
|
||||
"space_id": space_id,
|
||||
"document_type": "Confluence Page",
|
||||
"connector_type": "Confluence",
|
||||
}
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
page_content, user_llm, doc_meta
|
||||
)
|
||||
else:
|
||||
summary_content = f"Confluence Page: {page_title}\n\n{page_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(page_content)
|
||||
|
||||
document.title = page_title
|
||||
document.content = summary_content
|
||||
document.content_hash = generate_content_hash(page_content, search_space_id)
|
||||
document.embedding = summary_embedding
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
document.document_metadata = {
|
||||
**(document.document_metadata or {}),
|
||||
"page_id": page_id,
|
||||
"page_title": page_title,
|
||||
"space_id": space_id,
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
flag_modified(document, "document_metadata")
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync successful for document %s (%s)",
|
||||
document_id,
|
||||
page_title,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"KB sync failed for document %s: %s", document_id, e, exc_info=True
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
|
@ -0,0 +1,314 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import and_, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfluenceWorkspace:
|
||||
"""Represents a Confluence connector as a workspace for tool context."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
base_url: str
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: SearchSourceConnector) -> "ConfluenceWorkspace":
|
||||
return cls(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
base_url=connector.config.get("base_url", ""),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"base_url": self.base_url,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfluencePage:
|
||||
"""Represents an indexed Confluence page resolved from the knowledge base."""
|
||||
|
||||
page_id: str
|
||||
page_title: str
|
||||
space_id: str
|
||||
connector_id: int
|
||||
document_id: int
|
||||
indexed_at: str | None
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, document: Document) -> "ConfluencePage":
|
||||
meta = document.document_metadata or {}
|
||||
return cls(
|
||||
page_id=meta.get("page_id", ""),
|
||||
page_title=meta.get("page_title", document.title),
|
||||
space_id=meta.get("space_id", ""),
|
||||
connector_id=document.connector_id,
|
||||
document_id=document.id,
|
||||
indexed_at=meta.get("indexed_at"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"page_id": self.page_id,
|
||||
"page_title": self.page_title,
|
||||
"space_id": self.space_id,
|
||||
"connector_id": self.connector_id,
|
||||
"document_id": self.document_id,
|
||||
"indexed_at": self.indexed_at,
|
||||
}
|
||||
|
||||
|
||||
class ConfluenceToolMetadataService:
|
||||
"""Builds interrupt context for Confluence HITL tools."""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def _check_account_health(self, connector: SearchSourceConnector) -> bool:
|
||||
"""Check if the Confluence connector auth is still valid.
|
||||
|
||||
Returns True if auth is expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=self._db_session, connector_id=connector.id
|
||||
)
|
||||
await client._get_valid_token()
|
||||
await client.close()
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Confluence connector %s health check failed: %s", connector.id, e
|
||||
)
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return True
|
||||
|
||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||
"""Return context needed to create a new Confluence page.
|
||||
|
||||
Fetches all connected accounts, and for the first healthy one fetches spaces.
|
||||
"""
|
||||
connectors = await self._get_all_confluence_connectors(search_space_id, user_id)
|
||||
if not connectors:
|
||||
return {"error": "No Confluence account connected"}
|
||||
|
||||
accounts = []
|
||||
spaces = []
|
||||
fetched_context = False
|
||||
|
||||
for connector in connectors:
|
||||
auth_expired = await self._check_account_health(connector)
|
||||
workspace = ConfluenceWorkspace.from_connector(connector)
|
||||
accounts.append(
|
||||
{
|
||||
**workspace.to_dict(),
|
||||
"auth_expired": auth_expired,
|
||||
}
|
||||
)
|
||||
|
||||
if not auth_expired and not fetched_context:
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=self._db_session, connector_id=connector.id
|
||||
)
|
||||
raw_spaces = await client.get_all_spaces()
|
||||
spaces = [
|
||||
{"id": s.get("id"), "key": s.get("key"), "name": s.get("name")}
|
||||
for s in raw_spaces
|
||||
]
|
||||
await client.close()
|
||||
fetched_context = True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to fetch Confluence spaces for connector %s: %s",
|
||||
connector.id,
|
||||
e,
|
||||
)
|
||||
|
||||
return {
|
||||
"accounts": accounts,
|
||||
"spaces": spaces,
|
||||
}
|
||||
|
||||
async def get_update_context(
|
||||
self, search_space_id: int, user_id: str, page_ref: str
|
||||
) -> dict:
|
||||
"""Return context needed to update an indexed Confluence page.
|
||||
|
||||
Resolves the page from KB, then fetches current content and version from API.
|
||||
"""
|
||||
document = await self._resolve_page(search_space_id, user_id, page_ref)
|
||||
if not document:
|
||||
return {
|
||||
"error": f"Page '{page_ref}' not found in your synced Confluence pages. "
|
||||
"Please make sure the page is indexed in your knowledge base."
|
||||
}
|
||||
|
||||
connector = await self._get_connector_for_document(document, user_id)
|
||||
if not connector:
|
||||
return {"error": "Connector not found or access denied"}
|
||||
|
||||
auth_expired = await self._check_account_health(connector)
|
||||
if auth_expired:
|
||||
return {
|
||||
"error": "Confluence authentication has expired. Please re-authenticate.",
|
||||
"auth_expired": True,
|
||||
"connector_id": connector.id,
|
||||
}
|
||||
|
||||
workspace = ConfluenceWorkspace.from_connector(connector)
|
||||
page = ConfluencePage.from_document(document)
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=self._db_session, connector_id=connector.id
|
||||
)
|
||||
page_data = await client.get_page(page.page_id)
|
||||
await client.close()
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"401" in error_str
|
||||
or "403" in error_str
|
||||
or "authentication" in error_str
|
||||
):
|
||||
return {
|
||||
"error": f"Failed to fetch Confluence page: {e!s}",
|
||||
"auth_expired": True,
|
||||
"connector_id": connector.id,
|
||||
}
|
||||
return {"error": f"Failed to fetch Confluence page: {e!s}"}
|
||||
|
||||
body_storage = ""
|
||||
body_obj = page_data.get("body", {})
|
||||
if isinstance(body_obj, dict):
|
||||
storage = body_obj.get("storage", {})
|
||||
if isinstance(storage, dict):
|
||||
body_storage = storage.get("value", "")
|
||||
|
||||
version_obj = page_data.get("version", {})
|
||||
version_number = (
|
||||
version_obj.get("number", 1) if isinstance(version_obj, dict) else 1
|
||||
)
|
||||
|
||||
return {
|
||||
"account": {**workspace.to_dict(), "auth_expired": False},
|
||||
"page": {
|
||||
"page_id": page.page_id,
|
||||
"page_title": page_data.get("title", page.page_title),
|
||||
"space_id": page.space_id,
|
||||
"body": body_storage,
|
||||
"version": version_number,
|
||||
"document_id": page.document_id,
|
||||
"indexed_at": page.indexed_at,
|
||||
},
|
||||
}
|
||||
|
||||
async def get_deletion_context(
|
||||
self, search_space_id: int, user_id: str, page_ref: str
|
||||
) -> dict:
|
||||
"""Return context needed to delete a Confluence page (KB metadata only)."""
|
||||
document = await self._resolve_page(search_space_id, user_id, page_ref)
|
||||
if not document:
|
||||
return {
|
||||
"error": f"Page '{page_ref}' not found in your synced Confluence pages. "
|
||||
"Please make sure the page is indexed in your knowledge base."
|
||||
}
|
||||
|
||||
connector = await self._get_connector_for_document(document, user_id)
|
||||
if not connector:
|
||||
return {"error": "Connector not found or access denied"}
|
||||
|
||||
auth_expired = connector.config.get("auth_expired", False)
|
||||
workspace = ConfluenceWorkspace.from_connector(connector)
|
||||
page = ConfluencePage.from_document(document)
|
||||
|
||||
return {
|
||||
"account": {**workspace.to_dict(), "auth_expired": auth_expired},
|
||||
"page": page.to_dict(),
|
||||
}
|
||||
|
||||
async def _resolve_page(
|
||||
self, search_space_id: int, user_id: str, page_ref: str
|
||||
) -> Document | None:
|
||||
"""Resolve a page from KB: page_title -> document.title."""
|
||||
ref_lower = page_ref.lower()
|
||||
|
||||
result = await self._db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector, Document.connector_id == SearchSourceConnector.id
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.CONFLUENCE_CONNECTOR,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
or_(
|
||||
func.lower(Document.document_metadata.op("->>")("page_title"))
|
||||
== ref_lower,
|
||||
func.lower(Document.title) == ref_lower,
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def _get_all_confluence_connectors(
|
||||
self, search_space_id: int, user_id: str
|
||||
) -> list[SearchSourceConnector]:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def _get_connector_for_document(
|
||||
self, document: Document, user_id: str
|
||||
) -> SearchSourceConnector | None:
|
||||
if not document.connector_id:
|
||||
return None
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
|
@ -11,6 +11,7 @@ from sqlalchemy.future import select
|
|||
from tavily import TavilyClient
|
||||
|
||||
from app.db import (
|
||||
NATIVE_TO_LEGACY_DOCTYPE,
|
||||
Chunk,
|
||||
Document,
|
||||
SearchSourceConnector,
|
||||
|
|
@ -219,7 +220,7 @@ class ConnectorService:
|
|||
self,
|
||||
query_text: str,
|
||||
search_space_id: int,
|
||||
document_type: str,
|
||||
document_type: str | list[str],
|
||||
top_k: int = 20,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
|
|
@ -241,7 +242,8 @@ class ConnectorService:
|
|||
Args:
|
||||
query_text: The search query text
|
||||
search_space_id: The search space ID to search within
|
||||
document_type: Document type to filter (e.g., "FILE", "CRAWLED_URL")
|
||||
document_type: Document type(s) to filter (e.g., "FILE", "CRAWLED_URL",
|
||||
or a list for multi-type queries)
|
||||
top_k: Number of results to return
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
|
|
@ -254,6 +256,16 @@ class ConnectorService:
|
|||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Expand native Google types to include legacy Composio equivalents
|
||||
# so old documents remain searchable until re-indexed.
|
||||
if isinstance(document_type, str) and document_type in NATIVE_TO_LEGACY_DOCTYPE:
|
||||
resolved_type: str | list[str] = [
|
||||
document_type,
|
||||
NATIVE_TO_LEGACY_DOCTYPE[document_type],
|
||||
]
|
||||
else:
|
||||
resolved_type = document_type
|
||||
|
||||
# RRF constant
|
||||
k = 60
|
||||
|
||||
|
|
@ -276,7 +288,7 @@ class ConnectorService:
|
|||
"query_text": query_text,
|
||||
"top_k": retriever_top_k,
|
||||
"search_space_id": search_space_id,
|
||||
"document_type": document_type,
|
||||
"document_type": resolved_type,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"query_embedding": query_embedding,
|
||||
|
|
@ -2746,299 +2758,6 @@ class ConnectorService:
|
|||
|
||||
return result_object, obsidian_docs
|
||||
|
||||
# =========================================================================
|
||||
# Composio Connector Search Methods
|
||||
# =========================================================================
|
||||
|
||||
async def search_composio_google_drive(
|
||||
self,
|
||||
user_query: str,
|
||||
search_space_id: int,
|
||||
top_k: int = 20,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Search for Composio Google Drive files and return both the source information
|
||||
and langchain documents.
|
||||
|
||||
Uses combined chunk-level and document-level hybrid search with RRF fusion.
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
search_space_id: The search space ID to search in
|
||||
top_k: Maximum number of results to return
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
composio_drive_docs = await self._combined_rrf_search(
|
||||
query_text=user_query,
|
||||
search_space_id=search_space_id,
|
||||
document_type="COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
top_k=top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# Early return if no results
|
||||
if not composio_drive_docs:
|
||||
return {
|
||||
"id": 54,
|
||||
"name": "Google Drive (Composio)",
|
||||
"type": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"sources": [],
|
||||
}, []
|
||||
|
||||
def _title_fn(doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
|
||||
return (
|
||||
doc_info.get("title")
|
||||
or metadata.get("title")
|
||||
or metadata.get("file_name")
|
||||
or "Untitled Document"
|
||||
)
|
||||
|
||||
def _url_fn(_doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
|
||||
return metadata.get("url") or metadata.get("web_view_link") or ""
|
||||
|
||||
def _description_fn(
|
||||
chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
|
||||
) -> str:
|
||||
description = self._chunk_preview(chunk.get("content", ""), limit=200)
|
||||
info_parts = []
|
||||
mime_type = metadata.get("mime_type")
|
||||
modified_time = metadata.get("modified_time")
|
||||
if mime_type:
|
||||
info_parts.append(f"Type: {mime_type}")
|
||||
if modified_time:
|
||||
info_parts.append(f"Modified: {modified_time}")
|
||||
if info_parts:
|
||||
description = (description + " | " + " | ".join(info_parts)).strip(" |")
|
||||
return description
|
||||
|
||||
def _extra_fields_fn(
|
||||
_chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"mime_type": metadata.get("mime_type", ""),
|
||||
"file_id": metadata.get("file_id", ""),
|
||||
"modified_time": metadata.get("modified_time", ""),
|
||||
}
|
||||
|
||||
sources_list = self._build_chunk_sources_from_documents(
|
||||
composio_drive_docs,
|
||||
title_fn=_title_fn,
|
||||
url_fn=_url_fn,
|
||||
description_fn=_description_fn,
|
||||
extra_fields_fn=_extra_fields_fn,
|
||||
)
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 54,
|
||||
"name": "Google Drive (Composio)",
|
||||
"type": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, composio_drive_docs
|
||||
|
||||
async def search_composio_gmail(
|
||||
self,
|
||||
user_query: str,
|
||||
search_space_id: int,
|
||||
top_k: int = 20,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Search for Composio Gmail messages and return both the source information
|
||||
and langchain documents.
|
||||
|
||||
Uses combined chunk-level and document-level hybrid search with RRF fusion.
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
search_space_id: The search space ID to search in
|
||||
top_k: Maximum number of results to return
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
composio_gmail_docs = await self._combined_rrf_search(
|
||||
query_text=user_query,
|
||||
search_space_id=search_space_id,
|
||||
document_type="COMPOSIO_GMAIL_CONNECTOR",
|
||||
top_k=top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# Early return if no results
|
||||
if not composio_gmail_docs:
|
||||
return {
|
||||
"id": 55,
|
||||
"name": "Gmail (Composio)",
|
||||
"type": "COMPOSIO_GMAIL_CONNECTOR",
|
||||
"sources": [],
|
||||
}, []
|
||||
|
||||
def _title_fn(doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
|
||||
return (
|
||||
doc_info.get("title")
|
||||
or metadata.get("subject")
|
||||
or metadata.get("title")
|
||||
or "Untitled Email"
|
||||
)
|
||||
|
||||
def _url_fn(_doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
|
||||
return metadata.get("url") or ""
|
||||
|
||||
def _description_fn(
|
||||
chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
|
||||
) -> str:
|
||||
description = self._chunk_preview(chunk.get("content", ""), limit=200)
|
||||
info_parts = []
|
||||
sender = metadata.get("from") or metadata.get("sender")
|
||||
date = metadata.get("date") or metadata.get("received_at")
|
||||
if sender:
|
||||
info_parts.append(f"From: {sender}")
|
||||
if date:
|
||||
info_parts.append(f"Date: {date}")
|
||||
if info_parts:
|
||||
description = (description + " | " + " | ".join(info_parts)).strip(" |")
|
||||
return description
|
||||
|
||||
def _extra_fields_fn(
|
||||
_chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"message_id": metadata.get("message_id", ""),
|
||||
"thread_id": metadata.get("thread_id", ""),
|
||||
"from": metadata.get("from", ""),
|
||||
"to": metadata.get("to", ""),
|
||||
"date": metadata.get("date", ""),
|
||||
}
|
||||
|
||||
sources_list = self._build_chunk_sources_from_documents(
|
||||
composio_gmail_docs,
|
||||
title_fn=_title_fn,
|
||||
url_fn=_url_fn,
|
||||
description_fn=_description_fn,
|
||||
extra_fields_fn=_extra_fields_fn,
|
||||
)
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 55,
|
||||
"name": "Gmail (Composio)",
|
||||
"type": "COMPOSIO_GMAIL_CONNECTOR",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, composio_gmail_docs
|
||||
|
||||
async def search_composio_google_calendar(
|
||||
self,
|
||||
user_query: str,
|
||||
search_space_id: int,
|
||||
top_k: int = 20,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Search for Composio Google Calendar events and return both the source information
|
||||
and langchain documents.
|
||||
|
||||
Uses combined chunk-level and document-level hybrid search with RRF fusion.
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
search_space_id: The search space ID to search in
|
||||
top_k: Maximum number of results to return
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
composio_calendar_docs = await self._combined_rrf_search(
|
||||
query_text=user_query,
|
||||
search_space_id=search_space_id,
|
||||
document_type="COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
top_k=top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# Early return if no results
|
||||
if not composio_calendar_docs:
|
||||
return {
|
||||
"id": 56,
|
||||
"name": "Google Calendar (Composio)",
|
||||
"type": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
"sources": [],
|
||||
}, []
|
||||
|
||||
def _title_fn(doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
|
||||
return (
|
||||
doc_info.get("title")
|
||||
or metadata.get("summary")
|
||||
or metadata.get("title")
|
||||
or "Untitled Event"
|
||||
)
|
||||
|
||||
def _url_fn(_doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
|
||||
return metadata.get("url") or metadata.get("html_link") or ""
|
||||
|
||||
def _description_fn(
|
||||
chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
|
||||
) -> str:
|
||||
description = self._chunk_preview(chunk.get("content", ""), limit=200)
|
||||
info_parts = []
|
||||
start_time = metadata.get("start_time") or metadata.get("start")
|
||||
end_time = metadata.get("end_time") or metadata.get("end")
|
||||
if start_time:
|
||||
info_parts.append(f"Start: {start_time}")
|
||||
if end_time:
|
||||
info_parts.append(f"End: {end_time}")
|
||||
if info_parts:
|
||||
description = (description + " | " + " | ".join(info_parts)).strip(" |")
|
||||
return description
|
||||
|
||||
def _extra_fields_fn(
|
||||
_chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"event_id": metadata.get("event_id", ""),
|
||||
"calendar_id": metadata.get("calendar_id", ""),
|
||||
"start_time": metadata.get("start_time", ""),
|
||||
"end_time": metadata.get("end_time", ""),
|
||||
"location": metadata.get("location", ""),
|
||||
}
|
||||
|
||||
sources_list = self._build_chunk_sources_from_documents(
|
||||
composio_calendar_docs,
|
||||
title_fn=_title_fn,
|
||||
url_fn=_url_fn,
|
||||
description_fn=_description_fn,
|
||||
extra_fields_fn=_extra_fields_fn,
|
||||
)
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 56,
|
||||
"name": "Google Calendar (Composio)",
|
||||
"type": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, composio_calendar_docs
|
||||
|
||||
# =========================================================================
|
||||
# Utility Methods for Connector Discovery
|
||||
# =========================================================================
|
||||
|
|
|
|||
13
surfsense_backend/app/services/gmail/__init__.py
Normal file
13
surfsense_backend/app/services/gmail/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from app.services.gmail.kb_sync_service import GmailKBSyncService
|
||||
from app.services.gmail.tool_metadata_service import (
|
||||
GmailAccount,
|
||||
GmailMessage,
|
||||
GmailToolMetadataService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GmailAccount",
|
||||
"GmailKBSyncService",
|
||||
"GmailMessage",
|
||||
"GmailToolMetadataService",
|
||||
]
|
||||
169
surfsense_backend/app/services/gmail/kb_sync_service.py
Normal file
169
surfsense_backend/app/services/gmail/kb_sync_service.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GmailKBSyncService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
subject: str,
|
||||
sender: str,
|
||||
date_str: str,
|
||||
body_text: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
draft_id: str | None = None,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_GMAIL_CONNECTOR, message_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Document for Gmail message %s already exists (doc_id=%s), skipping",
|
||||
message_id,
|
||||
existing.id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (
|
||||
f"Gmail Message: {subject}\n\nFrom: {sender}\nDate: {date_str}\n\n"
|
||||
f"{body_text or ''}"
|
||||
).strip()
|
||||
if not indexable_content:
|
||||
indexable_content = f"Gmail message: {subject}"
|
||||
|
||||
content_hash = generate_content_hash(indexable_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
logger.info(
|
||||
"Content-hash collision for Gmail message %s -- identical content "
|
||||
"exists in doc %s. Using unique_identifier_hash as content_hash.",
|
||||
message_id,
|
||||
dup.id,
|
||||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"document_type": "Gmail Message",
|
||||
"connector_type": "Gmail",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
indexable_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
logger.warning("No LLM configured -- using fallback summary")
|
||||
summary_content = f"Gmail Message: {subject}\n\n{indexable_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(indexable_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
doc_metadata = {
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date": date_str,
|
||||
"connector_id": connector_id,
|
||||
"indexed_at": now_str,
|
||||
}
|
||||
if draft_id:
|
||||
doc_metadata["draft_id"] = draft_id
|
||||
|
||||
document = Document(
|
||||
title=subject,
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
document_metadata=doc_metadata,
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
source_markdown=body_text,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, subject=%s, chunks=%d",
|
||||
document.id,
|
||||
subject,
|
||||
len(chunks),
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
logger.warning(
|
||||
"Duplicate constraint hit during KB sync for message %s. "
|
||||
"Rolling back -- periodic indexer will handle it. Error: %s",
|
||||
message_id,
|
||||
e,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for message %s: %s",
|
||||
message_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
451
surfsense_backend/app/services/gmail/tool_metadata_service.py
Normal file
451
surfsense_backend/app/services/gmail/tool_metadata_service.py
Normal file
|
|
@ -0,0 +1,451 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from sqlalchemy import String, and_, cast, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GmailAccount:
|
||||
id: int
|
||||
name: str
|
||||
email: str
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: SearchSourceConnector) -> "GmailAccount":
|
||||
email = ""
|
||||
if connector.name and " - " in connector.name:
|
||||
email = connector.name.split(" - ", 1)[1]
|
||||
return cls(id=connector.id, name=connector.name, email=email)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"id": self.id, "name": self.name, "email": self.email}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GmailMessage:
|
||||
message_id: str
|
||||
thread_id: str
|
||||
subject: str
|
||||
sender: str
|
||||
date: str
|
||||
connector_id: int
|
||||
document_id: int
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, document: Document) -> "GmailMessage":
|
||||
meta = document.document_metadata or {}
|
||||
return cls(
|
||||
message_id=meta.get("message_id", ""),
|
||||
thread_id=meta.get("thread_id", ""),
|
||||
subject=meta.get("subject", document.title),
|
||||
sender=meta.get("sender", ""),
|
||||
date=meta.get("date", ""),
|
||||
connector_id=document.connector_id,
|
||||
document_id=document.id,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"message_id": self.message_id,
|
||||
"thread_id": self.thread_id,
|
||||
"subject": self.subject,
|
||||
"sender": self.sender,
|
||||
"date": self.date,
|
||||
"connector_id": self.connector_id,
|
||||
"document_id": self.document_id,
|
||||
}
|
||||
|
||||
|
||||
class GmailToolMetadataService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
return build_composio_credentials(cca_id)
|
||||
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
return Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
async def _check_account_health(self, connector_id: int) -> bool:
|
||||
"""Check if a Gmail connector's credentials are still valid.
|
||||
|
||||
Uses a lightweight ``users().getProfile(userId='me')`` call.
|
||||
|
||||
Returns True if the credentials are expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if not connector:
|
||||
return True
|
||||
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: service.users().getProfile(userId="me").execute()
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Gmail connector %s health check failed: %s",
|
||||
connector_id,
|
||||
e,
|
||||
)
|
||||
return True
|
||||
|
||||
async def _persist_auth_expired(self, connector_id: int) -> None:
|
||||
"""Persist ``auth_expired: True`` to the connector config if not already set."""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
db_connector = result.scalar_one_or_none()
|
||||
if db_connector and not db_connector.config.get("auth_expired"):
|
||||
db_connector.config = {**db_connector.config, "auth_expired": True}
|
||||
flag_modified(db_connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(db_connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _get_accounts(
|
||||
self, search_space_id: int, user_id: str
|
||||
) -> list[GmailAccount]:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(
|
||||
[
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
return [GmailAccount.from_connector(c) for c in connectors]
|
||||
|
||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||
accounts = await self._get_accounts(search_space_id, user_id)
|
||||
|
||||
if not accounts:
|
||||
return {
|
||||
"accounts": [],
|
||||
"error": "No Gmail account connected",
|
||||
}
|
||||
|
||||
accounts_with_status = []
|
||||
for acc in accounts:
|
||||
acc_dict = acc.to_dict()
|
||||
auth_expired = await self._check_account_health(acc.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(acc.id)
|
||||
else:
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == acc.id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if connector:
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
profile = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda service=service: (
|
||||
service.users().getProfile(userId="me").execute()
|
||||
),
|
||||
)
|
||||
acc_dict["email"] = profile.get("emailAddress", "")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch email for Gmail connector %s",
|
||||
acc.id,
|
||||
exc_info=True,
|
||||
)
|
||||
accounts_with_status.append(acc_dict)
|
||||
|
||||
return {"accounts": accounts_with_status}
|
||||
|
||||
async def get_update_context(
|
||||
self, search_space_id: int, user_id: str, email_ref: str
|
||||
) -> dict:
|
||||
document, connector = await self._resolve_email(
|
||||
search_space_id, user_id, email_ref
|
||||
)
|
||||
|
||||
if not document or not connector:
|
||||
return {
|
||||
"error": (
|
||||
f"Draft '{email_ref}' not found in your indexed Gmail messages. "
|
||||
"This could mean: (1) the draft doesn't exist, "
|
||||
"(2) it hasn't been indexed yet, "
|
||||
"or (3) the subject is different. "
|
||||
"Please check the exact draft subject in Gmail."
|
||||
)
|
||||
}
|
||||
|
||||
account = GmailAccount.from_connector(connector)
|
||||
message = GmailMessage.from_document(document)
|
||||
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(connector.id)
|
||||
|
||||
result: dict = {
|
||||
"account": acc_dict,
|
||||
"email": message.to_dict(),
|
||||
}
|
||||
|
||||
meta = document.document_metadata or {}
|
||||
if meta.get("draft_id"):
|
||||
result["draft_id"] = meta["draft_id"]
|
||||
|
||||
if not auth_expired:
|
||||
existing_body = await self._fetch_draft_body(
|
||||
connector, message.message_id, meta.get("draft_id")
|
||||
)
|
||||
if existing_body is not None:
|
||||
result["existing_body"] = existing_body
|
||||
|
||||
return result
|
||||
|
||||
async def _fetch_draft_body(
|
||||
self,
|
||||
connector: SearchSourceConnector,
|
||||
message_id: str,
|
||||
draft_id: str | None,
|
||||
) -> str | None:
|
||||
"""Fetch the plain-text body of a Gmail draft via the API.
|
||||
|
||||
Tries ``drafts.get`` first (if *draft_id* is available), then falls
|
||||
back to scanning ``drafts.list`` to resolve the draft by *message_id*.
|
||||
Returns ``None`` on any failure so callers can degrade gracefully.
|
||||
"""
|
||||
try:
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
if not draft_id:
|
||||
draft_id = await self._find_draft_id(service, message_id)
|
||||
if not draft_id:
|
||||
return None
|
||||
|
||||
draft = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.users()
|
||||
.drafts()
|
||||
.get(userId="me", id=draft_id, format="full")
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
|
||||
payload = draft.get("message", {}).get("payload", {})
|
||||
return self._extract_body_from_payload(payload)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch draft body for message_id=%s",
|
||||
message_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
|
||||
"""Resolve a draft ID from its message ID by scanning drafts.list."""
|
||||
try:
|
||||
page_token = None
|
||||
while True:
|
||||
kwargs: dict[str, Any] = {"userId": "me", "maxResults": 100}
|
||||
if page_token:
|
||||
kwargs["pageToken"] = page_token
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda kwargs=kwargs: (
|
||||
service.users().drafts().list(**kwargs).execute()
|
||||
),
|
||||
)
|
||||
for draft in response.get("drafts", []):
|
||||
if draft.get("message", {}).get("id") == message_id:
|
||||
return draft["id"]
|
||||
page_token = response.get("nextPageToken")
|
||||
if not page_token:
|
||||
break
|
||||
return None
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to look up draft by message_id=%s", message_id, exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_body_from_payload(payload: dict) -> str | None:
|
||||
"""Extract the plain-text (or html→text) body from a Gmail payload."""
|
||||
import base64
|
||||
|
||||
def _get_parts(p: dict) -> list[dict]:
|
||||
if "parts" in p:
|
||||
parts: list[dict] = []
|
||||
for sub in p["parts"]:
|
||||
parts.extend(_get_parts(sub))
|
||||
return parts
|
||||
return [p]
|
||||
|
||||
parts = _get_parts(payload)
|
||||
text_content = ""
|
||||
for part in parts:
|
||||
mime_type = part.get("mimeType", "")
|
||||
data = part.get("body", {}).get("data", "")
|
||||
if mime_type == "text/plain" and data:
|
||||
text_content += base64.urlsafe_b64decode(data + "===").decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
elif mime_type == "text/html" and data and not text_content:
|
||||
from markdownify import markdownify as md
|
||||
|
||||
raw_html = base64.urlsafe_b64decode(data + "===").decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
text_content = md(raw_html).strip()
|
||||
|
||||
return text_content.strip() if text_content.strip() else None
|
||||
|
||||
async def get_trash_context(
|
||||
self, search_space_id: int, user_id: str, email_ref: str
|
||||
) -> dict:
|
||||
document, connector = await self._resolve_email(
|
||||
search_space_id, user_id, email_ref
|
||||
)
|
||||
|
||||
if not document or not connector:
|
||||
return {
|
||||
"error": (
|
||||
f"Email '{email_ref}' not found in your indexed Gmail messages. "
|
||||
"This could mean: (1) the email doesn't exist, "
|
||||
"(2) it hasn't been indexed yet, "
|
||||
"or (3) the subject is different."
|
||||
)
|
||||
}
|
||||
|
||||
account = GmailAccount.from_connector(connector)
|
||||
message = GmailMessage.from_document(document)
|
||||
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(connector.id)
|
||||
|
||||
return {
|
||||
"account": acc_dict,
|
||||
"email": message.to_dict(),
|
||||
}
|
||||
|
||||
async def _resolve_email(
|
||||
self, search_space_id: int, user_id: str, email_ref: str
|
||||
) -> tuple[Document | None, SearchSourceConnector | None]:
|
||||
result = await self._db_session.execute(
|
||||
select(Document, SearchSourceConnector)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type.in_(
|
||||
[
|
||||
DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
DocumentType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
or_(
|
||||
func.lower(cast(Document.document_metadata["subject"], String))
|
||||
== func.lower(email_ref),
|
||||
func.lower(Document.title) == func.lower(email_ref),
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
row = result.first()
|
||||
if row:
|
||||
return row[0], row[1]
|
||||
return None, None
|
||||
13
surfsense_backend/app/services/google_calendar/__init__.py
Normal file
13
surfsense_backend/app/services/google_calendar/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from app.services.google_calendar.kb_sync_service import GoogleCalendarKBSyncService
|
||||
from app.services.google_calendar.tool_metadata_service import (
|
||||
GoogleCalendarAccount,
|
||||
GoogleCalendarEvent,
|
||||
GoogleCalendarToolMetadataService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GoogleCalendarAccount",
|
||||
"GoogleCalendarEvent",
|
||||
"GoogleCalendarKBSyncService",
|
||||
"GoogleCalendarToolMetadataService",
|
||||
]
|
||||
|
|
@ -0,0 +1,374 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleCalendarKBSyncService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
event_id: str,
|
||||
event_summary: str,
|
||||
calendar_id: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
location: str | None,
|
||||
html_link: str | None,
|
||||
description: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_CALENDAR_CONNECTOR, event_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Document for Calendar event %s already exists (doc_id=%s), skipping",
|
||||
event_id,
|
||||
existing.id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (
|
||||
f"Google Calendar Event: {event_summary}\n\n"
|
||||
f"Start: {start_time}\n"
|
||||
f"End: {end_time}\n"
|
||||
f"Location: {location or 'N/A'}\n\n"
|
||||
f"{description or ''}"
|
||||
).strip()
|
||||
|
||||
content_hash = generate_content_hash(indexable_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
logger.info(
|
||||
"Content-hash collision for Calendar event %s -- identical content "
|
||||
"exists in doc %s. Using unique_identifier_hash as content_hash.",
|
||||
event_id,
|
||||
dup.id,
|
||||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"event_summary": event_summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"document_type": "Google Calendar Event",
|
||||
"connector_type": "Google Calendar",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
indexable_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
logger.warning("No LLM configured -- using fallback summary")
|
||||
summary_content = (
|
||||
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(indexable_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document = Document(
|
||||
title=event_summary,
|
||||
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
document_metadata={
|
||||
"event_id": event_id,
|
||||
"event_summary": event_summary,
|
||||
"calendar_id": calendar_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
"html_link": html_link,
|
||||
"source_connector": "google_calendar",
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
source_markdown=indexable_content,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, event=%s, chunks=%d",
|
||||
document.id,
|
||||
event_summary,
|
||||
len(chunks),
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
logger.warning(
|
||||
"Duplicate constraint hit during KB sync for event %s. "
|
||||
"Rolling back -- periodic indexer will handle it. Error: %s",
|
||||
event_id,
|
||||
e,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for event %s: %s",
|
||||
event_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def sync_after_update(
|
||||
self,
|
||||
document_id: int,
|
||||
event_id: str,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
document = await self.db_session.get(Document, document_id)
|
||||
if not document:
|
||||
logger.warning("Document %s not found in KB", document_id)
|
||||
return {"status": "not_indexed"}
|
||||
|
||||
creds = await self._build_credentials_for_connector(connector_id)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
calendar_id = (document.document_metadata or {}).get(
|
||||
"calendar_id", "primary"
|
||||
)
|
||||
live_event = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.get(calendarId=calendar_id, eventId=event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
|
||||
event_summary = live_event.get("summary", "")
|
||||
description = live_event.get("description", "")
|
||||
location = live_event.get("location", "")
|
||||
|
||||
start_data = live_event.get("start", {})
|
||||
start_time = start_data.get("dateTime", start_data.get("date", ""))
|
||||
|
||||
end_data = live_event.get("end", {})
|
||||
end_time = end_data.get("dateTime", end_data.get("date", ""))
|
||||
|
||||
attendees = [
|
||||
{
|
||||
"email": a.get("email", ""),
|
||||
"responseStatus": a.get("responseStatus", ""),
|
||||
}
|
||||
for a in live_event.get("attendees", [])
|
||||
]
|
||||
|
||||
indexable_content = (
|
||||
f"Google Calendar Event: {event_summary}\n\n"
|
||||
f"Start: {start_time}\n"
|
||||
f"End: {end_time}\n"
|
||||
f"Location: {location or 'N/A'}\n\n"
|
||||
f"{description or ''}"
|
||||
).strip()
|
||||
|
||||
if not indexable_content:
|
||||
return {"status": "error", "message": "Event produced empty content"}
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"event_summary": event_summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"document_type": "Google Calendar Event",
|
||||
"connector_type": "Google Calendar",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
indexable_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = (
|
||||
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(indexable_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document.title = event_summary
|
||||
document.content = summary_content
|
||||
document.content_hash = generate_content_hash(
|
||||
indexable_content, search_space_id
|
||||
)
|
||||
document.embedding = summary_embedding
|
||||
|
||||
document.document_metadata = {
|
||||
**(document.document_metadata or {}),
|
||||
"event_id": event_id,
|
||||
"event_summary": event_summary,
|
||||
"calendar_id": calendar_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
"description": description,
|
||||
"attendees": attendees,
|
||||
"html_link": live_event.get("htmlLink", ""),
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
flag_modified(document, "document_metadata")
|
||||
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after update succeeded for document %s (event: %s)",
|
||||
document_id,
|
||||
event_summary,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"KB sync after update failed for document %s: %s",
|
||||
document_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
|
||||
result = await self.db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if not connector:
|
||||
raise ValueError(f"Connector {connector_id} not found")
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
return build_composio_credentials(cca_id)
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
return Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
|
@ -0,0 +1,431 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from sqlalchemy import String, and_, cast, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CALENDAR_CONNECTOR_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
CALENDAR_DOCUMENT_TYPES = [
|
||||
DocumentType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleCalendarAccount:
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_connector(
|
||||
cls, connector: SearchSourceConnector
|
||||
) -> "GoogleCalendarAccount":
|
||||
return cls(id=connector.id, name=connector.name)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"id": self.id, "name": self.name}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleCalendarEvent:
|
||||
event_id: str
|
||||
summary: str
|
||||
start: str
|
||||
end: str
|
||||
description: str
|
||||
location: str
|
||||
attendees: list
|
||||
calendar_id: str
|
||||
document_id: int
|
||||
indexed_at: str | None
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, document: Document) -> "GoogleCalendarEvent":
|
||||
meta = document.document_metadata or {}
|
||||
return cls(
|
||||
event_id=meta.get("event_id", ""),
|
||||
summary=meta.get("event_summary", document.title),
|
||||
start=meta.get("start_time", ""),
|
||||
end=meta.get("end_time", ""),
|
||||
description=meta.get("description", ""),
|
||||
location=meta.get("location", ""),
|
||||
attendees=meta.get("attendees", []),
|
||||
calendar_id=meta.get("calendar_id", "primary"),
|
||||
document_id=document.id,
|
||||
indexed_at=meta.get("indexed_at"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"event_id": self.event_id,
|
||||
"summary": self.summary,
|
||||
"start": self.start,
|
||||
"end": self.end,
|
||||
"description": self.description,
|
||||
"location": self.location,
|
||||
"attendees": self.attendees,
|
||||
"calendar_id": self.calendar_id,
|
||||
"document_id": self.document_id,
|
||||
"indexed_at": self.indexed_at,
|
||||
}
|
||||
|
||||
|
||||
class GoogleCalendarToolMetadataService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
return build_composio_credentials(cca_id)
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
return Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
async def _check_account_health(self, connector_id: int) -> bool:
|
||||
"""Check if a Google Calendar connector's credentials are still valid.
|
||||
|
||||
Uses a lightweight calendarList.list(maxResults=1) call to verify access.
|
||||
|
||||
Returns True if the credentials are expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if not connector:
|
||||
return True
|
||||
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
build("calendar", "v3", credentials=creds)
|
||||
.calendarList()
|
||||
.list(maxResults=1)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Google Calendar connector %s health check failed: %s",
|
||||
connector_id,
|
||||
e,
|
||||
)
|
||||
return True
|
||||
|
||||
async def _persist_auth_expired(self, connector_id: int) -> None:
|
||||
"""Persist ``auth_expired: True`` to the connector config if not already set."""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
db_connector = result.scalar_one_or_none()
|
||||
if db_connector and not db_connector.config.get("auth_expired"):
|
||||
db_connector.config = {**db_connector.config, "auth_expired": True}
|
||||
flag_modified(db_connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(db_connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _get_accounts(
|
||||
self, search_space_id: int, user_id: str
|
||||
) -> list[GoogleCalendarAccount]:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES),
|
||||
)
|
||||
)
|
||||
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
return [GoogleCalendarAccount.from_connector(c) for c in connectors]
|
||||
|
||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||
accounts = await self._get_accounts(search_space_id, user_id)
|
||||
|
||||
if not accounts:
|
||||
return {
|
||||
"accounts": [],
|
||||
"error": "No Google Calendar account connected",
|
||||
}
|
||||
|
||||
accounts_with_status = []
|
||||
for acc in accounts:
|
||||
acc_dict = acc.to_dict()
|
||||
auth_expired = await self._check_account_health(acc.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(acc.id)
|
||||
accounts_with_status.append(acc_dict)
|
||||
|
||||
healthy_account = next(
|
||||
(a for a in accounts_with_status if not a.get("auth_expired")), None
|
||||
)
|
||||
if not healthy_account:
|
||||
return {
|
||||
"accounts": accounts_with_status,
|
||||
"calendars": [],
|
||||
"timezone": "",
|
||||
"error": "All connected Google Calendar accounts have expired credentials",
|
||||
}
|
||||
|
||||
connector_id = healthy_account["id"]
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
|
||||
calendars = []
|
||||
timezone_str = ""
|
||||
if connector:
|
||||
try:
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
cal_list = await loop.run_in_executor(
|
||||
None, lambda: service.calendarList().list().execute()
|
||||
)
|
||||
for cal in cal_list.get("items", []):
|
||||
calendars.append(
|
||||
{
|
||||
"id": cal.get("id", ""),
|
||||
"summary": cal.get("summary", ""),
|
||||
"primary": cal.get("primary", False),
|
||||
}
|
||||
)
|
||||
|
||||
tz_setting = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: service.settings().get(setting="timezone").execute(),
|
||||
)
|
||||
timezone_str = tz_setting.get("value", "")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch calendars/timezone for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"accounts": accounts_with_status,
|
||||
"calendars": calendars,
|
||||
"timezone": timezone_str,
|
||||
}
|
||||
|
||||
async def get_update_context(
|
||||
self, search_space_id: int, user_id: str, event_ref: str
|
||||
) -> dict:
|
||||
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
|
||||
if not resolved:
|
||||
return {
|
||||
"error": (
|
||||
f"Event '{event_ref}' not found in your indexed Google Calendar events. "
|
||||
"This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the event name is different."
|
||||
)
|
||||
}
|
||||
|
||||
document, connector = resolved
|
||||
account = GoogleCalendarAccount.from_connector(connector)
|
||||
event = GoogleCalendarEvent.from_document(document)
|
||||
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(connector.id)
|
||||
return {
|
||||
"error": "Google Calendar credentials have expired. Please re-authenticate.",
|
||||
"auth_expired": True,
|
||||
"connector_id": connector.id,
|
||||
}
|
||||
|
||||
event_dict = event.to_dict()
|
||||
try:
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
calendar_id = event.calendar_id or "primary"
|
||||
live_event = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.get(calendarId=calendar_id, eventId=event.event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
|
||||
event_dict["summary"] = live_event.get("summary", event_dict["summary"])
|
||||
event_dict["description"] = live_event.get(
|
||||
"description", event_dict["description"]
|
||||
)
|
||||
event_dict["location"] = live_event.get("location", event_dict["location"])
|
||||
|
||||
start_data = live_event.get("start", {})
|
||||
event_dict["start"] = start_data.get(
|
||||
"dateTime", start_data.get("date", event_dict["start"])
|
||||
)
|
||||
|
||||
end_data = live_event.get("end", {})
|
||||
event_dict["end"] = end_data.get(
|
||||
"dateTime", end_data.get("date", event_dict["end"])
|
||||
)
|
||||
|
||||
event_dict["attendees"] = [
|
||||
{
|
||||
"email": a.get("email", ""),
|
||||
"responseStatus": a.get("responseStatus", ""),
|
||||
}
|
||||
for a in live_event.get("attendees", [])
|
||||
]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch live event data for event %s, using KB metadata",
|
||||
event.event_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"account": acc_dict,
|
||||
"event": event_dict,
|
||||
}
|
||||
|
||||
async def get_deletion_context(
|
||||
self, search_space_id: int, user_id: str, event_ref: str
|
||||
) -> dict:
|
||||
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
|
||||
if not resolved:
|
||||
return {
|
||||
"error": (
|
||||
f"Event '{event_ref}' not found in your indexed Google Calendar events. "
|
||||
"This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the event name is different."
|
||||
)
|
||||
}
|
||||
|
||||
document, connector = resolved
|
||||
account = GoogleCalendarAccount.from_connector(connector)
|
||||
event = GoogleCalendarEvent.from_document(document)
|
||||
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(connector.id)
|
||||
|
||||
return {
|
||||
"account": acc_dict,
|
||||
"event": event.to_dict(),
|
||||
}
|
||||
|
||||
async def _resolve_event(
|
||||
self, search_space_id: int, user_id: str, event_ref: str
|
||||
) -> tuple[Document, SearchSourceConnector] | None:
|
||||
result = await self._db_session.execute(
|
||||
select(Document, SearchSourceConnector)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type.in_(CALENDAR_DOCUMENT_TYPES),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
or_(
|
||||
func.lower(
|
||||
cast(Document.document_metadata["event_summary"], String)
|
||||
)
|
||||
== func.lower(event_ref),
|
||||
func.lower(Document.title) == func.lower(event_ref),
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
row = result.first()
|
||||
if row:
|
||||
return row[0], row[1]
|
||||
return None
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from app.services.google_drive.kb_sync_service import GoogleDriveKBSyncService
|
||||
from app.services.google_drive.tool_metadata_service import (
|
||||
GoogleDriveAccount,
|
||||
GoogleDriveFile,
|
||||
|
|
@ -7,5 +8,6 @@ from app.services.google_drive.tool_metadata_service import (
|
|||
__all__ = [
|
||||
"GoogleDriveAccount",
|
||||
"GoogleDriveFile",
|
||||
"GoogleDriveKBSyncService",
|
||||
"GoogleDriveToolMetadataService",
|
||||
]
|
||||
|
|
|
|||
164
surfsense_backend/app/services/google_drive/kb_sync_service.py
Normal file
164
surfsense_backend/app/services/google_drive/kb_sync_service.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleDriveKBSyncService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
file_id: str,
|
||||
file_name: str,
|
||||
mime_type: str,
|
||||
web_view_link: str | None,
|
||||
content: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Document for Drive file %s already exists (doc_id=%s), skipping",
|
||||
file_id,
|
||||
existing.id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (content or "").strip()
|
||||
if not indexable_content:
|
||||
indexable_content = (
|
||||
f"Google Drive file: {file_name} (type: {mime_type})"
|
||||
)
|
||||
|
||||
content_hash = generate_content_hash(indexable_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
logger.info(
|
||||
"Content-hash collision for Drive file %s — identical content "
|
||||
"exists in doc %s. Using unique_identifier_hash as content_hash.",
|
||||
file_id,
|
||||
dup.id,
|
||||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"file_name": file_name,
|
||||
"mime_type": mime_type,
|
||||
"document_type": "Google Drive File",
|
||||
"connector_type": "Google Drive",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
indexable_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
logger.warning("No LLM configured — using fallback summary")
|
||||
summary_content = (
|
||||
f"Google Drive File: {file_name}\n\n{indexable_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(indexable_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document = Document(
|
||||
title=file_name,
|
||||
document_type=DocumentType.GOOGLE_DRIVE_FILE,
|
||||
document_metadata={
|
||||
"google_drive_file_id": file_id,
|
||||
"google_drive_file_name": file_name,
|
||||
"google_drive_mime_type": mime_type,
|
||||
"web_view_link": web_view_link,
|
||||
"source_connector": "google_drive",
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
source_markdown=content,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, file=%s, chunks=%d",
|
||||
document.id,
|
||||
file_name,
|
||||
len(chunks),
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
logger.warning(
|
||||
"Duplicate constraint hit during KB sync for file %s. "
|
||||
"Rolling back — periodic indexer will handle it. Error: %s",
|
||||
file_id,
|
||||
e,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for file %s: %s",
|
||||
file_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
|
@ -1,15 +1,21 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -68,12 +74,25 @@ class GoogleDriveToolMetadataService:
|
|||
return {
|
||||
"accounts": [],
|
||||
"supported_types": [],
|
||||
"parent_folders": {},
|
||||
"error": "No Google Drive account connected",
|
||||
}
|
||||
|
||||
accounts_with_status = []
|
||||
for acc in accounts:
|
||||
acc_dict = acc.to_dict()
|
||||
auth_expired = await self._check_account_health(acc.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(acc.id)
|
||||
accounts_with_status.append(acc_dict)
|
||||
|
||||
parent_folders = await self._get_parent_folders_by_account(accounts_with_status)
|
||||
|
||||
return {
|
||||
"accounts": [acc.to_dict() for acc in accounts],
|
||||
"accounts": accounts_with_status,
|
||||
"supported_types": ["google_doc", "google_sheet"],
|
||||
"parent_folders": parent_folders,
|
||||
}
|
||||
|
||||
async def get_trash_context(
|
||||
|
|
@ -92,6 +111,8 @@ class GoogleDriveToolMetadataService:
|
|||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
|
|
@ -112,8 +133,12 @@ class GoogleDriveToolMetadataService:
|
|||
and_(
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnector.connector_type.in_(
|
||||
[
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -125,8 +150,14 @@ class GoogleDriveToolMetadataService:
|
|||
account = GoogleDriveAccount.from_connector(connector)
|
||||
file = GoogleDriveFile.from_document(document)
|
||||
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(connector.id)
|
||||
|
||||
return {
|
||||
"account": account.to_dict(),
|
||||
"account": acc_dict,
|
||||
"file": file.to_dict(),
|
||||
}
|
||||
|
||||
|
|
@ -139,11 +170,150 @@ class GoogleDriveToolMetadataService:
|
|||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnector.connector_type.in_(
|
||||
[
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
return [GoogleDriveAccount.from_connector(c) for c in connectors]
|
||||
|
||||
async def _check_account_health(self, connector_id: int) -> bool:
|
||||
"""Check if a Google Drive connector's credentials are still valid.
|
||||
|
||||
Uses a lightweight ``files.list(pageSize=1)`` call to verify access.
|
||||
|
||||
Returns True if the credentials are expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if not connector:
|
||||
return True
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=self._db_session,
|
||||
connector_id=connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
await client.list_files(
|
||||
query="trashed = false", page_size=1, fields="files(id)"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Google Drive connector %s health check failed: %s",
|
||||
connector_id,
|
||||
e,
|
||||
)
|
||||
return True
|
||||
|
||||
async def _persist_auth_expired(self, connector_id: int) -> None:
|
||||
"""Persist ``auth_expired: True`` to the connector config if not already set."""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
db_connector = result.scalar_one_or_none()
|
||||
if db_connector and not db_connector.config.get("auth_expired"):
|
||||
db_connector.config = {**db_connector.config, "auth_expired": True}
|
||||
flag_modified(db_connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(db_connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _get_parent_folders_by_account(
|
||||
self, accounts_with_status: list[dict]
|
||||
) -> dict[int, list[dict]]:
|
||||
"""Fetch root-level folders for each healthy account.
|
||||
|
||||
Skips accounts where ``auth_expired`` is True so we don't waste an API
|
||||
call that will fail anyway.
|
||||
"""
|
||||
parent_folders: dict[int, list[dict]] = {}
|
||||
|
||||
for acc in accounts_with_status:
|
||||
connector_id = acc["id"]
|
||||
if acc.get("auth_expired"):
|
||||
parent_folders[connector_id] = []
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if not connector:
|
||||
parent_folders[connector_id] = []
|
||||
continue
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=self._db_session,
|
||||
connector_id=connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
|
||||
folders, _, error = await client.list_files(
|
||||
query="mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents",
|
||||
fields="files(id, name)",
|
||||
page_size=50,
|
||||
)
|
||||
|
||||
if error:
|
||||
logger.warning(
|
||||
"Failed to list folders for connector %s: %s",
|
||||
connector_id,
|
||||
error,
|
||||
)
|
||||
parent_folders[connector_id] = []
|
||||
else:
|
||||
parent_folders[connector_id] = [
|
||||
{"folder_id": f["id"], "name": f["name"]}
|
||||
for f in folders
|
||||
if f.get("id") and f.get("name")
|
||||
]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error fetching folders for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
parent_folders[connector_id] = []
|
||||
|
||||
return parent_folders
|
||||
|
|
|
|||
13
surfsense_backend/app/services/jira/__init__.py
Normal file
13
surfsense_backend/app/services/jira/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from app.services.jira.kb_sync_service import JiraKBSyncService
|
||||
from app.services.jira.tool_metadata_service import (
|
||||
JiraIssue,
|
||||
JiraToolMetadataService,
|
||||
JiraWorkspace,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"JiraIssue",
|
||||
"JiraKBSyncService",
|
||||
"JiraToolMetadataService",
|
||||
"JiraWorkspace",
|
||||
]
|
||||
254
surfsense_backend/app/services/jira/kb_sync_service.py
Normal file
254
surfsense_backend/app/services/jira/kb_sync_service.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JiraKBSyncService:
|
||||
"""Syncs Jira issue documents to the knowledge base after HITL actions."""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
issue_id: str,
|
||||
issue_identifier: str,
|
||||
issue_title: str,
|
||||
description: str | None,
|
||||
state: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.JIRA_CONNECTOR, issue_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Document for Jira issue %s already exists (doc_id=%s), skipping",
|
||||
issue_identifier,
|
||||
existing.id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (description or "").strip()
|
||||
if not indexable_content:
|
||||
indexable_content = f"Jira Issue {issue_identifier}: {issue_title}"
|
||||
|
||||
issue_content = (
|
||||
f"# {issue_identifier}: {issue_title}\n\n{indexable_content}"
|
||||
)
|
||||
|
||||
content_hash = generate_content_hash(issue_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"issue_id": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"document_type": "Jira Issue",
|
||||
"connector_type": "Jira",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
issue_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = (
|
||||
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(issue_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document = Document(
|
||||
title=f"{issue_identifier}: {issue_title}",
|
||||
document_type=DocumentType.JIRA_CONNECTOR,
|
||||
document_metadata={
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"state": state or "Unknown",
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, issue=%s",
|
||||
document.id,
|
||||
issue_identifier,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for issue %s: %s",
|
||||
issue_identifier,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def sync_after_update(
|
||||
self,
|
||||
document_id: int,
|
||||
issue_id: str,
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
document = await self.db_session.get(Document, document_id)
|
||||
if not document:
|
||||
return {"status": "not_indexed"}
|
||||
|
||||
connector_id = document.connector_id
|
||||
if not connector_id:
|
||||
return {"status": "error", "message": "Document has no connector_id"}
|
||||
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=self.db_session, connector_id=connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
issue_raw = await asyncio.to_thread(jira_client.get_issue, issue_id)
|
||||
formatted = jira_client.format_issue(issue_raw)
|
||||
issue_content = jira_client.format_issue_to_markdown(formatted)
|
||||
|
||||
if not issue_content:
|
||||
return {"status": "error", "message": "Issue produced empty content"}
|
||||
|
||||
issue_identifier = formatted.get("key", "")
|
||||
issue_title = formatted.get("title", "")
|
||||
state = formatted.get("status", "Unknown")
|
||||
comment_count = len(formatted.get("comments", []))
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
doc_meta = {
|
||||
"issue_key": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"status": state,
|
||||
"document_type": "Jira Issue",
|
||||
"connector_type": "Jira",
|
||||
}
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
issue_content, user_llm, doc_meta
|
||||
)
|
||||
else:
|
||||
summary_content = (
|
||||
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(issue_content)
|
||||
|
||||
document.title = f"{issue_identifier}: {issue_title}"
|
||||
document.content = summary_content
|
||||
document.content_hash = generate_content_hash(
|
||||
issue_content, search_space_id
|
||||
)
|
||||
document.embedding = summary_embedding
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
document.document_metadata = {
|
||||
**(document.document_metadata or {}),
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"state": state,
|
||||
"comment_count": comment_count,
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
flag_modified(document, "document_metadata")
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync successful for document %s (%s: %s)",
|
||||
document_id,
|
||||
issue_identifier,
|
||||
issue_title,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"KB sync failed for document %s: %s", document_id, e, exc_info=True
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
332
surfsense_backend/app/services/jira/tool_metadata_service.py
Normal file
332
surfsense_backend/app/services/jira/tool_metadata_service.py
Normal file
|
|
@ -0,0 +1,332 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import and_, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraWorkspace:
|
||||
"""Represents a Jira connector as a workspace for tool context."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
base_url: str
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: SearchSourceConnector) -> "JiraWorkspace":
|
||||
return cls(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
base_url=connector.config.get("base_url", ""),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"base_url": self.base_url,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraIssue:
|
||||
"""Represents an indexed Jira issue resolved from the knowledge base."""
|
||||
|
||||
issue_id: str
|
||||
issue_identifier: str
|
||||
issue_title: str
|
||||
state: str
|
||||
connector_id: int
|
||||
document_id: int
|
||||
indexed_at: str | None
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, document: Document) -> "JiraIssue":
|
||||
meta = document.document_metadata or {}
|
||||
return cls(
|
||||
issue_id=meta.get("issue_id", ""),
|
||||
issue_identifier=meta.get("issue_identifier", ""),
|
||||
issue_title=meta.get("issue_title", document.title),
|
||||
state=meta.get("state", ""),
|
||||
connector_id=document.connector_id,
|
||||
document_id=document.id,
|
||||
indexed_at=meta.get("indexed_at"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"issue_id": self.issue_id,
|
||||
"issue_identifier": self.issue_identifier,
|
||||
"issue_title": self.issue_title,
|
||||
"state": self.state,
|
||||
"connector_id": self.connector_id,
|
||||
"document_id": self.document_id,
|
||||
"indexed_at": self.indexed_at,
|
||||
}
|
||||
|
||||
|
||||
class JiraToolMetadataService:
|
||||
"""Builds interrupt context for Jira HITL tools."""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def _check_account_health(self, connector: SearchSourceConnector) -> bool:
|
||||
"""Check if the Jira connector auth is still valid.
|
||||
|
||||
Returns True if auth is expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=self._db_session, connector_id=connector.id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
await asyncio.to_thread(jira_client.get_myself)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("Jira connector %s health check failed: %s", connector.id, e)
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return True
|
||||
|
||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||
"""Return context needed to create a new Jira issue.
|
||||
|
||||
Fetches all connected Jira accounts, and for the first healthy one
|
||||
fetches projects, issue types, and priorities.
|
||||
"""
|
||||
connectors = await self._get_all_jira_connectors(search_space_id, user_id)
|
||||
if not connectors:
|
||||
return {"error": "No Jira account connected"}
|
||||
|
||||
accounts = []
|
||||
projects = []
|
||||
issue_types = []
|
||||
priorities = []
|
||||
fetched_context = False
|
||||
|
||||
for connector in connectors:
|
||||
auth_expired = await self._check_account_health(connector)
|
||||
workspace = JiraWorkspace.from_connector(connector)
|
||||
account_info = {
|
||||
**workspace.to_dict(),
|
||||
"auth_expired": auth_expired,
|
||||
}
|
||||
accounts.append(account_info)
|
||||
|
||||
if not auth_expired and not fetched_context:
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=self._db_session, connector_id=connector.id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
raw_projects = await asyncio.to_thread(jira_client.get_projects)
|
||||
projects = [
|
||||
{"id": p.get("id"), "key": p.get("key"), "name": p.get("name")}
|
||||
for p in raw_projects
|
||||
]
|
||||
raw_types = await asyncio.to_thread(jira_client.get_issue_types)
|
||||
seen_type_names: set[str] = set()
|
||||
issue_types = []
|
||||
for t in raw_types:
|
||||
if t.get("subtask", False):
|
||||
continue
|
||||
name = t.get("name")
|
||||
if name not in seen_type_names:
|
||||
seen_type_names.add(name)
|
||||
issue_types.append({"id": t.get("id"), "name": name})
|
||||
raw_priorities = await asyncio.to_thread(jira_client.get_priorities)
|
||||
priorities = [
|
||||
{"id": p.get("id"), "name": p.get("name")}
|
||||
for p in raw_priorities
|
||||
]
|
||||
fetched_context = True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to fetch Jira context for connector %s: %s",
|
||||
connector.id,
|
||||
e,
|
||||
)
|
||||
|
||||
return {
|
||||
"accounts": accounts,
|
||||
"projects": projects,
|
||||
"issue_types": issue_types,
|
||||
"priorities": priorities,
|
||||
}
|
||||
|
||||
async def get_update_context(
|
||||
self, search_space_id: int, user_id: str, issue_ref: str
|
||||
) -> dict:
|
||||
"""Return context needed to update an indexed Jira issue.
|
||||
|
||||
Resolves the issue from the KB, then fetches current details from the Jira API.
|
||||
"""
|
||||
document = await self._resolve_issue(search_space_id, user_id, issue_ref)
|
||||
if not document:
|
||||
return {
|
||||
"error": f"Issue '{issue_ref}' not found in your synced Jira issues. "
|
||||
"Please make sure the issue is indexed in your knowledge base."
|
||||
}
|
||||
|
||||
connector = await self._get_connector_for_document(document, user_id)
|
||||
if not connector:
|
||||
return {"error": "Connector not found or access denied"}
|
||||
|
||||
auth_expired = await self._check_account_health(connector)
|
||||
if auth_expired:
|
||||
return {
|
||||
"error": "Jira authentication has expired. Please re-authenticate.",
|
||||
"auth_expired": True,
|
||||
"connector_id": connector.id,
|
||||
}
|
||||
|
||||
workspace = JiraWorkspace.from_connector(connector)
|
||||
issue = JiraIssue.from_document(document)
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=self._db_session, connector_id=connector.id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
issue_data = await asyncio.to_thread(jira_client.get_issue, issue.issue_id)
|
||||
formatted = jira_client.format_issue(issue_data)
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"401" in error_str
|
||||
or "403" in error_str
|
||||
or "authentication" in error_str
|
||||
):
|
||||
return {
|
||||
"error": f"Failed to fetch Jira issue: {e!s}",
|
||||
"auth_expired": True,
|
||||
"connector_id": connector.id,
|
||||
}
|
||||
return {"error": f"Failed to fetch Jira issue: {e!s}"}
|
||||
|
||||
return {
|
||||
"account": {**workspace.to_dict(), "auth_expired": False},
|
||||
"issue": {
|
||||
"issue_id": formatted.get("key", issue.issue_id),
|
||||
"issue_identifier": formatted.get("key", issue.issue_identifier),
|
||||
"issue_title": formatted.get("title", issue.issue_title),
|
||||
"state": formatted.get("status", "Unknown"),
|
||||
"priority": formatted.get("priority", "Unknown"),
|
||||
"issue_type": formatted.get("issue_type", "Unknown"),
|
||||
"assignee": formatted.get("assignee"),
|
||||
"description": formatted.get("description"),
|
||||
"project": formatted.get("project", ""),
|
||||
"document_id": issue.document_id,
|
||||
"indexed_at": issue.indexed_at,
|
||||
},
|
||||
}
|
||||
|
||||
async def get_deletion_context(
|
||||
self, search_space_id: int, user_id: str, issue_ref: str
|
||||
) -> dict:
|
||||
"""Return context needed to delete a Jira issue (KB metadata only, no API call)."""
|
||||
document = await self._resolve_issue(search_space_id, user_id, issue_ref)
|
||||
if not document:
|
||||
return {
|
||||
"error": f"Issue '{issue_ref}' not found in your synced Jira issues. "
|
||||
"Please make sure the issue is indexed in your knowledge base."
|
||||
}
|
||||
|
||||
connector = await self._get_connector_for_document(document, user_id)
|
||||
if not connector:
|
||||
return {"error": "Connector not found or access denied"}
|
||||
|
||||
auth_expired = connector.config.get("auth_expired", False)
|
||||
workspace = JiraWorkspace.from_connector(connector)
|
||||
issue = JiraIssue.from_document(document)
|
||||
|
||||
return {
|
||||
"account": {**workspace.to_dict(), "auth_expired": auth_expired},
|
||||
"issue": issue.to_dict(),
|
||||
}
|
||||
|
||||
async def _resolve_issue(
|
||||
self, search_space_id: int, user_id: str, issue_ref: str
|
||||
) -> Document | None:
|
||||
"""Resolve an issue from KB: issue_identifier -> issue_title -> document.title."""
|
||||
ref_lower = issue_ref.lower()
|
||||
|
||||
result = await self._db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector, Document.connector_id == SearchSourceConnector.id
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.JIRA_CONNECTOR,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
or_(
|
||||
func.lower(
|
||||
Document.document_metadata.op("->>")("issue_identifier")
|
||||
)
|
||||
== ref_lower,
|
||||
func.lower(Document.document_metadata.op("->>")("issue_title"))
|
||||
== ref_lower,
|
||||
func.lower(Document.title) == ref_lower,
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def _get_all_jira_connectors(
|
||||
self, search_space_id: int, user_id: str
|
||||
) -> list[SearchSourceConnector]:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def _get_connector_for_document(
|
||||
self, document: Document, user_id: str
|
||||
) -> SearchSourceConnector | None:
|
||||
if not document.connector_id:
|
||||
return None
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
|
@ -4,29 +4,174 @@ from datetime import datetime
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.linear_connector import LinearConnector
|
||||
from app.db import Document
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LinearKBSyncService:
|
||||
"""Re-indexes a single Linear issue document after a successful update.
|
||||
"""Syncs Linear issue documents to the knowledge base after HITL actions.
|
||||
|
||||
Mirrors the indexer's Phase-2 logic exactly: fetch fresh issue content,
|
||||
run generate_document_summary, create_document_chunks, then update the
|
||||
document row in the knowledge base.
|
||||
Provides sync_after_create (new issue) and sync_after_update (existing issue).
|
||||
Both mirror the indexer's Phase-2 logic: generate summary, create chunks,
|
||||
then persist the document row.
|
||||
"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
issue_id: str,
|
||||
issue_identifier: str,
|
||||
issue_title: str,
|
||||
issue_url: str | None,
|
||||
description: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.LINEAR_CONNECTOR, issue_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Document for Linear issue %s already exists (doc_id=%s), skipping",
|
||||
issue_identifier,
|
||||
existing.id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (description or "").strip()
|
||||
if not indexable_content:
|
||||
indexable_content = f"Linear Issue {issue_identifier}: {issue_title}"
|
||||
|
||||
issue_content = (
|
||||
f"# {issue_identifier}: {issue_title}\n\n{indexable_content}"
|
||||
)
|
||||
|
||||
content_hash = generate_content_hash(issue_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
logger.info(
|
||||
"Content-hash collision for Linear issue %s — identical content "
|
||||
"exists in doc %s. Using unique_identifier_hash as content_hash.",
|
||||
issue_identifier,
|
||||
dup.id,
|
||||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"issue_id": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"document_type": "Linear Issue",
|
||||
"connector_type": "Linear",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
issue_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
logger.warning("No LLM configured — using fallback summary")
|
||||
summary_content = (
|
||||
f"Linear Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(issue_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document = Document(
|
||||
title=f"{issue_identifier}: {issue_title}",
|
||||
document_type=DocumentType.LINEAR_CONNECTOR,
|
||||
document_metadata={
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"issue_url": issue_url,
|
||||
"source_connector": "linear",
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, issue=%s, chunks=%d",
|
||||
document.id,
|
||||
issue_identifier,
|
||||
len(chunks),
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
logger.warning(
|
||||
"Duplicate constraint hit during KB sync for issue %s. "
|
||||
"Rolling back — periodic indexer will handle it. Error: %s",
|
||||
issue_identifier,
|
||||
e,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for issue %s: %s",
|
||||
issue_identifier,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def sync_after_update(
|
||||
self,
|
||||
document_id: int,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import and_, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.linear_connector import LinearConnector
|
||||
from app.db import (
|
||||
|
|
@ -12,6 +14,8 @@ from app.db import (
|
|||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearWorkspace:
|
||||
|
|
@ -109,7 +113,34 @@ class LinearToolMetadataService:
|
|||
priorities = await self._fetch_priority_values(linear_client)
|
||||
teams = await self._fetch_teams_context(linear_client)
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to fetch Linear context: {e!s}"}
|
||||
logger.warning(
|
||||
"Linear connector %s (%s) auth failed, flagging as expired: %s",
|
||||
connector.id,
|
||||
workspace.name,
|
||||
e,
|
||||
)
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
workspaces.append(
|
||||
{
|
||||
"id": workspace.id,
|
||||
"name": workspace.name,
|
||||
"organization_name": workspace.organization_name,
|
||||
"teams": [],
|
||||
"priorities": [],
|
||||
"auth_expired": True,
|
||||
}
|
||||
)
|
||||
continue
|
||||
workspaces.append(
|
||||
{
|
||||
"id": workspace.id,
|
||||
|
|
@ -117,6 +148,7 @@ class LinearToolMetadataService:
|
|||
"organization_name": workspace.organization_name,
|
||||
"teams": teams,
|
||||
"priorities": priorities,
|
||||
"auth_expired": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -137,8 +169,8 @@ class LinearToolMetadataService:
|
|||
document = await self._resolve_issue(search_space_id, user_id, issue_ref)
|
||||
if not document:
|
||||
return {
|
||||
"error": f"Issue '{issue_ref}' not found in your indexed Linear issues. "
|
||||
"This could mean: (1) the issue doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"error": f"Issue '{issue_ref}' not found in your synced Linear issues. "
|
||||
"This could mean: (1) the issue doesn't exist, (2) it hasn't been synced yet, "
|
||||
"or (3) the title or identifier is different."
|
||||
}
|
||||
|
||||
|
|
@ -157,6 +189,17 @@ class LinearToolMetadataService:
|
|||
priorities = await self._fetch_priority_values(linear_client)
|
||||
issue_api = await self._fetch_issue_context(linear_client, issue.id)
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"401" in error_str
|
||||
or "authentication" in error_str
|
||||
or "re-authenticate" in error_str
|
||||
):
|
||||
return {
|
||||
"error": f"Failed to fetch Linear issue context: {e!s}",
|
||||
"auth_expired": True,
|
||||
"connector_id": connector.id,
|
||||
}
|
||||
return {"error": f"Failed to fetch Linear issue context: {e!s}"}
|
||||
|
||||
if not issue_api:
|
||||
|
|
@ -210,8 +253,8 @@ class LinearToolMetadataService:
|
|||
document = await self._resolve_issue(search_space_id, user_id, issue_ref)
|
||||
if not document:
|
||||
return {
|
||||
"error": f"Issue '{issue_ref}' not found in your indexed Linear issues. "
|
||||
"This could mean: (1) the issue doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"error": f"Issue '{issue_ref}' not found in your synced Linear issues. "
|
||||
"This could mean: (1) the issue doesn't exist, (2) it hasn't been synced yet, "
|
||||
"or (3) the title or identifier is different."
|
||||
}
|
||||
|
||||
|
|
@ -319,6 +362,7 @@ class LinearToolMetadataService:
|
|||
),
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
|
|
|||
|
|
@ -3,13 +3,14 @@ from datetime import datetime
|
|||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -19,6 +20,144 @@ class NotionKBSyncService:
|
|||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
page_id: str,
|
||||
page_title: str,
|
||||
page_url: str | None,
|
||||
content: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTION_CONNECTOR, page_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Document for Notion page %s already exists (doc_id=%s), skipping",
|
||||
page_id,
|
||||
existing.id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (content or "").strip()
|
||||
if not indexable_content:
|
||||
indexable_content = f"Notion Page: {page_title}"
|
||||
|
||||
markdown_content = f"# Notion Page: {page_title}\n\n{indexable_content}"
|
||||
|
||||
content_hash = generate_content_hash(markdown_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
logger.info(
|
||||
"Content-hash collision for Notion page %s — identical content "
|
||||
"exists in doc %s. Using unique_identifier_hash as content_hash.",
|
||||
page_id,
|
||||
dup.id,
|
||||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"page_title": page_title,
|
||||
"page_id": page_id,
|
||||
"document_type": "Notion Page",
|
||||
"connector_type": "Notion",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
markdown_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
logger.warning("No LLM configured — using fallback summary")
|
||||
summary_content = f"Notion Page: {page_title}\n\n{markdown_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(markdown_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document = Document(
|
||||
title=page_title,
|
||||
document_type=DocumentType.NOTION_CONNECTOR,
|
||||
document_metadata={
|
||||
"page_title": page_title,
|
||||
"page_id": page_id,
|
||||
"page_url": page_url,
|
||||
"source_connector": "notion",
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, page=%s, chunks=%d",
|
||||
document.id,
|
||||
page_title,
|
||||
len(chunks),
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
logger.warning(
|
||||
"Duplicate constraint hit during KB sync for page %s. "
|
||||
"Rolling back — periodic indexer will handle it. Error: %s",
|
||||
page_id,
|
||||
e,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for page %s: %s",
|
||||
page_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def sync_after_update(
|
||||
self,
|
||||
document_id: int,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.notion_history import NotionHistoryConnector
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
|
|
@ -11,6 +14,8 @@ from app.db import (
|
|||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotionAccount:
|
||||
|
|
@ -83,8 +88,37 @@ class NotionToolMetadataService:
|
|||
search_space_id, accounts
|
||||
)
|
||||
|
||||
accounts_with_status = []
|
||||
for acc in accounts:
|
||||
acc_dict = acc.to_dict()
|
||||
auth_expired = await self._check_account_health(acc.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == acc.id
|
||||
)
|
||||
)
|
||||
db_connector = result.scalar_one_or_none()
|
||||
if db_connector and not db_connector.config.get("auth_expired"):
|
||||
db_connector.config = {
|
||||
**db_connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(db_connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(db_connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
acc.id,
|
||||
exc_info=True,
|
||||
)
|
||||
accounts_with_status.append(acc_dict)
|
||||
|
||||
return {
|
||||
"accounts": [acc.to_dict() for acc in accounts],
|
||||
"accounts": accounts_with_status,
|
||||
"parent_pages": parent_pages,
|
||||
}
|
||||
|
||||
|
|
@ -104,13 +138,15 @@ class NotionToolMetadataService:
|
|||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
if not document:
|
||||
return {
|
||||
"error": f"Page '{page_title}' not found in your indexed Notion pages. "
|
||||
"This could mean: (1) the page doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"error": f"Page '{page_title}' not found in your synced Notion pages. "
|
||||
"This could mean: (1) the page doesn't exist, (2) it hasn't been synced yet, "
|
||||
"or (3) the page title is different. Please check the exact page title in Notion."
|
||||
}
|
||||
|
||||
|
|
@ -136,12 +172,33 @@ class NotionToolMetadataService:
|
|||
if not page_id:
|
||||
return {"error": "Page ID not found in document metadata"}
|
||||
|
||||
current_title = document.title
|
||||
document_id = document.id
|
||||
indexed_at = document.document_metadata.get("indexed_at")
|
||||
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
try:
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"account": account.to_dict(),
|
||||
"account": acc_dict,
|
||||
"page_id": page_id,
|
||||
"current_title": document.title,
|
||||
"document_id": document.id,
|
||||
"indexed_at": document.document_metadata.get("indexed_at"),
|
||||
"current_title": current_title,
|
||||
"document_id": document_id,
|
||||
"indexed_at": indexed_at,
|
||||
}
|
||||
|
||||
async def get_delete_context(
|
||||
|
|
@ -167,6 +224,26 @@ class NotionToolMetadataService:
|
|||
connectors = result.scalars().all()
|
||||
return [NotionAccount.from_connector(conn) for conn in connectors]
|
||||
|
||||
async def _check_account_health(self, connector_id: int) -> bool:
|
||||
"""Check if a Notion connector's token is still valid.
|
||||
|
||||
Uses a lightweight ``users.me()`` call to verify the token.
|
||||
|
||||
Returns True if the token is expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
connector = NotionHistoryConnector(
|
||||
session=self._db_session, connector_id=connector_id
|
||||
)
|
||||
client = await connector._get_client()
|
||||
await client.users.me()
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Notion connector %s health check failed: %s", connector_id, e
|
||||
)
|
||||
return True
|
||||
|
||||
async def _get_parent_pages_by_account(
|
||||
self, search_space_id: int, accounts: list[NotionAccount]
|
||||
) -> dict:
|
||||
|
|
|
|||
|
|
@ -55,7 +55,6 @@ async def _check_and_trigger_schedules():
|
|||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_airtable_records_task,
|
||||
index_clickup_tasks_task,
|
||||
index_composio_connector_task,
|
||||
index_confluence_pages_task,
|
||||
index_crawled_urls_task,
|
||||
index_discord_messages_task,
|
||||
|
|
@ -88,10 +87,10 @@ async def _check_and_trigger_schedules():
|
|||
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
|
||||
# Composio connector types
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: index_composio_connector_task,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: index_composio_connector_task,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: index_composio_connector_task,
|
||||
# Composio connector types (unified with native Google tasks)
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: index_google_gmail_messages_task,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
|
||||
}
|
||||
|
||||
# Trigger indexing for each due connector
|
||||
|
|
@ -129,11 +128,11 @@ async def _check_and_trigger_schedules():
|
|||
f"({connector.connector_type.value})"
|
||||
)
|
||||
|
||||
# Special handling for Google Drive - uses config for folder/file selection
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
# Special handling for Google Drive (native and Composio) - uses config for folder/file selection
|
||||
if connector.connector_type in [
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]:
|
||||
connector_config = connector.config or {}
|
||||
selected_folders = connector_config.get("selected_folders", [])
|
||||
selected_files = connector_config.get("selected_files", [])
|
||||
|
|
|
|||
|
|
@ -936,6 +936,19 @@ async def _stream_agent_events(
|
|||
"delete_linear_issue",
|
||||
"create_google_drive_file",
|
||||
"delete_google_drive_file",
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"send_gmail_email",
|
||||
"trash_gmail_email",
|
||||
"create_calendar_event",
|
||||
"update_calendar_event",
|
||||
"delete_calendar_event",
|
||||
"create_jira_issue",
|
||||
"update_jira_issue",
|
||||
"delete_jira_issue",
|
||||
"create_confluence_page",
|
||||
"update_confluence_page",
|
||||
"delete_confluence_page",
|
||||
):
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
|
|
|
|||
|
|
@ -25,6 +25,10 @@ from app.utils.document_converters import (
|
|||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
from app.utils.google_credentials import (
|
||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
||||
build_composio_credentials,
|
||||
)
|
||||
|
||||
from .base import (
|
||||
check_document_by_unique_identifier,
|
||||
|
|
@ -37,6 +41,11 @@ from .base import (
|
|||
update_connector_last_indexed,
|
||||
)
|
||||
|
||||
ACCEPTED_CALENDAR_CONNECTOR_TYPES = {
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
}
|
||||
|
||||
# Type hint for heartbeat callback
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
|
||||
|
|
@ -53,7 +62,7 @@ async def index_google_calendar_events(
|
|||
end_date: str | None = None,
|
||||
update_last_indexed: bool = True,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, str | None]:
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""
|
||||
Index Google Calendar events.
|
||||
|
||||
|
|
@ -69,7 +78,7 @@ async def index_google_calendar_events(
|
|||
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
|
||||
|
||||
Returns:
|
||||
Tuple containing (number of documents indexed, error message or None)
|
||||
Tuple containing (number of documents indexed, number of documents skipped, error message or None)
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
|
@ -87,10 +96,12 @@ async def index_google_calendar_events(
|
|||
)
|
||||
|
||||
try:
|
||||
# Get the connector from the database
|
||||
connector = await get_connector_by_id(
|
||||
session, connector_id, SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR
|
||||
)
|
||||
# Accept both native and Composio Calendar connectors
|
||||
connector = None
|
||||
for ct in ACCEPTED_CALENDAR_CONNECTOR_TYPES:
|
||||
connector = await get_connector_by_id(session, connector_id, ct)
|
||||
if connector:
|
||||
break
|
||||
|
||||
if not connector:
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -99,71 +110,80 @@ async def index_google_calendar_events(
|
|||
"Connector not found",
|
||||
{"error_type": "ConnectorNotFound"},
|
||||
)
|
||||
return 0, f"Connector with ID {connector_id} not found"
|
||||
return 0, 0, f"Connector with ID {connector_id} not found"
|
||||
|
||||
# Get the Google Calendar credentials from the connector config
|
||||
config_data = connector.config
|
||||
|
||||
# Decrypt sensitive credentials if encrypted (for backward compatibility)
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
|
||||
# Decrypt sensitive fields
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Decrypted Google Calendar credentials for connector {connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Build credentials based on connector type
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to decrypt Google Calendar credentials for connector {connector_id}: {e!s}",
|
||||
"Credential decryption failed",
|
||||
{"error_type": "CredentialDecryptionError"},
|
||||
f"Composio connected_account_id not found for connector {connector_id}",
|
||||
"Missing Composio account",
|
||||
{"error_type": "MissingComposioAccount"},
|
||||
)
|
||||
return 0, f"Failed to decrypt Google Calendar credentials: {e!s}"
|
||||
return 0, 0, "Composio connected_account_id not found"
|
||||
credentials = build_composio_credentials(connected_account_id)
|
||||
else:
|
||||
config_data = connector.config
|
||||
|
||||
exp = config_data.get("expiry", "").replace("Z", "")
|
||||
credentials = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes"),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
if (
|
||||
not credentials.client_id
|
||||
or not credentials.client_secret
|
||||
or not credentials.refresh_token
|
||||
):
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Google Calendar credentials not found in connector config for connector {connector_id}",
|
||||
"Missing Google Calendar credentials",
|
||||
{"error_type": "MissingCredentials"},
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
logger.info(
|
||||
f"Decrypted Google Calendar credentials for connector {connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to decrypt Google Calendar credentials for connector {connector_id}: {e!s}",
|
||||
"Credential decryption failed",
|
||||
{"error_type": "CredentialDecryptionError"},
|
||||
)
|
||||
return 0, 0, f"Failed to decrypt Google Calendar credentials: {e!s}"
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
credentials = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
return 0, "Google Calendar credentials not found in connector config"
|
||||
|
||||
# Initialize Google Calendar client
|
||||
if (
|
||||
not credentials.client_id
|
||||
or not credentials.client_secret
|
||||
or not credentials.refresh_token
|
||||
):
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Google Calendar credentials not found in connector config for connector {connector_id}",
|
||||
"Missing Google Calendar credentials",
|
||||
{"error_type": "MissingCredentials"},
|
||||
)
|
||||
return 0, 0, "Google Calendar credentials not found in connector config"
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Initializing Google Calendar client for connector {connector_id}",
|
||||
|
|
@ -281,7 +301,7 @@ async def index_google_calendar_events(
|
|||
f"No Google Calendar events found in date range {start_date_str} to {end_date_str}",
|
||||
{"events_found": 0},
|
||||
)
|
||||
return 0, None
|
||||
return 0, 0, None
|
||||
else:
|
||||
logger.error(f"Failed to get Google Calendar events: {error}")
|
||||
# Check if this is an authentication error that requires re-authentication
|
||||
|
|
@ -301,13 +321,13 @@ async def index_google_calendar_events(
|
|||
error,
|
||||
{"error_type": error_type},
|
||||
)
|
||||
return 0, error_message
|
||||
return 0, 0, error_message
|
||||
|
||||
logger.info(f"Retrieved {len(events)} events from Google Calendar API")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Google Calendar events: {e!s}", exc_info=True)
|
||||
return 0, f"Error fetching Google Calendar events: {e!s}"
|
||||
return 0, 0, f"Error fetching Google Calendar events: {e!s}"
|
||||
|
||||
documents_indexed = 0
|
||||
documents_skipped = 0
|
||||
|
|
@ -363,6 +383,31 @@ async def index_google_calendar_events(
|
|||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
# Fallback: legacy Composio hash
|
||||
if not existing_document:
|
||||
legacy_hash = generate_unique_identifier_hash(
|
||||
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
event_id,
|
||||
search_space_id,
|
||||
)
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, legacy_hash
|
||||
)
|
||||
if existing_document:
|
||||
existing_document.unique_identifier_hash = (
|
||||
unique_identifier_hash
|
||||
)
|
||||
if (
|
||||
existing_document.document_type
|
||||
== DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
existing_document.document_type = (
|
||||
DocumentType.GOOGLE_CALENDAR_CONNECTOR
|
||||
)
|
||||
logger.info(
|
||||
f"Migrated legacy Composio Calendar document: {event_id}"
|
||||
)
|
||||
|
||||
if existing_document:
|
||||
# Document exists - check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
|
|
@ -609,7 +654,7 @@ async def index_google_calendar_events(
|
|||
f"{documents_skipped} skipped, {documents_failed} failed "
|
||||
f"({duplicate_content_count} duplicate content)"
|
||||
)
|
||||
return total_processed, warning_message
|
||||
return total_processed, documents_skipped, warning_message
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
@ -620,7 +665,7 @@ async def index_google_calendar_events(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, f"Database error: {db_error!s}"
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -630,4 +675,4 @@ async def index_google_calendar_events(
|
|||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Google Calendar events: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to index Google Calendar events: {e!s}"
|
||||
return 0, 0, f"Failed to index Google Calendar events: {e!s}"
|
||||
|
|
|
|||
|
|
@ -31,6 +31,15 @@ from app.tasks.connector_indexers.base import (
|
|||
update_connector_last_indexed,
|
||||
)
|
||||
from app.utils.document_converters import generate_unique_identifier_hash
|
||||
from app.utils.google_credentials import (
|
||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
||||
build_composio_credentials,
|
||||
)
|
||||
|
||||
ACCEPTED_DRIVE_CONNECTOR_TYPES = {
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
}
|
||||
|
||||
# Type hint for heartbeat callback
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
|
|
@ -53,7 +62,7 @@ async def index_google_drive_files(
|
|||
max_files: int = 500,
|
||||
include_subfolders: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, str | None]:
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""
|
||||
Index Google Drive files for a specific connector.
|
||||
|
||||
|
|
@ -71,7 +80,7 @@ async def index_google_drive_files(
|
|||
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
|
||||
|
||||
Returns:
|
||||
Tuple of (number_of_indexed_files, error_message)
|
||||
Tuple of (number_of_indexed_files, number_of_skipped_files, error_message)
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
|
@ -89,16 +98,19 @@ async def index_google_drive_files(
|
|||
)
|
||||
|
||||
try:
|
||||
connector = await get_connector_by_id(
|
||||
session, connector_id, SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR
|
||||
)
|
||||
# Accept both native and Composio Drive connectors
|
||||
connector = None
|
||||
for ct in ACCEPTED_DRIVE_CONNECTOR_TYPES:
|
||||
connector = await get_connector_by_id(session, connector_id, ct)
|
||||
if connector:
|
||||
break
|
||||
|
||||
if not connector:
|
||||
error_msg = f"Google Drive connector with ID {connector_id} not found"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, {"error_type": "ConnectorNotFound"}
|
||||
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||
)
|
||||
return 0, error_msg
|
||||
return 0, 0, error_msg
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
|
|
@ -106,34 +118,51 @@ async def index_google_drive_files(
|
|||
{"stage": "client_initialization"},
|
||||
)
|
||||
|
||||
# Check if credentials are encrypted (only when explicitly marked)
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted:
|
||||
# Credentials are explicitly marked as encrypted, will be decrypted during client initialization
|
||||
if not config.SECRET_KEY:
|
||||
# Build credentials based on connector type
|
||||
pre_built_credentials = None
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"SECRET_KEY not configured but credentials are marked as encrypted for connector {connector_id}",
|
||||
"Missing SECRET_KEY for token decryption",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
error_msg,
|
||||
"Missing Composio account",
|
||||
{"error_type": "MissingComposioAccount"},
|
||||
)
|
||||
return (
|
||||
0,
|
||||
"SECRET_KEY not configured but credentials are marked as encrypted",
|
||||
return 0, 0, error_msg
|
||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
||||
else:
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted:
|
||||
if not config.SECRET_KEY:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"SECRET_KEY not configured but credentials are marked as encrypted for connector {connector_id}",
|
||||
"Missing SECRET_KEY for token decryption",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return (
|
||||
0,
|
||||
0,
|
||||
"SECRET_KEY not configured but credentials are marked as encrypted",
|
||||
)
|
||||
logger.info(
|
||||
f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization"
|
||||
)
|
||||
logger.info(
|
||||
f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization"
|
||||
)
|
||||
# If _token_encrypted is False or not set, treat credentials as plaintext
|
||||
|
||||
drive_client = GoogleDriveClient(session, connector_id)
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
|
||||
drive_client = GoogleDriveClient(
|
||||
session, connector_id, credentials=pre_built_credentials
|
||||
)
|
||||
|
||||
if not folder_id:
|
||||
error_msg = "folder_id is required for Google Drive indexing"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, {"error_type": "MissingParameter"}
|
||||
)
|
||||
return 0, error_msg
|
||||
return 0, 0, error_msg
|
||||
|
||||
target_folder_id = folder_id
|
||||
target_folder_name = folder_name or "Selected Folder"
|
||||
|
|
@ -164,7 +193,33 @@ async def index_google_drive_files(
|
|||
max_files=max_files,
|
||||
include_subfolders=include_subfolders,
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
documents_indexed, documents_skipped = result
|
||||
|
||||
# Reconciliation: full scan re-indexes documents that were manually
|
||||
# deleted from SurfSense but still exist in Google Drive.
|
||||
# Already-indexed files are skipped via md5/modifiedTime checks,
|
||||
# so the overhead is just one API listing call + fast DB lookups.
|
||||
logger.info("Running reconciliation scan after delta sync")
|
||||
reconcile_result = await _index_full_scan(
|
||||
drive_client=drive_client,
|
||||
session=session,
|
||||
connector=connector,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_id=target_folder_id,
|
||||
folder_name=target_folder_name,
|
||||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
max_files=max_files,
|
||||
include_subfolders=include_subfolders,
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
documents_indexed += reconcile_result[0]
|
||||
documents_skipped += reconcile_result[1]
|
||||
else:
|
||||
logger.info(f"Using full scan for connector {connector_id}")
|
||||
result = await _index_full_scan(
|
||||
|
|
@ -181,9 +236,9 @@ async def index_google_drive_files(
|
|||
max_files=max_files,
|
||||
include_subfolders=include_subfolders,
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
|
||||
documents_indexed, documents_skipped = result
|
||||
documents_indexed, documents_skipped = result
|
||||
|
||||
if documents_indexed > 0 or can_use_delta_sync:
|
||||
new_token, token_error = await get_start_page_token(drive_client)
|
||||
|
|
@ -217,7 +272,7 @@ async def index_google_drive_files(
|
|||
logger.info(
|
||||
f"Google Drive indexing completed: {documents_indexed} files indexed, {documents_skipped} skipped"
|
||||
)
|
||||
return documents_indexed, None
|
||||
return documents_indexed, documents_skipped, None
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
@ -228,7 +283,7 @@ async def index_google_drive_files(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, f"Database error: {db_error!s}"
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -238,7 +293,7 @@ async def index_google_drive_files(
|
|||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Google Drive files: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to index Google Drive files: {e!s}"
|
||||
return 0, 0, f"Failed to index Google Drive files: {e!s}"
|
||||
|
||||
|
||||
async def index_google_drive_single_file(
|
||||
|
|
@ -278,14 +333,17 @@ async def index_google_drive_single_file(
|
|||
)
|
||||
|
||||
try:
|
||||
connector = await get_connector_by_id(
|
||||
session, connector_id, SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR
|
||||
)
|
||||
# Accept both native and Composio Drive connectors
|
||||
connector = None
|
||||
for ct in ACCEPTED_DRIVE_CONNECTOR_TYPES:
|
||||
connector = await get_connector_by_id(session, connector_id, ct)
|
||||
if connector:
|
||||
break
|
||||
|
||||
if not connector:
|
||||
error_msg = f"Google Drive connector with ID {connector_id} not found"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, {"error_type": "ConnectorNotFound"}
|
||||
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||
)
|
||||
return 0, error_msg
|
||||
|
||||
|
|
@ -295,27 +353,42 @@ async def index_google_drive_single_file(
|
|||
{"stage": "client_initialization"},
|
||||
)
|
||||
|
||||
# Check if credentials are encrypted (only when explicitly marked)
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted:
|
||||
# Credentials are explicitly marked as encrypted, will be decrypted during client initialization
|
||||
if not config.SECRET_KEY:
|
||||
pre_built_credentials = None
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"SECRET_KEY not configured but credentials are marked as encrypted for connector {connector_id}",
|
||||
"Missing SECRET_KEY for token decryption",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
error_msg,
|
||||
"Missing Composio account",
|
||||
{"error_type": "MissingComposioAccount"},
|
||||
)
|
||||
return (
|
||||
0,
|
||||
"SECRET_KEY not configured but credentials are marked as encrypted",
|
||||
return 0, error_msg
|
||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
||||
else:
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted:
|
||||
if not config.SECRET_KEY:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"SECRET_KEY not configured but credentials are marked as encrypted for connector {connector_id}",
|
||||
"Missing SECRET_KEY for token decryption",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return (
|
||||
0,
|
||||
"SECRET_KEY not configured but credentials are marked as encrypted",
|
||||
)
|
||||
logger.info(
|
||||
f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization"
|
||||
)
|
||||
logger.info(
|
||||
f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization"
|
||||
)
|
||||
# If _token_encrypted is False or not set, treat credentials as plaintext
|
||||
|
||||
drive_client = GoogleDriveClient(session, connector_id)
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
|
||||
drive_client = GoogleDriveClient(
|
||||
session, connector_id, credentials=pre_built_credentials
|
||||
)
|
||||
|
||||
# Fetch the file metadata
|
||||
file, error = await get_file_by_id(drive_client, file_id)
|
||||
|
|
@ -362,6 +435,7 @@ async def index_google_drive_single_file(
|
|||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
pending_document=pending_doc,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
|
@ -433,6 +507,7 @@ async def _index_full_scan(
|
|||
max_files: int,
|
||||
include_subfolders: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""Perform full scan indexing of a folder.
|
||||
|
||||
|
|
@ -467,6 +542,7 @@ async def _index_full_scan(
|
|||
|
||||
# Queue of folders to process: (folder_id, folder_name)
|
||||
folders_to_process = [(folder_id, folder_name)]
|
||||
first_listing_error: str | None = None
|
||||
|
||||
logger.info("Phase 1: Collecting files and creating pending documents")
|
||||
|
||||
|
|
@ -486,6 +562,8 @@ async def _index_full_scan(
|
|||
|
||||
if error:
|
||||
logger.error(f"Error listing files in {current_folder_name}: {error}")
|
||||
if first_listing_error is None:
|
||||
first_listing_error = error
|
||||
break
|
||||
|
||||
if not files:
|
||||
|
|
@ -531,6 +609,19 @@ async def _index_full_scan(
|
|||
if not page_token:
|
||||
break
|
||||
|
||||
if not files_to_process and first_listing_error:
|
||||
error_lower = first_listing_error.lower()
|
||||
if (
|
||||
"401" in first_listing_error
|
||||
or "invalid credentials" in error_lower
|
||||
or "authError" in first_listing_error
|
||||
):
|
||||
raise Exception(
|
||||
f"Google Drive authentication failed. Please re-authenticate. "
|
||||
f"(Error: {first_listing_error})"
|
||||
)
|
||||
raise Exception(f"Failed to list Google Drive files: {first_listing_error}")
|
||||
|
||||
# Commit all pending documents - they all appear in UI now
|
||||
if new_documents_created:
|
||||
logger.info(
|
||||
|
|
@ -562,6 +653,7 @@ async def _index_full_scan(
|
|||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
pending_document=pending_doc,
|
||||
enable_summary=enable_summary,
|
||||
)
|
||||
|
||||
documents_indexed += indexed
|
||||
|
|
@ -592,6 +684,7 @@ async def _index_with_delta_sync(
|
|||
max_files: int,
|
||||
include_subfolders: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""Perform delta sync indexing using change tracking.
|
||||
|
||||
|
|
@ -614,7 +707,17 @@ async def _index_with_delta_sync(
|
|||
|
||||
if error:
|
||||
logger.error(f"Error fetching changes: {error}")
|
||||
return 0, 0
|
||||
error_lower = error.lower()
|
||||
if (
|
||||
"401" in error
|
||||
or "invalid credentials" in error_lower
|
||||
or "authError" in error
|
||||
):
|
||||
raise Exception(
|
||||
f"Google Drive authentication failed. Please re-authenticate. "
|
||||
f"(Error: {error})"
|
||||
)
|
||||
raise Exception(f"Failed to fetch Google Drive changes: {error}")
|
||||
|
||||
if not changes:
|
||||
logger.info("No changes detected since last sync")
|
||||
|
|
@ -703,6 +806,7 @@ async def _index_with_delta_sync(
|
|||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
pending_document=pending_doc,
|
||||
enable_summary=enable_summary,
|
||||
)
|
||||
|
||||
documents_indexed += indexed
|
||||
|
|
@ -763,10 +867,25 @@ async def _create_pending_document_for_file(
|
|||
DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id
|
||||
)
|
||||
|
||||
# Check if document exists
|
||||
# Check if document exists (primary hash first, then legacy Composio hash)
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
if not existing_document:
|
||||
legacy_hash = generate_unique_identifier_hash(
|
||||
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, file_id, search_space_id
|
||||
)
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, legacy_hash
|
||||
)
|
||||
if existing_document:
|
||||
existing_document.unique_identifier_hash = unique_identifier_hash
|
||||
if (
|
||||
existing_document.document_type
|
||||
== DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
existing_document.document_type = DocumentType.GOOGLE_DRIVE_FILE
|
||||
logger.info(f"Migrated legacy Composio document to native type: {file_id}")
|
||||
|
||||
if existing_document:
|
||||
# Check if this is a rename-only update (content unchanged)
|
||||
|
|
@ -862,12 +981,26 @@ async def _check_rename_only_update(
|
|||
)
|
||||
existing_document = await check_document_by_unique_identifier(session, primary_hash)
|
||||
|
||||
# If not found by primary hash, try searching by metadata (for legacy documents)
|
||||
# Fallback: legacy Composio hash
|
||||
if not existing_document:
|
||||
legacy_hash = generate_unique_identifier_hash(
|
||||
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, file_id, search_space_id
|
||||
)
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, legacy_hash
|
||||
)
|
||||
|
||||
# Fallback: metadata search (covers old filename-based hashes)
|
||||
if not existing_document:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.GOOGLE_DRIVE_FILE,
|
||||
Document.document_type.in_(
|
||||
[
|
||||
DocumentType.GOOGLE_DRIVE_FILE,
|
||||
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
),
|
||||
cast(Document.document_metadata["google_drive_file_id"], String)
|
||||
== file_id,
|
||||
)
|
||||
|
|
@ -876,6 +1009,17 @@ async def _check_rename_only_update(
|
|||
if existing_document:
|
||||
logger.debug(f"Found legacy document by metadata for file_id: {file_id}")
|
||||
|
||||
# Migrate legacy Composio document to native type
|
||||
if existing_document:
|
||||
if existing_document.unique_identifier_hash != primary_hash:
|
||||
existing_document.unique_identifier_hash = primary_hash
|
||||
if (
|
||||
existing_document.document_type
|
||||
== DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
existing_document.document_type = DocumentType.GOOGLE_DRIVE_FILE
|
||||
logger.info(f"Migrated legacy Composio Drive document: {file_id}")
|
||||
|
||||
if not existing_document:
|
||||
# New file, needs full processing
|
||||
return False, None
|
||||
|
|
@ -957,6 +1101,7 @@ async def _process_single_file(
|
|||
task_logger: TaskLoggingService,
|
||||
log_entry: any,
|
||||
pending_document: Document | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int, int]:
|
||||
"""
|
||||
Process a single file by downloading and using Surfsense's file processor.
|
||||
|
|
@ -1020,6 +1165,7 @@ async def _process_single_file(
|
|||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
connector_id=connector_id,
|
||||
enable_summary=enable_summary,
|
||||
)
|
||||
|
||||
if error:
|
||||
|
|
@ -1088,12 +1234,26 @@ async def _remove_document(session: AsyncSession, file_id: str, search_space_id:
|
|||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
# If not found, search by metadata (for legacy documents with filename-based hash)
|
||||
# Fallback: legacy Composio hash
|
||||
if not existing_document:
|
||||
legacy_hash = generate_unique_identifier_hash(
|
||||
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, file_id, search_space_id
|
||||
)
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, legacy_hash
|
||||
)
|
||||
|
||||
# Fallback: metadata search (covers old filename-based hashes, both native and Composio)
|
||||
if not existing_document:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.GOOGLE_DRIVE_FILE,
|
||||
Document.document_type.in_(
|
||||
[
|
||||
DocumentType.GOOGLE_DRIVE_FILE,
|
||||
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
),
|
||||
cast(Document.document_metadata["google_drive_file_id"], String)
|
||||
== file_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,10 @@ from app.utils.document_converters import (
|
|||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
from app.utils.google_credentials import (
|
||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
||||
build_composio_credentials,
|
||||
)
|
||||
|
||||
from .base import (
|
||||
calculate_date_range,
|
||||
|
|
@ -42,6 +46,11 @@ from .base import (
|
|||
update_connector_last_indexed,
|
||||
)
|
||||
|
||||
ACCEPTED_GMAIL_CONNECTOR_TYPES = {
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
}
|
||||
|
||||
# Type hint for heartbeat callback
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
|
||||
|
|
@ -59,7 +68,7 @@ async def index_google_gmail_messages(
|
|||
update_last_indexed: bool = True,
|
||||
max_messages: int = 1000,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, str]:
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""
|
||||
Index Gmail messages for a specific connector.
|
||||
|
||||
|
|
@ -75,7 +84,7 @@ async def index_google_gmail_messages(
|
|||
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
|
||||
|
||||
Returns:
|
||||
Tuple of (number_of_indexed_messages, status_message)
|
||||
Tuple of (number_of_indexed_messages, number_of_skipped_messages, status_message)
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
|
@ -94,90 +103,98 @@ async def index_google_gmail_messages(
|
|||
)
|
||||
|
||||
try:
|
||||
# Get connector by id
|
||||
connector = await get_connector_by_id(
|
||||
session, connector_id, SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR
|
||||
)
|
||||
# Accept both native and Composio Gmail connectors
|
||||
connector = None
|
||||
for ct in ACCEPTED_GMAIL_CONNECTOR_TYPES:
|
||||
connector = await get_connector_by_id(session, connector_id, ct)
|
||||
if connector:
|
||||
break
|
||||
|
||||
if not connector:
|
||||
error_msg = f"Gmail connector with ID {connector_id} not found"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, {"error_type": "ConnectorNotFound"}
|
||||
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||
)
|
||||
return 0, error_msg
|
||||
return 0, 0, error_msg
|
||||
|
||||
# Get the Google Gmail credentials from the connector config
|
||||
config_data = connector.config
|
||||
|
||||
# Decrypt sensitive credentials if encrypted (for backward compatibility)
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
|
||||
# Decrypt sensitive fields
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Decrypted Google Gmail credentials for connector {connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Build credentials based on connector type
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to decrypt Google Gmail credentials for connector {connector_id}: {e!s}",
|
||||
"Credential decryption failed",
|
||||
{"error_type": "CredentialDecryptionError"},
|
||||
f"Composio connected_account_id not found for connector {connector_id}",
|
||||
"Missing Composio account",
|
||||
{"error_type": "MissingComposioAccount"},
|
||||
)
|
||||
return 0, f"Failed to decrypt Google Gmail credentials: {e!s}"
|
||||
return 0, 0, "Composio connected_account_id not found"
|
||||
credentials = build_composio_credentials(connected_account_id)
|
||||
else:
|
||||
config_data = connector.config
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
credentials = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
if (
|
||||
not credentials.client_id
|
||||
or not credentials.client_secret
|
||||
or not credentials.refresh_token
|
||||
):
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Google gmail credentials not found in connector config for connector {connector_id}",
|
||||
"Missing Google gmail credentials",
|
||||
{"error_type": "MissingCredentials"},
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
logger.info(
|
||||
f"Decrypted Google Gmail credentials for connector {connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to decrypt Google Gmail credentials for connector {connector_id}: {e!s}",
|
||||
"Credential decryption failed",
|
||||
{"error_type": "CredentialDecryptionError"},
|
||||
)
|
||||
return 0, 0, f"Failed to decrypt Google Gmail credentials: {e!s}"
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
credentials = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
return 0, "Google gmail credentials not found in connector config"
|
||||
|
||||
# Initialize Google gmail client
|
||||
if (
|
||||
not credentials.client_id
|
||||
or not credentials.client_secret
|
||||
or not credentials.refresh_token
|
||||
):
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Google gmail credentials not found in connector config for connector {connector_id}",
|
||||
"Missing Google gmail credentials",
|
||||
{"error_type": "MissingCredentials"},
|
||||
)
|
||||
return 0, 0, "Google gmail credentials not found in connector config"
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Initializing Google gmail client for connector {connector_id}",
|
||||
{"stage": "client_initialization"},
|
||||
)
|
||||
|
||||
# Initialize Google gmail connector
|
||||
gmail_connector = GoogleGmailConnector(
|
||||
credentials, session, user_id, connector_id
|
||||
)
|
||||
|
|
@ -215,14 +232,14 @@ async def index_google_gmail_messages(
|
|||
await task_logger.log_task_failure(
|
||||
log_entry, error_message, error, {"error_type": error_type}
|
||||
)
|
||||
return 0, error_message
|
||||
return 0, 0, error_message
|
||||
|
||||
if not messages:
|
||||
success_msg = "No Google gmail messages found in the specified date range"
|
||||
await task_logger.log_task_success(
|
||||
log_entry, success_msg, {"messages_count": 0}
|
||||
)
|
||||
return 0, success_msg
|
||||
return 0, 0, success_msg
|
||||
|
||||
logger.info(f"Found {len(messages)} Google gmail messages to index")
|
||||
|
||||
|
|
@ -293,6 +310,31 @@ async def index_google_gmail_messages(
|
|||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
# Fallback: legacy Composio hash
|
||||
if not existing_document:
|
||||
legacy_hash = generate_unique_identifier_hash(
|
||||
DocumentType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
message_id,
|
||||
search_space_id,
|
||||
)
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, legacy_hash
|
||||
)
|
||||
if existing_document:
|
||||
existing_document.unique_identifier_hash = (
|
||||
unique_identifier_hash
|
||||
)
|
||||
if (
|
||||
existing_document.document_type
|
||||
== DocumentType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
existing_document.document_type = (
|
||||
DocumentType.GOOGLE_GMAIL_CONNECTOR
|
||||
)
|
||||
logger.info(
|
||||
f"Migrated legacy Composio Gmail document: {message_id}"
|
||||
)
|
||||
|
||||
if existing_document:
|
||||
# Document exists - check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
|
|
@ -531,10 +573,7 @@ async def index_google_gmail_messages(
|
|||
f"{documents_skipped} skipped, {documents_failed} failed "
|
||||
f"({duplicate_content_count} duplicate content)"
|
||||
)
|
||||
return (
|
||||
total_processed,
|
||||
warning_message,
|
||||
) # Return warning_message (None on success)
|
||||
return total_processed, documents_skipped, warning_message
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
@ -545,7 +584,7 @@ async def index_google_gmail_messages(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, f"Database error: {db_error!s}"
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -555,4 +594,4 @@ async def index_google_gmail_messages(
|
|||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Google gmail emails: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to index Google gmail emails: {e!s}"
|
||||
return 0, 0, f"Failed to index Google gmail emails: {e!s}"
|
||||
|
|
|
|||
|
|
@ -170,7 +170,34 @@ async def handle_existing_document_update(
|
|||
logging.info(f"Document for file {filename} unchanged. Skipping.")
|
||||
return True, existing_document
|
||||
else:
|
||||
# Content has changed - need to re-process
|
||||
# Content has changed — guard against content_hash collision before
|
||||
# expensive ETL processing. A collision means the exact same content
|
||||
# already lives in a *different* document (e.g. a manual upload of the
|
||||
# same file). Proceeding would trigger a unique-constraint violation
|
||||
# on ix_documents_content_hash.
|
||||
collision_doc = await check_duplicate_document(session, content_hash)
|
||||
if collision_doc and collision_doc.id != existing_document.id:
|
||||
logging.warning(
|
||||
"Content-hash collision for %s: identical content exists in "
|
||||
"document #%s (%s). Skipping re-processing.",
|
||||
filename,
|
||||
collision_doc.id,
|
||||
collision_doc.document_type,
|
||||
)
|
||||
if DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.PENDING
|
||||
) or DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.PROCESSING
|
||||
):
|
||||
# Pending/processing doc has no real content yet — remove it
|
||||
# so the UI doesn't show a contentless entry.
|
||||
await session.delete(existing_document)
|
||||
await session.commit()
|
||||
return True, None
|
||||
|
||||
# Document already has valid content — keep it as-is.
|
||||
return True, existing_document
|
||||
|
||||
logging.info(f"Content changed for file {filename}. Updating document.")
|
||||
return False, None
|
||||
|
||||
|
|
@ -411,6 +438,7 @@ async def add_received_file_document_using_unstructured(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store a file document using Unstructured service.
|
||||
|
|
@ -471,9 +499,13 @@ async def add_received_file_document_using_unstructured(
|
|||
"etl_service": "UNSTRUCTURED",
|
||||
"document_type": "File Document",
|
||||
}
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
file_in_markdown, user_llm, document_metadata
|
||||
)
|
||||
if enable_summary:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
file_in_markdown, user_llm, document_metadata
|
||||
)
|
||||
else:
|
||||
summary_content = f"File: {file_name}\n\n{file_in_markdown[:4000]}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = await create_document_chunks(file_in_markdown)
|
||||
|
|
@ -493,14 +525,13 @@ async def add_received_file_document_using_unstructured(
|
|||
existing_document.source_markdown = file_in_markdown
|
||||
existing_document.content_needs_reindexing = False
|
||||
existing_document.updated_at = get_current_timestamp()
|
||||
existing_document.status = DocumentStatus.ready() # Mark as ready
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(existing_document)
|
||||
document = existing_document
|
||||
else:
|
||||
# Create new document
|
||||
# Determine document type based on connector
|
||||
doc_type = DocumentType.FILE
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
doc_type = DocumentType.GOOGLE_DRIVE_FILE
|
||||
|
|
@ -523,7 +554,7 @@ async def add_received_file_document_using_unstructured(
|
|||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector.get("connector_id") if connector else None,
|
||||
status=DocumentStatus.ready(), # Mark as ready
|
||||
status=DocumentStatus.ready(),
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
|
|
@ -533,6 +564,12 @@ async def add_received_file_document_using_unstructured(
|
|||
return document
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
if "ix_documents_content_hash" in str(db_error):
|
||||
logging.warning(
|
||||
"content_hash collision during commit for %s (Unstructured). Skipping.",
|
||||
file_name,
|
||||
)
|
||||
return None
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
|
|
@ -546,6 +583,7 @@ async def add_received_file_document_using_llamacloud(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store document content parsed by LlamaCloud.
|
||||
|
|
@ -605,16 +643,19 @@ async def add_received_file_document_using_llamacloud(
|
|||
"etl_service": "LLAMACLOUD",
|
||||
"document_type": "File Document",
|
||||
}
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
file_in_markdown, user_llm, document_metadata
|
||||
)
|
||||
if enable_summary:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
file_in_markdown, user_llm, document_metadata
|
||||
)
|
||||
else:
|
||||
summary_content = f"File: {file_name}\n\n{file_in_markdown[:4000]}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = await create_document_chunks(file_in_markdown)
|
||||
|
||||
# Update or create document
|
||||
if existing_document:
|
||||
# Update existing document
|
||||
existing_document.title = file_name
|
||||
existing_document.content = summary_content
|
||||
existing_document.content_hash = content_hash
|
||||
|
|
@ -627,14 +668,12 @@ async def add_received_file_document_using_llamacloud(
|
|||
existing_document.source_markdown = file_in_markdown
|
||||
existing_document.content_needs_reindexing = False
|
||||
existing_document.updated_at = get_current_timestamp()
|
||||
existing_document.status = DocumentStatus.ready() # Mark as ready
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(existing_document)
|
||||
document = existing_document
|
||||
else:
|
||||
# Create new document
|
||||
# Determine document type based on connector
|
||||
doc_type = DocumentType.FILE
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
doc_type = DocumentType.GOOGLE_DRIVE_FILE
|
||||
|
|
@ -657,7 +696,7 @@ async def add_received_file_document_using_llamacloud(
|
|||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector.get("connector_id") if connector else None,
|
||||
status=DocumentStatus.ready(), # Mark as ready
|
||||
status=DocumentStatus.ready(),
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
|
|
@ -667,6 +706,12 @@ async def add_received_file_document_using_llamacloud(
|
|||
return document
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
if "ix_documents_content_hash" in str(db_error):
|
||||
logging.warning(
|
||||
"content_hash collision during commit for %s (LlamaCloud). Skipping.",
|
||||
file_name,
|
||||
)
|
||||
return None
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
|
|
@ -682,6 +727,7 @@ async def add_received_file_document_using_docling(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store document content parsed by Docling.
|
||||
|
|
@ -734,33 +780,32 @@ async def add_received_file_document_using_docling(
|
|||
f"No long context LLM configured for user {user_id} in search_space {search_space_id}"
|
||||
)
|
||||
|
||||
# Generate summary using chunked processing for large documents
|
||||
from app.services.docling_service import create_docling_service
|
||||
if enable_summary:
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
||||
docling_service = create_docling_service()
|
||||
docling_service = create_docling_service()
|
||||
|
||||
summary_content = await docling_service.process_large_document_summary(
|
||||
content=file_in_markdown, llm=user_llm, document_title=file_name
|
||||
)
|
||||
summary_content = await docling_service.process_large_document_summary(
|
||||
content=file_in_markdown, llm=user_llm, document_title=file_name
|
||||
)
|
||||
|
||||
# Enhance summary with metadata
|
||||
document_metadata = {
|
||||
"file_name": file_name,
|
||||
"etl_service": "DOCLING",
|
||||
"document_type": "File Document",
|
||||
}
|
||||
metadata_parts = []
|
||||
metadata_parts.append("# DOCUMENT METADATA")
|
||||
document_metadata = {
|
||||
"file_name": file_name,
|
||||
"etl_service": "DOCLING",
|
||||
"document_type": "File Document",
|
||||
}
|
||||
metadata_parts = ["# DOCUMENT METADATA"]
|
||||
for key, value in document_metadata.items():
|
||||
if value:
|
||||
formatted_key = key.replace("_", " ").title()
|
||||
metadata_parts.append(f"**{formatted_key}:** {value}")
|
||||
|
||||
for key, value in document_metadata.items():
|
||||
if value: # Only include non-empty values
|
||||
formatted_key = key.replace("_", " ").title()
|
||||
metadata_parts.append(f"**{formatted_key}:** {value}")
|
||||
|
||||
metadata_section = "\n".join(metadata_parts)
|
||||
enhanced_summary_content = (
|
||||
f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}"
|
||||
)
|
||||
metadata_section = "\n".join(metadata_parts)
|
||||
enhanced_summary_content = (
|
||||
f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}"
|
||||
)
|
||||
else:
|
||||
enhanced_summary_content = f"File: {file_name}\n\n{file_in_markdown[:4000]}"
|
||||
|
||||
summary_embedding = embed_text(enhanced_summary_content)
|
||||
|
||||
|
|
@ -822,6 +867,12 @@ async def add_received_file_document_using_docling(
|
|||
return document
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
if "ix_documents_content_hash" in str(db_error):
|
||||
logging.warning(
|
||||
"content_hash collision during commit for %s (Docling). Skipping.",
|
||||
file_name,
|
||||
)
|
||||
return None
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
|
|
@ -1219,9 +1270,17 @@ async def process_file_in_background(
|
|||
print("Error deleting temp file", e)
|
||||
pass
|
||||
|
||||
# Pass the documents to the existing background task
|
||||
enable_summary = (
|
||||
connector.get("enable_summary", True) if connector else True
|
||||
)
|
||||
result = await add_received_file_document_using_unstructured(
|
||||
session, filename, docs, search_space_id, user_id, connector
|
||||
session,
|
||||
filename,
|
||||
docs,
|
||||
search_space_id,
|
||||
user_id,
|
||||
connector,
|
||||
enable_summary=enable_summary,
|
||||
)
|
||||
|
||||
if connector:
|
||||
|
|
@ -1362,7 +1421,9 @@ async def process_file_in_background(
|
|||
# Extract text content from the markdown documents
|
||||
markdown_content = doc.text
|
||||
|
||||
# Process the documents using our LlamaCloud background task
|
||||
enable_summary = (
|
||||
connector.get("enable_summary", True) if connector else True
|
||||
)
|
||||
doc_result = await add_received_file_document_using_llamacloud(
|
||||
session,
|
||||
filename,
|
||||
|
|
@ -1370,6 +1431,7 @@ async def process_file_in_background(
|
|||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
connector=connector,
|
||||
enable_summary=enable_summary,
|
||||
)
|
||||
|
||||
# Track if this document was successfully created
|
||||
|
|
@ -1516,7 +1578,9 @@ async def process_file_in_background(
|
|||
session, notification, stage="chunking"
|
||||
)
|
||||
|
||||
# Process the document using our Docling background task
|
||||
enable_summary = (
|
||||
connector.get("enable_summary", True) if connector else True
|
||||
)
|
||||
doc_result = await add_received_file_document_using_docling(
|
||||
session,
|
||||
filename,
|
||||
|
|
@ -1524,6 +1588,7 @@ async def process_file_in_background(
|
|||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
connector=connector,
|
||||
enable_summary=enable_summary,
|
||||
)
|
||||
|
||||
if doc_result:
|
||||
|
|
|
|||
|
|
@ -158,6 +158,28 @@ async def _handle_existing_document_update(
|
|||
logging.info(f"Document for markdown file {filename} unchanged. Skipping.")
|
||||
return True, existing_document
|
||||
else:
|
||||
# Content has changed — guard against content_hash collision (same
|
||||
# content already lives in a different document).
|
||||
collision_doc = await check_duplicate_document(session, content_hash)
|
||||
if collision_doc and collision_doc.id != existing_document.id:
|
||||
logging.warning(
|
||||
"Content-hash collision for markdown %s: identical content "
|
||||
"exists in document #%s (%s). Skipping re-processing.",
|
||||
filename,
|
||||
collision_doc.id,
|
||||
collision_doc.document_type,
|
||||
)
|
||||
if DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.PENDING
|
||||
) or DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.PROCESSING
|
||||
):
|
||||
await session.delete(existing_document)
|
||||
await session.commit()
|
||||
return True, None
|
||||
|
||||
return True, existing_document
|
||||
|
||||
logging.info(
|
||||
f"Content changed for markdown file {filename}. Updating document."
|
||||
)
|
||||
|
|
@ -312,6 +334,12 @@ async def add_received_markdown_file_document(
|
|||
return document
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
if "ix_documents_content_hash" in str(db_error):
|
||||
logging.warning(
|
||||
"content_hash collision during commit for %s (markdown). Skipping.",
|
||||
file_name,
|
||||
)
|
||||
return None
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Database error processing markdown file: {file_name}",
|
||||
|
|
|
|||
|
|
@ -159,6 +159,44 @@ async def check_duplicate_connector(
|
|||
return (result.scalar() or 0) > 0
|
||||
|
||||
|
||||
async def ensure_unique_connector_name(
|
||||
session: AsyncSession,
|
||||
name: str,
|
||||
search_space_id: int,
|
||||
user_id: UUID,
|
||||
) -> str:
|
||||
"""
|
||||
Ensure a connector name is unique within a user's search space.
|
||||
|
||||
If the name already exists, appends a counter suffix: (2), (3), etc.
|
||||
Uses the same suffix format as generate_unique_connector_name.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
name: Desired connector name
|
||||
search_space_id: The search space ID
|
||||
user_id: The user ID
|
||||
|
||||
Returns:
|
||||
Unique name, either the original or with a counter suffix
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector.name).where(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
existing_names = {row[0] for row in result.all()}
|
||||
|
||||
if name not in existing_names:
|
||||
return name
|
||||
|
||||
counter = 2
|
||||
while f"{name} ({counter})" in existing_names:
|
||||
counter += 1
|
||||
return f"{name} ({counter})"
|
||||
|
||||
|
||||
async def generate_unique_connector_name(
|
||||
session: AsyncSession,
|
||||
connector_type: SearchSourceConnectorType,
|
||||
|
|
|
|||
41
surfsense_backend/app/utils/google_credentials.py
Normal file
41
surfsense_backend/app/utils/google_credentials.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Shared Google OAuth credential utilities for native and Composio connectors."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.db import SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES = {
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
}
|
||||
|
||||
|
||||
def build_composio_credentials(connected_account_id: str) -> Credentials:
|
||||
"""
|
||||
Build Google OAuth Credentials backed by Composio's token management.
|
||||
|
||||
The returned Credentials object uses a refresh_handler that fetches
|
||||
fresh access tokens from Composio on demand, so it works seamlessly
|
||||
with googleapiclient.discovery.build().
|
||||
"""
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
service = ComposioService()
|
||||
access_token = service.get_access_token(connected_account_id)
|
||||
|
||||
def composio_refresh_handler(request, scopes):
|
||||
fresh_token = service.get_access_token(connected_account_id)
|
||||
expiry = datetime.now(UTC).replace(tzinfo=None) + timedelta(minutes=55)
|
||||
return fresh_token, expiry
|
||||
|
||||
return Credentials(
|
||||
token=access_token,
|
||||
expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(minutes=55),
|
||||
refresh_handler=composio_refresh_handler,
|
||||
)
|
||||
|
|
@ -72,6 +72,7 @@ dependencies = [
|
|||
"daytona>=0.146.0",
|
||||
"langchain-daytona>=0.0.2",
|
||||
"pypandoc>=1.16.2",
|
||||
"notion-markdown>=0.7.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,328 @@
|
|||
"""Shared fixtures for Google unification integration tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import (
|
||||
Chunk,
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
SearchSpace,
|
||||
User,
|
||||
)
|
||||
|
||||
EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||
DUMMY_EMBEDDING = [0.1] * EMBEDDING_DIM
|
||||
|
||||
|
||||
def make_document(
|
||||
*,
|
||||
title: str,
|
||||
document_type: DocumentType,
|
||||
content: str,
|
||||
search_space_id: int,
|
||||
created_by_id: str,
|
||||
) -> Document:
|
||||
"""Build a Document instance with unique hashes and a dummy embedding."""
|
||||
uid = uuid.uuid4().hex[:12]
|
||||
return Document(
|
||||
title=title,
|
||||
document_type=document_type,
|
||||
content=content,
|
||||
content_hash=f"content-{uid}",
|
||||
unique_identifier_hash=f"uid-{uid}",
|
||||
source_markdown=content,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=created_by_id,
|
||||
embedding=DUMMY_EMBEDDING,
|
||||
updated_at=datetime.now(UTC),
|
||||
status={"state": "ready"},
|
||||
)
|
||||
|
||||
|
||||
def make_chunk(*, content: str, document_id: int) -> Chunk:
|
||||
return Chunk(
|
||||
content=content,
|
||||
document_id=document_id,
|
||||
embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Savepoint-based fixture (used by retriever tests that receive db_session)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def seed_google_docs(
|
||||
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||
):
|
||||
"""Insert a native Drive doc, a legacy Composio Drive doc, and a FILE doc.
|
||||
|
||||
Returns a dict with keys ``native_doc``, ``legacy_doc``, ``file_doc``,
|
||||
plus ``search_space`` and ``user``.
|
||||
"""
|
||||
user_id = str(db_user.id)
|
||||
space_id = db_search_space.id
|
||||
|
||||
native_doc = make_document(
|
||||
title="Native Drive Document",
|
||||
document_type=DocumentType.GOOGLE_DRIVE_FILE,
|
||||
content="quarterly report from native google drive connector",
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
)
|
||||
legacy_doc = make_document(
|
||||
title="Legacy Composio Drive Document",
|
||||
document_type=DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
content="quarterly report from composio google drive connector",
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
)
|
||||
file_doc = make_document(
|
||||
title="Uploaded PDF",
|
||||
document_type=DocumentType.FILE,
|
||||
content="unrelated uploaded file about quarterly reports",
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
db_session.add_all([native_doc, legacy_doc, file_doc])
|
||||
await db_session.flush()
|
||||
|
||||
native_chunk = make_chunk(
|
||||
content="quarterly report from native google drive connector",
|
||||
document_id=native_doc.id,
|
||||
)
|
||||
legacy_chunk = make_chunk(
|
||||
content="quarterly report from composio google drive connector",
|
||||
document_id=legacy_doc.id,
|
||||
)
|
||||
file_chunk = make_chunk(
|
||||
content="unrelated uploaded file about quarterly reports",
|
||||
document_id=file_doc.id,
|
||||
)
|
||||
|
||||
db_session.add_all([native_chunk, legacy_chunk, file_chunk])
|
||||
await db_session.flush()
|
||||
|
||||
return {
|
||||
"native_doc": native_doc,
|
||||
"legacy_doc": legacy_doc,
|
||||
"file_doc": file_doc,
|
||||
"search_space": db_search_space,
|
||||
"user": db_user,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Committed-data fixture (used by service / browse tests that create their
|
||||
# own sessions internally and therefore cannot see savepoint-scoped data)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def committed_google_data(async_engine):
|
||||
"""Insert native, legacy, and FILE docs via a committed transaction.
|
||||
|
||||
Yields ``{"search_space_id": int, "user_id": str}``.
|
||||
Cleans up by deleting the search space (cascades to documents / chunks).
|
||||
"""
|
||||
space_id = None
|
||||
|
||||
async with async_engine.begin() as conn:
|
||||
session = AsyncSession(bind=conn, expire_on_commit=False)
|
||||
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=f"google-test-{uuid.uuid4().hex[:6]}@surfsense.net",
|
||||
hashed_password="hashed",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
space = SearchSpace(name=f"Google Test {uuid.uuid4().hex[:6]}", user_id=user.id)
|
||||
session.add(space)
|
||||
await session.flush()
|
||||
space_id = space.id
|
||||
user_id = str(user.id)
|
||||
|
||||
native_doc = make_document(
|
||||
title="Native Drive Doc",
|
||||
document_type=DocumentType.GOOGLE_DRIVE_FILE,
|
||||
content="quarterly budget from native google drive",
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
)
|
||||
legacy_doc = make_document(
|
||||
title="Legacy Composio Drive Doc",
|
||||
document_type=DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
content="quarterly budget from composio google drive",
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
)
|
||||
file_doc = make_document(
|
||||
title="Plain File",
|
||||
document_type=DocumentType.FILE,
|
||||
content="quarterly budget uploaded as file",
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
)
|
||||
session.add_all([native_doc, legacy_doc, file_doc])
|
||||
await session.flush()
|
||||
|
||||
for doc in [native_doc, legacy_doc, file_doc]:
|
||||
session.add(
|
||||
Chunk(
|
||||
content=doc.content,
|
||||
document_id=doc.id,
|
||||
embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
yield {"search_space_id": space_id, "user_id": user_id}
|
||||
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text("DELETE FROM searchspaces WHERE id = :sid"), {"sid": space_id}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Monkeypatch fixtures for system boundaries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_session_factory(async_engine, monkeypatch):
|
||||
"""Replace ``async_session_maker`` in connector_service with one bound to the test engine."""
|
||||
test_maker = async_sessionmaker(async_engine, expire_on_commit=False)
|
||||
monkeypatch.setattr(
|
||||
"app.services.connector_service.async_session_maker", test_maker
|
||||
)
|
||||
return test_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_embed(monkeypatch):
|
||||
"""Mock the embedding model (system boundary) to return a fixed vector."""
|
||||
mock = MagicMock(return_value=DUMMY_EMBEDDING)
|
||||
monkeypatch.setattr("app.config.config.embedding_model_instance.embed", mock)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_shielded_session(async_engine, monkeypatch):
|
||||
"""Replace ``shielded_async_session`` in the knowledge_base module
|
||||
with one that yields sessions from the test engine."""
|
||||
test_maker = async_sessionmaker(async_engine, expire_on_commit=False)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _test_shielded():
|
||||
async with test_maker() as session:
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.tools.knowledge_base.shielded_async_session",
|
||||
_test_shielded,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Indexer test helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_session_factory(async_engine):
|
||||
"""Create a session factory bound to the test engine."""
|
||||
return async_sessionmaker(async_engine, expire_on_commit=False)
|
||||
|
||||
|
||||
def mock_task_logger():
|
||||
"""Return a fully-mocked TaskLoggingService with async methods."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
mock = AsyncMock()
|
||||
mock.log_task_start = AsyncMock(return_value=MagicMock())
|
||||
mock.log_task_progress = AsyncMock()
|
||||
mock.log_task_failure = AsyncMock()
|
||||
mock.log_task_success = AsyncMock()
|
||||
return mock
|
||||
|
||||
|
||||
async def seed_connector(
|
||||
async_engine,
|
||||
*,
|
||||
connector_type: SearchSourceConnectorType,
|
||||
config: dict,
|
||||
name_prefix: str = "test",
|
||||
):
|
||||
"""Seed a connector with committed data. Returns dict and cleanup function.
|
||||
|
||||
Yields ``{"connector_id", "search_space_id", "user_id"}``.
|
||||
"""
|
||||
space_id = None
|
||||
|
||||
async with async_engine.begin() as conn:
|
||||
session = AsyncSession(bind=conn, expire_on_commit=False)
|
||||
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=f"{name_prefix}-{uuid.uuid4().hex[:6]}@surfsense.net",
|
||||
hashed_password="hashed",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
space = SearchSpace(
|
||||
name=f"{name_prefix} {uuid.uuid4().hex[:6]}", user_id=user.id
|
||||
)
|
||||
session.add(space)
|
||||
await session.flush()
|
||||
space_id = space.id
|
||||
|
||||
connector = SearchSourceConnector(
|
||||
name=f"{name_prefix} connector",
|
||||
connector_type=connector_type,
|
||||
is_indexable=True,
|
||||
config=config,
|
||||
search_space_id=space_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(connector)
|
||||
await session.flush()
|
||||
connector_id = connector.id
|
||||
user_id = str(user.id)
|
||||
|
||||
return {
|
||||
"connector_id": connector_id,
|
||||
"search_space_id": space_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
|
||||
async def cleanup_space(async_engine, space_id: int):
|
||||
"""Delete a search space (cascades to connectors/documents)."""
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text("DELETE FROM searchspaces WHERE id = :sid"), {"sid": space_id}
|
||||
)
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
"""Integration test: _browse_recent_documents returns docs of multiple types.
|
||||
|
||||
Exercises the browse path (degenerate-query fallback) with a real PostgreSQL
|
||||
database. Verifies that passing a list of document types correctly returns
|
||||
documents of all listed types -- the same ``.in_()`` SQL path used by hybrid
|
||||
search but through the browse/recency-ordered code path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_browse_recent_documents_with_list_type_returns_both(
|
||||
committed_google_data, patched_shielded_session
|
||||
):
|
||||
"""_browse_recent_documents returns docs of all types when given a list."""
|
||||
from app.agents.new_chat.tools.knowledge_base import _browse_recent_documents
|
||||
|
||||
space_id = committed_google_data["search_space_id"]
|
||||
|
||||
results = await _browse_recent_documents(
|
||||
search_space_id=space_id,
|
||||
document_type=["GOOGLE_DRIVE_FILE", "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"],
|
||||
top_k=10,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
)
|
||||
|
||||
returned_types = set()
|
||||
for doc in results:
|
||||
doc_info = doc.get("document", {})
|
||||
dtype = doc_info.get("document_type")
|
||||
if dtype:
|
||||
returned_types.add(dtype)
|
||||
|
||||
assert "GOOGLE_DRIVE_FILE" in returned_types, (
|
||||
"Native Drive docs should appear in browse results"
|
||||
)
|
||||
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in returned_types, (
|
||||
"Legacy Composio Drive docs should appear in browse results"
|
||||
)
|
||||
|
|
@ -0,0 +1,177 @@
|
|||
"""Integration tests: Calendar indexer credential resolution for Composio vs native connectors.
|
||||
|
||||
Exercises ``index_google_calendar_events`` with a real PostgreSQL database
|
||||
containing seeded connector records. Google API and Composio SDK are
|
||||
mocked at their system boundaries.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.db import SearchSourceConnectorType
|
||||
|
||||
from .conftest import (
|
||||
cleanup_space,
|
||||
make_session_factory,
|
||||
mock_task_logger,
|
||||
seed_connector,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_COMPOSIO_ACCOUNT_ID = "composio-calendar-test-789"
|
||||
_INDEXER_MODULE = "app.tasks.connector_indexers.google_calendar_indexer"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def composio_calendar(async_engine):
|
||||
data = await seed_connector(
|
||||
async_engine,
|
||||
connector_type=SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
config={"composio_connected_account_id": _COMPOSIO_ACCOUNT_ID},
|
||||
name_prefix="cal-composio",
|
||||
)
|
||||
yield data
|
||||
await cleanup_space(async_engine, data["search_space_id"])
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def composio_calendar_no_id(async_engine):
|
||||
data = await seed_connector(
|
||||
async_engine,
|
||||
connector_type=SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
config={},
|
||||
name_prefix="cal-noid",
|
||||
)
|
||||
yield data
|
||||
await cleanup_space(async_engine, data["search_space_id"])
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def native_calendar(async_engine):
|
||||
data = await seed_connector(
|
||||
async_engine,
|
||||
connector_type=SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
config={
|
||||
"token": "fake",
|
||||
"refresh_token": "fake",
|
||||
"client_id": "fake",
|
||||
"client_secret": "fake",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
},
|
||||
name_prefix="cal-native",
|
||||
)
|
||||
yield data
|
||||
await cleanup_space(async_engine, data["search_space_id"])
|
||||
|
||||
|
||||
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
|
||||
@patch(f"{_INDEXER_MODULE}.GoogleCalendarConnector")
|
||||
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
|
||||
async def test_composio_calendar_uses_composio_credentials(
|
||||
mock_build_creds,
|
||||
mock_cal_cls,
|
||||
mock_tl_cls,
|
||||
async_engine,
|
||||
composio_calendar,
|
||||
):
|
||||
"""Calendar indexer calls build_composio_credentials for a Composio connector."""
|
||||
from app.tasks.connector_indexers.google_calendar_indexer import (
|
||||
index_google_calendar_events,
|
||||
)
|
||||
|
||||
data = composio_calendar
|
||||
mock_creds = MagicMock(name="composio-creds")
|
||||
mock_build_creds.return_value = mock_creds
|
||||
mock_tl_cls.return_value = mock_task_logger()
|
||||
|
||||
mock_cal_instance = MagicMock()
|
||||
mock_cal_instance.get_all_primary_calendar_events = AsyncMock(
|
||||
return_value=([], None)
|
||||
)
|
||||
mock_cal_cls.return_value = mock_cal_instance
|
||||
|
||||
maker = make_session_factory(async_engine)
|
||||
async with maker() as session:
|
||||
await index_google_calendar_events(
|
||||
session=session,
|
||||
connector_id=data["connector_id"],
|
||||
search_space_id=data["search_space_id"],
|
||||
user_id=data["user_id"],
|
||||
)
|
||||
|
||||
mock_build_creds.assert_called_once_with(_COMPOSIO_ACCOUNT_ID)
|
||||
mock_cal_cls.assert_called_once()
|
||||
_, kwargs = mock_cal_cls.call_args
|
||||
assert kwargs.get("credentials") is mock_creds
|
||||
|
||||
|
||||
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
|
||||
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
|
||||
async def test_composio_calendar_without_account_id_returns_error(
|
||||
mock_build_creds,
|
||||
mock_tl_cls,
|
||||
async_engine,
|
||||
composio_calendar_no_id,
|
||||
):
|
||||
"""Calendar indexer returns error when Composio connector lacks connected_account_id."""
|
||||
from app.tasks.connector_indexers.google_calendar_indexer import (
|
||||
index_google_calendar_events,
|
||||
)
|
||||
|
||||
data = composio_calendar_no_id
|
||||
mock_tl_cls.return_value = mock_task_logger()
|
||||
|
||||
maker = make_session_factory(async_engine)
|
||||
async with maker() as session:
|
||||
count, _skipped, error = await index_google_calendar_events(
|
||||
session=session,
|
||||
connector_id=data["connector_id"],
|
||||
search_space_id=data["search_space_id"],
|
||||
user_id=data["user_id"],
|
||||
)
|
||||
|
||||
assert count == 0
|
||||
assert error is not None
|
||||
assert "composio" in error.lower()
|
||||
mock_build_creds.assert_not_called()
|
||||
|
||||
|
||||
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
|
||||
@patch(f"{_INDEXER_MODULE}.GoogleCalendarConnector")
|
||||
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
|
||||
async def test_native_calendar_does_not_use_composio_credentials(
|
||||
mock_build_creds,
|
||||
mock_cal_cls,
|
||||
mock_tl_cls,
|
||||
async_engine,
|
||||
native_calendar,
|
||||
):
|
||||
"""Calendar indexer does NOT call build_composio_credentials for a native connector."""
|
||||
from app.tasks.connector_indexers.google_calendar_indexer import (
|
||||
index_google_calendar_events,
|
||||
)
|
||||
|
||||
data = native_calendar
|
||||
mock_tl_cls.return_value = mock_task_logger()
|
||||
|
||||
mock_cal_instance = MagicMock()
|
||||
mock_cal_instance.get_all_primary_calendar_events = AsyncMock(
|
||||
return_value=([], None)
|
||||
)
|
||||
mock_cal_cls.return_value = mock_cal_instance
|
||||
|
||||
maker = make_session_factory(async_engine)
|
||||
async with maker() as session:
|
||||
await index_google_calendar_events(
|
||||
session=session,
|
||||
connector_id=data["connector_id"],
|
||||
search_space_id=data["search_space_id"],
|
||||
user_id=data["user_id"],
|
||||
)
|
||||
|
||||
mock_build_creds.assert_not_called()
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
"""Integration tests: Drive indexer credential resolution for Composio vs native connectors.
|
||||
|
||||
Exercises ``index_google_drive_files`` with a real PostgreSQL database
|
||||
containing seeded connector records. Google API and Composio SDK are
|
||||
mocked at their system boundaries.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.db import SearchSourceConnectorType
|
||||
|
||||
from .conftest import (
|
||||
cleanup_space,
|
||||
make_session_factory,
|
||||
mock_task_logger,
|
||||
seed_connector,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_COMPOSIO_ACCOUNT_ID = "composio-test-account-123"
|
||||
_INDEXER_MODULE = "app.tasks.connector_indexers.google_drive_indexer"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def committed_drive_connector(async_engine):
|
||||
data = await seed_connector(
|
||||
async_engine,
|
||||
connector_type=SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
config={"composio_connected_account_id": _COMPOSIO_ACCOUNT_ID},
|
||||
name_prefix="drive-composio",
|
||||
)
|
||||
yield data
|
||||
await cleanup_space(async_engine, data["search_space_id"])
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def committed_native_drive_connector(async_engine):
|
||||
data = await seed_connector(
|
||||
async_engine,
|
||||
connector_type=SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
config={
|
||||
"token": "fake-token",
|
||||
"refresh_token": "fake-refresh",
|
||||
"client_id": "fake-client-id",
|
||||
"client_secret": "fake-secret",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
},
|
||||
name_prefix="drive-native",
|
||||
)
|
||||
yield data
|
||||
await cleanup_space(async_engine, data["search_space_id"])
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def committed_composio_no_account_id(async_engine):
|
||||
data = await seed_connector(
|
||||
async_engine,
|
||||
connector_type=SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
config={},
|
||||
name_prefix="drive-noid",
|
||||
)
|
||||
yield data
|
||||
await cleanup_space(async_engine, data["search_space_id"])
|
||||
|
||||
|
||||
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
|
||||
@patch(f"{_INDEXER_MODULE}.GoogleDriveClient")
|
||||
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
|
||||
async def test_composio_connector_uses_composio_credentials(
|
||||
mock_build_creds,
|
||||
mock_client_cls,
|
||||
mock_task_logger_cls,
|
||||
async_engine,
|
||||
committed_drive_connector,
|
||||
):
|
||||
"""Drive indexer calls build_composio_credentials for a Composio connector
|
||||
and passes the result to GoogleDriveClient."""
|
||||
from app.tasks.connector_indexers.google_drive_indexer import (
|
||||
index_google_drive_files,
|
||||
)
|
||||
|
||||
data = committed_drive_connector
|
||||
mock_creds = MagicMock(name="composio-credentials")
|
||||
mock_build_creds.return_value = mock_creds
|
||||
mock_task_logger_cls.return_value = mock_task_logger()
|
||||
|
||||
maker = make_session_factory(async_engine)
|
||||
async with maker() as session:
|
||||
await index_google_drive_files(
|
||||
session=session,
|
||||
connector_id=data["connector_id"],
|
||||
search_space_id=data["search_space_id"],
|
||||
user_id=data["user_id"],
|
||||
folder_id="test-folder-id",
|
||||
)
|
||||
|
||||
mock_build_creds.assert_called_once_with(_COMPOSIO_ACCOUNT_ID)
|
||||
mock_client_cls.assert_called_once()
|
||||
_, kwargs = mock_client_cls.call_args
|
||||
assert kwargs.get("credentials") is mock_creds
|
||||
|
||||
|
||||
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
|
||||
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
|
||||
async def test_composio_connector_without_account_id_returns_error(
|
||||
mock_build_creds,
|
||||
mock_task_logger_cls,
|
||||
async_engine,
|
||||
committed_composio_no_account_id,
|
||||
):
|
||||
"""Drive indexer returns an error when Composio connector lacks connected_account_id."""
|
||||
from app.tasks.connector_indexers.google_drive_indexer import (
|
||||
index_google_drive_files,
|
||||
)
|
||||
|
||||
data = committed_composio_no_account_id
|
||||
mock_task_logger_cls.return_value = mock_task_logger()
|
||||
|
||||
maker = make_session_factory(async_engine)
|
||||
async with maker() as session:
|
||||
count, _skipped, error = await index_google_drive_files(
|
||||
session=session,
|
||||
connector_id=data["connector_id"],
|
||||
search_space_id=data["search_space_id"],
|
||||
user_id=data["user_id"],
|
||||
folder_id="test-folder-id",
|
||||
)
|
||||
|
||||
assert count == 0
|
||||
assert error is not None
|
||||
assert (
|
||||
"composio_connected_account_id" in error.lower() or "composio" in error.lower()
|
||||
)
|
||||
mock_build_creds.assert_not_called()
|
||||
|
||||
|
||||
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
|
||||
@patch(f"{_INDEXER_MODULE}.GoogleDriveClient")
|
||||
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
|
||||
async def test_native_connector_does_not_use_composio_credentials(
|
||||
mock_build_creds,
|
||||
mock_client_cls,
|
||||
mock_task_logger_cls,
|
||||
async_engine,
|
||||
committed_native_drive_connector,
|
||||
):
|
||||
"""Drive indexer does NOT call build_composio_credentials for a native connector."""
|
||||
from app.tasks.connector_indexers.google_drive_indexer import (
|
||||
index_google_drive_files,
|
||||
)
|
||||
|
||||
data = committed_native_drive_connector
|
||||
mock_task_logger_cls.return_value = mock_task_logger()
|
||||
|
||||
maker = make_session_factory(async_engine)
|
||||
async with maker() as session:
|
||||
await index_google_drive_files(
|
||||
session=session,
|
||||
connector_id=data["connector_id"],
|
||||
search_space_id=data["search_space_id"],
|
||||
user_id=data["user_id"],
|
||||
folder_id="test-folder-id",
|
||||
)
|
||||
|
||||
mock_build_creds.assert_not_called()
|
||||
mock_client_cls.assert_called_once()
|
||||
_, kwargs = mock_client_cls.call_args
|
||||
assert kwargs.get("credentials") is None
|
||||
|
|
@ -0,0 +1,173 @@
|
|||
"""Integration tests: Gmail indexer credential resolution for Composio vs native connectors.
|
||||
|
||||
Exercises ``index_google_gmail_messages`` with a real PostgreSQL database
|
||||
containing seeded connector records. Google API and Composio SDK are
|
||||
mocked at their system boundaries.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.db import SearchSourceConnectorType
|
||||
|
||||
from .conftest import (
|
||||
cleanup_space,
|
||||
make_session_factory,
|
||||
mock_task_logger,
|
||||
seed_connector,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_COMPOSIO_ACCOUNT_ID = "composio-gmail-test-456"
|
||||
_INDEXER_MODULE = "app.tasks.connector_indexers.google_gmail_indexer"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def composio_gmail(async_engine):
|
||||
data = await seed_connector(
|
||||
async_engine,
|
||||
connector_type=SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
config={"composio_connected_account_id": _COMPOSIO_ACCOUNT_ID},
|
||||
name_prefix="gmail-composio",
|
||||
)
|
||||
yield data
|
||||
await cleanup_space(async_engine, data["search_space_id"])
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def composio_gmail_no_id(async_engine):
|
||||
data = await seed_connector(
|
||||
async_engine,
|
||||
connector_type=SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
config={},
|
||||
name_prefix="gmail-noid",
|
||||
)
|
||||
yield data
|
||||
await cleanup_space(async_engine, data["search_space_id"])
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def native_gmail(async_engine):
|
||||
data = await seed_connector(
|
||||
async_engine,
|
||||
connector_type=SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
config={
|
||||
"token": "fake",
|
||||
"refresh_token": "fake",
|
||||
"client_id": "fake",
|
||||
"client_secret": "fake",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
},
|
||||
name_prefix="gmail-native",
|
||||
)
|
||||
yield data
|
||||
await cleanup_space(async_engine, data["search_space_id"])
|
||||
|
||||
|
||||
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
|
||||
@patch(f"{_INDEXER_MODULE}.GoogleGmailConnector")
|
||||
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
|
||||
async def test_composio_gmail_uses_composio_credentials(
|
||||
mock_build_creds,
|
||||
mock_gmail_cls,
|
||||
mock_tl_cls,
|
||||
async_engine,
|
||||
composio_gmail,
|
||||
):
|
||||
"""Gmail indexer calls build_composio_credentials for a Composio connector."""
|
||||
from app.tasks.connector_indexers.google_gmail_indexer import (
|
||||
index_google_gmail_messages,
|
||||
)
|
||||
|
||||
data = composio_gmail
|
||||
mock_creds = MagicMock(name="composio-creds")
|
||||
mock_build_creds.return_value = mock_creds
|
||||
mock_tl_cls.return_value = mock_task_logger()
|
||||
|
||||
mock_gmail_instance = MagicMock()
|
||||
mock_gmail_instance.get_recent_messages = AsyncMock(return_value=([], None))
|
||||
mock_gmail_cls.return_value = mock_gmail_instance
|
||||
|
||||
maker = make_session_factory(async_engine)
|
||||
async with maker() as session:
|
||||
await index_google_gmail_messages(
|
||||
session=session,
|
||||
connector_id=data["connector_id"],
|
||||
search_space_id=data["search_space_id"],
|
||||
user_id=data["user_id"],
|
||||
)
|
||||
|
||||
mock_build_creds.assert_called_once_with(_COMPOSIO_ACCOUNT_ID)
|
||||
mock_gmail_cls.assert_called_once()
|
||||
args, _ = mock_gmail_cls.call_args
|
||||
assert args[0] is mock_creds
|
||||
|
||||
|
||||
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
|
||||
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
|
||||
async def test_composio_gmail_without_account_id_returns_error(
|
||||
mock_build_creds,
|
||||
mock_tl_cls,
|
||||
async_engine,
|
||||
composio_gmail_no_id,
|
||||
):
|
||||
"""Gmail indexer returns error when Composio connector lacks connected_account_id."""
|
||||
from app.tasks.connector_indexers.google_gmail_indexer import (
|
||||
index_google_gmail_messages,
|
||||
)
|
||||
|
||||
data = composio_gmail_no_id
|
||||
mock_tl_cls.return_value = mock_task_logger()
|
||||
|
||||
maker = make_session_factory(async_engine)
|
||||
async with maker() as session:
|
||||
count, _skipped, error = await index_google_gmail_messages(
|
||||
session=session,
|
||||
connector_id=data["connector_id"],
|
||||
search_space_id=data["search_space_id"],
|
||||
user_id=data["user_id"],
|
||||
)
|
||||
|
||||
assert count == 0
|
||||
assert error is not None
|
||||
assert "composio" in error.lower()
|
||||
mock_build_creds.assert_not_called()
|
||||
|
||||
|
||||
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
|
||||
@patch(f"{_INDEXER_MODULE}.GoogleGmailConnector")
|
||||
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
|
||||
async def test_native_gmail_does_not_use_composio_credentials(
|
||||
mock_build_creds,
|
||||
mock_gmail_cls,
|
||||
mock_tl_cls,
|
||||
async_engine,
|
||||
native_gmail,
|
||||
):
|
||||
"""Gmail indexer does NOT call build_composio_credentials for a native connector."""
|
||||
from app.tasks.connector_indexers.google_gmail_indexer import (
|
||||
index_google_gmail_messages,
|
||||
)
|
||||
|
||||
data = native_gmail
|
||||
mock_tl_cls.return_value = mock_task_logger()
|
||||
|
||||
mock_gmail_instance = MagicMock()
|
||||
mock_gmail_instance.get_recent_messages = AsyncMock(return_value=([], None))
|
||||
mock_gmail_cls.return_value = mock_gmail_instance
|
||||
|
||||
maker = make_session_factory(async_engine)
|
||||
async with maker() as session:
|
||||
await index_google_gmail_messages(
|
||||
session=session,
|
||||
connector_id=data["connector_id"],
|
||||
search_space_id=data["search_space_id"],
|
||||
user_id=data["user_id"],
|
||||
)
|
||||
|
||||
mock_build_creds.assert_not_called()
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
"""Integration tests: hybrid search correctly filters by document type lists.
|
||||
|
||||
These tests exercise the public ``hybrid_search`` method on
|
||||
``ChucksHybridSearchRetriever`` with a real PostgreSQL database.
|
||||
They verify that the ``.in_()`` SQL path works for list-of-types filtering,
|
||||
which is the foundation of the Google unification changes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
|
||||
from .conftest import DUMMY_EMBEDDING
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_list_of_types_returns_both_matching_doc_types(
|
||||
db_session, seed_google_docs
|
||||
):
|
||||
"""Searching with a list of document types returns documents of ALL listed types."""
|
||||
space_id = seed_google_docs["search_space"].id
|
||||
|
||||
retriever = ChucksHybridSearchRetriever(db_session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text="quarterly report",
|
||||
top_k=10,
|
||||
search_space_id=space_id,
|
||||
document_type=["GOOGLE_DRIVE_FILE", "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"],
|
||||
query_embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
|
||||
returned_types = {
|
||||
r["document"]["document_type"] for r in results if r.get("document")
|
||||
}
|
||||
assert "GOOGLE_DRIVE_FILE" in returned_types
|
||||
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in returned_types
|
||||
assert "FILE" not in returned_types
|
||||
|
||||
|
||||
async def test_single_string_type_returns_only_that_type(db_session, seed_google_docs):
|
||||
"""Searching with a single string type returns only documents of that exact type."""
|
||||
space_id = seed_google_docs["search_space"].id
|
||||
|
||||
retriever = ChucksHybridSearchRetriever(db_session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text="quarterly report",
|
||||
top_k=10,
|
||||
search_space_id=space_id,
|
||||
document_type="GOOGLE_DRIVE_FILE",
|
||||
query_embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
|
||||
returned_types = {
|
||||
r["document"]["document_type"] for r in results if r.get("document")
|
||||
}
|
||||
assert returned_types == {"GOOGLE_DRIVE_FILE"}
|
||||
|
||||
|
||||
async def test_all_invalid_types_returns_empty(db_session, seed_google_docs):
|
||||
"""Searching with a list of nonexistent types returns an empty list, no exceptions."""
|
||||
space_id = seed_google_docs["search_space"].id
|
||||
|
||||
retriever = ChucksHybridSearchRetriever(db_session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text="quarterly report",
|
||||
top_k=10,
|
||||
search_space_id=space_id,
|
||||
document_type=["NONEXISTENT_TYPE"],
|
||||
query_embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
|
||||
assert results == []
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
"""Integration tests: ConnectorService search transparently includes legacy Composio docs.
|
||||
|
||||
These tests exercise ``ConnectorService.search_google_drive`` and
|
||||
``ConnectorService.search_files`` through a real PostgreSQL database.
|
||||
They verify that the legacy-type alias expansion works end-to-end:
|
||||
searching for native Google Drive docs also returns old Composio-typed docs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.connector_service import ConnectorService
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_search_google_drive_includes_legacy_composio_docs(
|
||||
async_engine, committed_google_data, patched_session_factory, patched_embed
|
||||
):
|
||||
"""search_google_drive returns both GOOGLE_DRIVE_FILE and COMPOSIO_GOOGLE_DRIVE_CONNECTOR docs."""
|
||||
space_id = committed_google_data["search_space_id"]
|
||||
|
||||
async with patched_session_factory() as session:
|
||||
service = ConnectorService(session, search_space_id=space_id)
|
||||
_, raw_docs = await service.search_google_drive(
|
||||
user_query="quarterly budget",
|
||||
search_space_id=space_id,
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
returned_types = set()
|
||||
for doc in raw_docs:
|
||||
doc_info = doc.get("document", {})
|
||||
dtype = doc_info.get("document_type")
|
||||
if dtype:
|
||||
returned_types.add(dtype)
|
||||
|
||||
assert "GOOGLE_DRIVE_FILE" in returned_types, (
|
||||
"Native Drive docs should appear in search_google_drive results"
|
||||
)
|
||||
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in returned_types, (
|
||||
"Legacy Composio Drive docs should appear in search_google_drive results"
|
||||
)
|
||||
assert "FILE" not in returned_types, (
|
||||
"Plain FILE docs should NOT appear in search_google_drive results"
|
||||
)
|
||||
|
||||
|
||||
async def test_search_files_does_not_include_google_types(
|
||||
async_engine, committed_google_data, patched_session_factory, patched_embed
|
||||
):
|
||||
"""search_files returns only FILE docs, not Google Drive docs."""
|
||||
space_id = committed_google_data["search_space_id"]
|
||||
|
||||
async with patched_session_factory() as session:
|
||||
service = ConnectorService(session, search_space_id=space_id)
|
||||
_, raw_docs = await service.search_files(
|
||||
user_query="quarterly budget",
|
||||
search_space_id=space_id,
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
returned_types = set()
|
||||
for doc in raw_docs:
|
||||
doc_info = doc.get("document", {})
|
||||
dtype = doc_info.get("document_type")
|
||||
if dtype:
|
||||
returned_types.add(dtype)
|
||||
|
||||
if returned_types:
|
||||
assert "GOOGLE_DRIVE_FILE" not in returned_types
|
||||
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" not in returned_types
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
"""Unit tests: build_composio_credentials returns valid Google Credentials.
|
||||
|
||||
Mocks the Composio SDK (external system boundary) and verifies that the
|
||||
returned ``google.oauth2.credentials.Credentials`` object is correctly
|
||||
configured with a token and a working refresh handler.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@patch("app.services.composio_service.ComposioService")
|
||||
def test_returns_credentials_with_token_and_expiry(mock_composio_service):
|
||||
"""build_composio_credentials returns a Credentials object with the Composio access token."""
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_access_token.return_value = "fake-access-token"
|
||||
mock_composio_service.return_value = mock_service
|
||||
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
creds = build_composio_credentials("test-account-id")
|
||||
|
||||
assert isinstance(creds, Credentials)
|
||||
assert creds.token == "fake-access-token"
|
||||
assert creds.expiry is not None
|
||||
assert creds.expiry > datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
|
||||
@patch("app.services.composio_service.ComposioService")
|
||||
def test_refresh_handler_fetches_fresh_token(mock_composio_service):
|
||||
"""The refresh_handler on the returned Credentials fetches a new token from Composio."""
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_access_token.side_effect = [
|
||||
"initial-token",
|
||||
"refreshed-token",
|
||||
]
|
||||
mock_composio_service.return_value = mock_service
|
||||
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
creds = build_composio_credentials("test-account-id")
|
||||
assert creds.token == "initial-token"
|
||||
|
||||
refresh_handler = creds._refresh_handler
|
||||
assert callable(refresh_handler)
|
||||
|
||||
new_token, new_expiry = refresh_handler(request=None, scopes=None)
|
||||
|
||||
assert new_token == "refreshed-token"
|
||||
assert new_expiry > datetime.now(UTC).replace(tzinfo=None)
|
||||
assert mock_service.get_access_token.call_count == 2
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
"""Unit tests: Gmail, Calendar, and Drive connectors accept Composio-sourced credentials.
|
||||
|
||||
These tests exercise the REAL connector code with Composio-style credentials
|
||||
(token + expiry + refresh_handler, but NO refresh_token / client_id / client_secret).
|
||||
Only the Google API boundary (``googleapiclient.discovery.build``) is mocked.
|
||||
|
||||
This verifies Phase 2b: the relaxed validation in ``_get_credentials()`` correctly
|
||||
allows Composio credentials through without raising ValueError or persisting to DB.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _utcnow_naive() -> datetime:
|
||||
"""Return current UTC time as a naive datetime (matches google-auth convention)."""
|
||||
return datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
|
||||
def _composio_credentials(*, expired: bool = False) -> Credentials:
|
||||
"""Create a Credentials object that mimics build_composio_credentials output."""
|
||||
if expired:
|
||||
expiry = _utcnow_naive() - timedelta(hours=1)
|
||||
else:
|
||||
expiry = _utcnow_naive() + timedelta(hours=1)
|
||||
|
||||
def refresh_handler(request, scopes):
|
||||
return "refreshed-token", _utcnow_naive() + timedelta(hours=1)
|
||||
|
||||
return Credentials(
|
||||
token="composio-access-token",
|
||||
expiry=expiry,
|
||||
refresh_handler=refresh_handler,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gmail
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch("app.connectors.google_gmail_connector.build")
|
||||
async def test_gmail_accepts_valid_composio_credentials(mock_build):
|
||||
"""GoogleGmailConnector.get_user_profile succeeds with Composio credentials
|
||||
that have no client_id, client_secret, or refresh_token."""
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
creds = _composio_credentials(expired=False)
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.users.return_value.getProfile.return_value.execute.return_value = {
|
||||
"emailAddress": "test@example.com",
|
||||
"messagesTotal": 42,
|
||||
"threadsTotal": 10,
|
||||
"historyId": "12345",
|
||||
}
|
||||
mock_build.return_value = mock_service
|
||||
|
||||
connector = GoogleGmailConnector(
|
||||
creds,
|
||||
session=MagicMock(),
|
||||
user_id="test-user",
|
||||
)
|
||||
|
||||
profile, error = await connector.get_user_profile()
|
||||
|
||||
assert error is None
|
||||
assert profile["email_address"] == "test@example.com"
|
||||
mock_build.assert_called_once_with("gmail", "v1", credentials=creds)
|
||||
|
||||
|
||||
@patch("app.connectors.google_gmail_connector.Request")
|
||||
@patch("app.connectors.google_gmail_connector.build")
|
||||
async def test_gmail_refreshes_expired_composio_credentials(
|
||||
mock_build, mock_request_cls
|
||||
):
|
||||
"""GoogleGmailConnector handles expired Composio credentials via refresh_handler
|
||||
without attempting DB persistence."""
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
creds = _composio_credentials(expired=True)
|
||||
assert creds.expired
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.users.return_value.getProfile.return_value.execute.return_value = {
|
||||
"emailAddress": "test@example.com",
|
||||
"messagesTotal": 42,
|
||||
"threadsTotal": 10,
|
||||
"historyId": "12345",
|
||||
}
|
||||
mock_build.return_value = mock_service
|
||||
|
||||
mock_session = AsyncMock()
|
||||
connector = GoogleGmailConnector(
|
||||
creds,
|
||||
session=mock_session,
|
||||
user_id="test-user",
|
||||
)
|
||||
|
||||
profile, error = await connector.get_user_profile()
|
||||
|
||||
assert error is None
|
||||
assert profile["email_address"] == "test@example.com"
|
||||
assert creds.token == "refreshed-token"
|
||||
assert not creds.expired
|
||||
mock_session.execute.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Calendar
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch("app.connectors.google_calendar_connector.build")
|
||||
async def test_calendar_accepts_valid_composio_credentials(mock_build):
|
||||
"""GoogleCalendarConnector.get_calendars succeeds with Composio credentials
|
||||
that have no client_id, client_secret, or refresh_token."""
|
||||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
||||
|
||||
creds = _composio_credentials(expired=False)
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.calendarList.return_value.list.return_value.execute.return_value = {
|
||||
"items": [{"id": "primary", "summary": "My Calendar", "primary": True}],
|
||||
}
|
||||
mock_build.return_value = mock_service
|
||||
|
||||
connector = GoogleCalendarConnector(
|
||||
creds,
|
||||
session=MagicMock(),
|
||||
user_id="test-user",
|
||||
)
|
||||
|
||||
calendars, error = await connector.get_calendars()
|
||||
|
||||
assert error is None
|
||||
assert len(calendars) == 1
|
||||
assert calendars[0]["summary"] == "My Calendar"
|
||||
mock_build.assert_called_once_with("calendar", "v3", credentials=creds)
|
||||
|
||||
|
||||
@patch("app.connectors.google_calendar_connector.Request")
|
||||
@patch("app.connectors.google_calendar_connector.build")
|
||||
async def test_calendar_refreshes_expired_composio_credentials(
|
||||
mock_build, mock_request_cls
|
||||
):
|
||||
"""GoogleCalendarConnector handles expired Composio credentials via refresh_handler
|
||||
without attempting DB persistence."""
|
||||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
||||
|
||||
creds = _composio_credentials(expired=True)
|
||||
assert creds.expired
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.calendarList.return_value.list.return_value.execute.return_value = {
|
||||
"items": [{"id": "primary", "summary": "My Calendar", "primary": True}],
|
||||
}
|
||||
mock_build.return_value = mock_service
|
||||
|
||||
mock_session = AsyncMock()
|
||||
connector = GoogleCalendarConnector(
|
||||
creds,
|
||||
session=mock_session,
|
||||
user_id="test-user",
|
||||
)
|
||||
|
||||
calendars, error = await connector.get_calendars()
|
||||
|
||||
assert error is None
|
||||
assert len(calendars) == 1
|
||||
assert creds.token == "refreshed-token"
|
||||
assert not creds.expired
|
||||
mock_session.execute.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Drive
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch("app.connectors.google_drive.client.build")
|
||||
async def test_drive_client_uses_prebuilt_composio_credentials(mock_build):
|
||||
"""GoogleDriveClient with pre-built Composio credentials uses them directly,
|
||||
bypassing DB credential loading via get_valid_credentials."""
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
|
||||
creds = _composio_credentials(expired=False)
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.files.return_value.list.return_value.execute.return_value = {
|
||||
"files": [],
|
||||
"nextPageToken": None,
|
||||
}
|
||||
mock_build.return_value = mock_service
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=MagicMock(),
|
||||
connector_id=999,
|
||||
credentials=creds,
|
||||
)
|
||||
|
||||
files, _next_token, error = await client.list_files()
|
||||
|
||||
assert error is None
|
||||
assert files == []
|
||||
mock_build.assert_called_once_with("drive", "v3", credentials=creds)
|
||||
|
||||
|
||||
@patch("app.connectors.google_drive.client.get_valid_credentials")
|
||||
@patch("app.connectors.google_drive.client.build")
|
||||
async def test_drive_client_prebuilt_creds_skip_db_loading(mock_build, mock_get_valid):
|
||||
"""GoogleDriveClient does NOT call get_valid_credentials when pre-built
|
||||
credentials are provided."""
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
|
||||
creds = _composio_credentials(expired=False)
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.files.return_value.list.return_value.execute.return_value = {
|
||||
"files": [],
|
||||
"nextPageToken": None,
|
||||
}
|
||||
mock_build.return_value = mock_service
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=MagicMock(),
|
||||
connector_id=999,
|
||||
credentials=creds,
|
||||
)
|
||||
|
||||
await client.list_files()
|
||||
|
||||
mock_get_valid.assert_not_called()
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue