mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 17:22:38 +02:00
Merge branch 'dev' into feat/e2e-testing
This commit is contained in:
commit
fa31da9937
100 changed files with 3751 additions and 1122 deletions
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from .main_agent import create_surfsense_deep_agent
|
||||
from .main_agent import create_multi_agent_chat_deep_agent
|
||||
|
||||
__all__ = ["create_surfsense_deep_agent"]
|
||||
__all__ = ["create_multi_agent_chat_deep_agent"]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from .runtime import create_surfsense_deep_agent
|
||||
from .runtime import create_multi_agent_chat_deep_agent
|
||||
|
||||
__all__ = ["create_surfsense_deep_agent"]
|
||||
__all__ = ["create_multi_agent_chat_deep_agent"]
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ from langchain_core.language_models import BaseChatModel
|
|||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.middleware import (
|
||||
build_main_agent_deepagent_middleware,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
|
@ -19,8 +22,6 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
|||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .middleware import build_main_agent_deepagent_middleware
|
||||
|
||||
|
||||
def build_compiled_agent_graph_sync(
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
"""Main-agent graph middleware assembly (SurfSense + LangChain + deepagents)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .deepagent_stack import build_main_agent_deepagent_middleware
|
||||
|
||||
__all__ = ["build_main_agent_deepagent_middleware"]
|
||||
|
|
@ -1,506 +0,0 @@
|
|||
"""Assemble the main-agent deep-agent middleware list (LangChain + SurfSense + deepagents)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from deepagents.backends import StateBackend
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from deepagents.middleware.skills import SkillsMiddleware
|
||||
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||
from langchain.agents.middleware import (
|
||||
LLMToolSelectorMiddleware,
|
||||
ModelCallLimitMiddleware,
|
||||
ModelFallbackMiddleware,
|
||||
TodoListMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents import (
|
||||
build_subagents,
|
||||
get_subagents_to_exclude,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import (
|
||||
ActionLogMiddleware,
|
||||
AnonymousDocumentMiddleware,
|
||||
BusyMutexMiddleware,
|
||||
ClearToolUsesEdit,
|
||||
DedupHITLToolCallsMiddleware,
|
||||
DoomLoopMiddleware,
|
||||
FileIntentMiddleware,
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
KnowledgePriorityMiddleware,
|
||||
KnowledgeTreeMiddleware,
|
||||
MemoryInjectionMiddleware,
|
||||
NoopInjectionMiddleware,
|
||||
OtelSpanMiddleware,
|
||||
PermissionMiddleware,
|
||||
RetryAfterMiddleware,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
SurfSenseFilesystemMiddleware,
|
||||
ToolCallNameRepairMiddleware,
|
||||
build_skills_backend_factory,
|
||||
create_surfsense_compaction_middleware,
|
||||
default_skills_sources,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
PluginContext,
|
||||
load_allowed_plugin_names_from_env,
|
||||
load_plugin_middlewares,
|
||||
)
|
||||
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ...context_prune.prune_tool_names import safe_exclude_tools
|
||||
from .checkpointed_subagent_middleware import SurfSenseCheckpointedSubAgentMiddleware
|
||||
|
||||
|
||||
def build_main_agent_deepagent_middleware(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
visibility: ChatVisibility,
|
||||
anon_session_id: str | None,
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
max_input_tokens: int | None,
|
||||
flags: AgentFeatureFlags,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
checkpointer: Checkpointer,
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Build ordered middleware for ``create_agent`` (Nones already stripped)."""
|
||||
_memory_middleware = MemoryInjectionMiddleware(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
thread_visibility=visibility,
|
||||
)
|
||||
|
||||
gp_middleware = [
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
FileIntentMiddleware(llm=llm),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
create_surfsense_compaction_middleware(llm, StateBackend),
|
||||
PatchToolCallsMiddleware(),
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
|
||||
# Build permission rulesets up front so the GP subagent can mirror ``ask``
|
||||
# rules into ``interrupt_on``: tool calls emitted from within ``task`` runs
|
||||
# never reach the parent's ``PermissionMiddleware``.
|
||||
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||
permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack
|
||||
permission_rulesets: list[Ruleset] = []
|
||||
if permission_enabled or is_desktop_fs:
|
||||
permission_rulesets.append(
|
||||
Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
)
|
||||
)
|
||||
if is_desktop_fs:
|
||||
permission_rulesets.append(
|
||||
Ruleset(
|
||||
rules=[
|
||||
Rule(permission="rm", pattern="*", action="ask"),
|
||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||
Rule(permission="move_file", pattern="*", action="ask"),
|
||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||
Rule(permission="write_file", pattern="*", action="ask"),
|
||||
],
|
||||
origin="desktop_safety",
|
||||
)
|
||||
)
|
||||
|
||||
# Tools that self-prompt via ``request_approval`` must not also appear
|
||||
# as ``ask`` rules — that would double-prompt the user for one call.
|
||||
_tool_names_in_use = {t.name for t in tools}
|
||||
|
||||
# Deny parent-bound tools whose ``required_connector`` is missing.
|
||||
# No-op today (connector subagents are pruned upstream); guards future
|
||||
# additions to the parent's tool list.
|
||||
if permission_enabled:
|
||||
_available_set = set(available_connectors or [])
|
||||
_synthesized: list[Rule] = []
|
||||
for tool_def in BUILTIN_TOOLS:
|
||||
if tool_def.name not in _tool_names_in_use:
|
||||
continue
|
||||
rc = tool_def.required_connector
|
||||
if rc and rc not in _available_set:
|
||||
_synthesized.append(
|
||||
Rule(permission=tool_def.name, pattern="*", action="deny")
|
||||
)
|
||||
if _synthesized:
|
||||
permission_rulesets.append(
|
||||
Ruleset(rules=_synthesized, origin="connector_synthesized")
|
||||
)
|
||||
gp_interrupt_on: dict[str, bool] = {
|
||||
rule.permission: True
|
||||
for rs in permission_rulesets
|
||||
for rule in rs.rules
|
||||
if rule.action == "ask" and rule.permission in _tool_names_in_use
|
||||
}
|
||||
|
||||
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
|
||||
**GENERAL_PURPOSE_SUBAGENT,
|
||||
"model": llm,
|
||||
"tools": tools,
|
||||
"middleware": gp_middleware,
|
||||
}
|
||||
if gp_interrupt_on:
|
||||
general_purpose_spec["interrupt_on"] = gp_interrupt_on
|
||||
|
||||
# Deny-only on subagents: ``task`` runs bypass the parent's
|
||||
# PermissionMiddleware, while bucket-based ask gates own the ask path.
|
||||
subagent_deny_rulesets: list[Ruleset] = [
|
||||
Ruleset(
|
||||
rules=[r for r in rs.rules if r.action == "deny"],
|
||||
origin=rs.origin,
|
||||
)
|
||||
for rs in permission_rulesets
|
||||
]
|
||||
subagent_deny_rulesets = [rs for rs in subagent_deny_rulesets if rs.rules]
|
||||
|
||||
subagent_deny_permission_mw: PermissionMiddleware | None = (
|
||||
PermissionMiddleware(rulesets=subagent_deny_rulesets)
|
||||
if subagent_deny_rulesets
|
||||
else None
|
||||
)
|
||||
|
||||
if subagent_deny_permission_mw is not None:
|
||||
# Run deny check on already-repaired tool calls; insert before
|
||||
# PatchToolCallsMiddleware (append if the slot moves).
|
||||
_patch_idx = next(
|
||||
(
|
||||
i
|
||||
for i, m in enumerate(gp_middleware)
|
||||
if isinstance(m, PatchToolCallsMiddleware)
|
||||
),
|
||||
len(gp_middleware),
|
||||
)
|
||||
gp_middleware.insert(_patch_idx, subagent_deny_permission_mw)
|
||||
|
||||
registry_subagents: list[SubAgent] = []
|
||||
try:
|
||||
subagent_extra_middleware: list[Any] = [
|
||||
TodoListMiddleware(),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
]
|
||||
if subagent_deny_permission_mw is not None:
|
||||
subagent_extra_middleware.append(subagent_deny_permission_mw)
|
||||
registry_subagents = build_subagents(
|
||||
dependencies=subagent_dependencies,
|
||||
model=llm,
|
||||
extra_middleware=subagent_extra_middleware,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent or {},
|
||||
exclude=get_subagents_to_exclude(available_connectors),
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
logging.info(
|
||||
"Registry subagents: %s",
|
||||
[s["name"] for s in registry_subagents],
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Registry subagent build failed")
|
||||
raise
|
||||
|
||||
subagent_specs: list[SubAgent] = [general_purpose_spec, *registry_subagents]
|
||||
|
||||
summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend)
|
||||
|
||||
context_edit_mw = None
|
||||
if (
|
||||
flags.enable_context_editing
|
||||
and not flags.disable_new_agent_stack
|
||||
and max_input_tokens
|
||||
):
|
||||
spill_edit = SpillToBackendEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
)
|
||||
clear_edit = ClearToolUsesEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
placeholder="[cleared - older tool output trimmed for context]",
|
||||
)
|
||||
context_edit_mw = SpillingContextEditingMiddleware(
|
||||
edits=[spill_edit, clear_edit],
|
||||
backend_resolver=backend_resolver,
|
||||
)
|
||||
|
||||
retry_mw = (
|
||||
RetryAfterMiddleware(max_retries=3)
|
||||
if flags.enable_retry_after and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
fallback_mw: ModelFallbackMiddleware | None = None
|
||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
fallback_mw = ModelFallbackMiddleware(
|
||||
"openai:gpt-4o-mini",
|
||||
"anthropic:claude-3-5-haiku-20241022",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
||||
fallback_mw = None
|
||||
model_call_limit_mw = (
|
||||
ModelCallLimitMiddleware(
|
||||
thread_limit=120,
|
||||
run_limit=80,
|
||||
exit_behavior="end",
|
||||
)
|
||||
if flags.enable_model_call_limit and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
tool_call_limit_mw = (
|
||||
ToolCallLimitMiddleware(
|
||||
thread_limit=300, run_limit=80, exit_behavior="continue"
|
||||
)
|
||||
if flags.enable_tool_call_limit and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
noop_mw = (
|
||||
NoopInjectionMiddleware()
|
||||
if flags.enable_compaction_v2 and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
repair_mw = None
|
||||
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
|
||||
registered_names: set[str] = {t.name for t in tools}
|
||||
registered_names |= {
|
||||
"write_todos",
|
||||
"ls",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"glob",
|
||||
"grep",
|
||||
"execute",
|
||||
"task",
|
||||
"mkdir",
|
||||
"cd",
|
||||
"pwd",
|
||||
"move_file",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"execute_code",
|
||||
}
|
||||
repair_mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names=registered_names,
|
||||
fuzzy_match_threshold=None,
|
||||
)
|
||||
|
||||
doom_loop_mw = (
|
||||
DoomLoopMiddleware(threshold=3)
|
||||
if flags.enable_doom_loop and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
permission_mw: PermissionMiddleware | None = (
|
||||
PermissionMiddleware(rulesets=permission_rulesets)
|
||||
if permission_rulesets
|
||||
else None
|
||||
)
|
||||
|
||||
action_log_mw: ActionLogMiddleware | None = None
|
||||
if (
|
||||
flags.enable_action_log
|
||||
and not flags.disable_new_agent_stack
|
||||
and thread_id is not None
|
||||
):
|
||||
try:
|
||||
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
|
||||
action_log_mw = ActionLogMiddleware(
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
tool_definitions=tool_defs_by_name,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"ActionLogMiddleware init failed; running without it.",
|
||||
exc_info=True,
|
||||
)
|
||||
action_log_mw = None
|
||||
|
||||
busy_mutex_mw: BusyMutexMiddleware | None = (
|
||||
BusyMutexMiddleware()
|
||||
if flags.enable_busy_mutex and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
otel_mw: OtelSpanMiddleware | None = (
|
||||
OtelSpanMiddleware()
|
||||
if flags.enable_otel and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
plugin_middlewares: list[Any] = []
|
||||
if flags.enable_plugin_loader and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
allowed_names = load_allowed_plugin_names_from_env()
|
||||
if allowed_names:
|
||||
plugin_middlewares = load_plugin_middlewares(
|
||||
PluginContext.build(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_visibility=visibility,
|
||||
llm=llm,
|
||||
),
|
||||
allowed_plugin_names=allowed_names,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"Plugin loader failed; continuing without plugins.",
|
||||
exc_info=True,
|
||||
)
|
||||
plugin_middlewares = []
|
||||
|
||||
skills_mw: SkillsMiddleware | None = None
|
||||
if flags.enable_skills and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
skills_factory = build_skills_backend_factory(
|
||||
search_space_id=search_space_id
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
)
|
||||
skills_mw = SkillsMiddleware(
|
||||
backend=skills_factory,
|
||||
sources=default_skills_sources(),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
|
||||
skills_mw = None
|
||||
|
||||
selector_mw: LLMToolSelectorMiddleware | None = None
|
||||
if (
|
||||
flags.enable_llm_tool_selector
|
||||
and not flags.disable_new_agent_stack
|
||||
and len(tools) > 30
|
||||
):
|
||||
try:
|
||||
selector_mw = LLMToolSelectorMiddleware(
|
||||
model="openai:gpt-4o-mini",
|
||||
max_tools=12,
|
||||
always_include=[
|
||||
name
|
||||
for name in (
|
||||
"update_memory",
|
||||
"get_connected_accounts",
|
||||
"scrape_webpage",
|
||||
)
|
||||
if name in {t.name for t in tools}
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
|
||||
selector_mw = None
|
||||
|
||||
deepagent_middleware = [
|
||||
busy_mutex_mw,
|
||||
otel_mw,
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
AnonymousDocumentMiddleware(
|
||||
anon_session_id=anon_session_id,
|
||||
)
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
KnowledgeTreeMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
llm=llm,
|
||||
)
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
KnowledgePriorityMiddleware(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
),
|
||||
FileIntentMiddleware(llm=llm),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
KnowledgeBasePersistenceMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
skills_mw,
|
||||
SurfSenseCheckpointedSubAgentMiddleware(
|
||||
checkpointer=checkpointer,
|
||||
backend=StateBackend,
|
||||
subagents=subagent_specs,
|
||||
),
|
||||
selector_mw,
|
||||
model_call_limit_mw,
|
||||
tool_call_limit_mw,
|
||||
context_edit_mw,
|
||||
summarization_mw,
|
||||
noop_mw,
|
||||
retry_mw,
|
||||
fallback_mw,
|
||||
repair_mw,
|
||||
permission_mw,
|
||||
doom_loop_mw,
|
||||
action_log_mw,
|
||||
PatchToolCallsMiddleware(),
|
||||
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
|
||||
*plugin_middlewares,
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
return [m for m in deepagent_middleware if m is not None]
|
||||
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from .factory import create_surfsense_deep_agent
|
||||
from .factory import create_multi_agent_chat_deep_agent
|
||||
|
||||
__all__ = ["create_surfsense_deep_agent"]
|
||||
__all__ = ["create_multi_agent_chat_deep_agent"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,117 @@
|
|||
"""Compiled agent graph caching for the multi-agent path."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
|
||||
from app.agents.new_chat.agent_cache import (
|
||||
flags_signature,
|
||||
get_cache,
|
||||
stable_hash,
|
||||
system_prompt_hash,
|
||||
tools_signature,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||
|
||||
|
||||
def mcp_signature(mcp_tools_by_agent: dict[str, ToolsPermissions]) -> str:
|
||||
"""Hash the per-agent MCP tool surface so a change rotates the cache key."""
|
||||
rows = []
|
||||
for agent_name in sorted(mcp_tools_by_agent.keys()):
|
||||
perms = mcp_tools_by_agent[agent_name]
|
||||
allow_names = sorted(item.get("name", "") for item in perms.get("allow", []))
|
||||
ask_names = sorted(item.get("name", "") for item in perms.get("ask", []))
|
||||
rows.append((agent_name, allow_names, ask_names))
|
||||
return stable_hash(rows)
|
||||
|
||||
|
||||
async def build_agent_with_cache(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
final_system_prompt: str,
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
visibility: ChatVisibility,
|
||||
anon_session_id: str | None,
|
||||
available_connectors: list[str],
|
||||
available_document_types: list[str],
|
||||
mentioned_document_ids: list[int] | None,
|
||||
max_input_tokens: int | None,
|
||||
flags: AgentFeatureFlags,
|
||||
checkpointer: Checkpointer,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions],
|
||||
disabled_tools: list[str] | None,
|
||||
config_id: str | None,
|
||||
) -> Any:
|
||||
"""Compile the multi-agent graph, serving from cache when key components are stable."""
|
||||
|
||||
async def _build() -> Any:
|
||||
return await asyncio.to_thread(
|
||||
build_compiled_agent_graph_sync,
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
final_system_prompt=final_system_prompt,
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
visibility=visibility,
|
||||
anon_session_id=anon_session_id,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
max_input_tokens=max_input_tokens,
|
||||
flags=flags,
|
||||
checkpointer=checkpointer,
|
||||
subagent_dependencies=subagent_dependencies,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
|
||||
if not (flags.enable_agent_cache and not flags.disable_new_agent_stack):
|
||||
return await _build()
|
||||
|
||||
# Every per-request value any middleware closes over at __init__ must be in
|
||||
# the key, otherwise a hit will leak state across threads. Bump the schema
|
||||
# version when the component list changes shape.
|
||||
cache_key = stable_hash(
|
||||
"multi-agent-v1",
|
||||
config_id,
|
||||
thread_id,
|
||||
user_id,
|
||||
search_space_id,
|
||||
visibility,
|
||||
filesystem_mode,
|
||||
anon_session_id,
|
||||
tools_signature(
|
||||
tools,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
),
|
||||
mcp_signature(mcp_tools_by_agent),
|
||||
flags_signature(flags),
|
||||
system_prompt_hash(final_system_prompt),
|
||||
max_input_tokens,
|
||||
sorted(disabled_tools) if disabled_tools else None,
|
||||
)
|
||||
return await get_cache().get_or_build(cache_key, builder=_build)
|
||||
|
||||
|
||||
__all__ = ["build_agent_with_cache", "mcp_signature"]
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
|
|
@ -26,23 +25,24 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
|||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
|
||||
from app.agents.new_chat.tools.registry import build_tools_async
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||
from ..system_prompt import build_main_agent_system_prompt
|
||||
from ..tools import (
|
||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
|
||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
|
||||
)
|
||||
from .agent_cache import build_agent_with_cache
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
async def create_surfsense_deep_agent(
|
||||
async def create_multi_agent_chat_deep_agent(
|
||||
llm: BaseChatModel,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
|
|
@ -62,6 +62,9 @@ async def create_surfsense_deep_agent(
|
|||
):
|
||||
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled."""
|
||||
_t_agent_total = time.perf_counter()
|
||||
|
||||
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
|
||||
|
||||
filesystem_selection = filesystem_selection or FilesystemSelection()
|
||||
backend_resolver = build_backend_resolver(
|
||||
filesystem_selection,
|
||||
|
|
@ -85,7 +88,18 @@ async def create_surfsense_deep_agent(
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning("Failed to discover available connectors/document types: %s", e)
|
||||
logging.warning(
|
||||
"Connector/doc-type discovery failed; excluding connector subagents this turn: %s",
|
||||
e,
|
||||
)
|
||||
|
||||
# Fail closed: a None list short-circuits ``get_subagents_to_exclude`` to "exclude
|
||||
# nothing", which would silently advertise every connector specialist on a flaky
|
||||
# discovery call. Empty list excludes connector-gated subagents while keeping builtins.
|
||||
if available_connectors is None:
|
||||
available_connectors = []
|
||||
if available_document_types is None:
|
||||
available_document_types = []
|
||||
_perf_log.info(
|
||||
"[create_agent] Connector/doc-type discovery in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
|
|
@ -115,7 +129,18 @@ async def create_surfsense_deep_agent(
|
|||
}
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
mcp_tools_by_agent = await load_mcp_tools_by_connector(db_session, search_space_id)
|
||||
try:
|
||||
mcp_tools_by_agent = await load_mcp_tools_by_connector(
|
||||
db_session, search_space_id
|
||||
)
|
||||
except Exception as e:
|
||||
# Degrade to builtins-only rather than aborting the turn: a transient
|
||||
# DB or MCP-server hiccup should not deny the user a response.
|
||||
logging.warning(
|
||||
"MCP tool discovery failed; subagents will run without MCP tools this turn: %s",
|
||||
e,
|
||||
)
|
||||
mcp_tools_by_agent = {}
|
||||
_perf_log.info(
|
||||
"[create_agent] load_mcp_tools_by_connector in %.3fs (%d buckets)",
|
||||
time.perf_counter() - _t0,
|
||||
|
|
@ -195,9 +220,10 @@ async def create_surfsense_deep_agent(
|
|||
|
||||
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
||||
|
||||
config_id = agent_config.config_id if agent_config is not None else None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await asyncio.to_thread(
|
||||
build_compiled_agent_graph_sync,
|
||||
agent = await build_agent_with_cache(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
final_system_prompt=final_system_prompt,
|
||||
|
|
@ -217,6 +243,7 @@ async def create_surfsense_deep_agent(
|
|||
subagent_dependencies=dependencies,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||
disabled_tools=disabled_tools,
|
||||
config_id=config_id,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
"""Multi-agent middleware stack assembly."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .stack import build_main_agent_deepagent_middleware
|
||||
|
||||
__all__ = ["build_main_agent_deepagent_middleware"]
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
"""Audit row per tool call (reversibility metadata)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import ActionLogMiddleware
|
||||
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_action_log_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
thread_id: int | None,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
) -> ActionLogMiddleware | None:
|
||||
if not enabled(flags, "enable_action_log") or thread_id is None:
|
||||
return None
|
||||
try:
|
||||
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
|
||||
return ActionLogMiddleware(
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
tool_definitions=tool_defs_by_name,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"ActionLogMiddleware init failed; running without it.",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
"""Anonymous document hydration from Redis (cloud only)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import AnonymousDocumentMiddleware
|
||||
|
||||
|
||||
def build_anonymous_doc_mw(
|
||||
*,
|
||||
filesystem_mode: FilesystemMode,
|
||||
anon_session_id: str | None,
|
||||
) -> AnonymousDocumentMiddleware | None:
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
return AnonymousDocumentMiddleware(anon_session_id=anon_session_id)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Per-thread cooperative lock around the whole turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import BusyMutexMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_busy_mutex_mw(flags: AgentFeatureFlags) -> BusyMutexMiddleware | None:
|
||||
return BusyMutexMiddleware() if enabled(flags, "enable_busy_mutex") else None
|
||||
|
|
@ -69,9 +69,16 @@ def build_task_tool_with_parent_config(
|
|||
raise ValueError(msg)
|
||||
|
||||
state_update = {k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS}
|
||||
message_text = (
|
||||
result["messages"][-1].text.rstrip() if result["messages"][-1].text else ""
|
||||
)
|
||||
messages = result["messages"]
|
||||
if not messages:
|
||||
msg = (
|
||||
"CompiledSubAgent returned an empty 'messages' list. "
|
||||
"Subagents must produce at least one message so the parent has "
|
||||
"output to forward back to the user."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
last_text = getattr(messages[-1], "text", None) or ""
|
||||
message_text = last_text.rstrip()
|
||||
return Command(
|
||||
update={
|
||||
**state_update,
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""Spill + clear-tool-uses passes to keep payloads under budget."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.main_agent.context_prune.prune_tool_names import (
|
||||
safe_exclude_tools,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import (
|
||||
ClearToolUsesEdit,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
)
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_context_editing_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
max_input_tokens: int | None,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
) -> SpillingContextEditingMiddleware | None:
|
||||
if not enabled(flags, "enable_context_editing") or not max_input_tokens:
|
||||
return None
|
||||
spill_edit = SpillToBackendEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
)
|
||||
clear_edit = ClearToolUsesEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
placeholder="[cleared - older tool output trimmed for context]",
|
||||
)
|
||||
return SpillingContextEditingMiddleware(
|
||||
edits=[spill_edit, clear_edit],
|
||||
backend_resolver=backend_resolver,
|
||||
)
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
"""Drop duplicate HITL tool calls before execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware
|
||||
|
||||
|
||||
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
|
||||
return DedupHITLToolCallsMiddleware(agent_tools=list(tools))
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
"""Stop N identical tool calls in a row via interrupt."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import DoomLoopMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_doom_loop_mw(flags: AgentFeatureFlags) -> DoomLoopMiddleware | None:
|
||||
return (
|
||||
DoomLoopMiddleware(threshold=3) if enabled(flags, "enable_doom_loop") else None
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Commit staged cloud filesystem mutations to Postgres at end of turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import KnowledgeBasePersistenceMiddleware
|
||||
|
||||
|
||||
def build_kb_persistence_mw(
|
||||
*,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
) -> KnowledgeBasePersistenceMiddleware | None:
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
return KnowledgeBasePersistenceMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""KB priority planner: <priority_documents> injection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import KnowledgePriorityMiddleware
|
||||
|
||||
|
||||
def build_knowledge_priority_mw(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
search_space_id: int,
|
||||
filesystem_mode: FilesystemMode,
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
) -> KnowledgePriorityMiddleware:
|
||||
return KnowledgePriorityMiddleware(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""<workspace_tree> injection (cloud only)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import KnowledgeTreeMiddleware
|
||||
|
||||
|
||||
def build_knowledge_tree_mw(
|
||||
*,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
llm: BaseChatModel,
|
||||
) -> KnowledgeTreeMiddleware | None:
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
return KnowledgeTreeMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
llm=llm,
|
||||
)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Provider-compat: append a `_noop` tool when tools=[] but history has tool calls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import NoopInjectionMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_noop_injection_mw(flags: AgentFeatureFlags) -> NoopInjectionMiddleware | None:
|
||||
return NoopInjectionMiddleware() if enabled(flags, "enable_compaction_v2") else None
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""OTel spans on model and tool calls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import OtelSpanMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_otel_mw(flags: AgentFeatureFlags) -> OtelSpanMiddleware | None:
|
||||
return OtelSpanMiddleware() if enabled(flags, "enable_otel") else None
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
"""Tail-of-stack plugin slot driven by env allowlist."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
PluginContext,
|
||||
load_allowed_plugin_names_from_env,
|
||||
load_plugin_middlewares,
|
||||
)
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_plugin_middlewares(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
visibility: ChatVisibility,
|
||||
llm: BaseChatModel,
|
||||
) -> list[Any]:
|
||||
if not enabled(flags, "enable_plugin_loader"):
|
||||
return []
|
||||
try:
|
||||
allowed_names = load_allowed_plugin_names_from_env()
|
||||
if not allowed_names:
|
||||
return []
|
||||
return load_plugin_middlewares(
|
||||
PluginContext.build(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_visibility=visibility,
|
||||
llm=llm,
|
||||
),
|
||||
allowed_plugin_names=allowed_names,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"Plugin loader failed; continuing without plugins.",
|
||||
exc_info=True,
|
||||
)
|
||||
return []
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""Repair miscased / unknown tool names to the registered set or invalid_tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import ToolCallNameRepairMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
# deepagents-built-in tool names the repair pass treats as known.
|
||||
_DEEPAGENT_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"write_todos",
|
||||
"ls",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"glob",
|
||||
"grep",
|
||||
"execute",
|
||||
"task",
|
||||
"mkdir",
|
||||
"cd",
|
||||
"pwd",
|
||||
"move_file",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"execute_code",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def build_repair_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
tools: Sequence[BaseTool],
|
||||
) -> ToolCallNameRepairMiddleware | None:
|
||||
if not enabled(flags, "enable_tool_call_repair"):
|
||||
return None
|
||||
registered_names: set[str] = {t.name for t in tools}
|
||||
registered_names |= _DEEPAGENT_BUILTIN_TOOL_NAMES
|
||||
return ToolCallNameRepairMiddleware(
|
||||
registered_tool_names=registered_names,
|
||||
fuzzy_match_threshold=None,
|
||||
)
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""LLM-based tool subset selection (only when >30 tools)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_selector_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
tools: Sequence[BaseTool],
|
||||
) -> LLMToolSelectorMiddleware | None:
|
||||
if not enabled(flags, "enable_llm_tool_selector") or len(tools) <= 30:
|
||||
return None
|
||||
try:
|
||||
return LLMToolSelectorMiddleware(
|
||||
model="openai:gpt-4o-mini",
|
||||
max_tools=12,
|
||||
always_include=[
|
||||
name
|
||||
for name in (
|
||||
"update_memory",
|
||||
"get_connected_accounts",
|
||||
"scrape_webpage",
|
||||
)
|
||||
if name in {t.name for t in tools}
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
|
||||
return None
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""Skill discovery + injection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from deepagents.middleware.skills import SkillsMiddleware
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import (
|
||||
build_skills_backend_factory,
|
||||
default_skills_sources,
|
||||
)
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_skills_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
) -> SkillsMiddleware | None:
|
||||
if not enabled(flags, "enable_skills"):
|
||||
return None
|
||||
try:
|
||||
skills_factory = build_skills_backend_factory(
|
||||
search_space_id=search_space_id
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
)
|
||||
return SkillsMiddleware(
|
||||
backend=skills_factory,
|
||||
sources=default_skills_sources(),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
|
||||
return None
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Anthropic prompt caching annotations on system/tool/message blocks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
|
||||
|
||||
def build_anthropic_cache_mw() -> AnthropicPromptCachingMiddleware:
|
||||
return AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
"""Context-window summarization with SurfSense protected sections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from deepagents.backends import StateBackend
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.middleware import create_surfsense_compaction_middleware
|
||||
|
||||
|
||||
def build_compaction_mw(llm: BaseChatModel) -> Any:
|
||||
return create_surfsense_compaction_middleware(llm, StateBackend)
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
"""File-intent classifier that gates strict write contracts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.middleware import FileIntentMiddleware
|
||||
|
||||
|
||||
def build_file_intent_mw(llm: BaseChatModel) -> FileIntentMiddleware:
|
||||
return FileIntentMiddleware(llm=llm)
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
"""SurfSense filesystem tools/middleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def build_filesystem_mw(
|
||||
*,
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
) -> SurfSenseFilesystemMiddleware:
|
||||
return SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
"""Single source of truth for the feature-flag predicate."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
|
||||
|
||||
def enabled(flags: AgentFeatureFlags, attr: str) -> bool:
|
||||
"""``flags.<attr>`` is on AND the new-agent-stack kill switch is off."""
|
||||
return getattr(flags, attr) and not flags.disable_new_agent_stack
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""User/team memory injection prepended to the conversation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.middleware import MemoryInjectionMiddleware
|
||||
from app.db import ChatVisibility
|
||||
|
||||
|
||||
def build_memory_mw(
|
||||
*,
|
||||
user_id: str | None,
|
||||
search_space_id: int,
|
||||
visibility: ChatVisibility,
|
||||
) -> MemoryInjectionMiddleware:
|
||||
return MemoryInjectionMiddleware(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
thread_visibility=visibility,
|
||||
)
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Repair dangling tool-call sequences before each agent turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
|
||||
|
||||
def build_patch_tool_calls_mw() -> PatchToolCallsMiddleware:
|
||||
return PatchToolCallsMiddleware()
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Permission rulesets fanned out to parent / general-purpose / subagent stacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .context import PermissionContext, build_permission_context
|
||||
from .middleware import build_full_permission_mw
|
||||
|
||||
__all__ = [
|
||||
"PermissionContext",
|
||||
"build_full_permission_mw",
|
||||
"build_permission_context",
|
||||
]
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
"""Derive shared permission context once; fan out to all three stack layers.
|
||||
|
||||
The context carries:
|
||||
- ``rulesets``: full ask/deny/allow rules for the main-agent permission middleware.
|
||||
- ``general_purpose_interrupt_on``: ``ask`` rules mirrored as deepagents
|
||||
``interrupt_on`` so HITL still triggers from inside ``task`` runs (subagents
|
||||
bypass the main-agent permission middleware).
|
||||
- ``subagent_deny_mw``: a deny-only ``PermissionMiddleware`` instance shared
|
||||
across the general-purpose and registry subagent stacks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import PermissionMiddleware
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PermissionContext:
|
||||
rulesets: list[Ruleset]
|
||||
general_purpose_interrupt_on: dict[str, bool]
|
||||
subagent_deny_mw: PermissionMiddleware | None
|
||||
|
||||
|
||||
def build_permission_context(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
filesystem_mode: FilesystemMode,
|
||||
tools: Sequence[BaseTool],
|
||||
available_connectors: list[str] | None,
|
||||
) -> PermissionContext:
|
||||
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||
permission_enabled = enabled(flags, "enable_permission")
|
||||
|
||||
rulesets: list[Ruleset] = []
|
||||
if permission_enabled or is_desktop_fs:
|
||||
rulesets.append(
|
||||
Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
)
|
||||
)
|
||||
if is_desktop_fs:
|
||||
rulesets.append(
|
||||
Ruleset(
|
||||
rules=[
|
||||
Rule(permission="rm", pattern="*", action="ask"),
|
||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||
Rule(permission="move_file", pattern="*", action="ask"),
|
||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||
Rule(permission="write_file", pattern="*", action="ask"),
|
||||
],
|
||||
origin="desktop_safety",
|
||||
)
|
||||
)
|
||||
|
||||
tool_names_in_use = {t.name for t in tools}
|
||||
|
||||
if permission_enabled:
|
||||
available_set = set(available_connectors or [])
|
||||
synthesized: list[Rule] = []
|
||||
for tool_def in BUILTIN_TOOLS:
|
||||
if tool_def.name not in tool_names_in_use:
|
||||
continue
|
||||
rc = tool_def.required_connector
|
||||
if rc and rc not in available_set:
|
||||
synthesized.append(
|
||||
Rule(permission=tool_def.name, pattern="*", action="deny")
|
||||
)
|
||||
if synthesized:
|
||||
rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized"))
|
||||
|
||||
general_purpose_interrupt_on: dict[str, bool] = {
|
||||
rule.permission: True
|
||||
for rs in rulesets
|
||||
for rule in rs.rules
|
||||
if rule.action == "ask" and rule.permission in tool_names_in_use
|
||||
}
|
||||
|
||||
deny_rulesets = [
|
||||
Ruleset(
|
||||
rules=[r for r in rs.rules if r.action == "deny"],
|
||||
origin=rs.origin,
|
||||
)
|
||||
for rs in rulesets
|
||||
]
|
||||
deny_rulesets = [rs for rs in deny_rulesets if rs.rules]
|
||||
|
||||
subagent_deny_mw: PermissionMiddleware | None = (
|
||||
PermissionMiddleware(rulesets=deny_rulesets) if deny_rulesets else None
|
||||
)
|
||||
|
||||
return PermissionContext(
|
||||
rulesets=rulesets,
|
||||
general_purpose_interrupt_on=general_purpose_interrupt_on,
|
||||
subagent_deny_mw=subagent_deny_mw,
|
||||
)
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
"""Main-agent permission middleware (full ask/deny/allow rules)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.middleware import PermissionMiddleware
|
||||
from app.agents.new_chat.permissions import Ruleset
|
||||
|
||||
|
||||
def build_full_permission_mw(rulesets: list[Ruleset]) -> PermissionMiddleware | None:
|
||||
return PermissionMiddleware(rulesets=rulesets) if rulesets else None
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""Resilience middleware shared as the same instances across parent / general-purpose / registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .bundle import ResilienceBundle, build_resilience_bundle
|
||||
|
||||
__all__ = ["ResilienceBundle", "build_resilience_bundle"]
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
"""Construct each resilience middleware once; same instances flow into every consumer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import (
|
||||
ModelCallLimitMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import RetryAfterMiddleware
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
|
||||
from .fallback import build_fallback_mw
|
||||
from .model_call_limit import build_model_call_limit_mw
|
||||
from .retry import build_retry_mw
|
||||
from .tool_call_limit import build_tool_call_limit_mw
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResilienceBundle:
|
||||
retry: RetryAfterMiddleware | None
|
||||
fallback: ScopedModelFallbackMiddleware | None
|
||||
model_call_limit: ModelCallLimitMiddleware | None
|
||||
tool_call_limit: ToolCallLimitMiddleware | None
|
||||
|
||||
def as_list(self) -> list[Any]:
|
||||
return [
|
||||
m
|
||||
for m in (
|
||||
self.retry,
|
||||
self.fallback,
|
||||
self.model_call_limit,
|
||||
self.tool_call_limit,
|
||||
)
|
||||
if m is not None
|
||||
]
|
||||
|
||||
|
||||
def build_resilience_bundle(flags: AgentFeatureFlags) -> ResilienceBundle:
|
||||
return ResilienceBundle(
|
||||
retry=build_retry_mw(flags),
|
||||
fallback=build_fallback_mw(flags),
|
||||
model_call_limit=build_model_call_limit_mw(flags),
|
||||
tool_call_limit=build_tool_call_limit_mw(flags),
|
||||
)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""Switch to a fallback model on provider/network errors only."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
def build_fallback_mw(
|
||||
flags: AgentFeatureFlags,
|
||||
) -> ScopedModelFallbackMiddleware | None:
|
||||
if not enabled(flags, "enable_model_fallback"):
|
||||
return None
|
||||
try:
|
||||
return ScopedModelFallbackMiddleware(
|
||||
"openai:gpt-4o-mini",
|
||||
"anthropic:claude-3-5-haiku-20241022",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
||||
return None
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""Cap model calls per thread / per run to prevent runaway cost."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.agents.middleware import ModelCallLimitMiddleware
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
def build_model_call_limit_mw(
|
||||
flags: AgentFeatureFlags,
|
||||
) -> ModelCallLimitMiddleware | None:
|
||||
if not enabled(flags, "enable_model_call_limit"):
|
||||
return None
|
||||
return ModelCallLimitMiddleware(
|
||||
thread_limit=120,
|
||||
run_limit=80,
|
||||
exit_behavior="end",
|
||||
)
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
"""Retry on transient model errors (e.g. Retry-After-bearing 429s)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import RetryAfterMiddleware
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
def build_retry_mw(flags: AgentFeatureFlags) -> RetryAfterMiddleware | None:
|
||||
return (
|
||||
RetryAfterMiddleware(max_retries=3)
|
||||
if enabled(flags, "enable_retry_after")
|
||||
else None
|
||||
)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""Cap tool calls per thread / per run to bound infinite-loop blast radius."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.agents.middleware import ToolCallLimitMiddleware
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
def build_tool_call_limit_mw(
|
||||
flags: AgentFeatureFlags,
|
||||
) -> ToolCallLimitMiddleware | None:
|
||||
if not enabled(flags, "enable_tool_call_limit"):
|
||||
return None
|
||||
return ToolCallLimitMiddleware(
|
||||
thread_limit=300,
|
||||
run_limit=80,
|
||||
exit_behavior="continue",
|
||||
)
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Todo-list middleware (each consumer needs its own instance)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
|
||||
|
||||
def build_todos_mw() -> TodoListMiddleware:
|
||||
return TodoListMiddleware()
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
"""Main-agent middleware list assembly: one line per slot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from deepagents.backends import StateBackend
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents import (
|
||||
build_subagents,
|
||||
get_subagents_to_exclude,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.builtins.general_purpose.agent import (
|
||||
build_subagent as build_general_purpose_subagent,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .main_agent.action_log import build_action_log_mw
|
||||
from .main_agent.anonymous_doc import build_anonymous_doc_mw
|
||||
from .main_agent.busy_mutex import build_busy_mutex_mw
|
||||
from .main_agent.checkpointed_subagent_middleware import (
|
||||
SurfSenseCheckpointedSubAgentMiddleware,
|
||||
)
|
||||
from .main_agent.context_editing import build_context_editing_mw
|
||||
from .main_agent.dedup_hitl import build_dedup_hitl_mw
|
||||
from .main_agent.doom_loop import build_doom_loop_mw
|
||||
from .main_agent.kb_persistence import build_kb_persistence_mw
|
||||
from .main_agent.knowledge_priority import build_knowledge_priority_mw
|
||||
from .main_agent.knowledge_tree import build_knowledge_tree_mw
|
||||
from .main_agent.noop_injection import build_noop_injection_mw
|
||||
from .main_agent.otel import build_otel_mw
|
||||
from .main_agent.plugins import build_plugin_middlewares
|
||||
from .main_agent.repair import build_repair_mw
|
||||
from .main_agent.selector import build_selector_mw
|
||||
from .main_agent.skills import build_skills_mw
|
||||
from .shared.anthropic_cache import build_anthropic_cache_mw
|
||||
from .shared.compaction import build_compaction_mw
|
||||
from .shared.file_intent import build_file_intent_mw
|
||||
from .shared.filesystem import build_filesystem_mw
|
||||
from .shared.memory import build_memory_mw
|
||||
from .shared.patch_tool_calls import build_patch_tool_calls_mw
|
||||
from .shared.permissions import (
|
||||
build_full_permission_mw,
|
||||
build_permission_context,
|
||||
)
|
||||
from .shared.resilience import build_resilience_bundle
|
||||
from .shared.todos import build_todos_mw
|
||||
from .subagent.extras import build_subagent_extras
|
||||
|
||||
|
||||
def build_main_agent_deepagent_middleware(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
visibility: ChatVisibility,
|
||||
anon_session_id: str | None,
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
max_input_tokens: int | None,
|
||||
flags: AgentFeatureFlags,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
checkpointer: Checkpointer,
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Ordered middleware for ``create_agent`` (None entries already stripped)."""
|
||||
permissions = build_permission_context(
|
||||
flags=flags,
|
||||
filesystem_mode=filesystem_mode,
|
||||
tools=tools,
|
||||
available_connectors=available_connectors,
|
||||
)
|
||||
resilience = build_resilience_bundle(flags)
|
||||
|
||||
# Single instance threaded into both the main-agent stack and the general-purpose subagent.
|
||||
memory_mw = build_memory_mw(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
visibility=visibility,
|
||||
)
|
||||
|
||||
general_purpose_subagent = build_general_purpose_subagent(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
permissions=permissions,
|
||||
resilience=resilience,
|
||||
memory_mw=memory_mw,
|
||||
)
|
||||
|
||||
subagents_registry: list[SubAgent] = []
|
||||
try:
|
||||
subagent_extras = build_subagent_extras(
|
||||
permissions=permissions,
|
||||
resilience=resilience,
|
||||
)
|
||||
subagents_registry = build_subagents(
|
||||
dependencies=subagent_dependencies,
|
||||
model=llm,
|
||||
extra_middleware=subagent_extras,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent or {},
|
||||
exclude=get_subagents_to_exclude(available_connectors),
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
logging.debug(
|
||||
"Subagents registry: %s",
|
||||
[s["name"] for s in subagents_registry],
|
||||
)
|
||||
except Exception:
|
||||
# Degrade to general-purpose-only rather than aborting the turn:
|
||||
# one bad subagent dep should not deny the user a response.
|
||||
logging.exception(
|
||||
"Subagents registry build failed; falling back to general-purpose only"
|
||||
)
|
||||
subagents_registry = []
|
||||
|
||||
subagents: list[SubAgent] = [general_purpose_subagent, *subagents_registry]
|
||||
|
||||
stack: list[Any] = [
|
||||
build_busy_mutex_mw(flags),
|
||||
build_otel_mw(flags),
|
||||
build_todos_mw(),
|
||||
memory_mw,
|
||||
build_anonymous_doc_mw(
|
||||
filesystem_mode=filesystem_mode, anon_session_id=anon_session_id
|
||||
),
|
||||
build_knowledge_tree_mw(
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
llm=llm,
|
||||
),
|
||||
build_knowledge_priority_mw(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
),
|
||||
build_file_intent_mw(llm),
|
||||
build_filesystem_mw(
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
build_kb_persistence_mw(
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
build_skills_mw(
|
||||
flags=flags,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
),
|
||||
SurfSenseCheckpointedSubAgentMiddleware(
|
||||
checkpointer=checkpointer,
|
||||
backend=StateBackend,
|
||||
subagents=subagents,
|
||||
),
|
||||
build_selector_mw(flags=flags, tools=tools),
|
||||
resilience.model_call_limit,
|
||||
resilience.tool_call_limit,
|
||||
build_context_editing_mw(
|
||||
flags=flags,
|
||||
max_input_tokens=max_input_tokens,
|
||||
tools=tools,
|
||||
backend_resolver=backend_resolver,
|
||||
),
|
||||
build_compaction_mw(llm),
|
||||
build_noop_injection_mw(flags),
|
||||
resilience.retry,
|
||||
resilience.fallback,
|
||||
build_repair_mw(flags=flags, tools=tools),
|
||||
build_full_permission_mw(permissions.rulesets),
|
||||
build_doom_loop_mw(flags),
|
||||
build_action_log_mw(
|
||||
flags=flags,
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
),
|
||||
build_patch_tool_calls_mw(),
|
||||
build_dedup_hitl_mw(tools),
|
||||
*build_plugin_middlewares(
|
||||
flags=flags,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
visibility=visibility,
|
||||
llm=llm,
|
||||
),
|
||||
build_anthropic_cache_mw(),
|
||||
]
|
||||
return [m for m in stack if m is not None]
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
"""Extra middleware threaded into every registry subagent's stack.
|
||||
|
||||
Registry subagents are scoped to one domain (deliverables, research, memory,
|
||||
connectors, MCP) and never read or write the SurfSense filesystem — that
|
||||
capability belongs to the main agent and is delegated to the general-purpose
|
||||
subagent as an escape hatch. Keeping FS off the registry stacks avoids
|
||||
polluting their tool surface with FS tools they never act on.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..shared.permissions import PermissionContext
|
||||
from ..shared.resilience import ResilienceBundle
|
||||
from ..shared.todos import build_todos_mw
|
||||
|
||||
|
||||
def build_subagent_extras(
|
||||
*,
|
||||
permissions: PermissionContext,
|
||||
resilience: ResilienceBundle,
|
||||
) -> list[Any]:
|
||||
extras: list[Any] = [build_todos_mw()]
|
||||
if permissions.subagent_deny_mw is not None:
|
||||
extras.append(permissions.subagent_deny_mw)
|
||||
extras.extend(resilience.as_list())
|
||||
return extras
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
"""General-purpose subagent for the multi-agent main agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from deepagents import SubAgent
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.anthropic_cache import (
|
||||
build_anthropic_cache_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.compaction import (
|
||||
build_compaction_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.file_intent import (
|
||||
build_file_intent_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.filesystem import (
|
||||
build_filesystem_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.patch_tool_calls import (
|
||||
build_patch_tool_calls_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions import (
|
||||
PermissionContext,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.resilience import (
|
||||
ResilienceBundle,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.todos import build_todos_mw
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import MemoryInjectionMiddleware
|
||||
|
||||
NAME = "general-purpose"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
permissions: PermissionContext,
|
||||
resilience: ResilienceBundle,
|
||||
memory_mw: MemoryInjectionMiddleware,
|
||||
) -> SubAgent:
|
||||
"""Deny + resilience inserts encapsulated here so the orchestrator never mutates the list."""
|
||||
middleware: list[Any] = [
|
||||
build_todos_mw(),
|
||||
memory_mw,
|
||||
build_file_intent_mw(llm),
|
||||
build_filesystem_mw(
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
build_compaction_mw(llm),
|
||||
build_patch_tool_calls_mw(),
|
||||
build_anthropic_cache_mw(),
|
||||
]
|
||||
|
||||
if permissions.subagent_deny_mw is not None:
|
||||
patch_idx = next(
|
||||
(
|
||||
i
|
||||
for i, m in enumerate(middleware)
|
||||
if isinstance(m, PatchToolCallsMiddleware)
|
||||
),
|
||||
len(middleware),
|
||||
)
|
||||
middleware.insert(patch_idx, permissions.subagent_deny_mw)
|
||||
|
||||
resilience_mws = resilience.as_list()
|
||||
if resilience_mws:
|
||||
cache_idx = next(
|
||||
(
|
||||
i
|
||||
for i, m in enumerate(middleware)
|
||||
if isinstance(m, AnthropicPromptCachingMiddleware)
|
||||
),
|
||||
len(middleware),
|
||||
)
|
||||
for offset, mw in enumerate(resilience_mws):
|
||||
middleware.insert(cache_idx + offset, mw)
|
||||
|
||||
spec: dict[str, Any] = {
|
||||
**GENERAL_PURPOSE_SUBAGENT,
|
||||
"model": llm,
|
||||
"tools": tools,
|
||||
"middleware": middleware,
|
||||
}
|
||||
if permissions.general_purpose_interrupt_on:
|
||||
spec["interrupt_on"] = permissions.general_purpose_interrupt_on
|
||||
return cast(SubAgent, spec)
|
||||
|
|
@ -168,20 +168,46 @@ def create_create_calendar_event_tool(
|
|||
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
tz = context.get("timezone", "UTC")
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
(
|
||||
event_id,
|
||||
html_link,
|
||||
error,
|
||||
) = await ComposioService().create_calendar_event(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
summary=final_summary,
|
||||
start_datetime=final_start_datetime,
|
||||
end_datetime=final_end_datetime,
|
||||
timezone=tz,
|
||||
description=final_description,
|
||||
location=final_location,
|
||||
attendees=final_attendees,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
created = {
|
||||
"id": event_id,
|
||||
"summary": final_summary,
|
||||
"htmlLink": html_link,
|
||||
}
|
||||
logger.info(
|
||||
f"Calendar event created via Composio: id={event_id}, summary={final_summary}"
|
||||
)
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
|
|
@ -211,70 +237,69 @@ def create_create_calendar_event_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
tz = context.get("timezone", "UTC")
|
||||
event_body: dict[str, Any] = {
|
||||
"summary": final_summary,
|
||||
"start": {"dateTime": final_start_datetime, "timeZone": tz},
|
||||
"end": {"dateTime": final_end_datetime, "timeZone": tz},
|
||||
}
|
||||
if final_description:
|
||||
event_body["description"] = final_description
|
||||
if final_location:
|
||||
event_body["location"] = final_location
|
||||
if final_attendees:
|
||||
event_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_attendees if e.strip()
|
||||
]
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.insert(calendarId="primary", body=event_body)
|
||||
.execute()
|
||||
),
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
event_body: dict[str, Any] = {
|
||||
"summary": final_summary,
|
||||
"start": {"dateTime": final_start_datetime, "timeZone": tz},
|
||||
"end": {"dateTime": final_end_datetime, "timeZone": tz},
|
||||
}
|
||||
if final_description:
|
||||
event_body["description"] = final_description
|
||||
if final_location:
|
||||
event_body["location"] = final_location
|
||||
if final_attendees:
|
||||
event_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_attendees if e.strip()
|
||||
]
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.insert(calendarId="primary", body=event_body)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
logger.info(
|
||||
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
|
||||
)
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Calendar event created via Google API: id={created.get('id')}, summary={created.get('summary')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -163,16 +163,22 @@ def create_delete_calendar_event_tool(
|
|||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
error = await ComposioService().delete_calendar_event(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
event_id=final_event_id,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
|
|
@ -202,51 +208,51 @@ def create_delete_calendar_event_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.delete(calendarId="primary", eventId=final_event_id)
|
||||
.execute()
|
||||
),
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.delete(calendarId="primary", eventId=final_event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event deleted: event_id={final_event_id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,14 @@ _CALENDAR_TYPES = [
|
|||
]
|
||||
|
||||
|
||||
def _to_calendar_boundary(value: str, *, is_end: bool) -> str:
|
||||
"""Promote a bare YYYY-MM-DD to RFC3339 with a day-edge time, leave full datetimes alone."""
|
||||
if "T" in value:
|
||||
return value
|
||||
time = "23:59:59" if is_end else "00:00:00"
|
||||
return f"{value}T{time}Z"
|
||||
|
||||
|
||||
def create_search_calendar_events_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
|
|
@ -61,22 +69,47 @@ def create_search_calendar_events_tool(
|
|||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
|
||||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
cal = GoogleCalendarConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
events_raw, error = await ComposioService().get_calendar_events(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
time_min=_to_calendar_boundary(start_date, is_end=False),
|
||||
time_max=_to_calendar_boundary(end_date, is_end=True),
|
||||
max_results=max_results,
|
||||
)
|
||||
if not events_raw and not error:
|
||||
error = "No events found in the specified date range."
|
||||
else:
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
events_raw, error = await cal.get_all_primary_calendar_events(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
max_results=max_results,
|
||||
)
|
||||
from app.connectors.google_calendar_connector import (
|
||||
GoogleCalendarConnector,
|
||||
)
|
||||
|
||||
cal = GoogleCalendarConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
events_raw, error = await cal.get_all_primary_calendar_events(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
if error:
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -192,20 +192,62 @@ def create_update_calendar_event_tool(
|
|||
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
has_changes = any(
|
||||
v is not None
|
||||
for v in (
|
||||
final_new_summary,
|
||||
final_new_start_datetime,
|
||||
final_new_end_datetime,
|
||||
final_new_description,
|
||||
final_new_location,
|
||||
final_new_attendees,
|
||||
)
|
||||
)
|
||||
if not has_changes:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No changes specified. Please provide at least one field to update.",
|
||||
}
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
tz_for_composio: str | None = None
|
||||
if final_new_start_datetime is not None and not _is_date_only(
|
||||
final_new_start_datetime
|
||||
):
|
||||
tz_for_composio = (
|
||||
context.get("timezone") if isinstance(context, dict) else None
|
||||
)
|
||||
|
||||
_, html_link, error = await ComposioService().update_calendar_event(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
event_id=final_event_id,
|
||||
summary=final_new_summary,
|
||||
start_time=final_new_start_datetime,
|
||||
end_time=final_new_end_datetime,
|
||||
timezone=tz_for_composio,
|
||||
description=final_new_description,
|
||||
location=final_new_location,
|
||||
attendees=final_new_attendees,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
updated = {"htmlLink": html_link}
|
||||
logger.info(
|
||||
f"Calendar event updated via Composio: event_id={final_event_id}"
|
||||
)
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
|
|
@ -235,81 +277,79 @@ def create_update_calendar_event_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
update_body: dict[str, Any] = {}
|
||||
if final_new_summary is not None:
|
||||
update_body["summary"] = final_new_summary
|
||||
if final_new_start_datetime is not None:
|
||||
update_body["start"] = _build_time_body(
|
||||
final_new_start_datetime, context
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
if final_new_end_datetime is not None:
|
||||
update_body["end"] = _build_time_body(final_new_end_datetime, context)
|
||||
if final_new_description is not None:
|
||||
update_body["description"] = final_new_description
|
||||
if final_new_location is not None:
|
||||
update_body["location"] = final_new_location
|
||||
if final_new_attendees is not None:
|
||||
update_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_new_attendees if e.strip()
|
||||
]
|
||||
|
||||
if not update_body:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No changes specified. Please provide at least one field to update.",
|
||||
}
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.patch(
|
||||
calendarId="primary",
|
||||
eventId=final_event_id,
|
||||
body=update_body,
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
update_body: dict[str, Any] = {}
|
||||
if final_new_summary is not None:
|
||||
update_body["summary"] = final_new_summary
|
||||
if final_new_start_datetime is not None:
|
||||
update_body["start"] = _build_time_body(
|
||||
final_new_start_datetime, context
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
if final_new_end_datetime is not None:
|
||||
update_body["end"] = _build_time_body(
|
||||
final_new_end_datetime, context
|
||||
)
|
||||
if final_new_description is not None:
|
||||
update_body["description"] = final_new_description
|
||||
if final_new_location is not None:
|
||||
update_body["location"] = final_new_location
|
||||
if final_new_attendees is not None:
|
||||
update_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_new_attendees if e.strip()
|
||||
]
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.patch(
|
||||
calendarId="primary",
|
||||
eventId=final_event_id,
|
||||
body=update_body,
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
logger.info(f"Calendar event updated: event_id={final_event_id}")
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Calendar event updated via Google API: event_id={final_event_id}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id is not None:
|
||||
|
|
|
|||
|
|
@ -161,16 +161,39 @@ def create_create_gmail_draft_tool(
|
|||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
(
|
||||
draft_id,
|
||||
draft_message_id,
|
||||
draft_thread_id,
|
||||
error,
|
||||
) = await ComposioService().create_gmail_draft(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
to=final_to,
|
||||
subject=final_subject,
|
||||
body=final_body,
|
||||
cc=final_cc,
|
||||
bcc=final_bcc,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
created = {
|
||||
"id": draft_id,
|
||||
"message": {
|
||||
"id": draft_message_id,
|
||||
"threadId": draft_thread_id,
|
||||
},
|
||||
}
|
||||
logger.info(f"Gmail draft created via Composio: id={draft_id}")
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
|
@ -208,63 +231,65 @@ def create_create_gmail_draft_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.create(userId="me", body={"message": {"raw": raw}})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.create(userId="me", body={"message": {"raw": raw}})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
logger.info(f"Gmail draft created: id={created.get('id')}")
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Gmail draft created via Google API: id={created.get('id')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -50,7 +50,56 @@ def create_read_gmail_email_tool(
|
|||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
_format_gmail_summary,
|
||||
)
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
detail, error = await ComposioService().get_gmail_message_detail(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
message_id=message_id,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
if not detail:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": f"Email with ID '{message_id}' not found.",
|
||||
}
|
||||
|
||||
summary = _format_gmail_summary(detail)
|
||||
content = (
|
||||
f"# {summary['subject']}\n\n"
|
||||
f"**From:** {summary['from']}\n"
|
||||
f"**To:** {summary['to']}\n"
|
||||
f"**Date:** {summary['date']}\n\n"
|
||||
f"## Message Content\n\n"
|
||||
f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
|
||||
f"## Message Details\n\n"
|
||||
f"- **Message ID:** {summary['message_id']}\n"
|
||||
f"- **Thread ID:** {summary['thread_id']}\n"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": summary["message_id"] or message_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
_build_credentials,
|
||||
)
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
|
@ -15,57 +14,6 @@ _GMAIL_TYPES = [
|
|||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
_token_encryption_cache: object | None = None
|
||||
|
||||
|
||||
def _get_token_encryption():
|
||||
global _token_encryption_cache
|
||||
if _token_encryption_cache is None:
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise RuntimeError("SECRET_KEY not configured for token decryption.")
|
||||
_token_encryption_cache = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption_cache
|
||||
|
||||
|
||||
def _build_credentials(connector: SearchSourceConnector):
|
||||
"""Build Google OAuth Credentials from a connector's stored config.
|
||||
|
||||
Handles both native OAuth connectors (with encrypted tokens) and
|
||||
Composio-backed connectors. Shared by Gmail and Calendar tools.
|
||||
"""
|
||||
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected account ID not found.")
|
||||
return build_composio_credentials(cca_id)
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
cfg = dict(connector.config)
|
||||
if cfg.get("_token_encrypted"):
|
||||
enc = _get_token_encryption()
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if cfg.get(key):
|
||||
cfg[key] = enc.decrypt_token(cfg[key])
|
||||
|
||||
exp = (cfg.get("expiry") or "").replace("Z", "")
|
||||
return Credentials(
|
||||
token=cfg.get("token"),
|
||||
refresh_token=cfg.get("refresh_token"),
|
||||
token_uri=cfg.get("token_uri"),
|
||||
client_id=cfg.get("client_id"),
|
||||
client_secret=cfg.get("client_secret"),
|
||||
scopes=cfg.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
|
||||
def create_search_gmail_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
|
|
@ -110,6 +58,50 @@ def create_search_gmail_tool(
|
|||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
_format_gmail_summary,
|
||||
)
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
(
|
||||
messages,
|
||||
_next,
|
||||
_estimate,
|
||||
error,
|
||||
) = await ComposioService().get_gmail_messages(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
emails = [_format_gmail_summary(m) for m in messages]
|
||||
if not emails:
|
||||
return {
|
||||
"status": "success",
|
||||
"emails": [],
|
||||
"total": 0,
|
||||
"message": "No emails found.",
|
||||
}
|
||||
return {"status": "success", "emails": emails, "total": len(emails)}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
_build_credentials,
|
||||
)
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
|
|
|||
|
|
@ -162,16 +162,31 @@ def create_send_gmail_email_tool(
|
|||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
(
|
||||
sent_message_id,
|
||||
sent_thread_id,
|
||||
error,
|
||||
) = await ComposioService().send_gmail_email(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
to=final_to,
|
||||
subject=final_subject,
|
||||
body=final_body,
|
||||
cc=final_cc,
|
||||
bcc=final_bcc,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
sent = {"id": sent_message_id, "threadId": sent_thread_id}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
|
@ -209,61 +224,61 @@ def create_send_gmail_email_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
sent = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"raw": raw})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
try:
|
||||
sent = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"raw": raw})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
|
||||
|
|
|
|||
|
|
@ -162,16 +162,22 @@ def create_trash_gmail_email_tool(
|
|||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
error = await ComposioService().trash_gmail_message(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
message_id=final_message_id,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
|
@ -209,49 +215,49 @@ def create_trash_gmail_email_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.trash(userId="me", id=final_message_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.trash(userId="me", id=final_message_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail email trashed: message_id={final_message_id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -192,16 +192,51 @@ def create_update_gmail_draft_tool(
|
|||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
|
||||
if not final_draft_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"Could not find this draft in Gmail. "
|
||||
"It may have already been sent or deleted."
|
||||
),
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
(
|
||||
new_draft_id,
|
||||
new_message_id,
|
||||
error,
|
||||
) = await ComposioService().update_gmail_draft(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
draft_id=final_draft_id,
|
||||
to=final_to or None,
|
||||
subject=final_subject,
|
||||
body=final_body,
|
||||
cc=final_cc,
|
||||
bcc=final_bcc,
|
||||
)
|
||||
if error:
|
||||
if "not found" in error.lower() or "no longer" in error.lower():
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
||||
}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
updated = {
|
||||
"id": new_draft_id or final_draft_id,
|
||||
"message": {"id": new_message_id} if new_message_id else {},
|
||||
}
|
||||
logger.info(f"Gmail draft updated via Composio: id={updated.get('id')}")
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
|
@ -239,88 +274,90 @@ def create_update_gmail_draft_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
# Resolve draft_id if not already available
|
||||
if not final_draft_id:
|
||||
logger.info(
|
||||
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
||||
)
|
||||
final_draft_id = await _find_draft_id_by_message(
|
||||
gmail_service, message_id
|
||||
)
|
||||
|
||||
if not final_draft_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"Could not find this draft in Gmail. "
|
||||
"It may have already been sent or deleted."
|
||||
),
|
||||
}
|
||||
|
||||
message = MIMEText(final_body)
|
||||
if final_to:
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.update(
|
||||
userId="me",
|
||||
id=final_draft_id,
|
||||
body={"message": {"raw": raw}},
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
# Resolve draft_id if not already available
|
||||
if not final_draft_id:
|
||||
logger.info(
|
||||
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
||||
)
|
||||
final_draft_id = await _find_draft_id_by_message(
|
||||
gmail_service, message_id
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
|
||||
if not final_draft_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
||||
"message": (
|
||||
"Could not find this draft in Gmail. "
|
||||
"It may have already been sent or deleted."
|
||||
),
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail draft updated: id={updated.get('id')}")
|
||||
message = MIMEText(final_body)
|
||||
if final_to:
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.update(
|
||||
userId="me",
|
||||
id=final_draft_id,
|
||||
body={"message": {"raw": raw}},
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Gmail draft updated via Google API: id={updated.get('id')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id:
|
||||
|
|
|
|||
|
|
@ -179,59 +179,96 @@ def create_create_google_drive_file_tool(
|
|||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
pre_built_creds = None
|
||||
async def _flag_auth_expired() -> None:
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Google Drive connector.",
|
||||
}
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
created = await client.create_file(
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
created, error = await ComposioService().create_drive_file_from_text(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
name=final_name,
|
||||
mime_type=mime_type,
|
||||
parent_folder_id=final_parent_folder_id,
|
||||
content=final_content,
|
||||
parent_id=final_parent_folder_id,
|
||||
)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
if error or not created:
|
||||
err_lower = (error or "").lower()
|
||||
if (
|
||||
"insufficient" in err_lower
|
||||
or "permission" in err_lower
|
||||
or "403" in err_lower
|
||||
):
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for Composio Drive connector {actual_connector_id}: {error}"
|
||||
)
|
||||
await _flag_auth_expired()
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
logger.error(
|
||||
f"Composio Drive create_file failed for connector {actual_connector_id}: {error}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the file. Please try again.",
|
||||
}
|
||||
raise
|
||||
else:
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
)
|
||||
try:
|
||||
created = await client.create_file(
|
||||
name=final_name,
|
||||
mime_type=mime_type,
|
||||
parent_folder_id=final_parent_folder_id,
|
||||
content=final_content,
|
||||
)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
||||
)
|
||||
await _flag_auth_expired()
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
|
|
|
|||
|
|
@ -158,51 +158,84 @@ def create_delete_google_drive_file_tool(
|
|||
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
pre_built_creds = None
|
||||
async def _flag_auth_expired() -> None:
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=connector.id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
await client.trash_file(file_id=final_file_id)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Google Drive connector.",
|
||||
}
|
||||
raise
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
error = await ComposioService().trash_drive_file(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
file_id=final_file_id,
|
||||
)
|
||||
if error:
|
||||
err_lower = error.lower()
|
||||
if (
|
||||
"insufficient" in err_lower
|
||||
or "permission" in err_lower
|
||||
or "403" in err_lower
|
||||
):
|
||||
logger.warning(
|
||||
f"Insufficient permissions for Composio Drive connector {connector.id}: {error}"
|
||||
)
|
||||
await _flag_auth_expired()
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
logger.error(
|
||||
f"Composio Drive trash_file failed for connector {connector.id}: {error}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while trashing the file. Please try again.",
|
||||
}
|
||||
else:
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
try:
|
||||
await client.trash_file(file_id=final_file_id)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
||||
)
|
||||
await _flag_auth_expired()
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ from langchain.agents import create_agent
|
|||
from langchain.agents.middleware import (
|
||||
LLMToolSelectorMiddleware,
|
||||
ModelCallLimitMiddleware,
|
||||
ModelFallbackMiddleware,
|
||||
TodoListMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
|
|
@ -77,6 +76,9 @@ from app.agents.new_chat.middleware import (
|
|||
create_surfsense_compaction_middleware,
|
||||
default_skills_sources,
|
||||
)
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
PluginContext,
|
||||
|
|
@ -792,15 +794,15 @@ def _build_compiled_agent_blocking(
|
|||
# Fallback chain — primary is the agent's own model; we add cheap
|
||||
# alternatives. Off by default; only the first call site that
|
||||
# configures the chain via env should enable it.
|
||||
fallback_mw: ModelFallbackMiddleware | None = None
|
||||
fallback_mw: ScopedModelFallbackMiddleware | None = None
|
||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
fallback_mw = ModelFallbackMiddleware(
|
||||
fallback_mw = ScopedModelFallbackMiddleware(
|
||||
"openai:gpt-4o-mini",
|
||||
"anthropic:claude-3-5-haiku-20241022",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
||||
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
||||
fallback_mw = None
|
||||
model_call_limit_mw = (
|
||||
ModelCallLimitMiddleware(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,91 @@
|
|||
"""Fallback only on provider/network errors; let programming bugs raise."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
# Matched by class name across the MRO so we don't have to import every
|
||||
# provider SDK (openai/anthropic/google/...). Extend as new providers ship.
|
||||
_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"RateLimitError",
|
||||
"APIStatusError",
|
||||
"InternalServerError",
|
||||
"ServiceUnavailableError",
|
||||
"BadGatewayError",
|
||||
"GatewayTimeoutError",
|
||||
"APIConnectionError",
|
||||
"APITimeoutError",
|
||||
"ConnectError",
|
||||
"ConnectTimeout",
|
||||
"ReadTimeout",
|
||||
"RemoteProtocolError",
|
||||
"TimeoutError",
|
||||
"TimeoutException",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_fallback_eligible(exc: BaseException) -> bool:
|
||||
return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__)
|
||||
|
||||
|
||||
class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
|
||||
"""Re-raise non-provider exceptions instead of walking the fallback chain."""
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[Any],
|
||||
handler: Callable[[ModelRequest[Any]], ModelResponse[Any]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
last_exception: Exception
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
|
||||
for fallback_model in self.models:
|
||||
try:
|
||||
return handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[Any],
|
||||
handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
last_exception: Exception
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
|
||||
for fallback_model in self.models:
|
||||
try:
|
||||
return await handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
|
@ -1,5 +1,15 @@
|
|||
import re
|
||||
|
||||
from app.config import config
|
||||
|
||||
# Regex that matches a Markdown table block (header + separator + one or more rows)
|
||||
# A table block starts with a | at the beginning of a line and ends when a
|
||||
# non-table line (or end of string) is encountered.
|
||||
_TABLE_BLOCK_RE = re.compile(
|
||||
r"(?:(?:^|\n)(?=[ \t]*\|)(?:[ \t]*\|[^\n]*\n)+)",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
|
||||
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
|
||||
"""Chunk a text string using the configured chunker and return the chunk texts."""
|
||||
|
|
@ -7,3 +17,43 @@ def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
|
|||
config.code_chunker_instance if use_code_chunker else config.chunker_instance
|
||||
)
|
||||
return [c.text for c in chunker.chunk(text)]
|
||||
|
||||
|
||||
def chunk_text_hybrid(text: str) -> list[str]:
|
||||
"""Table-aware chunker that prevents Markdown tables from being split mid-row.
|
||||
|
||||
Algorithm:
|
||||
1. Scan the document for Markdown table blocks.
|
||||
2. Each table block is emitted as a single, unmodified chunk so that its
|
||||
header, separator row, and data rows always stay together.
|
||||
3. The non-table prose segments between (and around) tables are passed through
|
||||
the normal ``chunk_text`` chunker and their sub-chunks are interleaved in
|
||||
document order.
|
||||
|
||||
This ensures that table data is never sliced in the middle by the token-based
|
||||
chunker, which would otherwise produce garbled rows that are useless for RAG.
|
||||
|
||||
Fixes #1334.
|
||||
"""
|
||||
chunks: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
for match in _TABLE_BLOCK_RE.finditer(text):
|
||||
# Prose before this table
|
||||
prose = text[cursor : match.start()].strip()
|
||||
if prose:
|
||||
chunks.extend(chunk_text(prose))
|
||||
|
||||
# The table itself is kept as one indivisible chunk
|
||||
table_block = match.group(0).strip()
|
||||
if table_block:
|
||||
chunks.append(table_block)
|
||||
|
||||
cursor = match.end()
|
||||
|
||||
# Remaining prose after the last table (or entire text if no tables)
|
||||
trailing = text[cursor:].strip()
|
||||
if trailing:
|
||||
chunks.extend(chunk_text(trailing))
|
||||
|
||||
return chunks
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from app.db import (
|
|||
DocumentType,
|
||||
)
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid
|
||||
from app.indexing_pipeline.document_embedder import embed_texts
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
|
|
@ -387,11 +387,19 @@ class IndexingPipelineService:
|
|||
)
|
||||
|
||||
t_step = time.perf_counter()
|
||||
chunk_texts = await asyncio.to_thread(
|
||||
chunk_text,
|
||||
connector_doc.source_markdown,
|
||||
use_code_chunker=connector_doc.should_use_code_chunker,
|
||||
)
|
||||
if connector_doc.should_use_code_chunker:
|
||||
chunk_texts = await asyncio.to_thread(
|
||||
chunk_text,
|
||||
connector_doc.source_markdown,
|
||||
use_code_chunker=True,
|
||||
)
|
||||
else:
|
||||
# Use the table-aware hybrid chunker so Markdown tables are not
|
||||
# split mid-row (see issue #1334).
|
||||
chunk_texts = await asyncio.to_thread(
|
||||
chunk_text_hybrid,
|
||||
connector_doc.source_markdown,
|
||||
)
|
||||
|
||||
texts_to_embed = [content, *chunk_texts]
|
||||
embeddings = await asyncio.to_thread(embed_texts, texts_to_embed)
|
||||
|
|
|
|||
|
|
@ -1027,6 +1027,505 @@ class ComposioService:
|
|||
logger.error(f"Failed to list Calendar events: {e!s}")
|
||||
return [], str(e)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_response_data(data: Any) -> Any:
|
||||
"""Composio responses often nest the meaningful payload under
|
||||
``data.data.response_data``. Walk that envelope safely and return
|
||||
whichever inner dict actually has the result keys."""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
inner = data.get("data", data)
|
||||
if isinstance(inner, dict):
|
||||
return inner.get("response_data", inner)
|
||||
return inner
|
||||
|
||||
@staticmethod
|
||||
def _split_email_csv(value: str | None) -> list[str] | None:
|
||||
"""Tools accept comma-separated cc/bcc strings; Composio expects an array."""
|
||||
if not value:
|
||||
return None
|
||||
addrs = [e.strip() for e in value.split(",") if e.strip()]
|
||||
return addrs or None
|
||||
|
||||
# ===== Gmail write methods =====
|
||||
|
||||
async def send_gmail_email(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
is_html: bool = False,
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
"""Send a Gmail message via the Composio ``GMAIL_SEND_EMAIL`` toolkit.
|
||||
|
||||
Returns:
|
||||
Tuple of (message_id, thread_id, error). On success ``error`` is
|
||||
None and at least one of the IDs is populated when Composio
|
||||
returns them; on failure both IDs are None.
|
||||
"""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"recipient_email": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"is_html": is_html,
|
||||
}
|
||||
if cc:
|
||||
cc_list = self._split_email_csv(cc)
|
||||
if cc_list:
|
||||
params["cc"] = cc_list
|
||||
if bcc:
|
||||
bcc_list = self._split_email_csv(bcc)
|
||||
if bcc_list:
|
||||
params["bcc"] = bcc_list
|
||||
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GMAIL_SEND_EMAIL",
|
||||
params=params,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, None, result.get("error", "Unknown error")
|
||||
|
||||
payload = self._unwrap_response_data(result.get("data", {}))
|
||||
message_id = None
|
||||
thread_id = None
|
||||
if isinstance(payload, dict):
|
||||
message_id = (
|
||||
payload.get("id")
|
||||
or payload.get("message_id")
|
||||
or payload.get("messageId")
|
||||
)
|
||||
thread_id = payload.get("threadId") or payload.get("thread_id")
|
||||
return message_id, thread_id, None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send Gmail email: {e!s}")
|
||||
return None, None, str(e)
|
||||
|
||||
async def create_gmail_draft(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
is_html: bool = False,
|
||||
) -> tuple[str | None, str | None, str | None, str | None]:
|
||||
"""Create a Gmail draft via the Composio ``GMAIL_CREATE_EMAIL_DRAFT`` toolkit.
|
||||
|
||||
Returns:
|
||||
Tuple of (draft_id, message_id, thread_id, error). On success
|
||||
``error`` is None and ``draft_id`` is populated.
|
||||
"""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"recipient_email": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"is_html": is_html,
|
||||
}
|
||||
cc_list = self._split_email_csv(cc)
|
||||
if cc_list:
|
||||
params["cc"] = cc_list
|
||||
bcc_list = self._split_email_csv(bcc)
|
||||
if bcc_list:
|
||||
params["bcc"] = bcc_list
|
||||
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GMAIL_CREATE_EMAIL_DRAFT",
|
||||
params=params,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, None, None, result.get("error", "Unknown error")
|
||||
|
||||
payload = self._unwrap_response_data(result.get("data", {}))
|
||||
draft_id = None
|
||||
message_id = None
|
||||
thread_id = None
|
||||
if isinstance(payload, dict):
|
||||
draft_id = payload.get("id") or payload.get("draft_id")
|
||||
draft_message = payload.get("message") or {}
|
||||
if isinstance(draft_message, dict):
|
||||
message_id = draft_message.get("id") or draft_message.get(
|
||||
"message_id"
|
||||
)
|
||||
thread_id = draft_message.get("threadId") or draft_message.get(
|
||||
"thread_id"
|
||||
)
|
||||
if message_id is None:
|
||||
message_id = payload.get("message_id") or payload.get("messageId")
|
||||
if thread_id is None:
|
||||
thread_id = payload.get("thread_id") or payload.get("threadId")
|
||||
return draft_id, message_id, thread_id, None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Gmail draft: {e!s}")
|
||||
return None, None, None, str(e)
|
||||
|
||||
async def update_gmail_draft(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
draft_id: str,
|
||||
to: str | None = None,
|
||||
subject: str | None = None,
|
||||
body: str | None = None,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
is_html: bool = False,
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
"""Update an existing Gmail draft via ``GMAIL_UPDATE_DRAFT``.
|
||||
|
||||
Returns:
|
||||
Tuple of (draft_id, message_id, error).
|
||||
"""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"draft_id": draft_id,
|
||||
"is_html": is_html,
|
||||
}
|
||||
if to:
|
||||
params["recipient_email"] = to
|
||||
if subject is not None:
|
||||
params["subject"] = subject
|
||||
if body is not None:
|
||||
params["body"] = body
|
||||
cc_list = self._split_email_csv(cc)
|
||||
if cc_list:
|
||||
params["cc"] = cc_list
|
||||
bcc_list = self._split_email_csv(bcc)
|
||||
if bcc_list:
|
||||
params["bcc"] = bcc_list
|
||||
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GMAIL_UPDATE_DRAFT",
|
||||
params=params,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, None, result.get("error", "Unknown error")
|
||||
|
||||
payload = self._unwrap_response_data(result.get("data", {}))
|
||||
new_draft_id = draft_id
|
||||
message_id = None
|
||||
if isinstance(payload, dict):
|
||||
new_draft_id = payload.get("id") or payload.get("draft_id") or draft_id
|
||||
draft_message = payload.get("message") or {}
|
||||
if isinstance(draft_message, dict):
|
||||
message_id = draft_message.get("id") or draft_message.get(
|
||||
"message_id"
|
||||
)
|
||||
if message_id is None:
|
||||
message_id = payload.get("message_id") or payload.get("messageId")
|
||||
return new_draft_id, message_id, None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update Gmail draft: {e!s}")
|
||||
return None, None, str(e)
|
||||
|
||||
async def trash_gmail_message(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
message_id: str,
|
||||
) -> str | None:
|
||||
"""Move a Gmail message to trash via ``GMAIL_MOVE_TO_TRASH``.
|
||||
|
||||
Returns the error message on failure, ``None`` on success.
|
||||
"""
|
||||
try:
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GMAIL_MOVE_TO_TRASH",
|
||||
params={"message_id": message_id},
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return result.get("error", "Unknown error")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trash Gmail message: {e!s}")
|
||||
return str(e)
|
||||
|
||||
# ===== Google Calendar write methods =====
|
||||
|
||||
async def create_calendar_event(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
summary: str,
|
||||
start_datetime: str,
|
||||
end_datetime: str,
|
||||
timezone: str | None = None,
|
||||
description: str | None = None,
|
||||
location: str | None = None,
|
||||
attendees: list[str] | None = None,
|
||||
calendar_id: str = "primary",
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
"""Create a Google Calendar event via ``GOOGLECALENDAR_CREATE_EVENT``.
|
||||
|
||||
Composio strips trailing timezone info on ``start_datetime`` /
|
||||
``end_datetime`` and uses the ``timezone`` field as the IANA name,
|
||||
so callers may pass ISO 8601 strings with or without offsets.
|
||||
|
||||
Returns:
|
||||
Tuple of (event_id, html_link, error).
|
||||
"""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"summary": summary,
|
||||
"start_datetime": start_datetime,
|
||||
"end_datetime": end_datetime,
|
||||
"calendar_id": calendar_id,
|
||||
}
|
||||
if timezone:
|
||||
params["timezone"] = timezone
|
||||
if description:
|
||||
params["description"] = description
|
||||
if location:
|
||||
params["location"] = location
|
||||
if attendees:
|
||||
params["attendees"] = [a for a in attendees if a]
|
||||
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GOOGLECALENDAR_CREATE_EVENT",
|
||||
params=params,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, None, result.get("error", "Unknown error")
|
||||
|
||||
payload = self._unwrap_response_data(result.get("data", {}))
|
||||
event_id = None
|
||||
html_link = None
|
||||
if isinstance(payload, dict):
|
||||
event_id = payload.get("id") or payload.get("event_id")
|
||||
html_link = payload.get("htmlLink") or payload.get("html_link")
|
||||
return event_id, html_link, None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Calendar event: {e!s}")
|
||||
return None, None, str(e)
|
||||
|
||||
async def update_calendar_event(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
event_id: str,
|
||||
summary: str | None = None,
|
||||
start_time: str | None = None,
|
||||
end_time: str | None = None,
|
||||
timezone: str | None = None,
|
||||
description: str | None = None,
|
||||
location: str | None = None,
|
||||
attendees: list[str] | None = None,
|
||||
calendar_id: str = "primary",
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
"""Patch an existing Google Calendar event via ``GOOGLECALENDAR_PATCH_EVENT``.
|
||||
|
||||
Uses PATCH (not PUT) semantics so omitted fields are preserved.
|
||||
|
||||
Returns:
|
||||
Tuple of (event_id, html_link, error).
|
||||
"""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"event_id": event_id,
|
||||
"calendar_id": calendar_id,
|
||||
}
|
||||
if summary is not None:
|
||||
params["summary"] = summary
|
||||
if start_time is not None:
|
||||
params["start_time"] = start_time
|
||||
if end_time is not None:
|
||||
params["end_time"] = end_time
|
||||
if timezone:
|
||||
params["timezone"] = timezone
|
||||
if description is not None:
|
||||
params["description"] = description
|
||||
if location is not None:
|
||||
params["location"] = location
|
||||
if attendees is not None:
|
||||
params["attendees"] = [a for a in attendees if a]
|
||||
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GOOGLECALENDAR_PATCH_EVENT",
|
||||
params=params,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, None, result.get("error", "Unknown error")
|
||||
|
||||
payload = self._unwrap_response_data(result.get("data", {}))
|
||||
new_event_id = event_id
|
||||
html_link = None
|
||||
if isinstance(payload, dict):
|
||||
new_event_id = payload.get("id") or payload.get("event_id") or event_id
|
||||
html_link = payload.get("htmlLink") or payload.get("html_link")
|
||||
return new_event_id, html_link, None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to patch Calendar event: {e!s}")
|
||||
return None, None, str(e)
|
||||
|
||||
async def delete_calendar_event(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
event_id: str,
|
||||
calendar_id: str = "primary",
|
||||
) -> str | None:
|
||||
"""Delete a Google Calendar event via ``GOOGLECALENDAR_DELETE_EVENT``.
|
||||
|
||||
Returns the error message on failure, ``None`` on success (idempotent
|
||||
on already-deleted events).
|
||||
"""
|
||||
try:
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GOOGLECALENDAR_DELETE_EVENT",
|
||||
params={
|
||||
"event_id": event_id,
|
||||
"calendar_id": calendar_id,
|
||||
},
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return result.get("error", "Unknown error")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete Calendar event: {e!s}")
|
||||
return str(e)
|
||||
|
||||
# ===== Google Drive write methods =====
|
||||
|
||||
@staticmethod
|
||||
def _drive_web_view_link(file_id: str, mime_type: str | None) -> str:
|
||||
"""Synthesize a Google Drive ``webViewLink`` from id + mimeType.
|
||||
|
||||
Composio's ``GOOGLEDRIVE_CREATE_FILE_FROM_TEXT`` returns flat
|
||||
metadata (id, name, mimeType) but does not always include a
|
||||
``webViewLink``. We rebuild the canonical UI URL based on the
|
||||
Workspace MIME type so callers can keep using a single field.
|
||||
"""
|
||||
if not file_id:
|
||||
return ""
|
||||
mt = (mime_type or "").lower()
|
||||
if mt == "application/vnd.google-apps.document":
|
||||
return f"https://docs.google.com/document/d/{file_id}/edit"
|
||||
if mt == "application/vnd.google-apps.spreadsheet":
|
||||
return f"https://docs.google.com/spreadsheets/d/{file_id}/edit"
|
||||
if mt == "application/vnd.google-apps.presentation":
|
||||
return f"https://docs.google.com/presentation/d/{file_id}/edit"
|
||||
if mt == "application/vnd.google-apps.folder":
|
||||
return f"https://drive.google.com/drive/folders/{file_id}"
|
||||
return f"https://drive.google.com/file/d/{file_id}/view"
|
||||
|
||||
async def create_drive_file_from_text(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
name: str,
|
||||
mime_type: str,
|
||||
content: str | None = None,
|
||||
parent_id: str | None = None,
|
||||
) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""Create a Google Drive file from text via ``GOOGLEDRIVE_CREATE_FILE_FROM_TEXT``.
|
||||
|
||||
Composio's tool requires ``text_content`` even for "empty" files;
|
||||
an empty string is accepted. Native Workspace types (Docs, Sheets)
|
||||
are produced by setting ``mime_type`` to the Google Apps MIME, and
|
||||
Drive auto-converts the text payload (e.g. CSV → Sheet).
|
||||
|
||||
Returns:
|
||||
Tuple of (file_meta, error). ``file_meta`` keys:
|
||||
``id``, ``name``, ``mimeType``, ``webViewLink``.
|
||||
"""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"file_name": name,
|
||||
"mime_type": mime_type,
|
||||
"text_content": content if content is not None else "",
|
||||
}
|
||||
if parent_id:
|
||||
params["parent_id"] = parent_id
|
||||
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GOOGLEDRIVE_CREATE_FILE_FROM_TEXT",
|
||||
params=params,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, result.get("error", "Unknown error")
|
||||
|
||||
payload = self._unwrap_response_data(result.get("data", {}))
|
||||
file_id: str | None = None
|
||||
file_name: str | None = name
|
||||
mime: str | None = mime_type
|
||||
web_view_link: str | None = None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
file_id = (
|
||||
payload.get("id") or payload.get("file_id") or payload.get("fileId")
|
||||
)
|
||||
file_name = payload.get("name") or payload.get("file_name") or name
|
||||
mime = payload.get("mimeType") or payload.get("mime_type") or mime_type
|
||||
web_view_link = payload.get("webViewLink") or payload.get(
|
||||
"web_view_link"
|
||||
)
|
||||
|
||||
if not file_id:
|
||||
return None, "Composio response did not include a file id"
|
||||
|
||||
if not web_view_link:
|
||||
web_view_link = self._drive_web_view_link(file_id, mime)
|
||||
|
||||
return (
|
||||
{
|
||||
"id": file_id,
|
||||
"name": file_name,
|
||||
"mimeType": mime,
|
||||
"webViewLink": web_view_link,
|
||||
},
|
||||
None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Drive file: {e!s}")
|
||||
return None, str(e)
|
||||
|
||||
async def trash_drive_file(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
file_id: str,
|
||||
) -> str | None:
|
||||
"""Move a Google Drive file to trash via ``GOOGLEDRIVE_TRASH_FILE``.
|
||||
|
||||
Returns the error message on failure, ``None`` on success.
|
||||
"""
|
||||
try:
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GOOGLEDRIVE_TRASH_FILE",
|
||||
params={"file_id": file_id},
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return result.get("error", "Unknown error")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trash Drive file: {e!s}")
|
||||
return str(e)
|
||||
|
||||
# ===== User Info Methods =====
|
||||
|
||||
async def get_connected_account_email(
|
||||
|
|
|
|||
|
|
@ -28,9 +28,7 @@ from langchain_core.messages import HumanMessage
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.multi_agent_chat import (
|
||||
create_surfsense_deep_agent as create_registry_deep_agent,
|
||||
)
|
||||
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
|
|
@ -577,6 +575,43 @@ async def _preflight_llm(llm: Any) -> None:
|
|||
)
|
||||
|
||||
|
||||
async def _build_main_agent_for_thread(
|
||||
agent_factory: Any,
|
||||
*,
|
||||
llm: Any,
|
||||
search_space_id: int,
|
||||
db_session: Any,
|
||||
connector_service: ConnectorService,
|
||||
checkpointer: Any,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
agent_config: AgentConfig | None,
|
||||
firecrawl_api_key: str | None,
|
||||
thread_visibility: ChatVisibility | None,
|
||||
filesystem_selection: FilesystemSelection | None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
) -> Any:
|
||||
"""Single (re)build path so the agent factory cannot drift across
|
||||
initial build, preflight repin, and mid-stream 429 recovery for one
|
||||
``thread_id``: a graph swap mid-turn would corrupt checkpointer state."""
|
||||
return await agent_factory(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=db_session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=thread_visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
)
|
||||
|
||||
|
||||
async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
|
||||
"""Wait for a discarded speculative agent build to release shared state.
|
||||
|
||||
|
|
@ -2767,7 +2802,7 @@ async def stream_new_chat(
|
|||
|
||||
_t0 = time.perf_counter()
|
||||
agent_factory = (
|
||||
create_registry_deep_agent
|
||||
create_multi_agent_chat_deep_agent
|
||||
if use_multi_agent
|
||||
else create_surfsense_deep_agent
|
||||
)
|
||||
|
|
@ -2776,7 +2811,8 @@ async def stream_new_chat(
|
|||
# if preflight reports 429 we will discard this future and rebuild
|
||||
# against the freshly pinned config below.
|
||||
agent_build_task = asyncio.create_task(
|
||||
agent_factory(
|
||||
_build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
|
|
@ -2787,9 +2823,9 @@ async def stream_new_chat(
|
|||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
filesystem_selection=filesystem_selection,
|
||||
),
|
||||
name="agent_build:stream_new_chat",
|
||||
)
|
||||
|
|
@ -3466,7 +3502,8 @@ async def stream_new_chat(
|
|||
title_task = None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
agent = await _build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
|
|
@ -3477,9 +3514,9 @@ async def stream_new_chat(
|
|||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
filesystem_selection=filesystem_selection,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Runtime rate-limit recovery repinned "
|
||||
|
|
@ -4130,12 +4167,13 @@ async def stream_resume_chat(
|
|||
|
||||
_t0 = time.perf_counter()
|
||||
agent_factory = (
|
||||
create_registry_deep_agent
|
||||
create_multi_agent_chat_deep_agent
|
||||
if _app_config.MULTI_AGENT_CHAT_ENABLED
|
||||
else create_surfsense_deep_agent
|
||||
)
|
||||
agent_build_task = asyncio.create_task(
|
||||
agent_factory(
|
||||
_build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
|
|
@ -4224,7 +4262,8 @@ async def stream_resume_chat(
|
|||
"fallback_config_id": llm_config_id,
|
||||
},
|
||||
)
|
||||
agent = await agent_factory(
|
||||
agent = await _build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
|
|
@ -4409,7 +4448,8 @@ async def stream_resume_chat(
|
|||
raise stream_exc
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
agent = await _build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
|
|
@ -4421,6 +4461,7 @@ async def stream_resume_chat(
|
|||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Runtime rate-limit recovery repinned "
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "surf-new-backend"
|
||||
version = "0.0.22"
|
||||
version = "0.0.23"
|
||||
description = "SurfSense Backend"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
|
|
|
|||
|
|
@ -0,0 +1,208 @@
|
|||
"""End-to-end resume-bridge tests against a real LangGraph subagent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Command, interrupt
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
|
||||
|
||||
class _SubagentState(TypedDict, total=False):
|
||||
messages: list
|
||||
decision_text: str
|
||||
|
||||
|
||||
def _build_single_interrupt_subagent():
|
||||
def approve_node(state):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{
|
||||
"name": "do_thing",
|
||||
"args": {"x": 1},
|
||||
"description": "test action",
|
||||
}
|
||||
],
|
||||
"review_configs": [{}],
|
||||
}
|
||||
)
|
||||
return {
|
||||
"messages": [AIMessage(content="done")],
|
||||
"decision_text": repr(decision),
|
||||
}
|
||||
|
||||
graph = StateGraph(_SubagentState)
|
||||
graph.add_node("approve", approve_node)
|
||||
graph.add_edge(START, "approve")
|
||||
graph.add_edge("approve", END)
|
||||
return graph.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
def _make_runtime(config: dict) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state={"messages": [HumanMessage(content="seed")]},
|
||||
context=None,
|
||||
config=config,
|
||||
stream_writer=None,
|
||||
tool_call_id="parent-tcid-1",
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_bridge_dispatches_decision_into_pending_subagent():
|
||||
"""Side-channel decision must reach the subagent's pending interrupt verbatim."""
|
||||
subagent = _build_single_interrupt_subagent()
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "approver",
|
||||
"description": "approves things",
|
||||
"runnable": subagent,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "shared-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
snap = await subagent.aget_state(parent_config)
|
||||
assert snap.tasks and snap.tasks[0].interrupts, (
|
||||
"fixture broken: subagent should be paused on its interrupt"
|
||||
)
|
||||
|
||||
parent_config["configurable"]["surfsense_resume_value"] = {
|
||||
"decisions": ["APPROVED"]
|
||||
}
|
||||
runtime = _make_runtime(parent_config)
|
||||
|
||||
result = await task_tool.coroutine(
|
||||
description="please approve",
|
||||
subagent_type="approver",
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
assert isinstance(result, Command)
|
||||
update = result.update
|
||||
assert update["decision_text"] == repr({"decisions": ["APPROVED"]})
|
||||
assert "surfsense_resume_value" not in parent_config["configurable"]
|
||||
|
||||
final = await subagent.aget_state(parent_config)
|
||||
assert not final.tasks or all(not t.interrupts for t in final.tasks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_interrupt_without_resume_value_raises_runtime_error():
|
||||
"""Bridge must fail loud rather than silently replay the user's interrupt."""
|
||||
subagent = _build_single_interrupt_subagent()
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "approver",
|
||||
"description": "approves things",
|
||||
"runnable": subagent,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "guard-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
snap = await subagent.aget_state(parent_config)
|
||||
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
|
||||
|
||||
runtime = _make_runtime(parent_config)
|
||||
|
||||
with pytest.raises(RuntimeError, match="resume bridge is broken"):
|
||||
await task_tool.coroutine(
|
||||
description="please approve",
|
||||
subagent_type="approver",
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
|
||||
def _build_bundle_subagent():
|
||||
def bundle_node(state):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{"name": "create_a", "args": {}, "description": ""},
|
||||
{"name": "create_b", "args": {}, "description": ""},
|
||||
{"name": "create_c", "args": {}, "description": ""},
|
||||
],
|
||||
"review_configs": [{}, {}, {}],
|
||||
}
|
||||
)
|
||||
return {
|
||||
"messages": [AIMessage(content="bundle-done")],
|
||||
"decision_text": repr(decision),
|
||||
}
|
||||
|
||||
graph = StateGraph(_SubagentState)
|
||||
graph.add_node("bundle", bundle_node)
|
||||
graph.add_edge(START, "bundle")
|
||||
graph.add_edge("bundle", END)
|
||||
return graph.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bundle_three_mixed_decisions_arrive_in_order():
|
||||
"""Approve / edit / reject for a 3-action bundle must land at ordinals 0/1/2."""
|
||||
subagent = _build_bundle_subagent()
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "bundler",
|
||||
"description": "creates a bundle",
|
||||
"runnable": subagent,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "bundle-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
|
||||
decisions_payload = {
|
||||
"decisions": [
|
||||
{"type": "approve", "args": {}},
|
||||
{"type": "edit", "args": {"args": {"name": "edited-b"}}},
|
||||
{"type": "reject", "args": {"message": "no thanks"}},
|
||||
]
|
||||
}
|
||||
parent_config["configurable"]["surfsense_resume_value"] = decisions_payload
|
||||
runtime = _make_runtime(parent_config)
|
||||
|
||||
result = await task_tool.coroutine(
|
||||
description="run bundle",
|
||||
subagent_type="bundler",
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
assert isinstance(result, Command)
|
||||
received = ast.literal_eval(result.update["decision_text"])
|
||||
assert received == decisions_payload
|
||||
assert received["decisions"][0]["type"] == "approve"
|
||||
assert received["decisions"][1]["type"] == "edit"
|
||||
assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}}
|
||||
assert received["decisions"][2]["type"] == "reject"
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""Pins the first-wins assumption of ``get_first_pending_subagent_interrupt``.
|
||||
|
||||
The bridge currently relies on at-most-one pending interrupt per snapshot
|
||||
(sequential tool nodes). If parallel tool calls are ever enabled, the bridge
|
||||
needs an id-aware lookup; these tests will need to be revisited at that point.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume import (
|
||||
get_first_pending_subagent_interrupt,
|
||||
)
|
||||
|
||||
|
||||
class TestGetFirstPendingSubagentInterrupt:
|
||||
def test_returns_first_when_multiple_top_level_interrupts_pending(self):
|
||||
first = SimpleNamespace(id="i-1", value={"decision": "approve"})
|
||||
second = SimpleNamespace(id="i-2", value={"decision": "reject"})
|
||||
state = SimpleNamespace(interrupts=(first, second), tasks=())
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == (
|
||||
"i-1",
|
||||
{"decision": "approve"},
|
||||
)
|
||||
|
||||
def test_returns_first_when_multiple_subtask_interrupts_pending(self):
|
||||
first = SimpleNamespace(id="i-A", value="approve")
|
||||
second = SimpleNamespace(id="i-B", value="reject")
|
||||
sub_task = SimpleNamespace(interrupts=(first, second))
|
||||
state = SimpleNamespace(interrupts=(), tasks=(sub_task,))
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == ("i-A", "approve")
|
||||
|
||||
def test_returns_none_when_no_interrupts(self):
|
||||
state = SimpleNamespace(interrupts=(), tasks=())
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == (None, None)
|
||||
|
||||
def test_returns_none_when_state_is_none(self):
|
||||
assert get_first_pending_subagent_interrupt(None) == (None, None)
|
||||
|
||||
def test_skips_interrupts_with_none_value(self):
|
||||
empty = SimpleNamespace(id="i-empty", value=None)
|
||||
real = SimpleNamespace(id="i-real", value="approve")
|
||||
state = SimpleNamespace(interrupts=(empty, real), tasks=())
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == ("i-real", "approve")
|
||||
|
||||
def test_normalizes_non_string_id_to_none(self):
|
||||
interrupt = SimpleNamespace(id=12345, value="approve")
|
||||
state = SimpleNamespace(interrupts=(interrupt,), tasks=())
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == (None, "approve")
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
"""Resume side-channel must be read exactly once per turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
|
||||
consume_surfsense_resume,
|
||||
has_surfsense_resume,
|
||||
)
|
||||
|
||||
|
||||
def _runtime_with_config(config: dict) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state=None,
|
||||
context=None,
|
||||
config=config,
|
||||
stream_writer=None,
|
||||
tool_call_id="tcid-test",
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
class TestConsumeSurfsenseResume:
|
||||
def test_pops_value_on_first_call(self):
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
|
||||
)
|
||||
|
||||
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
|
||||
|
||||
def test_second_call_returns_none(self):
|
||||
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
|
||||
runtime = _runtime_with_config({"configurable": configurable})
|
||||
|
||||
consume_surfsense_resume(runtime)
|
||||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
assert "surfsense_resume_value" not in configurable
|
||||
|
||||
def test_returns_none_when_no_payload_queued(self):
|
||||
runtime = _runtime_with_config({"configurable": {}})
|
||||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
|
||||
def test_returns_none_when_configurable_missing(self):
|
||||
runtime = _runtime_with_config({})
|
||||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
|
||||
|
||||
class TestHasSurfsenseResume:
|
||||
def test_true_when_payload_queued(self):
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": {"surfsense_resume_value": "approve"}}
|
||||
)
|
||||
|
||||
assert has_surfsense_resume(runtime) is True
|
||||
|
||||
def test_does_not_consume_payload(self):
|
||||
configurable = {"surfsense_resume_value": "approve"}
|
||||
runtime = _runtime_with_config({"configurable": configurable})
|
||||
|
||||
has_surfsense_resume(runtime)
|
||||
|
||||
assert configurable == {"surfsense_resume_value": "approve"}
|
||||
|
||||
def test_false_when_payload_absent(self):
|
||||
runtime = _runtime_with_config({"configurable": {}})
|
||||
|
||||
assert has_surfsense_resume(runtime) is False
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
"""Subagent resilience contract: ``extra_middleware`` reaches the agent chain."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeMessagesListChatModel,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""Name matches the scoped-fallback eligibility allowlist."""
|
||||
|
||||
|
||||
class _AlwaysFailingChatModel(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "always-failing-test-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "primary llm exploded"
|
||||
raise RateLimitError(msg)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "primary llm exploded"
|
||||
raise RateLimitError(msg)
|
||||
|
||||
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
|
||||
msg = "primary llm exploded"
|
||||
raise RateLimitError(msg)
|
||||
|
||||
async def _astream(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> AsyncIterator[ChatGeneration]:
|
||||
msg = "primary llm exploded"
|
||||
raise RateLimitError(msg)
|
||||
yield # pragma: no cover - unreachable, satisfies async generator typing
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_recovers_when_primary_llm_fails():
|
||||
"""Fallback in ``extra_middleware`` must finish the turn when primary raises."""
|
||||
primary = _AlwaysFailingChatModel()
|
||||
fallback = FakeMessagesListChatModel(
|
||||
responses=[AIMessage(content="recovered via fallback")]
|
||||
)
|
||||
|
||||
spec = pack_subagent(
|
||||
name="resilience_test",
|
||||
description="test subagent",
|
||||
system_prompt="be helpful",
|
||||
tools=[],
|
||||
model=primary,
|
||||
extra_middleware=[ModelFallbackMiddleware(fallback)],
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=spec["model"],
|
||||
tools=spec["tools"],
|
||||
middleware=spec["middleware"],
|
||||
system_prompt=spec["system_prompt"],
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage(content="hi")]})
|
||||
|
||||
final = result["messages"][-1]
|
||||
assert isinstance(final, AIMessage)
|
||||
assert final.content == "recovered via fallback"
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
"""``ScopedModelFallbackMiddleware`` triggers fallback only on provider errors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
|
||||
class _RaisingChatModel(BaseChatModel):
|
||||
exc_to_raise: Any
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "raising-test-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise self.exc_to_raise
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise self.exc_to_raise
|
||||
|
||||
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
|
||||
raise self.exc_to_raise
|
||||
|
||||
async def _astream(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> AsyncIterator[ChatGeneration]:
|
||||
raise self.exc_to_raise
|
||||
yield # pragma: no cover - unreachable
|
||||
|
||||
|
||||
class _RecordingChatModel(BaseChatModel):
|
||||
response_text: str = "fallback-ok"
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "recording-test-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
self.call_count += 1
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=self.response_text))]
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return self._generate(messages, stop, None, **kwargs)
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""Name matches the scoped-fallback eligibility allowlist."""
|
||||
|
||||
|
||||
def _build_agent(primary: BaseChatModel, fallback: BaseChatModel):
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
|
||||
return create_agent(
|
||||
model=primary,
|
||||
tools=[],
|
||||
middleware=[ScopedModelFallbackMiddleware(fallback)],
|
||||
system_prompt="be helpful",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_errors_trigger_fallback():
|
||||
"""Eligible exception names must drive the fallback chain."""
|
||||
primary = _RaisingChatModel(exc_to_raise=RateLimitError("429 from provider"))
|
||||
fallback = _RecordingChatModel(response_text="recovered")
|
||||
|
||||
agent = _build_agent(primary, fallback)
|
||||
result = await agent.ainvoke({"messages": [("user", "hi")]})
|
||||
|
||||
final = result["messages"][-1]
|
||||
assert isinstance(final, AIMessage)
|
||||
assert final.content == "recovered"
|
||||
assert fallback.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_programming_errors_propagate_without_invoking_fallback():
|
||||
"""Non-eligible exceptions must propagate; fallback must not be invoked."""
|
||||
primary = _RaisingChatModel(exc_to_raise=KeyError("missing_state_field"))
|
||||
fallback = _RecordingChatModel(response_text="should-never-arrive")
|
||||
|
||||
agent = _build_agent(primary, fallback)
|
||||
|
||||
with pytest.raises(KeyError, match="missing_state_field"):
|
||||
await agent.ainvoke({"messages": [("user", "hi")]})
|
||||
|
||||
assert fallback.call_count == 0
|
||||
|
|
@ -202,6 +202,15 @@ class FakeBudgetLLM:
|
|||
|
||||
|
||||
class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _disable_planner_runnable(self, monkeypatch):
|
||||
# ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the
|
||||
# planner Runnable path is enabled) calls ``.bind()`` on the LLM,
|
||||
# which the mock does not implement. Pin the flag off so the
|
||||
# planner falls through to the legacy ``self.llm.ainvoke`` path
|
||||
# these tests assert against (``llm.calls[0]["config"]``).
|
||||
monkeypatch.setenv("SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "false")
|
||||
|
||||
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
|
||||
messages = [
|
||||
HumanMessage(content="old user context " * 40),
|
||||
|
|
|
|||
2
surfsense_backend/uv.lock
generated
2
surfsense_backend/uv.lock
generated
|
|
@ -7947,7 +7947,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "surf-new-backend"
|
||||
version = "0.0.22"
|
||||
version = "0.0.23"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "alembic" },
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue