mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-28 10:26:33 +02:00
refactor: enhance deduplication logic for HITL tool calls
Updated the deduplication mechanism in the DedupHITLToolCallsMiddleware to utilize a comprehensive list of native HITL tools. The deduplication keys are now dynamically populated from both hardcoded values and metadata from StructuredTool instances. Additionally, integrated HITL approval into MCP tool creation, ensuring all tools are gated by user approval, with the ability to bypass for trusted tools.
This commit is contained in:
parent
0c4fd30cce
commit
82c7d4a2ab
2 changed files with 137 additions and 99 deletions
|
|
@ -20,19 +20,39 @@ from langgraph.runtime import Runtime
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_HITL_TOOL_DEDUP_KEYS: dict[str, str] = {
|
||||
"delete_calendar_event": "event_title_or_id",
|
||||
"update_calendar_event": "event_title_or_id",
|
||||
"trash_gmail_email": "email_subject_or_id",
|
||||
_NATIVE_HITL_TOOL_DEDUP_KEYS: dict[str, str] = {
|
||||
# Gmail
|
||||
"send_gmail_email": "subject",
|
||||
"create_gmail_draft": "subject",
|
||||
"update_gmail_draft": "draft_subject_or_id",
|
||||
"trash_gmail_email": "email_subject_or_id",
|
||||
# Google Calendar
|
||||
"create_calendar_event": "title",
|
||||
"update_calendar_event": "event_title_or_id",
|
||||
"delete_calendar_event": "event_title_or_id",
|
||||
# Google Drive
|
||||
"create_google_drive_file": "file_name",
|
||||
"delete_google_drive_file": "file_name",
|
||||
# OneDrive
|
||||
"create_onedrive_file": "file_name",
|
||||
"delete_onedrive_file": "file_name",
|
||||
"delete_notion_page": "page_title",
|
||||
# Dropbox
|
||||
"create_dropbox_file": "file_name",
|
||||
"delete_dropbox_file": "file_name",
|
||||
# Notion
|
||||
"create_notion_page": "title",
|
||||
"update_notion_page": "page_title",
|
||||
"delete_linear_issue": "issue_ref",
|
||||
"delete_notion_page": "page_title",
|
||||
# Linear
|
||||
"create_linear_issue": "title",
|
||||
"update_linear_issue": "issue_ref",
|
||||
"delete_linear_issue": "issue_ref",
|
||||
# Jira
|
||||
"create_jira_issue": "summary",
|
||||
"update_jira_issue": "issue_title_or_key",
|
||||
"delete_jira_issue": "issue_title_or_key",
|
||||
# Confluence
|
||||
"create_confluence_page": "title",
|
||||
"update_confluence_page": "page_title_or_id",
|
||||
"delete_confluence_page": "page_title_or_id",
|
||||
}
|
||||
|
|
@ -43,22 +63,38 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
Only the **first** occurrence of each (tool-name, primary-arg-value)
|
||||
pair is kept; subsequent duplicates are silently dropped.
|
||||
|
||||
The dedup map is built from two sources:
|
||||
|
||||
1. A comprehensive list of native HITL tools (hardcoded above).
|
||||
2. Any ``StructuredTool`` instances passed via *agent_tools* whose
|
||||
``metadata`` contains ``{"hitl": True, "hitl_dedup_key": "..."}``.
|
||||
This is how MCP tools automatically get dedup support.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
|
||||
self._dedup_keys: dict[str, str] = dict(_NATIVE_HITL_TOOL_DEDUP_KEYS)
|
||||
for t in agent_tools or []:
|
||||
meta = getattr(t, "metadata", None) or {}
|
||||
if meta.get("hitl") and meta.get("hitl_dedup_key"):
|
||||
self._dedup_keys[t.name] = meta["hitl_dedup_key"]
|
||||
|
||||
def after_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state)
|
||||
return self._dedup(state, self._dedup_keys)
|
||||
|
||||
async def aafter_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state)
|
||||
return self._dedup(state, self._dedup_keys)
|
||||
|
||||
@staticmethod
|
||||
def _dedup(state: AgentState) -> dict[str, Any] | None: # type: ignore[type-arg]
|
||||
def _dedup(
|
||||
state: AgentState, dedup_keys: dict[str, str] # type: ignore[type-arg]
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages")
|
||||
if not messages:
|
||||
return None
|
||||
|
|
@ -73,7 +109,7 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
dedup_key_arg = _HITL_TOOL_DEDUP_KEYS.get(name)
|
||||
dedup_key_arg = dedup_keys.get(name)
|
||||
if dedup_key_arg is not None:
|
||||
arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower()
|
||||
key = (name, arg_val)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue