diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 9bf38cad6..8dc2fa98f 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -472,7 +472,7 @@ async def create_surfsense_deep_agent( SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]), create_summarization_middleware(llm, StateBackend), PatchToolCallsMiddleware(), - DedupHITLToolCallsMiddleware(), + DedupHITLToolCallsMiddleware(agent_tools=tools), AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), ] diff --git a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py index f5e8f1235..bc6f7fd9e 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py +++ b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py @@ -20,19 +20,39 @@ from langgraph.runtime import Runtime logger = logging.getLogger(__name__) -_HITL_TOOL_DEDUP_KEYS: dict[str, str] = { - "delete_calendar_event": "event_title_or_id", - "update_calendar_event": "event_title_or_id", - "trash_gmail_email": "email_subject_or_id", +_NATIVE_HITL_TOOL_DEDUP_KEYS: dict[str, str] = { + # Gmail + "send_gmail_email": "subject", + "create_gmail_draft": "subject", "update_gmail_draft": "draft_subject_or_id", + "trash_gmail_email": "email_subject_or_id", + # Google Calendar + "create_calendar_event": "title", + "update_calendar_event": "event_title_or_id", + "delete_calendar_event": "event_title_or_id", + # Google Drive + "create_google_drive_file": "file_name", "delete_google_drive_file": "file_name", + # OneDrive + "create_onedrive_file": "file_name", "delete_onedrive_file": "file_name", - "delete_notion_page": "page_title", + # Dropbox + "create_dropbox_file": "file_name", + "delete_dropbox_file": "file_name", + # Notion + "create_notion_page": "title", "update_notion_page": "page_title", - "delete_linear_issue": "issue_ref", + "delete_notion_page": "page_title", + # Linear + "create_linear_issue": "title", "update_linear_issue": "issue_ref", + "delete_linear_issue": "issue_ref", + # Jira + "create_jira_issue": "summary", "update_jira_issue": "issue_title_or_key", "delete_jira_issue": "issue_title_or_key", + # Confluence + "create_confluence_page": "title", "update_confluence_page": "page_title_or_id", "delete_confluence_page": "page_title_or_id", } @@ -43,22 +63,38 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] Only the **first** occurrence of each (tool-name, primary-arg-value) pair is kept; subsequent duplicates are silently dropped. + + The dedup map is built from two sources: + + 1. A comprehensive list of native HITL tools (hardcoded above). + 2. Any ``StructuredTool`` instances passed via *agent_tools* whose + ``metadata`` contains ``{"hitl": True, "hitl_dedup_key": "..."}``. + This is how MCP tools automatically get dedup support. """ tools = () + def __init__(self, *, agent_tools: list[Any] | None = None) -> None: + self._dedup_keys: dict[str, str] = dict(_NATIVE_HITL_TOOL_DEDUP_KEYS) + for t in agent_tools or []: + meta = getattr(t, "metadata", None) or {} + if meta.get("hitl") and meta.get("hitl_dedup_key"): + self._dedup_keys[t.name] = meta["hitl_dedup_key"] + def after_model( self, state: AgentState, runtime: Runtime[Any] ) -> dict[str, Any] | None: - return self._dedup(state) + return self._dedup(state, self._dedup_keys) async def aafter_model( self, state: AgentState, runtime: Runtime[Any] ) -> dict[str, Any] | None: - return self._dedup(state) + return self._dedup(state, self._dedup_keys) @staticmethod - def _dedup(state: AgentState) -> dict[str, Any] | None: # type: ignore[type-arg] + def _dedup( + state: AgentState, dedup_keys: dict[str, str] # type: ignore[type-arg] + ) -> dict[str, Any] | None: messages = state.get("messages") if not messages: return None @@ -73,7 +109,7 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] for tc in tool_calls: name = tc.get("name", "") - dedup_key_arg = _HITL_TOOL_DEDUP_KEYS.get(name) + dedup_key_arg = dedup_keys.get(name) if dedup_key_arg is not None: arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower() key = (name, arg_val) diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py index b4d532b76..b76f4d757 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py @@ -2,7 +2,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified @@ -65,54 +65,28 @@ def create_create_confluence_page_tool( "connector_type": "confluence", } - approval = interrupt( - { - "type": "confluence_page_creation", - "action": { - "tool": "create_confluence_page", - "params": { - "title": title, - "content": content, - "space_id": space_id, - "connector_id": connector_id, - }, - }, - "context": context, - } + result = request_approval( + action_type="confluence_page_creation", + tool_name="create_confluence_page", + params={ + "title": title, + "content": content, + "space_id": space_id, + "connector_id": connector_id, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", - "message": "User declined. The page was not created.", + "message": "User declined. Do not retry or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_title = final_params.get("title", title) - final_content = final_params.get("content", content) or "" - final_space_id = final_params.get("space_id", space_id) - final_connector_id = final_params.get("connector_id", connector_id) + final_title = result.params.get("title", title) + final_content = result.params.get("content", content) or "" + final_space_id = result.params.get("space_id", space_id) + final_connector_id = result.params.get("connector_id", connector_id) if not final_title or not final_title.strip(): return {"status": "error", "message": "Page title cannot be empty."} diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py index ba1dae653..070efaf57 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py @@ -2,7 +2,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified @@ -74,54 +74,28 @@ def create_delete_confluence_page_tool( document_id = page_data["document_id"] connector_id_from_context = context.get("account", {}).get("id") - approval = interrupt( - { - "type": "confluence_page_deletion", - "action": { - "tool": "delete_confluence_page", - "params": { - "page_id": page_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - }, - "context": context, - } + result = request_approval( + action_type="confluence_page_deletion", + tool_name="delete_confluence_page", + params={ + "page_id": page_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", - "message": "User declined. The page was not deleted.", + "message": "User declined. Do not retry or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_page_id = final_params.get("page_id", page_id) - final_connector_id = final_params.get( + final_page_id = result.params.get("page_id", page_id) + final_connector_id = result.params.get( "connector_id", connector_id_from_context ) - final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb) + final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) from sqlalchemy.future import select diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py index 913896f83..c80df9710 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py @@ -2,7 +2,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified @@ -78,62 +78,36 @@ def create_update_confluence_page_tool( document_id = page_data.get("document_id") connector_id_from_context = context.get("account", {}).get("id") - approval = interrupt( - { - "type": "confluence_page_update", - "action": { - "tool": "update_confluence_page", - "params": { - "page_id": page_id, - "document_id": document_id, - "new_title": new_title, - "new_content": new_content, - "version": current_version, - "connector_id": connector_id_from_context, - }, - }, - "context": context, - } + result = request_approval( + action_type="confluence_page_update", + tool_name="update_confluence_page", + params={ + "page_id": page_id, + "document_id": document_id, + "new_title": new_title, + "new_content": new_content, + "version": current_version, + "connector_id": connector_id_from_context, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", - "message": "User declined. The page was not updated.", + "message": "User declined. Do not retry or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_page_id = final_params.get("page_id", page_id) - final_title = final_params.get("new_title", new_title) or current_title - final_content = final_params.get("new_content", new_content) + final_page_id = result.params.get("page_id", page_id) + final_title = result.params.get("new_title", new_title) or current_title + final_content = result.params.get("new_content", new_content) if final_content is None: final_content = current_body - final_version = final_params.get("version", current_version) - final_connector_id = final_params.get( + final_version = result.params.get("version", current_version) + final_connector_id = result.params.get( "connector_id", connector_id_from_context ) - final_document_id = final_params.get("document_id", document_id) + final_document_id = result.params.get("document_id", document_id) from sqlalchemy.future import select diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py index ed8034861..6e2578334 100644 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Any, Literal from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -159,56 +159,30 @@ def create_create_dropbox_file_tool( "supported_types": _SUPPORTED_TYPES, } - approval = interrupt( - { - "type": "dropbox_file_creation", - "action": { - "tool": "create_dropbox_file", - "params": { - "name": name, - "file_type": file_type, - "content": content, - "connector_id": None, - "parent_folder_path": None, - }, - }, - "context": context, - } + result = request_approval( + action_type="dropbox_file_creation", + tool_name="create_dropbox_file", + params={ + "name": name, + "file_type": file_type, + "content": content, + "connector_id": None, + "parent_folder_path": None, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", - "message": "User declined. The file was not created.", + "message": "User declined. Do not retry or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_name = final_params.get("name", name) - final_file_type = final_params.get("file_type", file_type) - final_content = final_params.get("content", content) - final_connector_id = final_params.get("connector_id") - final_parent_folder_path = final_params.get("parent_folder_path") + final_name = result.params.get("name", name) + final_file_type = result.params.get("file_type", file_type) + final_content = result.params.get("content", content) + final_connector_id = result.params.get("connector_id") + final_parent_folder_path = result.params.get("parent_folder_path") if not final_name or not final_name.strip(): return {"status": "error", "message": "File name cannot be empty."} diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py index e15dc3092..620b39aa2 100644 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py @@ -2,7 +2,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy import String, and_, cast, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -174,53 +174,26 @@ def create_delete_dropbox_file_tool( }, } - approval = interrupt( - { - "type": "dropbox_file_trash", - "action": { - "tool": "delete_dropbox_file", - "params": { - "file_path": file_path, - "connector_id": connector.id, - "delete_from_kb": delete_from_kb, - }, - }, - "context": context, - } + result = request_approval( + action_type="dropbox_file_trash", + tool_name="delete_dropbox_file", + params={ + "file_path": file_path, + "connector_id": connector.id, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", - "message": "User declined. The file was not deleted. Do not ask again or suggest alternatives.", + "message": "User declined. Do not retry or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_file_path = final_params.get("file_path", file_path) - final_connector_id = final_params.get("connector_id", connector.id) - final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb) + final_file_path = result.params.get("file_path", file_path) + final_connector_id = result.params.get("connector_id", connector.id) + final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) if final_connector_id != connector.id: result = await db_session.execute( diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py index a812f621a..974f9b4af 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py @@ -6,7 +6,7 @@ from email.mime.text import MIMEText from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from app.services.gmail import GmailToolMetadataService @@ -85,60 +85,32 @@ def create_create_gmail_draft_tool( logger.info( f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'" ) - approval = interrupt( - { - "type": "gmail_draft_creation", - "action": { - "tool": "create_gmail_draft", - "params": { - "to": to, - "subject": subject, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": None, - }, - }, - "context": context, - } + result = request_approval( + action_type="gmail_draft_creation", + tool_name="create_gmail_draft", + params={ + "to": to, + "subject": subject, + "body": body, + "cc": cc, + "bcc": bcc, + "connector_id": None, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", "message": "User declined. The draft was not created. Do not ask again or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_to = final_params.get("to", to) - final_subject = final_params.get("subject", subject) - final_body = final_params.get("body", body) - final_cc = final_params.get("cc", cc) - final_bcc = final_params.get("bcc", bcc) - final_connector_id = final_params.get("connector_id") + final_to = result.params.get("to", to) + final_subject = result.params.get("subject", subject) + final_body = result.params.get("body", body) + final_cc = result.params.get("cc", cc) + final_bcc = result.params.get("bcc", bcc) + final_connector_id = result.params.get("connector_id") from sqlalchemy.future import select diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py index 2599578bd..a1c713f0a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py @@ -6,7 +6,7 @@ from email.mime.text import MIMEText from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from app.services.gmail import GmailToolMetadataService @@ -86,60 +86,32 @@ def create_send_gmail_email_tool( logger.info( f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'" ) - approval = interrupt( - { - "type": "gmail_email_send", - "action": { - "tool": "send_gmail_email", - "params": { - "to": to, - "subject": subject, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": None, - }, - }, - "context": context, - } + result = request_approval( + action_type="gmail_email_send", + tool_name="send_gmail_email", + params={ + "to": to, + "subject": subject, + "body": body, + "cc": cc, + "bcc": bcc, + "connector_id": None, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_to = final_params.get("to", to) - final_subject = final_params.get("subject", subject) - final_body = final_params.get("body", body) - final_cc = final_params.get("cc", cc) - final_bcc = final_params.get("bcc", bcc) - final_connector_id = final_params.get("connector_id") + final_to = result.params.get("to", to) + final_subject = result.params.get("subject", subject) + final_body = result.params.get("body", body) + final_cc = result.params.get("cc", cc) + final_bcc = result.params.get("bcc", bcc) + final_connector_id = result.params.get("connector_id") from sqlalchemy.future import select diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py index 146020845..cab97ee8a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from app.services.gmail import GmailToolMetadataService @@ -101,56 +101,28 @@ def create_trash_gmail_email_tool( logger.info( f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})" ) - approval = interrupt( - { - "type": "gmail_email_trash", - "action": { - "tool": "trash_gmail_email", - "params": { - "message_id": message_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - }, - "context": context, - } + result = request_approval( + action_type="gmail_email_trash", + tool_name="trash_gmail_email", + params={ + "message_id": message_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", "message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.", } - edited_action = decision.get("edited_action") - final_params: dict[str, Any] = {} - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_message_id = final_params.get("message_id", message_id) - final_connector_id = final_params.get( + final_message_id = result.params.get("message_id", message_id) + final_connector_id = result.params.get( "connector_id", connector_id_from_context ) - final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb) + final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) if not final_connector_id: return { diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py index 28deec2b4..1d53ac9ce 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py @@ -6,7 +6,7 @@ from email.mime.text import MIMEText from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from app.services.gmail import GmailToolMetadataService @@ -122,65 +122,37 @@ def create_update_gmail_draft_tool( f"Requesting approval for updating Gmail draft: '{original_subject}' " f"(message_id={message_id}, draft_id={draft_id_from_context})" ) - approval = interrupt( - { - "type": "gmail_draft_update", - "action": { - "tool": "update_gmail_draft", - "params": { - "message_id": message_id, - "draft_id": draft_id_from_context, - "to": final_to_default, - "subject": final_subject_default, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": connector_id_from_context, - }, - }, - "context": context, - } + result = request_approval( + action_type="gmail_draft_update", + tool_name="update_gmail_draft", + params={ + "message_id": message_id, + "draft_id": draft_id_from_context, + "to": final_to_default, + "subject": final_subject_default, + "body": body, + "cc": cc, + "bcc": bcc, + "connector_id": connector_id_from_context, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", "message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_to = final_params.get("to", final_to_default) - final_subject = final_params.get("subject", final_subject_default) - final_body = final_params.get("body", body) - final_cc = final_params.get("cc", cc) - final_bcc = final_params.get("bcc", bcc) - final_connector_id = final_params.get( + final_to = result.params.get("to", final_to_default) + final_subject = result.params.get("subject", final_subject_default) + final_body = result.params.get("body", body) + final_cc = result.params.get("cc", cc) + final_bcc = result.params.get("bcc", bcc) + final_connector_id = result.params.get( "connector_id", connector_id_from_context ) - final_draft_id = final_params.get("draft_id", draft_id_from_context) + final_draft_id = result.params.get("draft_id", draft_id_from_context) if not final_connector_id: return { diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py index 592ced5ec..37bcf083e 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py @@ -6,9 +6,9 @@ from typing import Any from google.oauth2.credentials import Credentials from googleapiclient.discovery import build from langchain_core.tools import tool -from langgraph.types import interrupt from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) @@ -90,63 +90,35 @@ def create_create_calendar_event_tool( logger.info( f"Requesting approval for creating calendar event: summary='{summary}'" ) - approval = interrupt( - { - "type": "google_calendar_event_creation", - "action": { - "tool": "create_calendar_event", - "params": { - "summary": summary, - "start_datetime": start_datetime, - "end_datetime": end_datetime, - "description": description, - "location": location, - "attendees": attendees, - "timezone": context.get("timezone"), - "connector_id": None, - }, - }, - "context": context, - } + result = request_approval( + action_type="google_calendar_event_creation", + tool_name="create_calendar_event", + params={ + "summary": summary, + "start_datetime": start_datetime, + "end_datetime": end_datetime, + "description": description, + "location": location, + "attendees": attendees, + "timezone": context.get("timezone"), + "connector_id": None, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", "message": "User declined. The event was not created. Do not ask again or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_summary = final_params.get("summary", summary) - final_start_datetime = final_params.get("start_datetime", start_datetime) - final_end_datetime = final_params.get("end_datetime", end_datetime) - final_description = final_params.get("description", description) - final_location = final_params.get("location", location) - final_attendees = final_params.get("attendees", attendees) - final_connector_id = final_params.get("connector_id") + final_summary = result.params.get("summary", summary) + final_start_datetime = result.params.get("start_datetime", start_datetime) + final_end_datetime = result.params.get("end_datetime", end_datetime) + final_description = result.params.get("description", description) + final_location = result.params.get("location", location) + final_attendees = result.params.get("attendees", attendees) + final_connector_id = result.params.get("connector_id") if not final_summary or not final_summary.strip(): return {"status": "error", "message": "Event summary cannot be empty."} diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py index 8b088487c..4d9d69b4b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py @@ -6,9 +6,9 @@ from typing import Any from google.oauth2.credentials import Credentials from googleapiclient.discovery import build from langchain_core.tools import tool -from langgraph.types import interrupt from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) @@ -100,56 +100,28 @@ def create_delete_calendar_event_tool( logger.info( f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})" ) - approval = interrupt( - { - "type": "google_calendar_event_deletion", - "action": { - "tool": "delete_calendar_event", - "params": { - "event_id": event_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - }, - "context": context, - } + result = request_approval( + action_type="google_calendar_event_deletion", + tool_name="delete_calendar_event", + params={ + "event_id": event_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", "message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.", } - edited_action = decision.get("edited_action") - final_params: dict[str, Any] = {} - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_event_id = final_params.get("event_id", event_id) - final_connector_id = final_params.get( + final_event_id = result.params.get("event_id", event_id) + final_connector_id = result.params.get( "connector_id", connector_id_from_context ) - final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb) + final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) if not final_connector_id: return { diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py index ed826f1b8..45ff6dfb9 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py @@ -6,9 +6,9 @@ from typing import Any from google.oauth2.credentials import Credentials from googleapiclient.discovery import build from langchain_core.tools import tool -from langgraph.types import interrupt from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) @@ -116,71 +116,43 @@ def create_update_calendar_event_tool( logger.info( f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})" ) - approval = interrupt( - { - "type": "google_calendar_event_update", - "action": { - "tool": "update_calendar_event", - "params": { - "event_id": event_id, - "document_id": document_id, - "connector_id": connector_id_from_context, - "new_summary": new_summary, - "new_start_datetime": new_start_datetime, - "new_end_datetime": new_end_datetime, - "new_description": new_description, - "new_location": new_location, - "new_attendees": new_attendees, - }, - }, - "context": context, - } + result = request_approval( + action_type="google_calendar_event_update", + tool_name="update_calendar_event", + params={ + "event_id": event_id, + "document_id": document_id, + "connector_id": connector_id_from_context, + "new_summary": new_summary, + "new_start_datetime": new_start_datetime, + "new_end_datetime": new_end_datetime, + "new_description": new_description, + "new_location": new_location, + "new_attendees": new_attendees, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", "message": "User declined. The event was not updated. Do not ask again or suggest alternatives.", } - edited_action = decision.get("edited_action") - final_params: dict[str, Any] = {} - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_event_id = final_params.get("event_id", event_id) - final_connector_id = final_params.get( + final_event_id = result.params.get("event_id", event_id) + final_connector_id = result.params.get( "connector_id", connector_id_from_context ) - final_new_summary = final_params.get("new_summary", new_summary) - final_new_start_datetime = final_params.get( + final_new_summary = result.params.get("new_summary", new_summary) + final_new_start_datetime = result.params.get( "new_start_datetime", new_start_datetime ) - final_new_end_datetime = final_params.get( + final_new_end_datetime = result.params.get( "new_end_datetime", new_end_datetime ) - final_new_description = final_params.get("new_description", new_description) - final_new_location = final_params.get("new_location", new_location) - final_new_attendees = final_params.get("new_attendees", new_attendees) + final_new_description = result.params.get("new_description", new_description) + final_new_location = result.params.get("new_location", new_location) + final_new_attendees = result.params.get("new_attendees", new_attendees) if not final_connector_id: return { diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py index a4fee0965..f36db8f3f 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py @@ -3,9 +3,9 @@ from typing import Any, Literal from googleapiclient.errors import HttpError from langchain_core.tools import tool -from langgraph.types import interrupt from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.google_drive.client import GoogleDriveClient from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET from app.services.google_drive import GoogleDriveToolMetadataService @@ -99,58 +99,30 @@ def create_create_google_drive_file_tool( logger.info( f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'" ) - approval = interrupt( - { - "type": "google_drive_file_creation", - "action": { - "tool": "create_google_drive_file", - "params": { - "name": name, - "file_type": file_type, - "content": content, - "connector_id": None, - "parent_folder_id": None, - }, - }, - "context": context, - } + result = request_approval( + action_type="google_drive_file_creation", + tool_name="create_google_drive_file", + params={ + "name": name, + "file_type": file_type, + "content": content, + "connector_id": None, + "parent_folder_id": None, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", "message": "User declined. The file was not created. Do not ask again or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_name = final_params.get("name", name) - final_file_type = final_params.get("file_type", file_type) - final_content = final_params.get("content", content) - final_connector_id = final_params.get("connector_id") - final_parent_folder_id = final_params.get("parent_folder_id") + final_name = result.params.get("name", name) + final_file_type = result.params.get("file_type", file_type) + final_content = result.params.get("content", content) + final_connector_id = result.params.get("connector_id") + final_parent_folder_id = result.params.get("parent_folder_id") if not final_name or not final_name.strip(): return {"status": "error", "message": "File name cannot be empty."} diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py index fdf7f9cd3..832afff0d 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py @@ -3,9 +3,9 @@ from typing import Any from googleapiclient.errors import HttpError from langchain_core.tools import tool -from langgraph.types import interrupt from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.connectors.google_drive.client import GoogleDriveClient from app.services.google_drive import GoogleDriveToolMetadataService @@ -101,56 +101,28 @@ def create_delete_google_drive_file_tool( logger.info( f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})" ) - approval = interrupt( - { - "type": "google_drive_file_trash", - "action": { - "tool": "delete_google_drive_file", - "params": { - "file_id": file_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - }, - "context": context, - } + result = request_approval( + action_type="google_drive_file_trash", + tool_name="delete_google_drive_file", + params={ + "file_id": file_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", "message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.", } - edited_action = decision.get("edited_action") - final_params: dict[str, Any] = {} - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_file_id = final_params.get("file_id", file_id) - final_connector_id = final_params.get( + final_file_id = result.params.get("file_id", file_id) + final_connector_id = result.params.get( "connector_id", connector_id_from_context ) - final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb) + final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) if not final_connector_id: return { diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py new file mode 100644 index 000000000..a1ac90dc7 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py @@ -0,0 +1,140 @@ +"""Unified HITL (Human-in-the-Loop) approval utility. + +Provides a single ``request_approval()`` function that encapsulates the +interrupt payload creation, decision parsing, and parameter merging logic +shared by every sensitive tool (native connectors and MCP tools alike). + +Usage inside a tool:: + + from app.agents.new_chat.tools.hitl import request_approval + + result = request_approval( + action_type="gmail_email_send", + tool_name="send_gmail_email", + params={"to": to, "subject": subject, "body": body}, + context=context, + ) + if result.rejected: + return {"status": "rejected", "message": "User declined."} + # result.params contains the final (possibly edited) parameters +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any + +from langgraph.types import interrupt + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True, slots=True) +class HITLResult: + """Outcome of a human-in-the-loop approval request.""" + + rejected: bool + decision_type: str + params: dict[str, Any] = field(default_factory=dict) + + +def _parse_decision(approval: Any) -> tuple[str, dict[str, Any]]: + """Extract the first valid decision and its edited parameters. + + Returns: + (decision_type, edited_params) where *decision_type* is one of + ``"approve"``, ``"edit"``, or ``"reject"`` and *edited_params* is + the dict of user-modified arguments (empty when there are none). + + Raises: + ValueError: when no usable decision dict can be found. + """ + decisions_raw = approval.get("decisions", []) if isinstance(approval, dict) else [] + decisions = decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] + decisions = [d for d in decisions if isinstance(d, dict)] + + if not decisions: + raise ValueError("No approval decision received") + + decision = decisions[0] + decision_type: str = decision.get("type") or decision.get("decision_type") or "approve" + + edited_params: dict[str, Any] = {} + edited_action = decision.get("edited_action") + if isinstance(edited_action, dict): + edited_args = edited_action.get("args") + if isinstance(edited_args, dict): + edited_params = edited_args + elif isinstance(decision.get("args"), dict): + edited_params = decision["args"] + + return decision_type, edited_params + + +def request_approval( + *, + action_type: str, + tool_name: str, + params: dict[str, Any], + context: dict[str, Any] | None = None, + trusted_tools: list[str] | None = None, +) -> HITLResult: + """Pause the graph for user approval and return the decision. + + This is a **synchronous** helper (not ``async``) because + ``langgraph.types.interrupt`` is itself synchronous — it raises a + ``GraphInterrupt`` exception that the LangGraph runtime catches. + + Parameters + ---------- + action_type: + A label that the frontend uses to select the correct approval card + (e.g. ``"gmail_email_send"``, ``"mcp_tool_call"``). + tool_name: + The registered LangChain tool name (e.g. ``"send_gmail_email"``). + params: + The original tool arguments. These are shown in the approval card + and used as defaults when the user does not edit anything. + context: + Rich metadata from a ``*ToolMetadataService`` (accounts, folders, + labels, etc.). For MCP tools this can hold the server name and + tool description. + trusted_tools: + An allow-list of tool names the user has previously marked as + "Always Allow". If *tool_name* appears in this list, HITL is + skipped and the tool executes immediately. + + Returns + ------- + HITLResult + ``result.rejected`` is ``True`` when the user chose to deny the + action. Otherwise ``result.params`` contains the final parameter + dict — either the originals or the user-edited version merged on + top. + """ + if trusted_tools and tool_name in trusted_tools: + logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name) + return HITLResult(rejected=False, decision_type="trusted", params=dict(params)) + + approval = interrupt( + { + "type": action_type, + "action": {"tool": tool_name, "params": params}, + "context": context or {}, + } + ) + + try: + decision_type, edited_params = _parse_decision(approval) + except ValueError: + logger.warning("No approval decision received for %s", tool_name) + return HITLResult(rejected=False, decision_type="error", params=params) + + logger.info("User decision for %s: %s", tool_name, decision_type) + + if decision_type == "reject": + return HITLResult(rejected=True, decision_type="reject", params=params) + + final_params = {**params, **edited_params} if edited_params else dict(params) + return HITLResult(rejected=False, decision_type=decision_type, params=final_params) diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py index d441c49f3..0b3332694 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py @@ -3,7 +3,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified @@ -69,58 +69,32 @@ def create_create_jira_issue_tool( "connector_type": "jira", } - approval = interrupt( - { - "type": "jira_issue_creation", - "action": { - "tool": "create_jira_issue", - "params": { - "project_key": project_key, - "summary": summary, - "issue_type": issue_type, - "description": description, - "priority": priority, - "connector_id": connector_id, - }, - }, - "context": context, - } + result = request_approval( + action_type="jira_issue_creation", + tool_name="create_jira_issue", + params={ + "project_key": project_key, + "summary": summary, + "issue_type": issue_type, + "description": description, + "priority": priority, + "connector_id": connector_id, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", - "message": "User declined. The issue was not created.", + "message": "User declined. Do not retry or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_project_key = final_params.get("project_key", project_key) - final_summary = final_params.get("summary", summary) - final_issue_type = final_params.get("issue_type", issue_type) - final_description = final_params.get("description", description) - final_priority = final_params.get("priority", priority) - final_connector_id = final_params.get("connector_id", connector_id) + final_project_key = result.params.get("project_key", project_key) + final_summary = result.params.get("summary", summary) + final_issue_type = result.params.get("issue_type", issue_type) + final_description = result.params.get("description", description) + final_priority = result.params.get("priority", priority) + final_connector_id = result.params.get("connector_id", connector_id) if not final_summary or not final_summary.strip(): return {"status": "error", "message": "Issue summary cannot be empty."} diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py index 2f8c370ad..52d4556a5 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py @@ -3,7 +3,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified @@ -71,54 +71,28 @@ def create_delete_jira_issue_tool( document_id = issue_data["document_id"] connector_id_from_context = context.get("account", {}).get("id") - approval = interrupt( - { - "type": "jira_issue_deletion", - "action": { - "tool": "delete_jira_issue", - "params": { - "issue_key": issue_key, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - }, - "context": context, - } + result = request_approval( + action_type="jira_issue_deletion", + tool_name="delete_jira_issue", + params={ + "issue_key": issue_key, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", - "message": "User declined. The issue was not deleted.", + "message": "User declined. Do not retry or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_issue_key = final_params.get("issue_key", issue_key) - final_connector_id = final_params.get( + final_issue_key = result.params.get("issue_key", issue_key) + final_connector_id = result.params.get( "connector_id", connector_id_from_context ) - final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb) + final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) from sqlalchemy.future import select diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py index c2b948ae3..9c676fea3 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py @@ -3,7 +3,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified @@ -75,60 +75,34 @@ def create_update_jira_issue_tool( document_id = issue_data.get("document_id") connector_id_from_context = context.get("account", {}).get("id") - approval = interrupt( - { - "type": "jira_issue_update", - "action": { - "tool": "update_jira_issue", - "params": { - "issue_key": issue_key, - "document_id": document_id, - "new_summary": new_summary, - "new_description": new_description, - "new_priority": new_priority, - "connector_id": connector_id_from_context, - }, - }, - "context": context, - } + result = request_approval( + action_type="jira_issue_update", + tool_name="update_jira_issue", + params={ + "issue_key": issue_key, + "document_id": document_id, + "new_summary": new_summary, + "new_description": new_description, + "new_priority": new_priority, + "connector_id": connector_id_from_context, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", - "message": "User declined. The issue was not updated.", + "message": "User declined. Do not retry or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_issue_key = final_params.get("issue_key", issue_key) - final_summary = final_params.get("new_summary", new_summary) - final_description = final_params.get("new_description", new_description) - final_priority = final_params.get("new_priority", new_priority) - final_connector_id = final_params.get( + final_issue_key = result.params.get("issue_key", issue_key) + final_summary = result.params.get("new_summary", new_summary) + final_description = result.params.get("new_description", new_description) + final_priority = result.params.get("new_priority", new_priority) + final_connector_id = result.params.get( "connector_id", connector_id_from_context ) - final_document_id = final_params.get("document_id", document_id) + final_document_id = result.params.get("document_id", document_id) from sqlalchemy.future import select diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py index 2b5d37903..d8005bd5c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py @@ -2,7 +2,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.linear_connector import LinearAPIError, LinearConnector @@ -94,65 +94,37 @@ def create_create_linear_issue_tool( } logger.info(f"Requesting approval for creating Linear issue: '{title}'") - approval = interrupt( - { - "type": "linear_issue_creation", - "action": { - "tool": "create_linear_issue", - "params": { - "title": title, - "description": description, - "team_id": None, - "state_id": None, - "assignee_id": None, - "priority": None, - "label_ids": [], - "connector_id": connector_id, - }, - }, - "context": context, - } + result = request_approval( + action_type="linear_issue_creation", + tool_name="create_linear_issue", + params={ + "title": title, + "description": description, + "team_id": None, + "state_id": None, + "assignee_id": None, + "priority": None, + "label_ids": [], + "connector_id": connector_id, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: logger.info("Linear issue creation rejected by user") return { "status": "rejected", - "message": "User declined. The issue was not created. Do not ask again or suggest alternatives.", + "message": "User declined. Do not retry or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_title = final_params.get("title", title) - final_description = final_params.get("description", description) - final_team_id = final_params.get("team_id") - final_state_id = final_params.get("state_id") - final_assignee_id = final_params.get("assignee_id") - final_priority = final_params.get("priority") - final_label_ids = final_params.get("label_ids") or [] - final_connector_id = final_params.get("connector_id", connector_id) + final_title = result.params.get("title", title) + final_description = result.params.get("description", description) + final_team_id = result.params.get("team_id") + final_state_id = result.params.get("state_id") + final_assignee_id = result.params.get("assignee_id") + final_priority = result.params.get("priority") + final_label_ids = result.params.get("label_ids") or [] + final_connector_id = result.params.get("connector_id", connector_id) if not final_title or not final_title.strip(): logger.error("Title is empty or contains only whitespace") diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py index 9f4a60953..d8bc88d82 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py @@ -2,7 +2,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.linear_connector import LinearAPIError, LinearConnector @@ -114,57 +114,29 @@ def create_delete_linear_issue_tool( f"Requesting approval for deleting Linear issue: '{issue_ref}' " f"(id={issue_id}, delete_from_kb={delete_from_kb})" ) - approval = interrupt( - { - "type": "linear_issue_deletion", - "action": { - "tool": "delete_linear_issue", - "params": { - "issue_id": issue_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - }, - "context": context, - } + result = request_approval( + action_type="linear_issue_deletion", + tool_name="delete_linear_issue", + params={ + "issue_id": issue_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: logger.info("Linear issue deletion rejected by user") return { "status": "rejected", - "message": "User declined. The issue was not deleted. Do not ask again or suggest alternatives.", + "message": "User declined. Do not retry or suggest alternatives.", } - edited_action = decision.get("edited_action") - final_params: dict[str, Any] = {} - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_issue_id = final_params.get("issue_id", issue_id) - final_connector_id = final_params.get( + final_issue_id = result.params.get("issue_id", issue_id) + final_connector_id = result.params.get( "connector_id", connector_id_from_context ) - final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb) + final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) logger.info( f"Deleting Linear issue with final params: issue_id={final_issue_id}, " diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py index 19af851c1..7f6d952e5 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py @@ -2,7 +2,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.linear_connector import LinearAPIError, LinearConnector @@ -130,69 +130,41 @@ def create_update_linear_issue_tool( logger.info( f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})" ) - approval = interrupt( - { - "type": "linear_issue_update", - "action": { - "tool": "update_linear_issue", - "params": { - "issue_id": issue_id, - "document_id": document_id, - "new_title": new_title, - "new_description": new_description, - "new_state_id": new_state_id, - "new_assignee_id": new_assignee_id, - "new_priority": new_priority, - "new_label_ids": new_label_ids, - "connector_id": connector_id_from_context, - }, - }, - "context": context, - } + result = request_approval( + action_type="linear_issue_update", + tool_name="update_linear_issue", + params={ + "issue_id": issue_id, + "document_id": document_id, + "new_title": new_title, + "new_description": new_description, + "new_state_id": new_state_id, + "new_assignee_id": new_assignee_id, + "new_priority": new_priority, + "new_label_ids": new_label_ids, + "connector_id": connector_id_from_context, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: logger.info("Linear issue update rejected by user") return { "status": "rejected", - "message": "User declined. The issue was not updated. Do not ask again or suggest alternatives.", + "message": "User declined. Do not retry or suggest alternatives.", } - edited_action = decision.get("edited_action") - final_params: dict[str, Any] = {} - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_issue_id = final_params.get("issue_id", issue_id) - final_document_id = final_params.get("document_id", document_id) - final_new_title = final_params.get("new_title", new_title) - final_new_description = final_params.get("new_description", new_description) - final_new_state_id = final_params.get("new_state_id", new_state_id) - final_new_assignee_id = final_params.get("new_assignee_id", new_assignee_id) - final_new_priority = final_params.get("new_priority", new_priority) - final_new_label_ids: list[str] | None = final_params.get( + final_issue_id = result.params.get("issue_id", issue_id) + final_document_id = result.params.get("document_id", document_id) + final_new_title = result.params.get("new_title", new_title) + final_new_description = result.params.get("new_description", new_description) + final_new_state_id = result.params.get("new_state_id", new_state_id) + final_new_assignee_id = result.params.get("new_assignee_id", new_assignee_id) + final_new_priority = result.params.get("new_priority", new_priority) + final_new_label_ids: list[str] | None = result.params.get( "new_label_ids", new_label_ids ) - final_connector_id = final_params.get( + final_connector_id = result.params.get( "connector_id", connector_id_from_context ) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 2fb7ffb06..9743d049d 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -7,7 +7,11 @@ Supports both transport types: - stdio: Local process-based MCP servers (command, args, env) - streamable-http/http/sse: Remote HTTP-based MCP servers (url, headers) -This implements real MCP protocol support similar to Cursor's implementation. +All MCP tools are unconditionally gated by HITL (Human-in-the-Loop) approval. +Per the MCP spec: "Clients MUST consider tool annotations to be untrusted unless +they come from trusted servers." Users can bypass HITL for specific tools by +clicking "Always Allow", which adds the tool name to the connector's +``config.trusted_tools`` allow-list. """ import logging @@ -21,6 +25,7 @@ from pydantic import BaseModel, create_model from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.mcp_client import MCPClient from app.db import SearchSourceConnector, SearchSourceConnectorType @@ -49,27 +54,15 @@ def _create_dynamic_input_model_from_schema( tool_name: str, input_schema: dict[str, Any], ) -> type[BaseModel]: - """Create a Pydantic model from MCP tool's JSON schema. - - Args: - tool_name: Name of the tool (used for model class name) - input_schema: JSON schema from MCP server - - Returns: - Pydantic model class for tool input validation - - """ + """Create a Pydantic model from MCP tool's JSON schema.""" properties = input_schema.get("properties", {}) required_fields = input_schema.get("required", []) - # Build Pydantic field definitions field_definitions = {} for param_name, param_schema in properties.items(): param_description = param_schema.get("description", "") is_required = param_name in required_fields - # Use Any type for complex schemas to preserve structure - # This allows the MCP server to do its own validation from typing import Any as AnyType from pydantic import Field @@ -85,7 +78,6 @@ def _create_dynamic_input_model_from_schema( Field(None, description=param_description), ) - # Create dynamic model model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input" return create_model(model_name, **field_definitions) @@ -93,55 +85,70 @@ def _create_dynamic_input_model_from_schema( async def _create_mcp_tool_from_definition_stdio( tool_def: dict[str, Any], mcp_client: MCPClient, + *, + connector_name: str = "", + connector_id: int | None = None, + trusted_tools: list[str] | None = None, ) -> StructuredTool: """Create a LangChain tool from an MCP tool definition (stdio transport). - Args: - tool_def: Tool definition from MCP server with name, description, input_schema - mcp_client: MCP client instance for calling the tool - - Returns: - LangChain StructuredTool instance - + All MCP tools are unconditionally wrapped with HITL approval. + ``request_approval()`` is called OUTSIDE the try/except so that + ``GraphInterrupt`` propagates cleanly to LangGraph. """ tool_name = tool_def.get("name", "unnamed_tool") tool_description = tool_def.get("description", "No description provided") input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) - # Log the actual schema for debugging logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}") - # Create dynamic input model from schema input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) async def mcp_tool_call(**kwargs) -> str: """Execute the MCP tool call via the client with retry support.""" logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}") + # HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph + hitl_result = request_approval( + action_type="mcp_tool_call", + tool_name=tool_name, + params=kwargs, + context={ + "mcp_server": connector_name, + "tool_description": tool_description, + "mcp_transport": "stdio", + "mcp_connector_id": connector_id, + }, + trusted_tools=trusted_tools, + ) + if hitl_result.rejected: + return "Tool call rejected by user." + call_kwargs = hitl_result.params + try: - # Connect to server and call tool (connect has built-in retry logic) async with mcp_client.connect(): - result = await mcp_client.call_tool(tool_name, kwargs) + result = await mcp_client.call_tool(tool_name, call_kwargs) return str(result) except RuntimeError as e: - # Connection failures after all retries error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}" logger.error(error_msg) return f"Error: {error_msg}" except Exception as e: - # Tool execution or other errors error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}" logger.exception(error_msg) return f"Error: {error_msg}" - # Create StructuredTool with response_format to preserve exact schema tool = StructuredTool( name=tool_name, description=tool_description, coroutine=mcp_tool_call, args_schema=input_model, - # Store the original MCP schema as metadata so we can access it later - metadata={"mcp_input_schema": input_schema, "mcp_transport": "stdio"}, + metadata={ + "mcp_input_schema": input_schema, + "mcp_transport": "stdio", + "hitl": True, + "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), + }, ) logger.info(f"Created MCP tool (stdio): '{tool_name}'") @@ -152,43 +159,54 @@ async def _create_mcp_tool_from_definition_http( tool_def: dict[str, Any], url: str, headers: dict[str, str], + *, + connector_name: str = "", + connector_id: int | None = None, + trusted_tools: list[str] | None = None, ) -> StructuredTool: """Create a LangChain tool from an MCP tool definition (HTTP transport). - Args: - tool_def: Tool definition from MCP server with name, description, input_schema - url: URL of the MCP server - headers: HTTP headers for authentication - - Returns: - LangChain StructuredTool instance - + All MCP tools are unconditionally wrapped with HITL approval. + ``request_approval()`` is called OUTSIDE the try/except so that + ``GraphInterrupt`` propagates cleanly to LangGraph. """ tool_name = tool_def.get("name", "unnamed_tool") tool_description = tool_def.get("description", "No description provided") input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) - # Log the actual schema for debugging logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}") - # Create dynamic input model from schema input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) async def mcp_http_tool_call(**kwargs) -> str: """Execute the MCP tool call via HTTP transport.""" logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}") + # HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph + hitl_result = request_approval( + action_type="mcp_tool_call", + tool_name=tool_name, + params=kwargs, + context={ + "mcp_server": connector_name, + "tool_description": tool_description, + "mcp_transport": "http", + "mcp_connector_id": connector_id, + }, + trusted_tools=trusted_tools, + ) + if hitl_result.rejected: + return "Tool call rejected by user." + call_kwargs = hitl_result.params + try: async with ( streamablehttp_client(url, headers=headers) as (read, write, _), ClientSession(read, write) as session, ): await session.initialize() + response = await session.call_tool(tool_name, arguments=call_kwargs) - # Call the tool - response = await session.call_tool(tool_name, arguments=kwargs) - - # Extract content from response result = [] for content in response.content: if hasattr(content, "text"): @@ -209,7 +227,6 @@ async def _create_mcp_tool_from_definition_http( logger.exception(error_msg) return f"Error: {error_msg}" - # Create StructuredTool tool = StructuredTool( name=tool_name, description=tool_description, @@ -219,6 +236,8 @@ async def _create_mcp_tool_from_definition_http( "mcp_input_schema": input_schema, "mcp_transport": "http", "mcp_url": url, + "hitl": True, + "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), }, ) @@ -230,20 +249,11 @@ async def _load_stdio_mcp_tools( connector_id: int, connector_name: str, server_config: dict[str, Any], + trusted_tools: list[str] | None = None, ) -> list[StructuredTool]: - """Load tools from a stdio-based MCP server. - - Args: - connector_id: Connector ID for logging - connector_name: Connector name for logging - server_config: Server configuration with command, args, env - - Returns: - List of tools from the MCP server - """ + """Load tools from a stdio-based MCP server.""" tools: list[StructuredTool] = [] - # Validate required command field command = server_config.get("command") if not command or not isinstance(command, str): logger.warning( @@ -251,7 +261,6 @@ async def _load_stdio_mcp_tools( ) return tools - # Validate args field (must be list if present) args = server_config.get("args", []) if not isinstance(args, list): logger.warning( @@ -259,7 +268,6 @@ async def _load_stdio_mcp_tools( ) return tools - # Validate env field (must be dict if present) env = server_config.get("env", {}) if not isinstance(env, dict): logger.warning( @@ -267,10 +275,8 @@ async def _load_stdio_mcp_tools( ) return tools - # Create MCP client mcp_client = MCPClient(command, args, env) - # Connect and discover tools async with mcp_client.connect(): tool_definitions = await mcp_client.list_tools() @@ -279,10 +285,15 @@ async def _load_stdio_mcp_tools( f"'{command}' (connector {connector_id})" ) - # Create LangChain tools from definitions for tool_def in tool_definitions: try: - tool = await _create_mcp_tool_from_definition_stdio(tool_def, mcp_client) + tool = await _create_mcp_tool_from_definition_stdio( + tool_def, + mcp_client, + connector_name=connector_name, + connector_id=connector_id, + trusted_tools=trusted_tools, + ) tools.append(tool) except Exception as e: logger.exception( @@ -297,20 +308,11 @@ async def _load_http_mcp_tools( connector_id: int, connector_name: str, server_config: dict[str, Any], + trusted_tools: list[str] | None = None, ) -> list[StructuredTool]: - """Load tools from an HTTP-based MCP server. - - Args: - connector_id: Connector ID for logging - connector_name: Connector name for logging - server_config: Server configuration with url, headers - - Returns: - List of tools from the MCP server - """ + """Load tools from an HTTP-based MCP server.""" tools: list[StructuredTool] = [] - # Validate required url field url = server_config.get("url") if not url or not isinstance(url, str): logger.warning( @@ -318,7 +320,6 @@ async def _load_http_mcp_tools( ) return tools - # Validate headers field (must be dict if present) headers = server_config.get("headers", {}) if not isinstance(headers, dict): logger.warning( @@ -326,7 +327,6 @@ async def _load_http_mcp_tools( ) return tools - # Connect and discover tools via HTTP try: async with ( streamablehttp_client(url, headers=headers) as (read, write, _), @@ -334,7 +334,6 @@ async def _load_http_mcp_tools( ): await session.initialize() - # List available tools response = await session.list_tools() tool_definitions = [] for tool in response.tools: @@ -353,11 +352,15 @@ async def _load_http_mcp_tools( f"'{url}' (connector {connector_id})" ) - # Create LangChain tools from definitions for tool_def in tool_definitions: try: tool = await _create_mcp_tool_from_definition_http( - tool_def, url, headers + tool_def, + url, + headers, + connector_name=connector_name, + connector_id=connector_id, + trusted_tools=trusted_tools, ) tools.append(tool) except Exception as e: @@ -398,14 +401,6 @@ async def load_mcp_tools( Results are cached per search space for up to 5 minutes to avoid re-spawning MCP server processes on every chat message. - - Args: - session: Database session - search_space_id: User's search space ID - - Returns: - List of LangChain StructuredTool instances - """ _evict_expired_mcp_cache() @@ -436,6 +431,7 @@ async def load_mcp_tools( try: config = connector.config or {} server_config = config.get("server_config", {}) + trusted_tools = config.get("trusted_tools", []) if not server_config or not isinstance(server_config, dict): logger.warning( @@ -447,11 +443,17 @@ async def load_mcp_tools( if transport in ("streamable-http", "http", "sse"): connector_tools = await _load_http_mcp_tools( - connector.id, connector.name, server_config + connector.id, + connector.name, + server_config, + trusted_tools=trusted_tools, ) else: connector_tools = await _load_stdio_mcp_tools( - connector.id, connector.name, server_config + connector.id, + connector.name, + server_config, + trusted_tools=trusted_tools, ) tools.extend(connector_tools) diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py index 5bb0c52d1..396f3fe0d 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py @@ -2,7 +2,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector @@ -99,61 +99,29 @@ def create_create_notion_page_tool( } logger.info(f"Requesting approval for creating Notion page: '{title}'") - approval = interrupt( - { - "type": "notion_page_creation", - "action": { - "tool": "create_notion_page", - "params": { - "title": title, - "content": content, - "parent_page_id": None, - "connector_id": connector_id, - }, - }, - "context": context, - } + result = request_approval( + action_type="notion_page_creation", + tool_name="create_notion_page", + params={ + "title": title, + "content": content, + "parent_page_id": None, + "connector_id": connector_id, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return { - "status": "error", - "message": "No approval decision received", - } - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: logger.info("Notion page creation rejected by user") return { "status": "rejected", - "message": "User declined. The page was not created. Do not ask again or suggest alternatives.", + "message": "User declined. Do not retry or suggest alternatives.", } - edited_action = decision.get("edited_action") - final_params: dict[str, Any] = {} - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - # Some interrupt payloads place args directly on the decision. - final_params = decision["args"] - - final_title = final_params.get("title", title) - final_content = final_params.get("content", content) - final_parent_page_id = final_params.get("parent_page_id") - final_connector_id = final_params.get("connector_id", connector_id) + final_title = result.params.get("title", title) + final_content = result.params.get("content", content) + final_parent_page_id = result.params.get("parent_page_id") + final_connector_id = result.params.get("connector_id", connector_id) if not final_title or not final_title.strip(): logger.error("Title is empty or contains only whitespace") diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py index fbb7c5004..92e395624 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py @@ -2,7 +2,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector @@ -114,63 +114,29 @@ def create_delete_notion_page_tool( f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})" ) - # Request approval before deleting - approval = interrupt( - { - "type": "notion_page_deletion", - "action": { - "tool": "delete_notion_page", - "params": { - "page_id": page_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - }, - "context": context, - } + result = request_approval( + action_type="notion_page_deletion", + tool_name="delete_notion_page", + params={ + "page_id": page_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return { - "status": "error", - "message": "No approval decision received", - } - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: logger.info("Notion page deletion rejected by user") return { "status": "rejected", - "message": "User declined. The page was not deleted. Do not ask again or suggest alternatives.", + "message": "User declined. Do not retry or suggest alternatives.", } - # Extract edited action arguments (if user modified the checkbox) - edited_action = decision.get("edited_action") - final_params: dict[str, Any] = {} - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - # Some interrupt payloads place args directly on the decision. - final_params = decision["args"] - - final_page_id = final_params.get("page_id", page_id) - final_connector_id = final_params.get( + final_page_id = result.params.get("page_id", page_id) + final_connector_id = result.params.get( "connector_id", connector_id_from_context ) - final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb) + final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) logger.info( f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py index 25f2b9918..ee7b8f256 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py @@ -2,7 +2,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector @@ -127,59 +127,27 @@ def create_update_notion_page_tool( logger.info( f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})" ) - approval = interrupt( - { - "type": "notion_page_update", - "action": { - "tool": "update_notion_page", - "params": { - "page_id": page_id, - "content": content, - "connector_id": connector_id_from_context, - }, - }, - "context": context, - } + result = request_approval( + action_type="notion_page_update", + tool_name="update_notion_page", + params={ + "page_id": page_id, + "content": content, + "connector_id": connector_id_from_context, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - logger.warning("No approval decision received") - return { - "status": "error", - "message": "No approval decision received", - } - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: logger.info("Notion page update rejected by user") return { "status": "rejected", - "message": "User declined. The page was not updated. Do not ask again or suggest alternatives.", + "message": "User declined. Do not retry or suggest alternatives.", } - edited_action = decision.get("edited_action") - final_params: dict[str, Any] = {} - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - # Some interrupt payloads place args directly on the decision. - final_params = decision["args"] - - final_page_id = final_params.get("page_id", page_id) - final_content = final_params.get("content", content) - final_connector_id = final_params.get( + final_page_id = result.params.get("page_id", page_id) + final_content = result.params.get("content", content) + final_connector_id = result.params.get( "connector_id", connector_id_from_context ) diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py index 8dffb18dd..5050c7885 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -145,54 +145,28 @@ def create_create_onedrive_file_tool( "parent_folders": parent_folders, } - approval = interrupt( - { - "type": "onedrive_file_creation", - "action": { - "tool": "create_onedrive_file", - "params": { - "name": name, - "content": content, - "connector_id": None, - "parent_folder_id": None, - }, - }, - "context": context, - } + result = request_approval( + action_type="onedrive_file_creation", + tool_name="create_onedrive_file", + params={ + "name": name, + "content": content, + "connector_id": None, + "parent_folder_id": None, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", - "message": "User declined. The file was not created.", + "message": "User declined. Do not retry or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_name = final_params.get("name", name) - final_content = final_params.get("content", content) - final_connector_id = final_params.get("connector_id") - final_parent_folder_id = final_params.get("parent_folder_id") + final_name = result.params.get("name", name) + final_content = result.params.get("content", content) + final_connector_id = result.params.get("connector_id") + final_parent_folder_id = result.params.get("parent_folder_id") if not final_name or not final_name.strip(): return {"status": "error", "message": "File name cannot be empty."} diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py index 79d8222fd..6997e1d52 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py @@ -2,7 +2,7 @@ import logging from typing import Any from langchain_core.tools import tool -from langgraph.types import interrupt +from app.agents.new_chat.tools.hitl import request_approval from sqlalchemy import String, and_, cast, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -174,53 +174,26 @@ def create_delete_onedrive_file_tool( }, } - approval = interrupt( - { - "type": "onedrive_file_trash", - "action": { - "tool": "delete_onedrive_file", - "params": { - "file_id": file_id, - "connector_id": connector.id, - "delete_from_kb": delete_from_kb, - }, - }, - "context": context, - } + result = request_approval( + action_type="onedrive_file_trash", + tool_name="delete_onedrive_file", + params={ + "file_id": file_id, + "connector_id": connector.id, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - decisions_raw = ( - approval.get("decisions", []) if isinstance(approval, dict) else [] - ) - decisions = ( - decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] - ) - decisions = [d for d in decisions if isinstance(d, dict)] - if not decisions: - return {"status": "error", "message": "No approval decision received"} - - decision = decisions[0] - decision_type = decision.get("type") or decision.get("decision_type") - logger.info(f"User decision: {decision_type}") - - if decision_type == "reject": + if result.rejected: return { "status": "rejected", - "message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.", + "message": "User declined. Do not retry or suggest alternatives.", } - final_params: dict[str, Any] = {} - edited_action = decision.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - final_params = edited_args - elif isinstance(decision.get("args"), dict): - final_params = decision["args"] - - final_file_id = final_params.get("file_id", file_id) - final_connector_id = final_params.get("connector_id", connector.id) - final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb) + final_file_id = result.params.get("file_id", file_id) + final_connector_id = result.params.get("connector_id", connector.id) + final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) if final_connector_id != connector.id: result = await db_session.execute( diff --git a/surfsense_backend/app/routes/editor_routes.py b/surfsense_backend/app/routes/editor_routes.py index 829b2cf69..0f986e416 100644 --- a/surfsense_backend/app/routes/editor_routes.py +++ b/surfsense_backend/app/routes/editor_routes.py @@ -139,9 +139,19 @@ async def get_editor_content( status_code=409, detail="This document is still being processed. Please wait a moment and try again.", ) + if state == "failed": + reason = ( + doc_status.get("reason", "Unknown error") + if isinstance(doc_status, dict) + else "Unknown error" + ) + raise HTTPException( + status_code=422, + detail=f"Processing failed: {reason}. You can delete this document and re-upload it.", + ) raise HTTPException( status_code=400, - detail="This document has no viewable content yet. It may still be syncing. Try again in a few seconds, or re-upload if the issue persists.", + detail="This document has no content. It may not have been processed correctly. Try deleting and re-uploading it.", ) markdown_content = "\n\n".join(chunk_contents) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index bb20da65d..b87ce28c9 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -636,9 +636,16 @@ async def delete_search_source_connector( ) # Delete the connector record + search_space_id = db_connector.search_space_id + is_mcp = db_connector.connector_type == SearchSourceConnectorType.MCP_CONNECTOR await session.delete(db_connector) await session.commit() + if is_mcp: + from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + + invalidate_mcp_tools_cache(search_space_id) + logger.info( f"Connector {connector_id} ({connector_name}) deleted successfully. " f"Total documents deleted: {total_deleted}" @@ -3624,3 +3631,114 @@ async def get_drive_picker_token( status_code=500, detail="Failed to retrieve access token. Check server logs for details.", ) from e + + +# ============================================================================= +# MCP Tool Trust (Allow-List) Routes +# ============================================================================= + + +class MCPTrustToolRequest(BaseModel): + tool_name: str + + +@router.post("/connectors/mcp/{connector_id}/trust-tool") +async def trust_mcp_tool( + connector_id: int, + body: MCPTrustToolRequest, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Add a tool to the MCP connector's trusted (always-allow) list. + + Once trusted, the tool executes without HITL approval on subsequent calls. + """ + try: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + raise HTTPException(status_code=404, detail="MCP connector not found") + + config = dict(connector.config or {}) + trusted: list[str] = list(config.get("trusted_tools", [])) + if body.tool_name not in trusted: + trusted.append(body.tool_name) + config["trusted_tools"] = trusted + connector.config = config + + from sqlalchemy.orm.attributes import flag_modified + + flag_modified(connector, "config") + await session.commit() + + from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + + invalidate_mcp_tools_cache(connector.search_space_id) + + return {"status": "ok", "trusted_tools": trusted} + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to trust MCP tool: {e!s}", exc_info=True) + await session.rollback() + raise HTTPException( + status_code=500, detail=f"Failed to trust tool: {e!s}" + ) from e + + +@router.post("/connectors/mcp/{connector_id}/untrust-tool") +async def untrust_mcp_tool( + connector_id: int, + body: MCPTrustToolRequest, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Remove a tool from the MCP connector's trusted list. + + The tool will require HITL approval again on subsequent calls. + """ + try: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + raise HTTPException(status_code=404, detail="MCP connector not found") + + config = dict(connector.config or {}) + trusted: list[str] = list(config.get("trusted_tools", [])) + if body.tool_name in trusted: + trusted.remove(body.tool_name) + config["trusted_tools"] = trusted + connector.config = config + + from sqlalchemy.orm.attributes import flag_modified + + flag_modified(connector, "config") + await session.commit() + + from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + + invalidate_mcp_tools_cache(connector.search_space_id) + + return {"status": "ok", "trusted_tools": trusted} + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to untrust MCP tool: {e!s}", exc_info=True) + await session.rollback() + raise HTTPException( + status_code=500, detail=f"Failed to untrust tool: {e!s}" + ) from e diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 0b1369340..58eb58f4b 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -798,7 +798,7 @@ export default function NewChatPage() { }); } else { const tcId = `interrupt-${action.name}`; - addToolCall(contentPartsState, TOOLS_WITH_UI, tcId, action.name, action.args); + addToolCall(contentPartsState, TOOLS_WITH_UI, tcId, action.name, action.args, true); updateToolCall(contentPartsState, tcId, { result: { __interrupt__: true, ...interruptData }, }); @@ -1125,7 +1125,7 @@ export default function NewChatPage() { }); } else { const tcId = `interrupt-${action.name}`; - addToolCall(contentPartsState, TOOLS_WITH_UI, tcId, action.name, action.args); + addToolCall(contentPartsState, TOOLS_WITH_UI, tcId, action.name, action.args, true); updateToolCall(contentPartsState, tcId, { result: { __interrupt__: true, diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/ProfileContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/ProfileContent.tsx index c1e81283b..32377194a 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/ProfileContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/ProfileContent.tsx @@ -78,7 +78,7 @@ export function ProfileContent() { ) : (
-
+
diff --git a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx index 2d55f4d20..f5de0be7d 100644 --- a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx +++ b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx @@ -499,10 +499,14 @@ export const InlineMentionEditor = forwardRef 0) { const range = selection.getRangeAt(0); @@ -512,63 +516,41 @@ export const InlineMentionEditor = forwardRef= 0; i--) { - if (textContent[i] === "@") { - atIndex = i; - break; - } - // Stop if we hit a space (@ must be at word boundary) if (textContent[i] === " " || textContent[i] === "\n") { + wordStart = i + 1; break; } } - if (atIndex !== -1) { - const query = textContent.slice(atIndex + 1, cursorPos); - // Only trigger if query doesn't start with space + let triggerChar: "@" | "/" | null = null; + let triggerIndex = -1; + for (let i = wordStart; i < cursorPos; i++) { + if (textContent[i] === "@" || textContent[i] === "/") { + triggerChar = textContent[i] as "@" | "/"; + triggerIndex = i; + break; + } + } + + if (triggerChar === "@" && triggerIndex !== -1) { + const query = textContent.slice(triggerIndex + 1, cursorPos); if (!query.startsWith(" ")) { shouldTriggerMention = true; mentionQuery = query; } - } - } - } - - // Check for / actions (same pattern as @) - let shouldTriggerAction = false; - let actionQuery = ""; - - if (!shouldTriggerMention && selection && selection.rangeCount > 0) { - const range = selection.getRangeAt(0); - const textNode = range.startContainer; - - if (textNode.nodeType === Node.TEXT_NODE) { - const textContent = textNode.textContent || ""; - const cursorPos = range.startOffset; - - let slashIndex = -1; - for (let i = cursorPos - 1; i >= 0; i--) { - if (textContent[i] === "/") { - slashIndex = i; - break; - } - if (textContent[i] === " " || textContent[i] === "\n") { - break; - } - } - - if ( - slashIndex !== -1 && - (slashIndex === 0 || - textContent[slashIndex - 1] === " " || - textContent[slashIndex - 1] === "\n") - ) { - const query = textContent.slice(slashIndex + 1, cursorPos); - if (!query.startsWith(" ")) { - shouldTriggerAction = true; - actionQuery = query; + } else if (triggerChar === "/" && triggerIndex !== -1) { + if ( + triggerIndex === 0 || + textContent[triggerIndex - 1] === " " || + textContent[triggerIndex - 1] === "\n" + ) { + const query = textContent.slice(triggerIndex + 1, cursorPos); + if (!query.startsWith(" ")) { + shouldTriggerAction = true; + actionQuery = query; + } } } } diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 7fb01401f..59797fc72 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -28,8 +28,7 @@ import { import { AnimatePresence, motion } from "motion/react"; import Image from "next/image"; import { useParams } from "next/navigation"; -import { type FC, useCallback, useEffect, useLayoutEffect, useMemo, useRef, useState } from "react"; -import { createPortal } from "react-dom"; +import { type FC, useCallback, useEffect, useMemo, useRef, useState } from "react"; import { agentToolsAtom, disabledToolsAtom, @@ -124,16 +123,18 @@ const ThreadContent: FC = () => { }} /> + !thread.isEmpty}> +
+ + - !thread.isEmpty}> -
- -
-
+ !thread.isEmpty}> + +
@@ -339,10 +340,7 @@ const Composer: FC = () => { const [showPromptPicker, setShowPromptPicker] = useState(false); const [mentionQuery, setMentionQuery] = useState(""); const [actionQuery, setActionQuery] = useState(""); - const [containerPos, setContainerPos] = useState({ bottom: "200px", left: "50%", top: "auto" }); const editorRef = useRef(null); - const editorContainerRef = useRef(null); - const composerBoxRef = useRef(null); const documentPickerRef = useRef(null); const promptPickerRef = useRef(null); const viewportRef = useRef(null); @@ -363,38 +361,13 @@ const Composer: FC = () => { viewportRef.current = document.querySelector(".aui-thread-viewport"); }, []); - // Compute picker positions using ResizeObserver to avoid layout reads during render - useLayoutEffect(() => { - if (!editorContainerRef.current) return; - - const updatePosition = () => { - if (!editorContainerRef.current) return; - const rect = editorContainerRef.current.getBoundingClientRect(); - const composerRect = composerBoxRef.current?.getBoundingClientRect(); - setContainerPos({ - bottom: `${window.innerHeight - rect.top + 8}px`, - left: `${rect.left}px`, - top: composerRect ? `${composerRect.bottom + 8}px` : "auto", - }); - }; - - updatePosition(); - const ro = new ResizeObserver(updatePosition); - ro.observe(editorContainerRef.current); - if (composerBoxRef.current) { - ro.observe(composerBoxRef.current); - } - - return () => ro.disconnect(); - }, []); - const electronAPI = useElectronAPI(); const [clipboardInitialText, setClipboardInitialText] = useState(); const clipboardLoadedRef = useRef(false); useEffect(() => { if (!electronAPI || clipboardLoadedRef.current) return; clipboardLoadedRef.current = true; - electronAPI.getQuickAskText().then((text) => { + electronAPI.getQuickAskText().then((text: string) => { if (text) { setClipboardInitialText(text); } @@ -587,23 +560,15 @@ const Composer: FC = () => { // Submit message (blocked during streaming, document picker open, or AI responding to another user) const handleSubmit = useCallback(() => { - if (isThreadRunning || isBlockedByOtherUser) { - return; - } - if (!showDocumentPopover && !showPromptPicker) { - if (clipboardInitialText) { - const userText = editorRef.current?.getText() ?? ""; - const combined = userText ? `${userText}\n\n${clipboardInitialText}` : clipboardInitialText; - aui.composer().setText(combined); - setClipboardInitialText(undefined); - } - aui.composer().send(); - editorRef.current?.clear(); - setMentionedDocuments([]); - setSidebarDocs([]); - } if (isThreadRunning || isBlockedByOtherUser) return; - if (showDocumentPopover) return; + if (showDocumentPopover || showPromptPicker) return; + + if (clipboardInitialText) { + const userText = editorRef.current?.getText() ?? ""; + const combined = userText ? `${userText}\n\n${clipboardInitialText}` : clipboardInitialText; + aui.composer().setText(combined); + setClipboardInitialText(undefined); + } const viewportEl = viewportRef.current; const heightBefore = viewportEl?.scrollHeight ?? 0; @@ -617,18 +582,14 @@ const Composer: FC = () => { // assistant message so that scrolling-to-bottom actually positions the // user message at the TOP of the viewport. That slack height is // calculated asynchronously (ResizeObserver → style → layout). - // - // We poll via rAF for ~2 s, re-scrolling whenever scrollHeight changes - // (user msg render → assistant placeholder → ViewportSlack min-height → - // first streamed content). Backup setTimeout calls cover cases where - // the batcher's 50 ms throttle delays the DOM update past the rAF. + // Poll via rAF for ~500ms, re-scrolling whenever scrollHeight changes. const scrollToBottom = () => threadViewportStore.getState().scrollToBottom({ behavior: "instant" }); let lastHeight = heightBefore; let frames = 0; let cancelled = false; - const POLL_FRAMES = 120; + const POLL_FRAMES = 30; const pollAndScroll = () => { if (cancelled) return; @@ -648,16 +609,11 @@ const Composer: FC = () => { const t1 = setTimeout(scrollToBottom, 100); const t2 = setTimeout(scrollToBottom, 300); - const t3 = setTimeout(scrollToBottom, 600); - // Cleanup if component unmounts during the polling window. The ref is - // checked inside pollAndScroll; timeouts are cleared in the return below. - // Store cleanup fn so it can be called from a useEffect cleanup if needed. submitCleanupRef.current = () => { cancelled = true; clearTimeout(t1); clearTimeout(t2); - clearTimeout(t3); }; }, [ showDocumentPopover, @@ -705,28 +661,54 @@ const Composer: FC = () => { ); return ( - + -
+ {showDocumentPopover && ( +
+ { + setShowDocumentPopover(false); + setMentionQuery(""); + }} + initialSelectedDocuments={mentionedDocuments} + externalSearch={mentionQuery} + /> +
+ )} + {showPromptPicker && ( +
+ { + setShowPromptPicker(false); + setActionQuery(""); + }} + externalSearch={actionQuery} + /> +
+ )} +
{clipboardInitialText && ( setClipboardInitialText(undefined)} /> )} - {/* Inline editor with @mention support */} -
+
{ className="min-h-[24px]" />
- {/* Document picker popover (portal to body for proper z-index stacking) */} - {showDocumentPopover && - typeof document !== "undefined" && - createPortal( - { - setShowDocumentPopover(false); - setMentionQuery(""); - }} - initialSelectedDocuments={mentionedDocuments} - externalSearch={mentionQuery} - containerStyle={{ - bottom: containerPos.bottom, - left: containerPos.left, - }} - />, - document.body - )} - {showPromptPicker && - typeof document !== "undefined" && - createPortal( - { - setShowPromptPicker(false); - setActionQuery(""); - }} - externalSearch={actionQuery} - containerStyle={{ - position: "fixed", - ...(clipboardInitialText - ? { top: containerPos.top } - : { bottom: containerPos.bottom }), - left: containerPos.left, - zIndex: 50, - }} - />, - document.body - )} diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index b658dba6d..d9833b387 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -1,14 +1,16 @@ import type { ToolCallMessagePartComponent } from "@assistant-ui/react"; import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XCircleIcon } from "lucide-react"; import { useMemo, useState } from "react"; +import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval"; import { getToolIcon } from "@/contracts/enums/toolIcons"; +import { isInterruptResult } from "@/lib/hitl"; import { cn } from "@/lib/utils"; function formatToolName(name: string): string { return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); } -export const ToolFallback: ToolCallMessagePartComponent = ({ +const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ toolName, argsText, result, @@ -145,3 +147,10 @@ export const ToolFallback: ToolCallMessagePartComponent = ({
); }; + +export const ToolFallback: ToolCallMessagePartComponent = (props) => { + if (isInterruptResult(props.result)) { + return ; + } + return ; +}; diff --git a/surfsense_web/components/documents/DocumentNode.tsx b/surfsense_web/components/documents/DocumentNode.tsx index 04cec5f89..478037520 100644 --- a/surfsense_web/components/documents/DocumentNode.tsx +++ b/surfsense_web/components/documents/DocumentNode.tsx @@ -82,11 +82,12 @@ export const DocumentNode = React.memo(function DocumentNode({ onContextMenuOpenChange, }: DocumentNodeProps) { const statusState = doc.status?.state ?? "ready"; - const isSelectable = statusState !== "pending" && statusState !== "processing"; + const isFailed = statusState === "failed"; + const isProcessing = statusState === "pending" || statusState === "processing"; + const isUnavailable = isProcessing || isFailed; + const isSelectable = !isUnavailable; const isEditable = - EDITABLE_DOCUMENT_TYPES.has(doc.document_type) && - statusState !== "pending" && - statusState !== "processing"; + EDITABLE_DOCUMENT_TYPES.has(doc.document_type) && !isUnavailable; const handleCheckChange = useCallback(() => { if (isSelectable) { @@ -103,7 +104,6 @@ export const DocumentNode = React.memo(function DocumentNode({ [doc.id] ); - const isProcessing = statusState === "pending" || statusState === "processing"; const [dropdownOpen, setDropdownOpen] = useState(false); const [exporting, setExporting] = useState(null); const [titleTooltipOpen, setTitleTooltipOpen] = useState(false); @@ -261,38 +261,38 @@ export const DocumentNode = React.memo(function DocumentNode({ className="w-40" onClick={(e) => e.stopPropagation()} > - onPreview(doc)} disabled={isProcessing}> - - Open + onPreview(doc)} disabled={isUnavailable}> + + Open + + {isEditable && ( + onEdit(doc)}> + + Edit - {isEditable && ( - onEdit(doc)}> - - Edit - - )} - onMove(doc)}> - - Move to... + )} + onMove(doc)}> + + Move to... + + {onExport && ( + + + + Export + + + + + + )} + {onVersionHistory && isVersionableType(doc.document_type) && ( + onVersionHistory(doc)}> + + Versions - {onExport && ( - - - - Export - - - - - - )} - {onVersionHistory && isVersionableType(doc.document_type) && ( - onVersionHistory(doc)}> - - Versions - - )} - onDelete(doc)}> + )} + onDelete(doc)}> Delete @@ -304,38 +304,38 @@ export const DocumentNode = React.memo(function DocumentNode({ {contextMenuOpen && ( e.stopPropagation()}> - onPreview(doc)} disabled={isProcessing}> - - Open + onPreview(doc)} disabled={isUnavailable}> + + Open + + {isEditable && ( + onEdit(doc)}> + + Edit - {isEditable && ( - onEdit(doc)}> - - Edit - - )} - onMove(doc)}> - - Move to... + )} + onMove(doc)}> + + Move to... + + {onExport && ( + + + + Export + + + + + + )} + {onVersionHistory && isVersionableType(doc.document_type) && ( + onVersionHistory(doc)}> + + Versions - {onExport && ( - - - - Export - - - - - - )} - {onVersionHistory && isVersionableType(doc.document_type) && ( - onVersionHistory(doc)}> - - Versions - - )} - onDelete(doc)}> + )} + onDelete(doc)}> Delete diff --git a/surfsense_web/components/documents/DocumentsFilters.tsx b/surfsense_web/components/documents/DocumentsFilters.tsx index 703c9c3b4..d43f3680b 100644 --- a/surfsense_web/components/documents/DocumentsFilters.tsx +++ b/surfsense_web/components/documents/DocumentsFilters.tsx @@ -1,6 +1,6 @@ "use client"; -import { Download, FolderPlus, ListFilter, Loader2, Search, Upload, X } from "lucide-react"; +import { FolderPlus, ListFilter, Search, Upload, X } from "lucide-react"; import { useTranslations } from "next-intl"; import React, { useCallback, useMemo, useRef, useState } from "react"; import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup"; @@ -20,8 +20,6 @@ export function DocumentsFilters({ onToggleType, activeTypes, onCreateFolder, - onExportKB, - isExporting, }: { typeCounts: Partial>; onSearch: (v: string) => void; @@ -29,8 +27,6 @@ export function DocumentsFilters({ onToggleType: (type: DocumentTypeEnum, checked: boolean) => void; activeTypes: DocumentTypeEnum[]; onCreateFolder?: () => void; - onExportKB?: () => void; - isExporting?: boolean; }) { const t = useTranslations("documents"); const id = React.useId(); @@ -85,33 +81,8 @@ export function DocumentsFilters({ New folder - - )} - - {onExportKB && ( - - - { - e.preventDefault(); - onExportKB(); - }} - > - {isExporting ? ( - - ) : ( - - )} - - - - {isExporting ? "Exporting…" : "Export knowledge base"} - - - )} + + )} diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index 7e9c33a1a..aecf55a27 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -532,16 +532,14 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid const isOutOfSync = currentThreadState.id !== null && !params?.chat_id; if (isOutOfSync) { - // First sync Next.js router by navigating to the current chat's actual URL - // This updates the router's internal state to match the browser URL resetCurrentThread(); - router.replace(`/dashboard/${searchSpaceId}/new-chat/${currentThreadState.id}`); - // Allow router to sync, then navigate to fresh new-chat - setTimeout(() => { - router.push(`/dashboard/${searchSpaceId}/new-chat`); - }, 0); + // Immediately set the browser URL so the page remounts with a clean /new-chat path + window.history.replaceState(null, "", `/dashboard/${searchSpaceId}/new-chat`); + // Force-remount the page component to reset all React state synchronously + setChatResetKey((k) => k + 1); + // Sync Next.js router internals so useParams/usePathname stay correct going forward + router.replace(`/dashboard/${searchSpaceId}/new-chat`); } else { - // Normal navigation - router is in sync router.push(`/dashboard/${searchSpaceId}/new-chat`); } }, [router, searchSpaceId, currentThreadState.id, params?.chat_id, resetCurrentThread]); diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index 20b25a2d2..d69f48606 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -406,22 +406,13 @@ export function DocumentsSidebar({ setFolderPickerOpen(true); }, []); - const [isExportingKB, setIsExportingKB] = useState(false); + const [, setIsExportingKB] = useState(false); const [exportWarningOpen, setExportWarningOpen] = useState(false); const [exportWarningContext, setExportWarningContext] = useState<{ - type: "kb" | "folder"; - folder?: FolderDisplay; + folder: FolderDisplay; pendingCount: number; } | null>(null); - const pendingDocuments = useMemo( - () => - treeDocuments.filter( - (d) => d.status?.state === "pending" || d.status?.state === "processing" - ), - [treeDocuments] - ); - const doExport = useCallback(async (url: string, downloadName: string) => { const response = await authenticatedFetch(url, { method: "GET" }); if (!response.ok) { @@ -440,68 +431,28 @@ export function DocumentsSidebar({ URL.revokeObjectURL(blobUrl); }, []); - const handleExportKB = useCallback(async () => { - if (isExportingKB) return; - - if (pendingDocuments.length > 0) { - setExportWarningContext({ type: "kb", pendingCount: pendingDocuments.length }); - setExportWarningOpen(true); - return; - } - - setIsExportingKB(true); - try { - await doExport( - `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export`, - "knowledge-base.zip" - ); - toast.success("Knowledge base exported"); - } catch (err) { - console.error("KB export failed:", err); - toast.error(err instanceof Error ? err.message : "Export failed"); - } finally { - setIsExportingKB(false); - } - }, [searchSpaceId, isExportingKB, pendingDocuments.length, doExport]); - const handleExportWarningConfirm = useCallback(async () => { setExportWarningOpen(false); const ctx = exportWarningContext; - if (!ctx) return; + if (!ctx?.folder) return; - if (ctx.type === "kb") { - setIsExportingKB(true); - try { - await doExport( - `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export`, - "knowledge-base.zip" - ); - toast.success("Knowledge base exported"); - } catch (err) { - console.error("KB export failed:", err); - toast.error(err instanceof Error ? err.message : "Export failed"); - } finally { - setIsExportingKB(false); - } - } else if (ctx.type === "folder" && ctx.folder) { - setIsExportingKB(true); - try { - const safeName = - ctx.folder.name - .replace(/[^a-zA-Z0-9 _-]/g, "_") - .trim() - .slice(0, 80) || "folder"; - await doExport( - `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export?folder_id=${ctx.folder.id}`, - `${safeName}.zip` - ); - toast.success(`Folder "${ctx.folder.name}" exported`); - } catch (err) { - console.error("Folder export failed:", err); - toast.error(err instanceof Error ? err.message : "Export failed"); - } finally { - setIsExportingKB(false); - } + setIsExportingKB(true); + try { + const safeName = + ctx.folder.name + .replace(/[^a-zA-Z0-9 _-]/g, "_") + .trim() + .slice(0, 80) || "folder"; + await doExport( + `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export?folder_id=${ctx.folder.id}`, + `${safeName}.zip` + ); + toast.success(`Folder "${ctx.folder.name}" exported`); + } catch (err) { + console.error("Folder export failed:", err); + toast.error(err instanceof Error ? err.message : "Export failed"); + } finally { + setIsExportingKB(false); } setExportWarningContext(null); }, [exportWarningContext, searchSpaceId, doExport]); @@ -530,7 +481,6 @@ export function DocumentsSidebar({ const folderPendingCount = getPendingCountInSubtree(folder.id); if (folderPendingCount > 0) { setExportWarningContext({ - type: "folder", folder, pendingCount: folderPendingCount, }); @@ -677,9 +627,10 @@ export function DocumentsSidebar({ function collectSubtreeDocs(parentId: number): DocumentNodeDoc[] { const directDocs = (treeDocuments ?? []).filter( (d) => - d.folderId === parentId && - d.status?.state !== "pending" && - d.status?.state !== "processing" + d.folderId === parentId && + d.status?.state !== "pending" && + d.status?.state !== "processing" && + d.status?.state !== "failed" ); const childFolders = foldersByParent[String(parentId)] ?? []; const descendantDocs = childFolders.flatMap((cf) => collectSubtreeDocs(cf.id)); @@ -954,8 +905,6 @@ export function DocumentsSidebar({ onToggleType={onToggleType} activeTypes={activeTypes} onCreateFolder={() => handleCreateFolder(null)} - onExportKB={handleExportKB} - isExporting={isExportingKB} />
diff --git a/surfsense_web/components/new-chat/document-mention-picker.tsx b/surfsense_web/components/new-chat/document-mention-picker.tsx index 9c6521f31..f2985278d 100644 --- a/surfsense_web/components/new-chat/document-mention-picker.tsx +++ b/surfsense_web/components/new-chat/document-mention-picker.tsx @@ -29,8 +29,6 @@ interface DocumentMentionPickerProps { onDone: () => void; initialSelectedDocuments?: Pick[]; externalSearch?: string; - /** Positioning styles for the container */ - containerStyle?: React.CSSProperties; } const PAGE_SIZE = 20; @@ -75,7 +73,6 @@ export const DocumentMentionPicker = forwardRef< onDone, initialSelectedDocuments = [], externalSearch = "", - containerStyle, }, ref ) { @@ -394,19 +391,9 @@ export const DocumentMentionPicker = forwardRef< [selectableDocuments, highlightedIndex, handleSelectDocument, onDone] ); - // Hide popup when there are no documents to display (regardless of fetch state) - // Search continues in background; popup reappears when results arrive - if (!actualLoading && actualDocuments.length === 0) { - return null; - } - return (
)}
- ) : null} + ) : ( +
+

No matching documents

+
+ )}
); diff --git a/surfsense_web/components/new-chat/prompt-picker.tsx b/surfsense_web/components/new-chat/prompt-picker.tsx index 3e6457b8c..b1723f0be 100644 --- a/surfsense_web/components/new-chat/prompt-picker.tsx +++ b/surfsense_web/components/new-chat/prompt-picker.tsx @@ -15,7 +15,7 @@ import { import { promptsAtom } from "@/atoms/prompts/prompts-query.atoms"; import { userSettingsDialogAtom } from "@/atoms/settings/settings-dialog.atoms"; -import { Spinner } from "@/components/ui/spinner"; +import { Skeleton } from "@/components/ui/skeleton"; import { cn } from "@/lib/utils"; export interface PromptPickerRef { @@ -28,11 +28,10 @@ interface PromptPickerProps { onSelect: (action: { name: string; prompt: string; mode: "transform" | "explore" }) => void; onDone: () => void; externalSearch?: string; - containerStyle?: React.CSSProperties; } export const PromptPicker = forwardRef(function PromptPicker( - { onSelect, onDone, externalSearch = "", containerStyle }, + { onSelect, onDone, externalSearch = "" }, ref ) { const setUserSettingsDialog = useSetAtom(userSettingsDialogAtom); @@ -60,13 +59,21 @@ export const PromptPicker = forwardRef(funct } } + const createPromptIndex = filtered.length; + const totalItems = filtered.length + 1; + const handleSelect = useCallback( (index: number) => { + if (index === createPromptIndex) { + onDone(); + setUserSettingsDialog({ open: true, initialTab: "prompts" }); + return; + } const action = filtered[index]; if (!action) return; onSelect({ name: action.name, prompt: action.prompt, mode: action.mode }); }, - [filtered, onSelect] + [filtered, onSelect, createPromptIndex, onDone, setUserSettingsDialog] ); useEffect(() => { @@ -93,69 +100,98 @@ export const PromptPicker = forwardRef(funct () => ({ selectHighlighted: () => handleSelect(highlightedIndex), moveUp: () => { - if (filtered.length === 0) return; shouldScrollRef.current = true; - setHighlightedIndex((prev) => (prev > 0 ? prev - 1 : filtered.length - 1)); + setHighlightedIndex((prev) => (prev > 0 ? prev - 1 : totalItems - 1)); }, moveDown: () => { - if (filtered.length === 0) return; shouldScrollRef.current = true; - setHighlightedIndex((prev) => (prev < filtered.length - 1 ? prev + 1 : 0)); + setHighlightedIndex((prev) => (prev < totalItems - 1 ? prev + 1 : 0)); }, }), - [filtered.length, highlightedIndex, handleSelect] + [totalItems, highlightedIndex, handleSelect] ); return ( -
-
+
+
{isLoading ? ( -
- +
+
+ +
+ {["a", "b", "c", "d", "e"].map((id, i) => ( +
= 3 && "hidden sm:flex" + )} + > + + + + + + +
+ ))}
) : isError ? ( -

Failed to load prompts

+
+

Failed to load prompts

+
) : filtered.length === 0 ? ( -

No matching prompts

+
+

No matching prompts

+
) : ( - filtered.map((action, index) => ( +
+
+ Saved Prompts +
+ {filtered.map((action, index) => ( + + ))} + +
- )) +
)} - -
-
); diff --git a/surfsense_web/components/settings/model-config-manager.tsx b/surfsense_web/components/settings/agent-model-manager.tsx similarity index 83% rename from surfsense_web/components/settings/model-config-manager.tsx rename to surfsense_web/components/settings/agent-model-manager.tsx index f83251426..f7a2fb824 100644 --- a/surfsense_web/components/settings/model-config-manager.tsx +++ b/surfsense_web/components/settings/agent-model-manager.tsx @@ -10,7 +10,6 @@ import { MessageSquareQuote, RefreshCw, Trash2, - Wand2, } from "lucide-react"; import { useMemo, useState } from "react"; import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; @@ -43,7 +42,7 @@ import { useMediaQuery } from "@/hooks/use-media-query"; import { getProviderIcon } from "@/lib/provider-icons"; import { cn } from "@/lib/utils"; -interface ModelConfigManagerProps { +interface AgentModelManagerProps { searchSpaceId: number; } @@ -55,7 +54,7 @@ function getInitials(name: string): string { return name.slice(0, 2).toUpperCase(); } -export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { +export function AgentModelManager({ searchSpaceId }: AgentModelManagerProps) { const isDesktop = useMediaQuery("(min-width: 768px)"); // Mutations const { mutateAsync: deleteConfig, isPending: isDeleting } = useAtomValue( @@ -208,28 +207,26 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - {/* Header */} -
+ {/* Header: Icon + Name */} +
+
- {/* Provider + Model */} -
- - -
{/* Feature badges */}
-
{/* Footer */} -
- - - +
+ + +
+ + +
@@ -262,20 +259,25 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
- {/* Header: Name + Actions */} -
-
-

- {config.name} -

- {config.description && ( -

- {config.description} -

- )} + {/* Header: Icon + Name + Actions */} +
+
+
+ {getProviderIcon(config.provider, { className: "size-4" })} +
+
+

+ {config.name} +

+ {config.description && ( +

+ {config.description} +

+ )} +
{(canUpdate || canDelete) && ( -
+
{canUpdate && ( @@ -284,7 +286,7 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { variant="ghost" size="icon" onClick={() => openEditDialog(config)} - className="h-7 w-7 text-muted-foreground hover:text-foreground" + className="h-7 w-7 rounded-lg text-muted-foreground hover:text-foreground" > @@ -301,7 +303,7 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { variant="ghost" size="icon" onClick={() => setConfigToDelete(config)} - className="h-7 w-7 text-muted-foreground hover:text-destructive" + className="h-7 w-7 rounded-lg text-muted-foreground hover:text-destructive" > @@ -314,20 +316,12 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { )}
- {/* Provider + Model */} -
- {getProviderIcon(config.provider, { className: "size-3.5 shrink-0" })} - - {config.model_name} - -
- {/* Feature badges */}
{config.citations_enabled && ( Citations @@ -336,8 +330,8 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { {!config.use_default_system_instructions && config.system_instructions && ( Custom @@ -346,8 +340,8 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
{/* Footer: Date + Creator */} -
- +
+ {new Date(config.created_at).toLocaleDateString(undefined, { year: "numeric", month: "short", @@ -356,11 +350,11 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { {member && ( <> - + -
+
{member.avatarUrl && ( @@ -369,7 +363,7 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { {getInitials(member.name)} - + {member.name}
diff --git a/surfsense_web/components/settings/general-settings-manager.tsx b/surfsense_web/components/settings/general-settings-manager.tsx index e83525982..15d44906b 100644 --- a/surfsense_web/components/settings/general-settings-manager.tsx +++ b/surfsense_web/components/settings/general-settings-manager.tsx @@ -2,18 +2,18 @@ import { useQuery } from "@tanstack/react-query"; import { useAtomValue } from "jotai"; -import { Info } from "lucide-react"; +import { FolderArchive, Info } from "lucide-react"; import { useTranslations } from "next-intl"; -import { useEffect, useState } from "react"; +import { useCallback, useEffect, useState } from "react"; import { toast } from "sonner"; import { updateSearchSpaceMutationAtom } from "@/atoms/search-spaces/search-space-mutation.atoms"; import { Alert, AlertDescription } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; -import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { Skeleton } from "@/components/ui/skeleton"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; +import { authenticatedFetch } from "@/lib/auth-utils"; import { cacheKeys } from "@/lib/query-client/cache-keys"; import { Spinner } from "../ui/spinner"; @@ -40,6 +40,37 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager const [name, setName] = useState(""); const [description, setDescription] = useState(""); const [saving, setSaving] = useState(false); + const [isExporting, setIsExporting] = useState(false); + + const handleExportKB = useCallback(async () => { + if (isExporting) return; + setIsExporting(true); + try { + const response = await authenticatedFetch( + `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export`, + { method: "GET" } + ); + if (!response.ok) { + const errorData = await response.json().catch(() => ({ detail: "Export failed" })); + throw new Error(errorData.detail || "Export failed"); + } + const blob = await response.blob(); + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + a.download = "knowledge-base.zip"; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + toast.success("Knowledge base exported"); + } catch (err) { + console.error("KB export failed:", err); + toast.error(err instanceof Error ? err.message : "Export failed"); + } finally { + setIsExporting(false); + } + }, [searchSpaceId, isExporting]); // Initialize state from fetched search space useEffect(() => { @@ -83,16 +114,10 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager if (loading) { return (
- - - - - - - - - - +
+ + +
); } @@ -113,61 +138,45 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager - Update your search space name and description. These details help identify and organize - your workspace. + Update your search space name and description. - {/* Search Space Details Card */} - - - - Search Space Details - - Manage the basic information for this search space. - - - -
- - setName(e.target.value)} - className="text-sm md:text-base h-9 md:h-10" - /> -

- {t("general_name_description")} -

-
+ +
+
+ + setName(e.target.value)} + /> +

+ {t("general_name_description")} +

+
-
- - setDescription(e.target.value)} - className="text-sm md:text-base h-9 md:h-10" - /> -

- {t("general_description_description")} -

-
- - +
+ + setDescription(e.target.value)} + /> +

+ {t("general_description_description")} +

+
+
- {/* Action Buttons */} -
+
+ +
+
+ +

+ Download all documents in this search space as a ZIP of markdown files. +

+
+ +
); } diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx index 831519aa2..fb28e5b1c 100644 --- a/surfsense_web/components/settings/image-model-manager.tsx +++ b/surfsense_web/components/settings/image-model-manager.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtomValue } from "jotai"; -import { AlertCircle, Dot, Edit3, Info, RefreshCw, Trash2, Wand2 } from "lucide-react"; +import { AlertCircle, Dot, Edit3, Info, RefreshCw, Trash2 } from "lucide-react"; import { useMemo, useState } from "react"; import { deleteImageGenConfigMutationAtom } from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; import { @@ -209,20 +209,20 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( -
+
+
-
- - -
-
- - - +
+ + +
+ + +
@@ -255,20 +255,25 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
- {/* Header: Name + Actions */} -
-
-

- {config.name} -

- {config.description && ( -

- {config.description} -

- )} + {/* Header: Icon + Name + Actions */} +
+
+
+ {getProviderIcon(config.provider, { className: "size-4" })} +
+
+

+ {config.name} +

+ {config.description && ( +

+ {config.description} +

+ )} +
{(canUpdate || canDelete) && ( -
+
{canUpdate && ( @@ -277,7 +282,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { variant="ghost" size="icon" onClick={() => openEditDialog(config)} - className="h-7 w-7 text-muted-foreground hover:text-foreground" + className="h-7 w-7 rounded-lg text-muted-foreground hover:text-foreground" > @@ -294,7 +299,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { variant="ghost" size="icon" onClick={() => setConfigToDelete(config)} - className="h-7 w-7 text-muted-foreground hover:text-destructive" + className="h-7 w-7 rounded-lg text-muted-foreground hover:text-destructive" > @@ -307,17 +312,9 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { )}
- {/* Provider + Model */} -
- {getProviderIcon(config.provider, { className: "size-3.5 shrink-0" })} - - {config.model_name} - -
- {/* Footer: Date + Creator */} -
- +
+ {new Date(config.created_at).toLocaleDateString(undefined, { year: "numeric", month: "short", @@ -326,11 +323,11 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { {member && ( <> - + -
+
{member.avatarUrl && ( @@ -339,7 +336,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { {getInitials(member.name)} - + {member.name}
diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx index e280db493..d6eb7c64d 100644 --- a/surfsense_web/components/settings/llm-role-manager.tsx +++ b/surfsense_web/components/settings/llm-role-manager.tsx @@ -6,7 +6,7 @@ import { Bot, CircleCheck, CircleDashed, - Eye, + ScanEye, FileText, ImageIcon, RefreshCw, @@ -74,7 +74,7 @@ const ROLE_DESCRIPTIONS = { configType: "image" as const, }, vision: { - icon: Eye, + icon: ScanEye, title: "Vision LLM", description: "Vision-capable model for screenshot analysis and context extraction", color: "text-muted-foreground", diff --git a/surfsense_web/components/settings/roles-manager.tsx b/surfsense_web/components/settings/roles-manager.tsx index 5b30b5f60..b72a53854 100644 --- a/surfsense_web/components/settings/roles-manager.tsx +++ b/surfsense_web/components/settings/roles-manager.tsx @@ -4,6 +4,7 @@ import { useQuery } from "@tanstack/react-query"; import { useAtomValue } from "jotai"; import { Bot, + ChevronDown, Edit2, FileText, Globe, @@ -47,7 +48,6 @@ import { DialogDescription, DialogHeader, DialogTitle, - DialogTrigger, } from "@/components/ui/dialog"; import { DropdownMenu, @@ -58,7 +58,6 @@ import { } from "@/components/ui/dropdown-menu"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; -import { ScrollArea } from "@/components/ui/scroll-area"; import { Spinner } from "@/components/ui/spinner"; import type { PermissionInfo } from "@/contracts/types/permissions.types"; import type { @@ -319,100 +318,6 @@ export function RolesManager({ searchSpaceId }: { searchSpaceId: number }) { ); } -// ============ Role Permissions Display ============ - -function RolePermissionsDialog({ - permissions, - roleName, - children, -}: { - permissions: string[]; - roleName: string; - children: React.ReactNode; -}) { - const isFullAccess = permissions.includes("*"); - - const grouped: Record = {}; - if (!isFullAccess) { - for (const perm of permissions) { - const [category, action] = perm.split(":"); - if (!grouped[category]) grouped[category] = []; - grouped[category].push(action); - } - } - - const sortedCategories = Object.keys(grouped).sort((a, b) => { - const orderA = CATEGORY_CONFIG[a]?.order ?? 99; - const orderB = CATEGORY_CONFIG[b]?.order ?? 99; - return orderA - orderB; - }); - - const categoryCount = sortedCategories.length; - - return ( - - {children} - - - {roleName} — Permissions - - {isFullAccess - ? "This role has unrestricted access to all resources" - : `${permissions.length} permissions across ${categoryCount} categories`} - - - {isFullAccess ? ( -
-
- -
-
-

Full access

-

- All permissions granted across every category -

-
-
- ) : ( - -
- {sortedCategories.map((category) => { - const actions = grouped[category]; - const config = CATEGORY_CONFIG[category] || { - label: category, - icon: FileText, - }; - const IconComponent = config.icon; - return ( -
-
- - {config.label} -
-
- {actions.map((action) => ( - - {ACTION_LABELS[action] || action.replace(/_/g, " ")} - - ))} -
-
- ); - })} -
-
- )} -
-
- ); -} - function PermissionsBadge({ permissions }: { permissions: string[] }) { if (permissions.includes("*")) { return ( @@ -463,6 +368,7 @@ function RolesContent({ }) { const [showCreateRole, setShowCreateRole] = useState(false); const [editingRoleId, setEditingRoleId] = useState(null); + const [expandedRoleId, setExpandedRoleId] = useState(null); if (loading) { return ( @@ -508,91 +414,170 @@ function RolesContent({ )}
- {roles.map((role) => ( -
-
-
- - - + {role.is_default && ( + + Default + + )} +
+ {role.description && ( +

+ {role.description} +

+ )} + + +
+ +
+ + {!role.is_system_role && ( +
+ + + + + e.preventDefault()}> + {canUpdate && ( + setEditingRoleId(role.id)}> + + Edit Role + + )} + {canDelete && ( + <> + + + + e.preventDefault()}> + + Delete Role + + + + + Delete role? + + This will permanently delete the "{role.name}" role. + Members with this role will lose their permissions. + + + + Cancel + onDeleteRole(role.id)} + className="bg-destructive text-destructive-foreground hover:bg-destructive/90" + > + Delete + + + + + + )} + + +
+ )} + +
-
- -
- - {!role.is_system_role && ( -
- - - - - e.preventDefault()}> - {canUpdate && ( - setEditingRoleId(role.id)}> - - Edit Role - - )} - {canDelete && ( - <> - - - - e.preventDefault()}> - - Delete Role - - - - - Delete role? - - This will permanently delete the "{role.name}" role. - Members with this role will lose their permissions. - - - - Cancel - onDeleteRole(role.id)} - className="bg-destructive text-destructive-foreground hover:bg-destructive/90" + {isExpanded && ( +
+ {isFullAccess ? ( +
+ +

+ Full access — all permissions granted across every category +

+
+ ) : ( +
+ {sortedCategories.map((category) => { + const actions = grouped[category]; + const config = CATEGORY_CONFIG[category] || { + label: category, + icon: FileText, + }; + const IconComponent = config.icon; + return ( +
+
+ + + {config.label} + +
+
+ {actions.map((action) => ( + - Delete - - - - - - )} - - + {ACTION_LABELS[action] || action.replace(/_/g, " ")} + + ))} +
+
+ ); + })} +
+ )}
)}
-
- ))} + ); + })}
); @@ -676,46 +661,51 @@ function PermissionsEditor({ return (
-
+
onToggleCategory(category)} - onClick={(e) => e.stopPropagation()} aria-label={`Select all ${config.label} permissions`} /> -
toggleCategoryExpanded(category)} > - -
+ +
+
- +
{isExpanded && (
@@ -726,28 +716,29 @@ function PermissionsEditor({ const isSelected = selectedPermissions.includes(perm.value); return ( -
+ onTogglePermission(perm.value)} - onClick={(e) => e.stopPropagation()} className="shrink-0" /> - +
); })}
diff --git a/surfsense_web/components/settings/search-space-settings-dialog.tsx b/surfsense_web/components/settings/search-space-settings-dialog.tsx index 56ad0ab8f..e021e1e41 100644 --- a/surfsense_web/components/settings/search-space-settings-dialog.tsx +++ b/surfsense_web/components/settings/search-space-settings-dialog.tsx @@ -7,7 +7,7 @@ import { Brain, CircleUser, Earth, - Eye, + ScanEye, ImageIcon, ListChecks, UserKey, @@ -25,10 +25,10 @@ const GeneralSettingsManager = dynamic( })), { ssr: false } ); -const ModelConfigManager = dynamic( +const AgentModelManager = dynamic( () => - import("@/components/settings/model-config-manager").then((m) => ({ - default: m.ModelConfigManager, + import("@/components/settings/agent-model-manager").then((m) => ({ + default: m.AgentModelManager, })), { ssr: false } ); @@ -88,7 +88,7 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings const navItems = [ { value: "general", label: t("nav_general"), icon: }, { value: "roles", label: t("nav_role_assignments"), icon: }, - { value: "models", label: t("nav_agent_configs"), icon: }, + { value: "models", label: t("nav_agent_models"), icon: }, { value: "image-models", label: t("nav_image_models"), @@ -97,7 +97,7 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings { value: "vision-models", label: t("nav_vision_models"), - icon: , + icon: , }, { value: "team-roles", label: t("nav_team_roles"), icon: }, { @@ -115,7 +115,7 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings const content: Record = { general: , - models: , + models: , roles: , "image-models": , "vision-models": , diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx index 57ea8c205..81528c86a 100644 --- a/surfsense_web/components/settings/vision-model-manager.tsx +++ b/surfsense_web/components/settings/vision-model-manager.tsx @@ -208,20 +208,20 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( -
+
+
-
- - -
-
- - - +
+ + +
+ + +
@@ -253,19 +253,25 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
-
-
-

- {config.name} -

- {config.description && ( -

- {config.description} -

- )} + {/* Header: Icon + Name + Actions */} +
+
+
+ {getProviderIcon(config.provider, { className: "size-4" })} +
+
+

+ {config.name} +

+ {config.description && ( +

+ {config.description} +

+ )} +
{(canUpdate || canDelete) && ( -
+
{canUpdate && ( @@ -274,7 +280,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { variant="ghost" size="icon" onClick={() => openEditDialog(config)} - className="h-7 w-7 text-muted-foreground hover:text-foreground" + className="h-6 w-6 text-muted-foreground hover:text-foreground" > @@ -291,7 +297,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { variant="ghost" size="icon" onClick={() => setConfigToDelete(config)} - className="h-7 w-7 text-muted-foreground hover:text-destructive" + className="h-7 w-7 rounded-lg text-muted-foreground hover:text-destructive" > @@ -304,17 +310,9 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { )}
-
- {getProviderIcon(config.provider, { - className: "size-3.5 shrink-0", - })} - - {config.model_name} - -
- -
- + {/* Footer: Date + Creator */} +
+ {new Date(config.created_at).toLocaleDateString(undefined, { year: "numeric", month: "short", @@ -323,11 +321,11 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { {member && ( <> - + -
+
{member.avatarUrl && ( @@ -336,7 +334,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { {getInitials(member.name)} - + {member.name}
diff --git a/surfsense_web/components/tool-ui/confluence/create-confluence-page.tsx b/surfsense_web/components/tool-ui/confluence/create-confluence-page.tsx index d22f317b8..f8cd6ee15 100644 --- a/surfsense_web/components/tool-ui/confluence/create-confluence-page.tsx +++ b/surfsense_web/components/tool-ui/confluence/create-confluence-page.tsx @@ -15,6 +15,8 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; +import { isInterruptResult, useHitlDecision } from "@/lib/hitl"; +import type { InterruptResult, HitlDecision } from "@/lib/hitl"; import { useHitlPhase } from "@/hooks/use-hitl-phase"; interface ConfluenceAccount { @@ -30,24 +32,10 @@ interface ConfluenceSpace { name: string; } -interface InterruptResult { - __interrupt__: true; - __decided__?: "approve" | "reject" | "edit"; - __completed__?: boolean; - action_requests: Array<{ - name: string; - args: Record; - }>; - review_configs: Array<{ - action_name: string; - allowed_decisions: Array<"approve" | "edit" | "reject">; - }>; - interrupt_type?: string; - context?: { - accounts?: ConfluenceAccount[]; - spaces?: ConfluenceSpace[]; - error?: string; - }; +type CreateConfluencePageInterruptContext = { + accounts?: ConfluenceAccount[]; + spaces?: ConfluenceSpace[]; + error?: string; } interface SuccessResult { @@ -76,21 +64,12 @@ interface InsufficientPermissionsResult { } type CreateConfluencePageResult = - | InterruptResult + | InterruptResult | SuccessResult | ErrorResult | AuthErrorResult | InsufficientPermissionsResult; -function isInterruptResult(result: unknown): result is InterruptResult { - return ( - typeof result === "object" && - result !== null && - "__interrupt__" in result && - (result as InterruptResult).__interrupt__ === true - ); -} - function isErrorResult(result: unknown): result is ErrorResult { return ( typeof result === "object" && @@ -124,12 +103,8 @@ function ApprovalCard({ onDecision, }: { args: { title: string; content?: string; space_id?: string }; - interruptData: InterruptResult; - onDecision: (decision: { - type: "approve" | "reject" | "edit"; - message?: string; - edited_action?: { name: string; args: Record }; - }) => void; + interruptData: InterruptResult; + onDecision: (decision: HitlDecision) => void; }) { const { phase, setProcessing, setRejected } = useHitlPhase(interruptData); const [isPanelOpen, setIsPanelOpen] = useState(false); @@ -464,18 +439,16 @@ export const CreateConfluencePageToolUI = ({ { title: string; content?: string; space_id?: string }, CreateConfluencePageResult >) => { + const { dispatch } = useHitlDecision(); + if (!result) return null; if (isInterruptResult(result)) { return ( { - window.dispatchEvent( - new CustomEvent("hitl-decision", { detail: { decisions: [decision] } }) - ); - }} + interruptData={result as InterruptResult} + onDecision={(decision) => dispatch([decision])} /> ); } diff --git a/surfsense_web/components/tool-ui/confluence/delete-confluence-page.tsx b/surfsense_web/components/tool-ui/confluence/delete-confluence-page.tsx index 258f259f0..37d73377e 100644 --- a/surfsense_web/components/tool-ui/confluence/delete-confluence-page.tsx +++ b/surfsense_web/components/tool-ui/confluence/delete-confluence-page.tsx @@ -6,38 +6,26 @@ import { useCallback, useEffect, useState } from "react"; import { TextShimmerLoader } from "@/components/prompt-kit/loader"; import { Button } from "@/components/ui/button"; import { Checkbox } from "@/components/ui/checkbox"; +import { isInterruptResult, useHitlDecision } from "@/lib/hitl"; +import type { InterruptResult, HitlDecision } from "@/lib/hitl"; import { useHitlPhase } from "@/hooks/use-hitl-phase"; -interface InterruptResult { - __interrupt__: true; - __decided__?: "approve" | "reject"; - __completed__?: boolean; - action_requests: Array<{ +type DeleteConfluencePageInterruptContext = { + account?: { + id: number; name: string; - args: Record; - }>; - review_configs: Array<{ - action_name: string; - allowed_decisions: Array<"approve" | "reject">; - }>; - interrupt_type?: string; - context?: { - account?: { - id: number; - name: string; - base_url: string; - auth_expired?: boolean; - }; - page?: { - page_id: string; - page_title: string; - space_id: string; - connector_id?: number; - document_id?: number; - indexed_at?: string; - }; - error?: string; + base_url: string; + auth_expired?: boolean; }; + page?: { + page_id: string; + page_title: string; + space_id: string; + connector_id?: number; + document_id?: number; + indexed_at?: string; + }; + error?: string; } interface SuccessResult { @@ -77,7 +65,7 @@ interface InsufficientPermissionsResult { } type DeleteConfluencePageResult = - | InterruptResult + | InterruptResult | SuccessResult | ErrorResult | NotFoundResult @@ -85,15 +73,6 @@ type DeleteConfluencePageResult = | AuthErrorResult | InsufficientPermissionsResult; -function isInterruptResult(result: unknown): result is InterruptResult { - return ( - typeof result === "object" && - result !== null && - "__interrupt__" in result && - (result as InterruptResult).__interrupt__ === true - ); -} - function isErrorResult(result: unknown): result is ErrorResult { return ( typeof result === "object" && @@ -145,12 +124,8 @@ function ApprovalCard({ interruptData, onDecision, }: { - interruptData: InterruptResult; - onDecision: (decision: { - type: "approve" | "reject"; - message?: string; - edited_action?: { name: string; args: Record }; - }) => void; + interruptData: InterruptResult; + onDecision: (decision: HitlDecision) => void; }) { const { phase, setProcessing, setRejected } = useHitlPhase(interruptData); const [deleteFromKb, setDeleteFromKb] = useState(false); @@ -402,18 +377,15 @@ export const DeleteConfluencePageToolUI = ({ { page_title_or_id: string; delete_from_kb?: boolean }, DeleteConfluencePageResult >) => { + const { dispatch } = useHitlDecision(); + if (!result) return null; if (isInterruptResult(result)) { return ( { - const event = new CustomEvent("hitl-decision", { - detail: { decisions: [decision] }, - }); - window.dispatchEvent(event); - }} + interruptData={result as InterruptResult} + onDecision={(decision) => dispatch([decision])} /> ); } diff --git a/surfsense_web/components/tool-ui/confluence/update-confluence-page.tsx b/surfsense_web/components/tool-ui/confluence/update-confluence-page.tsx index 8bc0772a4..76df10e40 100644 --- a/surfsense_web/components/tool-ui/confluence/update-confluence-page.tsx +++ b/surfsense_web/components/tool-ui/confluence/update-confluence-page.tsx @@ -8,39 +8,27 @@ import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { PlateEditor } from "@/components/editor/plate-editor"; import { TextShimmerLoader } from "@/components/prompt-kit/loader"; import { Button } from "@/components/ui/button"; +import { isInterruptResult, useHitlDecision } from "@/lib/hitl"; +import type { InterruptResult, HitlDecision } from "@/lib/hitl"; import { useHitlPhase } from "@/hooks/use-hitl-phase"; -interface InterruptResult { - __interrupt__: true; - __decided__?: "approve" | "reject" | "edit"; - __completed__?: boolean; - action_requests: Array<{ +type UpdateConfluencePageInterruptContext = { + account?: { + id: number; name: string; - args: Record; - }>; - review_configs: Array<{ - action_name: string; - allowed_decisions: Array<"approve" | "edit" | "reject">; - }>; - interrupt_type?: string; - context?: { - account?: { - id: number; - name: string; - base_url: string; - auth_expired?: boolean; - }; - page?: { - page_id: string; - page_title: string; - space_id: string; - body: string; - version: number; - document_id: number; - indexed_at?: string; - }; - error?: string; + base_url: string; + auth_expired?: boolean; }; + page?: { + page_id: string; + page_title: string; + space_id: string; + body: string; + version: number; + document_id: number; + indexed_at?: string; + }; + error?: string; } interface SuccessResult { @@ -74,22 +62,13 @@ interface InsufficientPermissionsResult { } type UpdateConfluencePageResult = - | InterruptResult + | InterruptResult | SuccessResult | ErrorResult | NotFoundResult | AuthErrorResult | InsufficientPermissionsResult; -function isInterruptResult(result: unknown): result is InterruptResult { - return ( - typeof result === "object" && - result !== null && - "__interrupt__" in result && - (result as InterruptResult).__interrupt__ === true - ); -} - function isErrorResult(result: unknown): result is ErrorResult { return ( typeof result === "object" && @@ -136,12 +115,8 @@ function ApprovalCard({ new_title?: string; new_content?: string; }; - interruptData: InterruptResult; - onDecision: (decision: { - type: "approve" | "reject" | "edit"; - message?: string; - edited_action?: { name: string; args: Record }; - }) => void; + interruptData: InterruptResult; + onDecision: (decision: HitlDecision) => void; }) { const { phase, setProcessing, setRejected } = useHitlPhase(interruptData); @@ -502,18 +477,16 @@ export const UpdateConfluencePageToolUI = ({ }, UpdateConfluencePageResult >) => { + const { dispatch } = useHitlDecision(); + if (!result) return null; if (isInterruptResult(result)) { return ( { - window.dispatchEvent( - new CustomEvent("hitl-decision", { detail: { decisions: [decision] } }) - ); - }} + interruptData={result as InterruptResult} + onDecision={(decision) => dispatch([decision])} /> ); } diff --git a/surfsense_web/components/tool-ui/dropbox/create-file.tsx b/surfsense_web/components/tool-ui/dropbox/create-file.tsx index ac45f1f5b..15d454c76 100644 --- a/surfsense_web/components/tool-ui/dropbox/create-file.tsx +++ b/surfsense_web/components/tool-ui/dropbox/create-file.tsx @@ -16,6 +16,8 @@ import { SelectValue, } from "@/components/ui/select"; import { useHitlPhase } from "@/hooks/use-hitl-phase"; +import { isInterruptResult, useHitlDecision } from "@/lib/hitl"; +import type { InterruptResult, HitlDecision } from "@/lib/hitl"; interface DropboxAccount { id: number; @@ -29,21 +31,11 @@ interface SupportedType { label: string; } -interface InterruptResult { - __interrupt__: true; - __decided__?: "approve" | "reject" | "edit"; - __completed__?: boolean; - action_requests: Array<{ name: string; args: Record }>; - review_configs: Array<{ - action_name: string; - allowed_decisions: Array<"approve" | "edit" | "reject">; - }>; - context?: { - accounts?: DropboxAccount[]; - parent_folders?: Record>; - supported_types?: SupportedType[]; - error?: string; - }; +type DropboxCreateFileContext = { + accounts?: DropboxAccount[]; + parent_folders?: Record>; + supported_types?: SupportedType[]; + error?: string; } interface SuccessResult { @@ -65,16 +57,7 @@ interface AuthErrorResult { connector_type?: string; } -type CreateDropboxFileResult = InterruptResult | SuccessResult | ErrorResult | AuthErrorResult; - -function isInterruptResult(result: unknown): result is InterruptResult { - return ( - typeof result === "object" && - result !== null && - "__interrupt__" in result && - (result as InterruptResult).__interrupt__ === true - ); -} +type CreateDropboxFileResult = InterruptResult | SuccessResult | ErrorResult | AuthErrorResult; function isErrorResult(result: unknown): result is ErrorResult { return ( @@ -100,12 +83,8 @@ function ApprovalCard({ onDecision, }: { args: { name: string; file_type?: string; content?: string }; - interruptData: InterruptResult; - onDecision: (decision: { - type: "approve" | "reject" | "edit"; - message?: string; - edited_action?: { name: string; args: Record }; - }) => void; + interruptData: InterruptResult; + onDecision: (decision: HitlDecision) => void; }) { const { phase, setProcessing, setRejected } = useHitlPhase(interruptData); const [isPanelOpen, setIsPanelOpen] = useState(false); @@ -455,17 +434,14 @@ export const CreateDropboxFileToolUI = ({ { name: string; file_type?: string; content?: string }, CreateDropboxFileResult >) => { + const { dispatch } = useHitlDecision(); if (!result) return null; if (isInterruptResult(result)) { return ( { - window.dispatchEvent( - new CustomEvent("hitl-decision", { detail: { decisions: [decision] } }) - ); - }} + interruptData={result as InterruptResult} + onDecision={(decision) => dispatch([decision])} /> ); } diff --git a/surfsense_web/components/tool-ui/dropbox/trash-file.tsx b/surfsense_web/components/tool-ui/dropbox/trash-file.tsx index 0b38777c1..a2fadd20f 100644 --- a/surfsense_web/components/tool-ui/dropbox/trash-file.tsx +++ b/surfsense_web/components/tool-ui/dropbox/trash-file.tsx @@ -7,6 +7,8 @@ import { TextShimmerLoader } from "@/components/prompt-kit/loader"; import { Button } from "@/components/ui/button"; import { Checkbox } from "@/components/ui/checkbox"; import { useHitlPhase } from "@/hooks/use-hitl-phase"; +import { isInterruptResult, useHitlDecision } from "@/lib/hitl"; +import type { InterruptResult, HitlDecision } from "@/lib/hitl"; interface DropboxAccount { id: number; @@ -22,13 +24,10 @@ interface DropboxFile { document_id?: number; } -interface InterruptResult { - __interrupt__: true; - __decided__?: "approve" | "reject"; - __completed__?: boolean; - action_requests: Array<{ name: string; args: Record }>; - review_configs: Array<{ action_name: string; allowed_decisions: Array<"approve" | "reject"> }>; - context?: { account?: DropboxAccount; file?: DropboxFile; error?: string }; +type DropboxTrashFileContext = { + account?: DropboxAccount; + file?: DropboxFile; + error?: string; } interface SuccessResult { @@ -52,20 +51,12 @@ interface AuthErrorResult { } type DeleteDropboxFileResult = - | InterruptResult + | InterruptResult | SuccessResult | ErrorResult | NotFoundResult | AuthErrorResult; -function isInterruptResult(result: unknown): result is InterruptResult { - return ( - typeof result === "object" && - result !== null && - "__interrupt__" in result && - (result as InterruptResult).__interrupt__ === true - ); -} function isErrorResult(result: unknown): result is ErrorResult { return ( typeof result === "object" && @@ -95,12 +86,8 @@ function ApprovalCard({ interruptData, onDecision, }: { - interruptData: InterruptResult; - onDecision: (decision: { - type: "approve" | "reject"; - message?: string; - edited_action?: { name: string; args: Record }; - }) => void; + interruptData: InterruptResult; + onDecision: (decision: HitlDecision) => void; }) { const { phase, setProcessing, setRejected } = useHitlPhase(interruptData); const [deleteFromKb, setDeleteFromKb] = useState(false); @@ -308,16 +295,13 @@ export const DeleteDropboxFileToolUI = ({ { file_name: string; delete_from_kb?: boolean }, DeleteDropboxFileResult >) => { + const { dispatch } = useHitlDecision(); if (!result) return null; if (isInterruptResult(result)) { return ( { - window.dispatchEvent( - new CustomEvent("hitl-decision", { detail: { decisions: [decision] } }) - ); - }} + interruptData={result as InterruptResult} + onDecision={(decision) => dispatch([decision])} /> ); } diff --git a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx new file mode 100644 index 000000000..2a26b18f7 --- /dev/null +++ b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx @@ -0,0 +1,264 @@ +"use client"; + +import type { ToolCallMessagePartComponent } from "@assistant-ui/react"; +import { CornerDownLeftIcon, Pen } from "lucide-react"; +import { useCallback, useEffect, useMemo, useState } from "react"; +import { TextShimmerLoader } from "@/components/prompt-kit/loader"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Textarea } from "@/components/ui/textarea"; +import { useHitlPhase } from "@/hooks/use-hitl-phase"; +import { connectorsApiService } from "@/lib/apis/connectors-api.service"; +import { isInterruptResult, useHitlDecision } from "@/lib/hitl"; +import type { HitlDecision, InterruptResult } from "@/lib/hitl"; + +function ParamEditor({ + params, + onChange, + disabled, +}: { + params: Record; + onChange: (updated: Record) => void; + disabled: boolean; +}) { + const entries = Object.entries(params); + if (entries.length === 0) return null; + + return ( +
+ {entries.map(([key, value]) => { + const strValue = value == null ? "" : String(value); + const isLong = strValue.length > 120; + const fieldId = `hitl-param-${key}`; + + return ( +
+ + {isLong ? ( +